From 3c368e80d1e8858eb46ce631e168dfc089dd6c84 Mon Sep 17 00:00:00 2001 From: Jeffrey Brooks Date: Tue, 2 Jul 2024 13:52:00 -0700 Subject: [PATCH] Add bounded unique count aggregation --- .../aggregator/base/SimpleAggregators.scala | 38 ++++++++++++++++ .../aggregator/row/ColumnAggregator.scala | 12 +++++ .../test/BoundedUniqueCountTest.scala | 44 +++++++++++++++++++ api/py/ai/chronon/group_by.py | 1 + api/thrift/api.thrift | 3 +- docs/source/authoring_features/GroupBy.md | 1 + .../ai/chronon/spark/test/FetcherTest.scala | 12 ++++- .../ai/chronon/spark/test/GroupByTest.scala | 42 ++++++++++++++++++ 8 files changed, 150 insertions(+), 3 deletions(-) create mode 100644 aggregator/src/test/scala/ai/chronon/aggregator/test/BoundedUniqueCountTest.scala diff --git a/aggregator/src/main/scala/ai/chronon/aggregator/base/SimpleAggregators.scala b/aggregator/src/main/scala/ai/chronon/aggregator/base/SimpleAggregators.scala index b120d29e7..83fd8df64 100644 --- a/aggregator/src/main/scala/ai/chronon/aggregator/base/SimpleAggregators.scala +++ b/aggregator/src/main/scala/ai/chronon/aggregator/base/SimpleAggregators.scala @@ -599,6 +599,44 @@ class ApproxHistogram[T: FrequentItemsFriendly](mapSize: Int, errorType: ErrorTy } } +class BoundedUniqueCount[T](inputType: DataType, k: Int = 8) extends SimpleAggregator[T, util.Set[T], Long] { + override def prepare(input: T): util.Set[T] = { + val result = new util.HashSet[T](k) + result.add(input) + result + } + + override def update(ir: util.Set[T], input: T): util.Set[T] = { + if (ir.size() >= k) { + return ir + } + + ir.add(input) + ir + } + + override def outputType: DataType = LongType + + override def irType: DataType = ListType(inputType) + + override def merge(ir1: util.Set[T], ir2: util.Set[T]): util.Set[T] = { + ir2.asScala.foreach(v => + if (ir1.size() < k) { + ir1.add(v) + }) + + ir1 + } + + override def finalize(ir: util.Set[T]): Long = ir.size() + + override def clone(ir: util.Set[T]): util.Set[T] = new util.HashSet[T](ir) + + override def normalize(ir: util.Set[T]): Any = new util.ArrayList[T](ir) + + override def denormalize(ir: Any): util.Set[T] = new util.HashSet[T](ir.asInstanceOf[util.ArrayList[T]]) +} + // Based on CPC sketch (a faster, smaller and more accurate version of HLL) // See: Back to the future: an even more nearly optimal cardinality estimation algorithm, 2017 // https://arxiv.org/abs/1708.06839 diff --git a/aggregator/src/main/scala/ai/chronon/aggregator/row/ColumnAggregator.scala b/aggregator/src/main/scala/ai/chronon/aggregator/row/ColumnAggregator.scala index 81ed14337..5d7793199 100644 --- a/aggregator/src/main/scala/ai/chronon/aggregator/row/ColumnAggregator.scala +++ b/aggregator/src/main/scala/ai/chronon/aggregator/row/ColumnAggregator.scala @@ -304,7 +304,19 @@ object ColumnAggregator { case BinaryType => simple(new ApproxDistinctCount[Array[Byte]](aggregationPart.getInt("k", Some(8)))) case _ => mismatchException } + case Operation.BOUNDED_UNIQUE_COUNT => + val k = aggregationPart.getInt("k", Some(8)) + inputType match { + case IntType => simple(new BoundedUniqueCount[Int](inputType, k)) + case LongType => simple(new BoundedUniqueCount[Long](inputType, k)) + case ShortType => simple(new BoundedUniqueCount[Short](inputType, k)) + case DoubleType => simple(new BoundedUniqueCount[Double](inputType, k)) + case FloatType => simple(new BoundedUniqueCount[Float](inputType, k)) + case StringType => simple(new BoundedUniqueCount[String](inputType, k)) + case BinaryType => simple(new BoundedUniqueCount[Array[Byte]](inputType, k)) + case _ => mismatchException + } case Operation.APPROX_PERCENTILE => val k = aggregationPart.getInt("k", Some(128)) val mapper = new ObjectMapper() diff --git a/aggregator/src/test/scala/ai/chronon/aggregator/test/BoundedUniqueCountTest.scala b/aggregator/src/test/scala/ai/chronon/aggregator/test/BoundedUniqueCountTest.scala new file mode 100644 index 000000000..d780aae08 --- /dev/null +++ b/aggregator/src/test/scala/ai/chronon/aggregator/test/BoundedUniqueCountTest.scala @@ -0,0 +1,44 @@ +package ai.chronon.aggregator.test + +import ai.chronon.aggregator.base.BoundedUniqueCount +import ai.chronon.api.StringType +import junit.framework.TestCase +import org.junit.Assert._ + +import java.util +import scala.jdk.CollectionConverters._ + +class BoundedUniqueCountTest extends TestCase { + def testHappyCase(): Unit = { + val boundedDistinctCount = new BoundedUniqueCount[String](StringType, 5) + var ir = boundedDistinctCount.prepare("1") + ir = boundedDistinctCount.update(ir, "1") + ir = boundedDistinctCount.update(ir, "2") + + val result = boundedDistinctCount.finalize(ir) + assertEquals(2, result) + } + + def testExceedSize(): Unit = { + val boundedDistinctCount = new BoundedUniqueCount[String](StringType, 5) + var ir = boundedDistinctCount.prepare("1") + ir = boundedDistinctCount.update(ir, "2") + ir = boundedDistinctCount.update(ir, "3") + ir = boundedDistinctCount.update(ir, "4") + ir = boundedDistinctCount.update(ir, "5") + ir = boundedDistinctCount.update(ir, "6") + ir = boundedDistinctCount.update(ir, "7") + + val result = boundedDistinctCount.finalize(ir) + assertEquals(5, result) + } + + def testMerge(): Unit = { + val boundedDistinctCount = new BoundedUniqueCount[String](StringType, 5) + val ir1 = new util.HashSet[String](Seq("1", "2", "3").asJava) + val ir2 = new util.HashSet[String](Seq("4", "5", "6").asJava) + + val merged = boundedDistinctCount.merge(ir1, ir2) + assertEquals(merged.size(), 5) + } +} \ No newline at end of file diff --git a/api/py/ai/chronon/group_by.py b/api/py/ai/chronon/group_by.py index b74ff39ae..ca57f1327 100644 --- a/api/py/ai/chronon/group_by.py +++ b/api/py/ai/chronon/group_by.py @@ -64,6 +64,7 @@ class Operation: # https://github.com/apache/incubator-datasketches-java/blob/master/src/main/java/org/apache/datasketches/cpc/CpcSketch.java#L180 APPROX_UNIQUE_COUNT_LGK = collector(ttypes.Operation.APPROX_UNIQUE_COUNT) UNIQUE_COUNT = ttypes.Operation.UNIQUE_COUNT + BOUNDED_UNIQUE_COUNT = ttypes.Operation.BOUNDED_UNIQUE_COUNT COUNT = ttypes.Operation.COUNT SUM = ttypes.Operation.SUM AVERAGE = ttypes.Operation.AVERAGE diff --git a/api/thrift/api.thrift b/api/thrift/api.thrift index 883fc2ce8..30566db3e 100644 --- a/api/thrift/api.thrift +++ b/api/thrift/api.thrift @@ -160,7 +160,8 @@ enum Operation { BOTTOM_K = 16 HISTOGRAM = 17, // use this only if you know the set of inputs is bounded - APPROX_HISTOGRAM_K = 18 + APPROX_HISTOGRAM_K = 18, + BOUNDED_UNIQUE_COUNT = 19 } // integers map to milliseconds in the timeunit diff --git a/docs/source/authoring_features/GroupBy.md b/docs/source/authoring_features/GroupBy.md index c9e5cb5f4..cdaa844a1 100644 --- a/docs/source/authoring_features/GroupBy.md +++ b/docs/source/authoring_features/GroupBy.md @@ -147,6 +147,7 @@ Limitations: | approx_unique_count | primitive types | list, map | long | no | k=8 | yes | | approx_percentile | primitive types | list, map | list | no | k=128, percentiles | yes | | unique_count | primitive types | list, map | long | no | | no | +| bounded_unique_count | primitive types | list, map | long | no | k=inf | yes | ## Accuracy diff --git a/spark/src/test/scala/ai/chronon/spark/test/FetcherTest.scala b/spark/src/test/scala/ai/chronon/spark/test/FetcherTest.scala index 9091c9ee4..d54932123 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/FetcherTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/FetcherTest.scala @@ -322,8 +322,11 @@ class FetcherTest extends TestCase { Builders.Aggregation(operation = Operation.LAST_K, argMap = Map("k" -> "300"), inputColumn = "user", - windows = Seq(new Window(2, TimeUnit.DAYS), new Window(30, TimeUnit.DAYS))) - ), + windows = Seq(new Window(2, TimeUnit.DAYS), new Window(30, TimeUnit.DAYS))), + Builders.Aggregation(operation = Operation.BOUNDED_UNIQUE_COUNT, + argMap = Map("k" -> "5"), + inputColumn = "user", + windows = Seq(new Window(2, TimeUnit.DAYS), new Window(30, TimeUnit.DAYS)))), metaData = Builders.MetaData(name = "unit_test/vendor_ratings", namespace = namespace), accuracy = Accuracy.SNAPSHOT ) @@ -547,6 +550,11 @@ class FetcherTest extends TestCase { operation = Operation.APPROX_HISTOGRAM_K, inputColumn = "rating", windows = Seq(new Window(1, TimeUnit.DAYS)) + ), + Builders.Aggregation( + operation = Operation.BOUNDED_UNIQUE_COUNT, + inputColumn = "rating", + windows = Seq(new Window(1, TimeUnit.DAYS)) ) ), accuracy = Accuracy.TEMPORAL, diff --git a/spark/src/test/scala/ai/chronon/spark/test/GroupByTest.scala b/spark/src/test/scala/ai/chronon/spark/test/GroupByTest.scala index ba6f41dfe..16d3c3e1a 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/GroupByTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/GroupByTest.scala @@ -656,4 +656,46 @@ class GroupByTest { tableUtils = tableUtils, additionalAgg = aggs) } + + @Test + def testBoundedUniqueCounts(): Unit = { + val (source, endPartition) = createTestSource(suffix = "_bounded_counts") + val tableUtils = TableUtils(spark) + val namespace = "test_bounded_counts" + val aggs = Seq( + Builders.Aggregation( + operation = Operation.BOUNDED_UNIQUE_COUNT, + inputColumn = "item", + windows = Seq( + new Window(15, TimeUnit.DAYS), + new Window(60, TimeUnit.DAYS) + ), + argMap = Map("k" -> "5") + ), + Builders.Aggregation( + operation = Operation.BOUNDED_UNIQUE_COUNT, + inputColumn = "price", + windows = Seq( + new Window(15, TimeUnit.DAYS), + new Window(60, TimeUnit.DAYS) + ), + argMap = Map("k" -> "5") + ), + ) + backfill(name = "unit_test_group_by_bounded_counts", + source = source, + endPartition = endPartition, + namespace = namespace, + tableUtils = tableUtils, + additionalAgg = aggs) + + val result = spark.sql( + """ + |select * + |from test_bounded_counts.unit_test_group_by_bounded_counts + |where item_bounded_unique_count_60d > 5 or price_bounded_unique_count_60d > 5 + |""".stripMargin) + + assertTrue(result.isEmpty) + } }