Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

metric_utils.py 5.7 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
  1. import numpy as np
  2. import torch
  3. from torchmetrics import MetricCollection
  4. from super_gradients.training.metrics.detection_metrics import ap_per_class
  5. from super_gradients.training.utils.utils import AverageMeter
  6. def calc_batch_prediction_detection_metrics_per_class(metrics, dataset_interface, iou_thres, silent_mode, images_counter,
  7. per_class_verbosity, class_names, test_loss):
  8. metrics = [np.concatenate(x, 0) for x in list(zip(*metrics))]
  9. if len(metrics):
  10. precision, recall, average_precision, f1, ap_class = ap_per_class(*metrics)
  11. if iou_thres.is_range():
  12. precision, recall, average_precision, f1 = precision[:, 0], recall[:, 0], average_precision.mean(
  13. 1), average_precision[:, 0]
  14. mean_precision, mean_recall, map, mf1 = precision.mean(), recall.mean(), average_precision.mean(), f1.mean()
  15. targets_per_class = np.bincount(metrics[3].astype(np.int64),
  16. minlength=len(dataset_interface.testset.classes))
  17. else:
  18. targets_per_class = torch.zeros(1)
  19. if not silent_mode:
  20. # PRINT RESULTS
  21. map_str = 'mAP@%.1f' % iou_thres[0] if not iou_thres.is_range() else 'mAP@%.2f:%.2f' % iou_thres
  22. print(('%15s' * 7) % ('Class', 'Images', 'Targets', 'Precision', 'Recall', map_str, 'F1'))
  23. pf = '%15s' + '%15.3g' * 6 # print format
  24. print(pf % ('all', images_counter, targets_per_class.sum(), mean_precision, mean_recall, map, mf1))
  25. # PRINT RESULTS PER CLASS
  26. if len(dataset_interface.testset.classes) > 1 and len(metrics) and per_class_verbosity:
  27. for i, c in enumerate(ap_class):
  28. print(pf % (
  29. class_names[c], images_counter, targets_per_class[c], precision[i], recall[i],
  30. average_precision[i],
  31. f1[i]))
  32. results_tuple = (mean_precision, mean_recall, map, mf1, *test_loss.average)
  33. return results_tuple
  34. def get_logging_values(loss_loggings: AverageMeter, metrics: MetricCollection, criterion=None):
  35. """
  36. @param loss_loggings: AverageMeter running average for the loss items
  37. @param metrics: MetricCollection object for running user specified metrics
  38. @param criterion the object loss_loggings average meter is monitoring, when set to None- only the metrics values are
  39. computed and returned.
  40. @return: tuple of the computed values
  41. """
  42. if criterion is not None:
  43. loss_loggingg_avg = loss_loggings.average
  44. if not isinstance(loss_loggingg_avg, tuple):
  45. loss_loggingg_avg = tuple([loss_loggingg_avg])
  46. logging_vals = loss_loggingg_avg + get_metrics_results_tuple(metrics)
  47. else:
  48. logging_vals = get_metrics_results_tuple(metrics)
  49. return logging_vals
  50. def get_metrics_titles(metrics_collection: MetricCollection):
  51. """
  52. @param metrics_collection: MetricCollection object for running user specified metrics
  53. @return: list of all the names of the computed values list(str)
  54. """
  55. titles = []
  56. for metric_name, metric in metrics_collection.items():
  57. if metric_name == "additional_items":
  58. continue
  59. elif hasattr(metric, "component_names"):
  60. titles += metric.component_names
  61. else:
  62. titles.append(metric_name)
  63. return titles
  64. def get_metrics_results_tuple(metrics_collection: MetricCollection):
  65. """
  66. @param metrics_collection: metrics collection of the user specified metrics
  67. @type metrics_collection
  68. @return: tuple of metrics values
  69. """
  70. if metrics_collection is None:
  71. results_tuple = ()
  72. else:
  73. results_tuple = tuple(flatten_metrics_dict(metrics_collection.compute()).values())
  74. return results_tuple
  75. def flatten_metrics_dict(metrics_dict: dict):
  76. """
  77. :param metrics_dict - dictionary of metric values where values can also be dictionaries containing subvalues
  78. (in the case of compound metrics)
  79. @return: flattened dict of metric values i.e {metric1_name: metric1_value...}
  80. """
  81. flattened = {}
  82. for metric_name, metric_val in metrics_dict.items():
  83. if metric_name == "additional_items":
  84. continue
  85. # COLLECT ALL OF THE COMPONENTS IN THE CASE OF COMPOUND METRICS
  86. elif isinstance(metric_val, dict):
  87. for sub_metric_name, sub_metric_val in metric_val.items():
  88. flattened[sub_metric_name] = sub_metric_val
  89. else:
  90. flattened[metric_name] = metric_val
  91. return flattened
  92. def get_metrics_dict(metrics_tuple, metrics_collection, loss_logging_item_names):
  93. """
  94. Returns a dictionary with the epoch results as values and their names as keys.
  95. @param metrics_tuple: the result tuple
  96. @param metrics_collection: MetricsCollection
  97. @param loss_logging_item_names: loss component's names.
  98. @return: dict
  99. """
  100. keys = loss_logging_item_names + get_metrics_titles(metrics_collection)
  101. metrics_dict = dict(zip(keys, list(metrics_tuple)))
  102. return metrics_dict
  103. def get_train_loop_description_dict(metrics_tuple, metrics_collection, loss_logging_item_names, **log_items):
  104. """
  105. Returns a dictionary with the epoch's logging items as values and their names as keys, with the purpose of
  106. passing it as a description to tqdm's progress bar.
  107. @param metrics_tuple: the result tuple
  108. @param metrics_collection: MetricsCollection
  109. @param loss_logging_item_names: loss component's names.
  110. @param log_items additional logging items to be rendered.
  111. @return: dict
  112. """
  113. log_items.update(get_metrics_dict(metrics_tuple, metrics_collection, loss_logging_item_names))
  114. for key, value in log_items.items():
  115. if isinstance(value, torch.Tensor):
  116. log_items[key] = value.detach().item()
  117. return log_items
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...