Commit adec3eef authored by Davis King's avatar Davis King

merged

parents 8bb4a421 486cf56b
...@@ -1278,7 +1278,7 @@ namespace dlib ...@@ -1278,7 +1278,7 @@ namespace dlib
{ {
private: private:
// We don't want anyone making these no_label_type objects. They are here only to // We don't want anyone making these no_label_type objects. They are here only to
// allow add_loss_layer::label_type and dnn_trainer::label_type to exist which voids // allow add_loss_layer::label_type and dnn_trainer::label_type to exist which avoids
// needing to overload add_loss_layer and dnn_trainer for supervised an unsupervised // needing to overload add_loss_layer and dnn_trainer for supervised an unsupervised
// losses. It also can be a type to use in template metaprogramming to indicate // losses. It also can be a type to use in template metaprogramming to indicate
// "no label". So here we make the constructor private with the exception that // "no label". So here we make the constructor private with the exception that
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#ifdef DLIB_USE_CUDA #ifdef DLIB_USE_CUDA
#include "cublas_dlibapi.h" #include "cublas_dlibapi.h"
#include "cuda_utils.h"
#include <cublas_v2.h> #include <cublas_v2.h>
...@@ -52,6 +53,7 @@ namespace dlib ...@@ -52,6 +53,7 @@ namespace dlib
cublas_context() cublas_context()
{ {
CHECK_CUBLAS(cublasCreate(&handle)); CHECK_CUBLAS(cublasCreate(&handle));
CHECK_CUDA(cudaGetDevice(&device_id));
} }
~cublas_context() ~cublas_context()
{ {
...@@ -59,18 +61,27 @@ namespace dlib ...@@ -59,18 +61,27 @@ namespace dlib
} }
cublasHandle_t get_handle ( cublasHandle_t get_handle (
) const { return handle; } )
{
// Check if the active device for the current thread changed. If so then
// regenerate our cuBLAS handle so it will use the currently selected
// device.
int new_device_id;
CHECK_CUDA(cudaGetDevice(&new_device_id));
if (new_device_id != device_id)
{
CHECK_CUBLAS(cublasDestroy(handle));
CHECK_CUBLAS(cublasCreate(&handle));
}
return handle;
}
private: private:
cublasHandle_t handle; cublasHandle_t handle;
int device_id;
}; };
// TODO, there should probably be some function that is like dlibCudaSetDevice().
// Because people will call cudaSetDevice() expecting to set the device but for
// cuBLAS and cuDNN, since they have these handles, they will keep using the old
// devices. So we should have something that resets these handles and does a
// "dlibCudaSetDevice()"
static cublasHandle_t context() static cublasHandle_t context()
{ {
thread_local cublas_context c; thread_local cublas_context c;
......
...@@ -70,6 +70,7 @@ namespace dlib ...@@ -70,6 +70,7 @@ namespace dlib
cudnn_context() cudnn_context()
{ {
CHECK_CUDNN(cudnnCreate(&handle)); CHECK_CUDNN(cudnnCreate(&handle));
CHECK_CUDA(cudaGetDevice(&device_id));
} }
~cudnn_context() ~cudnn_context()
...@@ -78,10 +79,24 @@ namespace dlib ...@@ -78,10 +79,24 @@ namespace dlib
} }
cudnnHandle_t get_handle ( cudnnHandle_t get_handle (
) const { return handle; } )
{
// Check if the active device for the current thread changed. If so then
// regenerate our cuDNN handle so it will use the currently selected
// device.
int new_device_id;
CHECK_CUDA(cudaGetDevice(&new_device_id));
if (new_device_id != device_id)
{
CHECK_CUDNN(cudnnDestroy(handle));
CHECK_CUDNN(cudnnCreate(&handle));
}
return handle;
}
private: private:
cudnnHandle_t handle; cudnnHandle_t handle;
int device_id;
}; };
static cudnnHandle_t context() static cudnnHandle_t context()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment