Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

extract_shapes.py 3.6 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
  1. import plyfile
  2. import argparse
  3. import torch
  4. import numpy as np
  5. import skimage.measure
  6. import scipy
  7. import mrcfile
  8. import os
  9. def create_samples(N=256, voxel_origin=[0, 0, 0], cube_length=2.0):
  10. # NOTE: the voxel_origin is actually the (bottom, left, down) corner, not the middle
  11. voxel_origin = np.array(voxel_origin) - cube_length/2
  12. voxel_size = cube_length / (N - 1)
  13. overall_index = torch.arange(0, N ** 3, 1, out=torch.LongTensor())
  14. samples = torch.zeros(N ** 3, 3)
  15. # transform first 3 columns
  16. # to be the x, y, z index
  17. samples[:, 2] = overall_index % N
  18. samples[:, 1] = (overall_index.float() / N) % N
  19. samples[:, 0] = ((overall_index.float() / N) / N) % N
  20. # transform first 3 columns
  21. # to be the x, y, z coordinate
  22. samples[:, 0] = (samples[:, 0] * voxel_size) + voxel_origin[2]
  23. samples[:, 1] = (samples[:, 1] * voxel_size) + voxel_origin[1]
  24. samples[:, 2] = (samples[:, 2] * voxel_size) + voxel_origin[0]
  25. num_samples = N ** 3
  26. return samples.unsqueeze(0), voxel_origin, voxel_size
  27. def sample_generator(generator, z, max_batch=100000, voxel_resolution=256, voxel_origin=[0,0,0], cube_length=2.0, psi=0.5):
  28. head = 0
  29. samples, voxel_origin, voxel_size = create_samples(voxel_resolution, voxel_origin, cube_length)
  30. samples = samples.to(z.device)
  31. sigmas = torch.zeros((samples.shape[0], samples.shape[1], 1), device=z.device)
  32. transformed_ray_directions_expanded = torch.zeros((samples.shape[0], max_batch, 3), device=z.device)
  33. transformed_ray_directions_expanded[..., -1] = -1
  34. generator.generate_avg_frequencies()
  35. with torch.no_grad():
  36. raw_frequencies, raw_phase_shifts = generator.siren.mapping_network(z)
  37. truncated_frequencies = generator.avg_frequencies + psi * (raw_frequencies - generator.avg_frequencies)
  38. truncated_phase_shifts = generator.avg_phase_shifts + psi * (raw_phase_shifts - generator.avg_phase_shifts)
  39. with torch.no_grad():
  40. while head < samples.shape[1]:
  41. coarse_output = generator.siren.forward_with_frequencies_phase_shifts(samples[:, head:head+max_batch], truncated_frequencies, truncated_phase_shifts, ray_directions=transformed_ray_directions_expanded[:, :samples.shape[1]-head]).reshape(samples.shape[0], -1, 4)
  42. sigmas[:, head:head+max_batch] = coarse_output[:, :, -1:]
  43. head += max_batch
  44. sigmas = sigmas.reshape((voxel_resolution, voxel_resolution, voxel_resolution)).cpu().numpy()
  45. return sigmas
  46. if __name__ == '__main__':
  47. parser = argparse.ArgumentParser()
  48. parser.add_argument('path', type=str)
  49. parser.add_argument('--seeds', nargs='+', default=[0, 1, 2])
  50. parser.add_argument('--cube_size', type=float, default=0.3)
  51. parser.add_argument('--voxel_resolution', type=int, default=256)
  52. parser.add_argument('--output_dir', type=str, default='shapes')
  53. opt = parser.parse_args()
  54. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  55. generator = torch.load(opt.path, map_location=torch.device(device))
  56. ema = torch.load(opt.path.split('generator')[0] + 'ema.pth')
  57. ema.copy_to(generator.parameters())
  58. generator.set_device(device)
  59. generator.eval()
  60. for seed in opt.seeds:
  61. torch.manual_seed(seed)
  62. z = torch.randn(1, 256, device=device)
  63. voxel_grid = sample_generator(generator, z, cube_length=opt.cube_size, voxel_resolution=opt.voxel_resolution)
  64. os.makedirs(opt.output_dir, exist_ok=True)
  65. with mrcfile.new_mmap(os.path.join(opt.output_dir, f'{seed}.mrc'), overwrite=True, shape=voxel_grid.shape, mrc_mode=2) as mrc:
  66. mrc.data[:] = voxel_grid
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...