Commit c5ca36fc authored by 夜阑听风's avatar 夜阑听风 Committed by Francisco Massa

use dist.barrier to synchronize (#393)

parent 5f2a8263
...@@ -40,22 +40,9 @@ def synchronize(): ...@@ -40,22 +40,9 @@ def synchronize():
if not dist.is_initialized(): if not dist.is_initialized():
return return
world_size = dist.get_world_size() world_size = dist.get_world_size()
rank = dist.get_rank()
if world_size == 1: if world_size == 1:
return return
dist.barrier()
def _send_and_wait(r):
if rank == r:
tensor = torch.tensor(0, device="cuda")
else:
tensor = torch.tensor(1, device="cuda")
dist.broadcast(tensor, r)
while tensor.item() == 1:
time.sleep(1)
_send_and_wait(0)
# now sync on the main process
_send_and_wait(1)
def all_gather(data): def all_gather(data):
......
...@@ -44,6 +44,7 @@ def main(): ...@@ -44,6 +44,7 @@ def main():
torch.distributed.init_process_group( torch.distributed.init_process_group(
backend="nccl", init_method="env://" backend="nccl", init_method="env://"
) )
synchronize()
cfg.merge_from_file(args.config_file) cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts) cfg.merge_from_list(args.opts)
......
...@@ -139,6 +139,7 @@ def main(): ...@@ -139,6 +139,7 @@ def main():
torch.distributed.init_process_group( torch.distributed.init_process_group(
backend="nccl", init_method="env://" backend="nccl", init_method="env://"
) )
synchronize()
cfg.merge_from_file(args.config_file) cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts) cfg.merge_from_list(args.opts)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment