Commit f66c1d40 authored by benjaminrwilson's avatar benjaminrwilson Committed by Francisco Massa

Added distributed training check (#287)

parent a0d6edd9
...@@ -13,18 +13,24 @@ import torch ...@@ -13,18 +13,24 @@ import torch
def get_world_size(): def get_world_size():
if not torch.distributed.is_available():
return 1
if not torch.distributed.is_initialized(): if not torch.distributed.is_initialized():
return 1 return 1
return torch.distributed.get_world_size() return torch.distributed.get_world_size()
def get_rank(): def get_rank():
if not torch.distributed.is_available():
return 0
if not torch.distributed.is_initialized(): if not torch.distributed.is_initialized():
return 0 return 0
return torch.distributed.get_rank() return torch.distributed.get_rank()
def is_main_process(): def is_main_process():
if not torch.distributed.is_available():
return True
if not torch.distributed.is_initialized(): if not torch.distributed.is_initialized():
return True return True
return torch.distributed.get_rank() == 0 return torch.distributed.get_rank() == 0
...@@ -35,6 +41,8 @@ def synchronize(): ...@@ -35,6 +41,8 @@ def synchronize():
Helper function to synchronize between multiple processes when Helper function to synchronize between multiple processes when
using distributed training using distributed training
""" """
if not torch.distributed.is_available():
return
if not torch.distributed.is_initialized(): if not torch.distributed.is_initialized():
return return
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
...@@ -103,6 +111,8 @@ def scatter_gather(data): ...@@ -103,6 +111,8 @@ def scatter_gather(data):
# each process will then serialize the data to the folder defined by # each process will then serialize the data to the folder defined by
# the main process, and then the main process reads all of the serialized # the main process, and then the main process reads all of the serialized
# files and returns them in a list # files and returns them in a list
if not torch.distributed.is_available():
return [data]
if not torch.distributed.is_initialized(): if not torch.distributed.is_initialized():
return [data] return [data]
synchronize() synchronize()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment