1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
|
- """
- Contains code for logging approximate FID scores during training.
- If you want to output ground-truth images from the training dataset, you can
- run this file as a script.
- """
- import os
- import shutil
- import torch
- import copy
- import argparse
- from torchvision.utils import save_image
- from pytorch_fid import fid_score
- from tqdm import tqdm
- import datasets
- import curriculums
- def output_real_images(dataloader, num_imgs, real_dir):
- img_counter = 0
- batch_size = dataloader.batch_size
- dataloader = iter(dataloader)
- for i in range(num_imgs//batch_size):
- real_imgs, _ = next(dataloader)
- for img in real_imgs:
- save_image(img, os.path.join(real_dir, f'{img_counter:0>5}.jpg'), normalize=True, range=(-1, 1))
- img_counter += 1
- def setup_evaluation(dataset_name, generated_dir, target_size=128, num_imgs=8000):
- # Only make real images if they haven't been made yet
- real_dir = os.path.join('EvalImages', dataset_name + '_real_images_' + str(target_size))
- if not os.path.exists(real_dir):
- os.makedirs(real_dir)
- dataloader, CHANNELS = datasets.get_dataset(dataset_name, img_size=target_size)
- print('outputting real images...')
- output_real_images(dataloader, num_imgs, real_dir)
- print('...done')
- if generated_dir is not None:
- os.makedirs(generated_dir, exist_ok=True)
- return real_dir
- def output_images(generator, input_metadata, rank, world_size, output_dir, num_imgs=2048):
- metadata = copy.deepcopy(input_metadata)
- metadata['img_size'] = 128
- metadata['batch_size'] = 4
- metadata['h_stddev'] = metadata.get('h_stddev_eval', metadata['h_stddev'])
- metadata['v_stddev'] = metadata.get('v_stddev_eval', metadata['v_stddev'])
- metadata['sample_dist'] = metadata.get('sample_dist_eval', metadata['sample_dist'])
- metadata['psi'] = 1
- img_counter = rank
- generator.eval()
- img_counter = rank
- if rank == 0: pbar = tqdm("generating images", total = num_imgs)
- with torch.no_grad():
- while img_counter < num_imgs:
- z = torch.randn((metadata['batch_size'], generator.module.z_dim), device=generator.module.device)
- generated_imgs, _ = generator.module.staged_forward(z, **metadata)
- for img in generated_imgs:
- save_image(img, os.path.join(output_dir, f'{img_counter:0>5}.jpg'), normalize=True, range=(-1, 1))
- img_counter += world_size
- if rank == 0: pbar.update(world_size)
- if rank == 0: pbar.close()
- def calculate_fid(dataset_name, generated_dir, target_size=256):
- real_dir = os.path.join('EvalImages', dataset_name + '_real_images_' + str(target_size))
- fid = fid_score.calculate_fid_given_paths([real_dir, generated_dir], 128, 'cuda', 2048)
- torch.cuda.empty_cache()
- return fid
- if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- parser.add_argument('--dataset', type=str, default='CelebA')
- parser.add_argument('--img_size', type=int, default=128)
- parser.add_argument('--num_imgs', type=int, default=8000)
- opt = parser.parse_args()
- real_images_dir = setup_evaluation(opt.dataset, None, target_size=opt.img_size, num_imgs=opt.num_imgs)
|