Commit 8466d332 authored by Davis King's avatar Davis King

Made the tensor dot() function use cuBLAS.

parent 8424083e
......@@ -7,6 +7,7 @@
#include "cublas_dlibapi.h"
#include "cuda_utils.h"
#include "tensor.h"
#include <cublas_v2.h>
......@@ -88,6 +89,23 @@ namespace dlib
return c.get_handle();
}
// -----------------------------------------------------------------------------------
float dot (
const tensor& a,
const tensor& b
)
{
DLIB_CASSERT(a.size() == b.size(), "");
float result = 0;
CHECK_CUBLAS(cublasSdot(context(),
a.size(),
a.device(), 1,
b.device(), 1,
&result));
return result;
}
// -----------------------------------------------------------------------------------
void gemm (
......
......@@ -5,14 +5,30 @@
#ifdef DLIB_USE_CUDA
#include "tensor.h"
#include "cuda_errors.h"
namespace dlib
{
class tensor;
namespace cuda
{
// -----------------------------------------------------------------------------------
float dot (
const tensor& a,
const tensor& b
);
/*!
requires
- a.size() == b.size()
ensures
- returns the dot product between a and b when they are both treated as
a.size() dimensional vectors. That is, this function pointwise
multiplies the vectors together, then sums the result and returns it.
!*/
// -----------------------------------------------------------------------------------
void gemm (
......
......@@ -7,6 +7,7 @@
#include <cstring>
#include "../matrix.h"
#include "cudnn_dlibapi.h"
#include "cublas_dlibapi.h"
#include "gpu_data.h"
#include <memory>
......@@ -405,7 +406,9 @@ namespace dlib
const tensor& b
)
{
// TODO, do on GPU?
#ifdef DLIB_USE_CUDA
return cuda::dot(a,b);
#else
DLIB_CASSERT(a.size() == b.size(), "");
const float* da = a.host();
const float* db = b.host();
......@@ -413,6 +416,7 @@ namespace dlib
for (size_t i = 0; i < a.size(); ++i)
sum += da[i]*db[i];
return sum;
#endif
}
// ----------------------------------------------------------------------------------------
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment