Skip to content

Commit

Permalink
Simplify trueskill calculations
Browse files Browse the repository at this point in the history
Actually improves the accuracy slightly, compared to the other trueskill libraries like ts-trueskill.
And by slightly it means ~0.0000000000007 difference instead of ~0.0000000000009
  • Loading branch information
atomflunder committed Oct 10, 2024
1 parent 5256e1b commit 6bba1cc
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 78 deletions.
16 changes: 8 additions & 8 deletions src/trueskill/factor_graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,17 +230,17 @@ impl SumFactor {
pub struct TruncateFactor {
id: usize,
variable: Rc<RefCell<Variable>>,
v_func: Box<dyn Fn(f64, f64) -> f64>,
w_func: Box<dyn Fn(f64, f64) -> f64>,
v_func: Box<dyn Fn(f64, f64, f64) -> f64>,
w_func: Box<dyn Fn(f64, f64, f64) -> f64>,
draw_margin: f64,
}

impl TruncateFactor {
pub fn new(
id: usize,
variable: Rc<RefCell<Variable>>,
v_func: Box<dyn Fn(f64, f64) -> f64>,
w_func: Box<dyn Fn(f64, f64) -> f64>,
v_func: Box<dyn Fn(f64, f64, f64) -> f64>,
w_func: Box<dyn Fn(f64, f64, f64) -> f64>,
draw_margin: f64,
) -> Self {
variable.borrow_mut().messages.entry(id).or_default();
Expand All @@ -260,10 +260,10 @@ impl TruncateFactor {
variable.gaussian / variable.messages[&self.id]
};
let pi_sqrt = div.pi.sqrt();
let arg_1 = div.tau / pi_sqrt;
let arg_2 = self.draw_margin * pi_sqrt;
let v = (self.v_func)(arg_1, arg_2);
let w = (self.w_func)(arg_1, arg_2);
let arg_1 = div.tau;
let arg_2 = self.draw_margin * div.pi;
let v = (self.v_func)(arg_1, arg_2, pi_sqrt);
let w = (self.w_func)(arg_1, arg_2, pi_sqrt);
let denom = 1.0 - w;

let pi = div.pi / denom;
Expand Down
92 changes: 22 additions & 70 deletions src/trueskill/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1809,54 +1809,6 @@ fn build_trunc_layer(
beta: f64,
starting_id: usize,
) -> Vec<TruncateFactor> {
fn v_w(diff: f64, draw_margin: f64) -> f64 {
let x = diff - draw_margin;
let denom = cdf(x, 0.0, 1.0);

if denom == 0.0 {
-x
} else {
pdf(x, 0.0, 1.0) / denom
}
}

fn v_d(diff: f64, draw_margin: f64) -> f64 {
let abs_diff = diff.abs();
let a = draw_margin - abs_diff;
let b = -draw_margin - abs_diff;
let denom = cdf(a, 0.0, 1.0) - cdf(b, 0.0, 1.0);
let numer = pdf(b, 0.0, 1.0) - pdf(a, 0.0, 1.0);

let lhs = if denom == 0.0 { a } else { numer / denom };
let rhs = if diff < 0.0 { -1.0 } else { 1.0 };

lhs * rhs
}

fn w_w(diff: f64, draw_margin: f64) -> f64 {
let x = diff - draw_margin;
let v = v_w(diff, draw_margin);
let w = v * (v + x);
if 0.0 < w && w < 1.0 {
return w;
}

panic!("floating point error");
}

#[allow(clippy::suboptimal_flops)]
fn w_d(diff: f64, draw_margin: f64) -> f64 {
let abs_diff = diff.abs();
let a = draw_margin - abs_diff;
let b = -draw_margin - abs_diff;
let denom = cdf(a, 0.0, 1.0) - cdf(b, 0.0, 1.0);

assert!(!(denom == 0.0), "floating point error");

let v = v_d(abs_diff, draw_margin);
v.mul_add(v, (a * pdf(a, 0.0, 1.0) - b * pdf(b, 0.0, 1.0)) / denom)
}

let mut v = Vec::with_capacity(team_diff_vars.len());
let mut i = starting_id;
for (x, team_diff_var) in team_diff_vars.iter().enumerate() {
Expand All @@ -1865,14 +1817,14 @@ fn build_trunc_layer(
.map(|v| v.0.len() as f64)
.sum();
let draw_margin = draw_margin(draw_probability, beta, size);
let v_func: Box<dyn Fn(f64, f64) -> f64>;
let w_func: Box<dyn Fn(f64, f64) -> f64>;
let v_func: Box<dyn Fn(f64, f64, f64) -> f64>;
let w_func: Box<dyn Fn(f64, f64, f64) -> f64>;
if sorted_teams_and_ranks[x].1 == sorted_teams_and_ranks[x + 1].1 {
v_func = Box::new(v_d);
w_func = Box::new(w_d);
v_func = Box::new(v_draw);
w_func = Box::new(w_draw);
} else {
v_func = Box::new(v_w);
w_func = Box::new(w_w);
v_func = Box::new(v_non_draw);
w_func = Box::new(w_non_draw);
};

v.push(TruncateFactor::new(
Expand Down Expand Up @@ -2668,26 +2620,26 @@ mod tests {
let results = trueskill_multi_team(&teams_and_ranks, &TrueSkillConfig::new());

assert!((results[0][0].rating - 40.876_849_177_315_655).abs() < f64::EPSILON);
assert!((results[0][1].rating - 45.493_394_092_398_44).abs() < f64::EPSILON);
assert!((results[0][1].rating - 45.493_394_092_398_45).abs() < f64::EPSILON);

assert!((results[1][0].rating - 19.608_650_920_845_236).abs() < f64::EPSILON);
assert!((results[1][0].rating - 19.608_650_920_845_23).abs() < f64::EPSILON);
assert!((results[1][1].rating - 18.712_463_514_890_54).abs() < f64::EPSILON);
assert!((results[1][2].rating - 29.353_112_227_810_637).abs() < f64::EPSILON);
assert!((results[1][3].rating - 9.872_175_198_037_164).abs() < f64::EPSILON);

assert!((results[2][0].rating - 48.829_832_201_455_32).abs() < f64::EPSILON);
assert!((results[2][1].rating - 29.812_500_188_903_005).abs() < f64::EPSILON);
assert!((results[2][0].rating - 48.829_832_201_455_31).abs() < f64::EPSILON);
assert!((results[2][1].rating - 29.812_500_188_902_998).abs() < f64::EPSILON);

assert!((results[0][0].uncertainty - 3.839_527_589_355_37).abs() < f64::EPSILON);
assert!((results[0][0].uncertainty - 3.839_527_589_355_369_8).abs() < f64::EPSILON);
assert!((results[0][1].uncertainty - 2.933_671_613_522_051).abs() < f64::EPSILON);

assert!((results[1][0].uncertainty - 6.396_044_310_523_897).abs() < f64::EPSILON);
assert!((results[1][0].uncertainty - 6.396_044_310_523_896).abs() < f64::EPSILON);
assert!((results[1][1].uncertainty - 5.624_556_429_622_889).abs() < f64::EPSILON);
assert!((results[1][2].uncertainty - 7.673_456_361_986_594).abs() < f64::EPSILON);
assert!((results[1][2].uncertainty - 7.673_456_361_986_593).abs() < f64::EPSILON);
assert!((results[1][3].uncertainty - 3.891_408_425_994_520_3).abs() < f64::EPSILON);

assert!((results[2][0].uncertainty - 4.590_018_525_151_38).abs() < f64::EPSILON);
assert!((results[2][1].uncertainty - 1.976_314_792_712_798_2).abs() < f64::EPSILON);
assert!((results[2][0].uncertainty - 4.590_018_525_151_379).abs() < f64::EPSILON);
assert!((results[2][1].uncertainty - 1.976_314_792_712_798).abs() < f64::EPSILON);
}

#[test]
Expand All @@ -2714,7 +2666,7 @@ mod tests {

let results = trueskill_multi_team(teams_and_ranks, &TrueSkillConfig::new());

assert!((results[0][0].rating - 41.720_925_460_665).abs() < f64::EPSILON);
assert!((results[0][0].rating - 41.720_925_460_665_01).abs() < f64::EPSILON);
assert!((results[1][0].rating - 20.997_268_045_415_94).abs() < f64::EPSILON);
assert!((results[2][0].rating - 41.771_076_420_914_83).abs() < f64::EPSILON);

Expand Down Expand Up @@ -2753,15 +2705,15 @@ mod tests {

let results = trueskill_multi_team(teams_and_ranks, &TrueSkillConfig::new());

assert!((results[0][0].rating - 46.844_398_641_974_195).abs() < f64::EPSILON);
assert!((results[0][0].rating - 46.844_398_641_974_97).abs() < f64::EPSILON);
assert!((results[1][0].rating - -21.0).abs() < f64::EPSILON);
assert!((results[2][0].rating - 121.973_594_228_967_41).abs() < f64::EPSILON);
assert!((results[3][0].rating - 3.577_783_039_440_541_7).abs() < f64::EPSILON);
assert!((results[2][0].rating - 121.973_594_228_967_43).abs() < f64::EPSILON);
assert!((results[3][0].rating - 3.577_783_039_440_43).abs() < f64::EPSILON);

assert!((results[0][0].uncertainty - 4.453_979_220_473_841).abs() < f64::EPSILON);
assert!((results[0][0].uncertainty - 4.453_979_220_477_661).abs() < f64::EPSILON);
assert!((results[1][0].uncertainty - 1.871_855_882_391_709_3).abs() < f64::EPSILON);
assert!((results[2][0].uncertainty - 0.083_922_196_135_183_51).abs() < f64::EPSILON);
assert!((results[3][0].uncertainty - 1.197_926_990_096_083_4).abs() < f64::EPSILON);
assert!((results[2][0].uncertainty - 0.083_922_196_135_183_55).abs() < f64::EPSILON);
assert!((results[3][0].uncertainty - 1.197_926_990_096_302_3).abs() < f64::EPSILON);
}

#[test]
Expand Down

0 comments on commit 6bba1cc

Please sign in to comment.