Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

train.py 2.6 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
  1. import pathlib
  2. import pandas as pd
  3. from tensorflow import keras
  4. from tensorflow.keras import layers
  5. from zntrack import Node, utils, zn
  6. from zntrack.core import ZnTrackOption
  7. class TFModel(ZnTrackOption):
  8. dvc_option = "outs"
  9. zn_type = utils.ZnTypes.RESULTS
  10. def get_filename(self, instance) -> pathlib.Path:
  11. """Filename depending on the instance node_name"""
  12. return pathlib.Path("nodes", instance.node_name, "model")
  13. def save(self, instance):
  14. """Serialize and save values to file"""
  15. model = self.__get__(instance, self.owner)
  16. file = self.get_filename(instance)
  17. model.save(file)
  18. def get_data_from_files(self, instance):
  19. """Load values from file and deserialize"""
  20. file = self.get_filename(instance)
  21. model = keras.models.load_model(file)
  22. return model
  23. class MLModel(Node):
  24. # dependencies
  25. train_data = zn.deps()
  26. # outputs
  27. training_history = zn.plots()
  28. metrics = zn.metrics()
  29. # custom model output
  30. model = TFModel()
  31. # parameter
  32. epochs = zn.params()
  33. filters = zn.params([4])
  34. dense = zn.params([4])
  35. optimizer = zn.params("adam")
  36. def run(self):
  37. """Primary Node Method"""
  38. self.build_model()
  39. self.train_model()
  40. def train_model(self):
  41. """Train the model"""
  42. self.model.compile(
  43. optimizer=self.optimizer,
  44. loss="categorical_crossentropy",
  45. metrics=["accuracy"],
  46. )
  47. print(self.model.summary())
  48. history = self.model.fit(
  49. self.train_data.features,
  50. self.train_data.labels,
  51. validation_split=0.3,
  52. epochs=self.epochs,
  53. batch_size=64,
  54. )
  55. self.training_history = pd.DataFrame(history.history)
  56. self.training_history.index.name = "epoch"
  57. # use the last values for model metrics
  58. self.metrics = dict(self.training_history.iloc[-1])
  59. def build_model(self):
  60. """Build the model using keras.Sequential API"""
  61. inputs = keras.Input(shape=(28, 28, 1))
  62. cargo = inputs
  63. for filters in self.filters:
  64. cargo = layers.Conv2D(
  65. filters=filters, kernel_size=(3, 3), padding="same", activation="relu"
  66. )(cargo)
  67. cargo = layers.MaxPooling2D((2, 2))(cargo)
  68. cargo = layers.Flatten()(cargo)
  69. for dense in self.dense:
  70. cargo = layers.Dense(dense, activation="relu")(cargo)
  71. output = layers.Dense(25, activation="softmax")(cargo)
  72. self.model = keras.Model(inputs=inputs, outputs=output)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...