Commit 5c058ea1 authored by Davis King's avatar Davis King

Made cuBLAS and cuDNN automatically switch their library handles to the

currently active device id if the user changes the active device via a call to
cudaSetDevice().
parent ccb148b4
......@@ -6,6 +6,7 @@
#ifdef DLIB_USE_CUDA
#include "cublas_dlibapi.h"
#include "cuda_utils.h"
#include <cublas_v2.h>
......@@ -52,6 +53,7 @@ namespace dlib
cublas_context()
{
CHECK_CUBLAS(cublasCreate(&handle));
CHECK_CUDA(cudaGetDevice(&device_id));
}
~cublas_context()
{
......@@ -59,18 +61,27 @@ namespace dlib
}
cublasHandle_t get_handle (
) const { return handle; }
)
{
// Check if the active device for the current thread changed. If so then
// regenerate our cuBLAS handle so it will use the currently selected
// device.
int new_device_id;
CHECK_CUDA(cudaGetDevice(&new_device_id));
if (new_device_id != device_id)
{
CHECK_CUBLAS(cublasDestroy(handle));
CHECK_CUBLAS(cublasCreate(&handle));
}
return handle;
}
private:
cublasHandle_t handle;
int device_id;
};
// TODO, there should probably be some function that is like dlibCudaSetDevice().
// Because people will call cudaSetDevice() expecting to set the device but for
// cuBLAS and cuDNN, since they have these handles, they will keep using the old
// devices. So we should have something that resets these handles and does a
// "dlibCudaSetDevice()"
static cublasHandle_t context()
{
thread_local cublas_context c;
......
......@@ -70,6 +70,7 @@ namespace dlib
cudnn_context()
{
CHECK_CUDNN(cudnnCreate(&handle));
CHECK_CUDA(cudaGetDevice(&device_id));
}
~cudnn_context()
......@@ -78,10 +79,24 @@ namespace dlib
}
cudnnHandle_t get_handle (
) const { return handle; }
)
{
// Check if the active device for the current thread changed. If so then
// regenerate our cuDNN handle so it will use the currently selected
// device.
int new_device_id;
CHECK_CUDA(cudaGetDevice(&new_device_id));
if (new_device_id != device_id)
{
CHECK_CUDNN(cudnnDestroy(handle));
CHECK_CUDNN(cudnnCreate(&handle));
}
return handle;
}
private:
cudnnHandle_t handle;
int device_id;
};
static cudnnHandle_t context()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment