Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

train.py 4.9 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
  1. from options.options import Options
  2. import os
  3. import torch
  4. from build_dataset_model import build_loaders, build_model
  5. from utils import get_model_attr, calculate_model_losses, tensor_aug
  6. from collections import defaultdict
  7. import math
  8. def main(args):
  9. vocab, train_loader, val_loader = build_loaders(args)
  10. model, model_kwargs = build_model(args, vocab)
  11. print(model)
  12. model.float().cuda()
  13. optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
  14. restore_path = None
  15. if args.restore_from_checkpoint:
  16. restore_path = '%s_with_model.pt' % args.checkpoint_name
  17. restore_path = os.path.join(args.output_dir, restore_path)
  18. if restore_path is not None and os.path.isfile(restore_path):
  19. print('Restoring from checkpoint:')
  20. print(restore_path)
  21. checkpoint = torch.load(restore_path)
  22. get_model_attr(model, 'load_state_dict')(checkpoint['model_state'])
  23. optimizer.load_state_dict(checkpoint['optim_state'])
  24. t = checkpoint['counters']['t']
  25. if 0 <= args.eval_mode_after <= t:
  26. model.eval()
  27. else:
  28. model.train()
  29. epoch = checkpoint['counters']['epoch']
  30. else:
  31. t, epoch = 0, 0
  32. checkpoint = {
  33. 'args': args.__dict__,
  34. 'vocab': vocab,
  35. 'model_kwargs': model_kwargs,
  36. 'losses_ts': [],
  37. 'losses': defaultdict(list),
  38. 'd_losses': defaultdict(list),
  39. 'checkpoint_ts': [],
  40. 'train_batch_data': [],
  41. 'train_samples': [],
  42. 'train_iou': [],
  43. 'val_batch_data': [],
  44. 'val_samples': [],
  45. 'val_losses': defaultdict(list),
  46. 'val_iou': [],
  47. 'counters': {
  48. 't': None,
  49. 'epoch': None,
  50. },
  51. 'model_state': None,
  52. 'optim_state': None,
  53. }
  54. while True:
  55. if t >= args.num_iterations:
  56. break
  57. epoch += 1
  58. print('Starting epoch %d' % epoch)
  59. for batch in train_loader:
  60. if t == args.eval_mode_after:
  61. print('switching to eval mode')
  62. model.eval()
  63. t += 1
  64. if t%50 ==0:
  65. print("Currently on batch {}".format(t))
  66. ids, objs, boxes, triples, angles, attributes, obj_to_img, triple_to_img = tensor_aug(batch)
  67. model_out = model(objs, triples, boxes, angles, attributes, obj_to_img)
  68. mu, logvar, boxes_pred, angles_pred = model_out
  69. if args.KL_linear_decay:
  70. KL_weight = 10 ** (t // 1e5 - 6)
  71. else:
  72. KL_weight = args.KL_loss_weight
  73. total_loss, losses = calculate_model_losses(args, model, boxes, boxes_pred, angles, angles_pred, mu=mu, logvar=logvar, KL_weight=KL_weight)
  74. losses['total_loss'] = total_loss.item()
  75. if not math.isfinite(losses['total_loss']):
  76. print('WARNING: Got loss = NaN, not backpropping')
  77. continue
  78. optimizer.zero_grad()
  79. total_loss.backward()
  80. optimizer.step()
  81. if t % args.print_every == 0:
  82. print("On batch {} out of {}".format(t, args.num_iterations))
  83. for name, val in losses.items():
  84. print(' [%s]: %.4f' % (name, val))
  85. checkpoint['losses'][name].append(val)
  86. checkpoint['losses_ts'].append(t)
  87. if t % args.checkpoint_every == 0:
  88. checkpoint['model_state'] = get_model_attr(model, 'state_dict')()
  89. checkpoint['optim_state'] = optimizer.state_dict()
  90. checkpoint['counters']['t'] = t
  91. checkpoint['counters']['epoch'] = epoch
  92. checkpoint_path = os.path.join(args.output_dir, 'latest_%s_with_model.pt' % args.checkpoint_name)
  93. print('Saving checkpoint to ', checkpoint_path)
  94. torch.save(checkpoint, checkpoint_path)
  95. if t % args.snapshot_every == 0:
  96. snapshot_name = args.checkpoint_name + 'snapshot_%06dK' % (t // 1000)
  97. snapshot_path = os.path.join(args.output_dir, snapshot_name)
  98. print('Saving snapshot to ', snapshot_path)
  99. torch.save(checkpoint, snapshot_path)
  100. checkpoint_path = os.path.join(args.output_dir, '%s_no_model.pt' % args.checkpoint_name)
  101. key_blacklist = ['model_state', 'optim_state']
  102. small_checkpoint = {}
  103. for k, v in checkpoint.items():
  104. if k not in key_blacklist:
  105. small_checkpoint[k] = v
  106. torch.save(small_checkpoint, checkpoint_path)
  107. if __name__ == '__main__':
  108. args = Options().parse()
  109. if (args.output_dir is not None) and (not os.path.isdir(args.output_dir)):
  110. os.mkdir(args.output_dir)
  111. if (args.test_dir is not None) and (not os.path.isdir(args.test_dir)):
  112. os.mkdir(args.test_dir)
  113. main(args)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...