1414logger = logging .getLogger ("mmdet2trt" )
1515
1616
17- def model_test (test_folder ,
18- cfg_path ,
19- checkpoint ,
20- save_folder ,
21- opt_shape_param = None ,
22- max_workspace_size = 1 << 25 ,
23- device = "cuda:0" ,
24- score_thr = 0.3 ,
25- fp16 = True ,
26- enable_mask = False ):
27-
28- if not osp .exists (save_folder ):
29- os .mkdir (save_folder )
30- trt_model_path = osp .join (save_folder , 'trt_model.pth' )
31-
17+ def convert_test (cfg_path ,
18+ checkpoint ,
19+ trt_model_path ,
20+ opt_shape_param = None ,
21+ max_workspace_size = 1 << 25 ,
22+ device = "cuda:0" ,
23+ fp16 = True ,
24+ enable_mask = False ):
3225 logger .info ("creating {} trt model." .format (cfg_path ))
3326 trt_model = mmdet2trt (cfg_path ,
3427 checkpoint ,
@@ -39,9 +32,15 @@ def model_test(test_folder,
3932 enable_mask = enable_mask )
4033 logger .info ("finish, save trt_model in {}" .format (trt_model_path ))
4134 torch .save (trt_model .state_dict (), trt_model_path )
35+ return trt_model
4236
43- trt_model = init_detector (trt_model_path )
4437
38+ def inference_test (trt_model ,
39+ cfg_path ,
40+ device ,
41+ test_folder ,
42+ save_folder ,
43+ score_thr = 0.3 ):
4544 file_list = os .listdir (test_folder )
4645
4746 for file_name in tqdm .tqdm (file_list ):
@@ -74,6 +73,9 @@ def model_test(test_folder,
7473 cv2 .imwrite (osp .join (save_folder , file_name ), image )
7574
7675
76+ TEST_MODE_DICT = {'convert' : 1 , 'inference' : 1 << 1 , 'all' : 0b11 }
77+
78+
7779def main ():
7880 parser = ArgumentParser ()
7981 parser .add_argument ('test_folder' , help = 'folder contain test images' )
@@ -82,6 +84,10 @@ def main():
8284 parser .add_argument (
8385 'save_folder' ,
8486 help = 'tensorrt model and test images results save folder' )
87+ parser .add_argument ('--trt_model_path' ,
88+ default = '' ,
89+ help = 'save and inference model. '
90+ 'default [save_folder]/trt_model.pth' )
8591 parser .add_argument (
8692 '--opt_shape_param' ,
8793 default = '[ [ [1,3,800,800], [1,3,800,1344], [1,3,1344,1344] ] ]' ,
@@ -102,17 +108,40 @@ def main():
102108 parser .add_argument ('--enable_mask' ,
103109 action = 'store_true' ,
104110 help = "enable mask output" )
111+ parser .add_argument ('--test-mode' ,
112+ default = 'all' ,
113+ help = 'what to do in the test' ,
114+ choices = ['convert' , 'inference' , 'all' ])
105115 args = parser .parse_args ()
106116
107- model_test (args .test_folder ,
108- args .config ,
109- args .checkpoint ,
110- args .save_folder ,
111- opt_shape_param = eval (args .opt_shape_param ),
112- max_workspace_size = args .max_workspace_size ,
113- device = args .device ,
114- score_thr = args .score_thr ,
115- fp16 = args .fp16 )
117+ trt_model_path = args .trt_model_path
118+ if len (trt_model_path ) == 0 :
119+ trt_model_path = osp .join (args .save_folder , 'test_model.pth' )
120+
121+ if not osp .exists (args .save_folder ):
122+ os .mkdir (args .save_folder )
123+
124+ test_mode = TEST_MODE_DICT [args .test_mode ]
125+
126+ if test_mode & TEST_MODE_DICT ['convert' ] > 0 :
127+ convert_test (args .config ,
128+ args .checkpoint ,
129+ trt_model_path ,
130+ opt_shape_param = eval (args .opt_shape_param ),
131+ max_workspace_size = args .max_workspace_size ,
132+ device = args .device ,
133+ fp16 = args .fp16 )
134+ trt_model = init_detector (trt_model_path )
135+ else :
136+ trt_model = init_detector (trt_model_path )
137+
138+ if test_mode & TEST_MODE_DICT ['inference' ] > 0 :
139+ inference_test (trt_model ,
140+ args .config ,
141+ args .device ,
142+ args .test_folder ,
143+ args .save_folder ,
144+ score_thr = args .score_thr )
116145
117146
118147if __name__ == '__main__' :
0 commit comments