Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

download-huge.py 6.8 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
  1. # Tested with the following commands
  2. #
  3. # 1. One complete run
  4. #
  5. # > python helper/download-huge.py bigcode/the-stack 'data/arc/*' 'data/cmake/*'|grep -v Redirecting
  6. # Resume downloading of 0 LFS files
  7. # Found 4 LFS files to download from 4 matched files
  8. #
  9. # 2. Interupt and resume downloading
  10. #
  11. # Resume downloading of 4 LFS files
  12. # Found 1 LFS files to download from 4 matched files
  13. #
  14. # 3. Remove a file and resume downloading
  15. #
  16. # > rm datasets/bigcode_the-stack/data/cmake/train-00002-of-00003.parquet
  17. # > python helper/download-huge.py bigcode/the-stack 'data/arc/*' 'data/cmake/*'|grep -v Redirecting
  18. #
  19. # Resume downloading of 5 LFS files
  20. # Found 0 LFS files to download from 3 matched files
  21. from huggingface_hub import Repository, repository
  22. import argparse
  23. import os
  24. import re
  25. from pathlib import Path
  26. import subprocess
  27. import json
  28. import tqdm
  29. from tqdm.contrib.concurrent import thread_map
  30. parser = argparse.ArgumentParser()
  31. parser.add_argument('REPO', type=str, default=None, nargs='?')
  32. parser.add_argument('PATHS', type=str, default=None, nargs='*', help='Relative glob pattern(s) to download the dataset/model from')
  33. parser.add_argument('-m', '--model', action='store_true', help='By default this script downloads datasets, pass -d to download models instead')
  34. parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from')
  35. parser.add_argument('--output', type=str, default=None, help='The folder where the dataset/model should be saved.')
  36. parser.add_argument('--threads', type=int, default=8, help='Number of files to download simultaneously.')
  37. args = parser.parse_args()
  38. def sanitize_branch_name(branch_name):
  39. pattern = re.compile(r"^[a-zA-Z0-9._-]+$")
  40. if pattern.match(branch_name):
  41. return branch_name
  42. else:
  43. raise ValueError("Invalid branch name. Only alphanumeric characters, period, underscore and dash are allowed.")
  44. def get_file_by_aria2(url_info, output_folder):
  45. url = url_info['link']
  46. filename = url_info['path']
  47. dir_and_basename = filename.rsplit('/', 1)
  48. dir = dir_and_basename[0] if len(dir_and_basename) == 2 else ''
  49. if not (output_folder / dir).exists():
  50. (output_folder / dir).mkdir(parents=True, exist_ok=True)
  51. total_size = url_info['size']
  52. output_file = output_folder / Path(filename)
  53. if output_file.exists() and output_file.stat().st_size == total_size and not (output_folder / Path(f"{filename}.aria2")).exists():
  54. print(f"Downloaded: {output_file}")
  55. return
  56. aria_command = f"aria2c -c -x 16 -s 16 -k 1M {url} -d {output_folder} -o {filename}"
  57. print(f"Running: {aria_command}")
  58. token = os.environ.get("HUGGINGFACE_TOKEN")
  59. aria_command = f"{aria_command} --header='authorization: Bearer {token}'"
  60. # # call command line aria2c to download
  61. subprocess.run(aria_command, shell=True, check=True)
  62. def download_files(file_list, output_folder, num_threads=8):
  63. thread_map(lambda url_info: get_file_by_aria2(url_info, output_folder), file_list, max_workers=num_threads)
  64. def parse_lfs_file(file, root, repo_type, repo, branch):
  65. file = Path(file)
  66. is_lfs = False
  67. oid = None
  68. size = None
  69. base_url = "https://huggingface.co/datasets/" if repo_type == 'dataset' else "https://huggingface.co/"
  70. fname = file.relative_to(root)
  71. url_info = None
  72. if not file.is_dir() and file.stat().st_size < 2048:
  73. with open(file, 'r', encoding='latin1') as f:
  74. for line in f:
  75. if line.startswith('version https://git-lfs.github.com/spec/'):
  76. url_info = {}
  77. url_info["path"] = f"{fname}"
  78. url_info["link"] = f"{base_url}{repo}/resolve/{branch}/{fname}"
  79. elif line.startswith('oid sha256:'):
  80. url_info["sha256sum"] = line.split('sha256:')[1].strip()
  81. elif line.startswith('size '):
  82. url_info["size"] = int(line.split('size ')[1].strip())
  83. break
  84. else:
  85. break
  86. return url_info
  87. def get_download_links(paths, output_folder, repo_type, remote_repo, branch):
  88. all_files = []
  89. lfs_files = []
  90. for path_pattern in paths:
  91. files = output_folder.glob(path_pattern)
  92. for file in files:
  93. fname = f'{file.relative_to(output_folder)}'
  94. if not fname.endswith('.aria2'):
  95. all_files.append(fname)
  96. url_info = parse_lfs_file(file, output_folder, repo_type, remote_repo, branch)
  97. if url_info is not None:
  98. lfs_files.append(url_info)
  99. return lfs_files, all_files
  100. if __name__ == '__main__':
  101. # determine the root of the repo and cd to it
  102. ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..')
  103. os.chdir(ROOT)
  104. print(f"Working directory changed to: {ROOT}")
  105. HF_TOKEN = os.environ.get("HUGGINGFACE_TOKEN")
  106. remote_repo = args.REPO
  107. branch = args.branch
  108. if remote_repo is None:
  109. print("Error: Please specify a dataset or a model to download.")
  110. sys.exit()
  111. else:
  112. if remote_repo[-1] == '/':
  113. remote_repo = remote_repo[:-1]
  114. branch = args.branch
  115. if branch is None:
  116. branch = "main"
  117. else:
  118. try:
  119. branch = sanitize_branch_name(branch)
  120. except ValueError as err_branch:
  121. print(f"Error: {err_branch}")
  122. sys.exit()
  123. repo_type = 'model' if args.model else 'dataset'
  124. if args.output is not None:
  125. base_folder = args.output
  126. else:
  127. base_folder = 'models' if args.model else 'datasets'
  128. output_folder = f"{'_'.join(remote_repo.split('/')[-2:])}"
  129. if branch != 'main':
  130. output_folder += f'_{branch}'
  131. output_folder = Path(base_folder) / output_folder
  132. repo = Repository(local_dir=output_folder, clone_from=remote_repo, repo_type=repo_type, skip_lfs_files=True, use_auth_token=HF_TOKEN)
  133. links_file = output_folder / 'links.json'
  134. lfs_files, all_files = get_download_links(args.PATHS, output_folder, repo_type, remote_repo, branch)
  135. lfs_files = []
  136. # This is to prevent partial download from losing LFS information
  137. lfs_files_old = []
  138. if links_file.exists():
  139. with open(links_file, 'r') as f:
  140. lfs_files_old = json.load(f)
  141. lfs_files.extend(lfs_files_old)
  142. lfs_files_new, all_files = get_download_links(args.PATHS, output_folder, repo_type, remote_repo, branch)
  143. lfs_files.extend(lfs_files_new)
  144. with open(links_file, 'w') as f:
  145. json.dump(lfs_files, f)
  146. print(f"Resume downloading of {len(lfs_files_old)} LFS files")
  147. download_files(lfs_files_old, output_folder, args.threads)
  148. print(f"Found {len(lfs_files_new)} LFS files to download from {len(all_files)} matched files")
  149. download_files(lfs_files_new, output_folder, args.threads)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...