-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add functionality to evaluate a test dataset
- Loading branch information
1 parent
ea9b2fe
commit 90a2137
Showing
5 changed files
with
78 additions
and
33 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,65 +1,87 @@ | ||
use std::path::Path; | ||
use serenade_optimized::{io, vmisknn}; | ||
|
||
use serenade_optimized::metrics::mrr::Mrr; | ||
use serenade_optimized::metrics::SessionMetric; | ||
use serenade_optimized::vmisknn::vmis_index::VMISIndex; | ||
use serenade_optimized::config::AppConfig; | ||
use serenade_optimized::metrics::evaluation_reporter::EvaluationReporter; | ||
use serenade_optimized::stopwatch::Stopwatch; | ||
|
||
fn main() { | ||
// hyper-parameters | ||
let n_most_recent_sessions = 1500; | ||
let neighborhood_size_k = 500; | ||
let last_items_in_session = 3; | ||
let idf_weighting = 1.0; | ||
let enable_business_logic = false; | ||
let config_path = std::env::args().nth(1).unwrap_or_default(); | ||
let config = AppConfig::new(config_path); | ||
|
||
let path_to_training = std::env::args() | ||
.nth(1) | ||
.expect("Training data file not specified!"); | ||
let m_most_recent_sessions = config.model.m_most_recent_sessions; | ||
let neighborhood_size_k = config.model.neighborhood_size_k; | ||
let num_items_to_recommend = config.model.num_items_to_recommend; | ||
let max_items_in_session = config.model.max_items_in_session; | ||
let enable_business_logic = config.logic.enable_business_logic; | ||
|
||
println!("training_data_file:{}", path_to_training); | ||
let training_data_path = Path::new(&config.data.training_data_path); | ||
let vmis_index = if training_data_path.is_dir() { | ||
// By default we use an index that is computed offline on billions of user-item interactions. | ||
VMISIndex::new(&config.data.training_data_path) | ||
} else if training_data_path.is_file() { | ||
// The following line creates an index directly from a csv file as input. | ||
VMISIndex::new_from_csv( | ||
&config.data.training_data_path, | ||
config.model.m_most_recent_sessions, | ||
config.model.idf_weighting as f64, | ||
) | ||
} else { | ||
panic!( | ||
"Training data file does not exist: {}", | ||
&config.data.training_data_path | ||
) | ||
}; | ||
|
||
let test_data_file = std::env::args() | ||
.nth(2) | ||
.expect("Test data file not specified!"); | ||
let test_data_file = config.hyperparam.test_data_path; | ||
println!("test_data_file:{}", test_data_file); | ||
|
||
let vmis_index = VMISIndex::new_from_csv(&*path_to_training, n_most_recent_sessions, idf_weighting); | ||
|
||
let ordered_test_sessions = io::read_test_data_evolving(&*test_data_file); | ||
|
||
let qty_max_reco_results = 20; | ||
let mut mymetric = Mrr::new(qty_max_reco_results); | ||
let mut reporter = EvaluationReporter::new(&io::read_training_data(&*config.data.training_data_path), num_items_to_recommend); | ||
|
||
let mut stopwatch = Stopwatch::new(); | ||
|
||
ordered_test_sessions | ||
.iter() | ||
.for_each(|(_session_id, evolving_session_items)| { | ||
for session_state in 1..evolving_session_items.len() { | ||
// use last x items of evolving session | ||
let start_index = if session_state > last_items_in_session { | ||
session_state - last_items_in_session | ||
let start_index = if session_state > max_items_in_session { | ||
session_state - max_items_in_session | ||
} else { | ||
0 | ||
}; | ||
let session: &[u64] = &evolving_session_items[start_index..session_state]; | ||
stopwatch.start(); | ||
let recommendations = vmisknn::predict( | ||
&vmis_index, | ||
&session, | ||
neighborhood_size_k, | ||
n_most_recent_sessions, | ||
qty_max_reco_results, | ||
m_most_recent_sessions, | ||
num_items_to_recommend, | ||
enable_business_logic, | ||
); | ||
|
||
stopwatch.stop(&start_index); | ||
let recommended_items = recommendations | ||
.into_sorted_vec() | ||
.iter() | ||
.map(|scored| scored.id) | ||
.collect::<Vec<u64>>(); | ||
|
||
let actual_next_items = Vec::from(&evolving_session_items[session_state..]); | ||
mymetric.add(&recommended_items, &actual_next_items); | ||
reporter.add(&recommended_items, &actual_next_items); | ||
} | ||
}); | ||
|
||
println!("{}: {}", mymetric.get_name(), mymetric.result()); | ||
println!("==============================================================="); | ||
println!("=== START EVALUATING TEST FILE ===="); | ||
println!("==============================================================="); | ||
println!("{}", reporter.get_name()); | ||
println!("{}", reporter.result()); | ||
println!("Qty test evaluations: {}", stopwatch.get_n()); | ||
println!("Prediction latency"); | ||
println!("p90 (microseconds): {}", stopwatch.get_percentile_in_micros(90.0)); | ||
println!("p95 (microseconds): {}", stopwatch.get_percentile_in_micros(95.0)); | ||
println!("p99.5 (microseconds): {}", stopwatch.get_percentile_in_micros(99.5)); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters