Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

render_video_interpolation.py 4.9 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
  1. import argparse
  2. import math
  3. import os
  4. from torchvision.utils import save_image
  5. import torch
  6. import numpy as np
  7. from PIL import Image
  8. from tqdm import tqdm
  9. import numpy as np
  10. import skvideo.io
  11. import curriculums
  12. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  13. parser = argparse.ArgumentParser()
  14. parser.add_argument('path', type=str)
  15. parser.add_argument('--seeds', nargs='+', default=[0, 1, 2])
  16. parser.add_argument('--output_dir', type=str, default='vids')
  17. parser.add_argument('--batch_size', type=int, default=1)
  18. parser.add_argument('--max_batch_size', type=int, default=2400000)
  19. parser.add_argument('--depth_map', action='store_true')
  20. parser.add_argument('--lock_view_dependence', action='store_true')
  21. parser.add_argument('--image_size', type=int, default=256)
  22. parser.add_argument('--ray_step_multiplier', type=int, default=2)
  23. parser.add_argument('--num_frames', type=int, default=36)
  24. parser.add_argument('--curriculum', type=str, default='CelebA')
  25. parser.add_argument('--trajectory', type=str, default='front')
  26. parser.add_argument('--psi', type=float, default=0.7)
  27. opt = parser.parse_args()
  28. os.makedirs(opt.output_dir, exist_ok=True)
  29. curriculum = getattr(curriculums, opt.curriculum)
  30. curriculum['num_steps'] = curriculum[0]['num_steps'] * opt.ray_step_multiplier
  31. curriculum['img_size'] = opt.image_size
  32. curriculum['psi'] = opt.psi
  33. curriculum['v_stddev'] = 0
  34. curriculum['h_stddev'] = 0
  35. curriculum['lock_view_dependence'] = opt.lock_view_dependence
  36. curriculum['last_back'] = curriculum.get('eval_last_back', False)
  37. curriculum['num_frames'] = opt.num_frames
  38. curriculum['nerf_noise'] = 0
  39. curriculum = {key: value for key, value in curriculum.items() if type(key) is str}
  40. class FrequencyInterpolator:
  41. def __init__(self, generator, z1, z2, psi=0.5):
  42. avg_frequencies, avg_phase_shifts = generator.generate_avg_frequencies()
  43. raw_frequencies1, raw_phase_shifts1 = generator.siren.mapping_network(z1)
  44. self.truncated_frequencies1 = avg_frequencies + psi * (raw_frequencies1 - avg_frequencies)
  45. self.truncated_phase_shifts1 = avg_phase_shifts + psi * (raw_phase_shifts1 - avg_phase_shifts)
  46. raw_frequencies2, raw_phase_shifts2 = generator.siren.mapping_network(z2)
  47. self.truncated_frequencies2 = avg_frequencies + psi * (raw_frequencies2 - avg_frequencies)
  48. self.truncated_phase_shifts2 = avg_phase_shifts + psi * (raw_phase_shifts2 - avg_phase_shifts)
  49. def forward(self, t):
  50. frequencies = self.truncated_frequencies1 * (1-t) + self.truncated_frequencies2 * t
  51. phase_shifts = self.truncated_phase_shifts1 * (1-t) + self.truncated_phase_shifts2 * t
  52. return frequencies, phase_shifts
  53. def tensor_to_PIL(img):
  54. img = img.squeeze() * 0.5 + 0.5
  55. return Image.fromarray(img.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy())
  56. generator = torch.load(opt.path, map_location=torch.device(device))
  57. ema_file = opt.path.split('generator')[0] + 'ema.pth'
  58. ema = torch.load(ema_file)
  59. ema.copy_to(generator.parameters())
  60. generator.set_device(device)
  61. generator.eval()
  62. if opt.trajectory == 'front':
  63. trajectory = []
  64. for t in np.linspace(0, 1, curriculum['num_frames']):
  65. pitch = 0.2 * np.cos(t * 2 * math.pi) + math.pi/2
  66. yaw = 0.4 * np.sin(t * 2 * math.pi) + math.pi/2
  67. fov = curriculum['fov'] + 5 + np.sin(t * 2 * math.pi) * 5
  68. trajectory.append((t, pitch, yaw, fov))
  69. elif opt.trajectory == 'orbit':
  70. trajectory = []
  71. for t in np.linspace(0, 1, curriculum['num_frames']):
  72. pitch = 0.2 * np.cos(t * 2 * math.pi) + math.pi/4
  73. yaw = t * 2 * math.pi
  74. fov = curriculum['fov']
  75. trajectory.append((t, pitch, yaw, fov))
  76. output_name = f'interp.mp4'
  77. writer = skvideo.io.FFmpegWriter(os.path.join(opt.output_dir, output_name), outputdict={'-pix_fmt': 'yuv420p', '-crf': '21'})
  78. print(opt.seeds)
  79. for i, seed in enumerate(opt.seeds):
  80. frames = []
  81. depths = []
  82. torch.manual_seed(seed)
  83. z_current = torch.randn(1, 256, device=device)
  84. torch.manual_seed(opt.seeds[(i+1)%len(opt.seeds)])
  85. z_next = torch.randn(1, 256, device=device)
  86. frequencyInterpolator = FrequencyInterpolator(generator, z_current, z_next, psi=opt.psi)
  87. with torch.no_grad():
  88. for t, pitch, yaw, fov in tqdm(trajectory):
  89. curriculum['h_mean'] = yaw# + 3.14/2
  90. curriculum['v_mean'] = pitch# + 3.14/2
  91. curriculum['fov'] = fov
  92. curriculum['h_stddev'] = 0
  93. curriculum['v_stddev'] = 0
  94. frame, depth_map = generator.staged_forward_with_frequencies(*frequencyInterpolator.forward(t), max_batch_size=opt.max_batch_size, depth_map=opt.depth_map, **curriculum)
  95. # frame, depth_map = generator.staged_forward(z, max_batch_size=opt.max_batch_size, depth_map=opt.depth_map, **curriculum)
  96. frames.append(tensor_to_PIL(frame))
  97. for frame in frames:
  98. writer.writeFrame(np.array(frame))
  99. writer.close()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...