Commit f0318794 authored by kaiJIN's avatar kaiJIN Committed by Francisco Massa

Support for running on arbitrary CUDA device. (#537)

* support for any one cuda device

* Revert "support for any one cuda device"

This reverts commit 0197e4e2ef18ec41cc155f3ae2a0face5b77e1e9.

* support runnning for anyone cuda device

* using safe CUDAGuard rather than intrinsic CUDASetDevice

* supplement a header dependency (test passed)

* Support for  arbitrary GPU device.

* Support for arbitrary GPU device.

* add docs for two method to control devices
parent 9063850d
...@@ -68,6 +68,27 @@ image = ... ...@@ -68,6 +68,27 @@ image = ...
predictions = coco_demo.run_on_opencv_image(image) predictions = coco_demo.run_on_opencv_image(image)
``` ```
### Use it on an arbitrary GPU device
For some cases, while multi-GPU devices are installed in a machine, a possible situation is that
we only have accesse to a specified GPU device (e.g. CUDA:1 or CUDA:2) for inference, testing or training.
Here, the repository currently supports two methods to control devices.
#### 1. using CUDA_VISIBLE_DEVICES environment variable (Recommend)
Here is an example for Mask R-CNN R-50 FPN quick on the second device (CUDA:1):
```bash
export CUDA_VISIBLE_DEVICES=1
python tools/train_net.py --config-file=configs/quick_schedules/e2e_mask_rcnn_R_50_FPN_quick.yaml
```
Now, the session will be totally loaded on the second GPU device (CUDA:1).
#### 2. using MODEL.DEVICE flag
In addition, the program could run on a sepcific GPU device by setting `MODEL.DEVICE` flag.
```bash
python tools/train_net.py --config-file=configs/quick_schedules/e2e_mask_rcnn_R_50_FPN_quick.yaml MODEL.DEVICE cuda:1
```
Where, we add a `MODEL.DEVICE cuda:1` flag to configure the target device.
*Pay attention, there is still a small part of memory stored in `cuda:0` for some reasons.*
## Perform training on COCO dataset ## Perform training on COCO dataset
For the following examples to work, you need to first install `maskrcnn_benchmark`. For the following examples to work, you need to first install `maskrcnn_benchmark`.
......
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAGuard.h>
#include <THC/THC.h> #include <THC/THC.h>
#include <THC/THCAtomics.cuh> #include <THC/THCAtomics.cuh>
...@@ -263,6 +264,8 @@ at::Tensor ROIAlign_forward_cuda(const at::Tensor& input, ...@@ -263,6 +264,8 @@ at::Tensor ROIAlign_forward_cuda(const at::Tensor& input,
AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor"); AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor");
AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor"); AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor");
at::cuda::CUDAGuard device_guard(input.device());
auto num_rois = rois.size(0); auto num_rois = rois.size(0);
auto channels = input.size(1); auto channels = input.size(1);
auto height = input.size(2); auto height = input.size(2);
...@@ -311,6 +314,7 @@ at::Tensor ROIAlign_backward_cuda(const at::Tensor& grad, ...@@ -311,6 +314,7 @@ at::Tensor ROIAlign_backward_cuda(const at::Tensor& grad,
const int sampling_ratio) { const int sampling_ratio) {
AT_ASSERTM(grad.type().is_cuda(), "grad must be a CUDA tensor"); AT_ASSERTM(grad.type().is_cuda(), "grad must be a CUDA tensor");
AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor"); AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor");
at::cuda::CUDAGuard device_guard(grad.device());
auto num_rois = rois.size(0); auto num_rois = rois.size(0);
auto grad_input = at::zeros({batch_size, channels, height, width}, grad.options()); auto grad_input = at::zeros({batch_size, channels, height, width}, grad.options());
......
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAGuard.h>
#include <THC/THC.h> #include <THC/THC.h>
#include <THC/THCAtomics.cuh> #include <THC/THCAtomics.cuh>
...@@ -115,6 +116,8 @@ std::tuple<at::Tensor, at::Tensor> ROIPool_forward_cuda(const at::Tensor& input, ...@@ -115,6 +116,8 @@ std::tuple<at::Tensor, at::Tensor> ROIPool_forward_cuda(const at::Tensor& input,
AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor"); AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor");
AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor"); AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor");
at::cuda::CUDAGuard device_guard(input.device());
auto num_rois = rois.size(0); auto num_rois = rois.size(0);
auto channels = input.size(1); auto channels = input.size(1);
auto height = input.size(2); auto height = input.size(2);
...@@ -167,6 +170,7 @@ at::Tensor ROIPool_backward_cuda(const at::Tensor& grad, ...@@ -167,6 +170,7 @@ at::Tensor ROIPool_backward_cuda(const at::Tensor& grad,
AT_ASSERTM(grad.type().is_cuda(), "grad must be a CUDA tensor"); AT_ASSERTM(grad.type().is_cuda(), "grad must be a CUDA tensor");
AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor"); AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor");
// TODO add more checks // TODO add more checks
at::cuda::CUDAGuard device_guard(grad.device());
auto num_rois = rois.size(0); auto num_rois = rois.size(0);
auto grad_input = at::zeros({batch_size, channels, height, width}, grad.options()); auto grad_input = at::zeros({batch_size, channels, height, width}, grad.options());
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
// cyfu@cs.unc.edu // cyfu@cs.unc.edu
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAGuard.h>
#include <THC/THC.h> #include <THC/THC.h>
#include <THC/THCAtomics.cuh> #include <THC/THCAtomics.cuh>
...@@ -111,6 +112,8 @@ at::Tensor SigmoidFocalLoss_forward_cuda( ...@@ -111,6 +112,8 @@ at::Tensor SigmoidFocalLoss_forward_cuda(
AT_ASSERTM(targets.type().is_cuda(), "targets must be a CUDA tensor"); AT_ASSERTM(targets.type().is_cuda(), "targets must be a CUDA tensor");
AT_ASSERTM(logits.dim() == 2, "logits should be NxClass"); AT_ASSERTM(logits.dim() == 2, "logits should be NxClass");
at::cuda::CUDAGuard device_guard(logits.device());
const int num_samples = logits.size(0); const int num_samples = logits.size(0);
auto losses = at::empty({num_samples, logits.size(1)}, logits.options()); auto losses = at::empty({num_samples, logits.size(1)}, logits.options());
...@@ -156,7 +159,9 @@ at::Tensor SigmoidFocalLoss_backward_cuda( ...@@ -156,7 +159,9 @@ at::Tensor SigmoidFocalLoss_backward_cuda(
const int num_samples = logits.size(0); const int num_samples = logits.size(0);
AT_ASSERTM(logits.size(1) == num_classes, "logits.size(1) should be num_classes"); AT_ASSERTM(logits.size(1) == num_classes, "logits.size(1) should be num_classes");
at::cuda::CUDAGuard device_guard(logits.device());
auto d_logits = at::zeros({num_samples, num_classes}, logits.options()); auto d_logits = at::zeros({num_samples, num_classes}, logits.options());
auto d_logits_size = num_samples * logits.size(1); auto d_logits_size = num_samples * logits.size(1);
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream();
......
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAGuard.h>
#include <THC/THC.h> #include <THC/THC.h>
#include <THC/THCDeviceUtils.cuh> #include <THC/THCDeviceUtils.cuh>
...@@ -70,6 +71,8 @@ __global__ void nms_kernel(const int n_boxes, const float nms_overlap_thresh, ...@@ -70,6 +71,8 @@ __global__ void nms_kernel(const int n_boxes, const float nms_overlap_thresh,
at::Tensor nms_cuda(const at::Tensor boxes, float nms_overlap_thresh) { at::Tensor nms_cuda(const at::Tensor boxes, float nms_overlap_thresh) {
using scalar_t = float; using scalar_t = float;
AT_ASSERTM(boxes.type().is_cuda(), "boxes must be a CUDA tensor"); AT_ASSERTM(boxes.type().is_cuda(), "boxes must be a CUDA tensor");
at::cuda::CUDAGuard device_guard(boxes.device());
auto scores = boxes.select(1, 4); auto scores = boxes.select(1, 4);
auto order_t = std::get<1>(scores.sort(0, /* descending=*/true)); auto order_t = std::get<1>(scores.sort(0, /* descending=*/true));
auto boxes_sorted = boxes.index_select(0, order_t); auto boxes_sorted = boxes.index_select(0, order_t);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment