import torchvision.models as models from torch import nn def resnet18(pretrained=False, **kwargs): """Constructs a ResNet-18 model. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ model = models.resnet18(num_classes=512, **kwargs) return model def resnet34(pretrained=False, **kwargs): """Constructs a ResNet-34 model. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ model = models.resnet34(num_classes=512, **kwargs) return model def resnet50(pretrained=False, **kwargs): """Constructs a ResNet-50 model. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ model = models.resnet50(num_classes=512, **kwargs) return model def resnet101(pretrained=False, **kwargs): """Constructs a ResNet-101 model. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ model = models.resnet101(num_classes=512, **kwargs) return model def resnet152(pretrained=False, **kwargs): """Constructs a ResNet-152 model. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ model = models.resnet152(num_classes=512, **kwargs) return model def sphere20(): return sphere20a() class sphere20a(nn.Module): def __init__(self): super(sphere20a, self).__init__() #input = B*3*112*96 self.conv1_1 = nn.Conv2d(3,64,3,2,1) #=>B*64*56*48 self.relu1_1 = nn.PReLU(64) self.conv1_2 = nn.Conv2d(64,64,3,1,1) self.relu1_2 = nn.PReLU(64) self.conv1_3 = nn.Conv2d(64,64,3,1,1) self.relu1_3 = nn.PReLU(64) self.conv2_1 = nn.Conv2d(64,128,3,2,1) #=>B*128*28*24 self.relu2_1 = nn.PReLU(128) self.conv2_2 = nn.Conv2d(128,128,3,1,1) self.relu2_2 = nn.PReLU(128) self.conv2_3 = nn.Conv2d(128,128,3,1,1) self.relu2_3 = nn.PReLU(128) self.conv2_4 = nn.Conv2d(128,128,3,1,1) #=>B*128*28*24 self.relu2_4 = nn.PReLU(128) self.conv2_5 = nn.Conv2d(128,128,3,1,1) self.relu2_5 = nn.PReLU(128) self.conv3_1 = nn.Conv2d(128,256,3,2,1) #=>B*256*14*12 self.relu3_1 = nn.PReLU(256) self.conv3_2 = nn.Conv2d(256,256,3,1,1) self.relu3_2 = nn.PReLU(256) self.conv3_3 = nn.Conv2d(256,256,3,1,1) self.relu3_3 = nn.PReLU(256) self.conv3_4 = nn.Conv2d(256,256,3,1,1) #=>B*256*14*12 self.relu3_4 = nn.PReLU(256) self.conv3_5 = nn.Conv2d(256,256,3,1,1) self.relu3_5 = nn.PReLU(256) self.conv3_6 = nn.Conv2d(256,256,3,1,1) #=>B*256*14*12 self.relu3_6 = nn.PReLU(256) self.conv3_7 = nn.Conv2d(256,256,3,1,1) self.relu3_7 = nn.PReLU(256) self.conv3_8 = nn.Conv2d(256,256,3,1,1) #=>B*256*14*12 self.relu3_8 = nn.PReLU(256) self.conv3_9 = nn.Conv2d(256,256,3,1,1) self.relu3_9 = nn.PReLU(256) self.conv4_1 = nn.Conv2d(256,512,3,2,1) #=>B*512*7*6 self.relu4_1 = nn.PReLU(512) self.conv4_2 = nn.Conv2d(512,512,3,1,1) self.relu4_2 = nn.PReLU(512) self.conv4_3 = nn.Conv2d(512,512,3,1,1) self.relu4_3 = nn.PReLU(512) self.fc5 = nn.Linear(512*14*14,512) # ORIGINAL for 112x96: self.fc5 = nn.Linear(512*7*6,512) def forward(self, x): x = self.relu1_1(self.conv1_1(x)) x = x + self.relu1_3(self.conv1_3(self.relu1_2(self.conv1_2(x)))) x = self.relu2_1(self.conv2_1(x)) x = x + self.relu2_3(self.conv2_3(self.relu2_2(self.conv2_2(x)))) x = x + self.relu2_5(self.conv2_5(self.relu2_4(self.conv2_4(x)))) x = self.relu3_1(self.conv3_1(x)) x = x + self.relu3_3(self.conv3_3(self.relu3_2(self.conv3_2(x)))) x = x + self.relu3_5(self.conv3_5(self.relu3_4(self.conv3_4(x)))) x = x + self.relu3_7(self.conv3_7(self.relu3_6(self.conv3_6(x)))) x = x + self.relu3_9(self.conv3_9(self.relu3_8(self.conv3_8(x)))) x = self.relu4_1(self.conv4_1(x)) x = x + self.relu4_3(self.conv4_3(self.relu4_2(self.conv4_2(x)))) x = x.view(x.size(0),-1) x = self.fc5(x) return x