-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy patheval.py
More file actions
36 lines (25 loc) · 821 Bytes
/
eval.py
File metadata and controls
36 lines (25 loc) · 821 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import numpy as np
import torch
torch.manual_seed(159753)
np.random.seed(159753)
from torchvision.utils import save_image
from utils import unloader
from modules import Generator
if __name__ == '__main__':
device = 'cuda'
config = {
"noise_dim": 128,
"gen_hidden": 512,
"input_shape": (28, 28),
}
generator = Generator(
input_dim = config["noise_dim"],
hidden_dim = config["gen_hidden"],
output_dim = int(np.array(config["input_shape"]).prod())
).to(device)
generator.load_state_dict(torch.load('saved_models/generator_93700.pt'))
generator.eval()
sample = torch.randn(64, config["noise_dim"]).to(device)
sample = generator(sample)
sample = unloader(sample)
save_image(sample.view(64, 1, 28, 28), 'results/output.png')