Commit 30005b7e authored by Davis King's avatar Davis King

Wrapped new dot() function into the tt namespace and gave it a CPU version.

parent d248a225
......@@ -830,6 +830,23 @@ namespace dlib
d[i] = d[i]>thresh ? 1:0;
}
void dot (
const tensor& a,
const tensor& b,
tensor& result,
size_t idx
)
{
DLIB_CASSERT(a.size() == b.size(), "");
DLIB_CASSERT(idx < result.size(), "");
const auto aa = a.host();
const auto bb = b.host();
auto r = result.host();
for (size_t i = 0; i < a.size(); ++i)
r[idx] += aa[i]*bb[i];
}
// -----------------------------------------------------------------------------------
// -----------------------------------------------------------------------------------
// -----------------------------------------------------------------------------------
......
......@@ -154,6 +154,13 @@ namespace dlib
float thresh
);
void dot (
const tensor& a,
const tensor& b,
tensor& result,
size_t idx
);
// -----------------------------------------------------------------------------------
void softmax (
......
......@@ -307,6 +307,20 @@ namespace dlib { namespace tt
#endif
}
void dot (
const tensor& a,
const tensor& b,
tensor& result,
size_t idx
)
{
#ifdef DLIB_USE_CUDA
cuda::dot(a,b,result,idx);
#else
cpu::dot(a,b,result,idx);
#endif
}
// ----------------------------------------------------------------------------------------
void add(
......
......@@ -398,6 +398,27 @@ namespace dlib { namespace tt
- #data.host()[i] == data.host()[i]>thresh ? 1 : 0
!*/
void dot (
const tensor& a,
const tensor& b,
tensor& result,
size_t idx
);
/*!
requires
- a.size() == b.size()
- idx < result.size()
ensures
- #result.host()[idx] == result.host()[idx] + dot(a,b);
I.e. Adds the dot product between a and b into the idx-th element of result.
The reason you might want to use this more complex version of dot() is
because, when using CUDA, it runs by generating asynchronous kernel launches
whereas the version of dot() that returns the result immediately as a scalar
must block the host while we wait for the result to be computed and then
transfered from the GPU do the host for return by dot(). So this version of
dot() might be much faster in some cases.
!*/
// ----------------------------------------------------------------------------------------
void add(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment