Commit dae8929a authored by Davis King's avatar Davis King

Added cuda::gemm()

parent 7fb29dae
......@@ -14,24 +14,63 @@ namespace dlib
namespace cuda
{
// ----------------------------------------------------------------------------------------
// TODO, make into a macro that prints more information like the line number, etc.
static void check(cublasStatus_t s)
{
switch(s)
{
case CUBLAS_STATUS_SUCCESS: return;
case CUBLAS_STATUS_NOT_INITIALIZED:
throw cublas_error("CUDA Runtime API initialization failed.");
case CUBLAS_STATUS_ALLOC_FAILED:
throw cublas_error("CUDA Resources could not be allocated.");
default:
throw cublas_error("A call to cuBLAS failed");
}
}
// -----------------------------------------------------------------------------------
cublas_context::
class cublas_context
{
public:
// not copyable
cublas_context(const cublas_context&) = delete;
cublas_context& operator=(const cublas_context&) = delete;
cublas_context()
{
// TODO
check(cublasCreate(&handle));
}
cublas_context::
~cublas_context()
{
// TODO
cublasDestroy(handle);
}
cublasHandle_t get_handle (
) const { return handle; }
private:
cublasHandle_t handle;
};
// TODO, there should probably be some function that is like dlibCudaSetDevice().
// Because people will call cudaSetDevice() expecting to set the device but for
// cuBLAS and cuDNN, since they have these handles, they will keep using the old
// devices. So we should have something that resets these handles and does a
// "dlibCudaSetDevice()"
static cublasHandle_t context()
{
thread_local cublas_context c;
return c.get_handle();
}
// -----------------------------------------------------------------------------------
void gemm (
cublas_context& context,
float beta,
tensor& dest,
float alpha,
......@@ -41,6 +80,48 @@ namespace dlib
bool trans_rhs
)
{
// Recall that BLAS uses column major order so to deal with that we flip the
// order of the lhs and rhs arguments.
const auto transa = trans_lhs ? CUBLAS_OP_T : CUBLAS_OP_N;
const auto transb = trans_rhs ? CUBLAS_OP_T : CUBLAS_OP_N;
if (trans_lhs && trans_rhs)
{
DLIB_CASSERT( mat(dest).nr() == trans(mat(lhs)).nr() &&
mat(dest).nc() == trans(mat(rhs)).nc() &&
trans(mat(lhs)).nc() == trans(mat(rhs)).nr(),"")
}
else if (!trans_lhs && trans_rhs)
{
DLIB_CASSERT( mat(dest).nr() == mat(lhs).nr() &&
mat(dest).nc() == trans(mat(rhs)).nc() &&
mat(lhs).nc() == trans(mat(rhs)).nr(),"")
}
else if (trans_lhs && !trans_rhs)
{
DLIB_CASSERT( mat(dest).nr() == trans(mat(lhs)).nr() &&
mat(dest).nc() == mat(rhs).nc() &&
trans(mat(lhs)).nc() == mat(rhs).nr(),"")
}
else
{
DLIB_CASSERT( mat(dest).nr() == mat(lhs).nr() &&
mat(dest).nc() == mat(rhs).nc() &&
mat(lhs).nc() == mat(rhs).nr(),"")
}
const int m = mat(dest).nr();
const int n = mat(dest).nc();
const int k = trans_rhs ? mat(rhs).nc() : mat(rhs).nr();
check(cublasSgemm(context(),
transb,
transa,
m, n, k,
&alpha,
rhs.device(), mat(rhs).nc(),
lhs.device(), mat(lhs).nc(),
&beta,
dest.device(), mat(dest).nc()));
}
// ------------------------------------------------------------------------------------
......
......@@ -20,34 +20,9 @@ namespace dlib
cublas_error(const std::string& message): error(message) {}
};
// -----------------------------------------------------------------------------------
class cublas_context
{
public:
// not copyable
cublas_context(const cublas_context&) = delete;
cublas_context& operator=(const cublas_context&) = delete;
// but is movable
cublas_context(cublas_context&& item) : cublas_context() { swap(item); }
cublas_context& operator=(cublas_context&& item) { swap(item); return *this; }
cublas_context();
~cublas_context();
const void* get_handle (
) const { return handle; }
private:
void swap(cublas_context& item) { std::swap(handle, item.handle); }
void* handle;
};
// -----------------------------------------------------------------------------------
void gemm (
cublas_context& context,
float beta,
tensor& dest,
float alpha,
......@@ -56,6 +31,19 @@ namespace dlib
const tensor& rhs,
bool trans_rhs
);
/*!
requires
- The dimensions of lhs and rhs must be compatible for matrix
multiplication. In particular:
- Let L == trans_lhs ? trans(mat(lhs)) : mat(lhs)
- Let R == trans_rhs ? trans(mat(rhs)) : mat(rhs)
- Let D == mat(dest)
- D.nr() == L.nr() && D.nc() == R.nc()
(i.e. dest must be preallocated and have the correct output dimensions)
- L.nc() == R.nr()
ensures
- performs: dest = alpha*L*R + beta*mat(dest)
!*/
// ------------------------------------------------------------------------------------
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment