Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

reconstruction.py 11 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
  1. import os
  2. import sys
  3. import inspect
  4. currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
  5. parentdir = os.path.dirname(currentdir)
  6. sys.path.insert(0, parentdir)
  7. import yaml
  8. import torch
  9. import click
  10. import torch
  11. import pickle
  12. from pathlib import Path
  13. import nvdiffrast.torch as dr
  14. from argparse import Namespace
  15. from nvdiffrec.render import obj
  16. from nvdiffrec.render import light
  17. from nvdiffrec.geometry.dlmesh import DLMesh
  18. from nvdiffrec.supports.uvmap import xatlas_uvmap
  19. from nvdiffrec.geometry.dmtet import DMTetGeometry
  20. from nvdiffrec.supports.training import optimize_mesh
  21. from nvdiffrec.dataset.dataset_nerf import DatasetNERF
  22. from nvdiffrec.supports.validation_and_testing import validate
  23. from nvdiffrec.supports.material_utility import initial_guess_material
  24. RADIUS = 3.0
  25. ROOT = Path(__file__).parent.parent
  26. def set_flags(ref_mesh, out_dir):
  27. FLAGS = Namespace()
  28. FLAGS.iter = 5000
  29. FLAGS.batch = 1
  30. FLAGS.spp = 1
  31. FLAGS.layers = 1
  32. FLAGS.train_res = [512, 512]
  33. FLAGS.display_res = None
  34. FLAGS.texture_res = [1024, 1024]
  35. FLAGS.display_interval = 0
  36. FLAGS.save_interval = 500
  37. FLAGS.learning_rate = 0.01
  38. FLAGS.min_roughness = 0.08
  39. FLAGS.custom_mip = True
  40. FLAGS.random_textures = True
  41. FLAGS.background = "white"
  42. FLAGS.loss = "logl1"
  43. FLAGS.out_dir = out_dir
  44. FLAGS.ref_mesh = ref_mesh
  45. FLAGS.base_mesh = None
  46. FLAGS.validate = True
  47. FLAGS.mtl_override = None # Override material of model
  48. FLAGS.dmtet_grid = 256 # Resolution of initial tet grid.
  49. FLAGS.mesh_scale = 2.5 # Scale of tet grid box. Adjust to cover the model
  50. FLAGS.env_scale = 1.0 # Env map intensity multiplier
  51. FLAGS.envmap = None # HDR environment probe
  52. FLAGS.display = [
  53. {"latlong": True}, {"bsdf": "kd"}, {"bsdf": "ks"}, {"bsdf": "normal"}
  54. ]
  55. FLAGS.camera_space_light = False # Fixed light in camera space.
  56. FLAGS.lock_light = False # Disable light optimization in the second pass
  57. FLAGS.lock_pos = False # Disable vertex position optimization in the second pass
  58. FLAGS.sdf_regularizer = 0.2 # Weight for sdf regularizer (see paper for details)
  59. FLAGS.laplace = "relative" # Mesh Laplacian ["absolute", "relative"]
  60. FLAGS.laplace_scale = (
  61. 7500 # Weight for sdf regularizer. Default is relative with large weight
  62. )
  63. FLAGS.pre_load = True # Pre-load entire dataset into memory for faster training
  64. FLAGS.kd_min = [0.0, 0.0, 0.0, 0.0] # Limits for kd
  65. FLAGS.kd_max = [1.0, 1.0, 1.0, 1.0]
  66. FLAGS.ks_min = [0.0, 0.2, 0.0] # Limits for ks
  67. FLAGS.ks_max = [1.0, 1.0, 1.0]
  68. FLAGS.nrm_min = [-1.0, -1.0, 0.0] # Limits for normal map
  69. FLAGS.nrm_max = [1.0, 1.0, 1.0]
  70. FLAGS.cam_near_far = [0.1, 1000.0]
  71. FLAGS.learn_light = True
  72. FLAGS.local_rank = 0
  73. FLAGS.multi_gpu = "WORLD_SIZE" in os.environ and int(os.environ["WORLD_SIZE"]) > 1
  74. if FLAGS.multi_gpu:
  75. if "MASTER_ADDR" not in os.environ:
  76. os.environ["MASTER_ADDR"] = "localhost"
  77. if "MASTER_PORT" not in os.environ:
  78. os.environ["MASTER_PORT"] = "23456"
  79. FLAGS.local_rank = int(os.environ["LOCAL_RANK"])
  80. torch.cuda.set_device(FLAGS.local_rank)
  81. torch.distributed.init_process_group(backend="nccl", init_method="env://")
  82. # save best params
  83. with open(ROOT / Path("params.yaml"), "r") as stream:
  84. data = yaml.safe_load(stream)
  85. for key in data:
  86. FLAGS.__dict__[key] = data[key]
  87. # if FLAGS.config is not None:
  88. # data = json.load(open(FLAGS.config, "r"))
  89. # for key in data:
  90. # FLAGS.__dict__[key] = data[key]
  91. if FLAGS.display_res is None:
  92. FLAGS.display_res = FLAGS.train_res
  93. if FLAGS.local_rank == 0:
  94. print("Config / Flags:")
  95. print("---------")
  96. for key in FLAGS.__dict__.keys():
  97. print(key, FLAGS.__dict__[key])
  98. print("---------")
  99. os.makedirs(FLAGS.out_dir, exist_ok=True)
  100. return FLAGS
  101. @click.group()
  102. def main():
  103. """
  104. Entry point for training scripts
  105. """
  106. pass
  107. @main.command()
  108. @click.option("--ref_mesh", type=str, default="data/processed/configs/", help="Config file")
  109. @click.option("--out_dir", type=str, default="data/results", help="Config file")
  110. def base_run(ref_mesh, out_dir):
  111. FLAGS = set_flags(ref_mesh, out_dir)
  112. glctx = dr.RasterizeGLContext()
  113. # ===============================================================================
  114. # Create data pipeline
  115. # ===============================================================================
  116. dataset_train = DatasetNERF(
  117. os.path.join(FLAGS.ref_mesh, "nerf_transforms.json"),
  118. FLAGS,
  119. examples=(FLAGS.iter + 1) * FLAGS.batch,
  120. )
  121. dataset_validate = DatasetNERF(
  122. os.path.join(FLAGS.ref_mesh, "nerf_transforms.json"), FLAGS
  123. )
  124. # ===============================================================================
  125. # Create env light with trainable parameters
  126. # ===============================================================================
  127. if FLAGS.learn_light:
  128. print("we are learning light")
  129. lgt = light.create_trainable_env_rnd(512, scale=0.0, bias=0.5)
  130. print(lgt.type)
  131. else:
  132. lgt = light.load_env(FLAGS.envmap, scale=FLAGS.env_scale)
  133. # Setup geometry for optimization
  134. geometry = DMTetGeometry(FLAGS.dmtet_grid, FLAGS.mesh_scale, FLAGS)
  135. # Setup textures, make initial guess from reference if possible
  136. mat = initial_guess_material(geometry, True, FLAGS)
  137. # Run optimization
  138. geometry, mat = optimize_mesh(
  139. glctx,
  140. geometry,
  141. mat,
  142. lgt,
  143. dataset_train,
  144. dataset_validate,
  145. FLAGS,
  146. pass_idx=0,
  147. pass_name="dmtet_pass1",
  148. optimize_light=FLAGS.learn_light,
  149. )
  150. if FLAGS.local_rank == 0 and FLAGS.validate:
  151. validate(
  152. glctx,
  153. geometry,
  154. mat,
  155. lgt,
  156. dataset_validate,
  157. FLAGS.out_dir,
  158. FLAGS,
  159. )
  160. # path to artefacts
  161. path_to_pickles = os.path.join(FLAGS.out_dir, "artefact_storage")
  162. os.makedirs(path_to_pickles, exist_ok=True)
  163. # save mesh without textures
  164. eval_mesh = geometry.getMesh(mat)
  165. obj.write_obj(path_to_pickles, eval_mesh)
  166. # save materials
  167. torch.save(mat.state_dict(), os.path.join(path_to_pickles, "mat.pt"))
  168. # save geometry
  169. with open(os.path.join(path_to_pickles, "geometry.pickle"), "wb") as file:
  170. pickle.dump(geometry, file)
  171. # save lgt
  172. with open(os.path.join(path_to_pickles, "lgt.pickle"), "wb") as file:
  173. pickle.dump(lgt, file)
  174. # save FLAGS
  175. with open(os.path.join(path_to_pickles, "FLAGS.pickle"), "wb") as file:
  176. pickle.dump(FLAGS, file)
  177. @main.command()
  178. @click.option(
  179. "--path_to_flags",
  180. type=str,
  181. default="data/results/artefact_storage/FLAGS.pickle",
  182. help="Config file"
  183. )
  184. def refinement_run(path_to_flags):
  185. with open(path_to_flags, "rb") as file:
  186. FLAGS = pickle.load(file)
  187. # ===============================================================================
  188. # Create data pipeline
  189. # ===============================================================================
  190. dataset_train = DatasetNERF(
  191. os.path.join(FLAGS.ref_mesh, "nerf_transforms.json"),
  192. FLAGS,
  193. examples=(FLAGS.iter + 1) * FLAGS.batch,
  194. )
  195. dataset_validate = DatasetNERF(
  196. os.path.join(FLAGS.ref_mesh, "nerf_transforms.json"), FLAGS
  197. )
  198. path_to_pickles = os.path.join(FLAGS.out_dir, "artefact_storage")
  199. # load geometry
  200. with open(os.path.join(path_to_pickles, "geometry.pickle"), "rb") as file:
  201. geometry = pickle.load(file)
  202. # load light
  203. with open(os.path.join(path_to_pickles, "lgt.pickle"), "rb") as file:
  204. lgt = pickle.load(file)
  205. # load materials
  206. mat = initial_guess_material(geometry, True, FLAGS)
  207. mat.load_state_dict(torch.load(os.path.join(path_to_pickles, "mat.pt")))
  208. # load mesh
  209. eval_mesh = obj.load_obj_without_mat(os.path.join(path_to_pickles, "remesh.obj"))
  210. eval_mesh.material = mat
  211. print("Creating RasterizeGLContext")
  212. # load glctx
  213. glctx = dr.RasterizeGLContext()
  214. print("Done!\n")
  215. print(f"Now we are going to create textures")
  216. # Trying to create textured mesh from result
  217. base_mesh = xatlas_uvmap(glctx, geometry, mat, FLAGS, eval_mesh)
  218. print("Done!\n")
  219. # Free temporaries / cached memory
  220. torch.cuda.empty_cache()
  221. mat["kd_ks_normal"].cleanup()
  222. del mat["kd_ks_normal"]
  223. lgt = lgt.clone()
  224. geometry = DLMesh(base_mesh, FLAGS)
  225. if FLAGS.local_rank == 0:
  226. # Dump mesh for debugging.
  227. os.makedirs(os.path.join(FLAGS.out_dir, "dmtet_mesh"), exist_ok=True)
  228. obj.write_obj(os.path.join(FLAGS.out_dir, "dmtet_mesh/"), base_mesh)
  229. light.save_env_map(os.path.join(FLAGS.out_dir, "dmtet_mesh/probe.hdr"), lgt)
  230. # ==========================================================================
  231. # Pass 2: Train with fixed topology (mesh)
  232. # ==========================================================================
  233. geometry, mat = optimize_mesh(
  234. glctx,
  235. geometry,
  236. base_mesh.material,
  237. lgt,
  238. dataset_train,
  239. dataset_validate,
  240. FLAGS,
  241. pass_idx=1,
  242. pass_name="mesh_pass",
  243. warmup_iter=100,
  244. optimize_light=FLAGS.learn_light and not FLAGS.lock_light,
  245. optimize_geometry=not FLAGS.lock_pos,
  246. )
  247. # ==========================================================================
  248. # Validate
  249. # ==========================================================================
  250. if FLAGS.validate and FLAGS.local_rank == 0:
  251. validate(
  252. glctx,
  253. geometry,
  254. mat,
  255. lgt,
  256. dataset_validate,
  257. os.path.join(FLAGS.out_dir, "validate"),
  258. FLAGS,
  259. )
  260. # ===========================================================================
  261. # Dump output
  262. # ===========================================================================
  263. if FLAGS.local_rank == 0:
  264. final_mesh = geometry.getMesh(mat)
  265. os.makedirs(os.path.join(FLAGS.out_dir, "mesh"), exist_ok=True)
  266. obj.write_obj(os.path.join(FLAGS.out_dir, "mesh/"), final_mesh)
  267. light.save_env_map(os.path.join(FLAGS.out_dir, "mesh/probe.hdr"), lgt)
  268. # ----------------------------------------------------------------------------
  269. if __name__ == "__main__":
  270. main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...