Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

feature_loader.py 1.8 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
  1. import torch
  2. import numpy as np
  3. import h5py
  4. class SimpleHDF5Dataset:
  5. def __init__(self, file_handle = None, depth=False):
  6. self.depth = depth
  7. if file_handle == None:
  8. self.f = ''
  9. self.all_feats_dset = []
  10. self.all_labels = []
  11. self.total = 0
  12. else:
  13. self.f = file_handle
  14. self.all_feats_dset = self.f['all_feats'][...]
  15. self.all_labels = self.f['all_labels'][...]
  16. # if depth:
  17. # self.all_feats_depth = self.f['all_feats_depth'][...]
  18. self.total = self.f['count'][0]
  19. # print('here')
  20. def __getitem__(self, i):
  21. if self.depth:
  22. return torch.Tensor(self.all_feats_dset[i,:]), torch.Tensor(self.all_feats_depth[i,:]), int(self.all_labels[i])
  23. else:
  24. return torch.Tensor(self.all_feats_dset[i,:]), int(self.all_labels[i])
  25. def __len__(self):
  26. return self.total
  27. def init_loader(filename, depth=False):
  28. with h5py.File(filename, 'r') as f:
  29. fileset = SimpleHDF5Dataset(f, depth=depth)
  30. #labels = [ l for l in fileset.all_labels if l != 0]
  31. feats = fileset.all_feats_dset
  32. labels = fileset.all_labels
  33. # if depth:
  34. # feats_depth = fileset.all_feats_depth
  35. while np.sum(feats[-1]) == 0:
  36. feats = np.delete(feats,-1,axis = 0)
  37. labels = np.delete(labels,-1,axis = 0)
  38. # if depth:
  39. # feats_depth = np.delete(feats_depth,-1,axis = 0)
  40. class_list = np.unique(np.array(labels)).tolist()
  41. inds = range(len(labels))
  42. cl_data_file = {}
  43. for cl in class_list:
  44. cl_data_file[cl] = []
  45. for ind in inds:
  46. # if depth:
  47. # cl_data_file[labels[ind]].append((feats[ind],feats_depth[ind]))
  48. # else:
  49. cl_data_file[labels[ind]].append(feats[ind])
  50. return cl_data_file
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...