Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

train.py 3.1 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
  1. from data import DataLoader # the data
  2. from model import FEModel # the model
  3. import hydra # for configurations
  4. import tf2onnx # model conversion
  5. import tensorflow as tf
  6. from omegaconf.omegaconf import OmegaConf # configs
  7. import matplotlib.pyplot as plt # plots
  8. import mlflow # for tracking
  9. EXPERIMENT_NAME = "facial-expression-recognition"
  10. EXPERIMENT_ID = mlflow.create_experiment(EXPERIMENT_NAME)
  11. MLFLOW_TRACKING_URI="https://dagshub.com/Marshall-mk/Face-Expression.mlflow"
  12. mlflow.set_tracking_uri(MLFLOW_TRACKING_URI)
  13. mlflow.tensorflow.autolog()
  14. @hydra.main(config_path="./configs", config_name="configs")
  15. def main(cfg):
  16. OmegaConf.to_yaml(cfg, resolve=True)
  17. """defines the data and the model"""
  18. fe_data = DataLoader()
  19. fe_model = FEModel()
  20. """Compiles and trains the model"""
  21. fe_model.compile(optimizer= cfg.train.optimizer, loss = cfg.train.loss, metrics= cfg.train.metrics)
  22. """Model callbacks"""
  23. earlystopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience=15)
  24. checkpointer = tf.keras.callbacks.ModelCheckpoint(filepath=cfg.model.ckpt_path, save_weights_only=True, save_best_only=True)
  25. """Trains the model"""
  26. with mlflow.start_run():
  27. # ... Define a model
  28. model_info = fe_model.train(
  29. fe_data.load_train_data(cfg.model.Train_path),
  30. batch_size=cfg.train.batch_size,
  31. epochs=cfg.train.epochs,
  32. validation_data=fe_data.load_val_data(cfg.model.Train_path),
  33. callbacks= [earlystopping, checkpointer]) # , WandbCallback()
  34. """Evaluates the model on the test set"""
  35. print(f'Model evaluation metrics: {fe_model.evaluate(fe_data.load_test_data(cfg.model.Test_path))}')
  36. mlflow.end_run()
  37. """Saving the model"""
  38. fe_model.save(cfg.model.save_path)
  39. # """converting the model to onnx"""
  40. spec = (tf.TensorSpec((None, 48, 48, 1), tf.float32, name="input"),)
  41. output_path = cfg.model.onnx_path
  42. model_proto, _ = tf2onnx.convert.from_keras(fe_model, input_signature=spec, opset=13, output_path=output_path)
  43. """Model training history """
  44. _model_history(model_info=model_info, cfg=cfg)
  45. def _model_history(model_info, cfg):
  46. accuracy = model_info.history["accuracy"]
  47. val_accuracy = model_info.history["val_accuracy"]
  48. loss = model_info.history["loss"]
  49. val_loss = model_info.history["val_loss"]
  50. epochs = range(1, len(accuracy) + 1)
  51. plt.figure(figsize=(20,10))
  52. plt.plot(epochs, accuracy, "g-", label="Training accuracy")
  53. plt.plot(epochs, val_accuracy, "b", label="Validation accuracy")
  54. plt.title("Training and validation accuracy")
  55. plt.grid()
  56. plt.savefig(f'{cfg.model.history_path}accuracy.png', dpi=300, bbox_inches='tight')
  57. plt.legend()
  58. plt.figure(figsize=(20,10))
  59. plt.plot(epochs, loss, "g-", label="Training loss")
  60. plt.plot(epochs, val_loss, "b", label="Validation loss")
  61. plt.title("Training and validation loss")
  62. plt.legend()
  63. plt.grid()
  64. plt.savefig(f'{cfg.model.history_path}loss.png', bbox_inches='tight', dpi=300)
  65. if __name__ == "__main__":
  66. main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...