Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

stage_03_training.py 2.1 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
  1. import argparse
  2. import os
  3. import shutil
  4. from tqdm import tqdm
  5. import logging
  6. from src.utils.all_utils import read_yaml, create_directory
  7. import joblib
  8. import numpy as np
  9. from sklearn.ensemble import RandomForestClassifier
  10. STAGE = "Three"
  11. logging.basicConfig(
  12. filename=os.path.join("logs", 'running_logs.log'),
  13. level=logging.INFO,
  14. format="[%(asctime)s: %(levelname)s: %(module)s]: %(message)s",
  15. filemode="a"
  16. )
  17. def main(config_path, params_path):
  18. config = read_yaml(config_path)
  19. params = read_yaml(params_path)
  20. artifacts = config["artifacts"]
  21. featurized_data_dir_path = os.path.join(artifacts["ARTIFACTS_DIR"], artifacts["FEATURIZED_DATA"])
  22. featurized_train_data_path = os.path.join(featurized_data_dir_path, artifacts["FEATURIZED_OUT_TRAIN"])
  23. model_dir_path = os.path.join(artifacts["ARTIFACTS_DIR"], artifacts["MODEL_DIR"])
  24. create_directory([model_dir_path])
  25. model_path = os.path.join(model_dir_path, artifacts["MODEL_NAME"])
  26. # Load matrix
  27. matrix = joblib.load(featurized_train_data_path)
  28. labels = np.squeeze(matrix[:, 1].toarray())
  29. X = matrix[:,2:]
  30. logging.info(f"input matrix size: {matrix.shape}")
  31. logging.info(f"X matrix size: {X.shape}")
  32. logging.info(f"Y matrix size or labels size: {labels.shape}")
  33. seed = params["train"]["seed"]
  34. n_est = params["train"]["n_est"]
  35. min_split = params["train"]['min_split']
  36. model = RandomForestClassifier(n_estimators=n_est,
  37. min_samples_split=min_split,
  38. random_state=seed)
  39. model.fit(X, labels)
  40. joblib.dump(model, model_path)
  41. if __name__ == '__main__':
  42. args = argparse.ArgumentParser()
  43. args.add_argument("--config", "-c", default="config/config.yaml")
  44. args.add_argument("--params", "-p", default="params.yaml")
  45. parsed_args = args.parse_args()
  46. try:
  47. logging.info("\n********************")
  48. logging.info(f">>>>> stage {STAGE} started <<<<<")
  49. main(config_path=parsed_args.config, params_path=parsed_args.params)
  50. logging.info(f">>>>> stage {STAGE} completed!<<<<<\n")
  51. except Exception as e:
  52. logging.exception(e)
  53. raise e
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...