Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

utils.py 5.1 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
  1. # Copyright (c) 2022, NVIDIA CORPORATION.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import List
  15. from enum import Enum
  16. import re
  17. import braceexpand
  18. import os
  19. from copy import deepcopy
  20. from omegaconf import DictConfig, open_dict
  21. import torch.utils.data as pt_data
  22. from pytorch_lightning.trainer.trainer import Trainer
  23. from nemo.utils import logging
  24. from .csv_dataset import MoleculeCsvDataset
  25. from .molecule_binary_dataset import MoleculeBinaryDataset
  26. __all__ = ['DatasetTypes', 'expand_dataset_paths', 'build_train_valid_test_datasets']
  27. class DatasetTypes(Enum):
  28. zinc_csv = 0
  29. def expand_dataset_paths(filepath: str, ext: str) -> List[str]:
  30. """Expand dataset paths from braces"""
  31. filepath = filepath + ext if ext else filepath
  32. # TODO this should eventually be moved to a Nemo fileutils module or similar
  33. filepath = re.sub(r"""\(|\[|\<|_OP_""", '{', filepath) # replaces '(', '[', '<' and '_OP_' with '{'
  34. filepath = re.sub(r"""\)|\]|\>|_CL_""", '}', filepath) # replaces ')', ']', '>' and '_CL_' with '}'
  35. dataset_paths = list(braceexpand.braceexpand(filepath))
  36. return dataset_paths
  37. def check_paths_exist(dataset_paths, dataset_format):
  38. """Check that the expanded dataset paths are valid and they exist."""
  39. errors = []
  40. for filepath in dataset_paths:
  41. if dataset_format == "csv":
  42. if not os.path.exists(filepath):
  43. errors.append(filepath)
  44. if dataset_format == "bin":
  45. binfile = filepath + ".bin"
  46. if not os.path.exists(binfile):
  47. errors.append(binfile)
  48. return errors
  49. def _build_train_valid_test_datasets(
  50. cfg: DictConfig,
  51. trainer: Trainer,
  52. num_samples: int,
  53. filepath: str,
  54. metadata_path: str,
  55. dataset_format: str
  56. ):
  57. # TODO num_samples is currently not used
  58. cfg = deepcopy(cfg)
  59. with open_dict(cfg):
  60. cfg['metadata_path'] = metadata_path
  61. # Get datasets and load data
  62. logging.info(f'Loading data from {filepath}')
  63. dataset_paths = expand_dataset_paths(filepath, ".csv") if dataset_format == "csv" else expand_dataset_paths(filepath, None)
  64. errors = check_paths_exist(dataset_paths, dataset_format)
  65. assert len(errors) == 0, "Following files do not exist %s" % ' '.join(errors)
  66. logging.info(f'Loading data from {dataset_paths}')
  67. dataset_list = []
  68. if dataset_format == "csv":
  69. dataset = MoleculeCsvDataset(dataset_paths=dataset_paths, cfg=cfg)
  70. elif dataset_format == "bin":
  71. for path in dataset_paths:
  72. data = MoleculeBinaryDataset(filepath=path, cfg=cfg, trainer=trainer, num_samples=num_samples)
  73. dataset_list.append(data)
  74. num_samples -= len(data)
  75. if num_samples < 1:
  76. break
  77. if len(dataset_list) == 1:
  78. dataset = dataset_list[0]
  79. else:
  80. dataset = pt_data.ConcatDataset(dataset_list)
  81. else:
  82. raise ValueError("Unrecognized data format. Expected csv or bin.")
  83. return dataset
  84. def build_train_valid_test_datasets(
  85. cfg: DictConfig,
  86. trainer: Trainer,
  87. train_valid_test_num_samples: List[int]
  88. ):
  89. # TODO metadata_file is currently not used
  90. cfg = deepcopy(cfg)
  91. with open_dict(cfg):
  92. dataset_path = cfg.pop('dataset_path', '')
  93. # dataset = cfg.pop('dataset')
  94. metadata_file = cfg.pop('metadata_file', None)
  95. dataset_format = cfg.pop('dataset_format')
  96. ds_train = cfg.dataset.train
  97. ds_val = cfg.dataset.val
  98. ds_test = cfg.dataset.test
  99. cfg.pop('dataset')
  100. # Build individual datasets.
  101. filepath = os.path.join(dataset_path, 'train', ds_train)
  102. metadata_path = os.path.join(dataset_path, 'train', metadata_file) if metadata_file else None
  103. train_dataset = _build_train_valid_test_datasets(cfg, trainer, train_valid_test_num_samples[0],
  104. filepath, metadata_path, dataset_format)
  105. filepath = os.path.join(dataset_path, 'val', ds_val)
  106. metadata_path = os.path.join(dataset_path, 'val', metadata_file) if metadata_file else None
  107. validation_dataset = _build_train_valid_test_datasets(cfg, trainer, train_valid_test_num_samples[1],
  108. filepath, metadata_path, dataset_format)
  109. filepath = os.path.join(dataset_path, 'test', ds_test)
  110. metadata_path = os.path.join(dataset_path, 'test', metadata_file) if metadata_file else None
  111. test_dataset = _build_train_valid_test_datasets(cfg, trainer, train_valid_test_num_samples[2],
  112. filepath, metadata_path, dataset_format)
  113. return (train_dataset, validation_dataset, test_dataset)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...