Commit 8466d332 authored by Davis King's avatar Davis King

Made the tensor dot() function use cuBLAS.

parent 8424083e
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include "cublas_dlibapi.h" #include "cublas_dlibapi.h"
#include "cuda_utils.h" #include "cuda_utils.h"
#include "tensor.h"
#include <cublas_v2.h> #include <cublas_v2.h>
...@@ -88,6 +89,23 @@ namespace dlib ...@@ -88,6 +89,23 @@ namespace dlib
return c.get_handle(); return c.get_handle();
} }
// -----------------------------------------------------------------------------------
float dot (
const tensor& a,
const tensor& b
)
{
DLIB_CASSERT(a.size() == b.size(), "");
float result = 0;
CHECK_CUBLAS(cublasSdot(context(),
a.size(),
a.device(), 1,
b.device(), 1,
&result));
return result;
}
// ----------------------------------------------------------------------------------- // -----------------------------------------------------------------------------------
void gemm ( void gemm (
......
...@@ -5,14 +5,30 @@ ...@@ -5,14 +5,30 @@
#ifdef DLIB_USE_CUDA #ifdef DLIB_USE_CUDA
#include "tensor.h"
#include "cuda_errors.h" #include "cuda_errors.h"
namespace dlib namespace dlib
{ {
class tensor;
namespace cuda namespace cuda
{ {
// -----------------------------------------------------------------------------------
float dot (
const tensor& a,
const tensor& b
);
/*!
requires
- a.size() == b.size()
ensures
- returns the dot product between a and b when they are both treated as
a.size() dimensional vectors. That is, this function pointwise
multiplies the vectors together, then sums the result and returns it.
!*/
// ----------------------------------------------------------------------------------- // -----------------------------------------------------------------------------------
void gemm ( void gemm (
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <cstring> #include <cstring>
#include "../matrix.h" #include "../matrix.h"
#include "cudnn_dlibapi.h" #include "cudnn_dlibapi.h"
#include "cublas_dlibapi.h"
#include "gpu_data.h" #include "gpu_data.h"
#include <memory> #include <memory>
...@@ -405,7 +406,9 @@ namespace dlib ...@@ -405,7 +406,9 @@ namespace dlib
const tensor& b const tensor& b
) )
{ {
// TODO, do on GPU? #ifdef DLIB_USE_CUDA
return cuda::dot(a,b);
#else
DLIB_CASSERT(a.size() == b.size(), ""); DLIB_CASSERT(a.size() == b.size(), "");
const float* da = a.host(); const float* da = a.host();
const float* db = b.host(); const float* db = b.host();
...@@ -413,6 +416,7 @@ namespace dlib ...@@ -413,6 +416,7 @@ namespace dlib
for (size_t i = 0; i < a.size(); ++i) for (size_t i = 0; i < a.size(); ++i)
sum += da[i]*db[i]; sum += da[i]*db[i];
return sum; return sum;
#endif
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment