Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

train.py 9.1 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
  1. import os
  2. import sys
  3. from pathlib import Path
  4. import fire
  5. import logging
  6. import wandb
  7. import pandas as pd
  8. from discord import SyncWebhook
  9. from addict import Dict
  10. import yaml
  11. import runpod
  12. from huggingface_hub import Repository, create_repo
  13. from transformers.utils import get_full_repo_name
  14. from transformers.trainer_callback import TrainerCallback
  15. from accelerate import Accelerator
  16. from accelerate.tracking import on_main_process
  17. import torch
  18. import numpy as np
  19. from huggingface_hub import login
  20. login(os.environ.get("HUGGINGFACE_TOKEN"), add_to_git_credential=True)
  21. project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
  22. axolotl_root = os.getenv("AXOLOTL_ROOT", os.path.abspath(os.path.join(project_root, "../axolotl")))
  23. src_dir = os.path.join(axolotl_root, "src")
  24. scripts_dir = os.path.join(axolotl_root, "scripts")
  25. sys.path.insert(0, src_dir)
  26. sys.path.insert(0, scripts_dir)
  27. import finetune
  28. import axolotl
  29. from axolotl.utils.trainer import setup_trainer as setup_trainer_orig
  30. from axolotl.utils.models import load_tokenizer as load_tokenizer_orig, load_model as load_model_orig
  31. from axolotl.utils.dict import DictDefault
  32. logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"))
  33. context = {}
  34. # TODO: avoid code dup
  35. def notify_discord(msg):
  36. webhook = SyncWebhook.from_url(os.getenv("DISCORD_WEBHOOK_URL"))
  37. webhook.send(msg)
  38. def edit_discord_message(last_msg, msg):
  39. return last_msg.edit(content=msg)
  40. def log_info(msg):
  41. logging.info(msg)
  42. notify_discord(msg)
  43. def log_error(msg, exc_info=None):
  44. logging.error(msg, exc_info=exc_info)
  45. if exc_info is not None:
  46. notify_discord(f'{msg}: {exc_info}')
  47. else:
  48. notify_discord(msg)
  49. def local_rank():
  50. return int(os.environ.get("LOCAL_RANK", 0))
  51. def parse_config(config, kwargs):
  52. # TODO: avoid code dup
  53. # load the config from the yaml file
  54. # Mostly borrowed from https://github.com/utensil/axolotl/blob/local_dataset/scripts/finetune.py
  55. with open(config, encoding="utf-8") as file:
  56. cfg: DictDefault = DictDefault(yaml.safe_load(file))
  57. # if there are any options passed in the cli, if it is something that seems valid from the yaml,
  58. # then overwrite the value
  59. cfg_keys = cfg.keys()
  60. for k, _ in kwargs.items():
  61. # if not strict, allow writing to cfg even if it's not in the yml already
  62. if k in cfg_keys or not cfg.strict:
  63. # handle booleans
  64. if isinstance(cfg[k], bool):
  65. cfg[k] = bool(kwargs[k])
  66. else:
  67. cfg[k] = kwargs[k]
  68. return cfg
  69. def init_accelerator_with_trackers(cfg):
  70. # os.environ["WANDB_RESUME"] = "auto"
  71. if cfg.wandb_project is not None:
  72. run_id = cfg.wandb_run_id or wandb.util.generate_id()
  73. accelerator = Accelerator(log_with="wandb")
  74. accelerator.init_trackers(cfg.wandb_project, init_kwargs={"id": run_id})
  75. # run = wandb.init(project=cfg.wandb_project, id=run_id) #, resume=True)
  76. os.environ["WANDB_RUN_ID"] = run_id
  77. return accelerator
  78. else:
  79. accelerator = Accelerator()
  80. return accelerator
  81. def train_ex(
  82. config,
  83. prepare_ds_only: bool = False,
  84. **kwargs,
  85. ):
  86. config = Path(config.strip())
  87. if local_rank() == 0:
  88. log_info(f"Prepare training with config: {config}")
  89. cfg = parse_config(config, kwargs)
  90. accelerator = init_accelerator_with_trackers(cfg)
  91. try:
  92. logging.info('train_ex before')
  93. finetune.train(config, cfg.runpod.prepare_ds_only or prepare_ds_only, **kwargs)
  94. logging.info('train_ex after')
  95. except Exception as ex:
  96. log_error(f"Error during training: {ex}", exc_info=ex)
  97. finally:
  98. accelerator.end_training()
  99. # If we need it stay alive for inspection, we should set one_shot to false
  100. if cfg.runpod.one_shot:
  101. runpod.api_key = os.getenv("RUNPOD_API_KEY")
  102. pod_id = os.getenv("RUNPOD_POD_ID")
  103. log_info(f"Pod {pod_id} terminated on train end")
  104. runpod.terminate_pod(pod_id)
  105. def log_data(name, data, tokenizer):
  106. # logging.info(f'{name}(type={type(data)}, shape={data.shape}):\n{data}')
  107. try:
  108. if data.ndim == 3:
  109. data = torch.argmax(torch.from_numpy(data), dim=-1)
  110. if data.ndim != 2:
  111. raise ValueError(f'Invalid data shape: {type(data)} {data.shape}')
  112. data = np.where(data != -100, data, tokenizer.pad_token_id)
  113. logging.info(f'{name}:\n{tokenizer.batch_decode(data, skip_special_tokens=True)}')
  114. if wandb.run:
  115. for i in range(len(data)):
  116. hist = wandb.Histogram(data[i]) #, num_bins=512)
  117. wandb.log({f"histogram/{name}": hist})
  118. except Exception as ex:
  119. logging.error(f'Error logging {name}: {ex}', exc_info=ex)
  120. def decode_data(name, data, tokenizer):
  121. try:
  122. if data.ndim == 3:
  123. data = torch.argmax(torch.from_numpy(data), dim=-1)
  124. if data.ndim != 2:
  125. raise ValueError(f'Invalid data shape: {type(data)} {data.shape}')
  126. data = np.where(data != -100, data, tokenizer.pad_token_id)
  127. return tokenizer.batch_decode(data)
  128. except Exception as ex:
  129. logging.error(f'Error decoding {name}, returning empty strings: {ex}', exc_info=ex)
  130. return ['' for _ in range(len(data))]
  131. def log_eval_prediction_debug(ep, tokenizer):
  132. log_data('inputs', ep.inputs, tokenizer)
  133. log_data('predictions', ep.predictions, tokenizer)
  134. log_data('labels', ep.label_ids, tokenizer)
  135. def log_eval_prediction(ep, tokenizer):
  136. if wandb.run:
  137. try:
  138. data = {
  139. 'input': decode_data('inputs', ep.inputs, tokenizer),
  140. 'prediction': decode_data('predictions', ep.predictions, tokenizer),
  141. 'labels': decode_data('label_ids', ep.label_ids, tokenizer)
  142. }
  143. df = pd.DataFrame(data)
  144. table = wandb.Table(dataframe=df)
  145. artifact = wandb.Artifact('eval', type="dataset")
  146. artifact.add(table, 't_eval')
  147. wandb.run.log_artifact(artifact)
  148. except Exception as ex:
  149. logging.error(f'Error logging eval predictions: {ex}', exc_info=ex)
  150. # Adapted from transformers.keras_callbacks.PushToHubCallback.__init__
  151. # Need to run it earlier
  152. def init_output_dir_from_hub_for_lora(cfg):
  153. hub_model_id = cfg.hub_model_id
  154. output_dir = Path(cfg.output_dir)
  155. if "/" not in hub_model_id:
  156. hub_model_id = get_full_repo_name(hub_model_id)
  157. create_repo(hub_model_id, exist_ok=True)
  158. repo = Repository(str(output_dir), clone_from=hub_model_id)
  159. cfg.lora_model_dir = str(output_dir)
  160. def setup_trainer_ex(cfg, train_dataset, eval_dataset, model, tokenizer):
  161. # logging.info(f'cfg.runpod.one_shot = {cfg.runpod.one_shot}')
  162. if os.environ.get('ACCELERATE_USE_DEEPSPEED', 'false') == 'true':
  163. cfg.deepspeed = os.environ.get('DEEPSPEED_CONFIG_PATH', False)
  164. logging.info('setup_trainer_ex before')
  165. trainer = setup_trainer_orig(cfg, train_dataset, eval_dataset, model, tokenizer)
  166. logging.info('setup_trainer_ex after')
  167. trainer.args.include_inputs_for_metrics = True
  168. compute_metrics_orig = trainer.compute_metrics
  169. tokenizer = context['tokenizer']
  170. logging.info(f'context.tokenizer: {tokenizer}')
  171. def compute_metrics(ep):
  172. metrics = compute_metrics_orig(ep) if compute_metrics_orig else {}
  173. log_eval_prediction(ep, tokenizer)
  174. return metrics
  175. if cfg.runpod.log_eval:
  176. trainer.compute_metrics = compute_metrics
  177. # Only the main process can get `wandb.run` for multiple-GPU training.
  178. if wandb.run:
  179. log_info(f"Training started: {wandb.run.get_url()}")
  180. return trainer
  181. def load_tokenizer_ex(
  182. tokenizer_config,
  183. tokenizer_type,
  184. cfg,
  185. ):
  186. tokenizer = load_tokenizer_orig(tokenizer_config, tokenizer_type, cfg)
  187. context['tokenizer'] = tokenizer
  188. return tokenizer
  189. def load_model_ex(
  190. base_model, base_model_config, model_type, tokenizer, cfg, adapter="lora"
  191. ):
  192. # if cfg.hf_use_auth_token and cfg.hub_model_id and cfg.adapter:
  193. # init_output_dir_from_hub_for_lora(cfg)
  194. # logging.info(f'cfg.lora_model_dir: {cfg.lora_model_dir}')
  195. local_model_path = os.path.join('models', f"{'_'.join(base_model.split('/')[-2:])}")
  196. local_model_config_path = os.path.join('models', f"{'_'.join(base_model_config.split('/')[-2:])}")
  197. # The model should be pre-downloaded before this training script
  198. if Path(local_model_path).exists() and Path(local_model_config_path).exists():
  199. log_info(f'Loading model from local: base_model={local_model_path} base_model_config={local_model_config_path}')
  200. model = load_model_orig(local_model_path, local_model_config_path, model_type, tokenizer, cfg, adapter)
  201. else:
  202. model = load_model_orig(base_model, base_model_config, model_type, tokenizer, cfg, adapter)
  203. context['model'] = model
  204. return model
  205. if __name__ == "__main__":
  206. finetune.setup_trainer = setup_trainer_ex
  207. finetune.load_tokenizer = load_tokenizer_ex
  208. finetune.load_model = load_model_ex
  209. fire.Fire(train_ex)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...