1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
|
- import os
- import sys
- from pathlib import Path
- import fire
- import logging
- import wandb
- import pandas as pd
- from discord import SyncWebhook
- from addict import Dict
- import yaml
- import runpod
- from huggingface_hub import Repository, create_repo
- from transformers.utils import get_full_repo_name
- from transformers.trainer_callback import TrainerCallback
- from accelerate import Accelerator
- from accelerate.tracking import on_main_process
- import torch
- import numpy as np
- from huggingface_hub import login
- login(os.environ.get("HUGGINGFACE_TOKEN"), add_to_git_credential=True)
- project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
- axolotl_root = os.getenv("AXOLOTL_ROOT", os.path.abspath(os.path.join(project_root, "../axolotl")))
- src_dir = os.path.join(axolotl_root, "src")
- scripts_dir = os.path.join(axolotl_root, "scripts")
- sys.path.insert(0, src_dir)
- sys.path.insert(0, scripts_dir)
- import finetune
- import axolotl
- from axolotl.utils.trainer import setup_trainer as setup_trainer_orig
- from axolotl.utils.models import load_tokenizer as load_tokenizer_orig, load_model as load_model_orig
- from axolotl.utils.dict import DictDefault
- logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"))
- context = {}
- # TODO: avoid code dup
- def notify_discord(msg):
- webhook = SyncWebhook.from_url(os.getenv("DISCORD_WEBHOOK_URL"))
- webhook.send(msg)
- def edit_discord_message(last_msg, msg):
- return last_msg.edit(content=msg)
- def log_info(msg):
- logging.info(msg)
- notify_discord(msg)
- def log_error(msg, exc_info=None):
- logging.error(msg, exc_info=exc_info)
- if exc_info is not None:
- notify_discord(f'{msg}: {exc_info}')
- else:
- notify_discord(msg)
- def local_rank():
- return int(os.environ.get("LOCAL_RANK", 0))
- def parse_config(config, kwargs):
- # TODO: avoid code dup
- # load the config from the yaml file
- # Mostly borrowed from https://github.com/utensil/axolotl/blob/local_dataset/scripts/finetune.py
- with open(config, encoding="utf-8") as file:
- cfg: DictDefault = DictDefault(yaml.safe_load(file))
- # if there are any options passed in the cli, if it is something that seems valid from the yaml,
- # then overwrite the value
- cfg_keys = cfg.keys()
- for k, _ in kwargs.items():
- # if not strict, allow writing to cfg even if it's not in the yml already
- if k in cfg_keys or not cfg.strict:
- # handle booleans
- if isinstance(cfg[k], bool):
- cfg[k] = bool(kwargs[k])
- else:
- cfg[k] = kwargs[k]
- return cfg
- def init_accelerator_with_trackers(cfg):
- # os.environ["WANDB_RESUME"] = "auto"
- if cfg.wandb_project is not None:
- run_id = cfg.wandb_run_id or wandb.util.generate_id()
- accelerator = Accelerator(log_with="wandb")
- accelerator.init_trackers(cfg.wandb_project, init_kwargs={"id": run_id})
- # run = wandb.init(project=cfg.wandb_project, id=run_id) #, resume=True)
- os.environ["WANDB_RUN_ID"] = run_id
- return accelerator
- else:
- accelerator = Accelerator()
- return accelerator
- def train_ex(
- config,
- prepare_ds_only: bool = False,
- **kwargs,
- ):
- config = Path(config.strip())
- if local_rank() == 0:
- log_info(f"Prepare training with config: {config}")
- cfg = parse_config(config, kwargs)
- accelerator = init_accelerator_with_trackers(cfg)
- try:
- logging.info('train_ex before')
- finetune.train(config, cfg.runpod.prepare_ds_only or prepare_ds_only, **kwargs)
- logging.info('train_ex after')
- except Exception as ex:
- log_error(f"Error during training: {ex}", exc_info=ex)
- finally:
- accelerator.end_training()
- # If we need it stay alive for inspection, we should set one_shot to false
- if cfg.runpod.one_shot:
- runpod.api_key = os.getenv("RUNPOD_API_KEY")
- pod_id = os.getenv("RUNPOD_POD_ID")
- log_info(f"Pod {pod_id} terminated on train end")
- runpod.terminate_pod(pod_id)
- def log_data(name, data, tokenizer):
- # logging.info(f'{name}(type={type(data)}, shape={data.shape}):\n{data}')
- try:
- if data.ndim == 3:
- data = torch.argmax(torch.from_numpy(data), dim=-1)
- if data.ndim != 2:
- raise ValueError(f'Invalid data shape: {type(data)} {data.shape}')
-
- data = np.where(data != -100, data, tokenizer.pad_token_id)
- logging.info(f'{name}:\n{tokenizer.batch_decode(data, skip_special_tokens=True)}')
-
- if wandb.run:
- for i in range(len(data)):
- hist = wandb.Histogram(data[i]) #, num_bins=512)
- wandb.log({f"histogram/{name}": hist})
- except Exception as ex:
- logging.error(f'Error logging {name}: {ex}', exc_info=ex)
- def decode_data(name, data, tokenizer):
- try:
- if data.ndim == 3:
- data = torch.argmax(torch.from_numpy(data), dim=-1)
- if data.ndim != 2:
- raise ValueError(f'Invalid data shape: {type(data)} {data.shape}')
-
- data = np.where(data != -100, data, tokenizer.pad_token_id)
- return tokenizer.batch_decode(data)
- except Exception as ex:
- logging.error(f'Error decoding {name}, returning empty strings: {ex}', exc_info=ex)
- return ['' for _ in range(len(data))]
- def log_eval_prediction_debug(ep, tokenizer):
- log_data('inputs', ep.inputs, tokenizer)
- log_data('predictions', ep.predictions, tokenizer)
- log_data('labels', ep.label_ids, tokenizer)
- def log_eval_prediction(ep, tokenizer):
- if wandb.run:
- try:
- data = {
- 'input': decode_data('inputs', ep.inputs, tokenizer),
- 'prediction': decode_data('predictions', ep.predictions, tokenizer),
- 'labels': decode_data('label_ids', ep.label_ids, tokenizer)
- }
- df = pd.DataFrame(data)
- table = wandb.Table(dataframe=df)
- artifact = wandb.Artifact('eval', type="dataset")
- artifact.add(table, 't_eval')
- wandb.run.log_artifact(artifact)
- except Exception as ex:
- logging.error(f'Error logging eval predictions: {ex}', exc_info=ex)
- # Adapted from transformers.keras_callbacks.PushToHubCallback.__init__
- # Need to run it earlier
- def init_output_dir_from_hub_for_lora(cfg):
- hub_model_id = cfg.hub_model_id
- output_dir = Path(cfg.output_dir)
- if "/" not in hub_model_id:
- hub_model_id = get_full_repo_name(hub_model_id)
- create_repo(hub_model_id, exist_ok=True)
- repo = Repository(str(output_dir), clone_from=hub_model_id)
- cfg.lora_model_dir = str(output_dir)
-
- def setup_trainer_ex(cfg, train_dataset, eval_dataset, model, tokenizer):
- # logging.info(f'cfg.runpod.one_shot = {cfg.runpod.one_shot}')
- if os.environ.get('ACCELERATE_USE_DEEPSPEED', 'false') == 'true':
- cfg.deepspeed = os.environ.get('DEEPSPEED_CONFIG_PATH', False)
- logging.info('setup_trainer_ex before')
- trainer = setup_trainer_orig(cfg, train_dataset, eval_dataset, model, tokenizer)
- logging.info('setup_trainer_ex after')
- trainer.args.include_inputs_for_metrics = True
- compute_metrics_orig = trainer.compute_metrics
- tokenizer = context['tokenizer']
- logging.info(f'context.tokenizer: {tokenizer}')
- def compute_metrics(ep):
- metrics = compute_metrics_orig(ep) if compute_metrics_orig else {}
- log_eval_prediction(ep, tokenizer)
- return metrics
- if cfg.runpod.log_eval:
- trainer.compute_metrics = compute_metrics
- # Only the main process can get `wandb.run` for multiple-GPU training.
- if wandb.run:
- log_info(f"Training started: {wandb.run.get_url()}")
-
- return trainer
- def load_tokenizer_ex(
- tokenizer_config,
- tokenizer_type,
- cfg,
- ):
- tokenizer = load_tokenizer_orig(tokenizer_config, tokenizer_type, cfg)
- context['tokenizer'] = tokenizer
- return tokenizer
- def load_model_ex(
- base_model, base_model_config, model_type, tokenizer, cfg, adapter="lora"
- ):
-
- # if cfg.hf_use_auth_token and cfg.hub_model_id and cfg.adapter:
- # init_output_dir_from_hub_for_lora(cfg)
- # logging.info(f'cfg.lora_model_dir: {cfg.lora_model_dir}')
- local_model_path = os.path.join('models', f"{'_'.join(base_model.split('/')[-2:])}")
- local_model_config_path = os.path.join('models', f"{'_'.join(base_model_config.split('/')[-2:])}")
- # The model should be pre-downloaded before this training script
- if Path(local_model_path).exists() and Path(local_model_config_path).exists():
- log_info(f'Loading model from local: base_model={local_model_path} base_model_config={local_model_config_path}')
- model = load_model_orig(local_model_path, local_model_config_path, model_type, tokenizer, cfg, adapter)
- else:
- model = load_model_orig(base_model, base_model_config, model_type, tokenizer, cfg, adapter)
- context['model'] = model
- return model
- if __name__ == "__main__":
- finetune.setup_trainer = setup_trainer_ex
- finetune.load_tokenizer = load_tokenizer_ex
- finetune.load_model = load_model_ex
- fire.Fire(train_ex)
-
|