Commit 9b53d15c authored by zimenglan's avatar zimenglan Committed by Francisco Massa

use 'kaiming_uniform' to initialize resnet, disable gn after fc layer (#377)

* make pixel indexes 0-based for bounding box in pascal voc dataset

* replacing all instances of torch.distributed.deprecated with torch.distributed

* replacing all instances of torch.distributed.deprecated with torch.distributed

* add GroupNorm

* add GroupNorm -- sort out yaml files

* use torch.nn.GroupNorm instead, replace 'use_gn' with 'conv_block' and use 'BaseStem'&'Bottleneck' to simply codes

* modification on 'group_norm' and 'conv_with_kaiming_uniform' function

* modification on yaml files in configs/gn_baselines/ and reduce the amount of indentation and code duplication

* use 'kaiming_uniform' to initialize resnet, disable gn after fc layer, and add dilation into ResNetHead
parent 7cdf122d
......@@ -244,6 +244,10 @@ class Bottleneck(nn.Module):
),
norm_func(out_channels),
)
for modules in [self.downsample,]:
for l in modules.modules():
if isinstance(l, Conv2d):
nn.init.kaiming_uniform_(l.weight, a=1)
if dilation > 1:
stride = 1 # reset to be 1
......@@ -280,6 +284,9 @@ class Bottleneck(nn.Module):
)
self.bn3 = norm_func(out_channels)
for l in [self.conv1, self.conv2, self.conv3,]:
nn.init.kaiming_uniform_(l.weight, a=1)
def forward(self, x):
identity = x
......@@ -314,6 +321,9 @@ class BaseStem(nn.Module):
)
self.bn1 = norm_func(out_channels)
for l in [self.conv1,]:
nn.init.kaiming_uniform_(l.weight, a=1)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
......
......@@ -77,7 +77,7 @@ def make_conv3x3(
return conv
def make_fc(dim_in, hidden_dim, use_gn):
def make_fc(dim_in, hidden_dim, use_gn=False):
'''
Caffe2 implementation uses XavierFill, which in fact
corresponds to kaiming_uniform_ in PyTorch
......
......@@ -33,6 +33,7 @@ class ResNet50Conv5ROIFeatureExtractor(nn.Module):
stride_in_1x1=config.MODEL.RESNETS.STRIDE_IN_1X1,
stride_init=None,
res2_out_channels=config.MODEL.RESNETS.RES2_OUT_CHANNELS,
dilation=config.MODEL.RESNETS.RES5_DILATION
)
self.pooler = pooler
......@@ -131,7 +132,7 @@ class FPNXconv1fcFeatureExtractor(nn.Module):
input_size = conv_head_dim * resolution ** 2
representation_size = cfg.MODEL.ROI_BOX_HEAD.MLP_HEAD_DIM
self.fc6 = make_fc(input_size, representation_size, use_gn)
self.fc6 = make_fc(input_size, representation_size, use_gn=False)
def forward(self, x, proposals):
x = self.pooler(x, proposals)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment