Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

csv_dataset.py 2.5 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
  1. # Copyright (c) 2022, NVIDIA CORPORATION.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import time
  16. import pickle
  17. from typing import Optional
  18. from dataclasses import dataclass
  19. import torch
  20. import numpy as np
  21. from nemo.core import Dataset
  22. from nemo.utils import logging
  23. from nemo.collections.nlp.data.language_modeling.text_memmap_dataset import CSVMemMapDataset
  24. try:
  25. from apex.transformer.parallel_state import get_rank_info
  26. HAVE_APEX = True
  27. except (ImportError, ModuleNotFoundError):
  28. HAVE_APEX = False
  29. __all__ = ['MoleculeCsvDatasetConfig', 'MoleculeCsvDataset', 'DatasetFileConfig']
  30. @dataclass
  31. class DatasetFileConfig():
  32. train: str = None
  33. test: str = None
  34. val: str = None
  35. @dataclass
  36. class MoleculeCsvDatasetConfig():
  37. dataset_path: str = ''
  38. dataset: DatasetFileConfig = None
  39. newline_int: int = 10
  40. header_lines: int = 1
  41. data_col: int = 1
  42. data_sep: str = ','
  43. sort_dataset_paths: bool = True
  44. # FIXME: remove unneeded config variables
  45. skip_lines: int = 0
  46. micro_batch_size: int = 1
  47. encoder_augment: bool = False
  48. encoder_mask: bool = False
  49. decoder_augment: bool = False
  50. decoder_mask: bool = False
  51. canonicalize_input: bool = True
  52. dataloader_type: str = 'single'
  53. drop_last: bool = False
  54. pin_memory: bool = False # must be False with CSV dataset
  55. num_workers: Optional[int] = None
  56. class MoleculeCsvDataset(CSVMemMapDataset):
  57. """
  58. Allow per-line lazy access to multiple text files using numpy memmap.
  59. """
  60. def __init__(self,
  61. dataset_paths,
  62. cfg,
  63. workers=None):
  64. super().__init__(
  65. dataset_paths=dataset_paths,
  66. newline_int=cfg.get('newline_int'),
  67. header_lines=cfg.get('header_lines'), # skip first N lines
  68. workers=workers,
  69. tokenizer=None,
  70. sort_dataset_paths=cfg.get('sort_dataset_paths'),
  71. data_col=cfg.get('data_col'),
  72. data_sep=cfg.get('data_sep'),
  73. )
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...