Commit 2009ed5e authored by zimenglan's avatar zimenglan Committed by Francisco Massa

replacing all instances of torch.distributed.deprecated with torch.distributed (#248)

* make pixel indexes 0-based for bounding box in pascal voc dataset

* replacing all instances of torch.distributed.deprecated with torch.distributed

* replacing all instances of torch.distributed.deprecated with torch.distributed
parent c2619ed4
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# Code is copy-pasted exactly as in torch.utils.data.distributed, # Code is copy-pasted exactly as in torch.utils.data.distributed.
# with a modification in the import to use the deprecated backend
# FIXME remove this once c10d fixes the bug it has # FIXME remove this once c10d fixes the bug it has
import math import math
import torch import torch
import torch.distributed.deprecated as dist import torch.distributed as dist
from torch.utils.data.sampler import Sampler from torch.utils.data.sampler import Sampler
......
...@@ -65,8 +65,8 @@ def inference( ...@@ -65,8 +65,8 @@ def inference(
# convert to a torch.device for efficiency # convert to a torch.device for efficiency
device = torch.device(device) device = torch.device(device)
num_devices = ( num_devices = (
torch.distributed.deprecated.get_world_size() torch.distributed.get_world_size()
if torch.distributed.deprecated.is_initialized() if torch.distributed.is_initialized()
else 1 else 1
) )
logger = logging.getLogger("maskrcnn_benchmark.inference") logger = logging.getLogger("maskrcnn_benchmark.inference")
......
...@@ -4,7 +4,7 @@ import logging ...@@ -4,7 +4,7 @@ import logging
import time import time
import torch import torch
from torch.distributed import deprecated as dist import torch.distributed as dist
from maskrcnn_benchmark.utils.comm import get_world_size from maskrcnn_benchmark.utils.comm import get_world_size
from maskrcnn_benchmark.utils.metric_logger import MetricLogger from maskrcnn_benchmark.utils.metric_logger import MetricLogger
......
...@@ -13,21 +13,21 @@ import torch ...@@ -13,21 +13,21 @@ import torch
def get_world_size(): def get_world_size():
if not torch.distributed.deprecated.is_initialized(): if not torch.distributed.is_initialized():
return 1 return 1
return torch.distributed.deprecated.get_world_size() return torch.distributed.get_world_size()
def get_rank(): def get_rank():
if not torch.distributed.deprecated.is_initialized(): if not torch.distributed.is_initialized():
return 0 return 0
return torch.distributed.deprecated.get_rank() return torch.distributed.get_rank()
def is_main_process(): def is_main_process():
if not torch.distributed.deprecated.is_initialized(): if not torch.distributed.is_initialized():
return True return True
return torch.distributed.deprecated.get_rank() == 0 return torch.distributed.get_rank() == 0
def synchronize(): def synchronize():
...@@ -35,10 +35,10 @@ def synchronize(): ...@@ -35,10 +35,10 @@ def synchronize():
Helper function to synchronize between multiple processes when Helper function to synchronize between multiple processes when
using distributed training using distributed training
""" """
if not torch.distributed.deprecated.is_initialized(): if not torch.distributed.is_initialized():
return return
world_size = torch.distributed.deprecated.get_world_size() world_size = torch.distributed.get_world_size()
rank = torch.distributed.deprecated.get_rank() rank = torch.distributed.get_rank()
if world_size == 1: if world_size == 1:
return return
...@@ -47,7 +47,7 @@ def synchronize(): ...@@ -47,7 +47,7 @@ def synchronize():
tensor = torch.tensor(0, device="cuda") tensor = torch.tensor(0, device="cuda")
else: else:
tensor = torch.tensor(1, device="cuda") tensor = torch.tensor(1, device="cuda")
torch.distributed.deprecated.broadcast(tensor, r) torch.distributed.broadcast(tensor, r)
while tensor.item() == 1: while tensor.item() == 1:
time.sleep(1) time.sleep(1)
...@@ -103,11 +103,11 @@ def scatter_gather(data): ...@@ -103,11 +103,11 @@ def scatter_gather(data):
# each process will then serialize the data to the folder defined by # each process will then serialize the data to the folder defined by
# the main process, and then the main process reads all of the serialized # the main process, and then the main process reads all of the serialized
# files and returns them in a list # files and returns them in a list
if not torch.distributed.deprecated.is_initialized(): if not torch.distributed.is_initialized():
return [data] return [data]
synchronize() synchronize()
# get rank of the current process # get rank of the current process
rank = torch.distributed.deprecated.get_rank() rank = torch.distributed.get_rank()
# the data to communicate should be small # the data to communicate should be small
data_to_communicate = torch.empty(256, dtype=torch.uint8, device="cuda") data_to_communicate = torch.empty(256, dtype=torch.uint8, device="cuda")
...@@ -119,7 +119,7 @@ def scatter_gather(data): ...@@ -119,7 +119,7 @@ def scatter_gather(data):
synchronize() synchronize()
# the main process (rank=0) communicates the data to all processes # the main process (rank=0) communicates the data to all processes
torch.distributed.deprecated.broadcast(data_to_communicate, 0) torch.distributed.broadcast(data_to_communicate, 0)
# get the data that was communicated # get the data that was communicated
tmp_dir = _decode(data_to_communicate) tmp_dir = _decode(data_to_communicate)
...@@ -135,7 +135,7 @@ def scatter_gather(data): ...@@ -135,7 +135,7 @@ def scatter_gather(data):
# only the master process returns the data # only the master process returns the data
if rank == 0: if rank == 0:
data_list = [] data_list = []
world_size = torch.distributed.deprecated.get_world_size() world_size = torch.distributed.get_world_size()
for r in range(world_size): for r in range(world_size):
file_path = os.path.join(tmp_dir, file_template.format(r)) file_path = os.path.join(tmp_dir, file_template.format(r))
d = torch.load(file_path) d = torch.load(file_path)
......
...@@ -41,7 +41,7 @@ def main(): ...@@ -41,7 +41,7 @@ def main():
if distributed: if distributed:
torch.cuda.set_device(args.local_rank) torch.cuda.set_device(args.local_rank)
torch.distributed.deprecated.init_process_group( torch.distributed.init_process_group(
backend="nccl", init_method="env://" backend="nccl", init_method="env://"
) )
......
...@@ -35,7 +35,7 @@ def train(cfg, local_rank, distributed): ...@@ -35,7 +35,7 @@ def train(cfg, local_rank, distributed):
scheduler = make_lr_scheduler(cfg, optimizer) scheduler = make_lr_scheduler(cfg, optimizer)
if distributed: if distributed:
model = torch.nn.parallel.deprecated.DistributedDataParallel( model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[local_rank], output_device=local_rank, model, device_ids=[local_rank], output_device=local_rank,
# this should be removed if we update BatchNorm stats # this should be removed if we update BatchNorm stats
broadcast_buffers=False, broadcast_buffers=False,
...@@ -136,7 +136,7 @@ def main(): ...@@ -136,7 +136,7 @@ def main():
if args.distributed: if args.distributed:
torch.cuda.set_device(args.local_rank) torch.cuda.set_device(args.local_rank)
torch.distributed.deprecated.init_process_group( torch.distributed.init_process_group(
backend="nccl", init_method="env://" backend="nccl", init_method="env://"
) )
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment