Commit c79ae7d4 authored by Davis King's avatar Davis King

Changed gemm so it doesn't cause device to host copies in the assert statements.

parent d059f3c1
...@@ -85,43 +85,47 @@ namespace dlib ...@@ -85,43 +85,47 @@ namespace dlib
const auto transa = trans_lhs ? CUBLAS_OP_T : CUBLAS_OP_N; const auto transa = trans_lhs ? CUBLAS_OP_T : CUBLAS_OP_N;
const auto transb = trans_rhs ? CUBLAS_OP_T : CUBLAS_OP_N; const auto transb = trans_rhs ? CUBLAS_OP_T : CUBLAS_OP_N;
const int dest_nr = dest.num_samples();
const int dest_nc = dest.size()/dest_nr;
const int lhs_nr = lhs.num_samples();
const int lhs_nc = lhs.size()/lhs_nr;
const int rhs_nr = rhs.num_samples();
const int rhs_nc = rhs.size()/rhs_nr;
if (trans_lhs && trans_rhs) if (trans_lhs && trans_rhs)
{ {
DLIB_CASSERT( mat(dest).nr() == trans(mat(lhs)).nr() && DLIB_ASSERT( dest_nr == lhs_nc &&
mat(dest).nc() == trans(mat(rhs)).nc() && dest_nc == rhs_nr &&
trans(mat(lhs)).nc() == trans(mat(rhs)).nr(),"") lhs_nr == rhs_nc,"")
} }
else if (!trans_lhs && trans_rhs) else if (!trans_lhs && trans_rhs)
{ {
DLIB_CASSERT( mat(dest).nr() == mat(lhs).nr() && DLIB_ASSERT( dest_nr == lhs_nr &&
mat(dest).nc() == trans(mat(rhs)).nc() && dest_nc == rhs_nr &&
mat(lhs).nc() == trans(mat(rhs)).nr(),"") lhs_nc == rhs_nc,"")
} }
else if (trans_lhs && !trans_rhs) else if (trans_lhs && !trans_rhs)
{ {
DLIB_CASSERT( mat(dest).nr() == trans(mat(lhs)).nr() && DLIB_ASSERT( dest_nr == lhs_nc &&
mat(dest).nc() == mat(rhs).nc() && dest_nc == rhs_nc &&
trans(mat(lhs)).nc() == mat(rhs).nr(),"") lhs_nr == rhs_nr,"")
} }
else else
{ {
DLIB_CASSERT( mat(dest).nr() == mat(lhs).nr() && DLIB_ASSERT( dest_nr == lhs_nr &&
mat(dest).nc() == mat(rhs).nc() && dest_nc == rhs_nc &&
mat(lhs).nc() == mat(rhs).nr(),"") lhs_nc == rhs_nr,"")
} }
const int m = mat(dest).nc(); const int k = trans_rhs ? rhs_nc : rhs_nr;
const int n = mat(dest).nr();
const int k = trans_rhs ? mat(rhs).nc() : mat(rhs).nr();
check(cublasSgemm(context(), check(cublasSgemm(context(),
transb, transb,
transa, transa,
m, n, k, dest_nc, dest_nr, k,
&alpha, &alpha,
rhs.device(), mat(rhs).nc(), rhs.device(), rhs_nc,
lhs.device(), mat(lhs).nc(), lhs.device(), lhs_nc,
&beta, &beta,
dest.device(), mat(dest).nc())); dest.device(),dest_nc));
} }
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment