Commit cc2de0e9 authored by Davis King's avatar Davis King

Improved the ranking example

parent cc708d04
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
In this example, we will create a simple test dataset and show how to learn In this example, we will create a simple test dataset and show how to learn
a ranking function on it. The purpose of the function will be to give a ranking function from it. The purpose of the function will be to give
"relevant" objects higher scores than "non-relevant" objects. The idea is "relevant" objects higher scores than "non-relevant" objects. The idea is
that you use this score to order the objects so that the most relevant that you use this score to order the objects so that the most relevant
objects come to the top of the ranked list. objects come to the top of the ranked list.
...@@ -43,16 +43,17 @@ int main() ...@@ -43,16 +43,17 @@ int main()
// should rank higher than other vectors. So what we do is make // should rank higher than other vectors. So what we do is make
// examples of relevant (i.e. high ranking) and non-relevant (i.e. low // examples of relevant (i.e. high ranking) and non-relevant (i.e. low
// ranking) vectors and store them into a ranking_pair object like so: // ranking) vectors and store them into a ranking_pair object like so:
ranking_pair<sample_type> query; ranking_pair<sample_type> data;
sample_type samp; sample_type samp;
// Make one relevant example. // Make one relevant example.
samp = 1, 0; samp = 1, 0;
query.relevant.push_back(samp); data.relevant.push_back(samp);
// Now make a non-relevant example. // Now make a non-relevant example.
samp = 0, 1; samp = 0, 1;
query.nonrelevant.push_back(samp); data.nonrelevant.push_back(samp);
// Now that we have some data, we can use a machine learning method to // Now that we have some data, we can use a machine learning method to
// learn a function that will give high scores to the relevant vectors // learn a function that will give high scores to the relevant vectors
...@@ -66,17 +67,29 @@ int main() ...@@ -66,17 +67,29 @@ int main()
// linear_kernel. // linear_kernel.
typedef linear_kernel<sample_type> kernel_type; typedef linear_kernel<sample_type> kernel_type;
// Now make a trainer and tell it to learn a ranking function based on
// our data.
svm_rank_trainer<kernel_type> trainer; svm_rank_trainer<kernel_type> trainer;
decision_function<kernel_type> rank = trainer.train(query); decision_function<kernel_type> rank = trainer.train(data);
// Now if you call rank on a vector it will output a ranking score. In
// particular, the ranking score for relevant vectors should be larger
// than the score for non-relevant vectors.
cout << "ranking score for a relevant vector: " << rank(data.relevant[0]) << endl;
cout << "ranking score for a non-relevant vector: " << rank(data.nonrelevant[0]) << endl;
// These output the following:
/*
ranking score for a relevant vector: 0.5
ranking score for a non-relevant vector: -0.5
*/
cout << "ranking score for a relevant vector: " << rank(query.relevant[0]) << endl;
cout << "ranking score for a non-relevant vector: " << rank(query.nonrelevant[0]) << endl;
// If we want an overall measure of ranking accuracy, we can find out // If we want an overall measure of ranking accuracy, we can find out
// how often a non-relevant vector was ranked ahead of a relevant // how often a non-relevant vector was ranked ahead of a relevant
// vector like so. This is a number between 0 and 1. A value of 1 // vector using test_ranking_function(). In this case, it returns a
// means everything was ranked perfectly. // value of 1, indicating that the rank function outputs a perfect
cout << "accuracy: " << test_ranking_function(rank, query) << endl; // ranking.
cout << "accuracy: " << test_ranking_function(rank, data) << endl;
// We can also see the ranking weights: // We can also see the ranking weights:
cout << "learned ranking weights: \n" << rank.basis_vectors(0) << endl; cout << "learned ranking weights: \n" << rank.basis_vectors(0) << endl;
...@@ -87,12 +100,42 @@ int main() ...@@ -87,12 +100,42 @@ int main()
// In the above example, our data contains just two sets of objects.
// The relevant set and non-relevant set. The trainer is attempting to
// find a ranking function that gives every relevant vector a higher
// score than every non-relevant vector. Sometimes what you want to do
// is a little more complex than this.
//
// For example, in the web page ranking example we have to rank pages
// based on a user's query. In this case, each query will have its own
// set of relevant and non-relevant documents. What might be relevant
// to one query may well be non-relevant to another. So in this case
// we don't have a single global set of relevant web pages and another
// set of non-relevant web pages.
//
// To handle cases like this, we can simply give multiple ranking_pair
// instances to the trainer. Each ranking_pair representing the
// relevant/non-relevant sets for a particular query. An example is
// shown below (for simplicity, we reuse our data from above to make 4
// identical "queries").
std::vector<ranking_pair<sample_type> > queries; std::vector<ranking_pair<sample_type> > queries;
queries.push_back(query); queries.push_back(data);
queries.push_back(query); queries.push_back(data);
queries.push_back(query); queries.push_back(data);
queries.push_back(query); queries.push_back(data);
// We train just as before.
rank = trainer.train(queries);
// Now that we have multiple ranking_pair instances, we can also use
// cross_validate_ranking_trainer(). This performs cross-validation by
// splitting the queries up into folds. That is, it lets the trainer
// train on a subset of ranking_pair instances and tests on the rest.
// It does this over 4 different splits and returns the overall ranking
// accuracy based on the held out data.
cout << "cv-accuracy: "<< cross_validate_ranking_trainer(trainer, queries, 4) << endl; cout << "cv-accuracy: "<< cross_validate_ranking_trainer(trainer, queries, 4) << endl;
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment