Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

fid_evaluation.py 3.2 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
  1. """
  2. Contains code for logging approximate FID scores during training.
  3. If you want to output ground-truth images from the training dataset, you can
  4. run this file as a script.
  5. """
  6. import os
  7. import shutil
  8. import torch
  9. import copy
  10. import argparse
  11. from torchvision.utils import save_image
  12. from pytorch_fid import fid_score
  13. from tqdm import tqdm
  14. import datasets
  15. import curriculums
  16. def output_real_images(dataloader, num_imgs, real_dir):
  17. img_counter = 0
  18. batch_size = dataloader.batch_size
  19. dataloader = iter(dataloader)
  20. for i in range(num_imgs//batch_size):
  21. real_imgs, _ = next(dataloader)
  22. for img in real_imgs:
  23. save_image(img, os.path.join(real_dir, f'{img_counter:0>5}.jpg'), normalize=True, range=(-1, 1))
  24. img_counter += 1
  25. def setup_evaluation(dataset_name, generated_dir, target_size=128, num_imgs=8000):
  26. # Only make real images if they haven't been made yet
  27. real_dir = os.path.join('EvalImages', dataset_name + '_real_images_' + str(target_size))
  28. if not os.path.exists(real_dir):
  29. os.makedirs(real_dir)
  30. dataloader, CHANNELS = datasets.get_dataset(dataset_name, img_size=target_size)
  31. print('outputting real images...')
  32. output_real_images(dataloader, num_imgs, real_dir)
  33. print('...done')
  34. if generated_dir is not None:
  35. os.makedirs(generated_dir, exist_ok=True)
  36. return real_dir
  37. def output_images(generator, input_metadata, rank, world_size, output_dir, num_imgs=2048):
  38. metadata = copy.deepcopy(input_metadata)
  39. metadata['img_size'] = 128
  40. metadata['batch_size'] = 4
  41. metadata['h_stddev'] = metadata.get('h_stddev_eval', metadata['h_stddev'])
  42. metadata['v_stddev'] = metadata.get('v_stddev_eval', metadata['v_stddev'])
  43. metadata['sample_dist'] = metadata.get('sample_dist_eval', metadata['sample_dist'])
  44. metadata['psi'] = 1
  45. img_counter = rank
  46. generator.eval()
  47. img_counter = rank
  48. if rank == 0: pbar = tqdm("generating images", total = num_imgs)
  49. with torch.no_grad():
  50. while img_counter < num_imgs:
  51. z = torch.randn((metadata['batch_size'], generator.module.z_dim), device=generator.module.device)
  52. generated_imgs, _ = generator.module.staged_forward(z, **metadata)
  53. for img in generated_imgs:
  54. save_image(img, os.path.join(output_dir, f'{img_counter:0>5}.jpg'), normalize=True, range=(-1, 1))
  55. img_counter += world_size
  56. if rank == 0: pbar.update(world_size)
  57. if rank == 0: pbar.close()
  58. def calculate_fid(dataset_name, generated_dir, target_size=256):
  59. real_dir = os.path.join('EvalImages', dataset_name + '_real_images_' + str(target_size))
  60. fid = fid_score.calculate_fid_given_paths([real_dir, generated_dir], 128, 'cuda', 2048)
  61. torch.cuda.empty_cache()
  62. return fid
  63. if __name__ == '__main__':
  64. parser = argparse.ArgumentParser()
  65. parser.add_argument('--dataset', type=str, default='CelebA')
  66. parser.add_argument('--img_size', type=int, default=128)
  67. parser.add_argument('--num_imgs', type=int, default=8000)
  68. opt = parser.parse_args()
  69. real_images_dir = setup_evaluation(opt.dataset, None, target_size=opt.img_size, num_imgs=opt.num_imgs)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...