-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy patheval.py
More file actions
144 lines (129 loc) · 6.82 KB
/
eval.py
File metadata and controls
144 lines (129 loc) · 6.82 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as thd
import core.dataset as cd
from tqdm import tqdm
from collections import defaultdict
import argparse
import gc
CKPT_DIR = os.path.join('checkpoints')
# Collect arguments (if any)
parser = argparse.ArgumentParser()
# Cache prefix
parser.add_argument('cache_prefix', nargs='?', type=str, choices=['mel256', 'wavelet', '44mel256', '24mel256'], default='mel256', help="Mel spectrogram or wavelets.")
# Checkpoint directory
parser.add_argument('-dir', '--ckpt_dir', type=str, choices=os.listdir(CKPT_DIR), default=sorted(os.listdir(CKPT_DIR))[-1], help="Checkpoints dir.")
# Checkpoint directory
parser.add_argument('-dir2', '--ckpt_dir2', type=str, choices=os.listdir(CKPT_DIR), default=sorted(os.listdir(CKPT_DIR))[-2], help="Second checkpoints dir.")
# Cache prefix
parser.add_argument('--cache_prefix2', type=str, choices=['mel256', 'wavelet', '44mel256'], default='mel256', help="Mel spectrogram or wavelets.")
# Checkpoint directory
parser.add_argument('-dir3', '--ckpt_dir3', type=str, choices=os.listdir(CKPT_DIR), default=None, help="Third checkpoints dir.")
# Cache prefix
parser.add_argument('--cache_prefix3', type=str, choices=['mel256', 'wavelet', '44mel256'], default='mel256', help="Mel spectrogram or wavelets.")
# Checkpoint directory
parser.add_argument('-dir4', '--ckpt_dir4', type=str, choices=os.listdir(CKPT_DIR), default=None, help="Fourth checkpoints dir.")
# Cache prefix
parser.add_argument('--cache_prefix4', type=str, choices=['mel256', 'wavelet', '44mel256'], default='mel256', help="Mel spectrogram or wavelets.")
# Type of evaluation
parser.add_argument('-t', '--type', type=str, choices=['all', 'last', 'combine-last', 'combine-all'], default='last', help="Type of experiment evaluation.")
# Batch size
parser.add_argument('-bs', '--batch_size', type=int, default=64, help='Batch size.')
# Number of processes
parser.add_argument('-nw', '--num_workers', type=int, default=6, help='Number of processes (workers).')
# Select device "cuda" for GPU or "cpu"
parser.add_argument('--device', type=str, default=("cuda" if torch.cuda.is_available() else "cpu"), choices=['cuda', 'cpu'], help='Device to use. Choose "cuda" for GPU or "cpu".')
# Select GPU device
parser.add_argument('--gpu_device', type=int, default=None, help='ID of a GPU to use when multiple GPUs are available.')
# Use multiple GPUs?
parser.add_argument('--multi_gpu', action='store_true', help='Flag whether to use all available GPUs.')
args = parser.parse_args()
print(f"Loading snapshots from experiment: {args.ckpt_dir}")
idx2label = cd.SoundData(prevent_cache=True).idx2label
#sound_data = cd.SoundData(phase='test', num_processes=args.num_workers)
testset = cd.TestDset(cache_prefix=args.cache_prefix, num_processes=args.num_workers, transform=cd.data_transforms[f'{args.cache_prefix}_test'])
testloader = thd.DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
device = torch.device(args.device)
RES_DIR = os.path.join(CKPT_DIR, args.ckpt_dir)
snaps_dir = os.path.join(RES_DIR, 'snaps')
runs = [os.path.join(snaps_dir, run_name) for run_name in sorted(os.listdir(snaps_dir))]
prefixes = [args.cache_prefix]*len(runs)
active_prefix = args.cache_prefix
if args.type.startswith('combine'):
RES_DIR2 = os.path.join(CKPT_DIR, args.ckpt_dir2)
snaps_dir2 = os.path.join(RES_DIR2, 'snaps')
runs += [os.path.join(snaps_dir2, run_name) for run_name in sorted(os.listdir(snaps_dir2))]
prefixes += ([args.cache_prefix2] * len(os.listdir(snaps_dir2)))
if args.ckpt_dir3 is not None:
RES_DIR3 = os.path.join(CKPT_DIR, args.ckpt_dir3)
snaps_dir3 = os.path.join(RES_DIR3, 'snaps')
runs += [os.path.join(snaps_dir3, run_name) for run_name in sorted(os.listdir(snaps_dir3))]
prefixes += ([args.cache_prefix3] * len(os.listdir(snaps_dir3)))
if args.ckpt_dir4 is not None:
RES_DIR4 = os.path.join(CKPT_DIR, args.ckpt_dir4)
snaps_dir4 = os.path.join(RES_DIR4, 'snaps')
runs += [os.path.join(snaps_dir4, run_name) for run_name in sorted(os.listdir(snaps_dir4))]
prefixes += ([args.cache_prefix4] * len(os.listdir(snaps_dir4)))
is_ensemble = len(runs) > 1
def eval_model(loader, model, model_num):
predictions = defaultdict(list)
pbar = tqdm(loader, total=(len(loader.dataset)//args.batch_size + 1), desc=f'Evaluation model {model_num}')
with torch.no_grad():
for inputs, ids in pbar:
inputs = inputs.to(device).float()
outputs = F.softmax(model(inputs), dim=1)
for i in range(len(ids)):
predictions[ids[i].item()].append(outputs[i].cpu())
predicted = {}
for idx in sorted(predictions.keys()):
predicted[idx] = torch.stack(predictions[idx], dim=0).numpy().mean(axis=0)
return predicted
# List of dictionaries
results = []
for split_num, run_dir in enumerate(runs):
if args.type.endswith('last'):
for model_num, mname in enumerate(os.listdir(run_dir)):
if mname.endswith('last.model'):
print(f"Evaluating model {mname}")
model = torch.load(os.path.join(run_dir, mname))
model.eval()
if args.multi_gpu:
model = nn.DataParallel(model)
if not prefixes[split_num] == active_prefix:
testset = cd.TestDset(cache_prefix=prefixes[split_num], num_processes=args.num_workers, transform=cd.data_transforms[f'{prefixes[split_num]}_test'])
testloader = thd.DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
active_prefix = prefixes[split_num]
results.append(eval_model(testloader, model, split_num))
elif args.type.endswith('all'):
for model_num, mname in enumerate(os.listdir(run_dir)):
if mname.endswith('.model'):
print(f"Evaluating model {mname}")
model = torch.load(os.path.join(run_dir, mname))
if args.multi_gpu:
model = nn.DataParallel(model)
results.append(eval_model(testloader, model, f'{split_num} / {model_num}'))
# Dictionary of lists / np.arrays
results = {k: np.array([dic[k] for dic in results]) for k in results[0]}
for key, value in results.items():
results[key] = value.mean(axis=0)
preds = {}
for key, value in results.items():
preds[key] = value.argsort()[-3:][::-1].tolist()
subm = pd.DataFrame.from_dict(preds, columns=[f'label{i}' for i in range(3)], orient='index')
subm['fname'] = subm.index
subm['fname'] = subm.fname.apply(lambda x: testset.idx2fname[x])
subm = subm[['fname', 'label0', 'label1', 'label2']]
for i in range(3):
subm[f'label{i}'] = subm[f'label{i}'].apply(lambda x: idx2label[x])
subm['label'] = subm.label0 + ' ' + subm.label1 + ' ' + subm.label2
subm = subm.drop([f'label{i}' for i in range(3)], axis=1)
for fname in ['0b0427e2.wav', '6ea0099f.wav', 'b39975f5.wav']:
subm.loc[subm.shape[0], 'fname'] = fname
subm.loc[subm.shape[0]-1, 'label'] = 'Laughter Hi-Hat Flute'
if not subm.shape[0] == 9400:
import pdb; pdb.set_trace()
subm.to_csv(os.path.join(RES_DIR, f'submission-{args.type}.csv'), index=False)