Commit 679a6517 authored by Davis King's avatar Davis King

Made the thread local variables that hold the cudnn and cublas context objects

not destruct and recreate themselves when you switch devices.  Instead, they
keep a table of context objects, for each thread and device, reusing as necessary.

This prevents churn in the context objects when you are switching back and
forth between devices inside a single thread.
parent f194bdc9
......@@ -9,6 +9,7 @@
#include "cuda_utils.h"
#include <cublas_v2.h>
#include <vector>
static const char* cublas_get_error_string(cublasStatus_t s)
{
......@@ -52,34 +53,37 @@ namespace dlib
cublas_context()
{
CHECK_CUBLAS(cublasCreate(&handle));
CHECK_CUDA(cudaGetDevice(&device_id));
handles.resize(16);
}
~cublas_context()
{
cublasDestroy(handle);
for (auto h : handles)
{
if (h)
cublasDestroy(h);
}
}
cublasHandle_t get_handle (
)
{
// Check if the active device for the current thread changed. If so then
// regenerate our cuBLAS handle so it will use the currently selected
// device.
int new_device_id;
CHECK_CUDA(cudaGetDevice(&new_device_id));
if (new_device_id != device_id)
{
CHECK_CUBLAS(cublasDestroy(handle));
CHECK_CUBLAS(cublasCreate(&handle));
}
return handle;
// make room for more devices if needed
if (new_device_id >= handles.size())
handles.resize(new_device_id+16);
// If we don't have a handle already for this device then make one
if (!handles[new_device_id])
CHECK_CUBLAS(cublasCreate(&handles[new_device_id]));
// Finally, return the handle for the current device
return handles[new_device_id];
}
private:
cublasHandle_t handle;
int device_id;
std::vector<cublasHandle_t> handles;
};
static cublasHandle_t context()
......
......@@ -10,6 +10,7 @@
#include <cudnn.h>
#include <iostream>
#include <string>
#include <vector>
#include "cuda_utils.h"
#include "cpu_dlib.h"
#include "cuda_dlib.h"
......@@ -70,40 +71,43 @@ namespace dlib
class cudnn_context
{
public:
// not copyable
// not copyable
cudnn_context(const cudnn_context&) = delete;
cudnn_context& operator=(const cudnn_context&) = delete;
cudnn_context()
{
CHECK_CUDNN(cudnnCreate(&handle));
CHECK_CUDA(cudaGetDevice(&device_id));
handles.resize(16);
}
~cudnn_context()
{
cudnnDestroy(handle);
for (auto h : handles)
{
if (h)
cudnnDestroy(h);
}
}
cudnnHandle_t get_handle (
)
)
{
// Check if the active device for the current thread changed. If so then
// regenerate our cuDNN handle so it will use the currently selected
// device.
int new_device_id;
CHECK_CUDA(cudaGetDevice(&new_device_id));
if (new_device_id != device_id)
{
CHECK_CUDNN(cudnnDestroy(handle));
CHECK_CUDNN(cudnnCreate(&handle));
}
return handle;
// make room for more devices if needed
if (new_device_id >= handles.size())
handles.resize(new_device_id+16);
// If we don't have a handle already for this device then make one
if (!handles[new_device_id])
CHECK_CUDNN(cudnnCreate(&handles[new_device_id]));
// Finally, return the handle for the current device
return handles[new_device_id];
}
private:
cudnnHandle_t handle;
int device_id;
std::vector<cudnnHandle_t> handles;
};
static cudnnHandle_t context()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment