Commit 946ba8d0 authored by Ashwin Bharambe's avatar Ashwin Bharambe Committed by Ashwin Bharambe

Refactor { train_net, utils/net }

Summary:
This diff starts a series of diffs to refactor the detectron codebase
so it can be used in an elastic data parallel context (see
https://our.intern.facebook.com/intern/dex/caffe2/elastic-data-parallel-model-for-distributed-training/)
without any forks.

Specifically, this diff does the following:

 - Splits out the `TrainingStats` class so it can be reused / composed
 - Slightly refactor of `initialize_weights_from_file()`
 - Reduces the verbosity of some logs

Hopefully, nothing controversial here :)

Reviewed By: rbgirshick

Differential Revision: D6826820

fbshipit-source-id: fc15209a0ff50e5d09281e36173198c77aa77a12
parent 8c013e44
...@@ -188,12 +188,12 @@ def _compute_and_log_stats(roidb): ...@@ -188,12 +188,12 @@ def _compute_and_log_stats(roidb):
(entry['gt_classes'] > 0) & (entry['is_crowd'] == 0))[0] (entry['gt_classes'] > 0) & (entry['is_crowd'] == 0))[0]
gt_classes = entry['gt_classes'][gt_inds] gt_classes = entry['gt_classes'][gt_inds]
gt_hist += np.histogram(gt_classes, bins=hist_bins)[0] gt_hist += np.histogram(gt_classes, bins=hist_bins)[0]
logger.info('Ground-truth class histogram:') logger.debug('Ground-truth class histogram:')
for i, v in enumerate(gt_hist): for i, v in enumerate(gt_hist):
logger.info( logger.debug(
'{:d}{:s}: {:d}'.format( '{:d}{:s}: {:d}'.format(
i, classes[i].rjust(char_len), v)) i, classes[i].rjust(char_len), v))
logger.info('-' * char_len) logger.debug('-' * char_len)
logger.info( logger.debug(
'{:s}: {:d}'.format( '{:s}: {:d}'.format(
'total'.rjust(char_len), np.sum(gt_hist))) 'total'.rjust(char_len), np.sum(gt_hist)))
...@@ -32,7 +32,6 @@ from ops.collect_and_distribute_fpn_rpn_proposals \ ...@@ -32,7 +32,6 @@ from ops.collect_and_distribute_fpn_rpn_proposals \
import CollectAndDistributeFpnRpnProposalsOp import CollectAndDistributeFpnRpnProposalsOp
from ops.generate_proposal_labels import GenerateProposalLabelsOp from ops.generate_proposal_labels import GenerateProposalLabelsOp
from ops.generate_proposals import GenerateProposalsOp from ops.generate_proposals import GenerateProposalsOp
from utils import lr_policy
import roi_data.fast_rcnn import roi_data.fast_rcnn
import utils.c2 as c2_utils import utils.c2 as c2_utils
...@@ -417,14 +416,13 @@ class DetectionModelHelper(cnn.CNNModelHelper): ...@@ -417,14 +416,13 @@ class DetectionModelHelper(cnn.CNNModelHelper):
self.use_cudnn = self.prev_use_cudnn self.use_cudnn = self.prev_use_cudnn
self.prev_use_cudnn = prev_use_cudnn self.prev_use_cudnn = prev_use_cudnn
def UpdateWorkspaceLr(self, cur_iter): def UpdateWorkspaceLr(self, cur_iter, new_lr):
"""Updates the model's current learning rate and the workspace (learning """Updates the model's current learning rate and the workspace (learning
rate and update history/momentum blobs). rate and update history/momentum blobs).
""" """
# The workspace is the one source of truth for the lr # The workspace is the one source of truth for the lr
# The lr is always the same on all GPUs # The lr is always the same on all GPUs
cur_lr = workspace.FetchBlob('gpu_0/lr')[0] cur_lr = workspace.FetchBlob('gpu_0/lr')[0]
new_lr = lr_policy.get_lr_at_iter(cur_iter)
# There are no type conversions between the lr in Python and the lr in # There are no type conversions between the lr in Python and the lr in
# the GPU (both are float32), so exact comparision is ok # the GPU (both are float32), so exact comparision is ok
if cur_lr != new_lr: if cur_lr != new_lr:
......
...@@ -169,7 +169,7 @@ def add_ResNet_roi_conv5_head_for_keypoints( ...@@ -169,7 +169,7 @@ def add_ResNet_roi_conv5_head_for_keypoints(
) )
# Using the prefix '_[pose]_' to 'res5' enables initializing the head's # Using the prefix '_[pose]_' to 'res5' enables initializing the head's
# parameters using pretrained 'res5' parameters if given (see # parameters using pretrained 'res5' parameters if given (see
# utils.net.initialize_gpu_0_from_weights_file) # utils.net.initialize_from_weights_file)
s, dim_in = ResNet.add_stage( s, dim_in = ResNet.add_stage(
model, model,
'_[pose]_res5', '_[pose]_res5',
......
...@@ -42,16 +42,19 @@ def initialize_from_weights_file(model, weights_file, broadcast=True): ...@@ -42,16 +42,19 @@ def initialize_from_weights_file(model, weights_file, broadcast=True):
multiple GPUs are used, the loaded weights are synchronized on all GPUs, multiple GPUs are used, the loaded weights are synchronized on all GPUs,
unless 'broadcast' is False. unless 'broadcast' is False.
""" """
initialize_gpu_0_from_weights_file(model, weights_file) initialize_gpu_from_weights_file(model, weights_file, gpu_id=0)
if broadcast: if broadcast:
broadcast_parameters(model) broadcast_parameters(model)
def initialize_gpu_0_from_weights_file(model, weights_file): def initialize_gpu_from_weights_file(model, weights_file, gpu_id=0):
"""Initialize a network with ops on GPU 0. Note that we always use GPU 0 and """Initialize a network with ops on a specific GPU.
rely on proper usage of CUDA_VISIBLE_DEVICES.
If you use CUDA_VISIBLE_DEVICES to target specific GPUs, Caffe2 will
automatically map logical GPU ids (starting from 0) to the physical GPUs
specified in CUDA_VISIBLE_DEVICES.
""" """
logger.info('Loading from: {}'.format(weights_file)) logger.info('Loading weights from: {}'.format(weights_file))
ws_blobs = workspace.Blobs() ws_blobs = workspace.Blobs()
with open(weights_file, 'r') as f: with open(weights_file, 'r') as f:
src_blobs = pickle.load(f) src_blobs = pickle.load(f)
...@@ -62,11 +65,11 @@ def initialize_gpu_0_from_weights_file(model, weights_file): ...@@ -62,11 +65,11 @@ def initialize_gpu_0_from_weights_file(model, weights_file):
# Backwards compat--dictionary used to be only blobs, now they are # Backwards compat--dictionary used to be only blobs, now they are
# stored under the 'blobs' key # stored under the 'blobs' key
src_blobs = src_blobs['blobs'] src_blobs = src_blobs['blobs']
# Initialize weights on GPU 0 only # Initialize weights on GPU gpu_id only
unscoped_param_names = OrderedDict() # Print these out in model order unscoped_param_names = OrderedDict() # Print these out in model order
for blob in model.params: for blob in model.params:
unscoped_param_names[c2_utils.UnscopeName(str(blob))] = True unscoped_param_names[c2_utils.UnscopeName(str(blob))] = True
with c2_utils.NamedCudaScope(0): with c2_utils.NamedCudaScope(gpu_id):
for unscoped_param_name in unscoped_param_names.keys(): for unscoped_param_name in unscoped_param_names.keys():
if (unscoped_param_name.find(']_') >= 0 and if (unscoped_param_name.find(']_') >= 0 and
unscoped_param_name not in src_blobs): unscoped_param_name not in src_blobs):
...@@ -85,10 +88,12 @@ def initialize_gpu_0_from_weights_file(model, weights_file): ...@@ -85,10 +88,12 @@ def initialize_gpu_0_from_weights_file(model, weights_file):
dst_name = core.ScopedName(unscoped_param_name) dst_name = core.ScopedName(unscoped_param_name)
has_momentum = src_name + '_momentum' in src_blobs has_momentum = src_name + '_momentum' in src_blobs
has_momentum_str = ' [+ momentum]' if has_momentum else '' has_momentum_str = ' [+ momentum]' if has_momentum else ''
logger.info('{:s}{:} loaded from weights file into {:s}: {}'. logger.debug(
format( '{:s}{:} loaded from weights file into {:s}: {}'.format(
src_name, has_momentum_str, src_name, has_momentum_str, dst_name, src_blobs[src_name]
dst_name, src_blobs[src_name].shape)) .shape
)
)
if dst_name in ws_blobs: if dst_name in ws_blobs:
# If the blob is already in the workspace, make sure that it # If the blob is already in the workspace, make sure that it
# matches the shape of the loaded blob # matches the shape of the loaded blob
...@@ -121,7 +126,7 @@ def initialize_gpu_0_from_weights_file(model, weights_file): ...@@ -121,7 +126,7 @@ def initialize_gpu_0_from_weights_file(model, weights_file):
with c2_utils.CpuScope(): with c2_utils.CpuScope():
workspace.FeedBlob( workspace.FeedBlob(
'__preserve__/{:s}'.format(src_name), src_blobs[src_name]) '__preserve__/{:s}'.format(src_name), src_blobs[src_name])
logger.info( logger.debug(
'{:s} preserved in workspace (unused)'.format(src_name)) '{:s} preserved in workspace (unused)'.format(src_name))
......
#!/usr/bin/env python2
# Copyright (c) 2017-present, Facebook, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##############################################################################
"""Utilities for training."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import datetime
import numpy as np
from caffe2.python import utils as c2_py_utils
from core.config import cfg
from utils.logging import log_json_stats
from utils.logging import SmoothedValue
from utils.timer import Timer
import utils.net as nu
class TrainingStats(object):
"""Track vital training statistics."""
def __init__(self, model):
# Window size for smoothing tracked values (with median filtering)
self.WIN_SZ = 20
# Output logging period in SGD iterations
self.LOG_PERIOD = 20
self.smoothed_losses_and_metrics = {
key: SmoothedValue(self.WIN_SZ)
for key in model.losses + model.metrics
}
self.losses_and_metrics = {
key: 0
for key in model.losses + model.metrics
}
self.smoothed_total_loss = SmoothedValue(self.WIN_SZ)
self.smoothed_mb_qsize = SmoothedValue(self.WIN_SZ)
self.iter_total_loss = np.nan
self.iter_timer = Timer()
self.model = model
def IterTic(self):
self.iter_timer.tic()
def IterToc(self):
return self.iter_timer.toc(average=False)
def ResetIterTimer(self):
self.iter_timer.reset()
def UpdateIterStats(self):
"""Update tracked iteration statistics."""
for k in self.losses_and_metrics.keys():
if k in self.model.losses:
self.losses_and_metrics[k] = nu.sum_multi_gpu_blob(k)
else:
self.losses_and_metrics[k] = nu.average_multi_gpu_blob(k)
for k, v in self.smoothed_losses_and_metrics.items():
v.AddValue(self.losses_and_metrics[k])
self.iter_total_loss = np.sum(
np.array([self.losses_and_metrics[k] for k in self.model.losses])
)
self.smoothed_total_loss.AddValue(self.iter_total_loss)
self.smoothed_mb_qsize.AddValue(
self.model.roi_data_loader._minibatch_queue.qsize()
)
def LogIterStats(self, cur_iter, lr):
"""Log the tracked statistics."""
if (cur_iter % self.LOG_PERIOD == 0 or
cur_iter == cfg.SOLVER.MAX_ITER - 1):
stats = self.GetStats(cur_iter, lr)
log_json_stats(stats)
def GetStats(self, cur_iter, lr):
eta_seconds = self.iter_timer.average_time * (
cfg.SOLVER.MAX_ITER - cur_iter
)
eta = str(datetime.timedelta(seconds=int(eta_seconds)))
mem_stats = c2_py_utils.GetGPUMemoryUsageStats()
mem_usage = np.max(mem_stats['max_by_gpu'][:cfg.NUM_GPUS])
stats = dict(
iter=cur_iter,
lr=float(lr),
time=self.iter_timer.average_time,
loss=self.smoothed_total_loss.GetMedianValue(),
eta=eta,
mb_qsize=int(
np.round(self.smoothed_mb_qsize.GetMedianValue())
),
mem=int(np.ceil(mem_usage / 1024 / 1024))
)
for k, v in self.smoothed_losses_and_metrics.items():
stats[k] = v.GetMedianValue()
return stats
...@@ -24,7 +24,6 @@ from __future__ import unicode_literals ...@@ -24,7 +24,6 @@ from __future__ import unicode_literals
import argparse import argparse
import cv2 # NOQA (Must import before importing caffe2 due to bug in cv2) import cv2 # NOQA (Must import before importing caffe2 due to bug in cv2)
import datetime
import logging import logging
import numpy as np import numpy as np
import os import os
...@@ -34,7 +33,6 @@ import sys ...@@ -34,7 +33,6 @@ import sys
import test_net import test_net
from caffe2.python import memonger from caffe2.python import memonger
from caffe2.python import utils as c2_py_utils
from caffe2.python import workspace from caffe2.python import workspace
from core.config import assert_and_infer_cfg from core.config import assert_and_infer_cfg
...@@ -44,10 +42,9 @@ from core.config import merge_cfg_from_file ...@@ -44,10 +42,9 @@ from core.config import merge_cfg_from_file
from core.config import merge_cfg_from_list from core.config import merge_cfg_from_list
from datasets.roidb import combined_roidb_for_training from datasets.roidb import combined_roidb_for_training
from modeling import model_builder from modeling import model_builder
from utils.logging import log_json_stats from utils import lr_policy
from utils.logging import setup_logging from utils.logging import setup_logging
from utils.logging import SmoothedValue from utils.training_stats import TrainingStats
from utils.timer import Timer
import utils.c2 import utils.c2
import utils.env as envu import utils.env as envu
import utils.net as nu import utils.net as nu
...@@ -95,80 +92,6 @@ def parse_args(): ...@@ -95,80 +92,6 @@ def parse_args():
return parser.parse_args() return parser.parse_args()
class TrainingStats(object):
"""Track vital training statistics."""
def __init__(self, model):
# Window size for smoothing tracked values (with median filtering)
self.WIN_SZ = 20
# Output logging period in SGD iterations
self.LOG_PERIOD = 20
self.smoothed_losses_and_metrics = {
key: SmoothedValue(self.WIN_SZ)
for key in model.losses + model.metrics
}
self.losses_and_metrics = {
key: 0
for key in model.losses + model.metrics
}
self.smoothed_total_loss = SmoothedValue(self.WIN_SZ)
self.smoothed_mb_qsize = SmoothedValue(self.WIN_SZ)
self.iter_total_loss = np.nan
self.iter_timer = Timer()
self.model = model
def IterTic(self):
self.iter_timer.tic()
def IterToc(self):
return self.iter_timer.toc(average=False)
def ResetIterTimer(self):
self.iter_timer.reset()
def UpdateIterStats(self):
"""Update tracked iteration statistics."""
for k in self.losses_and_metrics.keys():
if k in self.model.losses:
self.losses_and_metrics[k] = nu.sum_multi_gpu_blob(k)
else:
self.losses_and_metrics[k] = nu.average_multi_gpu_blob(k)
for k, v in self.smoothed_losses_and_metrics.items():
v.AddValue(self.losses_and_metrics[k])
self.iter_total_loss = np.sum(
np.array([self.losses_and_metrics[k] for k in self.model.losses])
)
self.smoothed_total_loss.AddValue(self.iter_total_loss)
self.smoothed_mb_qsize.AddValue(
self.model.roi_data_loader._minibatch_queue.qsize()
)
def LogIterStats(self, cur_iter, lr):
"""Log the tracked statistics."""
if (cur_iter % self.LOG_PERIOD == 0 or
cur_iter == cfg.SOLVER.MAX_ITER - 1):
eta_seconds = self.iter_timer.average_time * (
cfg.SOLVER.MAX_ITER - cur_iter
)
eta = str(datetime.timedelta(seconds=int(eta_seconds)))
mem_stats = c2_py_utils.GetGPUMemoryUsageStats()
mem_usage = np.max(mem_stats['max_by_gpu'][:cfg.NUM_GPUS])
stats = dict(
iter=cur_iter,
lr=float(lr),
time=self.iter_timer.average_time,
loss=self.smoothed_total_loss.GetMedianValue(),
eta=eta,
mb_qsize=int(
np.round(self.smoothed_mb_qsize.GetMedianValue())
),
mem=int(np.ceil(mem_usage / 1024 / 1024))
)
for k, v in self.smoothed_losses_and_metrics.items():
stats[k] = v.GetMedianValue()
log_json_stats(stats)
def main(): def main():
# Initialize C2 # Initialize C2
workspace.GlobalInit( workspace.GlobalInit(
...@@ -213,7 +136,7 @@ def train_model(): ...@@ -213,7 +136,7 @@ def train_model():
for cur_iter in range(start_iter, cfg.SOLVER.MAX_ITER): for cur_iter in range(start_iter, cfg.SOLVER.MAX_ITER):
training_stats.IterTic() training_stats.IterTic()
lr = model.UpdateWorkspaceLr(cur_iter) lr = model.UpdateWorkspaceLr(cur_iter, lr_policy.get_lr_at_iter(cur_iter))
workspace.RunNet(model.net.Proto().name) workspace.RunNet(model.net.Proto().name)
if cur_iter == start_iter: if cur_iter == start_iter:
nu.print_net(model) nu.print_net(model)
...@@ -309,7 +232,7 @@ def setup_model_for_training(model, output_dir): ...@@ -309,7 +232,7 @@ def setup_model_for_training(model, output_dir):
if cfg.TRAIN.WEIGHTS: if cfg.TRAIN.WEIGHTS:
# Override random weight initialization with weights from a saved model # Override random weight initialization with weights from a saved model
nu.initialize_gpu_0_from_weights_file(model, cfg.TRAIN.WEIGHTS) nu.initialize_gpu_from_weights_file(model, cfg.TRAIN.WEIGHTS, gpu_id=0)
# Even if we're randomly initializing we still need to synchronize # Even if we're randomly initializing we still need to synchronize
# parameters across GPUs # parameters across GPUs
nu.broadcast_parameters(model) nu.broadcast_parameters(model)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment