Unverified Commit fd204722 authored by Francisco Massa's avatar Francisco Massa Committed by GitHub

Update ImageList to work with 3d tensors (#543)

parent de42d895
...@@ -41,6 +41,8 @@ def to_image_list(tensors, size_divisible=0): ...@@ -41,6 +41,8 @@ def to_image_list(tensors, size_divisible=0):
return tensors return tensors
elif isinstance(tensors, torch.Tensor): elif isinstance(tensors, torch.Tensor):
# single tensor shape can be inferred # single tensor shape can be inferred
if tensors.dim() == 3:
tensors = tensors[None]
assert tensors.dim() == 4 assert tensors.dim() == 4
image_sizes = [tensor.shape[-2:] for tensor in tensors] image_sizes = [tensor.shape[-2:] for tensor in tensors]
return ImageList(tensors, image_sizes) return ImageList(tensors, image_sizes)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment