diff --git a/spark/src/test/scala/ai/chronon/spark/test/DataRangeTest.scala b/spark/src/test/scala/ai/chronon/spark/test/DataRangeTest.scala index 313360a74..0d37e7af2 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/DataRangeTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/DataRangeTest.scala @@ -16,7 +16,11 @@ package ai.chronon.spark.test +import ai.chronon.aggregator.test.Column +import ai.chronon.api +import ai.chronon.api.{Builders, Constants, QueryUtils, Source} import ai.chronon.spark.{PartitionRange, SparkSessionBuilder, TableUtils} +import ai.chronon.spark.Extensions._ import org.apache.spark.sql.SparkSession import org.junit.Assert.assertEquals import org.junit.Test @@ -25,10 +29,77 @@ class DataRangeTest { val spark: SparkSession = SparkSessionBuilder.build("DataRangeTest", local = true) private val tableUtils = TableUtils(spark) + @Test + def testGenScanQuery(): Unit = { + val namespace = "date_range_test_namespace" + spark.sql(s"CREATE DATABASE IF NOT EXISTS $namespace") + val testTable = s"$namespace.test_gen_scan_query" + val viewsSchema = List( + Column("col_1", api.StringType, 1), + Column("col_2", api.StringType, 1), + ) + DataFrameGen + .events(spark, viewsSchema, count = 1000, partitions = 200) + .drop("ds") + .save(testTable, partitionColumns = Seq()) + val partitionRange: PartitionRange = PartitionRange("2024-03-01", "2024-04-01")(tableUtils) + val source: Source = Builders.Source.events( + query = Builders.Query( + selects = Builders.Selects("col_1", "col_2"), + wheres = Seq("col_1 = 'TEST'"), + timeColumn = "ts" + ), + table = testTable, + ) + + val result: String = partitionRange.genScanQuery( + source.getEvents.query, + testTable, + Seq(Constants.TimeColumn -> Option(source.getEvents.query).map(_.timeColumn).orNull).toMap + ) + + val expected: String = + """SELECT + | ts as `ts`, + | col_1 as `col_1`, + | col_2 as `col_2` + |FROM date_range_test_namespace.test_gen_scan_query + |WHERE + | (ds >= '2024-03-01') AND (ds <= '2024-04-01') AND (col_1 = 'TEST')""" + assertEquals(expected.stripMargin, result.stripMargin) + } + @Test def testIntersect(): Unit = { val range1 = PartitionRange(null, null)(tableUtils) val range2 = PartitionRange("2023-01-01", "2023-01-02")(tableUtils) assertEquals(range2, range1.intersect(range2)) } + + @Test + def testWhereClauses(): Unit = { + val range = PartitionRange("2023-01-01", "2023-01-02")(tableUtils) + + val clauses = range.whereClauses("ds") + + assertEquals(Seq("ds >= '2023-01-01'", "ds <= '2023-01-02'"), clauses) + } + + @Test + def testWhereClausesNullStart(): Unit = { + val range = PartitionRange(null, "2023-01-02")(tableUtils) + + val clauses = range.whereClauses("ds") + + assertEquals(Seq("ds <= '2023-01-02'"), clauses) + } + + @Test + def testWhereClausesNullEnd(): Unit = { + val range = PartitionRange("2023-01-01", null)(tableUtils) + + val clauses = range.whereClauses("ds") + + assertEquals(Seq("ds >= '2023-01-01'"), clauses) + } } diff --git a/spark/src/test/scala/ai/chronon/spark/test/ExtensionsTest.scala b/spark/src/test/scala/ai/chronon/spark/test/ExtensionsTest.scala new file mode 100644 index 000000000..3a33dc90c --- /dev/null +++ b/spark/src/test/scala/ai/chronon/spark/test/ExtensionsTest.scala @@ -0,0 +1,46 @@ +package ai.chronon.spark.test + +import ai.chronon.spark.Extensions._ +import ai.chronon.spark.{Comparison, PartitionRange, SparkSessionBuilder, TableUtils} +import org.apache.spark.sql.SparkSession +import org.junit.Assert.assertEquals +import org.junit.Test + +class ExtensionsTest { + + lazy val spark: SparkSession = SparkSessionBuilder.build("ExtensionsTest", local = true) + + import spark.implicits._ + + private implicit val tableUtils = TableUtils(spark) + + @Test + def testPrunePartitionTest(): Unit = { + val df = Seq( + (1, "2024-01-03"), + (2, "2024-01-04"), + (3, "2024-01-04"), + (4, "2024-01-05"), + (5, "2024-01-05"), + (6, "2024-01-06"), + (7, "2024-01-07"), + (8, "2024-01-08"), + (9, "2024-01-08"), + (10, "2024-01-09"), + ).toDF("key", "ds") + + val prunedDf = df.prunePartition(PartitionRange("2024-01-05", "2024-01-07")) + + val expectedDf = Seq( + (4, "2024-01-05"), + (5, "2024-01-05"), + (6, "2024-01-06"), + (7, "2024-01-07"), + ).toDF("key", "ds") + val diff = Comparison.sideBySide(expectedDf, prunedDf, List("key")) + if (diff.count() != 0) { + diff.show() + } + assertEquals(0, diff.count()) + } +} \ No newline at end of file