Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

test_optical_flow_metric.py 1.0 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
  1. import torch
  2. import unittest
  3. from super_gradients.training.metrics.optical_flow_metric import EPE
  4. class TestOpticalFlowMetric(unittest.TestCase):
  5. def test_epe_metric(self):
  6. # Specific example data
  7. pred_flow = [torch.ones(1, 2, 100, 100)]
  8. gt_flow = torch.zeros(1, 2, 100, 100)
  9. valid = torch.ones(100, 100)
  10. # Create instances of delta metrics
  11. max_flow = 400
  12. metric = EPE(max_flow=max_flow)
  13. # Update metrics with specific example data
  14. metric.update(pred_flow, (gt_flow, valid))
  15. # Expected metric
  16. mag = torch.sum(gt_flow**2, dim=1).sqrt()
  17. valid = (valid >= 0.5) & (mag < max_flow)
  18. expected_epe = torch.sum((pred_flow[-1] - gt_flow) ** 2, dim=1).sqrt()
  19. expected_epe = expected_epe.view(-1)[valid.view(-1)]
  20. expected_epe = expected_epe.mean().item()
  21. # Compute and assert the delta metrics
  22. self.assertAlmostEqual(metric.compute()["epe"], expected_epe)
  23. if __name__ == "__main__":
  24. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...