Commit 1c269270 authored by Davis King's avatar Davis King

Added testing and cross validation routines for the python sequence segmenter interface.

parent a4590776
......@@ -355,6 +355,9 @@ void configure_trainer (
{
pyassert(samples.size() != 0, "Invalid arguments. You must give some training sequences.");
pyassert(samples[0].size() != 0, "Invalid arguments. You can't have zero length training sequences.");
pyassert(params.window_size != 0, "Invalid window_size parameter, it must be > 0.");
pyassert(params.epsilon > 0, "Invalid epsilon parameter, it must be > 0.");
pyassert(params.C > 0, "Invalid C parameter, it must be > 0.");
const long dims = samples[0][0].size();
trainer = structural_sequence_segmentation_trainer<T>(T(dims, params.window_size));
......@@ -532,11 +535,252 @@ segmenter_type train_sparse (
// ----------------------------------------------------------------------------------------
struct segmenter_test
{
double precision;
double recall;
double f1;
};
void serialize(const segmenter_test& item, std::ostream& out)
{
serialize(item.precision, out);
serialize(item.recall, out);
serialize(item.f1, out);
}
void deserialize(segmenter_test& item, std::istream& in)
{
deserialize(item.precision, in);
deserialize(item.recall, in);
deserialize(item.f1, in);
}
std::string segmenter_test__str__(const segmenter_test& item)
{
std::ostringstream sout;
sout << "precision: "<< item.precision << " recall: "<< item.recall << " f1-score: " << item.f1;
return sout.str();
}
std::string segmenter_test__repr__(const segmenter_test& item) { return "< " + segmenter_test__str__(item) + " >";}
// ----------------------------------------------------------------------------------------
const segmenter_test test_sequence_segmenter1 (
const segmenter_type& segmenter,
const std::vector<std::vector<dense_vect> >& samples,
const std::vector<ranges>& segments
)
{
pyassert(is_sequence_segmentation_problem(samples, segments), "Invalid inputs");
matrix<double,1,3> res;
switch(segmenter.mode)
{
case 0: res = test_sequence_segmenter(segmenter.segmenter0, samples, segments); break;
case 1: res = test_sequence_segmenter(segmenter.segmenter1, samples, segments); break;
case 2: res = test_sequence_segmenter(segmenter.segmenter2, samples, segments); break;
case 3: res = test_sequence_segmenter(segmenter.segmenter3, samples, segments); break;
case 4: res = test_sequence_segmenter(segmenter.segmenter4, samples, segments); break;
case 5: res = test_sequence_segmenter(segmenter.segmenter5, samples, segments); break;
case 6: res = test_sequence_segmenter(segmenter.segmenter6, samples, segments); break;
case 7: res = test_sequence_segmenter(segmenter.segmenter7, samples, segments); break;
default: throw dlib::error("Invalid mode");
}
segmenter_test temp;
temp.precision = res(0);
temp.recall = res(1);
temp.f1 = res(2);
return temp;
}
const segmenter_test test_sequence_segmenter2 (
const segmenter_type& segmenter,
const std::vector<std::vector<sparse_vect> >& samples,
const std::vector<ranges>& segments
)
{
pyassert(is_sequence_segmentation_problem(samples, segments), "Invalid inputs");
matrix<double,1,3> res;
switch(segmenter.mode)
{
case 8: res = test_sequence_segmenter(segmenter.segmenter8, samples, segments); break;
case 9: res = test_sequence_segmenter(segmenter.segmenter9, samples, segments); break;
case 10: res = test_sequence_segmenter(segmenter.segmenter10, samples, segments); break;
case 11: res = test_sequence_segmenter(segmenter.segmenter11, samples, segments); break;
case 12: res = test_sequence_segmenter(segmenter.segmenter12, samples, segments); break;
case 13: res = test_sequence_segmenter(segmenter.segmenter13, samples, segments); break;
case 14: res = test_sequence_segmenter(segmenter.segmenter14, samples, segments); break;
case 15: res = test_sequence_segmenter(segmenter.segmenter15, samples, segments); break;
default: throw dlib::error("Invalid mode");
}
segmenter_test temp;
temp.precision = res(0);
temp.recall = res(1);
temp.f1 = res(2);
return temp;
}
// ----------------------------------------------------------------------------------------
const segmenter_test cross_validate_sequence_segmenter1 (
const std::vector<std::vector<dense_vect> >& samples,
const std::vector<ranges>& segments,
long folds,
segmenter_params params
)
{
pyassert(is_sequence_segmentation_problem(samples, segments), "Invalid inputs");
pyassert(1 < folds && folds <= static_cast<long>(samples.size()), "folds argument is outside the valid range.");
matrix<double,1,3> res;
int mode = 0;
if (params.use_BIO_model)
mode = mode*2 + 1;
else
mode = mode*2;
if (params.use_high_order_features)
mode = mode*2 + 1;
else
mode = mode*2;
if (params.allow_negative_weights)
mode = mode*2 + 1;
else
mode = mode*2;
switch(mode)
{
case 0: { structural_sequence_segmentation_trainer<segmenter_type::fe0> trainer;
configure_trainer(samples, trainer, params);
res = cross_validate_sequence_segmenter(trainer, samples, segments, folds);
} break;
case 1: { structural_sequence_segmentation_trainer<segmenter_type::fe1> trainer;
configure_trainer(samples, trainer, params);
res = cross_validate_sequence_segmenter(trainer, samples, segments, folds);
} break;
case 2: { structural_sequence_segmentation_trainer<segmenter_type::fe2> trainer;
configure_trainer(samples, trainer, params);
res = cross_validate_sequence_segmenter(trainer, samples, segments, folds);
} break;
case 3: { structural_sequence_segmentation_trainer<segmenter_type::fe3> trainer;
configure_trainer(samples, trainer, params);
res = cross_validate_sequence_segmenter(trainer, samples, segments, folds);
} break;
case 4: { structural_sequence_segmentation_trainer<segmenter_type::fe4> trainer;
configure_trainer(samples, trainer, params);
res = cross_validate_sequence_segmenter(trainer, samples, segments, folds);
} break;
case 5: { structural_sequence_segmentation_trainer<segmenter_type::fe5> trainer;
configure_trainer(samples, trainer, params);
res = cross_validate_sequence_segmenter(trainer, samples, segments, folds);
} break;
case 6: { structural_sequence_segmentation_trainer<segmenter_type::fe6> trainer;
configure_trainer(samples, trainer, params);
res = cross_validate_sequence_segmenter(trainer, samples, segments, folds);
} break;
case 7: { structural_sequence_segmentation_trainer<segmenter_type::fe7> trainer;
configure_trainer(samples, trainer, params);
res = cross_validate_sequence_segmenter(trainer, samples, segments, folds);
} break;
default: throw dlib::error("Invalid mode");
}
segmenter_test temp;
temp.precision = res(0);
temp.recall = res(1);
temp.f1 = res(2);
return temp;
}
const segmenter_test cross_validate_sequence_segmenter2 (
const std::vector<std::vector<sparse_vect> >& samples,
const std::vector<ranges>& segments,
long folds,
segmenter_params params
)
{
pyassert(is_sequence_segmentation_problem(samples, segments), "Invalid inputs");
pyassert(1 < folds && folds <= static_cast<long>(samples.size()), "folds argument is outside the valid range.");
matrix<double,1,3> res;
int mode = 0;
if (params.use_BIO_model)
mode = mode*2 + 1;
else
mode = mode*2;
if (params.use_high_order_features)
mode = mode*2 + 1;
else
mode = mode*2;
if (params.allow_negative_weights)
mode = mode*2 + 1;
else
mode = mode*2;
mode += 8;
switch(mode)
{
case 8: { structural_sequence_segmentation_trainer<segmenter_type::fe8> trainer;
configure_trainer(samples, trainer, params);
res = cross_validate_sequence_segmenter(trainer, samples, segments, folds);
} break;
case 9: { structural_sequence_segmentation_trainer<segmenter_type::fe9> trainer;
configure_trainer(samples, trainer, params);
res = cross_validate_sequence_segmenter(trainer, samples, segments, folds);
} break;
case 10: { structural_sequence_segmentation_trainer<segmenter_type::fe10> trainer;
configure_trainer(samples, trainer, params);
res = cross_validate_sequence_segmenter(trainer, samples, segments, folds);
} break;
case 11: { structural_sequence_segmentation_trainer<segmenter_type::fe11> trainer;
configure_trainer(samples, trainer, params);
res = cross_validate_sequence_segmenter(trainer, samples, segments, folds);
} break;
case 12: { structural_sequence_segmentation_trainer<segmenter_type::fe12> trainer;
configure_trainer(samples, trainer, params);
res = cross_validate_sequence_segmenter(trainer, samples, segments, folds);
} break;
case 13: { structural_sequence_segmentation_trainer<segmenter_type::fe13> trainer;
configure_trainer(samples, trainer, params);
res = cross_validate_sequence_segmenter(trainer, samples, segments, folds);
} break;
case 14: { structural_sequence_segmentation_trainer<segmenter_type::fe14> trainer;
configure_trainer(samples, trainer, params);
res = cross_validate_sequence_segmenter(trainer, samples, segments, folds);
} break;
case 15: { structural_sequence_segmentation_trainer<segmenter_type::fe15> trainer;
configure_trainer(samples, trainer, params);
res = cross_validate_sequence_segmenter(trainer, samples, segments, folds);
} break;
default: throw dlib::error("Invalid mode");
}
segmenter_test temp;
temp.precision = res(0);
temp.recall = res(1);
temp.f1 = res(2);
return temp;
}
// ----------------------------------------------------------------------------------------
void bind_sequence_segmenter()
{
class_<segmenter_params>("segmenter_params",
"This class is used to define all the optional parameters to the \n\
train_sequence_segmenter() routine. ")
train_sequence_segmenter() and cross_validate_sequence_segmenter() routines. ")
.def_readwrite("use_BIO_model", &segmenter_params::use_BIO_model)
.def_readwrite("use_high_order_features", &segmenter_params::use_high_order_features)
.def_readwrite("allow_negative_weights", &segmenter_params::allow_negative_weights)
......@@ -545,6 +789,7 @@ train_sequence_segmenter() routine. ")
.def_readwrite("epsilon", &segmenter_params::epsilon)
.def_readwrite("max_cache_size", &segmenter_params::max_cache_size)
.def_readwrite("C", &segmenter_params::C, "SVM C parameter")
.def_readwrite("be_verbose", &segmenter_params::be_verbose)
.def("__repr__",&segmenter_params__repr__)
.def("__str__",&segmenter_params__str__)
.def_pickle(serialize_pickle<segmenter_params>());
......@@ -555,9 +800,26 @@ train_sequence_segmenter() routine. ")
.def_readonly("weights", &segmenter_type::get_weights)
.def_pickle(serialize_pickle<segmenter_type>());
class_<segmenter_test> ("segmenter_test")
.def_readwrite("precision", &segmenter_test::precision)
.def_readwrite("recall", &segmenter_test::recall)
.def_readwrite("f1", &segmenter_test::f1)
.def("__repr__",&segmenter_test__repr__)
.def("__str__",&segmenter_test__str__)
.def_pickle(serialize_pickle<segmenter_test>());
using boost::python::arg;
def("train_sequence_segmenter", train_dense, (arg("samples"), arg("segments"), arg("params")=segmenter_params()));
def("train_sequence_segmenter", train_sparse, (arg("samples"), arg("segments"), arg("params")=segmenter_params()));
def("test_sequence_segmenter", test_sequence_segmenter1);
def("test_sequence_segmenter", test_sequence_segmenter2);
def("cross_validate_sequence_segmenter", cross_validate_sequence_segmenter1,
(arg("samples"), arg("segments"), arg("folds"), arg("params")=segmenter_params()));
def("cross_validate_sequence_segmenter", cross_validate_sequence_segmenter2,
(arg("samples"), arg("segments"), arg("folds"), arg("params")=segmenter_params()));
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment