Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

train_text_summarizer.py 3.5 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
  1. try:
  2. import wandb
  3. WANDB_FLAG = True
  4. except ImportError:
  5. WANDB_FLAG = False
  6. import os
  7. import jax
  8. import hydra
  9. from omegaconf import DictConfig, OmegaConf
  10. from transformers import BartTokenizer
  11. import tensorflow as tf
  12. from schema import register_configs
  13. from utils.for_training import TextSummarizerTrainer, create_text_summarizer_dataset
  14. import logging
  15. logger = logging.getLogger(__name__)
  16. # Register data classes that manage all the parameters dealt with by hydra config to confirm their type annotation.
  17. register_configs()
  18. @hydra.main(config_path='config', config_name='text_summarization.yaml', version_base="1.2")
  19. def main(cfg: DictConfig):
  20. ## 0. Experiment Management
  21. # Setup weight & biases.
  22. if WANDB_FLAG:
  23. config = OmegaConf.to_container(
  24. cfg, resolve=True, throw_on_missing=True
  25. )
  26. tag_list = [
  27. "bs_{}".format(cfg.training.batch_size),
  28. "tstep_{}".format(cfg.training.num_train_steps),
  29. "opt_{}".format(cfg.opt.mode),
  30. "lr_{}".format(cfg.opt.peak_value if cfg.opt.mode == "scheduler" else cfg.opt.lr),
  31. ]
  32. if cfg.opt.mode == "scheduler":
  33. tag_list.append("wstep_{}".format(cfg.opt.warmup_steps))
  34. wandb.init(
  35. config=config,
  36. project="slide_generation",
  37. # name="",
  38. group="bart_debug",
  39. tags=tag_list,
  40. reinit=True,
  41. settings=wandb.Settings(start_method="thread"),
  42. notes="" # leave comments if you want to log more detailed messages
  43. )
  44. # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make
  45. # it unavailable to JAX.
  46. tf.config.experimental.set_visible_devices([], "GPU")
  47. # Log basic information,
  48. logger.info("Jax local devices: {}".format(jax.local_devices()))
  49. ## 1. Training Preparation
  50. # Get the path to the result directory and the project root.
  51. result_dir = os.getcwd()
  52. project_root = hydra.utils.get_original_cwd()
  53. # Prepare each group of configuration.
  54. training_cfg = cfg.training
  55. layout_model_cfg = cfg.model
  56. dataset_cfg = cfg.dataset
  57. optax_cfg = cfg.opt
  58. n_devices = jax.local_device_count()
  59. assert cfg.training.batch_size % n_devices == 0
  60. tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-xsum")
  61. dataset_dir_path = os.path.join(project_root, dataset_cfg.dataset_dir)
  62. train_dataset, val_dataset = create_text_summarizer_dataset(
  63. dataset_dir_path=dataset_dir_path,
  64. tokenizer=tokenizer,
  65. batch_size=cfg.training.batch_size,
  66. max_length=cfg.training.max_seq_length,
  67. layout_resolution_w=dataset_cfg.resolution_w,
  68. layout_resolution_h=dataset_cfg.resolution_h,
  69. )
  70. # How many epochs do we need (if count is None, that means the dataset will be repeated indefinitely)
  71. train_dataset = train_dataset.repeat(count=None)
  72. val_dataset = val_dataset.repeat(count=None)
  73. logger.info("Finished dataset creation.")
  74. rng = jax.random.PRNGKey(cfg.seed)
  75. train_rng, param_rng = jax.random.split(rng)
  76. # Make a trainer class instance.
  77. trainer = TextSummarizerTrainer.create_trainer(
  78. rng=param_rng,
  79. training_cfg=training_cfg,
  80. model_cfg=layout_model_cfg,
  81. dataset_cfg=dataset_cfg,
  82. optax_cfg=optax_cfg,
  83. save_folder=result_dir,
  84. )
  85. trainer.train(
  86. rng=train_rng,
  87. train_dataset=train_dataset,
  88. val_dataset=val_dataset,
  89. )
  90. if __name__ == '__main__':
  91. main() # pylint: disable=no-value-for-parameter
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...