Commit e4c7a6fc authored by Davis King's avatar Davis King

Added overloads of the randomize_samples() functions that allow

the user to supply a random number generator.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%403338
parent 7bca0ec6
...@@ -72,6 +72,19 @@ namespace dlib ...@@ -72,6 +72,19 @@ namespace dlib
!*/ !*/
}; };
// ----------------------------------------------------------------------------------------
template <typename T>
struct is_rand : public default_is_kind_value
{
/*!
- if (T is an implementation of rand/rand_kernel_abstract.h) then
- is_rand<T>::value == true
- else
- is_rand<T>::value == false
!*/
};
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
// Implementation details // Implementation details
......
...@@ -89,6 +89,14 @@ namespace dlib ...@@ -89,6 +89,14 @@ namespace dlib
rand_float_1<rand_base>& b rand_float_1<rand_base>& b
) { a.swap(b); } ) { a.swap(b); }
// ----------------------------------------------------------------------------------------
template <typename rand_base>
struct is_rand<rand_float_1<rand_base> >
{
static const bool value = true;
};
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
} }
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include "../algs.h" #include "../algs.h"
#include "rand_kernel_abstract.h" #include "rand_kernel_abstract.h"
#include "mersenne_twister.h" #include "mersenne_twister.h"
#include "../is_kind.h"
namespace dlib namespace dlib
{ {
...@@ -115,6 +116,13 @@ namespace dlib ...@@ -115,6 +116,13 @@ namespace dlib
rand_kernel_1& b rand_kernel_1& b
) { a.swap(b); } ) { a.swap(b); }
template <>
struct is_rand<rand_kernel_1>
{
static const bool value = true;
};
} }
#endif // DLIB_RAND_KERNEl_1_ #endif // DLIB_RAND_KERNEl_1_
......
...@@ -804,15 +804,18 @@ namespace dlib ...@@ -804,15 +804,18 @@ namespace dlib
folds); folds);
} }
// ----------------------------------------------------------------------------------------
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template < template <
typename T, typename T,
typename U typename U,
typename rand_type
> >
typename enable_if<is_matrix<T>,void>::type randomize_samples ( typename enable_if<is_matrix<T>,void>::type randomize_samples (
T& t, T& t,
U& u U& u,
rand_type& r
) )
{ {
// make sure requires clause is not broken // make sure requires clause is not broken
...@@ -825,8 +828,6 @@ namespace dlib ...@@ -825,8 +828,6 @@ namespace dlib
<< "\n\t is_vector(u): " << (is_vector(u)? "true" : "false") << "\n\t is_vector(u): " << (is_vector(u)? "true" : "false")
); );
rand::kernel_1a r;
long n = t.size()-1; long n = t.size()-1;
while (n > 0) while (n > 0)
{ {
...@@ -848,11 +849,13 @@ namespace dlib ...@@ -848,11 +849,13 @@ namespace dlib
template < template <
typename T, typename T,
typename U typename U,
typename rand_type
> >
typename disable_if<is_matrix<T>,void>::type randomize_samples ( typename disable_if<is_matrix<T>,void>::type randomize_samples (
T& t, T& t,
U& u U& u,
rand_type& r
) )
{ {
// make sure requires clause is not broken // make sure requires clause is not broken
...@@ -863,8 +866,6 @@ namespace dlib ...@@ -863,8 +866,6 @@ namespace dlib
<< "\n\t u.size(): " << u.size() << "\n\t u.size(): " << u.size()
); );
rand::kernel_1a r;
long n = t.size()-1; long n = t.size()-1;
while (n > 0) while (n > 0)
{ {
...@@ -885,10 +886,27 @@ namespace dlib ...@@ -885,10 +886,27 @@ namespace dlib
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template < template <
typename T typename T,
typename U
> >
typename enable_if<is_matrix<T>,void>::type randomize_samples ( typename disable_if<is_rand<U>,void>::type randomize_samples (
T& t T& t,
U& u
)
{
rand::kernel_1a r;
randomize_samples(t,u,r);
}
// ----------------------------------------------------------------------------------------
template <
typename T,
typename rand_type
>
typename enable_if_c<is_matrix<T>::value && is_rand<rand_type>::value,void>::type randomize_samples (
T& t,
rand_type& r
) )
{ {
// make sure requires clause is not broken // make sure requires clause is not broken
...@@ -898,8 +916,6 @@ namespace dlib ...@@ -898,8 +916,6 @@ namespace dlib
<< "\n\t is_vector(t): " << (is_vector(t)? "true" : "false") << "\n\t is_vector(t): " << (is_vector(t)? "true" : "false")
); );
rand::kernel_1a r;
long n = t.size()-1; long n = t.size()-1;
while (n > 0) while (n > 0)
{ {
...@@ -919,14 +935,14 @@ namespace dlib ...@@ -919,14 +935,14 @@ namespace dlib
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template < template <
typename T typename T,
typename rand_type
> >
typename disable_if<is_matrix<T>,void>::type randomize_samples ( typename disable_if_c<(is_matrix<T>::value==true)||(is_rand<rand_type>::value==false),void>::type randomize_samples (
T& t T& t,
rand_type& r
) )
{ {
rand::kernel_1a r;
long n = t.size()-1; long n = t.size()-1;
while (n > 0) while (n > 0)
{ {
...@@ -943,6 +959,19 @@ namespace dlib ...@@ -943,6 +959,19 @@ namespace dlib
} }
} }
// ----------------------------------------------------------------------------------------
template <
typename T
>
void randomize_samples (
T& t
)
{
rand::kernel_1a r;
randomize_samples(t,r);
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
......
...@@ -365,6 +365,7 @@ namespace dlib ...@@ -365,6 +365,7 @@ namespace dlib
- std::bad_alloc - std::bad_alloc
!*/ !*/
// ----------------------------------------------------------------------------------------
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template < template <
...@@ -387,6 +388,40 @@ namespace dlib ...@@ -387,6 +388,40 @@ namespace dlib
ensures ensures
- randomizes the order of the samples and labels but preserves - randomizes the order of the samples and labels but preserves
the pairing between each sample and its label the pairing between each sample and its label
- A default initialized random number generator is used to perform the randomizing.
Note that this means that each call this this function does the same thing.
That is, the random number generator always uses the same seed.
- for all valid i:
- let r == the random index samples(i) was moved to. then:
- #labels(r) == labels(i)
!*/
// ----------------------------------------------------------------------------------------
template <
typename T,
typename U,
typename rand_type
>
void randomize_samples (
T& samples,
U& labels,
rand_type& rnd
);
/*!
requires
- T == a matrix object or an object compatible with std::vector that contains
a swappable type.
- U == a matrix object or an object compatible with std::vector that contains
a swappable type.
- if samples or labels are matrix objects then is_vector(samples) == true and
is_vector(labels) == true
- samples.size() == labels.size()
- rand_type == a type that implements the dlib/rand/rand_kernel_abstract.h interface
ensures
- randomizes the order of the samples and labels but preserves
the pairing between each sample and its label
- the given rnd random number generator object is used to do the randomizing
- for all valid i: - for all valid i:
- let r == the random index samples(i) was moved to. then: - let r == the random index samples(i) was moved to. then:
- #labels(r) == labels(i) - #labels(r) == labels(i)
...@@ -407,9 +442,33 @@ namespace dlib ...@@ -407,9 +442,33 @@ namespace dlib
- if samples is a matrix then is_vector(samples) == true - if samples is a matrix then is_vector(samples) == true
ensures ensures
- randomizes the order of the elements inside samples - randomizes the order of the elements inside samples
- A default initialized random number generator is used to perform the randomizing.
Note that this means that each call this this function does the same thing.
That is, the random number generator always uses the same seed.
!*/
// ----------------------------------------------------------------------------------------
template <
typename T,
typename rand_type
>
void randomize_samples (
T& samples,
rand_type& rnd
);
/*!
requires
- T == a matrix object or an object compatible with std::vector that contains
a swappable type.
- if samples is a matrix then is_vector(samples) == true
ensures
- randomizes the order of the elements inside samples
- the given rnd random number generator object is used to do the randomizing
!*/ !*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment