|
@@ -9,7 +9,12 @@ from super_gradients.common.environment.checkpoints_dir_utils import get_checkpo
|
|
class ShortenedRecipesAccuracyTests(unittest.TestCase):
|
|
class ShortenedRecipesAccuracyTests(unittest.TestCase):
|
|
@classmethod
|
|
@classmethod
|
|
def setUp(cls):
|
|
def setUp(cls):
|
|
- cls.experiment_names = ["shortened_cifar10_resnet_accuracy_test", "shortened_coco2017_yolox_n_map_test", "shortened_cityscapes_regseg48_iou_test"]
|
|
|
|
|
|
+ cls.experiment_names = [
|
|
|
|
+ "shortened_cifar10_resnet_accuracy_test",
|
|
|
|
+ "shortened_coco2017_yolox_n_map_test",
|
|
|
|
+ "shortened_cityscapes_regseg48_iou_test",
|
|
|
|
+ "shortened_coco2017_pose_dekr_w32_ap_test",
|
|
|
|
+ ]
|
|
|
|
|
|
def test_shortened_cifar10_resnet_accuracy(self):
|
|
def test_shortened_cifar10_resnet_accuracy(self):
|
|
self.assertTrue(self._reached_goal_metric(experiment_name="shortened_cifar10_resnet_accuracy_test", metric_value=0.9167, delta=0.05))
|
|
self.assertTrue(self._reached_goal_metric(experiment_name="shortened_cifar10_resnet_accuracy_test", metric_value=0.9167, delta=0.05))
|
|
@@ -24,6 +29,9 @@ class ShortenedRecipesAccuracyTests(unittest.TestCase):
|
|
def test_shortened_cityscapes_regseg48_iou(self):
|
|
def test_shortened_cityscapes_regseg48_iou(self):
|
|
self.assertTrue(self._reached_goal_metric(experiment_name="shortened_cityscapes_regseg48_iou_test", metric_value=0.263, delta=0.05))
|
|
self.assertTrue(self._reached_goal_metric(experiment_name="shortened_cityscapes_regseg48_iou_test", metric_value=0.263, delta=0.05))
|
|
|
|
|
|
|
|
+ def test_shortened_coco_dekr_32_ap_test(self):
|
|
|
|
+ self.assertTrue(self._reached_goal_metric(experiment_name="shortened_coco2017_pose_dekr_w32_ap_test", metric_value=0.000154, delta=0.0001))
|
|
|
|
+
|
|
@classmethod
|
|
@classmethod
|
|
def _reached_goal_metric(cls, experiment_name: str, metric_value: float, delta: float):
|
|
def _reached_goal_metric(cls, experiment_name: str, metric_value: float, delta: float):
|
|
checkpoints_dir_path = get_checkpoints_dir_path(experiment_name=experiment_name)
|
|
checkpoints_dir_path = get_checkpoints_dir_path(experiment_name=experiment_name)
|