Commit 2083f9d1 authored by Jimmy Wu's avatar Jimmy Wu Committed by Francisco Massa

Fix iteration count (#145)

parent 80eae227
......@@ -55,6 +55,7 @@ def do_train(
end = time.time()
for iteration, (images, targets, _) in enumerate(data_loader, start_iter):
data_time = time.time() - end
iteration = iteration + 1
arguments["iteration"] = iteration
scheduler.step()
......@@ -82,7 +83,7 @@ def do_train(
eta_seconds = meters.time.global_avg * (max_iter - iteration)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if iteration % 20 == 0 or iteration == (max_iter - 1):
if iteration % 20 == 0 or iteration == max_iter:
logger.info(
meters.delimiter.join(
[
......@@ -100,7 +101,7 @@ def do_train(
memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0,
)
)
if iteration % checkpoint_period == 0 and iteration > 0:
if iteration % checkpoint_period == 0:
checkpointer.save("model_{:07d}".format(iteration), **arguments)
checkpointer.save("model_{:07d}".format(iteration), **arguments)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment