Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

qat_integration_test.py 3.4 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
  1. import unittest
  2. from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
  3. from super_gradients.training import Trainer, MultiGPUMode, models
  4. from super_gradients.training.metrics.classification_metrics import Accuracy
  5. import os
  6. from super_gradients.training.utils.quantization_utils import PostQATConversionCallback
  7. class QATIntegrationTest(unittest.TestCase):
  8. def _get_trainer(self, experiment_name):
  9. trainer = Trainer(experiment_name, multi_gpu=MultiGPUMode.OFF)
  10. model = models.get("resnet18", pretrained_weights="imagenet")
  11. return trainer, model
  12. def _get_train_params(self, qat_params):
  13. train_params = {
  14. "max_epochs": 2,
  15. "lr_mode": "step",
  16. "optimizer": "SGD",
  17. "lr_updates": [],
  18. "lr_decay_factor": 0.1,
  19. "initial_lr": 0.001,
  20. "loss": "cross_entropy",
  21. "train_metrics_list": [Accuracy()],
  22. "valid_metrics_list": [Accuracy()],
  23. "metric_to_watch": "Accuracy",
  24. "greater_metric_to_watch_is_better": True,
  25. "average_best_models": False,
  26. "enable_qat": True,
  27. "qat_params": qat_params,
  28. "phase_callbacks": [PostQATConversionCallback(dummy_input_size=(1, 3, 224, 224))],
  29. }
  30. return train_params
  31. def test_qat_from_start(self):
  32. model, net = self._get_trainer("test_qat_from_start")
  33. train_params = self._get_train_params(
  34. qat_params={"start_epoch": 0, "quant_modules_calib_method": "percentile", "calibrate": True, "num_calib_batches": 2, "percentile": 99.99}
  35. )
  36. model.train(model=net, training_params=train_params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader())
  37. def test_qat_transition(self):
  38. model, net = self._get_trainer("test_qat_transition")
  39. train_params = self._get_train_params(
  40. qat_params={"start_epoch": 1, "quant_modules_calib_method": "percentile", "calibrate": True, "num_calib_batches": 2, "percentile": 99.99}
  41. )
  42. model.train(model=net, training_params=train_params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader())
  43. def test_qat_from_calibrated_ckpt(self):
  44. model, net = self._get_trainer("generate_calibrated_model")
  45. train_params = self._get_train_params(
  46. qat_params={"start_epoch": 0, "quant_modules_calib_method": "percentile", "calibrate": True, "num_calib_batches": 2, "percentile": 99.99}
  47. )
  48. model.train(model=net, training_params=train_params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader())
  49. calibrated_model_path = os.path.join(model.checkpoints_dir_path, "ckpt_calibrated_percentile_99.99.pth")
  50. model, net = self._get_trainer("test_qat_from_calibrated_ckpt")
  51. train_params = self._get_train_params(
  52. qat_params={
  53. "start_epoch": 0,
  54. "quant_modules_calib_method": "percentile",
  55. "calibrate": False,
  56. "calibrated_model_path": calibrated_model_path,
  57. "num_calib_batches": 2,
  58. "percentile": 99.99,
  59. }
  60. )
  61. model.train(model=net, training_params=train_params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader())
  62. if __name__ == "__main__":
  63. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...