Commit 5386f3c5 authored by Maxim Berman's avatar Maxim Berman Committed by Francisco Massa

Use dist.get_rank() instead of local_rank to detect master process (#40)

parent b5de47b7
...@@ -18,6 +18,12 @@ def get_world_size(): ...@@ -18,6 +18,12 @@ def get_world_size():
return torch.distributed.deprecated.get_world_size() return torch.distributed.deprecated.get_world_size()
def get_rank():
if not torch.distributed.deprecated.is_initialized():
return 0
return torch.distributed.deprecated.get_rank()
def is_main_process(): def is_main_process():
if not torch.distributed.deprecated.is_initialized(): if not torch.distributed.deprecated.is_initialized():
return True return True
......
...@@ -4,11 +4,11 @@ import os ...@@ -4,11 +4,11 @@ import os
import sys import sys
def setup_logger(name, save_dir, local_rank): def setup_logger(name, save_dir, distributed_rank):
logger = logging.getLogger(name) logger = logging.getLogger(name)
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
# don't log results for the non-master process # don't log results for the non-master process
if local_rank > 0: if distributed_rank > 0:
return logger return logger
ch = logging.StreamHandler(stream=sys.stdout) ch = logging.StreamHandler(stream=sys.stdout)
ch.setLevel(logging.DEBUG) ch.setLevel(logging.DEBUG)
......
...@@ -13,7 +13,7 @@ from maskrcnn_benchmark.engine.inference import inference ...@@ -13,7 +13,7 @@ from maskrcnn_benchmark.engine.inference import inference
from maskrcnn_benchmark.modeling.detector import build_detection_model from maskrcnn_benchmark.modeling.detector import build_detection_model
from maskrcnn_benchmark.utils.checkpoint import DetectronCheckpointer from maskrcnn_benchmark.utils.checkpoint import DetectronCheckpointer
from maskrcnn_benchmark.utils.collect_env import collect_env_info from maskrcnn_benchmark.utils.collect_env import collect_env_info
from maskrcnn_benchmark.utils.comm import synchronize from maskrcnn_benchmark.utils.comm import synchronize, get_rank
from maskrcnn_benchmark.utils.logger import setup_logger from maskrcnn_benchmark.utils.logger import setup_logger
from maskrcnn_benchmark.utils.miscellaneous import mkdir from maskrcnn_benchmark.utils.miscellaneous import mkdir
...@@ -50,7 +50,7 @@ def main(): ...@@ -50,7 +50,7 @@ def main():
cfg.freeze() cfg.freeze()
save_dir = "" save_dir = ""
logger = setup_logger("maskrcnn_benchmark", save_dir, args.local_rank) logger = setup_logger("maskrcnn_benchmark", save_dir, get_rank())
logger.info("Using {} GPUs".format(num_gpus)) logger.info("Using {} GPUs".format(num_gpus))
logger.info(cfg) logger.info(cfg)
......
...@@ -20,7 +20,7 @@ from maskrcnn_benchmark.engine.trainer import do_train ...@@ -20,7 +20,7 @@ from maskrcnn_benchmark.engine.trainer import do_train
from maskrcnn_benchmark.modeling.detector import build_detection_model from maskrcnn_benchmark.modeling.detector import build_detection_model
from maskrcnn_benchmark.utils.checkpoint import DetectronCheckpointer from maskrcnn_benchmark.utils.checkpoint import DetectronCheckpointer
from maskrcnn_benchmark.utils.collect_env import collect_env_info from maskrcnn_benchmark.utils.collect_env import collect_env_info
from maskrcnn_benchmark.utils.comm import synchronize from maskrcnn_benchmark.utils.comm import synchronize, get_rank
from maskrcnn_benchmark.utils.imports import import_file from maskrcnn_benchmark.utils.imports import import_file
from maskrcnn_benchmark.utils.logger import setup_logger from maskrcnn_benchmark.utils.logger import setup_logger
from maskrcnn_benchmark.utils.miscellaneous import mkdir from maskrcnn_benchmark.utils.miscellaneous import mkdir
...@@ -46,7 +46,7 @@ def train(cfg, local_rank, distributed): ...@@ -46,7 +46,7 @@ def train(cfg, local_rank, distributed):
output_dir = cfg.OUTPUT_DIR output_dir = cfg.OUTPUT_DIR
save_to_disk = local_rank == 0 save_to_disk = get_rank() == 0
checkpointer = DetectronCheckpointer( checkpointer = DetectronCheckpointer(
cfg, model, optimizer, scheduler, output_dir, save_to_disk cfg, model, optimizer, scheduler, output_dir, save_to_disk
) )
...@@ -147,7 +147,7 @@ def main(): ...@@ -147,7 +147,7 @@ def main():
if output_dir: if output_dir:
mkdir(output_dir) mkdir(output_dir)
logger = setup_logger("maskrcnn_benchmark", output_dir, args.local_rank) logger = setup_logger("maskrcnn_benchmark", output_dir, get_rank())
logger.info("Using {} GPUs".format(num_gpus)) logger.info("Using {} GPUs".format(num_gpus))
logger.info(args) logger.info(args)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment