Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

build_dataset_model.py 1.9 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
  1. from data.suncg_dataset import SuncgDataset, suncg_collate_fn
  2. from torch.utils.data import DataLoader
  3. import json
  4. from models.Sg2ScVAE_model import Sg2ScVAEModel
  5. def build_suncg_dsets(args):
  6. dset_kwargs = {
  7. 'data_dir': args.suncg_train_dir,
  8. 'train_3d': args.train_3d,
  9. 'use_attr_30': args.use_attr_30,
  10. }
  11. train_dset = SuncgDataset(**dset_kwargs)
  12. num_objs = train_dset.total_objects()
  13. num_imgs = len(train_dset)
  14. print('Training dataset has %d scenes and %d objects' % (num_imgs, num_objs))
  15. print('(%.2f objects per image)' % (float(num_objs) / num_imgs))
  16. dset_kwargs['data_dir'] = args.suncg_val_dir
  17. val_dset = SuncgDataset(**dset_kwargs)
  18. assert train_dset.vocab == val_dset.vocab
  19. vocab = json.loads(json.dumps(train_dset.vocab))
  20. return vocab, train_dset, val_dset
  21. def build_loaders(args):
  22. vocab, train_dset, val_dset = build_suncg_dsets(args)
  23. collate_fn = suncg_collate_fn
  24. loader_kwargs = {
  25. 'batch_size': args.batch_size,
  26. 'num_workers': args.loader_num_workers,
  27. 'shuffle': True,
  28. 'collate_fn': collate_fn,
  29. }
  30. train_loader = DataLoader(train_dset, **loader_kwargs)
  31. loader_kwargs['shuffle'] = False
  32. val_loader = DataLoader(val_dset, **loader_kwargs)
  33. return vocab, train_loader, val_loader
  34. def build_model(args, vocab):
  35. kwargs = {
  36. 'vocab': vocab,
  37. 'batch_size': args.batch_size,
  38. 'train_3d': args.train_3d,
  39. 'decoder_cat': args.decoder_cat,
  40. 'embedding_dim': args.embedding_dim,
  41. 'gconv_mode': args.gconv_mode,
  42. 'gconv_num_layers': args.gconv_num_layers,
  43. 'mlp_normalization': args.mlp_normalization,
  44. 'vec_noise_dim': args.vec_noise_dim,
  45. 'layout_noise_dim': args.layout_noise_dim,
  46. 'use_AE': args.use_AE
  47. }
  48. model = Sg2ScVAEModel(**kwargs)
  49. if args.multigpu:
  50. assert False, 'Multi-GPU not supported'
  51. return model, kwargs
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...