-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
21 lines (13 loc) · 736 Bytes
/
utils.py
File metadata and controls
21 lines (13 loc) · 736 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import os
import torchvision.transforms as T
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
loader = T.Compose([T.ToTensor(),
T.Lambda(lambda t: (t * 2) - 1)])
unloader = T.Lambda(lambda t: (t + 1) / 2)
def get_loaders(config):
train_data = MNIST(root="data/", train=True, download=True, transform=loader)
test_data = MNIST(root="data/", train=False, download=True, transform=loader)
train_loader = DataLoader(train_data, batch_size=config["bs"], num_workers=config["num_workers"], shuffle=True, drop_last=True)
test_loader = DataLoader(test_data, batch_size=config["bs"], num_workers=config["num_workers"], drop_last=True)
return train_loader, test_loader