Commit f25c6cff authored by 夜阑听风's avatar 夜阑听风 Committed by Francisco Massa

fix a bug in loss reduction for log(#309) (#310)

sort keys before reduction
parent 49b96394
...@@ -22,9 +22,9 @@ def reduce_loss_dict(loss_dict): ...@@ -22,9 +22,9 @@ def reduce_loss_dict(loss_dict):
with torch.no_grad(): with torch.no_grad():
loss_names = [] loss_names = []
all_losses = [] all_losses = []
for k, v in loss_dict.items(): for k in sorted(loss_dict.keys()):
loss_names.append(k) loss_names.append(k)
all_losses.append(v) all_losses.append(loss_dict[k])
all_losses = torch.stack(all_losses, dim=0) all_losses = torch.stack(all_losses, dim=0)
dist.reduce(all_losses, dst=0) dist.reduce(all_losses, dst=0)
if dist.get_rank() == 0: if dist.get_rank() == 0:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment