Commit 946ba8d0 authored by Ashwin Bharambe's avatar Ashwin Bharambe Committed by Ashwin Bharambe

Refactor { train_net, utils/net }

Summary:
This diff starts a series of diffs to refactor the detectron codebase
so it can be used in an elastic data parallel context (see
https://our.intern.facebook.com/intern/dex/caffe2/elastic-data-parallel-model-for-distributed-training/)
without any forks.

Specifically, this diff does the following:

 - Splits out the `TrainingStats` class so it can be reused / composed
 - Slightly refactor of `initialize_weights_from_file()`
 - Reduces the verbosity of some logs

Hopefully, nothing controversial here :)

Reviewed By: rbgirshick

Differential Revision: D6826820

fbshipit-source-id: fc15209a0ff50e5d09281e36173198c77aa77a12
parent 8c013e44
......@@ -188,12 +188,12 @@ def _compute_and_log_stats(roidb):
(entry['gt_classes'] > 0) & (entry['is_crowd'] == 0))[0]
gt_classes = entry['gt_classes'][gt_inds]
gt_hist += np.histogram(gt_classes, bins=hist_bins)[0]
logger.info('Ground-truth class histogram:')
logger.debug('Ground-truth class histogram:')
for i, v in enumerate(gt_hist):
logger.info(
logger.debug(
'{:d}{:s}: {:d}'.format(
i, classes[i].rjust(char_len), v))
logger.info('-' * char_len)
logger.info(
logger.debug('-' * char_len)
logger.debug(
'{:s}: {:d}'.format(
'total'.rjust(char_len), np.sum(gt_hist)))
......@@ -32,7 +32,6 @@ from ops.collect_and_distribute_fpn_rpn_proposals \
import CollectAndDistributeFpnRpnProposalsOp
from ops.generate_proposal_labels import GenerateProposalLabelsOp
from ops.generate_proposals import GenerateProposalsOp
from utils import lr_policy
import roi_data.fast_rcnn
import utils.c2 as c2_utils
......@@ -417,14 +416,13 @@ class DetectionModelHelper(cnn.CNNModelHelper):
self.use_cudnn = self.prev_use_cudnn
self.prev_use_cudnn = prev_use_cudnn
def UpdateWorkspaceLr(self, cur_iter):
def UpdateWorkspaceLr(self, cur_iter, new_lr):
"""Updates the model's current learning rate and the workspace (learning
rate and update history/momentum blobs).
"""
# The workspace is the one source of truth for the lr
# The lr is always the same on all GPUs
cur_lr = workspace.FetchBlob('gpu_0/lr')[0]
new_lr = lr_policy.get_lr_at_iter(cur_iter)
# There are no type conversions between the lr in Python and the lr in
# the GPU (both are float32), so exact comparision is ok
if cur_lr != new_lr:
......
......@@ -169,7 +169,7 @@ def add_ResNet_roi_conv5_head_for_keypoints(
)
# Using the prefix '_[pose]_' to 'res5' enables initializing the head's
# parameters using pretrained 'res5' parameters if given (see
# utils.net.initialize_gpu_0_from_weights_file)
# utils.net.initialize_from_weights_file)
s, dim_in = ResNet.add_stage(
model,
'_[pose]_res5',
......
......@@ -42,16 +42,19 @@ def initialize_from_weights_file(model, weights_file, broadcast=True):
multiple GPUs are used, the loaded weights are synchronized on all GPUs,
unless 'broadcast' is False.
"""
initialize_gpu_0_from_weights_file(model, weights_file)
initialize_gpu_from_weights_file(model, weights_file, gpu_id=0)
if broadcast:
broadcast_parameters(model)
def initialize_gpu_0_from_weights_file(model, weights_file):
"""Initialize a network with ops on GPU 0. Note that we always use GPU 0 and
rely on proper usage of CUDA_VISIBLE_DEVICES.
def initialize_gpu_from_weights_file(model, weights_file, gpu_id=0):
"""Initialize a network with ops on a specific GPU.
If you use CUDA_VISIBLE_DEVICES to target specific GPUs, Caffe2 will
automatically map logical GPU ids (starting from 0) to the physical GPUs
specified in CUDA_VISIBLE_DEVICES.
"""
logger.info('Loading from: {}'.format(weights_file))
logger.info('Loading weights from: {}'.format(weights_file))
ws_blobs = workspace.Blobs()
with open(weights_file, 'r') as f:
src_blobs = pickle.load(f)
......@@ -62,11 +65,11 @@ def initialize_gpu_0_from_weights_file(model, weights_file):
# Backwards compat--dictionary used to be only blobs, now they are
# stored under the 'blobs' key
src_blobs = src_blobs['blobs']
# Initialize weights on GPU 0 only
# Initialize weights on GPU gpu_id only
unscoped_param_names = OrderedDict() # Print these out in model order
for blob in model.params:
unscoped_param_names[c2_utils.UnscopeName(str(blob))] = True
with c2_utils.NamedCudaScope(0):
with c2_utils.NamedCudaScope(gpu_id):
for unscoped_param_name in unscoped_param_names.keys():
if (unscoped_param_name.find(']_') >= 0 and
unscoped_param_name not in src_blobs):
......@@ -85,10 +88,12 @@ def initialize_gpu_0_from_weights_file(model, weights_file):
dst_name = core.ScopedName(unscoped_param_name)
has_momentum = src_name + '_momentum' in src_blobs
has_momentum_str = ' [+ momentum]' if has_momentum else ''
logger.info('{:s}{:} loaded from weights file into {:s}: {}'.
format(
src_name, has_momentum_str,
dst_name, src_blobs[src_name].shape))
logger.debug(
'{:s}{:} loaded from weights file into {:s}: {}'.format(
src_name, has_momentum_str, dst_name, src_blobs[src_name]
.shape
)
)
if dst_name in ws_blobs:
# If the blob is already in the workspace, make sure that it
# matches the shape of the loaded blob
......@@ -121,7 +126,7 @@ def initialize_gpu_0_from_weights_file(model, weights_file):
with c2_utils.CpuScope():
workspace.FeedBlob(
'__preserve__/{:s}'.format(src_name), src_blobs[src_name])
logger.info(
logger.debug(
'{:s} preserved in workspace (unused)'.format(src_name))
......
#!/usr/bin/env python2
# Copyright (c) 2017-present, Facebook, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##############################################################################
"""Utilities for training."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import datetime
import numpy as np
from caffe2.python import utils as c2_py_utils
from core.config import cfg
from utils.logging import log_json_stats
from utils.logging import SmoothedValue
from utils.timer import Timer
import utils.net as nu
class TrainingStats(object):
"""Track vital training statistics."""
def __init__(self, model):
# Window size for smoothing tracked values (with median filtering)
self.WIN_SZ = 20
# Output logging period in SGD iterations
self.LOG_PERIOD = 20
self.smoothed_losses_and_metrics = {
key: SmoothedValue(self.WIN_SZ)
for key in model.losses + model.metrics
}
self.losses_and_metrics = {
key: 0
for key in model.losses + model.metrics
}
self.smoothed_total_loss = SmoothedValue(self.WIN_SZ)
self.smoothed_mb_qsize = SmoothedValue(self.WIN_SZ)
self.iter_total_loss = np.nan
self.iter_timer = Timer()
self.model = model
def IterTic(self):
self.iter_timer.tic()
def IterToc(self):
return self.iter_timer.toc(average=False)
def ResetIterTimer(self):
self.iter_timer.reset()
def UpdateIterStats(self):
"""Update tracked iteration statistics."""
for k in self.losses_and_metrics.keys():
if k in self.model.losses:
self.losses_and_metrics[k] = nu.sum_multi_gpu_blob(k)
else:
self.losses_and_metrics[k] = nu.average_multi_gpu_blob(k)
for k, v in self.smoothed_losses_and_metrics.items():
v.AddValue(self.losses_and_metrics[k])
self.iter_total_loss = np.sum(
np.array([self.losses_and_metrics[k] for k in self.model.losses])
)
self.smoothed_total_loss.AddValue(self.iter_total_loss)
self.smoothed_mb_qsize.AddValue(
self.model.roi_data_loader._minibatch_queue.qsize()
)
def LogIterStats(self, cur_iter, lr):
"""Log the tracked statistics."""
if (cur_iter % self.LOG_PERIOD == 0 or
cur_iter == cfg.SOLVER.MAX_ITER - 1):
stats = self.GetStats(cur_iter, lr)
log_json_stats(stats)
def GetStats(self, cur_iter, lr):
eta_seconds = self.iter_timer.average_time * (
cfg.SOLVER.MAX_ITER - cur_iter
)
eta = str(datetime.timedelta(seconds=int(eta_seconds)))
mem_stats = c2_py_utils.GetGPUMemoryUsageStats()
mem_usage = np.max(mem_stats['max_by_gpu'][:cfg.NUM_GPUS])
stats = dict(
iter=cur_iter,
lr=float(lr),
time=self.iter_timer.average_time,
loss=self.smoothed_total_loss.GetMedianValue(),
eta=eta,
mb_qsize=int(
np.round(self.smoothed_mb_qsize.GetMedianValue())
),
mem=int(np.ceil(mem_usage / 1024 / 1024))
)
for k, v in self.smoothed_losses_and_metrics.items():
stats[k] = v.GetMedianValue()
return stats
......@@ -24,7 +24,6 @@ from __future__ import unicode_literals
import argparse
import cv2 # NOQA (Must import before importing caffe2 due to bug in cv2)
import datetime
import logging
import numpy as np
import os
......@@ -34,7 +33,6 @@ import sys
import test_net
from caffe2.python import memonger
from caffe2.python import utils as c2_py_utils
from caffe2.python import workspace
from core.config import assert_and_infer_cfg
......@@ -44,10 +42,9 @@ from core.config import merge_cfg_from_file
from core.config import merge_cfg_from_list
from datasets.roidb import combined_roidb_for_training
from modeling import model_builder
from utils.logging import log_json_stats
from utils import lr_policy
from utils.logging import setup_logging
from utils.logging import SmoothedValue
from utils.timer import Timer
from utils.training_stats import TrainingStats
import utils.c2
import utils.env as envu
import utils.net as nu
......@@ -95,80 +92,6 @@ def parse_args():
return parser.parse_args()
class TrainingStats(object):
"""Track vital training statistics."""
def __init__(self, model):
# Window size for smoothing tracked values (with median filtering)
self.WIN_SZ = 20
# Output logging period in SGD iterations
self.LOG_PERIOD = 20
self.smoothed_losses_and_metrics = {
key: SmoothedValue(self.WIN_SZ)
for key in model.losses + model.metrics
}
self.losses_and_metrics = {
key: 0
for key in model.losses + model.metrics
}
self.smoothed_total_loss = SmoothedValue(self.WIN_SZ)
self.smoothed_mb_qsize = SmoothedValue(self.WIN_SZ)
self.iter_total_loss = np.nan
self.iter_timer = Timer()
self.model = model
def IterTic(self):
self.iter_timer.tic()
def IterToc(self):
return self.iter_timer.toc(average=False)
def ResetIterTimer(self):
self.iter_timer.reset()
def UpdateIterStats(self):
"""Update tracked iteration statistics."""
for k in self.losses_and_metrics.keys():
if k in self.model.losses:
self.losses_and_metrics[k] = nu.sum_multi_gpu_blob(k)
else:
self.losses_and_metrics[k] = nu.average_multi_gpu_blob(k)
for k, v in self.smoothed_losses_and_metrics.items():
v.AddValue(self.losses_and_metrics[k])
self.iter_total_loss = np.sum(
np.array([self.losses_and_metrics[k] for k in self.model.losses])
)
self.smoothed_total_loss.AddValue(self.iter_total_loss)
self.smoothed_mb_qsize.AddValue(
self.model.roi_data_loader._minibatch_queue.qsize()
)
def LogIterStats(self, cur_iter, lr):
"""Log the tracked statistics."""
if (cur_iter % self.LOG_PERIOD == 0 or
cur_iter == cfg.SOLVER.MAX_ITER - 1):
eta_seconds = self.iter_timer.average_time * (
cfg.SOLVER.MAX_ITER - cur_iter
)
eta = str(datetime.timedelta(seconds=int(eta_seconds)))
mem_stats = c2_py_utils.GetGPUMemoryUsageStats()
mem_usage = np.max(mem_stats['max_by_gpu'][:cfg.NUM_GPUS])
stats = dict(
iter=cur_iter,
lr=float(lr),
time=self.iter_timer.average_time,
loss=self.smoothed_total_loss.GetMedianValue(),
eta=eta,
mb_qsize=int(
np.round(self.smoothed_mb_qsize.GetMedianValue())
),
mem=int(np.ceil(mem_usage / 1024 / 1024))
)
for k, v in self.smoothed_losses_and_metrics.items():
stats[k] = v.GetMedianValue()
log_json_stats(stats)
def main():
# Initialize C2
workspace.GlobalInit(
......@@ -213,7 +136,7 @@ def train_model():
for cur_iter in range(start_iter, cfg.SOLVER.MAX_ITER):
training_stats.IterTic()
lr = model.UpdateWorkspaceLr(cur_iter)
lr = model.UpdateWorkspaceLr(cur_iter, lr_policy.get_lr_at_iter(cur_iter))
workspace.RunNet(model.net.Proto().name)
if cur_iter == start_iter:
nu.print_net(model)
......@@ -309,7 +232,7 @@ def setup_model_for_training(model, output_dir):
if cfg.TRAIN.WEIGHTS:
# Override random weight initialization with weights from a saved model
nu.initialize_gpu_0_from_weights_file(model, cfg.TRAIN.WEIGHTS)
nu.initialize_gpu_from_weights_file(model, cfg.TRAIN.WEIGHTS, gpu_id=0)
# Even if we're randomly initializing we still need to synchronize
# parameters across GPUs
nu.broadcast_parameters(model)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment