Unverified Commit 1c168f8a authored by Soumith Chintala's avatar Soumith Chintala Committed by GitHub

Merge pull request #555 from vishwakftw/fix-dispatch-aten

Fix dispatch breakage
parents 90080e60 8df030c6
...@@ -239,7 +239,7 @@ at::Tensor ROIAlign_forward_cpu(const at::Tensor& input, ...@@ -239,7 +239,7 @@ at::Tensor ROIAlign_forward_cpu(const at::Tensor& input,
return output; return output;
} }
AT_DISPATCH_FLOATING_TYPES(input.type(), "ROIAlign_forward", [&] { AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "ROIAlign_forward", [&] {
ROIAlignForward_cpu_kernel<scalar_t>( ROIAlignForward_cpu_kernel<scalar_t>(
output_size, output_size,
input.data<scalar_t>(), input.data<scalar_t>(),
......
...@@ -68,7 +68,7 @@ at::Tensor nms_cpu(const at::Tensor& dets, ...@@ -68,7 +68,7 @@ at::Tensor nms_cpu(const at::Tensor& dets,
const at::Tensor& scores, const at::Tensor& scores,
const float threshold) { const float threshold) {
at::Tensor result; at::Tensor result;
AT_DISPATCH_FLOATING_TYPES(dets.type(), "nms", [&] { AT_DISPATCH_FLOATING_TYPES(dets.scalar_type(), "nms", [&] {
result = nms_cpu_kernel<scalar_t>(dets, scores, threshold); result = nms_cpu_kernel<scalar_t>(dets, scores, threshold);
}); });
return result; return result;
......
...@@ -280,7 +280,7 @@ at::Tensor ROIAlign_forward_cuda(const at::Tensor& input, ...@@ -280,7 +280,7 @@ at::Tensor ROIAlign_forward_cuda(const at::Tensor& input,
return output; return output;
} }
AT_DISPATCH_FLOATING_TYPES(input.type(), "ROIAlign_forward", [&] { AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "ROIAlign_forward", [&] {
RoIAlignForward<scalar_t><<<grid, block, 0, stream>>>( RoIAlignForward<scalar_t><<<grid, block, 0, stream>>>(
output_size, output_size,
input.contiguous().data<scalar_t>(), input.contiguous().data<scalar_t>(),
...@@ -326,7 +326,7 @@ at::Tensor ROIAlign_backward_cuda(const at::Tensor& grad, ...@@ -326,7 +326,7 @@ at::Tensor ROIAlign_backward_cuda(const at::Tensor& grad,
return grad_input; return grad_input;
} }
AT_DISPATCH_FLOATING_TYPES(grad.type(), "ROIAlign_backward", [&] { AT_DISPATCH_FLOATING_TYPES(grad.scalar_type(), "ROIAlign_backward", [&] {
RoIAlignBackwardFeature<scalar_t><<<grid, block, 0, stream>>>( RoIAlignBackwardFeature<scalar_t><<<grid, block, 0, stream>>>(
grad.numel(), grad.numel(),
grad.contiguous().data<scalar_t>(), grad.contiguous().data<scalar_t>(),
......
...@@ -134,7 +134,7 @@ std::tuple<at::Tensor, at::Tensor> ROIPool_forward_cuda(const at::Tensor& input, ...@@ -134,7 +134,7 @@ std::tuple<at::Tensor, at::Tensor> ROIPool_forward_cuda(const at::Tensor& input,
return std::make_tuple(output, argmax); return std::make_tuple(output, argmax);
} }
AT_DISPATCH_FLOATING_TYPES(input.type(), "ROIPool_forward", [&] { AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "ROIPool_forward", [&] {
RoIPoolFForward<scalar_t><<<grid, block, 0, stream>>>( RoIPoolFForward<scalar_t><<<grid, block, 0, stream>>>(
output_size, output_size,
input.contiguous().data<scalar_t>(), input.contiguous().data<scalar_t>(),
...@@ -182,7 +182,7 @@ at::Tensor ROIPool_backward_cuda(const at::Tensor& grad, ...@@ -182,7 +182,7 @@ at::Tensor ROIPool_backward_cuda(const at::Tensor& grad,
return grad_input; return grad_input;
} }
AT_DISPATCH_FLOATING_TYPES(grad.type(), "ROIPool_backward", [&] { AT_DISPATCH_FLOATING_TYPES(grad.scalar_type(), "ROIPool_backward", [&] {
RoIPoolFBackward<scalar_t><<<grid, block, 0, stream>>>( RoIPoolFBackward<scalar_t><<<grid, block, 0, stream>>>(
grad.numel(), grad.numel(),
grad.contiguous().data<scalar_t>(), grad.contiguous().data<scalar_t>(),
......
...@@ -125,7 +125,7 @@ at::Tensor SigmoidFocalLoss_forward_cuda( ...@@ -125,7 +125,7 @@ at::Tensor SigmoidFocalLoss_forward_cuda(
return losses; return losses;
} }
AT_DISPATCH_FLOATING_TYPES(logits.type(), "SigmoidFocalLoss_forward", [&] { AT_DISPATCH_FLOATING_TYPES(logits.scalar_type(), "SigmoidFocalLoss_forward", [&] {
SigmoidFocalLossForward<scalar_t><<<grid, block, 0, stream>>>( SigmoidFocalLossForward<scalar_t><<<grid, block, 0, stream>>>(
losses_size, losses_size,
logits.contiguous().data<scalar_t>(), logits.contiguous().data<scalar_t>(),
...@@ -169,7 +169,7 @@ at::Tensor SigmoidFocalLoss_backward_cuda( ...@@ -169,7 +169,7 @@ at::Tensor SigmoidFocalLoss_backward_cuda(
return d_logits; return d_logits;
} }
AT_DISPATCH_FLOATING_TYPES(logits.type(), "SigmoidFocalLoss_backward", [&] { AT_DISPATCH_FLOATING_TYPES(logits.scalar_type(), "SigmoidFocalLoss_backward", [&] {
SigmoidFocalLossBackward<scalar_t><<<grid, block, 0, stream>>>( SigmoidFocalLossBackward<scalar_t><<<grid, block, 0, stream>>>(
d_logits_size, d_logits_size,
logits.contiguous().data<scalar_t>(), logits.contiguous().data<scalar_t>(),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment