From a5e560429177f38716453bcf38e45fe0e96093f9 Mon Sep 17 00:00:00 2001 From: MilaKyr <11293949+MilaKyr@users.noreply.github.com> Date: Mon, 11 Sep 2023 21:21:33 +0200 Subject: [PATCH 1/6] Add separate updates for counts and rewards It could be used in applications, where reward is naturally separated from the arm selection --- src/lib.rs | 6 +++++ src/softmax.rs | 59 +++++++++++++++++--------------------------------- 2 files changed, 26 insertions(+), 39 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 3cf79ff..9e139ca 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,6 +10,8 @@ use std::hash::{Hash}; use std::io; pub mod softmax; +pub mod ucb; +mod utils; #[derive(Debug, PartialEq, Clone)] pub struct BanditConfig { @@ -24,6 +26,10 @@ pub static DEFAULT_BANDIT_CONFIG : BanditConfig = BanditConfig{log_file: Option: pub trait MultiArmedBandit { fn select_arm(&self) -> A; fn update(&mut self, arm: A, reward: f64); + /// additional function to update counts of selected arm in multi-threaded applications + fn update_counts(&mut self, arm: &A); + /// additional function to update rewards of selected arm in multi-threaded applications + fn update_rewards(&mut self, arm: &A, reward: f64); /// stores the current state of the bandit algorithm in /// the supplied file. Every implementation has a corresponding diff --git a/src/softmax.rs b/src/softmax.rs index f4e92e7..5a29491 100644 --- a/src/softmax.rs +++ b/src/softmax.rs @@ -3,6 +3,7 @@ extern crate serde; extern crate serde_json; use super::{MultiArmedBandit, Identifiable, BanditConfig}; +use super::utils::{find_arm, log, log_command}; use std::collections::{HashMap}; use std::hash::{Hash}; use std::cmp::{Eq}; @@ -83,7 +84,6 @@ impl AnnealingSoftmax { } impl MultiArmedBandit for AnnealingSoftmax { - fn select_arm(&self) -> A { let mut t : u64 = 1; @@ -148,6 +148,25 @@ impl MultiArmedBandit for AnnealingSoftm self.log_update(&arm, val_norm); } + fn update_counts(&mut self, arm: &A) { + { + let n_ = self.counts.entry(arm.clone()).or_insert(0); + *n_ += 1; + } + log_command("update counts", arm); + } + + fn update_rewards(&mut self, arm: &A, reward: f64) { + let val_norm; + { + let n = self.counts.get(arm).copied().unwrap_or_default() as f64; + let val = self.values.entry(arm.clone()).or_insert(0.0); + *val = ((n - 1.0) / n) * *val + (1.0 / n) * reward; + val_norm = *val; + } + self.log_update(arm, val_norm); + } + fn save_bandit(&self, path: &Path) -> io::Result<()> { let mut counts = HashMap::new(); @@ -177,44 +196,6 @@ impl MultiArmedBandit for AnnealingSoftm } } -fn log_command(cmd: &str, arm: &A) -> String { - format!("{};{};{}", cmd, arm.ident(), timestamp()) -} - -fn timestamp() -> u64 { - let timestamp_result = time::SystemTime::now().duration_since(time::UNIX_EPOCH); - let timestamp = timestamp_result.expect("system time"); - timestamp.as_secs() * 1_000 + u64::from(timestamp.subsec_millis()) -} - -fn log(line : &str, path : &Option) { - if path.is_none() { - return; - } - - let file = OpenOptions::new() - .append(true) - .create(true) - .open(path.as_ref().unwrap()); - if file.is_ok() { - let write_result = writeln!(file.unwrap(), "{}", line); - if write_result.is_err() { - println!("writing log failed {}", line); - } - } else { - println!("logging failed: {}", line); - } -} - -fn find_arm<'a, A: Identifiable>(arms : &'a [A], ident: &str) -> io::Result<&'a A> { - for arm in arms { - if arm.ident() == ident { - return Ok(arm); - } - } - Err(Error::new(ErrorKind::NotFound, format!("arm {} not found", ident))) -} - #[derive(Serialize, Deserialize)] struct ExternalFormat { config: AnnealingSoftmaxConfig, From 20fbec49c59421267f7af9cdf6226f528e73d179 Mon Sep 17 00:00:00 2001 From: MilaKyr <11293949+MilaKyr@users.noreply.github.com> Date: Mon, 11 Sep 2023 21:24:40 +0200 Subject: [PATCH 2/6] Move common functions to utils --- src/utils.rs | 82 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) create mode 100644 src/utils.rs diff --git a/src/utils.rs b/src/utils.rs new file mode 100644 index 0000000..081a3e7 --- /dev/null +++ b/src/utils.rs @@ -0,0 +1,82 @@ +use std::fs::OpenOptions; +use std::io; +use std::io::{Error, ErrorKind, Write}; +use std::path::PathBuf; +use std::time; +use Identifiable; + +pub(crate) fn select_argmax(collection: &[f64]) -> Option { + let mut current_max_value = None; + let mut current_max_position = None; + for (i, x) in collection.iter().enumerate() { + if current_max_value.unwrap_or(f64::MIN) < *x { + current_max_value = Some(*x); + current_max_position = Some(i); + } + } + current_max_position +} + +pub(crate) fn log_command(cmd: &str, arm: &A) -> String { + format!("{};{};{}", cmd, arm.ident(), timestamp()) +} + +pub(crate) fn timestamp() -> u64 { + let timestamp_result = time::SystemTime::now().duration_since(time::UNIX_EPOCH); + let timestamp = timestamp_result.expect("system time"); + timestamp.as_secs() * 1_000 + u64::from(timestamp.subsec_millis()) +} + +pub(crate) fn log(line: &str, path: &Option) { + if path.is_none() { + return; + } + + let file = OpenOptions::new() + .append(true) + .create(true) + .open(path.as_ref().unwrap()); + if file.is_ok() { + let write_result = writeln!(file.unwrap(), "{line}"); + if write_result.is_err() { + println!("writing log failed {line}"); + } + } else { + println!("logging failed: {line}"); + } +} + +pub(crate) fn find_arm<'a, A: Identifiable>(arms: &'a [A], ident: &str) -> io::Result<&'a A> { + for arm in arms { + if arm.ident() == ident { + return Ok(arm); + } + } + Err(Error::new( + ErrorKind::NotFound, + format!("arm {ident} not found"), + )) +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn select_first_as_max_works() { + let values = [10., 4., 3., 2.]; + assert_eq!(select_argmax(&values), Some(0)) + } + + #[test] + fn select_last_as_max_works() { + let values = [4., 3., 2., 10.]; + assert_eq!(select_argmax(&values), Some(3)) + } + + #[test] + fn select_works() { + let values = [0.56, 0.73, 1.67, 0.57]; + assert_eq!(select_argmax(&values), Some(2)) + } +} From af9e795bfdf8ed1fcc3f65e4ecd68c46d41b010c Mon Sep 17 00:00:00 2001 From: MilaKyr <11293949+MilaKyr@users.noreply.github.com> Date: Mon, 11 Sep 2023 21:37:44 +0200 Subject: [PATCH 3/6] Add UCB --- src/ucb.rs | 289 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 289 insertions(+) create mode 100644 src/ucb.rs diff --git a/src/ucb.rs b/src/ucb.rs new file mode 100644 index 0000000..4b79cc3 --- /dev/null +++ b/src/ucb.rs @@ -0,0 +1,289 @@ +use super::utils::{find_arm, log, log_command, select_argmax}; +use super::{BanditConfig, Identifiable, MultiArmedBandit}; +use std::cmp::Eq; +use std::collections::HashMap; +use std::fs::File; +use std::hash::Hash; +use std::io; +use std::io::{Read, Write}; +use std::path::Path; + +pub static DEFAULT_CONFIG: UcbConfig = UcbConfig { alpha: 0.5 }; + +#[derive(Debug, PartialEq, Copy, Clone, Serialize, Deserialize)] +pub struct UcbConfig { + /// The higher the value the faster the algorithms tends toward selecting + /// the arm with highest reward. Should be a number between [0, 1.0) + pub alpha: f64, +} + +#[derive(Debug, PartialEq)] +pub struct UCB { + config: UcbConfig, + bandit_config: BanditConfig, + pub arms: Vec, + counts: HashMap, + rewards: HashMap, + all_counts: u64, + all_arms_played_at_least_once: bool, +} + +impl UCB { + pub fn new(arms: Vec, bandit_config: BanditConfig, config: UcbConfig) -> UCB { + assert!(!arms.is_empty(), "Arms vector cannot be empty!"); + let mut rewards = HashMap::new(); + for arm in &arms { + rewards.insert(arm.clone(), 0.); + } + + let mut counts = HashMap::new(); + for arm in &arms { + counts.insert(arm.clone(), 0); + } + Self::new_with_values(arms, bandit_config, config, counts, rewards) + } + + pub fn new_with_values( + arms: Vec, + bandit_config: BanditConfig, + config: UcbConfig, + counts: HashMap, + rewards: HashMap, + ) -> UCB { + let all_counts: u64 = counts.values().sum(); + let all_arms_played_at_least_once = + all_counts > 0 && counts.values().filter(|value| **value == 0).count() == 0; + UCB { + config, + bandit_config, + arms, + counts, + rewards, + all_counts, + all_arms_played_at_least_once, + } + } + + pub fn load_bandit( + arms: Vec, + bandit_config: BanditConfig, + path: &Path, + ) -> io::Result> { + let mut file = File::open(path)?; + let mut content = String::new(); + file.read_to_string(&mut content)?; + + let deser: ExternalFormat = serde_json::from_str(&content)?; + + let mut counts = HashMap::new(); + for (arm_ident, count) in deser.counts { + let arm = find_arm(&arms, &arm_ident)?; + counts.insert(arm.clone(), count); + } + let all_counts: u64 = counts.values().sum(); + let mut values = HashMap::new(); + for (arm_ident, val) in deser.rewards { + let arm = find_arm(&arms, &arm_ident)?; + values.insert(arm.clone(), val); + } + + let all_arms_played = counts.values().filter(|c| **c == 0).count() == 0; + Ok(UCB { + config: deser.config, + bandit_config, + arms, + counts, + rewards: values, + all_counts, + all_arms_played_at_least_once: all_arms_played, + }) + } + + fn check_if_all_played(&self) -> bool { + self.counts.values().filter(|c| **c == 0).count() == 0 + } + + fn log_update(&self, arm: &A) { + log(&log_command("UPDATE", arm), &self.bandit_config.log_file); + } + + fn log_select(&self, arm: &A) { + log(&log_command("SELECT", arm), &self.bandit_config.log_file); + } + + fn exploration(&self, arm_counts: f64) -> f64 { + ((self.all_counts as f64).ln() / arm_counts).sqrt() + } + + fn calculate_best_arm(&self) -> Option { + let mut arms_estimations = vec![]; + for arm in self.arms.iter() { + let rewards = self.rewards.get(arm)?; + let n_counts = *self.counts.get(arm)? as f64; + let exploratory_factor = self.exploration(n_counts); + let est = *rewards / n_counts + exploratory_factor; + arms_estimations.push(est); + } + let argmax = select_argmax(&arms_estimations)?; + Some(self.arms[argmax].clone()) + } + + fn get_next_unexplored(&self) -> Option { + let mut unexplored: Vec<_> = self + .counts + .iter() + .filter(|(_, cnt)| **cnt == 0) + .map(|(arm, _)| arm.clone()) + .collect(); + unexplored.pop() + } +} + +impl MultiArmedBandit for UCB { + fn select_arm(&self) -> A { + let possible_arm_to_play = if self.all_arms_played_at_least_once { + self.calculate_best_arm() + } else { + self.get_next_unexplored() + }; + match possible_arm_to_play { + Some(arm) => { + self.log_select(&arm); + arm + } + None => { + let fallback_arm = self.arms[self.arms.len() - 1].clone(); + self.log_select(&fallback_arm); + fallback_arm + } + } + } + + fn update(&mut self, arm: A, reward: f64) { + self.all_counts += 1; + let n_ = self.counts.entry(arm.clone()).or_insert(0); + *n_ += 1; + self.all_arms_played_at_least_once = self.check_if_all_played(); + let val = self.rewards.entry(arm.clone()).or_insert(0.0); + *val += reward; + self.log_update(&arm); + } + + fn update_counts(&mut self, arm: &A) { + self.all_counts += 1; + let n_ = self.counts.entry(arm.clone()).or_insert(0); + *n_ += 1; + self.all_arms_played_at_least_once = self.check_if_all_played(); + self.log_update(arm); + } + + fn update_rewards(&mut self, arm: &A, reward: f64) { + let val = self.rewards.entry(arm.clone()).or_insert(0.0); + *val += reward; + self.log_update(arm); + } + + fn save_bandit(&self, path: &Path) -> io::Result<()> { + let mut counts = HashMap::new(); + for (arm, count) in &self.counts { + counts.insert(arm.ident(), *count); + } + + let mut arms = Vec::with_capacity(self.arms.len()); + let mut values = HashMap::new(); + for (arm, value) in &self.rewards { + let arm_ident = arm.ident(); + arms.push(arm_ident.clone()); + values.insert(arm_ident, *value); + } + + let external_format = ExternalFormat { + arms, + counts, + rewards: values, + config: self.config, + }; + let ser = serde_json::to_string(&external_format)?; + + let mut file = File::create(path)?; + file.write_all(&ser.into_bytes())?; + file.flush() + } +} + +#[derive(Serialize, Deserialize)] +struct ExternalFormat { + arms: Vec, + counts: HashMap, + rewards: HashMap, + config: UcbConfig, +} + +#[cfg(test)] +mod test { + use super::*; + use crate::DEFAULT_BANDIT_CONFIG; + + #[derive(Hash, PartialEq, Eq, Clone, Copy, Debug)] + struct TestArm { + num: u32, + } + + impl Identifiable for TestArm { + fn ident(&self) -> String { + format!("arm:{}", self.num) + } + } + + #[test] + fn creating_bandit_works() { + let arms = vec![ + TestArm { num: 0 }, + TestArm { num: 1 }, + TestArm { num: 2 }, + TestArm { num: 3 }, + ]; + let _bandit = UCB::new( + arms.clone(), + DEFAULT_BANDIT_CONFIG.clone(), + DEFAULT_CONFIG.clone(), + ); + } + + #[test] + #[should_panic] + fn creating_bandit_fails_with_empty_arm_vector() { + let arms: Vec = vec![]; + UCB::new(arms, DEFAULT_BANDIT_CONFIG.clone(), DEFAULT_CONFIG.clone()); + } + + #[test] + fn select_next_unexplored_arm() { + let arms = vec![TestArm{num: 0}, TestArm{num: 1}, TestArm{num: 2}, TestArm{num: 3}]; + let mut bandit = UCB::new(arms.clone(), DEFAULT_BANDIT_CONFIG.clone(), DEFAULT_CONFIG.clone()); + assert!(!bandit.all_arms_played_at_least_once); + + let n_arms = 3; + for _ in 0..=n_arms { + let arm = bandit.select_arm(); + bandit.update_counts(&arm); + } + assert!(bandit.all_arms_played_at_least_once); + let expected_counts = vec![ + (TestArm{num: 0}, 1), (TestArm{num: 1}, 1), + (TestArm{num: 2}, 1), (TestArm{num: 3}, 1), + ].into_iter().collect::>(); + assert_eq!(bandit.counts, expected_counts) + } +} + +#[derive(Hash, PartialEq, Eq, Clone, Copy, Debug)] +struct TestArm { + num: u32, +} + +impl Identifiable for TestArm { + fn ident(&self) -> String { + format!("arm:{}", self.num) + } +} From b2c70f23cbbf0dc0aeca42f8bff0ae0aa2d5e13b Mon Sep 17 00:00:00 2001 From: MilaKyr <11293949+MilaKyr@users.noreply.github.com> Date: Mon, 11 Sep 2023 21:52:43 +0200 Subject: [PATCH 4/6] Move utility functions for tests into separate module --- tests/common/mod.rs | 45 +++++++++++++++++++++++++++++++++++++++ tests/softmax.rs | 52 +++++++++++---------------------------------- 2 files changed, 57 insertions(+), 40 deletions(-) create mode 100644 tests/common/mod.rs diff --git a/tests/common/mod.rs b/tests/common/mod.rs new file mode 100644 index 0000000..f6418c8 --- /dev/null +++ b/tests/common/mod.rs @@ -0,0 +1,45 @@ +extern crate bandit; +extern crate regex; + +use bandit::Identifiable; +use std::fs::File; +use std::io::Read; +use std::path::Path; + +pub const NUM_SELECTS: u32 = 100_000; +pub static LOG_UPDATE_FILE: &str = "./tmp_log_update.csv"; +pub static LOG_SELECT_FILE: &str = "./tmp_log_select.csv"; +const EPSILON: u32 = (NUM_SELECTS as f64 * 0.005) as u32; + +#[derive(Hash, PartialEq, Eq, Clone, Copy, Debug)] +pub struct TestArm { + pub num: u32, +} + +impl Identifiable for TestArm { + fn ident(&self) -> String { + format!("arm:{}", self.num) + } +} + +pub fn abs_select(prop: f64) -> u32 { + (f64::from(NUM_SELECTS) * prop) as u32 +} + +pub fn read_file_content(path: &str) -> String { + let mut file = File::open(Path::new(path)).unwrap(); + let mut log_content = String::new(); + file.read_to_string(&mut log_content).unwrap(); + log_content +} + +pub fn assert_prop(expected_count: u32, v: u32, arm: TestArm) { + assert!( + expected_count - EPSILON < v && v < expected_count + EPSILON, + "expected {}+-{}, got {} arm {:?}", + expected_count, + EPSILON, + v, + arm + ); +} diff --git a/tests/softmax.rs b/tests/softmax.rs index 9d9f204..7c825c2 100644 --- a/tests/softmax.rs +++ b/tests/softmax.rs @@ -1,6 +1,8 @@ extern crate bandit; extern crate regex; +mod common; + use bandit::{MultiArmedBandit, Identifiable, BanditConfig, DEFAULT_BANDIT_CONFIG}; use bandit::softmax::{AnnealingSoftmax, AnnealingSoftmaxConfig, DEFAULT_CONFIG}; use std::collections::{HashMap}; @@ -9,7 +11,8 @@ use std::fs::{File, remove_file}; use std::io::{Read}; use regex::{Regex}; -const NUM_SELECTS : u32 = 100_000; +use common::{TestArm, NUM_SELECTS}; + const EPSILON : u32 = (NUM_SELECTS as f64 * 0.005) as u32; #[test] @@ -23,9 +26,9 @@ pub fn test_select_arm() { *selects.entry(arm_selected).or_insert(0) += 1; } - let expected_count = abs_select(0.25); + let expected_count = common::abs_select(0.25); for (arm, v) in selects { - assert_prop(expected_count, v, arm); + common::assert_prop(expected_count, v, arm); } } @@ -144,13 +147,13 @@ fn test_save_and_load_bandit_with_missing_arm() { #[test] fn test_logging_update() { - let test_file = Path::new(LOG_UPDATE_FILE); + let test_file = Path::new(common::LOG_UPDATE_FILE); if test_file.exists() { remove_file(test_file).unwrap(); } let arms = vec![TestArm{num: 0}, TestArm{num: 1}, TestArm{num: 2}, TestArm{num: 3}]; - let bandit_config = BanditConfig{log_file: Some(PathBuf::from(LOG_UPDATE_FILE))}; + let bandit_config = BanditConfig{log_file: Some(PathBuf::from(common::LOG_UPDATE_FILE))}; let mut sm = AnnealingSoftmax::new(arms.clone(), bandit_config, AnnealingSoftmaxConfig{cooldown_factor: 1.0}); sm.update(arms[0], 10.0); @@ -158,7 +161,7 @@ fn test_logging_update() { sm.update(arms[2], 30.0); sm.update(arms[3], 40.0); - let log_content = read_file_content(LOG_UPDATE_FILE); + let log_content = common::read_file_content(common::LOG_UPDATE_FILE); let re = Regex::new( r#"^UPDATE;arm:0;\d{13};10 @@ -173,20 +176,20 @@ $"#).expect("compiled regex"); #[test] fn test_logging_select() { - let test_file = Path::new(LOG_SELECT_FILE); + let test_file = Path::new(common::LOG_SELECT_FILE); if test_file.exists() { remove_file(test_file).unwrap(); } let arms = vec![TestArm{num: 0}, TestArm{num: 1}, TestArm{num: 2}, TestArm{num: 3}]; - let bandit_config = BanditConfig{log_file: Some(PathBuf::from(LOG_SELECT_FILE))}; + let bandit_config = BanditConfig{log_file: Some(PathBuf::from(common::LOG_SELECT_FILE))}; let sm = AnnealingSoftmax::new(arms.clone(), bandit_config, AnnealingSoftmaxConfig{cooldown_factor: 1.0}); let select1 = sm.select_arm(); let select2 = sm.select_arm(); let select3 = sm.select_arm(); - let log_content = read_file_content(LOG_SELECT_FILE); + let log_content = common::read_file_content(common::LOG_SELECT_FILE); let re = Regex::new(&format!( r#"^SELECT;{};\d{{13}} @@ -196,34 +199,3 @@ $"#, select1.ident(), select2.ident(), select3.ident())).expect("compiled regex" assert!(re.is_match(&log_content), "log file did not match expected, was {}", &log_content); } - -//Helper - -static LOG_UPDATE_FILE : &str = "./tmp_log_update.csv"; -static LOG_SELECT_FILE : &str = "./tmp_log_select.csv"; - -fn read_file_content(path : &str) -> String { - let mut file = File::open(Path::new(path)).unwrap(); - let mut log_content = String::new(); - file.read_to_string(&mut log_content).unwrap(); - log_content -} - -fn abs_select(prop: f64) -> u32 { - (f64::from(NUM_SELECTS) * prop) as u32 -} - -fn assert_prop(expected_count: u32, v: u32, arm: TestArm) { - assert!(expected_count - EPSILON < v && v < expected_count + EPSILON, "expected {}+-{}, got {} arm {:?}", expected_count, EPSILON, v, arm); -} - -#[derive(Hash, PartialEq, Eq, Clone, Copy, Debug)] -struct TestArm { - num: u32 -} - -impl Identifiable for TestArm { - fn ident(&self) -> String { - format!("arm:{}", self.num) - } -} From 52ece8d50459e9baa55ebad6a128de7a3e10f348 Mon Sep 17 00:00:00 2001 From: MilaKyr <11293949+MilaKyr@users.noreply.github.com> Date: Mon, 11 Sep 2023 21:53:10 +0200 Subject: [PATCH 5/6] Add tests to UCB --- tests/ucb.rs | 334 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 334 insertions(+) create mode 100644 tests/ucb.rs diff --git a/tests/ucb.rs b/tests/ucb.rs new file mode 100644 index 0000000..c080c30 --- /dev/null +++ b/tests/ucb.rs @@ -0,0 +1,334 @@ +extern crate bandit; +extern crate regex; + +mod common; + +use bandit::ucb::{UcbConfig, DEFAULT_CONFIG, UCB}; +use bandit::{BanditConfig, Identifiable, MultiArmedBandit, DEFAULT_BANDIT_CONFIG}; +use regex::Regex; +use std::collections::HashMap; +use std::fs::remove_file; +use std::path::{Path, PathBuf}; + +use common::{TestArm, NUM_SELECTS}; + +#[test] +pub fn test_select_arm() { + let arms = vec![ + TestArm { num: 0 }, + TestArm { num: 1 }, + TestArm { num: 2 }, + TestArm { num: 3 }, + ]; + let mut ucb = UCB::new(arms, DEFAULT_BANDIT_CONFIG.clone(), DEFAULT_CONFIG); + + let mut selects: HashMap = HashMap::new(); + for _ in 0..NUM_SELECTS { + let arm_selected = ucb.select_arm(); + *selects.entry(arm_selected).or_default() += 1; + ucb.update_counts(&arm_selected); + } + + let expected_count = common::abs_select(0.25); + for (arm, v) in selects { + common::assert_prop(expected_count, v, arm); + } +} + +#[test] +fn test_moves_towards_arm_with_highest_reward_with_low_alpha() { + let arms = vec![ + TestArm { num: 0 }, + TestArm { num: 1 }, + TestArm { num: 2 }, + TestArm { num: 3 }, + ]; + let arm_test_rewards = vec![98.0, 100.0, 99.0, 98.5]; + let mut sm = UCB::new( + arms.clone(), + DEFAULT_BANDIT_CONFIG.clone(), + UcbConfig { alpha: 0.1 }, + ); + + let num_iterations = 500; + + let mut selects = Vec::<[u64; 4]>::with_capacity(num_iterations); + for _ in 0..num_iterations { + for i in 0..arms.len() { + sm.update_counts(&arms[i]); + sm.update(arms[i], arm_test_rewards[i]) + } + + let mut draws = [0; 4]; + for _ in 0..1000 { + let selected_arm = sm.select_arm(); + draws[selected_arm.num as usize] += 1; + } + selects.push(draws); + } + + assert!( + selects[num_iterations - 1][1] >= 996, + "last round should favour highest reward, but did not {}", + selects[num_iterations - 1][1] + ); +} + +#[test] +fn test_eq() { + let arms0 = vec![ + TestArm { num: 0 }, + TestArm { num: 1 }, + TestArm { num: 2 }, + TestArm { num: 3 }, + ]; + let ucb0 = UCB::new( + arms0.clone(), + DEFAULT_BANDIT_CONFIG.clone(), + UcbConfig { alpha: 1.0 }, + ); + + let arms0_2 = vec![ + TestArm { num: 0 }, + TestArm { num: 1 }, + TestArm { num: 2 }, + TestArm { num: 3 }, + ]; + let ucb0_2 = UCB::new( + arms0_2.clone(), + DEFAULT_BANDIT_CONFIG.clone(), + UcbConfig { alpha: 1.0 }, + ); + ucb0_2.select_arm(); //arm select does not change state + ucb0_2.select_arm(); + + let arms1 = vec![ + TestArm { num: 0 }, + TestArm { num: 1 }, + TestArm { num: 2 }, + TestArm { num: 3 }, + TestArm { num: 4 }, + ]; + let ucb1 = UCB::new( + arms1.clone(), + DEFAULT_BANDIT_CONFIG.clone(), + UcbConfig { alpha: 1.0 }, + ); + + let arms2 = vec![ + TestArm { num: 0 }, + TestArm { num: 1 }, + TestArm { num: 2 }, + TestArm { num: 3 }, + TestArm { num: 4 }, + ]; + let mut ucb2 = UCB::new( + arms2.clone(), + DEFAULT_BANDIT_CONFIG.clone(), + UcbConfig { alpha: 1.0 }, + ); + ucb2.update(arms2[0], 1.); + + let arms3 = vec![ + TestArm { num: 0 }, + TestArm { num: 1 }, + TestArm { num: 2 }, + TestArm { num: 3 }, + ]; + let mut ucb3 = UCB::new( + arms3.clone(), + DEFAULT_BANDIT_CONFIG.clone(), + UcbConfig { alpha: 1.0 }, + ); + ucb3.update(arms3[0], 34.32); + ucb3.update(arms3[2], 1.); + ucb3.update(arms3[3], 1.); + + assert_eq!(ucb0, ucb0_2); + assert_ne!(ucb0, ucb1); + assert_ne!(ucb1, ucb2); + assert_ne!(ucb1, ucb3); + assert_ne!(ucb2, ucb3); +} + +#[test] +fn test_more_often_selects_highest_reward_if_alpha_is_zero() { + let arms = vec![ + TestArm { num: 0 }, + TestArm { num: 1 }, + TestArm { num: 2 }, + TestArm { num: 3 }, + ]; + let reward_values = vec![10., 9000., 5., 1.]; + let mut counts = HashMap::new(); + let mut rewards = HashMap::new(); + for (id, arm) in arms.iter().enumerate() { + counts.insert(arm.clone(), 10_000); + rewards.insert(arm.clone(), reward_values[id]); + } + let ucb = UCB::new_with_values( + arms.clone(), + DEFAULT_BANDIT_CONFIG.clone(), + UcbConfig { alpha: 0.0 }, + counts, + rewards, + ); + + let num_iterations = 1_000; + let mut selects = Vec::<[u64; 4]>::with_capacity(num_iterations); + for _ in 0..num_iterations { + let mut draws = [0; 4]; + for _ in 0..num_iterations { + let selected_arm = ucb.select_arm(); + draws[selected_arm.num as usize] += 1; + } + selects.push(draws); + } + + assert_eq!(selects[num_iterations - 1][1] as usize, num_iterations); +} + +#[test] +fn test_save_and_load_bandit() { + let arms = vec![ + TestArm { num: 0 }, + TestArm { num: 1 }, + TestArm { num: 2 }, + TestArm { num: 3 }, + ]; + let mut ucb = UCB::new( + arms.clone(), + DEFAULT_BANDIT_CONFIG.clone(), + UcbConfig { alpha: 0.5 }, + ); + ucb.update(arms[0], 1.); + ucb.update(arms[1], 1.); + //no update on arms[2] + ucb.update(arms[3], 1.); + + let save_result = ucb.save_bandit(Path::new("./tmp_bandit.json")); + assert!(save_result.is_ok(), "save failed {:?}", save_result); + + let load_result = UCB::load_bandit( + arms, + DEFAULT_BANDIT_CONFIG.clone(), + Path::new("./tmp_bandit.json"), + ); + assert!(load_result.is_ok(), "load failed {:?}", load_result); + let ucb_loaded: UCB = load_result.unwrap(); + + assert_eq!(ucb, ucb_loaded); +} + +#[test] +fn test_save_and_load_bandit_with_missing_arm() { + let arms = vec![ + TestArm { num: 0 }, + TestArm { num: 1 }, + TestArm { num: 2 }, + TestArm { num: 3 }, + ]; + let ucb = UCB::new( + arms.clone(), + DEFAULT_BANDIT_CONFIG.clone(), + UcbConfig { alpha: 1.0 }, + ); + + let save_result = ucb.save_bandit(Path::new("./tmp_bandit_err.json")); + assert!(save_result.is_ok(), "save failed {:?}", save_result); + + let arms_last_one_missing = vec![TestArm { num: 0 }, TestArm { num: 1 }, TestArm { num: 2 }]; + let load_result = UCB::load_bandit( + arms_last_one_missing, + DEFAULT_BANDIT_CONFIG.clone(), + Path::new("./tmp_bandit.json"), + ); + assert!( + load_result.is_err(), + "load should fail, since TestArm{{num: 3}} could not be found, but was {:?}", + load_result + ); +} + +#[test] +fn test_logging_update() { + let test_file = Path::new(common::LOG_UPDATE_FILE); + if test_file.exists() { + remove_file(test_file).unwrap(); + } + + let arms = vec![ + TestArm { num: 0 }, + TestArm { num: 1 }, + TestArm { num: 2 }, + TestArm { num: 3 }, + ]; + let bandit_config = BanditConfig { + log_file: Some(PathBuf::from(common::LOG_UPDATE_FILE)), + }; + let mut ucb = UCB::new(arms.clone(), bandit_config, UcbConfig { alpha: 1.0 }); + + ucb.update(arms[0], 1.); + ucb.update(arms[1], 1.); + ucb.update(arms[2], 1.); + ucb.update(arms[3], 1.); + + let log_content = common::read_file_content(common::LOG_UPDATE_FILE); + + let re = Regex::new( + r#"^UPDATE;arm:0;\d{13} +UPDATE;arm:1;\d{13} +UPDATE;arm:2;\d{13} +UPDATE;arm:3;\d{13} +$"#, + ) + .expect("compiled regex"); + + assert!( + re.is_match(&log_content), + "log file did not match expected, was {}", + &log_content + ); +} + +#[test] +fn test_logging_select() { + let test_file = Path::new(common::LOG_SELECT_FILE); + if test_file.exists() { + remove_file(test_file).unwrap(); + } + + let arms = vec![ + TestArm { num: 0 }, + TestArm { num: 1 }, + TestArm { num: 2 }, + TestArm { num: 3 }, + ]; + let bandit_config = BanditConfig { + log_file: Some(PathBuf::from(common::LOG_SELECT_FILE)), + }; + let ucb = UCB::new(arms.clone(), bandit_config, UcbConfig { alpha: 1.0 }); + + let select1 = ucb.select_arm(); + let select2 = ucb.select_arm(); + let select3 = ucb.select_arm(); + + let log_content = common::read_file_content(common::LOG_SELECT_FILE); + + let re = Regex::new(&format!( + r#"^SELECT;{};\d{{13}} +SELECT;{};\d{{13}} +SELECT;{};\d{{13}} +$"#, + select1.ident(), + select2.ident(), + select3.ident() + )) + .expect("compiled regex"); + + assert!( + re.is_match(&log_content), + "log file did not match expected, was {}", + &log_content + ); +} From 8983f3d0538ff867e1e4d4d36e3dd8f68bfcd1ec Mon Sep 17 00:00:00 2001 From: MilaKyr <11293949+MilaKyr@users.noreply.github.com> Date: Mon, 11 Sep 2023 22:04:50 +0200 Subject: [PATCH 6/6] Add additional test --- src/utils.rs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/utils.rs b/src/utils.rs index 081a3e7..b498b6a 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -79,4 +79,10 @@ mod test { let values = [0.56, 0.73, 1.67, 0.57]; assert_eq!(select_argmax(&values), Some(2)) } + + #[test] + fn select_none_works() { + let values = []; + assert_eq!(select_argmax(&values), None) + } }