Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

datasets.py 3.1 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
  1. """Datasets"""
  2. import os
  3. import torch
  4. from torch.utils.data import DataLoader, Dataset
  5. from torchvision import datasets
  6. import torchvision.transforms as transforms
  7. import torchvision
  8. import glob
  9. import PIL
  10. import random
  11. import math
  12. import pickle
  13. import numpy as np
  14. class CelebA(Dataset):
  15. """CelelebA Dataset"""
  16. def __init__(self, dataset_path, img_size, **kwargs):
  17. super().__init__()
  18. self.data = glob.glob(dataset_path)
  19. assert len(self.data) > 0, "Can't find data; make sure you specify the path to your dataset"
  20. self.transform = transforms.Compose(
  21. [transforms.Resize(320), transforms.CenterCrop(256), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), transforms.RandomHorizontalFlip(p=0.5), transforms.Resize((img_size, img_size), interpolation=0)])
  22. def __len__(self):
  23. return len(self.data)
  24. def __getitem__(self, index):
  25. X = PIL.Image.open(self.data[index])
  26. X = self.transform(X)
  27. return X, 0
  28. class Cats(Dataset):
  29. """Cats Dataset"""
  30. def __init__(self, dataset_path, img_size, **kwargs):
  31. super().__init__()
  32. self.data = glob.glob(dataset_path)
  33. assert len(self.data) > 0, "Can't find data; make sure you specify the path to your dataset"
  34. self.transform = transforms.Compose(
  35. [transforms.Resize((img_size, img_size), interpolation=0), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), transforms.RandomHorizontalFlip(p=0.5)])
  36. def __len__(self):
  37. return len(self.data)
  38. def __getitem__(self, index):
  39. X = PIL.Image.open(self.data[index])
  40. X = self.transform(X)
  41. return X, 0
  42. class Carla(Dataset):
  43. """Carla Dataset"""
  44. def __init__(self, dataset_path, img_size, **kwargs):
  45. super().__init__()
  46. self.data = glob.glob(dataset_path)
  47. assert len(self.data) > 0, "Can't find data; make sure you specify the path to your dataset"
  48. self.transform = transforms.Compose(
  49. [transforms.Resize((img_size, img_size), interpolation=0), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
  50. def __len__(self):
  51. return len(self.data)
  52. def __getitem__(self, index):
  53. X = PIL.Image.open(self.data[index])
  54. X = self.transform(X)
  55. return X, 0
  56. def get_dataset(name, subsample=None, batch_size=1, **kwargs):
  57. dataset = globals()[name](**kwargs)
  58. dataloader = torch.utils.data.DataLoader(
  59. dataset,
  60. batch_size=batch_size,
  61. shuffle=True,
  62. drop_last=True,
  63. pin_memory=False,
  64. num_workers=8
  65. )
  66. return dataloader, 3
  67. def get_dataset_distributed(name, world_size, rank, batch_size, **kwargs):
  68. dataset = globals()[name](**kwargs)
  69. sampler = torch.utils.data.distributed.DistributedSampler(
  70. dataset,
  71. num_replicas=world_size,
  72. rank=rank,
  73. )
  74. dataloader = torch.utils.data.DataLoader(
  75. dataset,
  76. sampler=sampler,
  77. batch_size=batch_size,
  78. shuffle=False,
  79. drop_last=True,
  80. pin_memory=True,
  81. num_workers=4,
  82. )
  83. return dataloader, 3
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...