Commit 2009ed5e authored by zimenglan's avatar zimenglan Committed by Francisco Massa

replacing all instances of torch.distributed.deprecated with torch.distributed (#248)

* make pixel indexes 0-based for bounding box in pascal voc dataset

* replacing all instances of torch.distributed.deprecated with torch.distributed

* replacing all instances of torch.distributed.deprecated with torch.distributed
parent c2619ed4
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# Code is copy-pasted exactly as in torch.utils.data.distributed,
# with a modification in the import to use the deprecated backend
# Code is copy-pasted exactly as in torch.utils.data.distributed.
# FIXME remove this once c10d fixes the bug it has
import math
import torch
import torch.distributed.deprecated as dist
import torch.distributed as dist
from torch.utils.data.sampler import Sampler
......
......@@ -65,8 +65,8 @@ def inference(
# convert to a torch.device for efficiency
device = torch.device(device)
num_devices = (
torch.distributed.deprecated.get_world_size()
if torch.distributed.deprecated.is_initialized()
torch.distributed.get_world_size()
if torch.distributed.is_initialized()
else 1
)
logger = logging.getLogger("maskrcnn_benchmark.inference")
......
......@@ -4,7 +4,7 @@ import logging
import time
import torch
from torch.distributed import deprecated as dist
import torch.distributed as dist
from maskrcnn_benchmark.utils.comm import get_world_size
from maskrcnn_benchmark.utils.metric_logger import MetricLogger
......
......@@ -13,21 +13,21 @@ import torch
def get_world_size():
if not torch.distributed.deprecated.is_initialized():
if not torch.distributed.is_initialized():
return 1
return torch.distributed.deprecated.get_world_size()
return torch.distributed.get_world_size()
def get_rank():
if not torch.distributed.deprecated.is_initialized():
if not torch.distributed.is_initialized():
return 0
return torch.distributed.deprecated.get_rank()
return torch.distributed.get_rank()
def is_main_process():
if not torch.distributed.deprecated.is_initialized():
if not torch.distributed.is_initialized():
return True
return torch.distributed.deprecated.get_rank() == 0
return torch.distributed.get_rank() == 0
def synchronize():
......@@ -35,10 +35,10 @@ def synchronize():
Helper function to synchronize between multiple processes when
using distributed training
"""
if not torch.distributed.deprecated.is_initialized():
if not torch.distributed.is_initialized():
return
world_size = torch.distributed.deprecated.get_world_size()
rank = torch.distributed.deprecated.get_rank()
world_size = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
if world_size == 1:
return
......@@ -47,7 +47,7 @@ def synchronize():
tensor = torch.tensor(0, device="cuda")
else:
tensor = torch.tensor(1, device="cuda")
torch.distributed.deprecated.broadcast(tensor, r)
torch.distributed.broadcast(tensor, r)
while tensor.item() == 1:
time.sleep(1)
......@@ -103,11 +103,11 @@ def scatter_gather(data):
# each process will then serialize the data to the folder defined by
# the main process, and then the main process reads all of the serialized
# files and returns them in a list
if not torch.distributed.deprecated.is_initialized():
if not torch.distributed.is_initialized():
return [data]
synchronize()
# get rank of the current process
rank = torch.distributed.deprecated.get_rank()
rank = torch.distributed.get_rank()
# the data to communicate should be small
data_to_communicate = torch.empty(256, dtype=torch.uint8, device="cuda")
......@@ -119,7 +119,7 @@ def scatter_gather(data):
synchronize()
# the main process (rank=0) communicates the data to all processes
torch.distributed.deprecated.broadcast(data_to_communicate, 0)
torch.distributed.broadcast(data_to_communicate, 0)
# get the data that was communicated
tmp_dir = _decode(data_to_communicate)
......@@ -135,7 +135,7 @@ def scatter_gather(data):
# only the master process returns the data
if rank == 0:
data_list = []
world_size = torch.distributed.deprecated.get_world_size()
world_size = torch.distributed.get_world_size()
for r in range(world_size):
file_path = os.path.join(tmp_dir, file_template.format(r))
d = torch.load(file_path)
......
......@@ -41,7 +41,7 @@ def main():
if distributed:
torch.cuda.set_device(args.local_rank)
torch.distributed.deprecated.init_process_group(
torch.distributed.init_process_group(
backend="nccl", init_method="env://"
)
......
......@@ -35,7 +35,7 @@ def train(cfg, local_rank, distributed):
scheduler = make_lr_scheduler(cfg, optimizer)
if distributed:
model = torch.nn.parallel.deprecated.DistributedDataParallel(
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[local_rank], output_device=local_rank,
# this should be removed if we update BatchNorm stats
broadcast_buffers=False,
......@@ -136,7 +136,7 @@ def main():
if args.distributed:
torch.cuda.set_device(args.local_rank)
torch.distributed.deprecated.init_process_group(
torch.distributed.init_process_group(
backend="nccl", init_method="env://"
)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment