diff --git a/model.py b/model.py index e6bd592..8bf8cf2 100644 --- a/model.py +++ b/model.py @@ -92,21 +92,21 @@ def load_encoder_arch(c, L): # if L >= 3: encoder.layer2.register_forward_hook(get_activation(pool_layers[pool_cnt])) - if 'wide' in c.enc_arch: + if ('wide' or 'resnet50') in c.enc_arch: pool_dims.append(encoder.layer2[-1].conv3.out_channels) else: pool_dims.append(encoder.layer2[-1].conv2.out_channels) pool_cnt = pool_cnt + 1 if L >= 2: encoder.layer3.register_forward_hook(get_activation(pool_layers[pool_cnt])) - if 'wide' in c.enc_arch: + if ('wide' or 'resnet50') in c.enc_arch: pool_dims.append(encoder.layer3[-1].conv3.out_channels) else: pool_dims.append(encoder.layer3[-1].conv2.out_channels) pool_cnt = pool_cnt + 1 if L >= 1: encoder.layer4.register_forward_hook(get_activation(pool_layers[pool_cnt])) - if 'wide' in c.enc_arch: + if ('wide' or 'resnet50') in c.enc_arch: pool_dims.append(encoder.layer4[-1].conv3.out_channels) else: pool_dims.append(encoder.layer4[-1].conv2.out_channels)