From 4bee5dd7593430461420b0b0d03d7b53bd16d389 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Sun, 5 Feb 2023 15:26:34 +0000 Subject: [PATCH 01/14] scale sum product --- funsor/sum_product.py | 38 +++++++++++++----- test/test_sum_product.py | 87 +++++++++++++++++++++++++--------------- 2 files changed, 84 insertions(+), 41 deletions(-) diff --git a/funsor/sum_product.py b/funsor/sum_product.py index c59e1b8f..94bd19a5 100644 --- a/funsor/sum_product.py +++ b/funsor/sum_product.py @@ -4,7 +4,7 @@ import re from collections import OrderedDict, defaultdict from functools import reduce -from math import gcd +from math import gcd, prod import funsor import funsor.ops as ops @@ -203,7 +203,14 @@ def partial_unroll(factors, eliminate=frozenset(), plate_to_step=dict()): def partial_sum_product( - sum_op, prod_op, factors, eliminate=frozenset(), plates=frozenset(), pedantic=False + sum_op, + prod_op, + factors, + eliminate=frozenset(), + plates=frozenset(), + pedantic=False, + pow_op=None, + scales={}, ): """ Performs partial sum-product contraction of a collection of factors. @@ -217,6 +224,7 @@ def partial_sum_product( assert all(isinstance(f, Funsor) for f in factors) assert isinstance(eliminate, frozenset) assert isinstance(plates, frozenset) + assert isinstance(scales, dict) if pedantic: var_to_errors = defaultdict(lambda: eliminate) @@ -250,10 +258,15 @@ def partial_sum_product( leaf = max(ordinal_to_factors, key=len) # CHOICE leaf_factors = ordinal_to_factors.pop(leaf) leaf_reduce_vars = ordinal_to_vars[leaf] - for (group_factors, group_vars) in _partition( + leaf_scale = reduce( + ops.mul, [scales[plate] for plate in leaf if plate in scales], Number(1.0) + ) + for group_factors, group_vars in _partition( leaf_factors, leaf_reduce_vars ): # CHOICE f = reduce(prod_op, group_factors).reduce(sum_op, group_vars & eliminate) + if pow_op is not None: + f = pow_op(f, leaf_scale) remaining_sum_vars = sum_vars.intersection(f.inputs) if not remaining_sum_vars: results.append(f.reduce(prod_op, leaf & eliminate)) @@ -400,7 +413,7 @@ def dynamic_partial_sum_product( leaf = max(ordinal_to_factors, key=len) leaf_factors = ordinal_to_factors.pop(leaf) leaf_reduce_vars = ordinal_to_vars[leaf] - for (group_factors, group_vars) in _partition( + for group_factors, group_vars in _partition( leaf_factors, leaf_reduce_vars | markov_prod_vars ): # eliminate non markov vars @@ -529,7 +542,7 @@ def modified_partial_sum_product( leaf = max(ordinal_to_factors, key=len) leaf_factors = ordinal_to_factors.pop(leaf) leaf_reduce_vars = ordinal_to_vars[leaf] - for (group_factors, group_vars) in _partition( + for group_factors, group_vars in _partition( leaf_factors, leaf_reduce_vars | markov_prod_vars ): # eliminate non markov vars @@ -571,7 +584,14 @@ def modified_partial_sum_product( def sum_product( - sum_op, prod_op, factors, eliminate=frozenset(), plates=frozenset(), pedantic=False + sum_op, + prod_op, + factors, + eliminate=frozenset(), + plates=frozenset(), + pedantic=False, + pow_op=None, + scales={}, ): """ Performs sum-product contraction of a collection of factors. @@ -579,7 +599,9 @@ def sum_product( :return: a single contracted Funsor. :rtype: :class:`~funsor.terms.Funsor` """ - factors = partial_sum_product(sum_op, prod_op, factors, eliminate, plates, pedantic) + factors = partial_sum_product( + sum_op, prod_op, factors, eliminate, plates, pedantic, pow_op, scales + ) return reduce(prod_op, factors, Number(UNITS[prod_op])) @@ -780,7 +802,6 @@ def _shift_funsor(f, t, global_vars): def naive_sarkka_bilmes_product( sum_op, prod_op, trans, time_var, global_vars=frozenset() ): - assert isinstance(global_vars, frozenset) time = time_var.name @@ -818,7 +839,6 @@ def naive_sarkka_bilmes_product( def sarkka_bilmes_product( sum_op, prod_op, trans, time_var, global_vars=frozenset(), num_periods=1 ): - assert isinstance(global_vars, frozenset) time = time_var.name diff --git a/test/test_sum_product.py b/test/test_sum_product.py index a40ea3ae..d646603d 100644 --- a/test/test_sum_product.py +++ b/test/test_sum_product.py @@ -35,7 +35,7 @@ sum_product, ) from funsor.tensor import Tensor, get_default_prototype -from funsor.terms import Variable +from funsor.terms import Cat, Number, Variable from funsor.testing import assert_close, random_gaussian, random_tensor from funsor.util import get_backend @@ -368,7 +368,6 @@ def test_var_in_plate_ok(): def test_modified_partial_sum_product_0( impl, sum_op, prod_op, vars1, vars2, x_dim, time ): - f1 = random_tensor(OrderedDict({})) f2 = random_tensor(OrderedDict({"x_0": Bint[x_dim]})) @@ -414,7 +413,6 @@ def test_modified_partial_sum_product_0( def test_modified_partial_sum_product_1( impl, sum_op, prod_op, vars1, vars2, x_dim, y_dim, time ): - f1 = random_tensor(OrderedDict({})) f2 = random_tensor(OrderedDict({"x_0": Bint[x_dim]})) @@ -471,7 +469,6 @@ def test_modified_partial_sum_product_1( def test_modified_partial_sum_product_2( impl, sum_op, prod_op, vars1, vars2, x_dim, y_dim, time ): - f1 = random_tensor(OrderedDict({})) f2 = random_tensor(OrderedDict({"x_0": Bint[x_dim]})) @@ -530,7 +527,6 @@ def test_modified_partial_sum_product_2( def test_modified_partial_sum_product_3( impl, sum_op, prod_op, vars1, vars2, x_dim, y_dim, time ): - f1 = random_tensor(OrderedDict({})) f2 = random_tensor(OrderedDict({"x_0": Bint[x_dim]})) @@ -627,7 +623,6 @@ def test_modified_partial_sum_product_3( def test_modified_partial_sum_product_4( impl, sum_op, prod_op, vars1, vars2, x_dim, y_dim, sequences, time, tones ): - f1 = random_tensor(OrderedDict({})) f2 = random_tensor(OrderedDict({"sequences": Bint[sequences], "x_0": Bint[x_dim]})) @@ -758,7 +753,6 @@ def test_modified_partial_sum_product_4( def test_modified_partial_sum_product_5( impl, sum_op, prod_op, vars1, vars2, x_dim, y_dim, sequences, days, weeks, tones ): - f1 = random_tensor(OrderedDict({})) f2 = random_tensor( @@ -870,7 +864,6 @@ def test_modified_partial_sum_product_5( def test_modified_partial_sum_product_6( impl, sum_op, prod_op, vars1, vars2, x_dim, y_dim, sequences, time, tones ): - f1 = random_tensor(OrderedDict({})) f2 = random_tensor(OrderedDict({"sequences": Bint[sequences], "x_0": Bint[x_dim]})) @@ -986,7 +979,6 @@ def test_modified_partial_sum_product_6( def test_modified_partial_sum_product_7( impl, sum_op, prod_op, vars1, vars2, x_dim, y_dim, sequences, time, tones ): - f1 = random_tensor(OrderedDict({})) f2 = random_tensor(OrderedDict({"sequences": Bint[sequences], "x_0": Bint[x_dim]})) @@ -1125,7 +1117,6 @@ def test_modified_partial_sum_product_7( def test_modified_partial_sum_product_8( impl, sum_op, prod_op, vars1, vars2, w_dim, x_dim, y_dim, sequences, time, tones ): - f1 = random_tensor(OrderedDict({})) f2 = random_tensor(OrderedDict({"sequences": Bint[sequences], "w_0": Bint[w_dim]})) @@ -1298,7 +1289,6 @@ def test_modified_partial_sum_product_9( time, tones, ): - f1 = random_tensor(OrderedDict({})) f2 = random_tensor(OrderedDict({"sequences": Bint[sequences], "w_0": Bint[w_dim]})) @@ -1364,7 +1354,7 @@ def test_modified_partial_sum_product_9( "tones": {}, } - with (lazy if use_lazy else eager): + with lazy if use_lazy else eager: factors1 = impl(sum_op, prod_op, factors, vars1, plate_to_step) factors2 = impl(sum_op, prod_op, factors1, vars2, plate_to_step) actual = reduce(prod_op, factors2) @@ -1454,7 +1444,6 @@ def test_modified_partial_sum_product_9( def test_modified_partial_sum_product_10( impl, sum_op, prod_op, vars1, vars2, w_dim, x_dim, y_dim, sequences, time, tones ): - f1 = random_tensor(OrderedDict({})) f2 = random_tensor(OrderedDict({"sequences": Bint[sequences], "w_0": Bint[w_dim]})) @@ -1640,7 +1629,6 @@ def test_modified_partial_sum_product_11( time, tones, ): - f1 = random_tensor(OrderedDict({})) f2 = random_tensor(OrderedDict({"a": Bint[a_dim]})) @@ -1721,7 +1709,7 @@ def test_modified_partial_sum_product_11( "tones": {}, } - with (lazy if use_lazy else eager): + with lazy if use_lazy else eager: factors1 = impl(sum_op, prod_op, factors, vars1, plate_to_step) factors2 = impl(sum_op, prod_op, factors1, vars2, plate_to_step) actual = reduce(prod_op, factors2) @@ -1808,7 +1796,6 @@ def test_modified_partial_sum_product_11( def test_modified_partial_sum_product_12( impl, sum_op, prod_op, vars1, vars2, w_dim, x_dim, y_dim, sequences, time, tones ): - f1 = random_tensor(OrderedDict({})) f2 = random_tensor(OrderedDict({"sequences": Bint[sequences], "w_0": Bint[w_dim]})) @@ -1995,7 +1982,6 @@ def test_modified_partial_sum_product_13( weeks, tones, ): - f1 = random_tensor(OrderedDict({})) f2 = random_tensor( @@ -2048,7 +2034,7 @@ def test_modified_partial_sum_product_13( "weeks": frozenset({("y_0", "y_prev", "y_curr")}), } - with (lazy if use_lazy else eager): + with lazy if use_lazy else eager: factors1 = impl(sum_op, prod_op, factors, vars1, plate_to_step) factors2 = impl(sum_op, prod_op, factors1, vars2, plate_to_step) actual = reduce(prod_op, factors2) @@ -2152,7 +2138,6 @@ def test_modified_partial_sum_product_13( def test_modified_partial_sum_product_14( impl, sum_op, prod_op, vars1, vars2, x_dim, y_dim, sequences, time, tones ): - f1 = random_tensor(OrderedDict({})) f2 = random_tensor(OrderedDict({"sequences": Bint[sequences], "x_0": Bint[x_dim]})) @@ -2258,7 +2243,6 @@ def test_modified_partial_sum_product_14( def test_modified_partial_sum_product_16( impl, sum_op, prod_op, vars1, vars2, x_dim, y_dim, time ): - f1 = random_tensor(OrderedDict({})) f2 = random_tensor(OrderedDict({"x_0": Bint[x_dim]})) @@ -2352,7 +2336,6 @@ def test_modified_partial_sum_product_16( def test_modified_partial_sum_product_17( impl, use_lazy, sum_op, prod_op, vars1, vars2, x_dim, y_dim, z_dim, time ): - f1 = random_tensor(OrderedDict({})) f2 = random_tensor(OrderedDict({"x_0": Bint[x_dim]})) @@ -2425,7 +2408,7 @@ def test_modified_partial_sum_product_17( factors = [f1, f2, f3, f4, f5, f6, f7, f8, f9] plate_to_step = {"time": frozenset({("x_0", "x_prev", "x_curr")})} - with (lazy if use_lazy else eager): + with lazy if use_lazy else eager: factors1 = impl(sum_op, prod_op, factors, vars1, plate_to_step) factors2 = impl(sum_op, prod_op, factors1, vars2, plate_to_step) actual = reduce(prod_op, factors2) @@ -2648,7 +2631,6 @@ def test_sequential_sum_product_bias_2(num_steps, num_sensors, dim): def _check_sarkka_bilmes(trans, expected_inputs, global_vars, num_periods=1): - sum_op, prod_op = ops.logaddexp, ops.add assert "time" in trans.inputs @@ -2674,7 +2656,6 @@ def _check_sarkka_bilmes(trans, expected_inputs, global_vars, num_periods=1): @pytest.mark.parametrize("duration", [2, 3, 4, 5, 6]) def test_sarkka_bilmes_example_0(duration): - trans = random_tensor(OrderedDict({"time": Bint[duration], "a": Bint[3]})) expected_inputs = {"a": Bint[3]} @@ -2684,7 +2665,6 @@ def test_sarkka_bilmes_example_0(duration): @pytest.mark.parametrize("duration", [2, 3, 4, 5, 6]) def test_sarkka_bilmes_example_1(duration): - trans = random_tensor( OrderedDict( {"time": Bint[duration], "a": Bint[3], "b": Bint[2], "_PREV_b": Bint[2]} @@ -2698,7 +2678,6 @@ def test_sarkka_bilmes_example_1(duration): @pytest.mark.parametrize("duration", [2, 3, 4, 5, 6, 7, 8]) def test_sarkka_bilmes_example_2(duration): - trans = random_tensor( OrderedDict( { @@ -2726,7 +2705,6 @@ def test_sarkka_bilmes_example_2(duration): @pytest.mark.parametrize("duration", [2, 3, 4, 5, 6, 7, 8]) def test_sarkka_bilmes_example_3(duration): - trans = random_tensor( OrderedDict( { @@ -2750,7 +2728,6 @@ def test_sarkka_bilmes_example_3(duration): @pytest.mark.parametrize("duration", [3, 4, 5, 6, 7, 9]) def test_sarkka_bilmes_example_4(duration): - trans = random_tensor( OrderedDict( { @@ -2774,7 +2751,6 @@ def test_sarkka_bilmes_example_4(duration): @pytest.mark.parametrize("duration", [2, 3, 4, 5, 6]) def test_sarkka_bilmes_example_5(duration): - trans = random_tensor( OrderedDict( {"time": Bint[duration], "a": Bint[3], "_PREV_a": Bint[3], "x": Bint[2]} @@ -2790,7 +2766,6 @@ def test_sarkka_bilmes_example_5(duration): @pytest.mark.parametrize("duration", [3, 4, 5, 6, 7, 8, 9]) def test_sarkka_bilmes_example_6(duration): - trans = random_tensor( OrderedDict( { @@ -2866,7 +2841,6 @@ def test_sarkka_bilmes_example_6(duration): ) @pytest.mark.parametrize("num_periods", [1, 2]) def test_sarkka_bilmes_generic(time_input, global_inputs, local_inputs, num_periods): - lags = { kk: reduce( max, @@ -2907,7 +2881,6 @@ def test_sarkka_bilmes_generic(time_input, global_inputs, local_inputs, num_peri "duration,num_segments", [(12, 1), (12, 2), (12, 3), (12, 4), (12, 6)] ) def test_mixed_sequential_sum_product(duration, num_segments): - sum_op, prod_op = ops.logaddexp, ops.add time_var = Variable("time", Bint[duration]) step = {"_PREV_x": "x"} @@ -2926,3 +2899,53 @@ def test_mixed_sequential_sum_product(duration, num_segments): ) assert_close(actual, expected) + + +@pytest.mark.parametrize( + "sum_op,prod_op,pow_op", + [(ops.logaddexp, ops.add, ops.mul), (ops.add, ops.mul, ops.pow)], +) +@pytest.mark.parametrize("scale", [2, 3]) +def test_partial_sum_product_scale_1(sum_op, prod_op, pow_op, scale): + f1 = random_tensor(OrderedDict(a=Bint[2])) + f2 = random_tensor(OrderedDict(a=Bint[2], b=Bint[3])) + f3 = Cat("b", (f2,) * scale) + + eliminate = frozenset("ab") + plates = frozenset("b") + + factors = [f1, f3] + expected = sum_product(sum_op, prod_op, factors, eliminate, plates) + + factors = [f1, f2] + scales = {"b": Number(scale)} + actual = sum_product( + sum_op, prod_op, factors, eliminate, plates, pow_op=pow_op, scales=scales + ) + + assert_close(actual, expected, atol=5e-4, rtol=5e-4) + + +@pytest.mark.parametrize( + "sum_op,prod_op,pow_op", + [(ops.logaddexp, ops.add, ops.mul), (ops.add, ops.mul, ops.pow)], +) +@pytest.mark.parametrize("scale", [2, 3]) +def test_partial_sum_product_scale_2(sum_op, prod_op, pow_op, scale): + f1 = random_tensor(OrderedDict(a=Bint[2])) + f2 = random_tensor(OrderedDict(a=Bint[2], b=Bint[3])) + f3 = Cat("b", (f2,) * scale) + + eliminate = frozenset("ab") + plates = frozenset("b") + + factors = [f1, f3] + expected = sum_product(sum_op, prod_op, factors, eliminate, plates) + + factors = [f1, f2] + scales = {"b": Number(scale)} + actual = sum_product( + sum_op, prod_op, factors, eliminate, plates, pow_op=pow_op, scales=scales + ) + + assert_close(actual, expected, atol=5e-4, rtol=5e-4) From b84ed43ef729e8916b47bd414ae0776d84da9bc4 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Mon, 6 Feb 2023 01:06:15 +0000 Subject: [PATCH 02/14] fix --- funsor/sum_product.py | 28 +++++++++----- test/test_sum_product.py | 80 +++++++++++++++++++++++++++++----------- 2 files changed, 78 insertions(+), 30 deletions(-) diff --git a/funsor/sum_product.py b/funsor/sum_product.py index 94bd19a5..085c1c78 100644 --- a/funsor/sum_product.py +++ b/funsor/sum_product.py @@ -209,7 +209,6 @@ def partial_sum_product( eliminate=frozenset(), plates=frozenset(), pedantic=False, - pow_op=None, scales={}, ): """ @@ -226,6 +225,14 @@ def partial_sum_product( assert isinstance(plates, frozenset) assert isinstance(scales, dict) + if scales: + if sum_op is ops.logaddexp and prod_op is ops.add: + pow_op = ops.mul + elif sum_op is ops.add and prod_op is ops.mul: + pow_op = ops.pow + else: + raise ValueError("should not be here!") + if pedantic: var_to_errors = defaultdict(lambda: eliminate) for f in factors: @@ -258,18 +265,18 @@ def partial_sum_product( leaf = max(ordinal_to_factors, key=len) # CHOICE leaf_factors = ordinal_to_factors.pop(leaf) leaf_reduce_vars = ordinal_to_vars[leaf] - leaf_scale = reduce( - ops.mul, [scales[plate] for plate in leaf if plate in scales], Number(1.0) - ) for group_factors, group_vars in _partition( leaf_factors, leaf_reduce_vars ): # CHOICE f = reduce(prod_op, group_factors).reduce(sum_op, group_vars & eliminate) - if pow_op is not None: - f = pow_op(f, leaf_scale) remaining_sum_vars = sum_vars.intersection(f.inputs) if not remaining_sum_vars: - results.append(f.reduce(prod_op, leaf & eliminate)) + f = f.reduce(prod_op, leaf & eliminate) + f_scales = [scales[plate] for plate in leaf & eliminate if plate in scales] + if f_scales: + scale = reduce(ops.mul, f_scales) + f = pow_op(f, scale) + results.append(f) else: new_plates = frozenset().union( *(var_to_ordinal[v] for v in remaining_sum_vars) @@ -319,6 +326,10 @@ def partial_sum_product( reduced_plates = leaf - new_plates assert reduced_plates.issubset(eliminate) f = f.reduce(prod_op, reduced_plates) + f_scales = [scales[plate] for plate in reduced_plates if plate in scales] + if f_scales: + scale = reduce(ops.mul, f_scales) + f = pow_op(f, scale) ordinal_to_factors[new_plates].append(f) return results @@ -590,7 +601,6 @@ def sum_product( eliminate=frozenset(), plates=frozenset(), pedantic=False, - pow_op=None, scales={}, ): """ @@ -600,7 +610,7 @@ def sum_product( :rtype: :class:`~funsor.terms.Funsor` """ factors = partial_sum_product( - sum_op, prod_op, factors, eliminate, plates, pedantic, pow_op, scales + sum_op, prod_op, factors, eliminate, plates, pedantic, scales ) return reduce(prod_op, factors, Number(UNITS[prod_op])) diff --git a/test/test_sum_product.py b/test/test_sum_product.py index d646603d..b6ff1840 100644 --- a/test/test_sum_product.py +++ b/test/test_sum_product.py @@ -2902,50 +2902,88 @@ def test_mixed_sequential_sum_product(duration, num_segments): @pytest.mark.parametrize( - "sum_op,prod_op,pow_op", - [(ops.logaddexp, ops.add, ops.mul), (ops.add, ops.mul, ops.pow)], + "sum_op,prod_op", + [(ops.logaddexp, ops.add), (ops.add, ops.mul)], ) @pytest.mark.parametrize("scale", [2, 3]) -def test_partial_sum_product_scale_1(sum_op, prod_op, pow_op, scale): +def test_partial_sum_product_scale_1(sum_op, prod_op, scale): f1 = random_tensor(OrderedDict(a=Bint[2])) - f2 = random_tensor(OrderedDict(a=Bint[2], b=Bint[3])) - f3 = Cat("b", (f2,) * scale) + f2 = random_tensor(OrderedDict(a=Bint[2], i=Bint[3])) + f3 = Cat("i", (f2,) * scale) - eliminate = frozenset("ab") - plates = frozenset("b") + eliminate = frozenset("ai") + plates = frozenset("i") factors = [f1, f3] expected = sum_product(sum_op, prod_op, factors, eliminate, plates) factors = [f1, f2] - scales = {"b": Number(scale)} + scales = {"i": scale} actual = sum_product( - sum_op, prod_op, factors, eliminate, plates, pow_op=pow_op, scales=scales + sum_op, prod_op, factors, eliminate, plates, scales=scales ) assert_close(actual, expected, atol=5e-4, rtol=5e-4) @pytest.mark.parametrize( - "sum_op,prod_op,pow_op", - [(ops.logaddexp, ops.add, ops.mul), (ops.add, ops.mul, ops.pow)], + "sum_op,prod_op", + [(ops.logaddexp, ops.add), (ops.add, ops.mul)], ) -@pytest.mark.parametrize("scale", [2, 3]) -def test_partial_sum_product_scale_2(sum_op, prod_op, pow_op, scale): +@pytest.mark.parametrize("scale_i", [2, 3]) +@pytest.mark.parametrize("scale_j", [2, 3]) +def test_partial_sum_product_scale_2(sum_op, prod_op, scale_i, scale_j): f1 = random_tensor(OrderedDict(a=Bint[2])) - f2 = random_tensor(OrderedDict(a=Bint[2], b=Bint[3])) - f3 = Cat("b", (f2,) * scale) + f2 = random_tensor(OrderedDict(a=Bint[2], i=Bint[3])) + f3 = random_tensor(OrderedDict(a=Bint[2], j=Bint[4])) + f4 = Cat("i", (f2,) * scale_i) + f5 = Cat("j", (f3,) * scale_j) - eliminate = frozenset("ab") - plates = frozenset("b") + eliminate = frozenset("aij") + plates = frozenset("ij") - factors = [f1, f3] + factors = [f1, f4, f5] expected = sum_product(sum_op, prod_op, factors, eliminate, plates) - factors = [f1, f2] - scales = {"b": Number(scale)} + factors = [f1, f2, f3] + scales = {"i": scale_i, "j": scale_j} + actual = sum_product( + sum_op, prod_op, factors, eliminate, plates, scales=scales + ) + + assert_close(actual, expected, atol=5e-4, rtol=5e-4) + + +@pytest.mark.parametrize( + "sum_op,prod_op", + [(ops.logaddexp, ops.add), (ops.add, ops.mul)], +) +@pytest.mark.parametrize("scale_i", [2, 3]) +@pytest.mark.parametrize("scale_j", [2, 3]) +@pytest.mark.parametrize("scale_k", [2, 3]) +def test_partial_sum_product_scale_3(sum_op, prod_op, scale_i, scale_j, scale_k): + f1 = random_tensor(OrderedDict(a=Bint[2], i=Bint[2])) + f2 = random_tensor(OrderedDict(a=Bint[2], i=Bint[2], j=Bint[3])) + f3 = random_tensor(OrderedDict(a=Bint[2], i=Bint[2], j=Bint[3], k=Bint[3])) + f4 = Cat("i", (f1,) * scale_i) + # concatenate across multiple dims + f5 = Cat("i", (f2,) * scale_i) + f5 = Cat("j", (f5,) * scale_j) + # concatenate across multiple dims + f6 = Cat("i", (f3,) * scale_i) + f6 = Cat("j", (f6,) * scale_j) + f6 = Cat("k", (f6,) * scale_k) + + eliminate = frozenset("aijk") + plates = frozenset("ijk") + + factors = [f4, f5, f6] + expected = sum_product(sum_op, prod_op, factors, eliminate, plates) + + factors = [f1, f2, f3] + scales = {"i": scale_i, "j": scale_j, "k": scale_k} actual = sum_product( - sum_op, prod_op, factors, eliminate, plates, pow_op=pow_op, scales=scales + sum_op, prod_op, factors, eliminate, plates, scales=scales ) assert_close(actual, expected, atol=5e-4, rtol=5e-4) From 87c10e02d9e152126838d91612eaae271c6fab46 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Mon, 6 Feb 2023 01:26:09 +0000 Subject: [PATCH 03/14] plate_to_scale --- funsor/sum_product.py | 16 ++++++++-------- test/test_sum_product.py | 8 ++++---- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/funsor/sum_product.py b/funsor/sum_product.py index 085c1c78..cd28c9bc 100644 --- a/funsor/sum_product.py +++ b/funsor/sum_product.py @@ -4,7 +4,7 @@ import re from collections import OrderedDict, defaultdict from functools import reduce -from math import gcd, prod +from math import gcd import funsor import funsor.ops as ops @@ -209,7 +209,7 @@ def partial_sum_product( eliminate=frozenset(), plates=frozenset(), pedantic=False, - scales={}, + plate_to_scale={}, ): """ Performs partial sum-product contraction of a collection of factors. @@ -223,9 +223,9 @@ def partial_sum_product( assert all(isinstance(f, Funsor) for f in factors) assert isinstance(eliminate, frozenset) assert isinstance(plates, frozenset) - assert isinstance(scales, dict) + assert isinstance(plate_to_scale, dict) - if scales: + if plate_to_scale: if sum_op is ops.logaddexp and prod_op is ops.add: pow_op = ops.mul elif sum_op is ops.add and prod_op is ops.mul: @@ -272,7 +272,7 @@ def partial_sum_product( remaining_sum_vars = sum_vars.intersection(f.inputs) if not remaining_sum_vars: f = f.reduce(prod_op, leaf & eliminate) - f_scales = [scales[plate] for plate in leaf & eliminate if plate in scales] + f_scales = [plate_to_scale[plate] for plate in leaf & eliminate if plate in plate_to_scale] if f_scales: scale = reduce(ops.mul, f_scales) f = pow_op(f, scale) @@ -326,7 +326,7 @@ def partial_sum_product( reduced_plates = leaf - new_plates assert reduced_plates.issubset(eliminate) f = f.reduce(prod_op, reduced_plates) - f_scales = [scales[plate] for plate in reduced_plates if plate in scales] + f_scales = [plate_to_scale[plate] for plate in reduced_plates if plate in plate_to_scale] if f_scales: scale = reduce(ops.mul, f_scales) f = pow_op(f, scale) @@ -601,7 +601,7 @@ def sum_product( eliminate=frozenset(), plates=frozenset(), pedantic=False, - scales={}, + plate_to_scale={}, ): """ Performs sum-product contraction of a collection of factors. @@ -610,7 +610,7 @@ def sum_product( :rtype: :class:`~funsor.terms.Funsor` """ factors = partial_sum_product( - sum_op, prod_op, factors, eliminate, plates, pedantic, scales + sum_op, prod_op, factors, eliminate, plates, pedantic, plate_to_scale ) return reduce(prod_op, factors, Number(UNITS[prod_op])) diff --git a/test/test_sum_product.py b/test/test_sum_product.py index b6ff1840..09a97316 100644 --- a/test/test_sum_product.py +++ b/test/test_sum_product.py @@ -35,7 +35,7 @@ sum_product, ) from funsor.tensor import Tensor, get_default_prototype -from funsor.terms import Cat, Number, Variable +from funsor.terms import Cat, Variable from funsor.testing import assert_close, random_gaussian, random_tensor from funsor.util import get_backend @@ -2920,7 +2920,7 @@ def test_partial_sum_product_scale_1(sum_op, prod_op, scale): factors = [f1, f2] scales = {"i": scale} actual = sum_product( - sum_op, prod_op, factors, eliminate, plates, scales=scales + sum_op, prod_op, factors, eliminate, plates, plate_to_scale=scales ) assert_close(actual, expected, atol=5e-4, rtol=5e-4) @@ -2948,7 +2948,7 @@ def test_partial_sum_product_scale_2(sum_op, prod_op, scale_i, scale_j): factors = [f1, f2, f3] scales = {"i": scale_i, "j": scale_j} actual = sum_product( - sum_op, prod_op, factors, eliminate, plates, scales=scales + sum_op, prod_op, factors, eliminate, plates, plate_to_scale=scales ) assert_close(actual, expected, atol=5e-4, rtol=5e-4) @@ -2983,7 +2983,7 @@ def test_partial_sum_product_scale_3(sum_op, prod_op, scale_i, scale_j, scale_k) factors = [f1, f2, f3] scales = {"i": scale_i, "j": scale_j, "k": scale_k} actual = sum_product( - sum_op, prod_op, factors, eliminate, plates, scales=scales + sum_op, prod_op, factors, eliminate, plates, plate_to_scale=scales ) assert_close(actual, expected, atol=5e-4, rtol=5e-4) From 6a3309596002e834764b003d6cbc9ebe19cce717 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Thu, 16 Feb 2023 19:09:45 +0000 Subject: [PATCH 04/14] lint --- funsor/sum_product.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/funsor/sum_product.py b/funsor/sum_product.py index cd28c9bc..c4460238 100644 --- a/funsor/sum_product.py +++ b/funsor/sum_product.py @@ -272,7 +272,11 @@ def partial_sum_product( remaining_sum_vars = sum_vars.intersection(f.inputs) if not remaining_sum_vars: f = f.reduce(prod_op, leaf & eliminate) - f_scales = [plate_to_scale[plate] for plate in leaf & eliminate if plate in plate_to_scale] + f_scales = [ + plate_to_scale[plate] + for plate in leaf & eliminate + if plate in plate_to_scale + ] if f_scales: scale = reduce(ops.mul, f_scales) f = pow_op(f, scale) @@ -326,7 +330,11 @@ def partial_sum_product( reduced_plates = leaf - new_plates assert reduced_plates.issubset(eliminate) f = f.reduce(prod_op, reduced_plates) - f_scales = [plate_to_scale[plate] for plate in reduced_plates if plate in plate_to_scale] + f_scales = [ + plate_to_scale[plate] + for plate in reduced_plates + if plate in plate_to_scale + ] if f_scales: scale = reduce(ops.mul, f_scales) f = pow_op(f, scale) From 6f662eb550ec7d21061ffbaa1973f5ee50b51c1c Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Thu, 16 Feb 2023 21:21:43 +0000 Subject: [PATCH 05/14] test --- test/test_sum_product.py | 62 ++++++++++++++++++++++------------------ 1 file changed, 34 insertions(+), 28 deletions(-) diff --git a/test/test_sum_product.py b/test/test_sum_product.py index 09a97316..ea55d4c5 100644 --- a/test/test_sum_product.py +++ b/test/test_sum_product.py @@ -2905,66 +2905,82 @@ def test_mixed_sequential_sum_product(duration, num_segments): "sum_op,prod_op", [(ops.logaddexp, ops.add), (ops.add, ops.mul)], ) -@pytest.mark.parametrize("scale", [2, 3]) +@pytest.mark.parametrize("scale", [1, 2]) def test_partial_sum_product_scale_1(sum_op, prod_op, scale): f1 = random_tensor(OrderedDict(a=Bint[2])) f2 = random_tensor(OrderedDict(a=Bint[2], i=Bint[3])) - f3 = Cat("i", (f2,) * scale) eliminate = frozenset("ai") plates = frozenset("i") - factors = [f1, f3] - expected = sum_product(sum_op, prod_op, factors, eliminate, plates) - + # Actual result based on applying scaling factors = [f1, f2] scales = {"i": scale} actual = sum_product( sum_op, prod_op, factors, eliminate, plates, plate_to_scale=scales ) - assert_close(actual, expected, atol=5e-4, rtol=5e-4) + # Expected result based on concatenating factors + f3 = Cat("i", (f2,) * scale) + factors = [f1, f3] + expected = sum_product(sum_op, prod_op, factors, eliminate, plates) + + assert_close(actual, expected, atol=1e-4, rtol=1e-4) @pytest.mark.parametrize( "sum_op,prod_op", [(ops.logaddexp, ops.add), (ops.add, ops.mul)], ) -@pytest.mark.parametrize("scale_i", [2, 3]) -@pytest.mark.parametrize("scale_j", [2, 3]) +@pytest.mark.parametrize("scale_i", [1, 2]) +@pytest.mark.parametrize("scale_j", [1, 3]) def test_partial_sum_product_scale_2(sum_op, prod_op, scale_i, scale_j): f1 = random_tensor(OrderedDict(a=Bint[2])) f2 = random_tensor(OrderedDict(a=Bint[2], i=Bint[3])) f3 = random_tensor(OrderedDict(a=Bint[2], j=Bint[4])) - f4 = Cat("i", (f2,) * scale_i) - f5 = Cat("j", (f3,) * scale_j) eliminate = frozenset("aij") plates = frozenset("ij") - factors = [f1, f4, f5] - expected = sum_product(sum_op, prod_op, factors, eliminate, plates) - + # Actual result based on applying scaling factors = [f1, f2, f3] scales = {"i": scale_i, "j": scale_j} actual = sum_product( sum_op, prod_op, factors, eliminate, plates, plate_to_scale=scales ) - assert_close(actual, expected, atol=5e-4, rtol=5e-4) + # Expected result based on concatenating factors + f4 = Cat("i", (f2,) * scale_i) + f5 = Cat("j", (f3,) * scale_j) + factors = [f1, f4, f5] + expected = sum_product(sum_op, prod_op, factors, eliminate, plates) + + assert_close(actual, expected, atol=1e-4, rtol=1e-4) @pytest.mark.parametrize( "sum_op,prod_op", [(ops.logaddexp, ops.add), (ops.add, ops.mul)], ) -@pytest.mark.parametrize("scale_i", [2, 3]) -@pytest.mark.parametrize("scale_j", [2, 3]) -@pytest.mark.parametrize("scale_k", [2, 3]) +@pytest.mark.parametrize("scale_i", [1, 2]) +@pytest.mark.parametrize("scale_j", [1, 3]) +@pytest.mark.parametrize("scale_k", [1, 4]) def test_partial_sum_product_scale_3(sum_op, prod_op, scale_i, scale_j, scale_k): f1 = random_tensor(OrderedDict(a=Bint[2], i=Bint[2])) f2 = random_tensor(OrderedDict(a=Bint[2], i=Bint[2], j=Bint[3])) f3 = random_tensor(OrderedDict(a=Bint[2], i=Bint[2], j=Bint[3], k=Bint[3])) + + eliminate = frozenset("aijk") + plates = frozenset("ijk") + + # Actual result based on applying scaling + factors = [f1, f2, f3] + scales = {"i": scale_i, "j": scale_j, "k": scale_k} + actual = sum_product( + sum_op, prod_op, factors, eliminate, plates, plate_to_scale=scales + ) + + # Expected result based on concatenating factors f4 = Cat("i", (f1,) * scale_i) # concatenate across multiple dims f5 = Cat("i", (f2,) * scale_i) @@ -2973,17 +2989,7 @@ def test_partial_sum_product_scale_3(sum_op, prod_op, scale_i, scale_j, scale_k) f6 = Cat("i", (f3,) * scale_i) f6 = Cat("j", (f6,) * scale_j) f6 = Cat("k", (f6,) * scale_k) - - eliminate = frozenset("aijk") - plates = frozenset("ijk") - factors = [f4, f5, f6] expected = sum_product(sum_op, prod_op, factors, eliminate, plates) - factors = [f1, f2, f3] - scales = {"i": scale_i, "j": scale_j, "k": scale_k} - actual = sum_product( - sum_op, prod_op, factors, eliminate, plates, plate_to_scale=scales - ) - - assert_close(actual, expected, atol=5e-4, rtol=5e-4) + assert_close(actual, expected, atol=1e-4, rtol=1e-4) From b66b3ac6b315130d3e5f2832553adcee185029c9 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Sun, 5 Feb 2023 15:26:34 +0000 Subject: [PATCH 06/14] scale sum product --- funsor/sum_product.py | 30 +++++++++++++++++++---- test/test_sum_product.py | 52 +++++++++++++++++++++++++++++++++++++++- 2 files changed, 77 insertions(+), 5 deletions(-) diff --git a/funsor/sum_product.py b/funsor/sum_product.py index 751fbc37..94bd19a5 100644 --- a/funsor/sum_product.py +++ b/funsor/sum_product.py @@ -4,7 +4,7 @@ import re from collections import OrderedDict, defaultdict from functools import reduce -from math import gcd +from math import gcd, prod import funsor import funsor.ops as ops @@ -203,7 +203,14 @@ def partial_unroll(factors, eliminate=frozenset(), plate_to_step=dict()): def partial_sum_product( - sum_op, prod_op, factors, eliminate=frozenset(), plates=frozenset(), pedantic=False + sum_op, + prod_op, + factors, + eliminate=frozenset(), + plates=frozenset(), + pedantic=False, + pow_op=None, + scales={}, ): """ Performs partial sum-product contraction of a collection of factors. @@ -217,6 +224,7 @@ def partial_sum_product( assert all(isinstance(f, Funsor) for f in factors) assert isinstance(eliminate, frozenset) assert isinstance(plates, frozenset) + assert isinstance(scales, dict) if pedantic: var_to_errors = defaultdict(lambda: eliminate) @@ -250,10 +258,15 @@ def partial_sum_product( leaf = max(ordinal_to_factors, key=len) # CHOICE leaf_factors = ordinal_to_factors.pop(leaf) leaf_reduce_vars = ordinal_to_vars[leaf] + leaf_scale = reduce( + ops.mul, [scales[plate] for plate in leaf if plate in scales], Number(1.0) + ) for group_factors, group_vars in _partition( leaf_factors, leaf_reduce_vars ): # CHOICE f = reduce(prod_op, group_factors).reduce(sum_op, group_vars & eliminate) + if pow_op is not None: + f = pow_op(f, leaf_scale) remaining_sum_vars = sum_vars.intersection(f.inputs) if not remaining_sum_vars: results.append(f.reduce(prod_op, leaf & eliminate)) @@ -571,7 +584,14 @@ def modified_partial_sum_product( def sum_product( - sum_op, prod_op, factors, eliminate=frozenset(), plates=frozenset(), pedantic=False + sum_op, + prod_op, + factors, + eliminate=frozenset(), + plates=frozenset(), + pedantic=False, + pow_op=None, + scales={}, ): """ Performs sum-product contraction of a collection of factors. @@ -579,7 +599,9 @@ def sum_product( :return: a single contracted Funsor. :rtype: :class:`~funsor.terms.Funsor` """ - factors = partial_sum_product(sum_op, prod_op, factors, eliminate, plates, pedantic) + factors = partial_sum_product( + sum_op, prod_op, factors, eliminate, plates, pedantic, pow_op, scales + ) return reduce(prod_op, factors, Number(UNITS[prod_op])) diff --git a/test/test_sum_product.py b/test/test_sum_product.py index b288cb45..d646603d 100644 --- a/test/test_sum_product.py +++ b/test/test_sum_product.py @@ -35,7 +35,7 @@ sum_product, ) from funsor.tensor import Tensor, get_default_prototype -from funsor.terms import Variable +from funsor.terms import Cat, Number, Variable from funsor.testing import assert_close, random_gaussian, random_tensor from funsor.util import get_backend @@ -2899,3 +2899,53 @@ def test_mixed_sequential_sum_product(duration, num_segments): ) assert_close(actual, expected) + + +@pytest.mark.parametrize( + "sum_op,prod_op,pow_op", + [(ops.logaddexp, ops.add, ops.mul), (ops.add, ops.mul, ops.pow)], +) +@pytest.mark.parametrize("scale", [2, 3]) +def test_partial_sum_product_scale_1(sum_op, prod_op, pow_op, scale): + f1 = random_tensor(OrderedDict(a=Bint[2])) + f2 = random_tensor(OrderedDict(a=Bint[2], b=Bint[3])) + f3 = Cat("b", (f2,) * scale) + + eliminate = frozenset("ab") + plates = frozenset("b") + + factors = [f1, f3] + expected = sum_product(sum_op, prod_op, factors, eliminate, plates) + + factors = [f1, f2] + scales = {"b": Number(scale)} + actual = sum_product( + sum_op, prod_op, factors, eliminate, plates, pow_op=pow_op, scales=scales + ) + + assert_close(actual, expected, atol=5e-4, rtol=5e-4) + + +@pytest.mark.parametrize( + "sum_op,prod_op,pow_op", + [(ops.logaddexp, ops.add, ops.mul), (ops.add, ops.mul, ops.pow)], +) +@pytest.mark.parametrize("scale", [2, 3]) +def test_partial_sum_product_scale_2(sum_op, prod_op, pow_op, scale): + f1 = random_tensor(OrderedDict(a=Bint[2])) + f2 = random_tensor(OrderedDict(a=Bint[2], b=Bint[3])) + f3 = Cat("b", (f2,) * scale) + + eliminate = frozenset("ab") + plates = frozenset("b") + + factors = [f1, f3] + expected = sum_product(sum_op, prod_op, factors, eliminate, plates) + + factors = [f1, f2] + scales = {"b": Number(scale)} + actual = sum_product( + sum_op, prod_op, factors, eliminate, plates, pow_op=pow_op, scales=scales + ) + + assert_close(actual, expected, atol=5e-4, rtol=5e-4) From d249c4fdc298b6ecafac3c9af3920077a8b32f3f Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Mon, 6 Feb 2023 01:06:15 +0000 Subject: [PATCH 07/14] fix --- funsor/sum_product.py | 28 +++++++++----- test/test_sum_product.py | 80 +++++++++++++++++++++++++++++----------- 2 files changed, 78 insertions(+), 30 deletions(-) diff --git a/funsor/sum_product.py b/funsor/sum_product.py index 94bd19a5..085c1c78 100644 --- a/funsor/sum_product.py +++ b/funsor/sum_product.py @@ -209,7 +209,6 @@ def partial_sum_product( eliminate=frozenset(), plates=frozenset(), pedantic=False, - pow_op=None, scales={}, ): """ @@ -226,6 +225,14 @@ def partial_sum_product( assert isinstance(plates, frozenset) assert isinstance(scales, dict) + if scales: + if sum_op is ops.logaddexp and prod_op is ops.add: + pow_op = ops.mul + elif sum_op is ops.add and prod_op is ops.mul: + pow_op = ops.pow + else: + raise ValueError("should not be here!") + if pedantic: var_to_errors = defaultdict(lambda: eliminate) for f in factors: @@ -258,18 +265,18 @@ def partial_sum_product( leaf = max(ordinal_to_factors, key=len) # CHOICE leaf_factors = ordinal_to_factors.pop(leaf) leaf_reduce_vars = ordinal_to_vars[leaf] - leaf_scale = reduce( - ops.mul, [scales[plate] for plate in leaf if plate in scales], Number(1.0) - ) for group_factors, group_vars in _partition( leaf_factors, leaf_reduce_vars ): # CHOICE f = reduce(prod_op, group_factors).reduce(sum_op, group_vars & eliminate) - if pow_op is not None: - f = pow_op(f, leaf_scale) remaining_sum_vars = sum_vars.intersection(f.inputs) if not remaining_sum_vars: - results.append(f.reduce(prod_op, leaf & eliminate)) + f = f.reduce(prod_op, leaf & eliminate) + f_scales = [scales[plate] for plate in leaf & eliminate if plate in scales] + if f_scales: + scale = reduce(ops.mul, f_scales) + f = pow_op(f, scale) + results.append(f) else: new_plates = frozenset().union( *(var_to_ordinal[v] for v in remaining_sum_vars) @@ -319,6 +326,10 @@ def partial_sum_product( reduced_plates = leaf - new_plates assert reduced_plates.issubset(eliminate) f = f.reduce(prod_op, reduced_plates) + f_scales = [scales[plate] for plate in reduced_plates if plate in scales] + if f_scales: + scale = reduce(ops.mul, f_scales) + f = pow_op(f, scale) ordinal_to_factors[new_plates].append(f) return results @@ -590,7 +601,6 @@ def sum_product( eliminate=frozenset(), plates=frozenset(), pedantic=False, - pow_op=None, scales={}, ): """ @@ -600,7 +610,7 @@ def sum_product( :rtype: :class:`~funsor.terms.Funsor` """ factors = partial_sum_product( - sum_op, prod_op, factors, eliminate, plates, pedantic, pow_op, scales + sum_op, prod_op, factors, eliminate, plates, pedantic, scales ) return reduce(prod_op, factors, Number(UNITS[prod_op])) diff --git a/test/test_sum_product.py b/test/test_sum_product.py index d646603d..b6ff1840 100644 --- a/test/test_sum_product.py +++ b/test/test_sum_product.py @@ -2902,50 +2902,88 @@ def test_mixed_sequential_sum_product(duration, num_segments): @pytest.mark.parametrize( - "sum_op,prod_op,pow_op", - [(ops.logaddexp, ops.add, ops.mul), (ops.add, ops.mul, ops.pow)], + "sum_op,prod_op", + [(ops.logaddexp, ops.add), (ops.add, ops.mul)], ) @pytest.mark.parametrize("scale", [2, 3]) -def test_partial_sum_product_scale_1(sum_op, prod_op, pow_op, scale): +def test_partial_sum_product_scale_1(sum_op, prod_op, scale): f1 = random_tensor(OrderedDict(a=Bint[2])) - f2 = random_tensor(OrderedDict(a=Bint[2], b=Bint[3])) - f3 = Cat("b", (f2,) * scale) + f2 = random_tensor(OrderedDict(a=Bint[2], i=Bint[3])) + f3 = Cat("i", (f2,) * scale) - eliminate = frozenset("ab") - plates = frozenset("b") + eliminate = frozenset("ai") + plates = frozenset("i") factors = [f1, f3] expected = sum_product(sum_op, prod_op, factors, eliminate, plates) factors = [f1, f2] - scales = {"b": Number(scale)} + scales = {"i": scale} actual = sum_product( - sum_op, prod_op, factors, eliminate, plates, pow_op=pow_op, scales=scales + sum_op, prod_op, factors, eliminate, plates, scales=scales ) assert_close(actual, expected, atol=5e-4, rtol=5e-4) @pytest.mark.parametrize( - "sum_op,prod_op,pow_op", - [(ops.logaddexp, ops.add, ops.mul), (ops.add, ops.mul, ops.pow)], + "sum_op,prod_op", + [(ops.logaddexp, ops.add), (ops.add, ops.mul)], ) -@pytest.mark.parametrize("scale", [2, 3]) -def test_partial_sum_product_scale_2(sum_op, prod_op, pow_op, scale): +@pytest.mark.parametrize("scale_i", [2, 3]) +@pytest.mark.parametrize("scale_j", [2, 3]) +def test_partial_sum_product_scale_2(sum_op, prod_op, scale_i, scale_j): f1 = random_tensor(OrderedDict(a=Bint[2])) - f2 = random_tensor(OrderedDict(a=Bint[2], b=Bint[3])) - f3 = Cat("b", (f2,) * scale) + f2 = random_tensor(OrderedDict(a=Bint[2], i=Bint[3])) + f3 = random_tensor(OrderedDict(a=Bint[2], j=Bint[4])) + f4 = Cat("i", (f2,) * scale_i) + f5 = Cat("j", (f3,) * scale_j) - eliminate = frozenset("ab") - plates = frozenset("b") + eliminate = frozenset("aij") + plates = frozenset("ij") - factors = [f1, f3] + factors = [f1, f4, f5] expected = sum_product(sum_op, prod_op, factors, eliminate, plates) - factors = [f1, f2] - scales = {"b": Number(scale)} + factors = [f1, f2, f3] + scales = {"i": scale_i, "j": scale_j} + actual = sum_product( + sum_op, prod_op, factors, eliminate, plates, scales=scales + ) + + assert_close(actual, expected, atol=5e-4, rtol=5e-4) + + +@pytest.mark.parametrize( + "sum_op,prod_op", + [(ops.logaddexp, ops.add), (ops.add, ops.mul)], +) +@pytest.mark.parametrize("scale_i", [2, 3]) +@pytest.mark.parametrize("scale_j", [2, 3]) +@pytest.mark.parametrize("scale_k", [2, 3]) +def test_partial_sum_product_scale_3(sum_op, prod_op, scale_i, scale_j, scale_k): + f1 = random_tensor(OrderedDict(a=Bint[2], i=Bint[2])) + f2 = random_tensor(OrderedDict(a=Bint[2], i=Bint[2], j=Bint[3])) + f3 = random_tensor(OrderedDict(a=Bint[2], i=Bint[2], j=Bint[3], k=Bint[3])) + f4 = Cat("i", (f1,) * scale_i) + # concatenate across multiple dims + f5 = Cat("i", (f2,) * scale_i) + f5 = Cat("j", (f5,) * scale_j) + # concatenate across multiple dims + f6 = Cat("i", (f3,) * scale_i) + f6 = Cat("j", (f6,) * scale_j) + f6 = Cat("k", (f6,) * scale_k) + + eliminate = frozenset("aijk") + plates = frozenset("ijk") + + factors = [f4, f5, f6] + expected = sum_product(sum_op, prod_op, factors, eliminate, plates) + + factors = [f1, f2, f3] + scales = {"i": scale_i, "j": scale_j, "k": scale_k} actual = sum_product( - sum_op, prod_op, factors, eliminate, plates, pow_op=pow_op, scales=scales + sum_op, prod_op, factors, eliminate, plates, scales=scales ) assert_close(actual, expected, atol=5e-4, rtol=5e-4) From 7c5f9f39576f9cd5db886e4ef06b644f268432bb Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Mon, 6 Feb 2023 01:26:09 +0000 Subject: [PATCH 08/14] plate_to_scale --- funsor/sum_product.py | 16 ++++++++-------- test/test_sum_product.py | 8 ++++---- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/funsor/sum_product.py b/funsor/sum_product.py index 085c1c78..cd28c9bc 100644 --- a/funsor/sum_product.py +++ b/funsor/sum_product.py @@ -4,7 +4,7 @@ import re from collections import OrderedDict, defaultdict from functools import reduce -from math import gcd, prod +from math import gcd import funsor import funsor.ops as ops @@ -209,7 +209,7 @@ def partial_sum_product( eliminate=frozenset(), plates=frozenset(), pedantic=False, - scales={}, + plate_to_scale={}, ): """ Performs partial sum-product contraction of a collection of factors. @@ -223,9 +223,9 @@ def partial_sum_product( assert all(isinstance(f, Funsor) for f in factors) assert isinstance(eliminate, frozenset) assert isinstance(plates, frozenset) - assert isinstance(scales, dict) + assert isinstance(plate_to_scale, dict) - if scales: + if plate_to_scale: if sum_op is ops.logaddexp and prod_op is ops.add: pow_op = ops.mul elif sum_op is ops.add and prod_op is ops.mul: @@ -272,7 +272,7 @@ def partial_sum_product( remaining_sum_vars = sum_vars.intersection(f.inputs) if not remaining_sum_vars: f = f.reduce(prod_op, leaf & eliminate) - f_scales = [scales[plate] for plate in leaf & eliminate if plate in scales] + f_scales = [plate_to_scale[plate] for plate in leaf & eliminate if plate in plate_to_scale] if f_scales: scale = reduce(ops.mul, f_scales) f = pow_op(f, scale) @@ -326,7 +326,7 @@ def partial_sum_product( reduced_plates = leaf - new_plates assert reduced_plates.issubset(eliminate) f = f.reduce(prod_op, reduced_plates) - f_scales = [scales[plate] for plate in reduced_plates if plate in scales] + f_scales = [plate_to_scale[plate] for plate in reduced_plates if plate in plate_to_scale] if f_scales: scale = reduce(ops.mul, f_scales) f = pow_op(f, scale) @@ -601,7 +601,7 @@ def sum_product( eliminate=frozenset(), plates=frozenset(), pedantic=False, - scales={}, + plate_to_scale={}, ): """ Performs sum-product contraction of a collection of factors. @@ -610,7 +610,7 @@ def sum_product( :rtype: :class:`~funsor.terms.Funsor` """ factors = partial_sum_product( - sum_op, prod_op, factors, eliminate, plates, pedantic, scales + sum_op, prod_op, factors, eliminate, plates, pedantic, plate_to_scale ) return reduce(prod_op, factors, Number(UNITS[prod_op])) diff --git a/test/test_sum_product.py b/test/test_sum_product.py index b6ff1840..09a97316 100644 --- a/test/test_sum_product.py +++ b/test/test_sum_product.py @@ -35,7 +35,7 @@ sum_product, ) from funsor.tensor import Tensor, get_default_prototype -from funsor.terms import Cat, Number, Variable +from funsor.terms import Cat, Variable from funsor.testing import assert_close, random_gaussian, random_tensor from funsor.util import get_backend @@ -2920,7 +2920,7 @@ def test_partial_sum_product_scale_1(sum_op, prod_op, scale): factors = [f1, f2] scales = {"i": scale} actual = sum_product( - sum_op, prod_op, factors, eliminate, plates, scales=scales + sum_op, prod_op, factors, eliminate, plates, plate_to_scale=scales ) assert_close(actual, expected, atol=5e-4, rtol=5e-4) @@ -2948,7 +2948,7 @@ def test_partial_sum_product_scale_2(sum_op, prod_op, scale_i, scale_j): factors = [f1, f2, f3] scales = {"i": scale_i, "j": scale_j} actual = sum_product( - sum_op, prod_op, factors, eliminate, plates, scales=scales + sum_op, prod_op, factors, eliminate, plates, plate_to_scale=scales ) assert_close(actual, expected, atol=5e-4, rtol=5e-4) @@ -2983,7 +2983,7 @@ def test_partial_sum_product_scale_3(sum_op, prod_op, scale_i, scale_j, scale_k) factors = [f1, f2, f3] scales = {"i": scale_i, "j": scale_j, "k": scale_k} actual = sum_product( - sum_op, prod_op, factors, eliminate, plates, scales=scales + sum_op, prod_op, factors, eliminate, plates, plate_to_scale=scales ) assert_close(actual, expected, atol=5e-4, rtol=5e-4) From 52e39b8fbbbf16551ff17aea27c70b9ef16a8e14 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Thu, 16 Feb 2023 19:09:45 +0000 Subject: [PATCH 09/14] lint --- funsor/sum_product.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/funsor/sum_product.py b/funsor/sum_product.py index cd28c9bc..c4460238 100644 --- a/funsor/sum_product.py +++ b/funsor/sum_product.py @@ -272,7 +272,11 @@ def partial_sum_product( remaining_sum_vars = sum_vars.intersection(f.inputs) if not remaining_sum_vars: f = f.reduce(prod_op, leaf & eliminate) - f_scales = [plate_to_scale[plate] for plate in leaf & eliminate if plate in plate_to_scale] + f_scales = [ + plate_to_scale[plate] + for plate in leaf & eliminate + if plate in plate_to_scale + ] if f_scales: scale = reduce(ops.mul, f_scales) f = pow_op(f, scale) @@ -326,7 +330,11 @@ def partial_sum_product( reduced_plates = leaf - new_plates assert reduced_plates.issubset(eliminate) f = f.reduce(prod_op, reduced_plates) - f_scales = [plate_to_scale[plate] for plate in reduced_plates if plate in plate_to_scale] + f_scales = [ + plate_to_scale[plate] + for plate in reduced_plates + if plate in plate_to_scale + ] if f_scales: scale = reduce(ops.mul, f_scales) f = pow_op(f, scale) From 0c415003ad0adbf36cbe8fa2802a5e2c90b4f073 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Thu, 16 Feb 2023 21:21:43 +0000 Subject: [PATCH 10/14] test --- test/test_sum_product.py | 62 ++++++++++++++++++++++------------------ 1 file changed, 34 insertions(+), 28 deletions(-) diff --git a/test/test_sum_product.py b/test/test_sum_product.py index 09a97316..ea55d4c5 100644 --- a/test/test_sum_product.py +++ b/test/test_sum_product.py @@ -2905,66 +2905,82 @@ def test_mixed_sequential_sum_product(duration, num_segments): "sum_op,prod_op", [(ops.logaddexp, ops.add), (ops.add, ops.mul)], ) -@pytest.mark.parametrize("scale", [2, 3]) +@pytest.mark.parametrize("scale", [1, 2]) def test_partial_sum_product_scale_1(sum_op, prod_op, scale): f1 = random_tensor(OrderedDict(a=Bint[2])) f2 = random_tensor(OrderedDict(a=Bint[2], i=Bint[3])) - f3 = Cat("i", (f2,) * scale) eliminate = frozenset("ai") plates = frozenset("i") - factors = [f1, f3] - expected = sum_product(sum_op, prod_op, factors, eliminate, plates) - + # Actual result based on applying scaling factors = [f1, f2] scales = {"i": scale} actual = sum_product( sum_op, prod_op, factors, eliminate, plates, plate_to_scale=scales ) - assert_close(actual, expected, atol=5e-4, rtol=5e-4) + # Expected result based on concatenating factors + f3 = Cat("i", (f2,) * scale) + factors = [f1, f3] + expected = sum_product(sum_op, prod_op, factors, eliminate, plates) + + assert_close(actual, expected, atol=1e-4, rtol=1e-4) @pytest.mark.parametrize( "sum_op,prod_op", [(ops.logaddexp, ops.add), (ops.add, ops.mul)], ) -@pytest.mark.parametrize("scale_i", [2, 3]) -@pytest.mark.parametrize("scale_j", [2, 3]) +@pytest.mark.parametrize("scale_i", [1, 2]) +@pytest.mark.parametrize("scale_j", [1, 3]) def test_partial_sum_product_scale_2(sum_op, prod_op, scale_i, scale_j): f1 = random_tensor(OrderedDict(a=Bint[2])) f2 = random_tensor(OrderedDict(a=Bint[2], i=Bint[3])) f3 = random_tensor(OrderedDict(a=Bint[2], j=Bint[4])) - f4 = Cat("i", (f2,) * scale_i) - f5 = Cat("j", (f3,) * scale_j) eliminate = frozenset("aij") plates = frozenset("ij") - factors = [f1, f4, f5] - expected = sum_product(sum_op, prod_op, factors, eliminate, plates) - + # Actual result based on applying scaling factors = [f1, f2, f3] scales = {"i": scale_i, "j": scale_j} actual = sum_product( sum_op, prod_op, factors, eliminate, plates, plate_to_scale=scales ) - assert_close(actual, expected, atol=5e-4, rtol=5e-4) + # Expected result based on concatenating factors + f4 = Cat("i", (f2,) * scale_i) + f5 = Cat("j", (f3,) * scale_j) + factors = [f1, f4, f5] + expected = sum_product(sum_op, prod_op, factors, eliminate, plates) + + assert_close(actual, expected, atol=1e-4, rtol=1e-4) @pytest.mark.parametrize( "sum_op,prod_op", [(ops.logaddexp, ops.add), (ops.add, ops.mul)], ) -@pytest.mark.parametrize("scale_i", [2, 3]) -@pytest.mark.parametrize("scale_j", [2, 3]) -@pytest.mark.parametrize("scale_k", [2, 3]) +@pytest.mark.parametrize("scale_i", [1, 2]) +@pytest.mark.parametrize("scale_j", [1, 3]) +@pytest.mark.parametrize("scale_k", [1, 4]) def test_partial_sum_product_scale_3(sum_op, prod_op, scale_i, scale_j, scale_k): f1 = random_tensor(OrderedDict(a=Bint[2], i=Bint[2])) f2 = random_tensor(OrderedDict(a=Bint[2], i=Bint[2], j=Bint[3])) f3 = random_tensor(OrderedDict(a=Bint[2], i=Bint[2], j=Bint[3], k=Bint[3])) + + eliminate = frozenset("aijk") + plates = frozenset("ijk") + + # Actual result based on applying scaling + factors = [f1, f2, f3] + scales = {"i": scale_i, "j": scale_j, "k": scale_k} + actual = sum_product( + sum_op, prod_op, factors, eliminate, plates, plate_to_scale=scales + ) + + # Expected result based on concatenating factors f4 = Cat("i", (f1,) * scale_i) # concatenate across multiple dims f5 = Cat("i", (f2,) * scale_i) @@ -2973,17 +2989,7 @@ def test_partial_sum_product_scale_3(sum_op, prod_op, scale_i, scale_j, scale_k) f6 = Cat("i", (f3,) * scale_i) f6 = Cat("j", (f6,) * scale_j) f6 = Cat("k", (f6,) * scale_k) - - eliminate = frozenset("aijk") - plates = frozenset("ijk") - factors = [f4, f5, f6] expected = sum_product(sum_op, prod_op, factors, eliminate, plates) - factors = [f1, f2, f3] - scales = {"i": scale_i, "j": scale_j, "k": scale_k} - actual = sum_product( - sum_op, prod_op, factors, eliminate, plates, plate_to_scale=scales - ) - - assert_close(actual, expected, atol=5e-4, rtol=5e-4) + assert_close(actual, expected, atol=1e-4, rtol=1e-4) From ef8a0e96c7e62d0bdb4817b550c6cb8a2443e73a Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Sat, 26 Aug 2023 23:41:24 +0000 Subject: [PATCH 11/14] address comments --- funsor/ops/builtin.py | 4 ++++ funsor/ops/op.py | 2 ++ funsor/sum_product.py | 53 +++++++++++++++++++++---------------------- 3 files changed, 32 insertions(+), 27 deletions(-) diff --git a/funsor/ops/builtin.py b/funsor/ops/builtin.py index 0f3d0c03..77fbfe6d 100644 --- a/funsor/ops/builtin.py +++ b/funsor/ops/builtin.py @@ -8,6 +8,7 @@ from .op import ( BINARY_INVERSES, DISTRIBUTIVE_OPS, + PRODUCT_TO_POWER, SAFE_BINARY_INVERSES, UNARY_INVERSES, UNITS, @@ -287,6 +288,9 @@ def sigmoid_log_abs_det_jacobian(x, y): UNARY_INVERSES[mul] = reciprocal UNARY_INVERSES[add] = neg +PRODUCT_TO_POWER[add] = mul +PRODUCT_TO_POWER[mul] = pow + __all__ = [ "AssociativeOp", "ComparisonOp", diff --git a/funsor/ops/op.py b/funsor/ops/op.py index 5c5312e8..f7540c63 100644 --- a/funsor/ops/op.py +++ b/funsor/ops/op.py @@ -421,6 +421,7 @@ def log_abs_det_jacobian(x, y, fn): BINARY_INVERSES = {} # binary op -> inverse binary op SAFE_BINARY_INVERSES = {} # binary op -> numerically safe inverse binary op UNARY_INVERSES = {} # binary op -> inverse unary op +PRODUCT_TO_POWER = {} # product op -> power op __all__ = [ "BINARY_INVERSES", @@ -430,6 +431,7 @@ def log_abs_det_jacobian(x, y, fn): "LogAbsDetJacobianOp", "NullaryOp", "Op", + "PRODUCT_TO_POWER", "SAFE_BINARY_INVERSES", "TernaryOp", "TransformOp", diff --git a/funsor/sum_product.py b/funsor/sum_product.py index c4460238..31a4d2fa 100644 --- a/funsor/sum_product.py +++ b/funsor/sum_product.py @@ -11,7 +11,7 @@ from funsor.cnf import Contraction from funsor.domains import Bint, Reals from funsor.interpreter import gensym -from funsor.ops import UNITS, AssociativeOp +from funsor.ops import PRODUCT_TO_POWER, UNITS, AssociativeOp from funsor.terms import ( Cat, Funsor, @@ -209,7 +209,8 @@ def partial_sum_product( eliminate=frozenset(), plates=frozenset(), pedantic=False, - plate_to_scale={}, + pow_op=None, + plate_to_scale=None, # dict ): """ Performs partial sum-product contraction of a collection of factors. @@ -223,15 +224,10 @@ def partial_sum_product( assert all(isinstance(f, Funsor) for f in factors) assert isinstance(eliminate, frozenset) assert isinstance(plates, frozenset) - assert isinstance(plate_to_scale, dict) if plate_to_scale: - if sum_op is ops.logaddexp and prod_op is ops.add: - pow_op = ops.mul - elif sum_op is ops.add and prod_op is ops.mul: - pow_op = ops.pow - else: - raise ValueError("should not be here!") + if pow_op is None: + pow_op = PRODUCT_TO_POWER[prod_op] if pedantic: var_to_errors = defaultdict(lambda: eliminate) @@ -272,14 +268,15 @@ def partial_sum_product( remaining_sum_vars = sum_vars.intersection(f.inputs) if not remaining_sum_vars: f = f.reduce(prod_op, leaf & eliminate) - f_scales = [ - plate_to_scale[plate] - for plate in leaf & eliminate - if plate in plate_to_scale - ] - if f_scales: - scale = reduce(ops.mul, f_scales) - f = pow_op(f, scale) + if plate_to_scale: + f_scales = [ + plate_to_scale[plate] + for plate in leaf & eliminate + if plate in plate_to_scale + ] + if f_scales: + scale = reduce(ops.mul, f_scales) + f = pow_op(f, scale) results.append(f) else: new_plates = frozenset().union( @@ -330,14 +327,15 @@ def partial_sum_product( reduced_plates = leaf - new_plates assert reduced_plates.issubset(eliminate) f = f.reduce(prod_op, reduced_plates) - f_scales = [ - plate_to_scale[plate] - for plate in reduced_plates - if plate in plate_to_scale - ] - if f_scales: - scale = reduce(ops.mul, f_scales) - f = pow_op(f, scale) + if plate_to_scale: + f_scales = [ + plate_to_scale[plate] + for plate in reduced_plates + if plate in plate_to_scale + ] + if f_scales: + scale = reduce(ops.mul, f_scales) + f = pow_op(f, scale) ordinal_to_factors[new_plates].append(f) return results @@ -609,7 +607,8 @@ def sum_product( eliminate=frozenset(), plates=frozenset(), pedantic=False, - plate_to_scale={}, + pow_op=None, + plate_to_scale=None, # dict ): """ Performs sum-product contraction of a collection of factors. @@ -618,7 +617,7 @@ def sum_product( :rtype: :class:`~funsor.terms.Funsor` """ factors = partial_sum_product( - sum_op, prod_op, factors, eliminate, plates, pedantic, plate_to_scale + sum_op, prod_op, factors, eliminate, plates, pedantic, pow_op, plate_to_scale ) return reduce(prod_op, factors, Number(UNITS[prod_op])) From 84221f9916d47eeca314571c508bec644c807c73 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Sun, 27 Aug 2023 00:06:41 +0000 Subject: [PATCH 12/14] fix E721 linting rule --- funsor/testing.py | 4 ++-- test/test_terms.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/funsor/testing.py b/funsor/testing.py index 91336a52..dbf61245 100644 --- a/funsor/testing.py +++ b/funsor/testing.py @@ -81,7 +81,7 @@ def id_from_inputs(inputs): @dispatch(object, object, Variadic[float]) def allclose(a, b, rtol=1e-05, atol=1e-08): - if type(a) != type(b): + if type(a) is not type(b): return False return ops.abs(a - b) < rtol + atol * ops.abs(b) @@ -125,7 +125,7 @@ def assert_close(actual, expected, atol=1e-6, rtol=1e-6): elif isinstance(actual, Gaussian): assert isinstance(expected, Gaussian) else: - assert type(actual) == type(expected), msg + assert type(actual) is type(expected), msg if isinstance(actual, Funsor): assert isinstance(expected, Funsor), msg diff --git a/test/test_terms.py b/test/test_terms.py index db7e586b..af18dfee 100644 --- a/test/test_terms.py +++ b/test/test_terms.py @@ -72,7 +72,7 @@ def test_to_funsor_error(x): def test_to_data(): actual = to_data(Number(0.0)) expected = 0.0 - assert type(actual) == type(expected) + assert type(actual) is type(expected) assert actual == expected @@ -569,7 +569,7 @@ def test_stack_slice(start, stop, step): xs = tuple(map(Number, range(10))) actual = Stack("i", xs)(i=Slice("j", start, stop, step, dtype=10)) expected = Stack("j", xs[start:stop:step]) - assert type(actual) == type(expected) + assert type(actual) is type(expected) assert actual.name == expected.name assert actual.parts == expected.parts From bd174f40d036bc4e38fe7e2f7bcdb9d7f0d74f80 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Sun, 27 Aug 2023 00:18:00 +0000 Subject: [PATCH 13/14] fix the merge --- funsor/sum_product.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/funsor/sum_product.py b/funsor/sum_product.py index e8f9eff2..31a4d2fa 100644 --- a/funsor/sum_product.py +++ b/funsor/sum_product.py @@ -224,15 +224,6 @@ def partial_sum_product( assert all(isinstance(f, Funsor) for f in factors) assert isinstance(eliminate, frozenset) assert isinstance(plates, frozenset) - assert isinstance(plate_to_scale, dict) - - if plate_to_scale: - if sum_op is ops.logaddexp and prod_op is ops.add: - pow_op = ops.mul - elif sum_op is ops.add and prod_op is ops.mul: - pow_op = ops.pow - else: - raise ValueError("should not be here!") if plate_to_scale: if pow_op is None: From d2ce7f3ff57bb583346bf5016d7dbf8ae380adeb Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Sun, 27 Aug 2023 00:26:46 +0000 Subject: [PATCH 14/14] use python 3.9 for jax --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4c39cb1d..a9accdcd 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -67,7 +67,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.8] + python-version: [3.9] env: CI: 1 FUNSOR_BACKEND: jax