Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

train.py 3.7 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
  1. from pathlib import Path
  2. import tensorflow as tf
  3. from dvclive.keras import DVCLiveCallback
  4. from dvc.api import params_show
  5. # Set the paths to the train and validation directories
  6. BASE_DIR = Path(__file__).parent.parent
  7. data_dir = BASE_DIR / "data"
  8. # Load the parameters from params.yaml
  9. params = params_show()["train"]
  10. ##############################################################################
  11. # Below, image width, height and batch_size parameters come from params.yaml #
  12. ##############################################################################
  13. # Create an ImageDataGenerator object for the train set with augmentation
  14. train_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
  15. rescale=1.0 / 255,
  16. rotation_range=10,
  17. width_shift_range=0.1,
  18. height_shift_range=0.1,
  19. zoom_range=0.15,
  20. fill_mode="nearest",
  21. )
  22. train_generator = train_datagen.flow_from_directory(
  23. data_dir / "prepared" / "train",
  24. target_size=(params["image_width"], params["image_height"]),
  25. batch_size=params["batch_size"],
  26. class_mode="categorical",
  27. )
  28. # Do the same for test
  29. test_dataget = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1.0 / 255)
  30. test_generator = test_dataget.flow_from_directory(
  31. data_dir / "prepared" / "test",
  32. target_size=(params["image_width"], params["image_height"]),
  33. batch_size=params["batch_size"],
  34. class_mode="categorical",
  35. )
  36. def get_model():
  37. """Define the model to be fit"""
  38. # Define a CNN model
  39. model = tf.keras.models.Sequential(
  40. [
  41. tf.keras.layers.Conv2D(
  42. filters=32,
  43. kernel_size=3,
  44. activation="relu",
  45. input_shape=(params["image_width"], params["image_height"], 3),
  46. ),
  47. tf.keras.layers.Conv2D(filters=64, kernel_size=3, activation="relu"),
  48. tf.keras.layers.MaxPooling2D(2, 2),
  49. tf.keras.layers.BatchNormalization(axis=-1),
  50. tf.keras.layers.Conv2D(filters=128, kernel_size=3, activation="relu"),
  51. tf.keras.layers.Conv2D(filters=256, kernel_size=3, activation="relu"),
  52. tf.keras.layers.MaxPooling2D(2, 2),
  53. tf.keras.layers.BatchNormalization(axis=-1),
  54. tf.keras.layers.Flatten(),
  55. tf.keras.layers.Dense(512, activation="relu"),
  56. tf.keras.layers.BatchNormalization(),
  57. tf.keras.layers.Dropout(0.5),
  58. tf.keras.layers.Dense(43, activation="softmax"),
  59. ]
  60. )
  61. # Compile the model
  62. model.compile(
  63. loss=tf.keras.losses.categorical_crossentropy,
  64. # Learning rate is loaded from `params.yaml`
  65. optimizer=tf.keras.optimizers.Adam(learning_rate=params["learning_rate"]),
  66. metrics=["accuracy", tf.keras.metrics.Precision(), tf.keras.metrics.Recall()],
  67. )
  68. return model
  69. def main():
  70. # Get the model
  71. model = get_model()
  72. # Create a path to save the model
  73. model_path = BASE_DIR / "models"
  74. model_path.mkdir(parents=True, exist_ok=True)
  75. # Define callbacks
  76. callbacks = [
  77. tf.keras.callbacks.ModelCheckpoint(
  78. model_path / "model.keras", monitor="val_accuracy", save_best_only=True
  79. ),
  80. tf.keras.callbacks.EarlyStopping(monitor="val_accuracy", patience=5),
  81. DVCLiveCallback(dir="evaluation"),
  82. ]
  83. # Fit the model
  84. history = model.fit(
  85. train_generator,
  86. steps_per_epoch=len(train_generator),
  87. # Number of epochs loaded from `params.yaml`
  88. epochs=params["n_epochs"],
  89. validation_data=test_generator,
  90. callbacks=callbacks,
  91. )
  92. if __name__ == "__main__":
  93. main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...