Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tests to cover partition range pruning, where clauses and gen scan query #887

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions spark/src/test/scala/ai/chronon/spark/test/DataRangeTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@

package ai.chronon.spark.test

import ai.chronon.aggregator.test.Column
import ai.chronon.api
import ai.chronon.api.{Builders, Constants, QueryUtils, Source}
import ai.chronon.spark.{PartitionRange, SparkSessionBuilder, TableUtils}
import ai.chronon.spark.Extensions._
import org.apache.spark.sql.SparkSession
import org.junit.Assert.assertEquals
import org.junit.Test
Expand All @@ -25,10 +29,77 @@ class DataRangeTest {
val spark: SparkSession = SparkSessionBuilder.build("DataRangeTest", local = true)
private val tableUtils = TableUtils(spark)

@Test
def testGenScanQuery(): Unit = {
val namespace = "date_range_test_namespace"
spark.sql(s"CREATE DATABASE IF NOT EXISTS $namespace")
val testTable = s"$namespace.test_gen_scan_query"
val viewsSchema = List(
Column("col_1", api.StringType, 1),
Column("col_2", api.StringType, 1),
)
DataFrameGen
.events(spark, viewsSchema, count = 1000, partitions = 200)
.drop("ds")
.save(testTable, partitionColumns = Seq())
val partitionRange: PartitionRange = PartitionRange("2024-03-01", "2024-04-01")(tableUtils)
val source: Source = Builders.Source.events(
query = Builders.Query(
selects = Builders.Selects("col_1", "col_2"),
wheres = Seq("col_1 = 'TEST'"),
timeColumn = "ts"
),
table = testTable,
)

val result: String = partitionRange.genScanQuery(
source.getEvents.query,
testTable,
Seq(Constants.TimeColumn -> Option(source.getEvents.query).map(_.timeColumn).orNull).toMap
)

val expected: String =
"""SELECT
| ts as `ts`,
| col_1 as `col_1`,
| col_2 as `col_2`
|FROM date_range_test_namespace.test_gen_scan_query
|WHERE
| (ds >= '2024-03-01') AND (ds <= '2024-04-01') AND (col_1 = 'TEST')"""
assertEquals(expected.stripMargin, result.stripMargin)
}

@Test
def testIntersect(): Unit = {
val range1 = PartitionRange(null, null)(tableUtils)
val range2 = PartitionRange("2023-01-01", "2023-01-02")(tableUtils)
assertEquals(range2, range1.intersect(range2))
}

@Test
def testWhereClauses(): Unit = {
val range = PartitionRange("2023-01-01", "2023-01-02")(tableUtils)

val clauses = range.whereClauses("ds")

assertEquals(Seq("ds >= '2023-01-01'", "ds <= '2023-01-02'"), clauses)
}

@Test
def testWhereClausesNullStart(): Unit = {
val range = PartitionRange(null, "2023-01-02")(tableUtils)

val clauses = range.whereClauses("ds")

assertEquals(Seq("ds <= '2023-01-02'"), clauses)
}

@Test
def testWhereClausesNullEnd(): Unit = {
val range = PartitionRange("2023-01-01", null)(tableUtils)

val clauses = range.whereClauses("ds")

assertEquals(Seq("ds >= '2023-01-01'"), clauses)
}
}
46 changes: 46 additions & 0 deletions spark/src/test/scala/ai/chronon/spark/test/ExtensionsTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package ai.chronon.spark.test

import ai.chronon.spark.Extensions._
import ai.chronon.spark.{Comparison, PartitionRange, SparkSessionBuilder, TableUtils}
import org.apache.spark.sql.SparkSession
import org.junit.Assert.assertEquals
import org.junit.Test

class ExtensionsTest {

lazy val spark: SparkSession = SparkSessionBuilder.build("ExtensionsTest", local = true)

import spark.implicits._

private implicit val tableUtils = TableUtils(spark)

@Test
def testPrunePartitionTest(): Unit = {
val df = Seq(
(1, "2024-01-03"),
(2, "2024-01-04"),
(3, "2024-01-04"),
(4, "2024-01-05"),
(5, "2024-01-05"),
(6, "2024-01-06"),
(7, "2024-01-07"),
(8, "2024-01-08"),
(9, "2024-01-08"),
(10, "2024-01-09"),
).toDF("key", "ds")

val prunedDf = df.prunePartition(PartitionRange("2024-01-05", "2024-01-07"))

val expectedDf = Seq(
(4, "2024-01-05"),
(5, "2024-01-05"),
(6, "2024-01-06"),
(7, "2024-01-07"),
).toDF("key", "ds")
val diff = Comparison.sideBySide(expectedDf, prunedDf, List("key"))
if (diff.count() != 0) {
diff.show()
}
assertEquals(0, diff.count())
}
}