Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in
Toggle navigation
D
dlib
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
钟尚武
dlib
Commits
dae8929a
Commit
dae8929a
authored
Nov 04, 2015
by
Davis King
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Added cuda::gemm()
parent
7fb29dae
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
102 additions
and
33 deletions
+102
-33
cublas_dlibapi.cpp
dlib/dnn/cublas_dlibapi.cpp
+89
-8
cublas_dlibapi.h
dlib/dnn/cublas_dlibapi.h
+13
-25
No files found.
dlib/dnn/cublas_dlibapi.cpp
View file @
dae8929a
...
@@ -14,24 +14,63 @@ namespace dlib
...
@@ -14,24 +14,63 @@ namespace dlib
namespace
cuda
namespace
cuda
{
{
// -----------------------------------------------------------------------------------
// -----------------------------------------------------------------------------------
-----
cublas_context
::
// TODO, make into a macro that prints more information like the line number, etc.
cublas_context
(
)
static
void
check
(
cublasStatus_t
s
)
{
{
// TODO
switch
(
s
)
{
case
CUBLAS_STATUS_SUCCESS
:
return
;
case
CUBLAS_STATUS_NOT_INITIALIZED
:
throw
cublas_error
(
"CUDA Runtime API initialization failed."
);
case
CUBLAS_STATUS_ALLOC_FAILED
:
throw
cublas_error
(
"CUDA Resources could not be allocated."
);
default
:
throw
cublas_error
(
"A call to cuBLAS failed"
);
}
}
}
cublas_context
::
// -----------------------------------------------------------------------------------
~
cublas_context
()
class
cublas_context
{
public
:
// not copyable
cublas_context
(
const
cublas_context
&
)
=
delete
;
cublas_context
&
operator
=
(
const
cublas_context
&
)
=
delete
;
cublas_context
()
{
check
(
cublasCreate
(
&
handle
));
}
~
cublas_context
()
{
cublasDestroy
(
handle
);
}
cublasHandle_t
get_handle
(
)
const
{
return
handle
;
}
private
:
cublasHandle_t
handle
;
};
// TODO, there should probably be some function that is like dlibCudaSetDevice().
// Because people will call cudaSetDevice() expecting to set the device but for
// cuBLAS and cuDNN, since they have these handles, they will keep using the old
// devices. So we should have something that resets these handles and does a
// "dlibCudaSetDevice()"
static
cublasHandle_t
context
()
{
{
// TODO
thread_local
cublas_context
c
;
return
c
.
get_handle
();
}
}
// -----------------------------------------------------------------------------------
// -----------------------------------------------------------------------------------
void
gemm
(
void
gemm
(
cublas_context
&
context
,
float
beta
,
float
beta
,
tensor
&
dest
,
tensor
&
dest
,
float
alpha
,
float
alpha
,
...
@@ -41,6 +80,48 @@ namespace dlib
...
@@ -41,6 +80,48 @@ namespace dlib
bool
trans_rhs
bool
trans_rhs
)
)
{
{
// Recall that BLAS uses column major order so to deal with that we flip the
// order of the lhs and rhs arguments.
const
auto
transa
=
trans_lhs
?
CUBLAS_OP_T
:
CUBLAS_OP_N
;
const
auto
transb
=
trans_rhs
?
CUBLAS_OP_T
:
CUBLAS_OP_N
;
if
(
trans_lhs
&&
trans_rhs
)
{
DLIB_CASSERT
(
mat
(
dest
).
nr
()
==
trans
(
mat
(
lhs
)).
nr
()
&&
mat
(
dest
).
nc
()
==
trans
(
mat
(
rhs
)).
nc
()
&&
trans
(
mat
(
lhs
)).
nc
()
==
trans
(
mat
(
rhs
)).
nr
(),
""
)
}
else
if
(
!
trans_lhs
&&
trans_rhs
)
{
DLIB_CASSERT
(
mat
(
dest
).
nr
()
==
mat
(
lhs
).
nr
()
&&
mat
(
dest
).
nc
()
==
trans
(
mat
(
rhs
)).
nc
()
&&
mat
(
lhs
).
nc
()
==
trans
(
mat
(
rhs
)).
nr
(),
""
)
}
else
if
(
trans_lhs
&&
!
trans_rhs
)
{
DLIB_CASSERT
(
mat
(
dest
).
nr
()
==
trans
(
mat
(
lhs
)).
nr
()
&&
mat
(
dest
).
nc
()
==
mat
(
rhs
).
nc
()
&&
trans
(
mat
(
lhs
)).
nc
()
==
mat
(
rhs
).
nr
(),
""
)
}
else
{
DLIB_CASSERT
(
mat
(
dest
).
nr
()
==
mat
(
lhs
).
nr
()
&&
mat
(
dest
).
nc
()
==
mat
(
rhs
).
nc
()
&&
mat
(
lhs
).
nc
()
==
mat
(
rhs
).
nr
(),
""
)
}
const
int
m
=
mat
(
dest
).
nr
();
const
int
n
=
mat
(
dest
).
nc
();
const
int
k
=
trans_rhs
?
mat
(
rhs
).
nc
()
:
mat
(
rhs
).
nr
();
check
(
cublasSgemm
(
context
(),
transb
,
transa
,
m
,
n
,
k
,
&
alpha
,
rhs
.
device
(),
mat
(
rhs
).
nc
(),
lhs
.
device
(),
mat
(
lhs
).
nc
(),
&
beta
,
dest
.
device
(),
mat
(
dest
).
nc
()));
}
}
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
...
...
dlib/dnn/cublas_dlibapi.h
View file @
dae8929a
...
@@ -20,34 +20,9 @@ namespace dlib
...
@@ -20,34 +20,9 @@ namespace dlib
cublas_error
(
const
std
::
string
&
message
)
:
error
(
message
)
{}
cublas_error
(
const
std
::
string
&
message
)
:
error
(
message
)
{}
};
};
// -----------------------------------------------------------------------------------
class
cublas_context
{
public
:
// not copyable
cublas_context
(
const
cublas_context
&
)
=
delete
;
cublas_context
&
operator
=
(
const
cublas_context
&
)
=
delete
;
// but is movable
cublas_context
(
cublas_context
&&
item
)
:
cublas_context
()
{
swap
(
item
);
}
cublas_context
&
operator
=
(
cublas_context
&&
item
)
{
swap
(
item
);
return
*
this
;
}
cublas_context
();
~
cublas_context
();
const
void
*
get_handle
(
)
const
{
return
handle
;
}
private
:
void
swap
(
cublas_context
&
item
)
{
std
::
swap
(
handle
,
item
.
handle
);
}
void
*
handle
;
};
// -----------------------------------------------------------------------------------
// -----------------------------------------------------------------------------------
void
gemm
(
void
gemm
(
cublas_context
&
context
,
float
beta
,
float
beta
,
tensor
&
dest
,
tensor
&
dest
,
float
alpha
,
float
alpha
,
...
@@ -56,6 +31,19 @@ namespace dlib
...
@@ -56,6 +31,19 @@ namespace dlib
const
tensor
&
rhs
,
const
tensor
&
rhs
,
bool
trans_rhs
bool
trans_rhs
);
);
/*!
requires
- The dimensions of lhs and rhs must be compatible for matrix
multiplication. In particular:
- Let L == trans_lhs ? trans(mat(lhs)) : mat(lhs)
- Let R == trans_rhs ? trans(mat(rhs)) : mat(rhs)
- Let D == mat(dest)
- D.nr() == L.nr() && D.nc() == R.nc()
(i.e. dest must be preallocated and have the correct output dimensions)
- L.nc() == R.nr()
ensures
- performs: dest = alpha*L*R + beta*mat(dest)
!*/
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment