Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

train_model.py 2.0 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
  1. from src.utils.file import read_dvc_params
  2. dvc_params = read_dvc_params(__file__)
  3. pipeline_params = dvc_params['train-pipeline']
  4. from src.scripts.dataset import load_train_val
  5. train_set, val_set = load_train_val(pipeline_params['data'])
  6. training_params = dvc_params['train']
  7. from src.scripts.config import CustomConfig
  8. config = CustomConfig(training_params['configs'])
  9. config.display()
  10. from mrcnn.model import MaskRCNN
  11. model = MaskRCNN(
  12. mode='training',
  13. model_dir=pipeline_params['train_dir'],
  14. config=config
  15. )
  16. weight_config = training_params['weights']
  17. model.load_weights(
  18. weight_config['init'],
  19. by_name=True,
  20. exclude=weight_config['exclude']
  21. )
  22. with open(pipeline_params['summary'], 'w+') as f:
  23. model.keras_model.summary(print_fn=lambda x: f.write(x + '\n'))
  24. from src.utils.benchmark import bench
  25. training_benchmark = bench(
  26. 'Training', model.train,
  27. train_set, val_set,
  28. learning_rate = config.LEARNING_RATE,
  29. epochs = training_params['epochs'],
  30. layers = training_params['layers']
  31. )
  32. output_params = pipeline_params['output']
  33. model.keras_model.save_weights(output_params['model'])
  34. from src.utils.output import write_file, check_create
  35. general_metrics = {
  36. 'train_time': training_benchmark['time']
  37. }
  38. metric_output_folder = output_params['metric']['folder']
  39. check_create(metric_output_folder)
  40. write_file(output_params['metric']['general'], general_metrics, formatter='json', folder=metric_output_folder)
  41. # Get other training metrics wrote in the log folder (TensorBoard)
  42. from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
  43. event_acc = EventAccumulator(model.log_dir)
  44. event_acc.Reload()
  45. keys = ['timestamp', 'iteration', 'value']
  46. # ! we actually only care about bounding box loss for this project
  47. other_metrics = output_params['metric']['others']
  48. for metric_name in other_metrics:
  49. metric = {
  50. metric_name: [
  51. dict(zip(keys ,values))
  52. for values in event_acc.Scalars(metric_name)
  53. ]
  54. }
  55. write_file(metric_name + '.json', metric, 'json', metric_output_folder)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...