Commit adec3eef authored by Davis King's avatar Davis King

merged

parents 8bb4a421 486cf56b
......@@ -1278,7 +1278,7 @@ namespace dlib
{
private:
// We don't want anyone making these no_label_type objects. They are here only to
// allow add_loss_layer::label_type and dnn_trainer::label_type to exist which voids
// allow add_loss_layer::label_type and dnn_trainer::label_type to exist which avoids
// needing to overload add_loss_layer and dnn_trainer for supervised an unsupervised
// losses. It also can be a type to use in template metaprogramming to indicate
// "no label". So here we make the constructor private with the exception that
......
......@@ -6,6 +6,7 @@
#ifdef DLIB_USE_CUDA
#include "cublas_dlibapi.h"
#include "cuda_utils.h"
#include <cublas_v2.h>
......@@ -52,6 +53,7 @@ namespace dlib
cublas_context()
{
CHECK_CUBLAS(cublasCreate(&handle));
CHECK_CUDA(cudaGetDevice(&device_id));
}
~cublas_context()
{
......@@ -59,18 +61,27 @@ namespace dlib
}
cublasHandle_t get_handle (
) const { return handle; }
)
{
// Check if the active device for the current thread changed. If so then
// regenerate our cuBLAS handle so it will use the currently selected
// device.
int new_device_id;
CHECK_CUDA(cudaGetDevice(&new_device_id));
if (new_device_id != device_id)
{
CHECK_CUBLAS(cublasDestroy(handle));
CHECK_CUBLAS(cublasCreate(&handle));
}
return handle;
}
private:
cublasHandle_t handle;
int device_id;
};
// TODO, there should probably be some function that is like dlibCudaSetDevice().
// Because people will call cudaSetDevice() expecting to set the device but for
// cuBLAS and cuDNN, since they have these handles, they will keep using the old
// devices. So we should have something that resets these handles and does a
// "dlibCudaSetDevice()"
static cublasHandle_t context()
{
thread_local cublas_context c;
......
......@@ -70,6 +70,7 @@ namespace dlib
cudnn_context()
{
CHECK_CUDNN(cudnnCreate(&handle));
CHECK_CUDA(cudaGetDevice(&device_id));
}
~cudnn_context()
......@@ -78,10 +79,24 @@ namespace dlib
}
cudnnHandle_t get_handle (
) const { return handle; }
)
{
// Check if the active device for the current thread changed. If so then
// regenerate our cuDNN handle so it will use the currently selected
// device.
int new_device_id;
CHECK_CUDA(cudaGetDevice(&new_device_id));
if (new_device_id != device_id)
{
CHECK_CUDNN(cudnnDestroy(handle));
CHECK_CUDNN(cudnnCreate(&handle));
}
return handle;
}
private:
cudnnHandle_t handle;
int device_id;
};
static cudnnHandle_t context()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment