Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

logger.py 2.2 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
  1. import os
  2. import numpy as np
  3. import pickle
  4. import tensorboardX
  5. import pathlib
  6. from torchvision import transforms
  7. class Logger(object):
  8. def __init__(self, args, experiment_dir):
  9. super(Logger, self).__init__()
  10. self.num_iter = {'train': 0, 'test': 0}
  11. self.no_disk_write_ops = args.no_disk_write_ops
  12. self.rank = args.rank
  13. if not self.no_disk_write_ops:
  14. self.experiment_dir = experiment_dir
  15. for phase in ['train', 'test']:
  16. os.makedirs(experiment_dir / 'images' / phase, exist_ok=True)
  17. self.to_image = transforms.ToPILImage()
  18. if args.rank == 0:
  19. if args.which_epoch != 'none' and args.init_experiment_dir == '':
  20. self.losses = pickle.load(open(self.experiment_dir / 'losses.pkl', 'rb'))
  21. else:
  22. self.losses = {}
  23. self.writer = tensorboardX.SummaryWriter('/tensorboard')
  24. def output_logs(self, phase, visuals, losses, time):
  25. if not self.no_disk_write_ops:
  26. # Increment iter counter
  27. self.num_iter[phase] += 1
  28. # Save visuals
  29. self.to_image(visuals).save(self.experiment_dir / 'images' / phase / ('%04d_%02d.jpg' % (self.num_iter[phase], self.rank)))
  30. if self.rank != 0:
  31. return
  32. self.writer.add_image(f'results_{phase}', visuals, self.num_iter[phase])
  33. # Save losses
  34. for key, value in losses.items():
  35. if key in self.losses:
  36. self.losses[key].append(value)
  37. else:
  38. self.losses[key] = [value]
  39. self.writer.add_scalar(f'{key}_{phase}', value, self.num_iter[phase])
  40. # Save losses
  41. pickle.dump(self.losses, open(self.experiment_dir / 'losses.pkl', 'wb'))
  42. elif self.rank != 0:
  43. return
  44. # Print losses
  45. print(', '.join('%s: %.3f' % (key, value) for key, value in losses.items()) + ', time: %.3f' % time)
  46. def set_num_iter(self, train_iter, test_iter):
  47. self.num_iter = {
  48. 'train': train_iter,
  49. 'test': test_iter}
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...