Commit 5f2a8263 authored by wat3rBro's avatar wat3rBro Committed by Francisco Massa

use all_gather to gather results from all gpus (#383)

parent 9b53d15c
...@@ -9,7 +9,7 @@ from tqdm import tqdm ...@@ -9,7 +9,7 @@ from tqdm import tqdm
from maskrcnn_benchmark.data.datasets.evaluation import evaluate from maskrcnn_benchmark.data.datasets.evaluation import evaluate
from ..utils.comm import is_main_process from ..utils.comm import is_main_process
from ..utils.comm import scatter_gather from ..utils.comm import all_gather
from ..utils.comm import synchronize from ..utils.comm import synchronize
...@@ -30,7 +30,7 @@ def compute_on_dataset(model, data_loader, device): ...@@ -30,7 +30,7 @@ def compute_on_dataset(model, data_loader, device):
def _accumulate_predictions_from_multiple_gpus(predictions_per_gpu): def _accumulate_predictions_from_multiple_gpus(predictions_per_gpu):
all_predictions = scatter_gather(predictions_per_gpu) all_predictions = all_gather(predictions_per_gpu)
if not is_main_process(): if not is_main_process():
return return
# merge the list of dicts # merge the list of dicts
......
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
""" """
This file contains primitives for multi-gpu communication. This file contains primitives for multi-gpu communication.
This is useful when doing distributed training. This is useful when doing distributed training.
""" """
import os
import pickle import pickle
import tempfile
import time import time
import torch import torch
import torch.distributed as dist
def get_world_size(): def get_world_size():
if not torch.distributed.is_available(): if not dist.is_available():
return 1 return 1
if not torch.distributed.is_initialized(): if not dist.is_initialized():
return 1 return 1
return torch.distributed.get_world_size() return dist.get_world_size()
def get_rank(): def get_rank():
if not torch.distributed.is_available(): if not dist.is_available():
return 0 return 0
if not torch.distributed.is_initialized(): if not dist.is_initialized():
return 0 return 0
return torch.distributed.get_rank() return dist.get_rank()
def is_main_process(): def is_main_process():
if not torch.distributed.is_available(): return get_rank() == 0
return True
if not torch.distributed.is_initialized():
return True
return torch.distributed.get_rank() == 0
def synchronize(): def synchronize():
""" """
Helper function to synchronize between multiple processes when Helper function to synchronize (barrier) among all processes when
using distributed training using distributed training
""" """
if not torch.distributed.is_available(): if not dist.is_available():
return return
if not torch.distributed.is_initialized(): if not dist.is_initialized():
return return
world_size = torch.distributed.get_world_size() world_size = dist.get_world_size()
rank = torch.distributed.get_rank() rank = dist.get_rank()
if world_size == 1: if world_size == 1:
return return
...@@ -55,7 +49,7 @@ def synchronize(): ...@@ -55,7 +49,7 @@ def synchronize():
tensor = torch.tensor(0, device="cuda") tensor = torch.tensor(0, device="cuda")
else: else:
tensor = torch.tensor(1, device="cuda") tensor = torch.tensor(1, device="cuda")
torch.distributed.broadcast(tensor, r) dist.broadcast(tensor, r)
while tensor.item() == 1: while tensor.item() == 1:
time.sleep(1) time.sleep(1)
...@@ -64,94 +58,73 @@ def synchronize(): ...@@ -64,94 +58,73 @@ def synchronize():
_send_and_wait(1) _send_and_wait(1)
def _encode(encoded_data, data): def all_gather(data):
# gets a byte representation for the data """
encoded_bytes = pickle.dumps(data) Run all_gather on arbitrary picklable data (not necessarily tensors)
# convert this byte string into a byte tensor Args:
storage = torch.ByteStorage.from_buffer(encoded_bytes) data: any picklable object
tensor = torch.ByteTensor(storage).to("cuda") Returns:
# encoding: first byte is the size and then rest is the data list[data]: list of data gathered from each rank
s = tensor.numel() """
assert s <= 255, "Can't encode data greater than 255 bytes" world_size = get_world_size()
# put the encoded data in encoded_data if world_size == 1:
encoded_data[0] = s return [data]
encoded_data[1 : (s + 1)] = tensor
def _decode(encoded_data):
size = encoded_data[0]
encoded_tensor = encoded_data[1 : (size + 1)].to("cpu")
return pickle.loads(bytearray(encoded_tensor.tolist()))
# serialized to a Tensor
buffer = pickle.dumps(data)
storage = torch.ByteStorage.from_buffer(buffer)
tensor = torch.ByteTensor(storage).to("cuda")
# TODO try to use tensor in shared-memory instead of serializing to disk # obtain Tensor size of each rank
# this involves getting the all_gather to work local_size = torch.IntTensor([tensor.numel()]).to("cuda")
def scatter_gather(data): size_list = [torch.IntTensor([0]).to("cuda") for _ in range(world_size)]
""" dist.all_gather(size_list, local_size)
This function gathers data from multiple processes, and returns them size_list = [int(size.item()) for size in size_list]
in a list, as they were obtained from each process. max_size = max(size_list)
# receiving Tensor from all ranks
# we pad the tensor because torch all_gather does not support
# gathering tensors of different shapes
tensor_list = []
for _ in size_list:
tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda"))
if local_size != max_size:
padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda")
tensor = torch.cat((tensor, padding), dim=0)
dist.all_gather(tensor_list, tensor)
This function is useful for retrieving data from multiple processes, data_list = []
when launching the code with torch.distributed.launch for size, tensor in zip(size_list, tensor_list):
buffer = tensor.cpu().numpy().tobytes()[:size]
data_list.append(pickle.loads(buffer))
Note: this function is slow and should not be used in tight loops, i.e., return data_list
do not use it in the training loop.
Arguments:
data: the object to be gathered from multiple processes.
It must be serializable
Returns: def reduce_dict(input_dict, average=True):
result (list): a list with as many elements as there are processes,
where each element i in the list corresponds to the data that was
gathered from the process of rank i.
""" """
# strategy: the main process creates a temporary directory, and communicates Args:
# the location of the temporary directory to all other processes. input_dict (dict): all the values will be reduced
# each process will then serialize the data to the folder defined by average (bool): whether to do average or sum
# the main process, and then the main process reads all of the serialized Reduce the values in the dictionary from all processes so that process with rank
# files and returns them in a list 0 has the averaged results. Returns a dict with the same fields as
if not torch.distributed.is_available(): input_dict, after reduction.
return [data] """
if not torch.distributed.is_initialized(): world_size = get_world_size()
return [data] if world_size < 2:
synchronize() return input_dict
# get rank of the current process with torch.no_grad():
rank = torch.distributed.get_rank() names = []
values = []
# the data to communicate should be small # sort the keys so that they are consistent across processes
data_to_communicate = torch.empty(256, dtype=torch.uint8, device="cuda") for k in sorted(input_dict.keys()):
if rank == 0: names.append(k)
# manually creates a temporary directory, that needs to be cleaned values.append(input_dict[k])
# afterwards values = torch.stack(values, dim=0)
tmp_dir = tempfile.mkdtemp() dist.reduce(values, dst=0)
_encode(data_to_communicate, tmp_dir) if dist.get_rank() == 0 and average:
# only main process gets accumulated, so only divide by
synchronize() # world_size in this case
# the main process (rank=0) communicates the data to all processes values /= world_size
torch.distributed.broadcast(data_to_communicate, 0) reduced_dict = {k: v for k, v in zip(names, values)}
return reduced_dict
# get the data that was communicated
tmp_dir = _decode(data_to_communicate)
# each process serializes to a different file
file_template = "file{}.pth"
tmp_file = os.path.join(tmp_dir, file_template.format(rank))
torch.save(data, tmp_file)
# synchronize before loading the data
synchronize()
# only the master process returns the data
if rank == 0:
data_list = []
world_size = torch.distributed.get_world_size()
for r in range(world_size):
file_path = os.path.join(tmp_dir, file_template.format(r))
d = torch.load(file_path)
data_list.append(d)
# cleanup
os.remove(file_path)
# cleanup
os.rmdir(tmp_dir)
return data_list
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment