Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

train_model.py 3.2 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
  1. # define and train model
  2. import keras
  3. from keras.applications.vgg16 import VGG16
  4. from keras.layers.core import Dense, Flatten
  5. from keras.models import Model
  6. from keras.engine.input_layer import Input
  7. import yaml
  8. import matplotlib.pyplot as plt
  9. import json
  10. from keras import optimizers
  11. params = yaml.safe_load(open("params.yaml"))["training"]
  12. img_shape = (224, 224, 3)
  13. num_classes = params["num_classes"]
  14. nb_epoch = params["nb_epoch"]
  15. base_lr = params["base_lr"]
  16. import numpy as np
  17. train_data = np.load('data/train_data.npy')
  18. train_label = np.load('data/train_label.npy')
  19. val_data = np.load('data/val_data.npy')
  20. val_label = np.load('data/val_label.npy')
  21. test_data = np.load('data/test_data.npy')
  22. test_label = np.load('data/test_label.npy')
  23. def model_def():
  24. model_vgg16_conv = VGG16(weights='imagenet', include_top=False)
  25. # Create your own input format
  26. keras_input = Input(shape=img_shape, name='image_input')
  27. # Use the generated model
  28. output_vgg16_conv = model_vgg16_conv(keras_input)
  29. # Add the fully-connected layers
  30. x = Flatten(name='flatten')(output_vgg16_conv)
  31. x = Dense(4096, activation='relu', name='fc1')(x)
  32. x = Dense(1024, activation='relu', name='fc2')(x)
  33. x = Dense(128, activation='relu', name='fc3')(x)
  34. x = Dense(64, activation='relu', name='fc4')(x)
  35. x = Dense(num_classes, activation='softmax', name='predictions')(x)
  36. # Create your own model
  37. model = Model(inputs=keras_input, outputs=x)
  38. return model
  39. def schedule(epoch, decay=0.9):
  40. return base_lr * decay ** (epoch)
  41. def train():
  42. model = model_def()
  43. print(model.summary())
  44. sgd = optimizers.SGD(lr=0.001, decay=1e-6, momentum=0.9, nesterov=True)
  45. model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])
  46. #.{epoch:02d}-{val_loss:.2f}
  47. callbacks = [keras.callbacks.ModelCheckpoint('saved-models/weights.h5',
  48. verbose=1, save_best_only=True,
  49. save_weights_only=True),
  50. keras.callbacks.LearningRateScheduler(schedule)]
  51. # train model
  52. result = model.fit(train_data, train_label, epochs=nb_epoch, validation_data=(val_data, val_label),
  53. callbacks=callbacks, verbose=1)
  54. test_score = model.evaluate(test_data, test_label)
  55. print('loss',test_score[0])
  56. print('accuracy',test_score[1])
  57. with open("scores.json", "w") as fd:
  58. json.dump({"loss": test_score[0], "accuracy": test_score[1]}, fd, indent=4)
  59. plt.figure(figsize=[8,6])
  60. plt.plot(result.history['loss'],'r',linewidth=3.0)
  61. plt.plot(result.history['val_loss'],'b',linewidth=3.0)
  62. plt.legend(['Training loss', 'Validation Loss'],fontsize=12)
  63. plt.xlabel('Epochs ',fontsize=12)
  64. plt.ylabel('Loss',fontsize=12)
  65. plt.title('Loss Curves',fontsize=12)
  66. plt.savefig('loss.png')
  67. plt.figure(figsize=[8,6])
  68. plt.plot(result.history['accuracy'],'r',linewidth=3.0)
  69. plt.plot(result.history['val_accuracy'],'b',linewidth=3.0)
  70. plt.legend(['Training Accuracy', 'Validation Accuracy'],fontsize=12)
  71. plt.xlabel('Epochs ',fontsize=12)
  72. plt.ylabel('Accuracy',fontsize=12)
  73. plt.title('Accuracy Curves',fontsize=12)
  74. plt.savefig('accuracy.png')
  75. if __name__ == '__main__':
  76. train()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...