Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

decomposition.py 9.1 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
  1. # Copyright 2020 Erik Härkönen. All rights reserved.
  2. # This file is licensed to you under the Apache License, Version 2.0 (the "License");
  3. # you may not use this file except in compliance with the License. You may obtain a copy
  4. # of the License at http://www.apache.org/licenses/LICENSE-2.0
  5. # Unless required by applicable law or agreed to in writing, software distributed under
  6. # the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS
  7. # OF ANY KIND, either express or implied. See the License for the specific language
  8. # governing permissions and limitations under the License.
  9. # Patch for broken CTRL+C handler
  10. # https://github.com/ContinuumIO/anaconda-issues/issues/905
  11. import copy
  12. import datetime
  13. import os
  14. import sys
  15. from pathlib import Path
  16. sys.path.append('./models/stylegan2')
  17. import dnnlib
  18. import dnnlib.tflib as tflib
  19. import matplotlib.pyplot as plt
  20. import pretrained_networks
  21. from PIL import Image
  22. from scipy.stats import special_ortho_group
  23. from tqdm import trange
  24. from estimators import get_estimator
  25. from utils import *
  26. os.environ['FOR_DISABLE_CONSOLE_CTRL_HANDLER'] = '1'
  27. SEED_SAMPLING = 1
  28. DEFAULT_BATCH_SIZE = 20
  29. SEED_RANDOM_DIRS = 2
  30. B = 20
  31. def get_random_dirs(components, dimensions):
  32. gen = np.random.RandomState(seed=SEED_RANDOM_DIRS)
  33. dirs = gen.normal(size=(components, dimensions))
  34. dirs /= np.sqrt(np.sum(dirs**2, axis=1, keepdims=True))
  35. return dirs.astype(np.float32)
  36. def load_network(out_class, model=2):
  37. network = out_classes[model][out_class]
  38. _G, _D, Gs = pretrained_networks.load_networks(network)
  39. Gs_kwargs = dnnlib.EasyDict()
  40. Gs_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
  41. Gs_kwargs.randomize_noise = False
  42. noise_vars = [var for name, var in Gs.components.synthesis.vars.items() if name.startswith('noise')]
  43. rnd = np.random.RandomState(0)
  44. tflib.set_vars({var: rnd.randn(*var.shape.as_list()) for var in noise_vars}) # [height, width]
  45. return Gs, Gs_kwargs
  46. def pca(Gs, stylegan_version, out_class, estimator='ipca', batch_size=20, num_components=80, num_samples=1_000_000, use_w=True, force_recompute=False, seed_compute=None):
  47. dump_name = "{}-{}_{}_c{}_n{}{}{}.npz".format(
  48. f'stylegan{stylegan_version}',
  49. out_class.replace(' ', '_'),
  50. estimator.lower(),
  51. num_components,
  52. num_samples,
  53. '_w' if use_w else '',
  54. f'_seed{seed_compute}' if seed_compute else ''
  55. )
  56. dump_path = Path(f'./cache/components/{dump_name}')
  57. if not dump_path.is_file() or force_recompute:
  58. os.makedirs(dump_path.parent, exist_ok=True)
  59. compute_pca(Gs, estimator, batch_size, num_components, num_samples, use_w, seed_compute, dump_path)
  60. return dump_path
  61. def compute_pca(Gs, estimator, batch_size, num_components, num_samples, use_w, seed, dump_path):
  62. global B
  63. timestamp = lambda : datetime.datetime.now().strftime("%d.%m %H:%M")
  64. print(f'[{timestamp()}] Computing', dump_path.name)
  65. # Ensure reproducibility
  66. np.random.seed(0)
  67. # Regress back to w space
  68. if use_w:
  69. print('Using W latent space')
  70. sample_shape = Gs.components.mapping.run(np.random.randn(1, *Gs.input_shape[1:]), None, dlatent_broadcast=None).shape
  71. sample_dims = np.prod(sample_shape)
  72. print("Feature shape: ", sample_shape)
  73. print("Feature dims: ", sample_dims)
  74. input_shape = (1, *Gs.input_shape[1:])
  75. input_dims = np.prod(input_shape)
  76. components = min(num_components, sample_dims)
  77. transformer = get_estimator(estimator, components, 1.0)
  78. X = None
  79. X_global_mean = None
  80. # Figure out batch size if not provided
  81. B = batch_size or DEFAULT_BATCH_SIZE
  82. # Divisible by B (ignored in output name)
  83. N = num_samples // B * B
  84. w_avg = Gs.get_var('dlatent_avg')
  85. # Compute maximum batch size based on RAM + pagefile budget
  86. target_bytes = 20 * 1_000_000_000 # GB
  87. feat_size_bytes = sample_dims * np.dtype('float64').itemsize
  88. N_limit_RAM = np.floor_divide(target_bytes, feat_size_bytes)
  89. if not transformer.batch_support and N > N_limit_RAM:
  90. print('WARNING: estimator does not support batching, ' \
  91. 'given config will use {:.1f} GB memory.'.format(feat_size_bytes / 1_000_000_000 * N))
  92. print('B={}, N={}, dims={}, N/dims={:.1f}'.format(B, N, sample_dims, N/sample_dims), flush=True)
  93. # Must not depend on chosen batch size (reproducibility)
  94. NB = max(B, max(2_000, 3*components)) # ipca: as large as possible!
  95. samples = None
  96. if not transformer.batch_support:
  97. samples = np.zeros((N + NB, sample_dims), dtype=np.float32)
  98. np.random.seed(seed or SEED_SAMPLING)
  99. # Use exactly the same latents regardless of batch size
  100. # Store in main memory, since N might be huge (1M+)
  101. # Run in batches, since sample_latent() might perform Z -> W mapping
  102. n_lat = ((N + NB - 1) // B + 1) * B
  103. latents = np.zeros((n_lat, *input_shape[1:]), dtype=np.float32)
  104. for i in trange(n_lat // B, desc='Sampling latents'):
  105. seed_global = np.random.randint(np.iinfo(np.int32).max) # use (reproducible) global rand state
  106. rng = np.random.RandomState(seed_global)
  107. # z = np.random.randn(B, *input_shape[1:])
  108. z = rng.standard_normal(512 * B).reshape(B, 512)
  109. if use_w:
  110. w = Gs.components.mapping.run(z, None, dlatent_broadcast=None)
  111. latents[i*B:(i+1)*B] = w
  112. else:
  113. latents[i*B:(i+1)*B] = z
  114. # Decomposition on non-Gaussian latent space
  115. samples_are_latents = use_w
  116. canceled = False
  117. try:
  118. X = np.ones((NB, sample_dims), dtype=np.float32)
  119. action = 'Fitting' if transformer.batch_support else 'Collecting'
  120. for gi in trange(0, N, NB, desc=f'{action} batches (NB={NB})', ascii=True):
  121. for mb in range(0, NB, B):
  122. z = latents[gi+mb:gi+mb+B]
  123. batch = z.reshape((B, -1))
  124. space_left = min(B, NB - mb)
  125. X[mb:mb+space_left] = batch[:space_left]
  126. if transformer.batch_support:
  127. if not transformer.fit_partial(X.reshape(-1, sample_dims)):
  128. break
  129. else:
  130. samples[gi:gi+NB, :] = X.copy()
  131. except KeyboardInterrupt:
  132. if not transformer.batch_support:
  133. sys.exit(1) # no progress yet
  134. dump_name = dump_path.parent / dump_path.name.replace(f'n{N}', f'n{gi}')
  135. print(f'Saving current state to "{dump_name.name}" before exiting')
  136. canceled = True
  137. if not transformer.batch_support:
  138. X = samples # Use all samples
  139. X_global_mean = X.mean(axis=0, keepdims=True, dtype=np.float32)
  140. X -= X_global_mean
  141. print(f'[{timestamp()}] Fitting whole batch')
  142. t_start_fit = datetime.datetime.now()
  143. transformer.fit(X)
  144. print(f'[{timestamp()}] Done in {datetime.datetime.now() - t_start_fit}')
  145. assert np.all(transformer.transformer.mean_ < 1e-3), 'Mean of normalized data should be zero'
  146. else:
  147. X_global_mean = transformer.transformer.mean_.reshape((1, sample_dims))
  148. X = X.reshape(-1, sample_dims)
  149. X -= X_global_mean
  150. X_comp, X_stdev, X_var_ratio = transformer.get_components()
  151. assert X_comp.shape[1] == sample_dims \
  152. and X_comp.shape[0] == components \
  153. and X_global_mean.shape[1] == sample_dims \
  154. and X_stdev.shape[0] == components, 'Invalid shape'
  155. Z_comp = X_comp
  156. Z_global_mean = X_global_mean
  157. # Normalize
  158. Z_comp /= np.linalg.norm(Z_comp, axis=-1, keepdims=True)
  159. # Random projections
  160. # We expect these to explain much less of the variance
  161. random_dirs = get_random_dirs(components, np.prod(sample_shape))
  162. n_rand_samples = min(5000, X.shape[0])
  163. X_view = X[:n_rand_samples, :].T
  164. assert np.shares_memory(X_view, X), "Error: slice produced copy"
  165. X_stdev_random = np.dot(random_dirs, X_view).std(axis=1)
  166. # Inflate back to proper shapes (for easier broadcasting)
  167. X_comp = X_comp.reshape(-1, *sample_shape)
  168. X_global_mean = X_global_mean.reshape(sample_shape)
  169. Z_comp = Z_comp.reshape(-1, *input_shape)
  170. Z_global_mean = Z_global_mean.reshape(input_shape)
  171. # Compute stdev in latent space if non-Gaussian
  172. lat_stdev = np.ones_like(X_stdev)
  173. if use_w:
  174. seed_global = np.random.randint(np.iinfo(np.int32).max) # use (reproducible) global rand state
  175. rng = np.random.RandomState(seed_global)
  176. z = rng.standard_normal(512 * 5000).reshape(5000, 512)
  177. samples = Gs.components.mapping.run(z, None, dlatent_broadcast=None).reshape(5000, input_dims)
  178. coords = np.dot(Z_comp.reshape(-1, input_dims), samples.T)
  179. lat_stdev = coords.std(axis=1)
  180. np.savez_compressed(dump_path, **{
  181. 'act_comp': X_comp.astype(np.float32),
  182. 'act_mean': X_global_mean.astype(np.float32),
  183. 'act_stdev': X_stdev.astype(np.float32),
  184. 'lat_comp': Z_comp.astype(np.float32),
  185. 'lat_mean': Z_global_mean.astype(np.float32),
  186. 'lat_stdev': lat_stdev.astype(np.float32),
  187. 'var_ratio': X_var_ratio.astype(np.float32),
  188. 'random_stdevs': X_stdev_random.astype(np.float32),
  189. })
  190. if canceled:
  191. sys.exit(1)
  192. del X
  193. del X_comp
  194. del random_dirs
  195. del batch
  196. del samples
  197. del latents
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...