Commit f66c1d40 authored by benjaminrwilson's avatar benjaminrwilson Committed by Francisco Massa

Added distributed training check (#287)

parent a0d6edd9
......@@ -13,18 +13,24 @@ import torch
def get_world_size():
if not torch.distributed.is_available():
return 1
if not torch.distributed.is_initialized():
return 1
return torch.distributed.get_world_size()
def get_rank():
if not torch.distributed.is_available():
return 0
if not torch.distributed.is_initialized():
return 0
return torch.distributed.get_rank()
def is_main_process():
if not torch.distributed.is_available():
return True
if not torch.distributed.is_initialized():
return True
return torch.distributed.get_rank() == 0
......@@ -35,6 +41,8 @@ def synchronize():
Helper function to synchronize between multiple processes when
using distributed training
"""
if not torch.distributed.is_available():
return
if not torch.distributed.is_initialized():
return
world_size = torch.distributed.get_world_size()
......@@ -103,6 +111,8 @@ def scatter_gather(data):
# each process will then serialize the data to the folder defined by
# the main process, and then the main process reads all of the serialized
# files and returns them in a list
if not torch.distributed.is_available():
return [data]
if not torch.distributed.is_initialized():
return [data]
synchronize()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment