Commit cc2de0e9 authored by Davis King's avatar Davis King

Improved the ranking example

parent cc708d04
......@@ -9,7 +9,7 @@
In this example, we will create a simple test dataset and show how to learn
a ranking function on it. The purpose of the function will be to give
a ranking function from it. The purpose of the function will be to give
"relevant" objects higher scores than "non-relevant" objects. The idea is
that you use this score to order the objects so that the most relevant
objects come to the top of the ranked list.
......@@ -43,16 +43,17 @@ int main()
// should rank higher than other vectors. So what we do is make
// examples of relevant (i.e. high ranking) and non-relevant (i.e. low
// ranking) vectors and store them into a ranking_pair object like so:
ranking_pair<sample_type> query;
ranking_pair<sample_type> data;
sample_type samp;
// Make one relevant example.
samp = 1, 0;
query.relevant.push_back(samp);
data.relevant.push_back(samp);
// Now make a non-relevant example.
samp = 0, 1;
query.nonrelevant.push_back(samp);
data.nonrelevant.push_back(samp);
// Now that we have some data, we can use a machine learning method to
// learn a function that will give high scores to the relevant vectors
......@@ -66,17 +67,29 @@ int main()
// linear_kernel.
typedef linear_kernel<sample_type> kernel_type;
// Now make a trainer and tell it to learn a ranking function based on
// our data.
svm_rank_trainer<kernel_type> trainer;
decision_function<kernel_type> rank = trainer.train(query);
decision_function<kernel_type> rank = trainer.train(data);
// Now if you call rank on a vector it will output a ranking score. In
// particular, the ranking score for relevant vectors should be larger
// than the score for non-relevant vectors.
cout << "ranking score for a relevant vector: " << rank(data.relevant[0]) << endl;
cout << "ranking score for a non-relevant vector: " << rank(data.nonrelevant[0]) << endl;
// These output the following:
/*
ranking score for a relevant vector: 0.5
ranking score for a non-relevant vector: -0.5
*/
cout << "ranking score for a relevant vector: " << rank(query.relevant[0]) << endl;
cout << "ranking score for a non-relevant vector: " << rank(query.nonrelevant[0]) << endl;
// If we want an overall measure of ranking accuracy, we can find out
// how often a non-relevant vector was ranked ahead of a relevant
// vector like so. This is a number between 0 and 1. A value of 1
// means everything was ranked perfectly.
cout << "accuracy: " << test_ranking_function(rank, query) << endl;
// vector using test_ranking_function(). In this case, it returns a
// value of 1, indicating that the rank function outputs a perfect
// ranking.
cout << "accuracy: " << test_ranking_function(rank, data) << endl;
// We can also see the ranking weights:
cout << "learned ranking weights: \n" << rank.basis_vectors(0) << endl;
......@@ -87,12 +100,42 @@ int main()
// In the above example, our data contains just two sets of objects.
// The relevant set and non-relevant set. The trainer is attempting to
// find a ranking function that gives every relevant vector a higher
// score than every non-relevant vector. Sometimes what you want to do
// is a little more complex than this.
//
// For example, in the web page ranking example we have to rank pages
// based on a user's query. In this case, each query will have its own
// set of relevant and non-relevant documents. What might be relevant
// to one query may well be non-relevant to another. So in this case
// we don't have a single global set of relevant web pages and another
// set of non-relevant web pages.
//
// To handle cases like this, we can simply give multiple ranking_pair
// instances to the trainer. Each ranking_pair representing the
// relevant/non-relevant sets for a particular query. An example is
// shown below (for simplicity, we reuse our data from above to make 4
// identical "queries").
std::vector<ranking_pair<sample_type> > queries;
queries.push_back(query);
queries.push_back(query);
queries.push_back(query);
queries.push_back(query);
queries.push_back(data);
queries.push_back(data);
queries.push_back(data);
queries.push_back(data);
// We train just as before.
rank = trainer.train(queries);
// Now that we have multiple ranking_pair instances, we can also use
// cross_validate_ranking_trainer(). This performs cross-validation by
// splitting the queries up into folds. That is, it lets the trainer
// train on a subset of ranking_pair instances and tests on the rest.
// It does this over 4 different splits and returns the overall ranking
// accuracy based on the held out data.
cout << "cv-accuracy: "<< cross_validate_ranking_trainer(trainer, queries, 4) << endl;
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment