模型代码的一部分:
class DownsamplerBlock(nn.Module):
def __init__(self, ninput, noutput):
super().__init__()
self.conv = nn.Conv2d(ninput, noutput - ninput, (3, 3), stride=2, padding=1, bias=True)
self.pool = nn.MaxPool2d(2, stride=2)
self.bn = nn.BatchNorm2d(noutput, eps=1e-3)
def forward(self, input):
output = torch.cat([self.conv(input), self.pool(input)], 1)
output = self.bn(output)
return F.relu(output)
请登录后评论