Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

ddrnet_segmetation_example.py 5.2 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
  1. """
  2. TODO: REFACTOR AS YAML FILES RECIPE
  3. Train DDRNet23 according to the paper
  4. Usage:
  5. python -m torch.distributed.launch --nproc_per_node=4 ddrnet_segmentation_example.py [-s for slim]
  6. [-d $n for decinet_$n backbone] --pretrained_bb_path <path>
  7. Training time:
  8. DDRNet23: 19H (on 4 x 2080Ti)
  9. DDRNet23 slim: 13H (on 4 x 2080Ti)
  10. Validation mIoU:
  11. DDRNet23: 78.94 (paper: 79.1±0.3)
  12. DDRNet23 slim: 76.79 (paper: 77.3±0.4)
  13. Official git repo:
  14. https://github.com/ydhongHIT/DDRNet
  15. Paper:
  16. https://arxiv.org/pdf/2101.06085.pdf
  17. Pretained checkpoints:
  18. Backbones (trained by the original authors):
  19. s3://deci-model-safe-research/DDRNet/DDRNet23_bb_imagenet.pth
  20. s3://deci-model-safe-research/DDRNet/DDRNet23s_bb_imagenet.pth
  21. Segmentation (trained using this recipe:
  22. s3://deci-model-safe-research/DDRNet/DDRNet23_new/ckpt_best.pth
  23. s3://deci-model-safe-research/DDRNet/DDRNet23s_new/ckpt_best.pth
  24. Comments:
  25. * Pretrained backbones were used
  26. * To pretrain the backbone on imagenet - see ddrnet_classification_example
  27. """
  28. import torch
  29. from super_gradients.training.metrics.segmentation_metrics import IoU, PixelAccuracy
  30. import super_gradients
  31. from super_gradients.training import SgModel, MultiGPUMode
  32. import argparse
  33. import torchvision.transforms as transforms
  34. from super_gradients.training.utils.segmentation_utils import RandomFlip, PadShortToCropSize, CropImageAndMask, RandomRescale
  35. from super_gradients.training.losses.ddrnet_loss import DDRNetLoss
  36. from super_gradients.training.datasets.segmentation_datasets.cityscape_segmentation import CITYSCAPES_IGNORE_LABEL
  37. from super_gradients.training.datasets.dataset_interfaces.dataset_interface import CityscapesDatasetInterface
  38. parser = argparse.ArgumentParser()
  39. super_gradients.init_trainer()
  40. parser.add_argument("--reload", action="store_true")
  41. parser.add_argument("--max_epochs", type=int, default=485)
  42. parser.add_argument("--batch", type=int, default=3)
  43. parser.add_argument("--img_size", type=int, default=1024)
  44. parser.add_argument("--experiment_name", type=str, default="ddrnet_23")
  45. parser.add_argument("--pretrained_bb_path", type=str)
  46. parser.add_argument("-s", "--slim", action="store_true", help='train the slim version of DDRNet23')
  47. args, _ = parser.parse_known_args()
  48. distributed = super_gradients.is_distributed()
  49. devices = torch.cuda.device_count() if not distributed else 1
  50. dataset_params = {
  51. "batch_size": args.batch,
  52. "val_batch_size": args.batch,
  53. "dataset_dir": "/home/ofri/cityscapes/",
  54. "crop_size": args.img_size,
  55. "img_size": args.img_size,
  56. "image_mask_transforms_aug": transforms.Compose([
  57. # ColorJitterSeg(brightness=0.5, contrast=0.5, saturation=0.5), # TODO - add
  58. RandomFlip(),
  59. RandomRescale(scales=(0.5, 2.0)),
  60. PadShortToCropSize(args.img_size, fill_mask=CITYSCAPES_IGNORE_LABEL,
  61. fill_image=(CITYSCAPES_IGNORE_LABEL, 0, 0)), # Legacy padding color that works best with this recipe
  62. CropImageAndMask(crop_size=args.img_size, mode="random"),
  63. ]),
  64. "image_mask_transforms": transforms.Compose([]) # no transform for evaluation
  65. }
  66. # num_classes for IoU includes the ignore label
  67. train_metrics_list = [PixelAccuracy(ignore_label=CITYSCAPES_IGNORE_LABEL),
  68. IoU(num_classes=20, ignore_index=CITYSCAPES_IGNORE_LABEL)]
  69. valid_metrics_list = [PixelAccuracy(ignore_label=CITYSCAPES_IGNORE_LABEL),
  70. IoU(num_classes=20, ignore_index=CITYSCAPES_IGNORE_LABEL)]
  71. train_params = {"max_epochs": args.max_epochs,
  72. "initial_lr": 1e-2,
  73. "loss": DDRNetLoss(ignore_label=CITYSCAPES_IGNORE_LABEL, num_pixels_exclude_ignored=False),
  74. "lr_mode": "poly",
  75. "ema": True, # unlike the paper (not specified in paper)
  76. "average_best_models": True,
  77. "optimizer": "SGD",
  78. "mixed_precision": False,
  79. "optimizer_params":
  80. {"weight_decay": 5e-4,
  81. "momentum": 0.9},
  82. "train_metrics_list": train_metrics_list,
  83. "valid_metrics_list": valid_metrics_list,
  84. "loss_logging_items_names": ["main_loss", "aux_loss", "Loss"],
  85. "metric_to_watch": "IoU",
  86. "greater_metric_to_watch_is_better": True
  87. }
  88. arch_params = {"num_classes": 19, "aux_head": True, "sync_bn": True}
  89. model = SgModel(args.experiment_name,
  90. multi_gpu=MultiGPUMode.DISTRIBUTED_DATA_PARALLEL if distributed else MultiGPUMode.DATA_PARALLEL,
  91. device='cuda')
  92. dataset_interface = CityscapesDatasetInterface(dataset_params=dataset_params, cache_labels=False)
  93. model.connect_dataset_interface(dataset_interface, data_loader_num_workers=8 * devices)
  94. model.build_model(architecture="ddrnet_23_slim" if args.slim else "ddrnet_23",
  95. arch_params=arch_params,
  96. load_checkpoint=args.reload,
  97. load_weights_only=args.pretrained_bb_path is not None,
  98. load_backbone=args.pretrained_bb_path is not None,
  99. external_checkpoint_path=args.pretrained_bb_path)
  100. model.train(training_params=train_params)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...