Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in
Toggle navigation
D
dlib
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
钟尚武
dlib
Commits
351a6331
Commit
351a6331
authored
9 years ago
by
Davis King
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Added loss_multiclass_log_
parent
04526626
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
105 additions
and
0 deletions
+105
-0
loss.h
dlib/dnn/loss.h
+105
-0
No files found.
dlib/dnn/loss.h
View file @
351a6331
...
@@ -204,6 +204,111 @@ namespace dlib
...
@@ -204,6 +204,111 @@ namespace dlib
template
<
typename
SUBNET
>
template
<
typename
SUBNET
>
using
loss_binary_log
=
add_loss_layer
<
loss_binary_log_
,
SUBNET
>
;
using
loss_binary_log
=
add_loss_layer
<
loss_binary_log_
,
SUBNET
>
;
// ----------------------------------------------------------------------------------------
class
loss_multiclass_log_
{
public
:
const
static
unsigned
int
sample_expansion_factor
=
1
;
typedef
unsigned
long
label_type
;
template
<
typename
SUB_TYPE
,
typename
label_iterator
>
void
to_label
(
const
tensor
&
input_tensor
,
const
SUB_TYPE
&
sub
,
label_iterator
iter
)
const
{
const
tensor
&
output_tensor
=
sub
.
get_output
();
DLIB_CASSERT
(
output_tensor
.
nr
()
==
1
&&
output_tensor
.
nc
()
==
1
,
""
);
DLIB_CASSERT
(
input_tensor
.
num_samples
()
==
output_tensor
.
num_samples
(),
""
);
// Note that output_tensor.k() should match the number of labels.
const
float
*
out_data
=
output_tensor
.
host
();
for
(
long
i
=
0
;
i
<
output_tensor
.
num_samples
();
++
i
)
{
// The index of the largest output for this sample is the label.
*
iter
++
=
index_of_max
(
rowm
(
mat
(
output_tensor
),
i
));
}
}
template
<
typename
const_label_iterator
,
typename
SUBNET
>
double
compute_loss
(
const
tensor
&
input_tensor
,
const_label_iterator
truth
,
SUBNET
&
sub
)
const
{
const
tensor
&
output_tensor
=
sub
.
get_output
();
tensor
&
grad
=
sub
.
get_gradient_input
();
DLIB_CASSERT
(
input_tensor
.
num_samples
()
!=
0
,
""
);
DLIB_CASSERT
(
input_tensor
.
num_samples
()
%
sample_expansion_factor
==
0
,
""
);
DLIB_CASSERT
(
input_tensor
.
num_samples
()
==
grad
.
num_samples
(),
""
);
DLIB_CASSERT
(
input_tensor
.
num_samples
()
==
output_tensor
.
num_samples
(),
""
);
DLIB_CASSERT
(
output_tensor
.
nr
()
==
1
&&
output_tensor
.
nc
()
==
1
,
""
);
DLIB_CASSERT
(
grad
.
nr
()
==
1
&&
grad
.
nc
()
==
1
,
""
);
tt
::
softmax
(
grad
,
output_tensor
);
// The loss we output is the average loss over the mini-batch.
const
double
scale
=
1
.
0
/
output_tensor
.
num_samples
();
double
loss
=
0
;
float
*
g
=
grad
.
host
();
for
(
long
i
=
0
;
i
<
output_tensor
.
num_samples
();
++
i
)
{
const
long
y
=
(
long
)
*
truth
++
;
// The network must produce a number of outputs that is equal to the number
// of labels when using this type of loss.
DLIB_CASSERT
(
y
<
output_tensor
.
k
(),
"y: "
<<
y
<<
", output_tensor.k(): "
<<
output_tensor
.
k
());
for
(
long
k
=
0
;
k
<
output_tensor
.
k
();
++
k
)
{
const
unsigned
long
idx
=
i
*
output_tensor
.
k
()
+
k
;
if
(
k
==
y
)
{
loss
+=
scale
*-
std
::
log
(
g
[
idx
]);
g
[
idx
]
=
scale
*
(
g
[
idx
]
-
1
);
}
else
{
g
[
idx
]
=
scale
*
g
[
idx
];
}
}
}
return
loss
;
}
friend
void
serialize
(
const
loss_multiclass_log_
&
,
std
::
ostream
&
out
)
{
serialize
(
"loss_multiclass_log_"
,
out
);
}
friend
void
deserialize
(
loss_multiclass_log_
&
,
std
::
istream
&
in
)
{
std
::
string
version
;
deserialize
(
version
,
in
);
if
(
version
!=
"loss_multiclass_log_"
)
throw
serialization_error
(
"Unexpected version found while deserializing dlib::loss_multiclass_log_."
);
}
};
template
<
typename
SUBNET
>
using
loss_multiclass_log
=
add_loss_layer
<
loss_multiclass_log_
,
SUBNET
>
;
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
}
}
...
...
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment