Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

test_layout_generator.py 6.1 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
  1. try:
  2. import wandb
  3. WANDB_FLAG = True
  4. except ImportError:
  5. WANDB_FLAG = False
  6. import os
  7. import glob
  8. import argparse
  9. import numpy as np
  10. import jax
  11. from omegaconf import OmegaConf
  12. import tensorflow as tf
  13. from transformers import BartTokenizer
  14. from utils.for_training import LayoutGeneratorTrainer, create_layout_generator_dataset, visualize_slides
  15. from models import MyFlaxBartForConditionalGeneration
  16. import logging
  17. logger = logging.getLogger(__name__)
  18. def create_arg_parser():
  19. parser = argparse.ArgumentParser(
  20. description=
  21. 'hogehoge'
  22. )
  23. parser.add_argument(
  24. '--train_result_dir',
  25. type=str,
  26. required=True,
  27. help='Path to the training result directory automatically generated by hydra when training the model.'
  28. )
  29. parser.add_argument(
  30. '--weights_type',
  31. choices=['latest', 'val_smallest'],
  32. help=
  33. 'A flag that determines whether to use the weights of the epoch with the smallest validation loss during ' \
  34. 'training or the weights of the last epoch.'
  35. )
  36. parser.add_argument(
  37. '--device',
  38. type=int,
  39. required=True,
  40. help='The cuda device number used for preprocessing.'
  41. )
  42. parser.add_argument(
  43. '--seed',
  44. type=int,
  45. required=True,
  46. help='The cuda device number used for preprocessing.'
  47. )
  48. parser.add_argument(
  49. '--vis',
  50. action='store_true',
  51. help=
  52. 'If specified, save the visualized ground-truth/predicted segmentations to the `visualization` directory, ' \
  53. 'which will be made under the same hierarchy as the weight path.'
  54. )
  55. return parser
  56. def main():
  57. arg_parser = create_arg_parser()
  58. args = arg_parser.parse_args([
  59. "--train_result_dir", "outputs/SLIDE/BLT/2023-01-07/bs_32-tstep_3000/opt_scheduler-lr_5e-05",
  60. "--weights_type", "val_smallest",
  61. "--device", "3",
  62. "--seed", "42",
  63. "--vis"
  64. ])
  65. os.environ['CUDA_VISIBLE_DEVICES'] = str(args.device)
  66. ckpt_path_list = glob.glob(os.path.join(args.train_result_dir, "ckpt_*"))
  67. if len(ckpt_path_list) == 1:
  68. ckpt_path = ckpt_path_list[0]
  69. logger.warning("Only one checkpoint file could be found under the `{}` directory, so that file will be used for testing regardless of your specification.".format(args.train_result_dir))
  70. else:
  71. train_step_list = [int(each_path.split("_")[-1]) for each_path in ckpt_path_list]
  72. if args.weights_type == "val_smallest":
  73. idx = np.argmax(train_step_list)
  74. else:
  75. idx = np.where(train_step_list==np.sort(train_step_list)[-2])[0][0]
  76. ckpt_path = ckpt_path_list[idx]
  77. conf_filepath = os.path.join(args.train_result_dir, ".hydra/config.yaml")
  78. with open(conf_filepath, mode='r', encoding="utf-8") as f:
  79. cfg = OmegaConf.load(f)
  80. wandb.init(
  81. project="slide_generation",
  82. # name="",
  83. group="blt_visualization",
  84. # tags=tag_list,
  85. notes="" # leave comments if you want to log more detailed messages
  86. )
  87. # logger = get_my_logger(name="Preprocess", filepath=os.path.join(args.dataset_root, "preprocess_log.log"))
  88. # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make
  89. # it unavailable to JAX.
  90. tf.config.experimental.set_visible_devices([], "GPU")
  91. # Log basic information,
  92. logger.info("Jax local devices: {}".format(jax.local_devices()))
  93. ## 1. Test Preparation
  94. # Get the path to the result directory and the project root.
  95. project_root = os.getcwd()
  96. result_dir = os.path.join(args.train_result_dir, "inference")
  97. os.makedirs(result_dir, exist_ok=True)
  98. # Prepare each group of configuration.
  99. training_cfg = cfg.training
  100. layout_model_cfg = cfg.model
  101. dataset_cfg = cfg.dataset
  102. optax_cfg = cfg.opt
  103. n_devices = jax.local_device_count()
  104. test_batch_size = 4
  105. assert test_batch_size % n_devices == 0
  106. tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
  107. bart_model = MyFlaxBartForConditionalGeneration.from_pretrained(
  108. pretrained_model_name_or_path="facebook/bart-large-cnn",
  109. layout_resolution_w=dataset_cfg.resolution_w,
  110. layout_resolution_h=dataset_cfg.resolution_h,
  111. num_layout_class=len(dataset_cfg.label_list)
  112. )
  113. logger.info("Instantiated bart model for snippet embeddings.")
  114. dataset_dir_path = os.path.join(project_root, dataset_cfg.dataset_dir)
  115. test_dataset, vocab_size, logit_masks, offset = create_layout_generator_dataset(
  116. dataset_cfg=dataset_cfg,
  117. dataset_dir_path=dataset_dir_path,
  118. batch_size=test_batch_size,
  119. max_length=cfg.training.max_seq_length,
  120. tokenizer=tokenizer,
  121. add_bos=layout_model_cfg.autoregressive,
  122. test=True
  123. )
  124. logger.info("Finished dataset creation.")
  125. rng = jax.random.PRNGKey(args.seed)
  126. test_rng, param_rng = jax.random.split(rng)
  127. # Make a trainer class instance.
  128. trainer = LayoutGeneratorTrainer.create_trainer(
  129. rng=param_rng,
  130. training_cfg=training_cfg,
  131. model_cfg=layout_model_cfg,
  132. dataset_cfg=dataset_cfg,
  133. optax_cfg=optax_cfg,
  134. vocab_size=vocab_size,
  135. save_folder=result_dir,
  136. ckpt_path=ckpt_path,
  137. test=True
  138. )
  139. generated_samples, real_samples = trainer.inference(
  140. rng=test_rng,
  141. test_dataset=test_dataset,
  142. embedder=bart_model.encode,
  143. logit_masks=logit_masks,
  144. offset=offset,
  145. iterative_nums=[5,5,5],
  146. test_batch_size=test_batch_size,
  147. condition="none",
  148. sampling_method="topp",
  149. num_samples=None,
  150. prior=None,
  151. max_elem_num=22,
  152. )
  153. for sampled_data, prefix in zip([generated_samples, real_samples], ["gen", "real"]):
  154. visualize_slides(
  155. batch_data=sampled_data,
  156. class_label_list=dataset_cfg.label_list,
  157. save_dir=result_dir,
  158. prefix=prefix,
  159. resolution_w=dataset_cfg.resolution_w,
  160. resolution_h=dataset_cfg.resolution_h
  161. )
  162. if __name__ == '__main__':
  163. main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...