Commit dd43ba44 authored by Davis King's avatar Davis King

Added an implementation of the least-squares policy iteration algorithm.

parent 7d32c4d1
// Copyright (C) 2015 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_CONTRoL_
#define DLIB_CONTRoL_
#include "control/lspi.h"
#endif // DLIB_CONTRoL_
// Copyright (C) 2015 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_APPROXIMATE_LINEAR_MODELS_Hh_
#define DLIB_APPROXIMATE_LINEAR_MODELS_Hh_
#include "approximate_linear_models_abstract.h"
#include "../matrix.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
template <
typename feature_extractor
>
struct process_sample
{
typedef typename feature_extractor::state_type state_type;
typedef typename feature_extractor::action_type action_type;
process_sample(){}
process_sample(
const state_type& s,
const action_type& a,
const state_type& n,
const double& r
) : state(s), action(a), next_state(n), reward(r) {}
state_type state;
action_type action;
state_type next_state;
double reward;
};
template < typename feature_extractor >
void serialize (const process_sample<feature_extractor>& item, std::ostream& out)
{
serialize(item.state, out);
serialize(item.action, out);
serialize(item.next_state, out);
serialize(item.reward, out);
}
template < typename feature_extractor >
void deserialize (process_sample<feature_extractor>& item, std::istream& in)
{
deserialize(item.state, in);
deserialize(item.action, in);
deserialize(item.next_state, in);
deserialize(item.reward, in);
}
// ----------------------------------------------------------------------------------------
template <
typename feature_extractor
>
class policy
{
public:
typedef typename feature_extractor::state_type state_type;
typedef typename feature_extractor::action_type action_type;
policy (
)
{
w.set_size(fe.num_features());
w = 0;
}
policy (
const matrix<double,0,1>& weights_,
const feature_extractor& fe_
) : w(weights_), fe(fe_) {}
action_type operator() (
const state_type& state
) const
{
return fe.find_best_action(state,w);
}
const feature_extractor& get_feature_extractor (
) const { return fe; }
const matrix<double,0,1>& get_weights (
) const { return w; }
private:
matrix<double,0,1> w;
feature_extractor fe;
};
template < typename feature_extractor >
inline void serialize(const policy<feature_extractor>& item, std::ostream& out)
{
int version = 1;
serialize(version, out);
serialize(item.get_feature_extractor(), out);
serialize(item.get_weights(), out);
}
template < typename feature_extractor >
inline void deserialize(policy<feature_extractor>& item, std::istream& in)
{
int version = 0;
deserialize(version, in);
if (version != 1)
throw serialization_error("Unexpected version found while deserializing dlib::policy object.");
feature_extractor fe;
matrix<double,0,1> w;
deserialize(fe, in);
deserialize(w, in);
item = policy<feature_extractor>(w,fe);
}
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_APPROXIMATE_LINEAR_MODELS_Hh_
// Copyright (C) 2015 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#undef DLIB_APPROXIMATE_LINEAR_MODELS_ABSTRACT_Hh_
#ifdef DLIB_APPROXIMATE_LINEAR_MODELS_ABSTRACT_Hh_
namespace dlib
{
}
#endif // DLIB_APPROXIMATE_LINEAR_MODELS_ABSTRACT_Hh_
// Copyright (C) 2015 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_LSPI_Hh_
#define DLIB_LSPI_Hh_
#include "lspi_abstract.h"
#include "approximate_linear_models.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
template <
typename feature_extractor
>
class lspi
{
public:
typedef typename feature_extractor::state_type state_type;
typedef typename feature_extractor::action_type action_type;
explicit lspi(
const feature_extractor& fe_
) : fe(fe_)
{
init();
}
lspi(
)
{
init();
}
double get_discount (
) const { return discount; }
void set_discount (
double value
)
{
discount = value;
}
const feature_extractor& get_feature_extractor (
) const { return fe; }
void be_verbose (
)
{
verbose = true;
}
void be_quiet (
)
{
verbose = false;
}
void set_epsilon (
double eps_
)
{
// make sure requires clause is not broken
DLIB_ASSERT(eps_ > 0,
"\t void lspi::set_epsilon(eps_)"
<< "\n\t invalid inputs were given to this function"
<< "\n\t eps_: " << eps_
);
eps = eps_;
}
double get_epsilon (
) const
{
return eps;
}
void set_lambda (
double lambda_
)
{
// make sure requires clause is not broken
DLIB_ASSERT(lambda_ >= 0,
"\t void lspi::set_lambda(lambda_)"
<< "\n\t invalid inputs were given to this function"
<< "\n\t lambda_: " << lambda_
);
lambda = lambda_;
}
double get_lambda (
) const
{
return lambda;
}
void set_max_iterations (
unsigned long max_iter
) { max_iterations = max_iter; }
/*!
ensures
- #get_max_iterations() == max_iter
!*/
unsigned long get_max_iterations (
) { return max_iterations; }
/*!
ensures
- returns the maximum number of iterations the SVM optimizer is allowed to
run before it is required to stop and return a result.
!*/
template <typename vector_type>
policy<feature_extractor> train (
//const std::vector<process_sample<feature_extractor> >& samples
const vector_type& samples
) const
{
matrix<double,0,1> w(fe.num_features());
w = 0;
matrix<double,0,1> prev_w, b, f1, f2;
matrix<double> A;
double change;
unsigned long iter = 0;
do
{
A = identity_matrix<double>(fe.num_features())*lambda;
b = 0;
for (unsigned long i = 0; i < samples.size(); ++i)
{
fe.get_features(samples[i].state, samples[i].action, f1);
fe.get_features(samples[i].next_state,
fe.find_best_action(samples[i].next_state,w),
f2);
A += f1*trans(f1 - discount*f2);
b += f1*samples[i].reward;
}
prev_w = w;
if (feature_extractor::force_last_weight_to_1)
w = join_cols(pinv(colm(A,range(0,A.nc()-2)))*(b-colm(A,A.nc()-1)),mat(1.0));
else
w = pinv(A)*b;
change = length(w-prev_w);
++iter;
if (verbose)
std::cout << "iteration: " << iter << "\tchange: " << change << std::endl;
} while(change > eps && iter < max_iterations);
return policy<feature_extractor>(w,fe);
}
private:
void init()
{
lambda = 0.01;
discount = 0.8;
eps = 0.01;
verbose = false;
max_iterations = 100;
}
double lambda;
double discount;
double eps;
bool verbose;
unsigned long max_iterations;
feature_extractor fe;
};
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_LSPI_Hh_
// Copyright (C) 2015 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#undef DLIB_LSPI_ABSTRACT_Hh_
#ifdef DLIB_LSPI_ABSTRACT_Hh_
namespace dlib
{
}
#endif // DLIB_LSPI_ABSTRACT_Hh_
......@@ -68,6 +68,7 @@ set (tests
learning_to_track.cpp
least_squares.cpp
linear_manifold_regularizer.cpp
lspi.cpp
lz77_buffer.cpp
map.cpp
matrix2.cpp
......
// Copyright (C) 2015 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#include "tester.h"
#include <dlib/control.h>
#include <vector>
#include <sstream>
#include <ctime>
namespace
{
using namespace test;
using namespace dlib;
using namespace std;
dlib::logger dlog("test.lspi");
template <bool have_prior>
struct chain_model
{
typedef int state_type;
typedef int action_type; // 0 is move left, 1 is move right
const static bool force_last_weight_to_1 = have_prior;
const static int num_states = 4; // not required in the model interface
matrix<double,8,1> offset;
chain_model()
{
offset =
2.048 ,
2.56 ,
2.048 ,
3.2 ,
2.56 ,
4 ,
3.2,
5 ;
if (!have_prior)
offset = 0;
}
unsigned long num_features(
) const
{
if (have_prior)
return num_states*2 + 1;
else
return num_states*2;
}
action_type find_best_action (
const state_type& state,
const matrix<double,0,1>& w
) const
{
if (w(state*2)+offset(state*2) >= w(state*2+1)+offset(state*2+1))
//if (w(state*2) >= w(state*2+1))
return 0;
else
return 1;
}
void get_features (
const state_type& state,
const action_type& action,
matrix<double,0,1>& feats
) const
{
feats.set_size(num_features());
feats = 0;
feats(state*2 + action) = 1;
if (have_prior)
feats(num_features()-1) = offset(state*2+action);
}
};
void test_lspi_prior1()
{
print_spinner();
typedef process_sample<chain_model<true> > sample_type;
std::vector<sample_type> samples;
samples.push_back(sample_type(0,0,0,0));
samples.push_back(sample_type(0,1,1,0));
samples.push_back(sample_type(1,0,0,0));
samples.push_back(sample_type(1,1,2,0));
samples.push_back(sample_type(2,0,1,0));
samples.push_back(sample_type(2,1,3,0));
samples.push_back(sample_type(3,0,2,0));
samples.push_back(sample_type(3,1,3,1));
lspi<chain_model<true> > trainer;
//trainer.be_verbose();
trainer.set_lambda(0);
policy<chain_model<true> > pol = trainer.train(samples);
dlog << LINFO << pol.get_weights();
matrix<double,0,1> w = pol.get_weights();
DLIB_TEST(pol.get_weights().size() == 9);
DLIB_TEST(w(w.size()-1) == 1);
w(w.size()-1) = 0;
DLIB_TEST_MSG(length(w) < 1e-12, length(w));
dlog << LINFO << "action: " << pol(0);
dlog << LINFO << "action: " << pol(1);
dlog << LINFO << "action: " << pol(2);
dlog << LINFO << "action: " << pol(3);
DLIB_TEST(pol(0) == 1);
DLIB_TEST(pol(1) == 1);
DLIB_TEST(pol(2) == 1);
DLIB_TEST(pol(3) == 1);
}
void test_lspi_prior2()
{
print_spinner();
typedef process_sample<chain_model<true> > sample_type;
std::vector<sample_type> samples;
samples.push_back(sample_type(0,0,0,0));
samples.push_back(sample_type(0,1,1,0));
samples.push_back(sample_type(1,0,0,0));
samples.push_back(sample_type(1,1,2,0));
samples.push_back(sample_type(2,0,1,0));
samples.push_back(sample_type(2,1,3,1));
samples.push_back(sample_type(3,0,2,0));
samples.push_back(sample_type(3,1,3,0));
lspi<chain_model<true> > trainer;
//trainer.be_verbose();
trainer.set_lambda(0);
policy<chain_model<true> > pol = trainer.train(samples);
dlog << LINFO << "action: " << pol(0);
dlog << LINFO << "action: " << pol(1);
dlog << LINFO << "action: " << pol(2);
dlog << LINFO << "action: " << pol(3);
DLIB_TEST(pol(0) == 1);
DLIB_TEST(pol(1) == 1);
DLIB_TEST(pol(2) == 1);
DLIB_TEST(pol(3) == 0);
}
void test_lspi_noprior1()
{
print_spinner();
typedef process_sample<chain_model<false> > sample_type;
std::vector<sample_type> samples;
samples.push_back(sample_type(0,0,0,0));
samples.push_back(sample_type(0,1,1,0));
samples.push_back(sample_type(1,0,0,0));
samples.push_back(sample_type(1,1,2,0));
samples.push_back(sample_type(2,0,1,0));
samples.push_back(sample_type(2,1,3,0));
samples.push_back(sample_type(3,0,2,0));
samples.push_back(sample_type(3,1,3,1));
lspi<chain_model<false> > trainer;
//trainer.be_verbose();
trainer.set_lambda(0.01);
policy<chain_model<false> > pol = trainer.train(samples);
dlog << LINFO << pol.get_weights();
DLIB_TEST(pol.get_weights().size() == 8);
dlog << LINFO << "action: " << pol(0);
dlog << LINFO << "action: " << pol(1);
dlog << LINFO << "action: " << pol(2);
dlog << LINFO << "action: " << pol(3);
DLIB_TEST(pol(0) == 1);
DLIB_TEST(pol(1) == 1);
DLIB_TEST(pol(2) == 1);
DLIB_TEST(pol(3) == 1);
}
void test_lspi_noprior2()
{
print_spinner();
typedef process_sample<chain_model<false> > sample_type;
std::vector<sample_type> samples;
samples.push_back(sample_type(0,0,0,0));
samples.push_back(sample_type(0,1,1,0));
samples.push_back(sample_type(1,0,0,0));
samples.push_back(sample_type(1,1,2,1));
samples.push_back(sample_type(2,0,1,0));
samples.push_back(sample_type(2,1,3,0));
samples.push_back(sample_type(3,0,2,0));
samples.push_back(sample_type(3,1,3,0));
lspi<chain_model<false> > trainer;
//trainer.be_verbose();
trainer.set_lambda(0.01);
policy<chain_model<false> > pol = trainer.train(samples);
dlog << LINFO << pol.get_weights();
DLIB_TEST(pol.get_weights().size() == 8);
dlog << LINFO << "action: " << pol(0);
dlog << LINFO << "action: " << pol(1);
dlog << LINFO << "action: " << pol(2);
dlog << LINFO << "action: " << pol(3);
DLIB_TEST(pol(0) == 1);
DLIB_TEST(pol(1) == 1);
DLIB_TEST(pol(2) == 0);
DLIB_TEST(pol(3) == 0);
}
class lspi_tester : public tester
{
public:
lspi_tester (
) :
tester (
"test_lspi", // the command line argument name for this test
"Run tests on the lspi object.", // the command line argument description
0 // the number of command line arguments for this test
)
{
}
void perform_test (
)
{
test_lspi_prior1();
test_lspi_prior2();
test_lspi_noprior1();
test_lspi_noprior2();
}
};
lspi_tester a;
}
......@@ -83,6 +83,7 @@ SRC += kmeans.cpp
SRC += learning_to_track.cpp
SRC += least_squares.cpp
SRC += linear_manifold_regularizer.cpp
SRC += lspi.cpp
SRC += lz77_buffer.cpp
SRC += map.cpp
SRC += matrix2.cpp
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment