Commit 55d3ab44 authored by Yuxin Wu's avatar Yuxin Wu Committed by Francisco Massa

Fix indexing error in nms kernel (#224)

* Fix indexing error in nms kernel

Here it is indexing a cuda tensor with CPU indices. It used to work, but 
after https://github.com/pytorch/pytorch/commit/006505bb8f9dcf0f38b32518308d071c8a1ccec6 it results in memory corruption.

* Use the device of other tensors
parent 0f61b004
...@@ -124,5 +124,8 @@ at::Tensor nms_cuda(const at::Tensor boxes, float nms_overlap_thresh) { ...@@ -124,5 +124,8 @@ at::Tensor nms_cuda(const at::Tensor boxes, float nms_overlap_thresh) {
THCudaFree(state, mask_dev); THCudaFree(state, mask_dev);
// TODO improve this part // TODO improve this part
return std::get<0>(order_t.index({keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep)}).sort(0, false)); return std::get<0>(order_t.index({
keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep).to(
order_t.device(), keep.scalar_type())
}).sort(0, false));
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment