1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
|
- try:
- import wandb
- WANDB_FLAG = True
- except ImportError:
- WANDB_FLAG = False
- import os
- import glob
- import argparse
- import numpy as np
- import jax
- from omegaconf import OmegaConf
- import tensorflow as tf
- from transformers import BartTokenizer
- from utils.for_training import LayoutGeneratorTrainer, create_layout_generator_dataset, visualize_slides
- from models import MyFlaxBartForConditionalGeneration
- import logging
- logger = logging.getLogger(__name__)
- def create_arg_parser():
- parser = argparse.ArgumentParser(
- description=
- 'hogehoge'
- )
- parser.add_argument(
- '--train_result_dir',
- type=str,
- required=True,
- help='Path to the training result directory automatically generated by hydra when training the model.'
- )
- parser.add_argument(
- '--weights_type',
- choices=['latest', 'val_smallest'],
- help=
- 'A flag that determines whether to use the weights of the epoch with the smallest validation loss during ' \
- 'training or the weights of the last epoch.'
- )
- parser.add_argument(
- '--device',
- type=int,
- required=True,
- help='The cuda device number used for preprocessing.'
- )
- parser.add_argument(
- '--seed',
- type=int,
- required=True,
- help='The cuda device number used for preprocessing.'
- )
- parser.add_argument(
- '--vis',
- action='store_true',
- help=
- 'If specified, save the visualized ground-truth/predicted segmentations to the `visualization` directory, ' \
- 'which will be made under the same hierarchy as the weight path.'
- )
- return parser
- def main():
- arg_parser = create_arg_parser()
- args = arg_parser.parse_args([
- "--train_result_dir", "outputs/SLIDE/BLT/2023-01-07/bs_32-tstep_3000/opt_scheduler-lr_5e-05",
- "--weights_type", "val_smallest",
- "--device", "3",
- "--seed", "42",
- "--vis"
- ])
- os.environ['CUDA_VISIBLE_DEVICES'] = str(args.device)
- ckpt_path_list = glob.glob(os.path.join(args.train_result_dir, "ckpt_*"))
- if len(ckpt_path_list) == 1:
- ckpt_path = ckpt_path_list[0]
- logger.warning("Only one checkpoint file could be found under the `{}` directory, so that file will be used for testing regardless of your specification.".format(args.train_result_dir))
- else:
- train_step_list = [int(each_path.split("_")[-1]) for each_path in ckpt_path_list]
- if args.weights_type == "val_smallest":
- idx = np.argmax(train_step_list)
- else:
- idx = np.where(train_step_list==np.sort(train_step_list)[-2])[0][0]
- ckpt_path = ckpt_path_list[idx]
- conf_filepath = os.path.join(args.train_result_dir, ".hydra/config.yaml")
- with open(conf_filepath, mode='r', encoding="utf-8") as f:
- cfg = OmegaConf.load(f)
- wandb.init(
- project="slide_generation",
- # name="",
- group="blt_visualization",
- # tags=tag_list,
- notes="" # leave comments if you want to log more detailed messages
- )
- # logger = get_my_logger(name="Preprocess", filepath=os.path.join(args.dataset_root, "preprocess_log.log"))
- # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make
- # it unavailable to JAX.
- tf.config.experimental.set_visible_devices([], "GPU")
- # Log basic information,
- logger.info("Jax local devices: {}".format(jax.local_devices()))
- ## 1. Test Preparation
- # Get the path to the result directory and the project root.
- project_root = os.getcwd()
- result_dir = os.path.join(args.train_result_dir, "inference")
- os.makedirs(result_dir, exist_ok=True)
- # Prepare each group of configuration.
- training_cfg = cfg.training
- layout_model_cfg = cfg.model
- dataset_cfg = cfg.dataset
- optax_cfg = cfg.opt
- n_devices = jax.local_device_count()
- test_batch_size = 4
- assert test_batch_size % n_devices == 0
- tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
- bart_model = MyFlaxBartForConditionalGeneration.from_pretrained(
- pretrained_model_name_or_path="facebook/bart-large-cnn",
- layout_resolution_w=dataset_cfg.resolution_w,
- layout_resolution_h=dataset_cfg.resolution_h,
- num_layout_class=len(dataset_cfg.label_list)
- )
- logger.info("Instantiated bart model for snippet embeddings.")
- dataset_dir_path = os.path.join(project_root, dataset_cfg.dataset_dir)
- test_dataset, vocab_size, logit_masks, offset = create_layout_generator_dataset(
- dataset_cfg=dataset_cfg,
- dataset_dir_path=dataset_dir_path,
- batch_size=test_batch_size,
- max_length=cfg.training.max_seq_length,
- tokenizer=tokenizer,
- add_bos=layout_model_cfg.autoregressive,
- test=True
- )
- logger.info("Finished dataset creation.")
- rng = jax.random.PRNGKey(args.seed)
- test_rng, param_rng = jax.random.split(rng)
- # Make a trainer class instance.
- trainer = LayoutGeneratorTrainer.create_trainer(
- rng=param_rng,
- training_cfg=training_cfg,
- model_cfg=layout_model_cfg,
- dataset_cfg=dataset_cfg,
- optax_cfg=optax_cfg,
- vocab_size=vocab_size,
- save_folder=result_dir,
- ckpt_path=ckpt_path,
- test=True
- )
- generated_samples, real_samples = trainer.inference(
- rng=test_rng,
- test_dataset=test_dataset,
- embedder=bart_model.encode,
- logit_masks=logit_masks,
- offset=offset,
- iterative_nums=[5,5,5],
- test_batch_size=test_batch_size,
- condition="none",
- sampling_method="topp",
- num_samples=None,
- prior=None,
- max_elem_num=22,
- )
- for sampled_data, prefix in zip([generated_samples, real_samples], ["gen", "real"]):
- visualize_slides(
- batch_data=sampled_data,
- class_label_list=dataset_cfg.label_list,
- save_dir=result_dir,
- prefix=prefix,
- resolution_w=dataset_cfg.resolution_w,
- resolution_h=dataset_cfg.resolution_h
- )
- if __name__ == '__main__':
- main()
|