Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

train.py 1.1 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
  1. from config import MODELS_DIR, IMAGES
  2. import matplotlib.pyplot as plt
  3. import torch
  4. from fastai.vision.all import *
  5. from fastai.metrics import error_rate, accuracy
  6. torch.cuda.set_device(0)
  7. import mlflow
  8. import os
  9. mlflow.set_tracking_uri(os.getenv("MLFLOW_TRACKING_URI"))
  10. def get_experiment_id(name):
  11. exp = mlflow.get_experiment_by_name(name)
  12. if exp is None:
  13. exp_id = mlflow.create_experiment(name)
  14. return exp_id
  15. return exp.experiment_id
  16. exp_id = get_experiment_id("yoga_lover")
  17. if __name__ == "__main__":
  18. data = ImageDataLoaders.from_folder(IMAGES, valid_pct=0.2, item_tfms=Resize(224))
  19. print("Training the model...")
  20. learn = vision_learner(data, resnet34, metrics=[accuracy, error_rate])
  21. mlflow.fastai.autolog()
  22. learn.remove_cb(ProgressCallback)
  23. with mlflow.start_run(experiment_id=exp_id):
  24. learn.fine_tune(3)
  25. print("Training completed.")
  26. model_path="/content/Yoga-Pose-Classification/Saved_Models/"
  27. os.mkdir(model_path)
  28. print("Saving the model...")
  29. learn.save(model_path+MODELS_DIR, with_opt=False)
  30. print("done.")
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...