Commit b3cab7fc authored by zimenglan's avatar zimenglan Committed by Francisco Massa

add 'once_differentiable' for dcn and modify 'configs/cityscapes/README.md' (#701)

* make pixel indexes 0-based for bounding box in pascal voc dataset

* replacing all instances of torch.distributed.deprecated with torch.distributed

* replacing all instances of torch.distributed.deprecated with torch.distributed

* add GroupNorm

* add GroupNorm -- sort out yaml files

* use torch.nn.GroupNorm instead, replace 'use_gn' with 'conv_block' and use 'BaseStem'&'Bottleneck' to simply codes

* modification on 'group_norm' and 'conv_with_kaiming_uniform' function

* modification on yaml files in configs/gn_baselines/ and reduce the amount of indentation and code duplication

* use 'kaiming_uniform' to initialize resnet, disable gn after fc layer, and add dilation into ResNetHead

* agnostic-regression for bbox

* please set 'STRIDE_IN_1X1' to be 'False' when backbone use GN

* add README.md for GN

* add dcn from mmdetection

* add documentation for finetuning cityscapes

* add documentation for finetuning cityscapes

* add documentation for finetuning cityscapes

* add 'once_differentiable' for dcn and modify 'configs/cityscapes/README.md'
parent 66c3e56c
...@@ -195,10 +195,13 @@ def clip_weights_from_pretrain_of_coco_to_cityscapes(f, out_file): ...@@ -195,10 +195,13 @@ def clip_weights_from_pretrain_of_coco_to_cityscapes(f, out_file):
print("f: {}\nout_file: {}".format(f, out_file)) print("f: {}\nout_file: {}".format(f, out_file))
torch.save(m, out_file) torch.save(m, out_file)
``` ```
Step 3: modify the `input&solver` configuration in the `yaml` file, like this: Step 3: modify the `input&weight&solver` configuration in the `yaml` file, like this:
``` ```
MODEL:
WEIGHT: "xxx.pth" # the model u save from above code
INPUT: INPUT:
MIN_SIZE_TRAIN: (800, 832, 863, 896, 928, 960, 992, 1024, 1024) MIN_SIZE_TRAIN: (800, 832, 864, 896, 928, 960, 992, 1024, 1024)
MAX_SIZE_TRAIN: 2048 MAX_SIZE_TRAIN: 2048
MIN_SIZE_TEST: 1024 MIN_SIZE_TEST: 1024
MAX_SIZE_TEST: 2048 MAX_SIZE_TEST: 2048
...@@ -210,4 +213,5 @@ SOLVER: ...@@ -210,4 +213,5 @@ SOLVER:
STEPS: (3000,) STEPS: (3000,)
MAX_ITER: 4000 MAX_ITER: 4000
``` ```
Step 4: train the model.
import torch import torch
from torch.autograd import Function from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair from torch.nn.modules.utils import _pair
from maskrcnn_benchmark import _C from maskrcnn_benchmark import _C
...@@ -67,6 +68,7 @@ class DeformConvFunction(Function): ...@@ -67,6 +68,7 @@ class DeformConvFunction(Function):
return output return output
@staticmethod @staticmethod
@once_differentiable
def backward(ctx, grad_output): def backward(ctx, grad_output):
input, offset, weight = ctx.saved_tensors input, offset, weight = ctx.saved_tensors
...@@ -201,6 +203,7 @@ class ModulatedDeformConvFunction(Function): ...@@ -201,6 +203,7 @@ class ModulatedDeformConvFunction(Function):
return output return output
@staticmethod @staticmethod
@once_differentiable
def backward(ctx, grad_output): def backward(ctx, grad_output):
if not grad_output.is_cuda: if not grad_output.is_cuda:
raise NotImplementedError raise NotImplementedError
......
import torch import torch
from torch.autograd import Function from torch.autograd import Function
from torch.autograd.function import once_differentiable
from maskrcnn_benchmark import _C from maskrcnn_benchmark import _C
...@@ -60,6 +61,7 @@ class DeformRoIPoolingFunction(Function): ...@@ -60,6 +61,7 @@ class DeformRoIPoolingFunction(Function):
return output return output
@staticmethod @staticmethod
@once_differentiable
def backward(ctx, grad_output): def backward(ctx, grad_output):
if not grad_output.is_cuda: if not grad_output.is_cuda:
raise NotImplementedError raise NotImplementedError
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment