Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

utils.py 4.7 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
  1. import time
  2. import json
  3. import numpy as np
  4. import math
  5. from contextlib import contextmanager
  6. import torch
  7. import torch.nn as nn
  8. import torch.nn.functional as F
  9. def calculate_model_losses(args, model, bbox, bbox_pred, angles, angles_pred, mu=None, logvar=None, KL_weight=None):
  10. dtype_f = bbox_pred.data.type()
  11. total_loss = 0.0
  12. losses = {}
  13. loss_bbox = F.l1_loss(bbox_pred, bbox)
  14. total_loss = add_loss(total_loss, loss_bbox, losses, 'bbox_pred', 1)
  15. loss_angle = F.nll_loss(angles_pred, angles)
  16. total_loss = add_loss(total_loss, loss_angle, losses, 'angle_pred', 1)
  17. if not args.use_AE:
  18. try:
  19. loss_gauss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / mu.size(0)
  20. except:
  21. print("blowup!!!")
  22. print("logvar", torch.sum(logvar.data), torch.sum(torch.abs(logvar.data)), torch.max(logvar.data),
  23. torch.min(logvar.data))
  24. print("mu", torch.sum(mu.data), torch.sum(torch.abs(mu.data)), torch.max(mu.data), torch.min(mu.data))
  25. return total_loss, losses
  26. total_loss = add_loss(total_loss, loss_gauss, losses, 'KLD_Gauss', KL_weight)
  27. return total_loss, losses
  28. def compute_rel(box1, box2, name1, name2):
  29. center1 = np.array([(box1[0] + box1[3]) / 2, (box1[1] + box1[4]) / 2, (box1[2] + box1[5]) / 2])
  30. center2 = np.array([(box2[0] + box2[3]) / 2, (box2[1] + box2[4]) / 2, (box2[2] + box2[5]) / 2])
  31. if name2 == "__room__":
  32. p = "__in_room__"
  33. else:
  34. # "on" relationship
  35. p = None
  36. if center1[0] >= box2[0] and center1[0] <= box2[3]:
  37. if center1[2] >= box2[2] and center1[2] <= box2[5]:
  38. delta1 = center1[1] - center2[1]
  39. delta2 = (box1[4] - box1[1] + box2[4] - box2[1]) / 2
  40. if abs(delta1 - delta2) < 0.05:
  41. p = 'on'
  42. return p
  43. # random relationship
  44. sx0, sy0, sz0, sx1, sy1, sz1 = box1
  45. ox0, oy0, oz0, ox1, oy1, oz1 = box2
  46. d = center1 - center2
  47. theta = math.atan2(d[2], d[0]) # range -pi to pi
  48. area_s = (sx1 - sx0) * (sz1 - sz0)
  49. area_o = (ox1 - ox0) * (oz1 - oz0)
  50. ix0, ix1 = max(sx0, ox0), min(sx1, ox1)
  51. iz0, iz1 = max(sz0, oz0), min(sz1, oz1)
  52. area_i = max(0, ix1 - ix0) * max(0, iz1 - iz0)
  53. iou = area_i / (area_s + area_o - area_i)
  54. touching = 0.0001 < iou < 0.5
  55. if sx0 < ox0 and sx1 > ox1 and sz0 < oz0 and sz1 > oz1:
  56. p = 'surrounding'
  57. elif sx0 > ox0 and sx1 < ox1 and sz0 > oz0 and sz1 < oz1:
  58. p = 'inside'
  59. elif theta >= 3 * math.pi / 4 or theta <= -3 * math.pi / 4:
  60. p = 'right touching' if touching else 'left of'
  61. elif -3 * math.pi / 4 <= theta < -math.pi / 4:
  62. p = 'behind touching' if touching else 'behind'
  63. elif -math.pi / 4 <= theta < math.pi / 4:
  64. p = 'left touching' if touching else 'right of'
  65. elif math.pi / 4 <= theta < 3 * math.pi / 4:
  66. p = 'front touching' if touching else 'in front of'
  67. return p
  68. def load_json(json_file):
  69. with open(json_file, 'r') as f:
  70. var = json.load(f)
  71. return var
  72. def write_json(json_file, data):
  73. with open(json_file, 'w') as f:
  74. json.dump(data, f)
  75. def int_tuple(s):
  76. return tuple(int(i) for i in s.split(','))
  77. def float_tuple(s):
  78. return tuple(float(i) for i in s.split(','))
  79. def str_tuple(s):
  80. return tuple(s.split(','))
  81. def bool_flag(s):
  82. if s == '1':
  83. return True
  84. elif s == '0':
  85. return False
  86. msg = 'Invalid value "%s" for bool flag (should be 0 or 1)'
  87. raise ValueError(msg % s)
  88. def tensor_aug(tensors, volatile=False, use_gpu=True):
  89. var_list = []
  90. for tensor in tensors:
  91. if use_gpu:
  92. var = tensor.cuda()
  93. else:
  94. var = tensor
  95. if volatile:
  96. var.requires_grad = False
  97. var_list.append(var)
  98. return tuple(var_list)
  99. @contextmanager
  100. def timeit(msg, should_time=True):
  101. if should_time:
  102. torch.cuda.synchronize()
  103. t0 = time.time()
  104. yield
  105. if should_time:
  106. torch.cuda.synchronize()
  107. t1 = time.time()
  108. duration = (t1 - t0) * 1000.0
  109. print('%s: %.2f ms' % (msg, duration))
  110. def add_loss(total_loss, curr_loss, loss_dict, loss_name, weight=1):
  111. curr_loss_weighted = curr_loss * weight
  112. loss_dict[loss_name] = curr_loss_weighted.item()
  113. if total_loss is not None:
  114. return total_loss + curr_loss_weighted
  115. else:
  116. return curr_loss_weighted
  117. return 0
  118. def get_model_attr(_object, attr):
  119. if isinstance(_object, nn.DataParallel):
  120. return getattr(_object.module, attr)
  121. else:
  122. return getattr(_object, attr)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...