Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

stage_02_prepare_base_model.py 2.2 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
  1. import argparse
  2. import os
  3. from src.utils.common import read_yaml_file, create_directories
  4. from src.utils.model import get_VGG16_model, prepare_full_model
  5. from tqdm import tqdm
  6. import logging
  7. logging.basicConfig(
  8. filename=os.path.join("logs", 'running_logs.log'),
  9. level=logging.INFO,
  10. format="[%(asctime)s: %(levelname)s: %(module)s]: %(message)s",
  11. filemode="a"
  12. )
  13. def prepare_base_model(config_path: str, params_path: str) -> None:
  14. """prepares and saves the untrained model
  15. that can be used for training later on the given data
  16. Args:
  17. config_path (str): path to configuration file
  18. params_path (str): path to params file
  19. """
  20. config = read_yaml_file(config_path)
  21. params = read_yaml_file(params_path)
  22. artifacts = config["artifacts"]
  23. artifacts_dir = artifacts["ARTIFACTS_DIR"]
  24. base_model_dir = artifacts["BASE_MODEL_DIR"]
  25. base_model_name = artifacts["BASE_MODEL_NAME"]
  26. base_model_dir_path = os.path.join(artifacts_dir, base_model_dir)
  27. create_directories([base_model_dir_path])
  28. base_model_path = os.path.join(base_model_dir_path, base_model_name)
  29. base_model = get_VGG16_model(
  30. input_shape=params["IMAGE_SIZE"],
  31. model_path=base_model_path)
  32. full_model = prepare_full_model(
  33. base_model,
  34. learning_rate=params["LEARNING_RATE"],
  35. CLASSES = 2,
  36. freeze_all=True,
  37. freeze_till=None)
  38. updated_full_model_path = os.path.join(base_model_dir_path, artifacts["UPDATED_BASE_MODEL_NAME"])
  39. full_model.save(updated_full_model_path)
  40. logging.info(f"full untrained model is saved at {updated_full_model_path}")
  41. if __name__ == '__main__':
  42. args = argparse.ArgumentParser()
  43. args.add_argument("--config", "-c", default="configs/config.yaml")
  44. args.add_argument("--params", "-p", default="params.yaml")
  45. parsed_args = args.parse_args()
  46. try:
  47. logging.info("\n********************")
  48. logging.info(">>>>> stage two started <<<<<")
  49. prepare_base_model(config_path=parsed_args.config, params_path=parsed_args.params)
  50. logging.info(">>>>> stage two completed! base model is created and saved successfully <<<<<\n")
  51. except Exception as e:
  52. logging.exception(e)
  53. raise e
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...