Commit 10caab26 authored by Davis King's avatar Davis King

Updated projection_hash creation functions to allow user to supply

the random number generator that gets used.
parent 97d6125f
...@@ -21,7 +21,8 @@ namespace dlib ...@@ -21,7 +21,8 @@ namespace dlib
> >
projection_hash create_random_projection_hash ( projection_hash create_random_projection_hash (
const vector_type& v, const vector_type& v,
const int bits const int bits,
dlib::rand& rnd
) )
{ {
// make sure requires clause is not broken // make sure requires clause is not broken
...@@ -65,7 +66,6 @@ namespace dlib ...@@ -65,7 +66,6 @@ namespace dlib
std::vector<double> temp; std::vector<double> temp;
// build a random projection matrix // build a random projection matrix
dlib::rand rnd;
matrix<double> proj(bits, v[0].size()); matrix<double> proj(bits, v[0].size());
for (long r = 0; r < proj.nr(); ++r) for (long r = 0; r < proj.nr(); ++r)
for (long c = 0; c < proj.nc(); ++c) for (long c = 0; c < proj.nc(); ++c)
...@@ -114,6 +114,20 @@ namespace dlib ...@@ -114,6 +114,20 @@ namespace dlib
return projection_hash(proj, offset); return projection_hash(proj, offset);
} }
// ----------------------------------------------------------------------------------------
template <
typename vector_type
>
projection_hash create_random_projection_hash (
const vector_type& v,
const int bits
)
{
dlib::rand rnd;
return create_random_projection_hash(v,bits,rnd);
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template < template <
...@@ -122,7 +136,8 @@ namespace dlib ...@@ -122,7 +136,8 @@ namespace dlib
projection_hash create_max_margin_projection_hash ( projection_hash create_max_margin_projection_hash (
const vector_type& v, const vector_type& v,
const int bits, const int bits,
const double C = 10 const double C,
dlib::rand& rnd
) )
{ {
// make sure requires clause is not broken // make sure requires clause is not broken
...@@ -155,7 +170,6 @@ namespace dlib ...@@ -155,7 +170,6 @@ namespace dlib
matrix<double> whiten = trans(chol(pinv(rc.covariance()))); matrix<double> whiten = trans(chol(pinv(rc.covariance())));
const matrix<double,0,1> meanval = whiten*rc.mean(); const matrix<double,0,1> meanval = whiten*rc.mean();
dlib::rand rnd;
typedef matrix<double,0,1> sample_type; typedef matrix<double,0,1> sample_type;
...@@ -195,6 +209,21 @@ namespace dlib ...@@ -195,6 +209,21 @@ namespace dlib
return projection_hash(proj*whiten, offset-proj*meanval); return projection_hash(proj*whiten, offset-proj*meanval);
} }
// ----------------------------------------------------------------------------------------
template <
typename vector_type
>
projection_hash create_max_margin_projection_hash (
const vector_type& v,
const int bits,
const double C = 10
)
{
dlib::rand rnd;
return create_max_margin_projection_hash(v,bits,C,rnd);
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
} }
......
...@@ -16,7 +16,8 @@ namespace dlib ...@@ -16,7 +16,8 @@ namespace dlib
> >
projection_hash create_random_projection_hash ( projection_hash create_random_projection_hash (
const vector_type& v, const vector_type& v,
const int bits const int bits,
dlib::rand& rnd
); );
/*! /*!
requires requires
...@@ -30,6 +31,7 @@ namespace dlib ...@@ -30,6 +31,7 @@ namespace dlib
- v[i].size() == v[j].size() - v[i].size() == v[j].size()
- i.e. v contains only column vectors and all the column vectors - i.e. v contains only column vectors and all the column vectors
have the same non-zero length have the same non-zero length
- rand_type == a type that implements the dlib/rand/rand_kernel_abstract.h interface
ensures ensures
- returns a hash function H such that: - returns a hash function H such that:
- H.num_hash_bins() == pow(2,bits) - H.num_hash_bins() == pow(2,bits)
...@@ -39,6 +41,35 @@ namespace dlib ...@@ -39,6 +41,35 @@ namespace dlib
particular, each plane normal vector is filled with Gaussian random particular, each plane normal vector is filled with Gaussian random
numbers and we also perform basic centering to ensure the plane passes numbers and we also perform basic centering to ensure the plane passes
though the data. though the data.
- This function uses the supplied random number generator, rnd, to drive part
of it's processing. Therefore, giving different random number generators
will produce different outputs.
!*/
// ----------------------------------------------------------------------------------------
template <
typename vector_type
>
projection_hash create_random_projection_hash (
const vector_type& v,
const int bits
);
/*!
requires
- 0 < bits <= 32
- v.size() > 1
- vector_type == a std::vector or compatible type containing dlib::matrix
objects, each representing a column vector of the same size.
- for all valid i, j:
- is_col_vector(v[i]) == true
- v[i].size() > 0
- v[i].size() == v[j].size()
- i.e. v contains only column vectors and all the column vectors
have the same non-zero length
ensures
- returns create_random_projection_hash(v,bits,dlib::rand())
(i.e. calls the above function with a default initialized random number generator)
!*/ !*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -49,7 +80,8 @@ namespace dlib ...@@ -49,7 +80,8 @@ namespace dlib
projection_hash create_max_margin_projection_hash ( projection_hash create_max_margin_projection_hash (
const vector_type& v, const vector_type& v,
const int bits, const int bits,
const double C = 10 const double C,
dlib::rand& rnd
); );
/*! /*!
requires requires
...@@ -63,6 +95,7 @@ namespace dlib ...@@ -63,6 +95,7 @@ namespace dlib
- v[i].size() == v[j].size() - v[i].size() == v[j].size()
- i.e. v contains only column vectors and all the column vectors - i.e. v contains only column vectors and all the column vectors
have the same non-zero length have the same non-zero length
- rand_type == a type that implements the dlib/rand/rand_kernel_abstract.h interface
ensures ensures
- returns a hash function H such that: - returns a hash function H such that:
- H.num_hash_bins() == pow(2,bits) - H.num_hash_bins() == pow(2,bits)
...@@ -74,6 +107,36 @@ namespace dlib ...@@ -74,6 +107,36 @@ namespace dlib
In particular, we use the svm_c_linear_dcd_trainer to generate planes. In particular, we use the svm_c_linear_dcd_trainer to generate planes.
We train it on randomly selected and randomly labeled points from v. We train it on randomly selected and randomly labeled points from v.
The C SVM parameter is set to the given C argument. The C SVM parameter is set to the given C argument.
- This function uses the supplied random number generator, rnd, to drive part
of it's processing. Therefore, giving different random number generators
will produce different outputs.
!*/
// ----------------------------------------------------------------------------------------
template <
typename vector_type
>
projection_hash create_max_margin_projection_hash (
const vector_type& v,
const int bits,
const double C = 10
);
/*!
requires
- 0 < bits <= 32
- v.size() > 1
- vector_type == a std::vector or compatible type containing dlib::matrix
objects, each representing a column vector of the same size.
- for all valid i, j:
- is_col_vector(v[i]) == true
- v[i].size() > 0
- v[i].size() == v[j].size()
- i.e. v contains only column vectors and all the column vectors
have the same non-zero length
ensures
- returns create_max_margin_projection_hash(v,bits,C,dlib::rand())
(i.e. calls the above function with a default initialized random number generator)
!*/ !*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment