1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
|
- """Datasets"""
- import os
- import torch
- from torch.utils.data import DataLoader, Dataset
- from torchvision import datasets
- import torchvision.transforms as transforms
- import torchvision
- import glob
- import PIL
- import random
- import math
- import pickle
- import numpy as np
- class CelebA(Dataset):
- """CelelebA Dataset"""
- def __init__(self, dataset_path, img_size, **kwargs):
- super().__init__()
- self.data = glob.glob(dataset_path)
- assert len(self.data) > 0, "Can't find data; make sure you specify the path to your dataset"
- self.transform = transforms.Compose(
- [transforms.Resize(320), transforms.CenterCrop(256), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), transforms.RandomHorizontalFlip(p=0.5), transforms.Resize((img_size, img_size), interpolation=0)])
- def __len__(self):
- return len(self.data)
- def __getitem__(self, index):
- X = PIL.Image.open(self.data[index])
- X = self.transform(X)
- return X, 0
- class Cats(Dataset):
- """Cats Dataset"""
- def __init__(self, dataset_path, img_size, **kwargs):
- super().__init__()
-
- self.data = glob.glob(dataset_path)
- assert len(self.data) > 0, "Can't find data; make sure you specify the path to your dataset"
- self.transform = transforms.Compose(
- [transforms.Resize((img_size, img_size), interpolation=0), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), transforms.RandomHorizontalFlip(p=0.5)])
- def __len__(self):
- return len(self.data)
-
- def __getitem__(self, index):
- X = PIL.Image.open(self.data[index])
- X = self.transform(X)
-
- return X, 0
- class Carla(Dataset):
- """Carla Dataset"""
- def __init__(self, dataset_path, img_size, **kwargs):
- super().__init__()
-
- self.data = glob.glob(dataset_path)
- assert len(self.data) > 0, "Can't find data; make sure you specify the path to your dataset"
- self.transform = transforms.Compose(
- [transforms.Resize((img_size, img_size), interpolation=0), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
- def __len__(self):
- return len(self.data)
-
- def __getitem__(self, index):
- X = PIL.Image.open(self.data[index])
- X = self.transform(X)
-
- return X, 0
- def get_dataset(name, subsample=None, batch_size=1, **kwargs):
- dataset = globals()[name](**kwargs)
- dataloader = torch.utils.data.DataLoader(
- dataset,
- batch_size=batch_size,
- shuffle=True,
- drop_last=True,
- pin_memory=False,
- num_workers=8
- )
- return dataloader, 3
- def get_dataset_distributed(name, world_size, rank, batch_size, **kwargs):
- dataset = globals()[name](**kwargs)
- sampler = torch.utils.data.distributed.DistributedSampler(
- dataset,
- num_replicas=world_size,
- rank=rank,
- )
- dataloader = torch.utils.data.DataLoader(
- dataset,
- sampler=sampler,
- batch_size=batch_size,
- shuffle=False,
- drop_last=True,
- pin_memory=True,
- num_workers=4,
- )
- return dataloader, 3
|