Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

eval_voc.py 2.8 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
  1. import _add_project_path
  2. import os
  3. import tqdm
  4. import pickle
  5. import tensorflow as tf
  6. from termcolor import colored
  7. from absl import flags, app
  8. from calc4ap.voc import CalcVOCmAP
  9. from libs.utils import yolo_output2boxes, box_postp2use
  10. from datasets.voc_tfds.voc import GetVoc
  11. from datasets.voc_tfds.libs import prep_voc_data, VOC_CLS_MAP
  12. from datasets.voc_tfds.eval.prepare_eval import get_labels
  13. from configs import ProjectPath, cfg
  14. FLAGS = flags.FLAGS
  15. flags.DEFINE_integer('batch_size', default=cfg.batch_size, help='Batch size')
  16. flags.DEFINE_string('pb_dir', default=os.path.join(ProjectPath.VOC_CKPTS_DIR.value, 'yolo_voc_448x448'), help='Save pb directory path')
  17. flags.DEFINE_float('val_ds_sample_ratio', default=cfg.val_ds_sample_ratio, help='Validation dataset sampling ratio')
  18. def main(_argv):
  19. yolo = tf.saved_model.load(
  20. export_dir=FLAGS.pb_dir,
  21. tags=None,
  22. options=None,
  23. )
  24. voc = GetVoc(batch_size=FLAGS.batch_size)
  25. val_ds = voc.get_val_ds(sample_ratio=FLAGS.val_ds_sample_ratio)
  26. val_preds = list()
  27. val_labels_path = os.path.join(ProjectPath.DATASETS_DIR.value, 'voc_tfds', 'eval', 'val_labels_448_full.pickle')
  28. if FLAGS.val_ds_sample_ratio == 1. and os.path.exists(val_labels_path):
  29. val_labels = pickle.load(open(val_labels_path, 'rb'))
  30. else:
  31. val_labels = get_labels(ds=val_ds, input_height=cfg.input_height, input_width=cfg.input_width, cls_map=VOC_CLS_MAP, full_save=False)
  32. img_id = 0
  33. for _step, batch_data in tqdm.tqdm(enumerate(val_ds, 1), total=len(val_ds), desc='Validation'):
  34. batch_imgs, _batch_labels = prep_voc_data(batch_data, input_height=cfg.input_height, input_width=cfg.input_width, val=True)
  35. yolo_output_raw = yolo(batch_imgs, training=False)
  36. yolo_boxes = yolo_output2boxes(yolo_output_raw, cfg.input_height, cfg.input_width, cfg.cell_size, cfg.boxes_per_cell)
  37. for i in range(len(yolo_boxes)):
  38. output_boxes = box_postp2use(yolo_boxes[i], cfg.nms_iou_thr, 0.)
  39. if output_boxes.size == 0:
  40. img_id += 1
  41. continue
  42. for output_box in output_boxes:
  43. *pts, conf, cls_idx = output_box
  44. cls_name = VOC_CLS_MAP[cls_idx]
  45. val_preds.append([*map(round, pts), conf, cls_name, img_id])
  46. img_id += 1
  47. voc_ap = CalcVOCmAP(labels=val_labels, preds=val_preds, iou_thr=0.5, conf_thr=0.0)
  48. ap_summary = voc_ap.get_summary()
  49. mAP = ap_summary.pop('mAP')
  50. APs_log = '\n====== mAP ======\n' + f'* mAP: {mAP:<8.4f}\n'
  51. for cls_name, ap in ap_summary.items():
  52. APs_log += f'- {cls_name}: {ap:<8.4f}\n'
  53. APs_log += '====== ====== ======\n'
  54. APs_log_colored = colored(APs_log, 'magenta')
  55. print(APs_log_colored)
  56. if __name__ == '__main__':
  57. app.run(main)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...