Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in
Toggle navigation
D
dlib
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
钟尚武
dlib
Commits
0508fe2b
Commit
0508fe2b
authored
Oct 18, 2015
by
Davis King
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Fully implemented the gpu_data object and also cleaned up a few other minor
details.
parent
1e623983
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
322 additions
and
80 deletions
+322
-80
cuda_dlib.cu
dlib/dnn/cuda_dlib.cu
+23
-27
cuda_dlib.h
dlib/dnn/cuda_dlib.h
+3
-3
cuda_errors.h
dlib/dnn/cuda_errors.h
+29
-0
cuda_utils.h
dlib/dnn/cuda_utils.h
+105
-0
cudnn_dlibapi.cpp
dlib/dnn/cudnn_dlibapi.cpp
+131
-0
cudnn_dlibapi.h
dlib/dnn/cudnn_dlibapi.h
+1
-7
tensor.h
dlib/dnn/tensor.h
+30
-43
No files found.
dlib/dnn/cuda_dlib.cu
View file @
0508fe2b
// Copyright (C) 2015 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#include <stdlib.h>
#include <stdio.h>
#include "cuda_utils.h"
#include "cuda_dlib.h"
#define CHECK(call) \
{ \
const cudaError_t error = call; \
if (error != cudaSuccess) \
{ \
fprintf(stderr, "Error: %s:%d, ", __FILE__, __LINE__); \
fprintf(stderr, "code: %d, reason: %s\n", error, \
cudaGetErrorString(error)); \
exit(1); \
} \
}
namespace dlib
{
namespace cuda
{
__global__ void helloFromGPU()
{
printf("Hello World from GPU!\n");
}
// ------------------------------------------------------------------------------------
void hello_cuda()
{
printf("Hello World from CPU!\n");
__global__ void cuda_add_arrays(const float* a, const float* b, float* out, size_t n)
{
for (auto i : grid_stride_range(0, n))
{
out[i] += a[i]+b[i];
}
}
helloFromGPU<<<1, 10>>>();
CHECK(cudaDeviceReset());
void add_arrays(const gpu_data& a, const gpu_data& b, gpu_data& out)
{
DLIB_CASSERT(a.size() == b.size(),"");
out.set_size(a.size());
cuda_add_arrays<<<512,512>>>(a.device(), b.device(), out.device(), a.size());
}
#ifndef DLIB_USE_CUDA
#error why is this not defined?
#endif
// ------------------------------------------------------------------------------------
auto x = 4;
}
}
dlib/dnn/cuda_dlib.h
View file @
0508fe2b
...
...
@@ -7,14 +7,14 @@
#include "tensor.h"
// TODO, remove this cruft
void
hello_cuda
();
namespace
dlib
{
namespace
cuda
{
// TODO, remove this
void
add_arrays
(
const
gpu_data
&
a
,
const
gpu_data
&
b
,
gpu_data
&
out
);
// -----------------------------------------------------------------------------------
void
affine_transform
(
...
...
dlib/dnn/cuda_errors.h
0 → 100644
View file @
0508fe2b
// Copyright (C) 2015 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_CUDA_ERRORs_H_
#define DLIB_CUDA_ERRORs_H_
#ifdef DLIB_USE_CUDA
#include "../error.h"
namespace
dlib
{
namespace
cuda
{
struct
cuda_error
:
public
error
{
cuda_error
(
const
std
::
string
&
message
)
:
error
(
message
)
{}
};
struct
cudnn_error
:
public
cuda_error
{
cudnn_error
(
const
std
::
string
&
message
)
:
cuda_error
(
message
)
{}
};
}
}
#endif // DLIB_USE_CUDA
#endif // DLIB_CUDA_ERRORs_H_
dlib/dnn/cuda_utils.h
0 → 100644
View file @
0508fe2b
// Copyright (C) 2015 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_CUDA_UtILS_H_
#define DLIB_CUDA_UtILS_H_
#include "cuda_errors.h"
#include <cuda.h>
#include <sstream>
// Check the return value of a call to the CUDA runtime for an error condition.
#define CHECK_CUDA(call) \
{ \
const cudaError_t error = call; \
if (error != cudaSuccess) \
{ \
std::ostringstream sout; \
sout << "Error while calling " << #call << " in file " << __FILE__ << ":" << __LINE__ << ". ";\
sout << "code: " << error << ", reason: " << cudaGetErrorString(error);\
throw dlib::cuda::cuda_error(sout.str()); \
} \
}
// ----------------------------------------------------------------------------------------
#ifdef __CUDACC__
class
grid_stride_range
{
/*!
WHAT THIS OBJECT REPRESENTS
This is a tool for making a for loop that loops over an entire block of memory
inside a kernel, but doing so in a way that parallelizes appropriately across
all the threads in a kernel launch. For example, the following kernel would
add the vector a to the vector b and store the output in out (assuming all
vectors are of dimension n):
__global__ void add_arrays(
const float* a,
const float* b,
float* out,
size_t n
)
{
for (auto i : grid_stride_range(0, n))
{
out[i] = a[i]+b[i];
}
}
!*/
public
:
__device__
grid_stride_range
(
size_t
ibegin_
,
size_t
iend_
)
:
ibegin
(
ibegin_
),
iend
(
iend_
)
{}
class
iterator
{
public
:
__device__
iterator
()
{}
__device__
iterator
(
size_t
pos_
)
:
pos
(
pos_
)
{}
__device__
size_t
operator
*
()
const
{
return
pos
;
}
__device__
iterator
&
operator
++
()
{
pos
+=
gridDim
.
x
*
blockDim
.
x
;
return
*
this
;
}
__device__
bool
operator
!=
(
const
iterator
&
item
)
const
{
return
pos
<
item
.
pos
;
}
private
:
size_t
pos
;
};
__device__
iterator
begin
()
const
{
return
iterator
(
ibegin
+
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
);
}
__device__
iterator
end
()
const
{
return
iterator
(
iend
);
}
private
:
size_t
ibegin
;
size_t
iend
;
};
#endif // __CUDACC__
// ----------------------------------------------------------------------------------------
#endif // DLIB_CUDA_UtILS_H_
dlib/dnn/cudnn_dlibapi.cpp
View file @
0508fe2b
...
...
@@ -8,9 +8,140 @@
#include "cudnn_dlibapi.h"
#include "tensor.h"
#include <cudnn.h>
#include <iostream>
#include "cuda_utils.h"
namespace
dlib
{
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// gpu_data member functions
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// TODO, add error handling
void
gpu_data
::
wait_for_transfer_to_finish
()
const
{
if
(
have_active_transfer
)
{
std
::
cout
<<
"wait for cudaStreamSynchronize()"
<<
std
::
endl
;
CHECK_CUDA
(
cudaStreamSynchronize
((
cudaStream_t
)
cuda_stream
.
get
()));
have_active_transfer
=
false
;
// Check for errors. These calls to cudaGetLastError() are what help us find
// out if our kernel launches have been failing.
CHECK_CUDA
(
cudaGetLastError
());
}
}
void
gpu_data
::
copy_to_device
()
const
{
wait_for_transfer_to_finish
();
if
(
!
device_current
)
{
std
::
cout
<<
"cudaMemcpy to device"
<<
std
::
endl
;
CHECK_CUDA
(
cudaMemcpy
(
data_device
.
get
(),
data_host
.
get
(),
data_size
*
sizeof
(
float
),
cudaMemcpyHostToDevice
));
device_current
=
true
;
// Check for errors. These calls to cudaGetLastError() are what help us find
// out if our kernel launches have been failing.
CHECK_CUDA
(
cudaGetLastError
());
}
}
void
gpu_data
::
copy_to_host
()
const
{
wait_for_transfer_to_finish
();
if
(
!
host_current
)
{
std
::
cout
<<
"cudaMemcpy to host"
<<
std
::
endl
;
CHECK_CUDA
(
cudaMemcpy
(
data_host
.
get
(),
data_device
.
get
(),
data_size
*
sizeof
(
float
),
cudaMemcpyDeviceToHost
));
host_current
=
true
;
// Check for errors. These calls to cudaGetLastError() are what help us find
// out if our kernel launches have been failing.
CHECK_CUDA
(
cudaGetLastError
());
}
}
void
gpu_data
::
async_copy_to_device
()
{
if
(
!
device_current
)
{
std
::
cout
<<
"cudaMemcpyAsync to device"
<<
std
::
endl
;
CHECK_CUDA
(
cudaMemcpyAsync
(
data_device
.
get
(),
data_host
.
get
(),
data_size
*
sizeof
(
float
),
cudaMemcpyHostToDevice
,
(
cudaStream_t
)
cuda_stream
.
get
()));
have_active_transfer
=
true
;
device_current
=
true
;
}
}
void
gpu_data
::
set_size
(
size_t
new_size
)
{
wait_for_transfer_to_finish
();
if
(
new_size
==
0
)
{
data_size
=
0
;
host_current
=
true
;
device_current
=
true
;
data_host
.
reset
();
data_device
.
reset
();
}
else
if
(
new_size
!=
data_size
)
{
data_size
=
new_size
;
host_current
=
true
;
device_current
=
true
;
try
{
void
*
data
;
CHECK_CUDA
(
cudaMallocHost
(
&
data
,
new_size
*
sizeof
(
float
)));
// Note that we don't throw exceptions since the free calls are invariably
// called in destructors. They also shouldn't fail anyway unless someone
// is resetting the GPU card in the middle of their program.
data_host
.
reset
((
float
*
)
data
,
[](
float
*
ptr
){
auto
err
=
cudaFreeHost
(
ptr
);
if
(
err
!=
cudaSuccess
)
std
::
cerr
<<
"cudaFreeHost() failed. Reason: "
<<
cudaGetErrorString
(
err
)
<<
std
::
endl
;
});
CHECK_CUDA
(
cudaMalloc
(
&
data
,
new_size
*
sizeof
(
float
)));
data_device
.
reset
((
float
*
)
data
,
[](
float
*
ptr
){
auto
err
=
cudaFree
(
ptr
);
if
(
err
!=
cudaSuccess
)
std
::
cerr
<<
"cudaFree() failed. Reason: "
<<
cudaGetErrorString
(
err
)
<<
std
::
endl
;
});
if
(
!
cuda_stream
)
{
cudaStream_t
cstream
;
CHECK_CUDA
(
cudaStreamCreateWithFlags
(
&
cstream
,
cudaStreamNonBlocking
));
cuda_stream
.
reset
(
cstream
,
[](
void
*
ptr
){
auto
err
=
cudaStreamDestroy
((
cudaStream_t
)
ptr
);
if
(
err
!=
cudaSuccess
)
std
::
cerr
<<
"cudaStreamDestroy() failed. Reason: "
<<
cudaGetErrorString
(
err
)
<<
std
::
endl
;
});
}
}
catch
(...)
{
set_size
(
0
);
throw
;
}
}
}
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
namespace
cuda
{
...
...
dlib/dnn/cudnn_dlibapi.h
View file @
0508fe2b
...
...
@@ -5,7 +5,7 @@
#ifdef DLIB_USE_CUDA
#include "
../error
.h"
#include "
cuda_errors
.h"
namespace
dlib
{
...
...
@@ -17,12 +17,6 @@ namespace dlib
// -----------------------------------------------------------------------------------
struct
cudnn_error
:
public
error
{
cudnn_error
(
const
std
::
string
&
message
)
:
error
(
message
)
{}
};
// ------------------------------------------------------------------------------------
class
cudnn_context
{
...
...
dlib/dnn/tensor.h
View file @
0508fe2b
...
...
@@ -22,17 +22,23 @@ namespace dlib
- if (data_device) then
- data_device == a pointer to size() floats in device memory.
- if (there might be an active transfer between host and device) then
- have_active_transfer == true
- We use the host_current and device_current bools to keep track of which
copy of the data (or both) are most current. e.g. if the CPU has
modified the tensor and it hasn't been copied to the device yet then
host_current==true and device_current == false.
THREAD SAFETY
This object is not thread-safe. Don't touch it from multiple threads as the
same time.
!*/
public
:
gpu_data
(
)
:
data_size
(
0
),
host_current
(
true
),
device_current
(
false
)
)
:
data_size
(
0
),
host_current
(
true
),
device_current
(
true
),
have_active_transfer
(
false
)
{
}
...
...
@@ -44,13 +50,21 @@ namespace dlib
gpu_data
(
gpu_data
&&
)
=
default
;
gpu_data
&
operator
=
(
gpu_data
&&
)
=
default
;
#ifdef DLIB_USE_CUDA
void
async_copy_to_device
();
void
set_size
(
size_t
new_size
);
#else
// Note that calls to host() or device() will block until any async transfers are complete.
void
async_copy_to_device
(){}
void
set_size
(
size_t
new_size
)
{
if
(
new_size
==
0
)
{
data_size
=
0
;
host_current
=
true
;
device_current
=
fals
e
;
device_current
=
tru
e
;
data_host
.
reset
();
data_device
.
reset
();
}
...
...
@@ -58,25 +72,12 @@ namespace dlib
{
data_size
=
new_size
;
host_current
=
true
;
device_current
=
fals
e
;
data_host
.
reset
(
new
float
[
new_size
]);
device_current
=
tru
e
;
data_host
.
reset
(
new
float
[
new_size
]
,
std
::
default_delete
<
float
[]
>
()
);
data_device
.
reset
();
}
}
void
async_copy_to_device
()
{
#ifdef DLIB_USE_CUDA
// TODO
#endif
}
void
async_copy_to_host
()
{
#ifdef DLIB_USE_CUDA
// TODO
#endif
}
const
float
*
host
()
const
{
...
...
@@ -115,39 +116,31 @@ namespace dlib
private
:
void
copy_to_device
()
const
{
if
(
!
device_current
)
{
#ifdef DLIB_USE_CUDA
// TODO, cudamemcpy()
void
copy_to_device
()
const
;
void
copy_to_host
()
const
;
void
wait_for_transfer_to_finish
()
const
;
#else
void
copy_to_device
()
const
{}
void
copy_to_host
()
const
{}
void
wait_for_transfer_to_finish
()
const
{}
#endif
device_current
=
true
;
}
}
void
copy_to_host
()
const
{
if
(
!
host_current
)
{
#ifdef DLIB_USE_CUDA
// TODO, cudamemcpy()
#endif
host_current
=
true
;
}
}
size_t
data_size
;
mutable
bool
host_current
;
mutable
bool
device_current
;
mutable
bool
have_active_transfer
;
std
::
unique_ptr
<
float
[]
>
data_host
;
std
::
unique_ptr
<
float
[]
>
data_device
;
std
::
shared_ptr
<
float
>
data_host
;
std
::
shared_ptr
<
float
>
data_device
;
std
::
shared_ptr
<
void
>
cuda_stream
;
};
inline
void
serialize
(
const
gpu_data
&
item
,
std
::
ostream
&
out
)
{
int
version
=
1
;
serialize
(
version
,
out
);
serialize
(
item
.
size
(),
out
);
auto
data
=
item
.
host
();
for
(
size_t
i
=
0
;
i
<
item
.
size
();
++
i
)
...
...
@@ -193,11 +186,6 @@ namespace dlib
long
k
()
const
{
return
m_k
;
}
size_t
size
()
const
{
return
data
.
size
();
}
void
async_copy_to_host
()
{
data
.
async_copy_to_host
();
}
void
async_copy_to_device
()
{
data
.
async_copy_to_device
();
...
...
@@ -306,7 +294,6 @@ namespace dlib
std
::
memcpy
(
data
.
host
(),
item
.
data
.
host
(),
data
.
size
()
*
sizeof
(
float
));
#ifdef DLIB_USE_CUDA
cudnn_descriptor
.
set_size
(
m_n
,
m_nr
,
m_nc
,
m_k
);
#endif
return
*
this
;
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment