1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
|
- import os
- import sys
- import json
- import yaml
- import torch
- import importlib
- import pandas as pd
- from pathlib import Path
- from dotenv import load_dotenv
- import logging
- logging.basicConfig(
- level=logging.DEBUG,
- format="%(asctime)s [%(levelname)s] %(message)s",
- handlers=[
- logging.FileHandler("debug.log"),
- logging.StreamHandler()
- ]
- )
- load_dotenv('envs/.env')
- with open('params.yaml', 'r') as f:
- PARAMS = yaml.safe_load(f)
- def start_training(bert_model, pretrained_model, method='basic'):
- try:
- model_module = importlib.import_module(f'model.{bert_model}.{method}')
- model = model_module.Model(
- **PARAMS[bert_model], **PARAMS[bert_model][method],
- pretrained_model=pretrained_model
- )
- except Exception as e:
- raise e
- if torch.cuda.is_available():
- device = torch.device('cuda', PARAMS.get('gpu', 0))
- else:
- device = torch.device('cpu')
- model.to(device)
- df = pd.read_csv('data/all.csv')
- try:
- dataloader_module = importlib.import_module(f'data_loader.{bert_model}_dataloaders')
- except Exception as e:
- raise e
- dataloader = dataloader_module.DataFrameDataLoader(
- df, pretrained_model=pretrained_model,
- do_lower_case=PARAMS[bert_model]['do_lower_case'],
- batch_size=PARAMS['train']['batch_size'],
- shuffle=PARAMS['validate']['shuffle'], max_len=PARAMS[bert_model]['max_len']
- )
- try:
- trainer_module = importlib.import_module(f'training.{bert_model}')
- bert_model_name = f'{bert_model}-{pretrained_model}-{method}'
- trainer = trainer_module.Trainer(model, dataloader, method=bert_model_name, mode='train')
- except Exception as e:
- raise e
- results, losses = trainer.train()
- columns = list(losses[0].keys())
- losses_df = pd.DataFrame(losses, columns=columns)
- return results, losses_df
- if __name__ == '__main__':
- bert_model, pretrained_model, method = sys.argv[1], sys.argv[2], sys.argv[3]
- try:
- results, losses_df = start_training(bert_model, pretrained_model, method)
- except Exception as e:
- logging.error(e)
- raise e
- results_path = Path(
- os.getenv('OUTPUT_PATH'),
- f'{bert_model}-{pretrained_model}-{method}_{os.getenv("RESULTS_PATH")}'
- )
- with open(results_path, 'w') as f:
- json.dump(results, f)
- plots_path = Path(
- os.getenv('OUTPUT_PATH'),
- f'{bert_model}-{pretrained_model}-{method}_{os.getenv("PLOTS_PATH")}'
- )
- losses_df.to_csv(plots_path, index=False)
|