Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

train.py 2.9 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
  1. # Original code from: https://www.tensorflow.org/tutorials/images/transfer_learning
  2. from dvclive.keras import DvcLiveCallback
  3. import tensorflow as tf
  4. from tensorflow.keras.callbacks import ModelCheckpoint
  5. from tensorflow.keras.preprocessing import image_dataset_from_directory
  6. from scripts.params import (
  7. BACKBONE,
  8. BATCH_SIZE,
  9. DATASET_DIR,
  10. EPOCHS_FROZEN,
  11. EPOCHS_UNFROZEN,
  12. FINE_TUNE_AT,
  13. IMG_SIZE,
  14. LEARNING_RATE,
  15. PREPROCESS_INPUT,
  16. TRAIN_DIR,
  17. )
  18. #%% Load dataset
  19. train_dataset = image_dataset_from_directory(
  20. DATASET_DIR / "train",
  21. shuffle=True,
  22. batch_size=BATCH_SIZE,
  23. image_size=IMG_SIZE,
  24. ).prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
  25. validation_dataset = image_dataset_from_directory(
  26. DATASET_DIR / "val",
  27. shuffle=True,
  28. batch_size=BATCH_SIZE,
  29. image_size=IMG_SIZE,
  30. ).prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
  31. #%% Define model
  32. # Data augmentation layers
  33. data_augmentation = tf.keras.Sequential([
  34. tf.keras.layers.experimental.preprocessing.RandomFlip('horizontal'),
  35. tf.keras.layers.experimental.preprocessing.RandomRotation(0.2),
  36. ])
  37. # Create the base model from the pre-trained model MobileNet V2
  38. IMG_SHAPE = IMG_SIZE + (3,)
  39. base_model = BACKBONE(input_shape=IMG_SHAPE, include_top=False, weights='imagenet')
  40. inputs = tf.keras.Input(shape=IMG_SHAPE)
  41. x = data_augmentation(inputs)
  42. x = PREPROCESS_INPUT(x)
  43. x = base_model(x, training=False)
  44. x = tf.keras.layers.GlobalAveragePooling2D()(x)
  45. x = tf.keras.layers.Dropout(0.2)(x)
  46. outputs = tf.keras.layers.Dense(1)(x)
  47. model = tf.keras.Model(inputs, outputs)
  48. callbacks = [
  49. # Use dvclive's Keras callback
  50. DvcLiveCallback(),
  51. ModelCheckpoint(str(TRAIN_DIR / "best_weights.h5"), save_best_only=True),
  52. ]
  53. #%% Freeze the base model and train 10 epochs
  54. base_model.trainable = False
  55. model.compile(
  56. optimizer=tf.keras.optimizers.Adam(lr=LEARNING_RATE),
  57. loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
  58. metrics=["accuracy"],
  59. )
  60. model.summary()
  61. history = model.fit(
  62. train_dataset,
  63. epochs=EPOCHS_FROZEN,
  64. validation_data=validation_dataset,
  65. callbacks=callbacks,
  66. )
  67. #%% Unfreeze the base model
  68. base_model.trainable = True
  69. # Let's take a look to see how many layers are in the base model
  70. print("Number of layers in the base model: ", len(base_model.layers))
  71. # Freeze all the layers before the `FINE_TUNE_AT` layer
  72. for layer in base_model.layers[:FINE_TUNE_AT]:
  73. layer.trainable = False
  74. model.compile(
  75. loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
  76. optimizer=tf.keras.optimizers.RMSprop(lr=LEARNING_RATE/10),
  77. metrics=["accuracy"],
  78. )
  79. model.summary()
  80. history_fine = model.fit(
  81. train_dataset,
  82. epochs=EPOCHS_FROZEN + EPOCHS_UNFROZEN,
  83. initial_epoch=EPOCHS_FROZEN,
  84. validation_data=validation_dataset,
  85. callbacks=callbacks,
  86. )
  87. #%% Load best weights and save model
  88. model.load_weights(str(TRAIN_DIR / "best_weights.h5"))
  89. tf.saved_model.save(model, str(TRAIN_DIR / "model"))
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...