Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

optical_flow_loss_test.py 3.0 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
  1. import torch
  2. import torch.nn as nn
  3. import unittest
  4. from super_gradients.training.losses import OpticalFlowLoss
  5. from super_gradients.training.losses.loss_utils import apply_reduce
  6. class OpticalFlowLossTest(unittest.TestCase):
  7. def setUp(self) -> None:
  8. self.img_size = 100
  9. self.gamma = 0.8
  10. self.max_flow = 400
  11. self.reduction = "mean"
  12. self.batch_size = 1
  13. def _get_default_predictions_tensor(self, n_predictions: int, fill_value: float):
  14. return [torch.empty(self.batch_size, 2, self.img_size, self.img_size).fill_(fill_value) for _ in range(n_predictions)]
  15. def _get_default_target_tensor(self):
  16. return (torch.zeros(self.batch_size, 2, self.img_size, self.img_size).long(), torch.ones(self.img_size, self.img_size))
  17. def _assertion_flow_loss_torch_values(self, expected_value: torch.Tensor, found_value: torch.Tensor, rtol: float = 1e-5):
  18. self.assertTrue(torch.allclose(found_value, expected_value, rtol=rtol), msg=f"Unequal flow loss: excepted: {expected_value}, found: {found_value}")
  19. def test_flow_loss_l1_criterion(self):
  20. predictions = self._get_default_predictions_tensor(3, 2.5)
  21. target, valid = self._get_default_target_tensor()
  22. criterion = nn.L1Loss()
  23. loss_fn = OpticalFlowLoss(criterion=criterion, gamma=self.gamma, max_flow=self.max_flow, reduction=self.reduction)
  24. flow_loss = loss_fn(predictions, (target, valid))
  25. # expected_flow_loss
  26. expected_flow_loss = 0.0
  27. mag = torch.sum(target**2, dim=1).sqrt()
  28. valid = (valid >= 0.5) & (mag < self.max_flow)
  29. n_predictions = len(predictions)
  30. for i in range(n_predictions):
  31. i_weight = self.gamma ** (n_predictions - i - 1)
  32. i_loss = i_weight * (valid[:, None] * (predictions[i] - target).abs()) # L1 dist
  33. expected_flow_loss += apply_reduce(i_loss, self.reduction)
  34. self._assertion_flow_loss_torch_values(torch.tensor(expected_flow_loss), flow_loss)
  35. def test_flow_loss_mse_criterion(self):
  36. predictions = self._get_default_predictions_tensor(3, 2.5)
  37. target, valid = self._get_default_target_tensor()
  38. criterion = nn.MSELoss()
  39. loss_fn = OpticalFlowLoss(criterion=criterion, gamma=self.gamma, max_flow=self.max_flow, reduction=self.reduction)
  40. flow_loss = loss_fn(predictions, (target, valid))
  41. # expected_flow_loss
  42. expected_flow_loss = 0.0
  43. mag = torch.sum(target**2, dim=1).sqrt()
  44. valid = (valid >= 0.5) & (mag < self.max_flow)
  45. n_predictions = len(predictions)
  46. for i in range(n_predictions):
  47. i_weight = self.gamma ** (n_predictions - i - 1)
  48. i_loss = i_weight * (valid[:, None] * ((predictions[i] - target) ** 2).mean()) # mse dist
  49. expected_flow_loss += apply_reduce(i_loss, self.reduction)
  50. self._assertion_flow_loss_torch_values(torch.tensor(expected_flow_loss), flow_loss)
  51. if __name__ == "__main__":
  52. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...