Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

ensemble.py 2.4 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
  1. import os
  2. import sys
  3. import yaml
  4. import torch
  5. import importlib
  6. import numpy as np
  7. import pandas as pd
  8. from tqdm import tqdm
  9. from functools import reduce
  10. from pathlib import Path
  11. from dotenv import load_dotenv
  12. load_dotenv('envs/.env')
  13. with open('params.yaml', 'r') as f:
  14. PARAMS = yaml.safe_load(f)
  15. def inference(bert_model, pretrained_model, method='lstm'):
  16. try:
  17. model_module = importlib.import_module(f'model.{bert_model}.{method}')
  18. model = model_module.Model(
  19. **PARAMS[bert_model], **PARAMS[bert_model][method],
  20. pretrained_model=pretrained_model
  21. )
  22. except Exception as e:
  23. raise e
  24. if torch.cuda.is_available():
  25. device = torch.device('cuda', PARAMS.get('gpu', 0))
  26. else:
  27. device = torch.device('cpu')
  28. model_path = Path(os.getenv('OUTPUT_PATH'), f'{bert_model}-{pretrained_model}-{method}_{os.getenv("MODEL_PATH")}')
  29. model.load_model(model_path)
  30. model.to(device)
  31. df = pd.read_csv('data/test.csv')
  32. try:
  33. dataloader_module = importlib.import_module(f'data_loader.{bert_model}_dataloaders')
  34. except Exception as e:
  35. raise e
  36. df[PARAMS['label']] = 0
  37. with torch.no_grad():
  38. all_preds = list()
  39. inference_dataloader = dataloader_module.DataFrameDataLoader(
  40. df, pretrained_model=pretrained_model,
  41. do_lower_case=PARAMS[bert_model]['do_lower_case'],
  42. batch_size=PARAMS['evaluate']['batch_size'], max_len=PARAMS[bert_model]['eval_max_len']
  43. )
  44. for idx, (label, text, offsets) in enumerate(tqdm(inference_dataloader)):
  45. predicted_label = model(text, offsets)
  46. predicted_label = predicted_label.squeeze(dim=-1)
  47. all_preds += [predicted_label.detach().cpu().numpy()]
  48. all_preds = np.concatenate(all_preds, axis=0)
  49. df[PARAMS['label']] = all_preds
  50. return df
  51. if __name__ == '__main__':
  52. xlnet_df = inference('xlnet', 'xlnet-large-cased', 'basic')
  53. xlnet_df = xlnet_df[['ID', PARAMS['label']]]
  54. bert_df = inference('bert', 'bert-large-uncased', 'basic')
  55. bert_df = bert_df[['ID', PARAMS['label']]]
  56. roberta_df = inference('roberta', 'roberta-large', 'basic')
  57. roberta_df = roberta_df[['ID', PARAMS['label']]]
  58. df = reduce(
  59. lambda left, right: pd.merge(left, right, on='ID', how='left'),
  60. [bert_df, roberta_df, xlnet_df]
  61. )
  62. submission_path = Path(os.getenv('OUTPUT_PATH'), 'ensemble.csv')
  63. df.to_csv(submission_path, index=False)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...