Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

shelfnet_pascal_aug.py 3.5 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
  1. # TODO: REFACTOR AS YAML FILES RECIPE
  2. import super_gradients
  3. import torch
  4. from super_gradients.training.datasets import PascalAUG2012SegmentationDataSetInterface
  5. from super_gradients.training import SgModel, MultiGPUMode
  6. from super_gradients.training.sg_model.sg_model import StrictLoad
  7. from super_gradients.training.metrics.segmentation_metrics import PixelAccuracy, IoU
  8. super_gradients.init_trainer()
  9. pascal_aug_dataset_params = {"batch_size": 16,
  10. "test_batch_size": 16,
  11. "dataset_dir": "/data/pascal_voc_2012/VOCaug/dataset/",
  12. "s3_link": None,
  13. "img_size": 512,
  14. "train_loader_drop_last": True,
  15. }
  16. shelfnet_lw_pascal_aug_training_params = {"max_epochs": 250, "initial_lr": 1e-2, "loss": "shelfnet_ohem_loss",
  17. "optimizer": "SGD", "mixed_precision": False, "lr_mode": "poly",
  18. "optimizer_params": {"momentum": 0.9, "weight_decay": 1e-4,
  19. "nesterov": False},
  20. "load_opt_params": False, "train_metrics_list": [PixelAccuracy(), IoU(21)],
  21. "valid_metrics_list": [PixelAccuracy(), IoU(21)],
  22. "loss_logging_items_names": ["Loss1/4", "Loss1/8", "Loss1/16", "Loss"],
  23. "metric_to_watch": "IoU",
  24. "greater_metric_to_watch_is_better": True}
  25. shelfnet_lw_arch_params = {"num_classes": 21, "load_checkpoint": True, "strict_load": StrictLoad.ON,
  26. "multi_gpu_mode": MultiGPUMode.OFF, "load_weights_only": True,
  27. "load_backbone": True, "source_ckpt_folder_name": 'resnet_backbones'}
  28. if torch.cuda.is_available() and torch.cuda.device_count() > 1:
  29. data_loader_num_workers = 16
  30. shelfnet_lw_pascal_aug_training_params["initial_lr"] = shelfnet_lw_pascal_aug_training_params["initial_lr"] / 2.
  31. else:
  32. # SINGLE GPU TRAINING
  33. data_loader_num_workers = 8
  34. epoc_metrics_headers = {"Epoch": 0, "gpu_mem": 0.0, "Loss1/4": 0.0, "Loss1/8": 0.0, "Loss1/16": 0.0,
  35. "TrainLoss": 0.0, "targets": 0, "img_size": 0}
  36. results_titles = ['LossP', 'Loss8', 'Loss16', 'Train Loss', 'pixAcc', 'mIOU', 'Test Loss']
  37. # SET THE *LIGHT-WEIGHT* SHELFNET ARCHITECTURE SIZE (UN-COMMENT TO TRAIN)
  38. model_size_str = '34'
  39. # model_size_str = '18'
  40. # BUILD THE LIGHT-WEIGHT SHELFNET ARCHITECTURE FOR TRAINING
  41. experiment_name_prefix = 'shelfnet_lw_'
  42. experiment_name_dataset_suffix = '_pascal_aug_encoding_dataset_train_250_epochs_no_batchnorm_decoder'
  43. experiment_name = experiment_name_prefix + model_size_str + experiment_name_dataset_suffix
  44. model = SgModel(experiment_name, model_checkpoints_location='local', multi_gpu=True,
  45. ckpt_name='resnet' + model_size_str + '.pth',
  46. epoch_metric_headers=epoc_metrics_headers,
  47. results_titles=results_titles)
  48. pascal_aug_datasaet_interface = PascalAUG2012SegmentationDataSetInterface(
  49. dataset_params=pascal_aug_dataset_params,
  50. cache_labels=False)
  51. model.connect_dataset_interface(pascal_aug_datasaet_interface, data_loader_num_workers=data_loader_num_workers)
  52. model.build_model('shelfnet' + model_size_str, arch_params=shelfnet_lw_arch_params)
  53. print('Training ShelfNet-LW model: ' + experiment_name)
  54. model.train(training_params=shelfnet_lw_pascal_aug_training_params)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...