Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

segment.py 3.2 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
  1. import pickle
  2. import warnings
  3. from typing import Tuple
  4. import matplotlib
  5. import matplotlib.pyplot as plt
  6. import numpy as np
  7. import pandas as pd
  8. from omegaconf import DictConfig
  9. from prefect import flow, task
  10. from sklearn.cluster import KMeans
  11. from sklearn.decomposition import PCA
  12. from yellowbrick.cluster import KElbowVisualizer
  13. from helper import create_parent_directory, load_config
  14. warnings.simplefilter(action="ignore", category=DeprecationWarning)
  15. @task
  16. def read_process_data(config: DictConfig):
  17. return pd.read_csv(config.intermediate.path)
  18. @task
  19. def get_pca_model(data: pd.DataFrame) -> PCA:
  20. pca = PCA(n_components=3)
  21. pca.fit(data)
  22. return pca
  23. @task
  24. def reduce_dimension(df: pd.DataFrame, pca: PCA) -> pd.DataFrame:
  25. return pd.DataFrame(pca.transform(df), columns=["col1", "col2", "col3"])
  26. @task
  27. def get_3d_projection(pca_df: pd.DataFrame) -> dict:
  28. """A 3D Projection Of Data In The Reduced Dimensionality Space"""
  29. return {"x": pca_df["col1"], "y": pca_df["col2"], "z": pca_df["col3"]}
  30. @task
  31. def get_best_k_cluster(
  32. pca_df: pd.DataFrame, config: DictConfig
  33. ) -> pd.DataFrame:
  34. matplotlib.use("svg")
  35. fig = plt.figure(figsize=(10, 8))
  36. fig.add_subplot(111)
  37. elbow = KElbowVisualizer(KMeans(), metric="distortion")
  38. elbow.fit(pca_df)
  39. create_parent_directory(config.image.kmeans)
  40. elbow.fig.savefig(config.image.kmeans)
  41. k_best = elbow.elbow_value_
  42. return k_best
  43. @task
  44. def get_clusters_model(
  45. pca_df: pd.DataFrame, k: int
  46. ) -> Tuple[pd.DataFrame, pd.DataFrame]:
  47. model = KMeans(n_clusters=k)
  48. # Fit model
  49. return model.fit(pca_df)
  50. @task
  51. def predict(model, pca_df: pd.DataFrame):
  52. return model.predict(pca_df)
  53. @task
  54. def insert_clusters_to_df(
  55. df: pd.DataFrame, clusters: np.ndarray
  56. ) -> pd.DataFrame:
  57. return df.assign(clusters=clusters)
  58. @task
  59. def plot_clusters(
  60. pca_df: pd.DataFrame,
  61. preds: np.ndarray,
  62. projections: dict,
  63. config: DictConfig,
  64. ) -> None:
  65. pca_df["clusters"] = preds
  66. matplotlib.use("svg")
  67. plt.figure(figsize=(10, 8))
  68. ax = plt.subplot(111, projection="3d")
  69. ax.scatter(
  70. projections["x"],
  71. projections["y"],
  72. projections["z"],
  73. s=40,
  74. c=pca_df["clusters"],
  75. marker="o",
  76. cmap="Accent",
  77. )
  78. ax.set_title("The Plot Of The Clusters")
  79. plt.savefig(config.image.clusters)
  80. @task
  81. def save_data_and_model(data: pd.DataFrame, model: KMeans, config: DictConfig):
  82. create_parent_directory(config.final.path)
  83. data.to_csv(config.final.path, index=False)
  84. create_parent_directory(config.model.path)
  85. pickle.dump(model, open(config.model.path, "wb"))
  86. @flow(name="Segment customers")
  87. def segment() -> None:
  88. config = load_config()
  89. data = read_process_data(config)
  90. pca = get_pca_model(data)
  91. pca_df = reduce_dimension(data, pca)
  92. projections = get_3d_projection(pca_df)
  93. k_best = get_best_k_cluster(pca_df, config)
  94. model = get_clusters_model(pca_df, k_best)
  95. pred = predict(model, pca_df)
  96. data = insert_clusters_to_df(data, pred)
  97. plot_clusters(
  98. pca_df,
  99. pred,
  100. projections,
  101. config,
  102. )
  103. save_data_and_model(data, model, config)
  104. if __name__ == "__main__":
  105. segment()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...