Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

train.py 4.0 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
  1. from data import DataLoader # the data
  2. from model import TSModel # the model
  3. import hydra # for configurations
  4. import tf2onnx # model conversion
  5. import tensorflow as tf
  6. from omegaconf.omegaconf import OmegaConf # configs
  7. import matplotlib.pyplot as plt # plots
  8. import mlflow # for tracking
  9. #EXPERIMENT_NAME = "TweetsSentiment"
  10. #EXPERIMENT_ID = mlflow.create_experiment(EXPERIMENT_NAME)
  11. MLFLOW_TRACKING_URI="https://dagshub.com/Marshall-mk/TweetsSentimentProject.mlflow"
  12. mlflow.set_tracking_uri(MLFLOW_TRACKING_URI)
  13. mlflow.tensorflow.autolog()
  14. @hydra.main(config_path="./configs", config_name="configs")
  15. def main(cfg):
  16. OmegaConf.to_yaml(cfg, resolve=True)
  17. """defines the data and the model"""
  18. ts_data = DataLoader()
  19. ts_model = TSModel()
  20. max_length = 100
  21. max_tokens = 2000
  22. text_vectorization = tf.keras.layers.TextVectorization(max_tokens=max_tokens, output_mode="int", output_sequence_length=max_length,)
  23. text_only_train_ds = ts_data.load_train_data(cfg.model.Train_path).map(lambda x, y: x)
  24. text_vectorization.adapt(text_only_train_ds)
  25. train_ds = ts_data.load_train_data(cfg.model.Train_path).map(lambda x, y: (text_vectorization(x), y),num_parallel_calls=4)
  26. val_ds = ts_data.load_val_data(cfg.model.Val_path).map(lambda x, y: (text_vectorization(x), y),num_parallel_calls=4)
  27. test_ds = ts_data.load_test_data(cfg.model.Test_path).map(lambda x, y: (text_vectorization(x), y),num_parallel_calls=4)
  28. # we need to save the text vectorization layer in order to use it in the inference
  29. vectorize_layer_model = tf.keras.models.Sequential()
  30. vectorize_layer_model.add(tf.keras.Input(shape=(1,), dtype=tf.string))
  31. vectorize_layer_model.add(text_vectorization)
  32. filepath = cfg.model.vectorize_layer_path
  33. vectorize_layer_model.save(filepath, save_format="tf")
  34. """Compiles and trains the model"""
  35. ts_model.compile(optimizer= cfg.train.optimizer, loss = cfg.train.loss, metrics= cfg.train.metrics)
  36. """Model callbacks"""
  37. earlystopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience=5)
  38. checkpointer = tf.keras.callbacks.ModelCheckpoint(filepath=cfg.model.ckpt_path, save_weights_only=True, save_best_only=True)
  39. """Trains the model"""
  40. with mlflow.start_run():
  41. # ... Define a model
  42. model_info = ts_model.fit(
  43. train_ds,
  44. batch_size=cfg.train.batch_size,
  45. epochs=cfg.train.epochs,
  46. validation_data=val_ds,
  47. callbacks= [earlystopping, checkpointer])
  48. """Evaluates the model on the test set"""
  49. print(f'Model evaluation metrics: {ts_model.evaluate(test_ds)}')
  50. mlflow.end_run()
  51. """Saving the model"""
  52. ts_model.save(cfg.model.save_path)
  53. # """converting the model to onnx"""
  54. spec = (tf.TensorSpec((None,None,), tf.float32, name="input"),)
  55. output_path = cfg.model.onnx_path
  56. model_proto, _ = tf2onnx.convert.from_keras(ts_model, input_signature=spec, opset=13, output_path=output_path)
  57. """Model training history """
  58. _model_history(model_info=model_info, cfg=cfg)
  59. def _model_history(model_info, cfg):
  60. accuracy = model_info.history["accuracy"]
  61. val_accuracy = model_info.history["val_accuracy"]
  62. loss = model_info.history["loss"]
  63. val_loss = model_info.history["val_loss"]
  64. epochs = range(1, len(accuracy) + 1)
  65. plt.figure(figsize=(20,10))
  66. plt.plot(epochs, accuracy, "g-", label="Training accuracy")
  67. plt.plot(epochs, val_accuracy, "b", label="Validation accuracy")
  68. plt.title("Training and validation accuracy")
  69. plt.grid()
  70. plt.savefig(f'{cfg.model.history_path}accuracy.png', dpi=300, bbox_inches='tight')
  71. plt.legend()
  72. plt.figure(figsize=(20,10))
  73. plt.plot(epochs, loss, "g-", label="Training loss")
  74. plt.plot(epochs, val_loss, "b", label="Validation loss")
  75. plt.title("Training and validation loss")
  76. plt.legend()
  77. plt.grid()
  78. plt.savefig(f'{cfg.model.history_path}loss.png', bbox_inches='tight', dpi=300)
  79. if __name__ == "__main__":
  80. main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...