Commit 2083f9d1 authored by Jimmy Wu's avatar Jimmy Wu Committed by Francisco Massa

Fix iteration count (#145)

parent 80eae227
...@@ -55,6 +55,7 @@ def do_train( ...@@ -55,6 +55,7 @@ def do_train(
end = time.time() end = time.time()
for iteration, (images, targets, _) in enumerate(data_loader, start_iter): for iteration, (images, targets, _) in enumerate(data_loader, start_iter):
data_time = time.time() - end data_time = time.time() - end
iteration = iteration + 1
arguments["iteration"] = iteration arguments["iteration"] = iteration
scheduler.step() scheduler.step()
...@@ -82,7 +83,7 @@ def do_train( ...@@ -82,7 +83,7 @@ def do_train(
eta_seconds = meters.time.global_avg * (max_iter - iteration) eta_seconds = meters.time.global_avg * (max_iter - iteration)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if iteration % 20 == 0 or iteration == (max_iter - 1): if iteration % 20 == 0 or iteration == max_iter:
logger.info( logger.info(
meters.delimiter.join( meters.delimiter.join(
[ [
...@@ -100,7 +101,7 @@ def do_train( ...@@ -100,7 +101,7 @@ def do_train(
memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0, memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0,
) )
) )
if iteration % checkpoint_period == 0 and iteration > 0: if iteration % checkpoint_period == 0:
checkpointer.save("model_{:07d}".format(iteration), **arguments) checkpointer.save("model_{:07d}".format(iteration), **arguments)
checkpointer.save("model_{:07d}".format(iteration), **arguments) checkpointer.save("model_{:07d}".format(iteration), **arguments)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment