Skip to content

Commit 7d49d2a

Browse files
committed
split convert_test and inference_test
1 parent aac0505 commit 7d49d2a

1 file changed

Lines changed: 54 additions & 25 deletions

File tree

tests/model_test.py

Lines changed: 54 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,14 @@
1414
logger = logging.getLogger("mmdet2trt")
1515

1616

17-
def model_test(test_folder,
18-
cfg_path,
19-
checkpoint,
20-
save_folder,
21-
opt_shape_param=None,
22-
max_workspace_size=1 << 25,
23-
device="cuda:0",
24-
score_thr=0.3,
25-
fp16=True,
26-
enable_mask=False):
27-
28-
if not osp.exists(save_folder):
29-
os.mkdir(save_folder)
30-
trt_model_path = osp.join(save_folder, 'trt_model.pth')
31-
17+
def convert_test(cfg_path,
18+
checkpoint,
19+
trt_model_path,
20+
opt_shape_param=None,
21+
max_workspace_size=1 << 25,
22+
device="cuda:0",
23+
fp16=True,
24+
enable_mask=False):
3225
logger.info("creating {} trt model.".format(cfg_path))
3326
trt_model = mmdet2trt(cfg_path,
3427
checkpoint,
@@ -39,9 +32,15 @@ def model_test(test_folder,
3932
enable_mask=enable_mask)
4033
logger.info("finish, save trt_model in {}".format(trt_model_path))
4134
torch.save(trt_model.state_dict(), trt_model_path)
35+
return trt_model
4236

43-
trt_model = init_detector(trt_model_path)
4437

38+
def inference_test(trt_model,
39+
cfg_path,
40+
device,
41+
test_folder,
42+
save_folder,
43+
score_thr=0.3):
4544
file_list = os.listdir(test_folder)
4645

4746
for file_name in tqdm.tqdm(file_list):
@@ -74,6 +73,9 @@ def model_test(test_folder,
7473
cv2.imwrite(osp.join(save_folder, file_name), image)
7574

7675

76+
TEST_MODE_DICT = {'convert': 1, 'inference': 1 << 1, 'all': 0b11}
77+
78+
7779
def main():
7880
parser = ArgumentParser()
7981
parser.add_argument('test_folder', help='folder contain test images')
@@ -82,6 +84,10 @@ def main():
8284
parser.add_argument(
8385
'save_folder',
8486
help='tensorrt model and test images results save folder')
87+
parser.add_argument('--trt_model_path',
88+
default='',
89+
help='save and inference model. '
90+
'default [save_folder]/trt_model.pth')
8591
parser.add_argument(
8692
'--opt_shape_param',
8793
default='[ [ [1,3,800,800], [1,3,800,1344], [1,3,1344,1344] ] ]',
@@ -102,17 +108,40 @@ def main():
102108
parser.add_argument('--enable_mask',
103109
action='store_true',
104110
help="enable mask output")
111+
parser.add_argument('--test-mode',
112+
default='all',
113+
help='what to do in the test',
114+
choices=['convert', 'inference', 'all'])
105115
args = parser.parse_args()
106116

107-
model_test(args.test_folder,
108-
args.config,
109-
args.checkpoint,
110-
args.save_folder,
111-
opt_shape_param=eval(args.opt_shape_param),
112-
max_workspace_size=args.max_workspace_size,
113-
device=args.device,
114-
score_thr=args.score_thr,
115-
fp16=args.fp16)
117+
trt_model_path = args.trt_model_path
118+
if len(trt_model_path) == 0:
119+
trt_model_path = osp.join(args.save_folder, 'test_model.pth')
120+
121+
if not osp.exists(args.save_folder):
122+
os.mkdir(args.save_folder)
123+
124+
test_mode = TEST_MODE_DICT[args.test_mode]
125+
126+
if test_mode & TEST_MODE_DICT['convert'] > 0:
127+
convert_test(args.config,
128+
args.checkpoint,
129+
trt_model_path,
130+
opt_shape_param=eval(args.opt_shape_param),
131+
max_workspace_size=args.max_workspace_size,
132+
device=args.device,
133+
fp16=args.fp16)
134+
trt_model = init_detector(trt_model_path)
135+
else:
136+
trt_model = init_detector(trt_model_path)
137+
138+
if test_mode & TEST_MODE_DICT['inference'] > 0:
139+
inference_test(trt_model,
140+
args.config,
141+
args.device,
142+
args.test_folder,
143+
args.save_folder,
144+
score_thr=args.score_thr)
116145

117146

118147
if __name__ == '__main__':

0 commit comments

Comments
 (0)