Skip to content

Commit

Permalink
Add bounded unique count aggregation
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrooks-stripe committed Jul 2, 2024
1 parent ed7d514 commit 3c368e8
Show file tree
Hide file tree
Showing 8 changed files with 150 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,44 @@ class ApproxHistogram[T: FrequentItemsFriendly](mapSize: Int, errorType: ErrorTy
}
}

class BoundedUniqueCount[T](inputType: DataType, k: Int = 8) extends SimpleAggregator[T, util.Set[T], Long] {
override def prepare(input: T): util.Set[T] = {
val result = new util.HashSet[T](k)
result.add(input)
result
}

override def update(ir: util.Set[T], input: T): util.Set[T] = {
if (ir.size() >= k) {
return ir
}

ir.add(input)
ir
}

override def outputType: DataType = LongType

override def irType: DataType = ListType(inputType)

override def merge(ir1: util.Set[T], ir2: util.Set[T]): util.Set[T] = {
ir2.asScala.foreach(v =>
if (ir1.size() < k) {
ir1.add(v)
})

ir1
}

override def finalize(ir: util.Set[T]): Long = ir.size()

override def clone(ir: util.Set[T]): util.Set[T] = new util.HashSet[T](ir)

override def normalize(ir: util.Set[T]): Any = new util.ArrayList[T](ir)

override def denormalize(ir: Any): util.Set[T] = new util.HashSet[T](ir.asInstanceOf[util.ArrayList[T]])
}

// Based on CPC sketch (a faster, smaller and more accurate version of HLL)
// See: Back to the future: an even more nearly optimal cardinality estimation algorithm, 2017
// https://arxiv.org/abs/1708.06839
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,19 @@ object ColumnAggregator {
case BinaryType => simple(new ApproxDistinctCount[Array[Byte]](aggregationPart.getInt("k", Some(8))))
case _ => mismatchException
}
case Operation.BOUNDED_UNIQUE_COUNT =>
val k = aggregationPart.getInt("k", Some(8))

inputType match {
case IntType => simple(new BoundedUniqueCount[Int](inputType, k))
case LongType => simple(new BoundedUniqueCount[Long](inputType, k))
case ShortType => simple(new BoundedUniqueCount[Short](inputType, k))
case DoubleType => simple(new BoundedUniqueCount[Double](inputType, k))
case FloatType => simple(new BoundedUniqueCount[Float](inputType, k))
case StringType => simple(new BoundedUniqueCount[String](inputType, k))
case BinaryType => simple(new BoundedUniqueCount[Array[Byte]](inputType, k))
case _ => mismatchException
}
case Operation.APPROX_PERCENTILE =>
val k = aggregationPart.getInt("k", Some(128))
val mapper = new ObjectMapper()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package ai.chronon.aggregator.test

import ai.chronon.aggregator.base.BoundedUniqueCount
import ai.chronon.api.StringType
import junit.framework.TestCase
import org.junit.Assert._

import java.util
import scala.jdk.CollectionConverters._

class BoundedUniqueCountTest extends TestCase {
def testHappyCase(): Unit = {
val boundedDistinctCount = new BoundedUniqueCount[String](StringType, 5)
var ir = boundedDistinctCount.prepare("1")
ir = boundedDistinctCount.update(ir, "1")
ir = boundedDistinctCount.update(ir, "2")

val result = boundedDistinctCount.finalize(ir)
assertEquals(2, result)
}

def testExceedSize(): Unit = {
val boundedDistinctCount = new BoundedUniqueCount[String](StringType, 5)
var ir = boundedDistinctCount.prepare("1")
ir = boundedDistinctCount.update(ir, "2")
ir = boundedDistinctCount.update(ir, "3")
ir = boundedDistinctCount.update(ir, "4")
ir = boundedDistinctCount.update(ir, "5")
ir = boundedDistinctCount.update(ir, "6")
ir = boundedDistinctCount.update(ir, "7")

val result = boundedDistinctCount.finalize(ir)
assertEquals(5, result)
}

def testMerge(): Unit = {
val boundedDistinctCount = new BoundedUniqueCount[String](StringType, 5)
val ir1 = new util.HashSet[String](Seq("1", "2", "3").asJava)
val ir2 = new util.HashSet[String](Seq("4", "5", "6").asJava)

val merged = boundedDistinctCount.merge(ir1, ir2)
assertEquals(merged.size(), 5)
}
}
1 change: 1 addition & 0 deletions api/py/ai/chronon/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class Operation:
# https://github.com/apache/incubator-datasketches-java/blob/master/src/main/java/org/apache/datasketches/cpc/CpcSketch.java#L180
APPROX_UNIQUE_COUNT_LGK = collector(ttypes.Operation.APPROX_UNIQUE_COUNT)
UNIQUE_COUNT = ttypes.Operation.UNIQUE_COUNT
BOUNDED_UNIQUE_COUNT = ttypes.Operation.BOUNDED_UNIQUE_COUNT
COUNT = ttypes.Operation.COUNT
SUM = ttypes.Operation.SUM
AVERAGE = ttypes.Operation.AVERAGE
Expand Down
3 changes: 2 additions & 1 deletion api/thrift/api.thrift
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,8 @@ enum Operation {
BOTTOM_K = 16

HISTOGRAM = 17, // use this only if you know the set of inputs is bounded
APPROX_HISTOGRAM_K = 18
APPROX_HISTOGRAM_K = 18,
BOUNDED_UNIQUE_COUNT = 19
}

// integers map to milliseconds in the timeunit
Expand Down
1 change: 1 addition & 0 deletions docs/source/authoring_features/GroupBy.md
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ Limitations:
| approx_unique_count | primitive types | list, map | long | no | k=8 | yes |
| approx_percentile | primitive types | list, map | list<input,> | no | k=128, percentiles | yes |
| unique_count | primitive types | list, map | long | no | | no |
| bounded_unique_count | primitive types | list, map | long | no | k=inf | yes |


## Accuracy
Expand Down
12 changes: 10 additions & 2 deletions spark/src/test/scala/ai/chronon/spark/test/FetcherTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -322,8 +322,11 @@ class FetcherTest extends TestCase {
Builders.Aggregation(operation = Operation.LAST_K,
argMap = Map("k" -> "300"),
inputColumn = "user",
windows = Seq(new Window(2, TimeUnit.DAYS), new Window(30, TimeUnit.DAYS)))
),
windows = Seq(new Window(2, TimeUnit.DAYS), new Window(30, TimeUnit.DAYS))),
Builders.Aggregation(operation = Operation.BOUNDED_UNIQUE_COUNT,
argMap = Map("k" -> "5"),
inputColumn = "user",
windows = Seq(new Window(2, TimeUnit.DAYS), new Window(30, TimeUnit.DAYS)))),
metaData = Builders.MetaData(name = "unit_test/vendor_ratings", namespace = namespace),
accuracy = Accuracy.SNAPSHOT
)
Expand Down Expand Up @@ -547,6 +550,11 @@ class FetcherTest extends TestCase {
operation = Operation.APPROX_HISTOGRAM_K,
inputColumn = "rating",
windows = Seq(new Window(1, TimeUnit.DAYS))
),
Builders.Aggregation(
operation = Operation.BOUNDED_UNIQUE_COUNT,
inputColumn = "rating",
windows = Seq(new Window(1, TimeUnit.DAYS))
)
),
accuracy = Accuracy.TEMPORAL,
Expand Down
42 changes: 42 additions & 0 deletions spark/src/test/scala/ai/chronon/spark/test/GroupByTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -656,4 +656,46 @@ class GroupByTest {
tableUtils = tableUtils,
additionalAgg = aggs)
}

@Test
def testBoundedUniqueCounts(): Unit = {
val (source, endPartition) = createTestSource(suffix = "_bounded_counts")
val tableUtils = TableUtils(spark)
val namespace = "test_bounded_counts"
val aggs = Seq(
Builders.Aggregation(
operation = Operation.BOUNDED_UNIQUE_COUNT,
inputColumn = "item",
windows = Seq(
new Window(15, TimeUnit.DAYS),
new Window(60, TimeUnit.DAYS)
),
argMap = Map("k" -> "5")
),
Builders.Aggregation(
operation = Operation.BOUNDED_UNIQUE_COUNT,
inputColumn = "price",
windows = Seq(
new Window(15, TimeUnit.DAYS),
new Window(60, TimeUnit.DAYS)
),
argMap = Map("k" -> "5")
),
)
backfill(name = "unit_test_group_by_bounded_counts",
source = source,
endPartition = endPartition,
namespace = namespace,
tableUtils = tableUtils,
additionalAgg = aggs)

val result = spark.sql(
"""
|select *
|from test_bounded_counts.unit_test_group_by_bounded_counts
|where item_bounded_unique_count_60d > 5 or price_bounded_unique_count_60d > 5
|""".stripMargin)

assertTrue(result.isEmpty)
}
}

0 comments on commit 3c368e8

Please sign in to comment.