Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

stage_04_train.py 3.0 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
  1. import argparse
  2. import os
  3. from tqdm import tqdm
  4. import logging
  5. from src.utils.common import read_yaml_file, create_directories
  6. from src.utils.model import load_full_model, get_unqique_path_to_save_model
  7. from src.utils.callbacks import get_callbacks
  8. from src.utils.data_management import train_valid_generator
  9. logging.basicConfig(
  10. filename=os.path.join("logs", 'running_logs.log'),
  11. level=logging.INFO,
  12. format="[%(asctime)s: %(levelname)s: %(module)s]: %(message)s",
  13. filemode="a"
  14. )
  15. def train_model(config_path: str, params_path: str) -> None:
  16. """function to train the model and save it into
  17. Args:
  18. config_path (str): path to config file
  19. params_path (str): path to params file
  20. """
  21. config = read_yaml_file(config_path)
  22. params = read_yaml_file(params_path)
  23. artifacts = config["artifacts"]
  24. ### get the untrained full model
  25. artifacts_dir = artifacts["ARTIFACTS_DIR"]
  26. train_model_dir_path = os.path.join(artifacts_dir, artifacts["TRAINED_MODEL_DIR"])
  27. create_directories([train_model_dir_path])
  28. untrained_full_model_path = os.path.join(artifacts_dir,
  29. artifacts["BASE_MODEL_DIR"], artifacts["UPDATED_BASE_MODEL_NAME"])
  30. model = load_full_model(untrained_full_model_path)
  31. ### get the callbacks
  32. callback_dir_path = os.path.join(artifacts_dir, artifacts["CALLBACKS_DIR"])
  33. callbacks = get_callbacks(callback_dir_path)
  34. ### get the data to create data generator
  35. train_generator, valid_generator = train_valid_generator(
  36. data_dir=artifacts["DATA_DIR"],
  37. IMAGE_SIZE=tuple(params["IMAGE_SIZE"][:-1]),
  38. BATCH_SIZE=params["BATCH_SIZE"],
  39. do_data_augmention=params["AUGMENTATION"]
  40. )
  41. ### train the model
  42. steps_per_epoch = train_generator.samples // train_generator.batch_size
  43. validation_steps = valid_generator.samples // valid_generator.batch_size
  44. model.fit(
  45. train_generator,
  46. validation_data=valid_generator,
  47. epochs=params["EPOCHS"],
  48. steps_per_epoch=steps_per_epoch,
  49. validation_steps=validation_steps,
  50. callbacks=callbacks
  51. )
  52. ### save the trained model
  53. trained_model_dir = os.path.join(artifacts_dir, artifacts["TRAINED_MODEL_DIR"])
  54. create_directories([trained_model_dir])
  55. model_file_path = get_unqique_path_to_save_model(trained_model_dir)
  56. model.save(model_file_path)
  57. logging.info(f"trained models is saved at: \n{model_file_path}")
  58. if __name__ == '__main__':
  59. args = argparse.ArgumentParser()
  60. args.add_argument("--config", "-c", default="configs/config.yaml")
  61. args.add_argument("--params", "-p", default="params.yaml")
  62. parsed_args = args.parse_args()
  63. try:
  64. logging.info("\n********************")
  65. logging.info(">>>>> stage four started <<<<<")
  66. train_model(config_path=parsed_args.config, params_path=parsed_args.params)
  67. logging.info(">>>>> stage four completed! training completed <<<<<n")
  68. except Exception as e:
  69. logging.exception(e)
  70. raise e
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...