From 215c014733a930272c458127e157f0fa5593a004 Mon Sep 17 00:00:00 2001 From: "Caio Camatta (Stripe)" <108533014+caiocamatta-stripe@users.noreply.github.com> Date: Thu, 27 Jun 2024 12:43:04 -0400 Subject: [PATCH] [CHIP-1] Cache batch IRs in the Fetcher (#682) * Add LRU Cache * Double gauge * Revert accidental build sbt change * Move GroupByRequestMeta to object * Add FetcherCache and tests * Update FetcherBase to use cache * Scala 2.13 support? * Add getServingInfo unit tests * Refactor getServingInfo tests * Fix stub * Fix Mockito "ambiguous reference to overloaded definition" error * Fix "Both batch and streaming data are null" check * Address PR review: add comments and use logger * Fewer comments for constructGroupByResponse * Use the FlagStore to determine if caching is enabled * Apply suggestions from code review Co-authored-by: Pengyu Hou <3771747+pengyu-hou@users.noreply.github.com> Signed-off-by: Caio Camatta (Stripe) <108533014+caiocamatta-stripe@users.noreply.github.com> * Address review, add comments, rename tests * Change test names * CamelCase FetcherBaseTest * Update online/src/main/scala/ai/chronon/online/FetcherBase.scala Co-authored-by: Pengyu Hou <3771747+pengyu-hou@users.noreply.github.com> Signed-off-by: Caio Camatta (Stripe) <108533014+caiocamatta-stripe@users.noreply.github.com> * fmt --------- Signed-off-by: Caio Camatta (Stripe) <108533014+caiocamatta-stripe@users.noreply.github.com> Co-authored-by: Pengyu Hou <3771747+pengyu-hou@users.noreply.github.com> --- build.sbt | 6 +- .../java/ai/chronon/online/FlagStore.java | 9 +- .../scala/ai/chronon/online/FetcherBase.scala | 288 ++++++++++----- .../ai/chronon/online/FetcherCache.scala | 216 +++++++++++ .../scala/ai/chronon/online/LRUCache.scala | 59 ++++ .../scala/ai/chronon/online/Metrics.scala | 2 + .../ai/chronon/online/FetcherBaseTest.scala | 91 ++++- .../ai/chronon/online/FetcherCacheTest.scala | 334 ++++++++++++++++++ .../ai/chronon/online/LRUCacheTest.scala | 37 ++ 9 files changed, 942 insertions(+), 100 deletions(-) create mode 100644 online/src/main/scala/ai/chronon/online/FetcherCache.scala create mode 100644 online/src/main/scala/ai/chronon/online/LRUCache.scala create mode 100644 online/src/test/scala/ai/chronon/online/FetcherCacheTest.scala create mode 100644 online/src/test/scala/ai/chronon/online/LRUCacheTest.scala diff --git a/build.sbt b/build.sbt index 7a88c0e34..ca926e678 100644 --- a/build.sbt +++ b/build.sbt @@ -309,7 +309,8 @@ lazy val online = project // statsd 3.0 has local aggregation - TODO: upgrade "com.datadoghq" % "java-dogstatsd-client" % "2.7", "org.rogach" %% "scallop" % "4.0.1", - "net.jodah" % "typetools" % "0.4.1" + "net.jodah" % "typetools" % "0.4.1", + "com.github.ben-manes.caffeine" % "caffeine" % "2.9.3" ), libraryDependencies ++= fromMatrix(scalaVersion.value, "spark-all", "scala-parallel-collections", "netty-buffer"), version := git.versionProperty.value @@ -326,7 +327,8 @@ lazy val online_unshaded = (project in file("online")) // statsd 3.0 has local aggregation - TODO: upgrade "com.datadoghq" % "java-dogstatsd-client" % "2.7", "org.rogach" %% "scallop" % "4.0.1", - "net.jodah" % "typetools" % "0.4.1" + "net.jodah" % "typetools" % "0.4.1", + "com.github.ben-manes.caffeine" % "caffeine" % "2.9.3" ), libraryDependencies ++= fromMatrix(scalaVersion.value, "jackson", diff --git a/online/src/main/java/ai/chronon/online/FlagStore.java b/online/src/main/java/ai/chronon/online/FlagStore.java index cb39668b3..bdc0c1bd3 100644 --- a/online/src/main/java/ai/chronon/online/FlagStore.java +++ b/online/src/main/java/ai/chronon/online/FlagStore.java @@ -3,7 +3,14 @@ import java.io.Serializable; import java.util.Map; -// Interface to allow rolling out features/infrastructure changes in a safe & controlled manner +/** + * Interface to allow rolling out features/infrastructure changes in a safe & controlled manner. + * + * The "Flag"s in FlagStore referes to 'feature flags', a technique that allows enabling or disabling features at + * runtime. + * + * Chronon users can provide their own implementation in the Api. + */ public interface FlagStore extends Serializable { Boolean isSet(String flagName, Map attributes); } diff --git a/online/src/main/scala/ai/chronon/online/FetcherBase.scala b/online/src/main/scala/ai/chronon/online/FetcherBase.scala index 3f2051168..71e43f64f 100644 --- a/online/src/main/scala/ai/chronon/online/FetcherBase.scala +++ b/online/src/main/scala/ai/chronon/online/FetcherBase.scala @@ -22,6 +22,7 @@ import ai.chronon.aggregator.windowing.{FinalBatchIr, SawtoothOnlineAggregator, import ai.chronon.api.Constants.ChrononMetadataKey import ai.chronon.api._ import ai.chronon.online.Fetcher.{ColumnSpec, PrefixedRequest, Request, Response} +import ai.chronon.online.FetcherCache.{BatchResponses, CachedBatchResponse, KvStoreBatchResponse} import ai.chronon.online.KVStore.{GetRequest, GetResponse, TimedValue} import ai.chronon.online.Metrics.Name import ai.chronon.api.Extensions.{DerivationOps, GroupByOps, JoinOps, ThrowableOps} @@ -50,43 +51,46 @@ class FetcherBase(kvStore: KVStore, timeoutMillis: Long = 10000, debug: Boolean = false, flagStore: FlagStore = null) - extends MetadataStore(kvStore, metaDataSet, timeoutMillis) { + extends MetadataStore(kvStore, metaDataSet, timeoutMillis) + with FetcherCache { + import FetcherBase._ - private case class GroupByRequestMeta( - groupByServingInfoParsed: GroupByServingInfoParsed, - batchRequest: GetRequest, - streamingRequestOpt: Option[GetRequest], - endTs: Option[Long], - context: Metrics.Context - ) - - // a groupBy request is split into batchRequest and optionally a streamingRequest - // this method decodes bytes (of the appropriate avro schema) into chronon rows aggregates further if necessary - private def constructGroupByResponse(batchResponsesTry: Try[Seq[TimedValue]], + /** + * A groupBy request is split into batchRequest and optionally a streamingRequest. This method decodes bytes + * (of the appropriate avro schema) into chronon rows aggregates further if necessary. + */ + private def constructGroupByResponse(batchResponses: BatchResponses, streamingResponsesOpt: Option[Seq[TimedValue]], oldServingInfo: GroupByServingInfoParsed, - queryTimeMs: Long, - startTimeMs: Long, - overallLatency: Long, + queryTimeMs: Long, // the timestamp of the Request being served. + startTimeMs: Long, // timestamp right before the KV store fetch. + overallLatency: Long, // the time it took to get the values from the KV store context: Metrics.Context, - totalResponseValueBytes: Int): Map[String, AnyRef] = { - val latestBatchValue = batchResponsesTry.map(_.maxBy(_.millis)) - val servingInfo = - latestBatchValue.map(timedVal => updateServingInfo(timedVal.millis, oldServingInfo)).getOrElse(oldServingInfo) - batchResponsesTry.map { - reportKvResponse(context.withSuffix("batch"), _, queryTimeMs, overallLatency, totalResponseValueBytes) + totalResponseValueBytes: Int, + keys: Map[String, Any] // The keys are used only for caching + ): Map[String, AnyRef] = { + val servingInfo = getServingInfo(oldServingInfo, batchResponses) + + // Batch metrics + batchResponses match { + case kvStoreResponse: KvStoreBatchResponse => + kvStoreResponse.response.map( + reportKvResponse(context.withSuffix("batch"), _, queryTimeMs, overallLatency, totalResponseValueBytes) + ) + case _: CachedBatchResponse => // no-op; } - // bulk upload didn't remove an older batch value - so we manually discard - val batchBytes: Array[Byte] = batchResponsesTry - .map(_.maxBy(_.millis)) - .filter(_.millis >= servingInfo.batchEndTsMillis) - .map(_.bytes) - .getOrElse(null) + // The bulk upload may not have removed an older batch values. We manually discard all but the latest one. + val batchBytes: Array[Byte] = batchResponses.getBatchBytes(servingInfo.batchEndTsMillis) + val responseMap: Map[String, AnyRef] = if (servingInfo.groupBy.aggregations == null) { // no-agg - servingInfo.selectedCodec.decodeMap(batchBytes) + getMapResponseFromBatchResponse(batchResponses, + batchBytes, + servingInfo.selectedCodec.decodeMap, + servingInfo, + keys) } else if (streamingResponsesOpt.isEmpty) { // snapshot accurate - servingInfo.outputCodec.decodeMap(batchBytes) + getMapResponseFromBatchResponse(batchResponses, batchBytes, servingInfo.outputCodec.decodeMap, servingInfo, keys) } else { // temporal accurate val streamingResponses = streamingResponsesOpt.get val mutations: Boolean = servingInfo.groupByOps.dataModel == DataModel.Entities @@ -97,62 +101,74 @@ class FetcherBase(kvStore: KVStore, s"Request time of $queryTimeMs is less than batch time ${aggregator.batchEndTs}" + s" for groupBy ${servingInfo.groupByOps.metaData.getName}")) null - } else if (batchBytes == null && (streamingResponses == null || streamingResponses.isEmpty)) { + } else if ( + // Check if there's no streaming data. + (streamingResponses == null || streamingResponses.isEmpty) && + // Check if there's no batch data. This is only possible if the batch response is from a KV Store request + // (KvStoreBatchResponse) that returned null bytes. It's not possible to have null batch data with cached batch + // responses as we only cache non-null data. + (batchResponses.isInstanceOf[KvStoreBatchResponse] && batchBytes == null) + ) { if (debug) logger.info("Both batch and streaming data are null") - null - } else { - reportKvResponse(context.withSuffix("streaming"), - streamingResponses, - queryTimeMs, - overallLatency, - totalResponseValueBytes) - - val batchIr = toBatchIr(batchBytes, servingInfo) - val output: Array[Any] = if (servingInfo.isTilingEnabled) { - val streamingIrs: Iterator[TiledIr] = streamingResponses.iterator - .filter(tVal => tVal.millis >= servingInfo.batchEndTsMillis) - .map { tVal => - val (tile, _) = servingInfo.tiledCodec.decodeTileIr(tVal.bytes) - TiledIr(tVal.millis, tile) - } + return null + } - if (debug) { - val gson = new Gson() - logger.info(s""" - |batch ir: ${gson.toJson(batchIr)} - |streamingIrs: ${gson.toJson(streamingIrs)} - |batchEnd in millis: ${servingInfo.batchEndTsMillis} - |queryTime in millis: $queryTimeMs - |""".stripMargin) - } + // Streaming metrics + reportKvResponse(context.withSuffix("streaming"), + streamingResponses, + queryTimeMs, + overallLatency, + totalResponseValueBytes) - aggregator.lambdaAggregateFinalizedTiled(batchIr, streamingIrs, queryTimeMs) - } else { - val selectedCodec = servingInfo.groupByOps.dataModel match { - case DataModel.Events => servingInfo.valueAvroCodec - case DataModel.Entities => servingInfo.mutationValueAvroCodec - } + // If caching is enabled, we try to fetch the batch IR from the cache so we avoid the work of decoding it. + val batchIr: FinalBatchIr = + getBatchIrFromBatchResponse(batchResponses, batchBytes, servingInfo, toBatchIr, keys) - val streamingRows: Array[Row] = streamingResponses.iterator - .filter(tVal => tVal.millis >= servingInfo.batchEndTsMillis) - .map(tVal => selectedCodec.decodeRow(tVal.bytes, tVal.millis, mutations)) - .toArray - - if (debug) { - val gson = new Gson() - logger.info(s""" - |batch ir: ${gson.toJson(batchIr)} - |streamingRows: ${gson.toJson(streamingRows)} - |batchEnd in millis: ${servingInfo.batchEndTsMillis} - |queryTime in millis: $queryTimeMs - |""".stripMargin) + val output: Array[Any] = if (servingInfo.isTilingEnabled) { + val streamingIrs: Iterator[TiledIr] = streamingResponses.iterator + .filter(tVal => tVal.millis >= servingInfo.batchEndTsMillis) + .map { tVal => + val (tile, _) = servingInfo.tiledCodec.decodeTileIr(tVal.bytes) + TiledIr(tVal.millis, tile) } - aggregator.lambdaAggregateFinalized(batchIr, streamingRows.iterator, queryTimeMs, mutations) + if (debug) { + val gson = new Gson() + logger.info(s""" + |batch ir: ${gson.toJson(batchIr)} + |streamingIrs: ${gson.toJson(streamingIrs)} + |batchEnd in millis: ${servingInfo.batchEndTsMillis} + |queryTime in millis: $queryTimeMs + |""".stripMargin) + } + + aggregator.lambdaAggregateFinalizedTiled(batchIr, streamingIrs, queryTimeMs) + } else { + val selectedCodec = servingInfo.groupByOps.dataModel match { + case DataModel.Events => servingInfo.valueAvroCodec + case DataModel.Entities => servingInfo.mutationValueAvroCodec + } + + val streamingRows: Array[Row] = streamingResponses.iterator + .filter(tVal => tVal.millis >= servingInfo.batchEndTsMillis) + .map(tVal => selectedCodec.decodeRow(tVal.bytes, tVal.millis, mutations)) + .toArray + + if (debug) { + val gson = new Gson() + logger.info(s""" + |batch ir: ${gson.toJson(batchIr)} + |streamingRows: ${gson.toJson(streamingRows)} + |batchEnd in millis: ${servingInfo.batchEndTsMillis} + |queryTime in millis: $queryTimeMs + |""".stripMargin) } - servingInfo.outputCodec.fieldNames.iterator.zip(output.iterator.map(_.asInstanceOf[AnyRef])).toMap + + aggregator.lambdaAggregateFinalized(batchIr, streamingRows.iterator, queryTimeMs, mutations) } + servingInfo.outputCodec.fieldNames.iterator.zip(output.iterator.map(_.asInstanceOf[AnyRef])).toMap } + context.distribution("group_by.latency.millis", System.currentTimeMillis() - startTimeMs) responseMap } @@ -175,8 +191,45 @@ class FetcherBase(kvStore: KVStore, ((responseBytes.toDouble / totalResponseBytes.toDouble) * latencyMillis).toLong) } - private def updateServingInfo(batchEndTs: Long, - groupByServingInfo: GroupByServingInfoParsed): GroupByServingInfoParsed = { + /** + * Get the latest serving information based on a batch response. + * + * The underlying metadata store used to store the latest GroupByServingInfoParsed will be updated if needed. + * + * @param oldServingInfo The previous serving information before fetching the latest KV store data. + * @param batchResponses the latest batch responses (either a fresh KV store response or a cached batch ir). + * @return the GroupByServingInfoParsed containing the latest serving information. + */ + private[online] def getServingInfo(oldServingInfo: GroupByServingInfoParsed, + batchResponses: BatchResponses): GroupByServingInfoParsed = { + batchResponses match { + case batchTimedValuesTry: KvStoreBatchResponse => { + val latestBatchValue: Try[TimedValue] = batchTimedValuesTry.response.map(_.maxBy(_.millis)) + latestBatchValue.map(timedVal => updateServingInfo(timedVal.millis, oldServingInfo)).getOrElse(oldServingInfo) + } + case _: CachedBatchResponse => { + // If there was cached batch data, there's no point try to update the serving info; it would be the same. + // However, there's one edge case to be handled. If all batch requests are cached and we never hit the kv store, + // we will never try to update the serving info. In that case, if new batch data were to land, we would never + // know of it. So, we force a refresh here to ensure that we are still periodically asynchronously hitting the + // KV store to update the serving info. (See CHIP-1) + getGroupByServingInfo.refresh(oldServingInfo.groupByOps.metaData.name) + + oldServingInfo + } + } + } + + /** + * If `batchEndTs` is ahead of `groupByServingInfo.batchEndTsMillis`, update the MetadataStore with the new + * timestamp. In practice, this means that new batch data has landed, so future kvstore requests should fetch + * streaming data after the new batchEndTsMillis. + * + * @param batchEndTs the new batchEndTs from the latest batch data + * @param groupByServingInfo the current GroupByServingInfo + */ + private[online] def updateServingInfo(batchEndTs: Long, + groupByServingInfo: GroupByServingInfoParsed): GroupByServingInfoParsed = { val name = groupByServingInfo.groupBy.metaData.name if (batchEndTs > groupByServingInfo.batchEndTsMillis) { logger.info(s"""$name's value's batch timestamp of $batchEndTs is @@ -196,6 +249,22 @@ class FetcherBase(kvStore: KVStore, } } + override def isCachingEnabled(groupBy: GroupBy): Boolean = { + if (!isCacheSizeConfigured || groupBy.getMetaData == null || groupBy.getMetaData.getName == null) return false + + val isCachingFlagEnabled = + Option(flagStore) + .exists( + _.isSet("enable_fetcher_batch_ir_cache", + Map("groupby_streaming_dataset" -> groupBy.getMetaData.getName).asJava)) + + if (debug) + logger.info( + s"Online IR caching is ${if (isCachingFlagEnabled) "enabled" else "disabled"} for ${groupBy.getMetaData.getName}") + + isCachingFlagEnabled + } + // 1. fetches GroupByServingInfo // 2. encodes keys as keyAvroSchema // 3. Based on accuracy, fetches streaming + batch data and aggregates further. @@ -253,15 +322,25 @@ class FetcherBase(kvStore: KVStore, } request -> groupByRequestMetaTry }.toSeq - val allRequests: Seq[GetRequest] = groupByRequestToKvRequest.flatMap { - case (_, Success(GroupByRequestMeta(_, batchRequest, streamingRequestOpt, _, _))) => - Some(batchRequest) ++ streamingRequestOpt + + // If caching is enabled, we check if any of the GetRequests are already cached. If so, we store them in a Map + // and avoid the work of re-fetching them. It is mainly for batch data requests. + val cachedRequests: Map[GetRequest, CachedBatchResponse] = getCachedRequests(groupByRequestToKvRequest) + // Collect cache metrics once per fetchGroupBys call; Caffeine metrics aren't tagged by groupBy + maybeBatchIrCache.foreach(cache => + LRUCache.collectCaffeineCacheMetrics(caffeineMetricsContext, cache.cache, cache.cacheName)) + + val allRequestsToFetch: Seq[GetRequest] = groupByRequestToKvRequest.flatMap { + case (_, Success(GroupByRequestMeta(_, batchRequest, streamingRequestOpt, _, _))) => { + // If a batch request is cached, don't include it in the list of requests to fetch because the batch IRs already cached + if (cachedRequests.contains(batchRequest)) streamingRequestOpt else Some(batchRequest) ++ streamingRequestOpt + } case _ => Seq.empty } val startTimeMs = System.currentTimeMillis() - val kvResponseFuture: Future[Seq[GetResponse]] = if (allRequests.nonEmpty) { - kvStore.multiGet(allRequests) + val kvResponseFuture: Future[Seq[GetResponse]] = if (allRequestsToFetch.nonEmpty) { + kvStore.multiGet(allRequestsToFetch) } else { Future(Seq.empty[GetResponse]) } @@ -278,20 +357,33 @@ class FetcherBase(kvStore: KVStore, .filter(_.isSuccess) .flatMap(_.get.map(v => Option(v.bytes).map(_.length).getOrElse(0))) .sum + val responses: Seq[Response] = groupByRequestToKvRequest.iterator.map { case (request, requestMetaTry) => val responseMapTry: Try[Map[String, AnyRef]] = requestMetaTry.map { requestMeta => val GroupByRequestMeta(groupByServingInfo, batchRequest, streamingRequestOpt, _, context) = requestMeta - context.count("multi_get.batch.size", allRequests.length) + + context.count("multi_get.batch.size", allRequestsToFetch.length) context.distribution("multi_get.bytes", totalResponseValueBytes) context.distribution("multi_get.response.length", kvResponses.length) context.distribution("multi_get.latency.millis", multiGetMillis) + // pick the batch version with highest timestamp - val batchResponseTryAll = responsesMap - .getOrElse(batchRequest, - Failure( - new IllegalStateException( - s"Couldn't find corresponding response for $batchRequest in responseMap"))) + val batchResponses: BatchResponses = + // Check if the get request was cached. If so, use the cache. Otherwise, try to get it from response. + cachedRequests.get(batchRequest) match { + case None => + BatchResponses( + responsesMap + .getOrElse( + batchRequest, + // Fail if response is neither in responsesMap nor in cache + Failure(new IllegalStateException( + s"Couldn't find corresponding response for $batchRequest in responseMap or cache")) + )) + case Some(cachedResponse: CachedBatchResponse) => cachedResponse + } + val streamingResponsesOpt = streamingRequestOpt.map(responsesMap.getOrElse(_, Success(Seq.empty)).getOrElse(Seq.empty)) val queryTs = request.atMillis.getOrElse(System.currentTimeMillis()) @@ -301,14 +393,15 @@ class FetcherBase(kvStore: KVStore, logger.info( s"Constructing response for groupBy: ${groupByServingInfo.groupByOps.metaData.getName} " + s"for keys: ${request.keys}") - constructGroupByResponse(batchResponseTryAll, + constructGroupByResponse(batchResponses, streamingResponsesOpt, groupByServingInfo, queryTs, startTimeMs, multiGetMillis, context, - totalResponseValueBytes) + totalResponseValueBytes, + request.keys) } catch { case ex: Exception => // not all exceptions are due to stale schema, so we want to control how often we hit kv store @@ -358,6 +451,9 @@ class FetcherBase(kvStore: KVStore, } } + /** + * Convert an array of bytes to a FinalBatchIr. + */ def toBatchIr(bytes: Array[Byte], gbInfo: GroupByServingInfoParsed): FinalBatchIr = { if (bytes == null) return null val batchRecord = @@ -542,3 +638,13 @@ class FetcherBase(kvStore: KVStore, } } } + +object FetcherBase { + private[online] case class GroupByRequestMeta( + groupByServingInfoParsed: GroupByServingInfoParsed, + batchRequest: GetRequest, + streamingRequestOpt: Option[GetRequest], + endTs: Option[Long], + context: Metrics.Context + ) +} diff --git a/online/src/main/scala/ai/chronon/online/FetcherCache.scala b/online/src/main/scala/ai/chronon/online/FetcherCache.scala new file mode 100644 index 000000000..41739ee4e --- /dev/null +++ b/online/src/main/scala/ai/chronon/online/FetcherCache.scala @@ -0,0 +1,216 @@ +package ai.chronon.online + +import ai.chronon.aggregator.windowing.FinalBatchIr +import ai.chronon.api.Extensions.MetadataOps +import ai.chronon.api.GroupBy +import ai.chronon.online.FetcherBase.GroupByRequestMeta +import ai.chronon.online.Fetcher.Request +import ai.chronon.online.FetcherCache.{ + BatchIrCache, + BatchResponses, + CachedBatchResponse, + CachedFinalIrBatchResponse, + CachedMapBatchResponse, + KvStoreBatchResponse +} +import ai.chronon.online.KVStore.{GetRequest, TimedValue} +import com.github.benmanes.caffeine.cache.{Cache => CaffeineCache} + +import scala.util.{Success, Try} +import java.util.concurrent.ConcurrentHashMap +import scala.collection.JavaConverters.mapAsScalaConcurrentMapConverter +import scala.collection.Seq +import org.slf4j.{Logger, LoggerFactory} + +/* + * FetcherCache is an extension to FetcherBase that provides caching functionality. It caches KV store + * requests to decrease feature serving latency. + * + * To use it, + * 1. Set the system property `ai.chronon.fetcher.batch_ir_cache_size_elements` to the desired cache size + * in therms of elements. This will create a cache shared across all GroupBys. To determine a size, start with a + * small number (e.g. 1,000) and measure how much memory it uses, then adjust accordingly. + * 2. Enable caching for a specific GroupBy by overriding `isCachingEnabled` and returning `true` for that GroupBy. + * FetcherBase already provides an implementation of `isCachingEnabled` that uses the FlagStore. + * */ +trait FetcherCache { + @transient private lazy val logger: Logger = LoggerFactory.getLogger(getClass) + + val batchIrCacheName = "batch_cache" + val maybeBatchIrCache: Option[BatchIrCache] = + Option(System.getProperty("ai.chronon.fetcher.batch_ir_cache_size_elements")) + .map(size => new BatchIrCache(batchIrCacheName, size.toInt)) + .orElse(None) + + // Caching needs to be configured globally + def isCacheSizeConfigured: Boolean = maybeBatchIrCache.isDefined + // Caching needs to be enabled for the specific groupBy + def isCachingEnabled(groupBy: GroupBy): Boolean = false + + protected val caffeineMetricsContext: Metrics.Context = Metrics.Context(Metrics.Environment.JoinFetching) + + /** + * Obtain the Map[String, AnyRef] response from a batch response. + * + * If batch IR caching is enabled, this method will try to fetch the IR from the cache. If it's not in the cache, + * it will decode it from the batch bytes and store it. + * + * @param batchResponses the batch responses + * @param batchBytes the batch bytes corresponding to the batchResponses. Can be `null`. + * @param servingInfo the GroupByServingInfoParsed that contains the info to decode the bytes + * @param decodingFunction the function to decode bytes into Map[String, AnyRef] + * @param keys the keys used to fetch this particular batch response, for caching purposes + */ + private[online] def getMapResponseFromBatchResponse(batchResponses: BatchResponses, + batchBytes: Array[Byte], + decodingFunction: Array[Byte] => Map[String, AnyRef], + servingInfo: GroupByServingInfoParsed, + keys: Map[String, Any]): Map[String, AnyRef] = { + if (!isCachingEnabled(servingInfo.groupBy)) return decodingFunction(batchBytes) + + batchResponses match { + case _: KvStoreBatchResponse => + val batchRequestCacheKey = + BatchIrCache.Key(servingInfo.groupByOps.batchDataset, keys, servingInfo.batchEndTsMillis) + val decodedBytes = decodingFunction(batchBytes) + if (decodedBytes != null) + maybeBatchIrCache.get.cache.put(batchRequestCacheKey, CachedMapBatchResponse(decodedBytes)) + decodedBytes + case cachedResponse: CachedBatchResponse => + cachedResponse match { + case CachedFinalIrBatchResponse(_: FinalBatchIr) => decodingFunction(batchBytes) + case CachedMapBatchResponse(mapResponse: Map[String, AnyRef]) => mapResponse + } + } + } + + /** + * Obtain the FinalBatchIr from a batch response. + * + * If batch IR caching is enabled, this method will try to fetch the IR from the cache. If it's not in the cache, + * it will decode it from the batch bytes and store it. + * + * @param batchResponses the batch responses + * @param batchBytes the batch bytes corresponding to the batchResponses. Can be `null`. + * @param servingInfo the GroupByServingInfoParsed that contains the info to decode the bytes + * @param decodingFunction the function to decode bytes into FinalBatchIr + * @param keys the keys used to fetch this particular batch response, for caching purposes + */ + private[online] def getBatchIrFromBatchResponse( + batchResponses: BatchResponses, + batchBytes: Array[Byte], + servingInfo: GroupByServingInfoParsed, + decodingFunction: (Array[Byte], GroupByServingInfoParsed) => FinalBatchIr, + keys: Map[String, Any]): FinalBatchIr = { + if (!isCachingEnabled(servingInfo.groupBy)) return decodingFunction(batchBytes, servingInfo) + + batchResponses match { + case _: KvStoreBatchResponse => + val batchRequestCacheKey = + BatchIrCache.Key(servingInfo.groupByOps.batchDataset, keys, servingInfo.batchEndTsMillis) + val decodedBytes = decodingFunction(batchBytes, servingInfo) + if (decodedBytes != null) + maybeBatchIrCache.get.cache.put(batchRequestCacheKey, CachedFinalIrBatchResponse(decodedBytes)) + decodedBytes + case cachedResponse: CachedBatchResponse => + cachedResponse match { + case CachedFinalIrBatchResponse(finalBatchIr: FinalBatchIr) => finalBatchIr + case CachedMapBatchResponse(_: Map[String, AnyRef]) => decodingFunction(batchBytes, servingInfo) + } + } + } + + /** + * Given a list of GetRequests, return a map of GetRequests to cached FinalBatchIrs. + */ + def getCachedRequests( + groupByRequestToKvRequest: Seq[(Request, Try[GroupByRequestMeta])]): Map[GetRequest, CachedBatchResponse] = { + if (!isCacheSizeConfigured) return Map.empty + + groupByRequestToKvRequest + .map { + case (request, Success(GroupByRequestMeta(servingInfo, batchRequest, _, _, _))) => + if (!isCachingEnabled(servingInfo.groupBy)) { Map.empty } + else { + val batchRequestCacheKey = + BatchIrCache.Key(batchRequest.dataset, request.keys, servingInfo.batchEndTsMillis) + + // Metrics so we can get per-groupby cache metrics + val metricsContext = + request.context.getOrElse(Metrics.Context(Metrics.Environment.JoinFetching, servingInfo.groupBy)) + + maybeBatchIrCache.get.cache.getIfPresent(batchRequestCacheKey) match { + case null => + metricsContext.increment(s"${batchIrCacheName}_gb_misses") + val emptyMap: Map[GetRequest, CachedBatchResponse] = Map.empty + emptyMap + case cachedIr: CachedBatchResponse => + metricsContext.increment(s"${batchIrCacheName}_gb_hits") + Map(batchRequest -> cachedIr) + } + } + case _ => + val emptyMap: Map[GetRequest, CachedBatchResponse] = Map.empty + emptyMap + } + .foldLeft(Map.empty[GetRequest, CachedBatchResponse])(_ ++ _) + } +} + +object FetcherCache { + private[online] class BatchIrCache(val cacheName: String, val maximumSize: Int = 10000) { + import BatchIrCache._ + + val cache: CaffeineCache[Key, Value] = + LRUCache[Key, Value](cacheName = cacheName, maximumSize = maximumSize) + } + + private[online] object BatchIrCache { + // We use the dataset, keys, and batchEndTsMillis to identify a batch request. + // There's one edge case to be aware of: if a batch job is re-run in the same day, the batchEndTsMillis will + // be the same but the underlying data may have have changed. If that new batch data is needed immediately, the + // Fetcher service should be restarted. + case class Key(dataset: String, keys: Map[String, Any], batchEndTsMillis: Long) + + // FinalBatchIr is for GroupBys using temporally accurate aggregation. + // Map[String, Any] is for GroupBys using snapshot accurate aggregation or no aggregation. + type Value = BatchResponses + } + + /** + * Encapsulates the response for a GetRequest for batch data. This response could be the values received from + * a KV Store request, or cached values. + * + * (The fetcher uses these batch values to construct the response for a request for feature values.) + * */ + sealed abstract class BatchResponses { + def getBatchBytes(batchEndTsMillis: Long): Array[Byte] + } + object BatchResponses { + def apply(kvStoreResponse: Try[Seq[TimedValue]]): KvStoreBatchResponse = KvStoreBatchResponse(kvStoreResponse) + def apply(cachedResponse: FinalBatchIr): CachedFinalIrBatchResponse = CachedFinalIrBatchResponse(cachedResponse) + def apply(cachedResponse: Map[String, AnyRef]): CachedMapBatchResponse = CachedMapBatchResponse(cachedResponse) + } + + /** Encapsulates batch response values received from a KV Store request. */ + case class KvStoreBatchResponse(response: Try[Seq[TimedValue]]) extends BatchResponses { + def getBatchBytes(batchEndTsMillis: Long): Array[Byte] = + response + .map(_.maxBy(_.millis)) + .filter(_.millis >= batchEndTsMillis) + .map(_.bytes) + .getOrElse(null) + } + + /** Encapsulates a batch response that was found in the Fetcher's internal IR cache. */ + sealed abstract class CachedBatchResponse extends BatchResponses { + // This is the case where we don't have bytes because the decoded IR was cached so we didn't hit the KV store again. + def getBatchBytes(batchEndTsMillis: Long): Null = null + } + + /** Encapsulates a decoded batch response that was found in the Fetcher's internal IR cache. */ + case class CachedFinalIrBatchResponse(response: FinalBatchIr) extends CachedBatchResponse + + /** Encapsulates a decoded batch response that was found in the Fetcher's internal IR cache */ + case class CachedMapBatchResponse(response: Map[String, AnyRef]) extends CachedBatchResponse +} diff --git a/online/src/main/scala/ai/chronon/online/LRUCache.scala b/online/src/main/scala/ai/chronon/online/LRUCache.scala new file mode 100644 index 000000000..5bdc643f8 --- /dev/null +++ b/online/src/main/scala/ai/chronon/online/LRUCache.scala @@ -0,0 +1,59 @@ +package ai.chronon.online + +import com.github.benmanes.caffeine.cache.{Caffeine, Cache => CaffeineCache} +import org.slf4j.{Logger, LoggerFactory} + +/** + * Utility to create a cache with LRU semantics. + * + * The original purpose of having an LRU cache in Chronon is to cache KVStore calls and decoded IRs + * in the Fetcher. This helps decrease to feature serving latency. + */ +object LRUCache { + @transient private lazy val logger: Logger = LoggerFactory.getLogger(getClass) + + /** + * Build a bounded, thread-safe Caffeine cache that stores KEY-VALUE pairs. + * + * @param cacheName Name of the cache + * @param maximumSize Maximum number of entries in the cache + * @tparam KEY The type of the key used to access the cache + * @tparam VALUE The type of the value stored in the cache + * @return Caffeine cache + */ + def apply[KEY <: Object, VALUE <: Object](cacheName: String, maximumSize: Int = 10000): CaffeineCache[KEY, VALUE] = { + buildCaffeineCache[KEY, VALUE](cacheName, maximumSize) + } + + private def buildCaffeineCache[KEY <: Object, VALUE <: Object]( + cacheName: String, + maximumSize: Int = 10000): CaffeineCache[KEY, VALUE] = { + logger.info(s"Chronon Cache build started. cacheName=$cacheName") + val cache: CaffeineCache[KEY, VALUE] = Caffeine + .newBuilder() + .maximumSize(maximumSize) + .recordStats() + .build[KEY, VALUE]() + logger.info(s"Chronon Cache build finished. cacheName=$cacheName") + cache + } + + /** + * Report metrics for a Caffeine cache. The "cache" tag is added to all metrics. + * + * @param metricsContext Metrics.Context for recording metrics + * @param cache Caffeine cache to get metrics from + * @param cacheName Cache name for tagging + */ + def collectCaffeineCacheMetrics(metricsContext: Metrics.Context, + cache: CaffeineCache[_, _], + cacheName: String): Unit = { + val stats = cache.stats() + metricsContext.gauge(s"$cacheName.hits", stats.hitCount()) + metricsContext.gauge(s"$cacheName.misses", stats.missCount()) + metricsContext.gauge(s"$cacheName.evictions", stats.evictionCount()) + metricsContext.gauge(s"$cacheName.loads", stats.loadCount()) + metricsContext.gauge(s"$cacheName.hit_rate", stats.hitRate()) + metricsContext.gauge(s"$cacheName.average_load_penalty", stats.averageLoadPenalty()) + } +} diff --git a/online/src/main/scala/ai/chronon/online/Metrics.scala b/online/src/main/scala/ai/chronon/online/Metrics.scala index 452a6caa6..b2f079c55 100644 --- a/online/src/main/scala/ai/chronon/online/Metrics.scala +++ b/online/src/main/scala/ai/chronon/online/Metrics.scala @@ -196,6 +196,8 @@ object Metrics { def gauge(metric: String, value: Long): Unit = stats.gauge(prefix(metric), value, tags) + def gauge(metric: String, value: Double): Unit = stats.gauge(prefix(metric), value, tags) + def recordEvent(metric: String, event: Event): Unit = stats.recordEvent(event, prefix(metric), tags) def toTags: Array[String] = { diff --git a/online/src/test/scala/ai/chronon/online/FetcherBaseTest.scala b/online/src/test/scala/ai/chronon/online/FetcherBaseTest.scala index 51b599ce7..d2960d292 100644 --- a/online/src/test/scala/ai/chronon/online/FetcherBaseTest.scala +++ b/online/src/test/scala/ai/chronon/online/FetcherBaseTest.scala @@ -16,7 +16,13 @@ package ai.chronon.online +import ai.chronon.aggregator.windowing.FinalBatchIr +import ai.chronon.api.Extensions.GroupByOps +import ai.chronon.api.{Builders, GroupBy, MetaData} import ai.chronon.online.Fetcher.{ColumnSpec, Request, Response} +import ai.chronon.online.FetcherCache.BatchResponses +import ai.chronon.online.KVStore.TimedValue +import org.junit.Assert.{assertFalse, assertTrue, fail} import org.junit.{Before, Test} import org.mockito.ArgumentMatchers.any import org.mockito.Mockito._ @@ -29,8 +35,9 @@ import org.scalatestplus.mockito.MockitoSugar import scala.concurrent.duration.DurationInt import scala.concurrent.{Await, ExecutionContext, Future} import scala.util.{Failure, Success} +import scala.util.Try -class FetcherBaseTest extends MockitoSugar with Matchers { +class FetcherBaseTest extends MockitoSugar with Matchers with MockitoHelper { val GroupBy = "relevance.short_term_user_features" val Column = "pdp_view_count_14d" val GuestKey = "guest" @@ -51,7 +58,7 @@ class FetcherBaseTest extends MockitoSugar with Matchers { } @Test - def testFetchColumns_SingleQuery(): Unit = { + def testFetchColumnsSingleQuery(): Unit = { // Fetch a single query val keyMap = Map(GuestKey -> GuestId) val query = ColumnSpec(GroupBy, Column, None, Some(keyMap)) @@ -80,7 +87,7 @@ class FetcherBaseTest extends MockitoSugar with Matchers { } @Test - def testFetchColumns_Batch(): Unit = { + def testFetchColumnsBatch(): Unit = { // Fetch a batch of queries val guestKeyMap = Map(GuestKey -> GuestId) val guestQuery = ColumnSpec(GroupBy, Column, Some(GuestKey), Some(guestKeyMap)) @@ -114,11 +121,11 @@ class FetcherBaseTest extends MockitoSugar with Matchers { } @Test - def testFetchColumns_MissingResponse(): Unit = { + def testFetchColumnsMissingResponse(): Unit = { // Fetch a single query val keyMap = Map(GuestKey -> GuestId) val query = ColumnSpec(GroupBy, Column, None, Some(keyMap)) - + doAnswer(new Answer[Future[Seq[Fetcher.Response]]] { def answer(invocation: InvocationOnMock): Future[Seq[Response]] = { Future.successful(Seq()) @@ -130,7 +137,7 @@ class FetcherBaseTest extends MockitoSugar with Matchers { queryResults.contains(query) shouldBe true queryResults.get(query).map(_.values) match { case Some(Failure(ex: IllegalStateException)) => succeed - case _ => fail() + case _ => fail() } // GroupBy request sent to KV store for the query @@ -141,4 +148,76 @@ class FetcherBaseTest extends MockitoSugar with Matchers { actualRequest.get.name shouldBe query.groupByName + "." + query.columnName actualRequest.get.keys shouldBe query.keyMapping.get } + + // updateServingInfo() is called when the batch response is from the KV store. + @Test + def testGetServingInfoShouldCallUpdateServingInfoIfBatchResponseIsFromKvStore(): Unit = { + val oldServingInfo = mock[GroupByServingInfoParsed] + val updatedServingInfo = mock[GroupByServingInfoParsed] + doReturn(updatedServingInfo).when(fetcherBase).updateServingInfo(any(), any()) + + val batchTimedValuesSuccess = Success(Seq(TimedValue(Array(1.toByte), 2000L))) + val kvStoreBatchResponses = BatchResponses(batchTimedValuesSuccess) + + val result = fetcherBase.getServingInfo(oldServingInfo, kvStoreBatchResponses) + + // updateServingInfo is called + result shouldEqual updatedServingInfo + verify(fetcherBase).updateServingInfo(any(), any()) + } + + // If a batch response is cached, the serving info should be refreshed. This is needed to prevent + // the serving info from becoming stale if all the requests are cached. + @Test + def testGetServingInfoShouldRefreshServingInfoIfBatchResponseIsCached(): Unit = { + val ttlCache = mock[TTLCache[String, Try[GroupByServingInfoParsed]]] + doReturn(ttlCache).when(fetcherBase).getGroupByServingInfo + + val oldServingInfo = mock[GroupByServingInfoParsed] + doReturn(Success(oldServingInfo)).when(ttlCache).refresh(any[String]) + + val metaDataMock = mock[MetaData] + val groupByOpsMock = mock[GroupByOps] + metaDataMock.name = "test" + groupByOpsMock.metaData = metaDataMock + doReturn(groupByOpsMock).when(oldServingInfo).groupByOps + + val cachedBatchResponses = BatchResponses(mock[FinalBatchIr]) + val result = fetcherBase.getServingInfo(oldServingInfo, cachedBatchResponses) + + // FetcherBase.updateServingInfo is not called, but getGroupByServingInfo.refresh() is. + result shouldEqual oldServingInfo + verify(ttlCache).refresh(any()) + verify(fetcherBase, never()).updateServingInfo(any(), any()) + } + + @Test + def testIsCachingEnabledCorrectlyDetermineIfCacheIsEnabled(): Unit = { + val flagStore: FlagStore = (flagName: String, attributes: java.util.Map[String, String]) => { + flagName match { + case "enable_fetcher_batch_ir_cache" => + attributes.get("groupby_streaming_dataset") match { + case "test_groupby_2" => false + case "test_groupby_3" => true + case other @ _ => + fail(s"Unexpected groupby_streaming_dataset: $other") + false + } + case _ => false + } + } + + kvStore = mock[KVStore](Answers.RETURNS_DEEP_STUBS) + when(kvStore.executionContext).thenReturn(ExecutionContext.global) + val fetcherBaseWithFlagStore = spy(new FetcherBase(kvStore, flagStore = flagStore)) + when(fetcherBaseWithFlagStore.isCacheSizeConfigured).thenReturn(true) + + def buildGroupByWithCustomJson(name: String): GroupBy = Builders.GroupBy(metaData = Builders.MetaData(name = name)) + + // no name set + assertFalse(fetcherBaseWithFlagStore.isCachingEnabled(Builders.GroupBy())) + + assertFalse(fetcherBaseWithFlagStore.isCachingEnabled(buildGroupByWithCustomJson("test_groupby_2"))) + assertTrue(fetcherBaseWithFlagStore.isCachingEnabled(buildGroupByWithCustomJson("test_groupby_3"))) + } } diff --git a/online/src/test/scala/ai/chronon/online/FetcherCacheTest.scala b/online/src/test/scala/ai/chronon/online/FetcherCacheTest.scala new file mode 100644 index 000000000..6719fa3e3 --- /dev/null +++ b/online/src/test/scala/ai/chronon/online/FetcherCacheTest.scala @@ -0,0 +1,334 @@ +package ai.chronon.online + +import ai.chronon.aggregator.windowing.FinalBatchIr +import ai.chronon.api.Extensions.GroupByOps +import ai.chronon.api.GroupBy +import ai.chronon.online.FetcherBase._ +import ai.chronon.online.Fetcher.Request +import ai.chronon.online.FetcherCache.{BatchIrCache, BatchResponses, CachedMapBatchResponse} +import ai.chronon.online.KVStore.TimedValue +import ai.chronon.online.Metrics.Context +import org.junit.Assert.{assertArrayEquals, assertEquals, assertNull, fail} +import org.junit.Test +import org.mockito.Mockito._ +import org.mockito.ArgumentMatchers.any +import org.mockito.Mockito +import org.mockito.stubbing.Stubber +import org.scalatestplus.mockito.MockitoSugar + +import scala.collection.JavaConverters._ +import scala.util.{Failure, Success, Try} + +trait MockitoHelper extends MockitoSugar { + // We override doReturn to fix a known Java/Scala interoperability issue. Without this fix, we see "doReturn: + // ambiguous reference to overloaded definition" errors. An alternative would be to use the 'mockito-scala' library. + def doReturn(toBeReturned: Any): Stubber = { + Mockito.doReturn(toBeReturned, Nil: _*) + } +} + +class FetcherCacheTest extends MockitoHelper { + class TestableFetcherCache(cache: Option[BatchIrCache]) extends FetcherCache { + override val maybeBatchIrCache: Option[BatchIrCache] = cache + } + val batchIrCacheMaximumSize = 50 + + @Test + def testBatchIrCacheCorrectlyCachesBatchIrs(): Unit = { + val cacheName = "test" + val batchIrCache = new BatchIrCache(cacheName, batchIrCacheMaximumSize) + val dataset = "TEST_GROUPBY_BATCH" + val batchEndTsMillis = 1000L + + def createBatchir(i: Int) = + BatchResponses(FinalBatchIr(collapsed = Array(i), tailHops = Array(Array(Array(i)), Array(Array(i))))) + def createCacheKey(i: Int) = BatchIrCache.Key(dataset, Map("key" -> i), batchEndTsMillis) + + // Create a bunch of test batchIrs and store them in cache + val batchIrs: Map[BatchIrCache.Key, BatchIrCache.Value] = + (0 until batchIrCacheMaximumSize).map(i => createCacheKey(i) -> createBatchir(i)).toMap + batchIrCache.cache.putAll(batchIrs.asJava) + + // Check that the cache contains all the batchIrs we created + batchIrs.foreach(entry => { + val cachedBatchIr = batchIrCache.cache.getIfPresent(entry._1) + assertEquals(cachedBatchIr, entry._2) + }) + } + + @Test + def testBatchIrCacheCorrectlyCachesMapResponse(): Unit = { + val cacheName = "test" + val batchIrCache = new BatchIrCache(cacheName, batchIrCacheMaximumSize) + val dataset = "TEST_GROUPBY_BATCH" + val batchEndTsMillis = 1000L + + def createMapResponse(i: Int) = + BatchResponses(Map("group_by_key" -> i.asInstanceOf[AnyRef])) + def createCacheKey(i: Int) = BatchIrCache.Key(dataset, Map("key" -> i), batchEndTsMillis) + + // Create a bunch of test mapResponses and store them in cache + val mapResponses: Map[BatchIrCache.Key, BatchIrCache.Value] = + (0 until batchIrCacheMaximumSize).map(i => createCacheKey(i) -> createMapResponse(i)).toMap + batchIrCache.cache.putAll(mapResponses.asJava) + + // Check that the cache contains all the mapResponses we created + mapResponses.foreach(entry => { + val cachedBatchIr = batchIrCache.cache.getIfPresent(entry._1) + assertEquals(cachedBatchIr, entry._2) + }) + } + + // Test that the cache keys are compared by equality, not by reference. In practice, this means that if two keys + // have the same (dataset, keys, batchEndTsMillis), they will only be stored once in the cache. + @Test + def testBatchIrCacheKeysAreComparedByEquality(): Unit = { + val cacheName = "test" + val batchIrCache = new BatchIrCache(cacheName, batchIrCacheMaximumSize) + + val dataset = "TEST_GROUPBY_BATCH" + val batchEndTsMillis = 1000L + + def createCacheValue(i: Int) = + BatchResponses(FinalBatchIr(collapsed = Array(i), tailHops = Array(Array(Array(i)), Array(Array(i))))) + def createCacheKey(i: Int) = BatchIrCache.Key(dataset, Map("key" -> i), batchEndTsMillis) + + assert(batchIrCache.cache.estimatedSize() == 0) + batchIrCache.cache.put(createCacheKey(1), createCacheValue(1)) + assert(batchIrCache.cache.estimatedSize() == 1) + // Create a second key object with the same values as the first key, make sure it's not stored separately + batchIrCache.cache.put(createCacheKey(1), createCacheValue(1)) + assert(batchIrCache.cache.estimatedSize() == 1) + } + + @Test + def testGetCachedRequestsReturnsCorrectCachedDataWhenCacheIsEnabled(): Unit = { + val cacheName = "test" + val testCache = Some(new BatchIrCache(cacheName, batchIrCacheMaximumSize)) + val fetcherCache = new TestableFetcherCache(testCache) { + override def isCachingEnabled(groupBy: GroupBy) = true + } + + // Prepare groupByRequestToKvRequest + val batchEndTsMillis = 0L + val keys = Map("key" -> "value") + val eventTs = 1000L + val dataset = "TEST_GROUPBY_BATCH" + val mockGroupByServingInfoParsed = mock[GroupByServingInfoParsed] + val mockContext = mock[Metrics.Context] + val request = Request("req_name", keys, Some(eventTs), Some(mock[Context])) + val getRequest = KVStore.GetRequest("key".getBytes, dataset, Some(eventTs)) + val requestMeta = + GroupByRequestMeta(mockGroupByServingInfoParsed, getRequest, Some(getRequest), Some(eventTs), mockContext) + val groupByRequestToKvRequest: Seq[(Request, Try[GroupByRequestMeta])] = Seq((request, Success(requestMeta))) + + // getCachedRequests should return an empty list when the cache is empty + val cachedRequestBeforePopulating = fetcherCache.getCachedRequests(groupByRequestToKvRequest) + assert(cachedRequestBeforePopulating.isEmpty) + + // Add a GetRequest and a FinalBatchIr + val key = BatchIrCache.Key(getRequest.dataset, keys, batchEndTsMillis) + val finalBatchIr = BatchResponses(FinalBatchIr(Array(1), Array(Array(Array(1)), Array(Array(1))))) + testCache.get.cache.put(key, finalBatchIr) + + // getCachedRequests should return the GetRequest and FinalBatchIr we cached + val cachedRequestsAfterAddingItem = fetcherCache.getCachedRequests(groupByRequestToKvRequest) + assert(cachedRequestsAfterAddingItem.head._1 == getRequest) + assert(cachedRequestsAfterAddingItem.head._2 == finalBatchIr) + } + + @Test + def testGetCachedRequestsDoesNotCacheWhenCacheIsDisabledForGroupBy(): Unit = { + val testCache = new BatchIrCache("test", batchIrCacheMaximumSize) + val spiedTestCache = spy(testCache) + val fetcherCache = new TestableFetcherCache(Some(testCache)) { + // Cache is enabled globally, but disabled for a specific groupBy + override def isCachingEnabled(groupBy: GroupBy) = false + } + + // Prepare groupByRequestToKvRequest + val keys = Map("key" -> "value") + val eventTs = 1000L + val dataset = "TEST_GROUPBY_BATCH" + val mockGroupByServingInfoParsed = mock[GroupByServingInfoParsed] + val mockContext = mock[Metrics.Context] + val request = Request("req_name", keys, Some(eventTs)) + val getRequest = KVStore.GetRequest("key".getBytes, dataset, Some(eventTs)) + val requestMeta = + GroupByRequestMeta(mockGroupByServingInfoParsed, getRequest, Some(getRequest), Some(eventTs), mockContext) + val groupByRequestToKvRequest: Seq[(Request, Try[GroupByRequestMeta])] = Seq((request, Success(requestMeta))) + + val cachedRequests = fetcherCache.getCachedRequests(groupByRequestToKvRequest) + assert(cachedRequests.isEmpty) + // Cache was never called + verify(spiedTestCache, never()).cache + } + + @Test + def testGetBatchBytesReturnsLatestTimedValueBytesIfGreaterThanBatchEnd(): Unit = { + val kvStoreResponse = Success( + Seq(TimedValue(Array(1.toByte), 1000L), TimedValue(Array(2.toByte), 2000L)) + ) + val batchResponses = BatchResponses(kvStoreResponse) + val batchBytes = batchResponses.getBatchBytes(1500L) + assertArrayEquals(Array(2.toByte), batchBytes) + } + + @Test + def testGetBatchBytesReturnsNullIfLatestTimedValueTimestampIsLessThanBatchEnd(): Unit = { + val kvStoreResponse = Success( + Seq(TimedValue(Array(1.toByte), 1000L), TimedValue(Array(2.toByte), 1500L)) + ) + val batchResponses = BatchResponses(kvStoreResponse) + val batchBytes = batchResponses.getBatchBytes(2000L) + assertNull(batchBytes) + } + + @Test + def testGetBatchBytesReturnsNullWhenCachedBatchResponse(): Unit = { + val finalBatchIr = mock[FinalBatchIr] + val batchResponses = BatchResponses(finalBatchIr) + val batchBytes = batchResponses.getBatchBytes(1000L) + assertNull(batchBytes) + } + + @Test + def testGetBatchBytesReturnsNullWhenKvStoreBatchResponseFails(): Unit = { + val kvStoreResponse = Failure(new RuntimeException("KV Store error")) + val batchResponses = BatchResponses(kvStoreResponse) + val batchBytes = batchResponses.getBatchBytes(1000L) + assertNull(batchBytes) + } + + @Test + def testGetBatchIrFromBatchResponseReturnsCorrectIRsWithCacheEnabled(): Unit = { + // Use a real cache + val batchIrCache = new BatchIrCache("test_cache", batchIrCacheMaximumSize) + + // Create all necessary mocks + val servingInfo = mock[GroupByServingInfoParsed] + val groupByOps = mock[GroupByOps] + val toBatchIr = mock[(Array[Byte], GroupByServingInfoParsed) => FinalBatchIr] + when(servingInfo.groupByOps).thenReturn(groupByOps) + when(groupByOps.batchDataset).thenReturn("test_dataset") + when(servingInfo.groupByOps.batchDataset).thenReturn("test_dataset") + when(servingInfo.batchEndTsMillis).thenReturn(1000L) + + // Dummy data + val batchBytes = Array[Byte](1, 1) + val keys = Map("key" -> "value") + val cacheKey = BatchIrCache.Key(servingInfo.groupByOps.batchDataset, keys, servingInfo.batchEndTsMillis) + + val fetcherCache = new TestableFetcherCache(Some(batchIrCache)) + val spiedFetcherCache = Mockito.spy(fetcherCache) + doReturn(true).when(spiedFetcherCache).isCachingEnabled(any()) + + // 1. Cached BatchResponse returns the same IRs passed in + val finalBatchIr1 = mock[FinalBatchIr] + val cachedBatchResponse = BatchResponses(finalBatchIr1) + val cachedIr = + spiedFetcherCache.getBatchIrFromBatchResponse(cachedBatchResponse, batchBytes, servingInfo, toBatchIr, keys) + assertEquals(finalBatchIr1, cachedIr) + verify(toBatchIr, never())(any(classOf[Array[Byte]]), any()) // no decoding needed + + // 2. Un-cached BatchResponse has IRs added to cache + val finalBatchIr2 = mock[FinalBatchIr] + val kvStoreBatchResponses = BatchResponses(Success(Seq(TimedValue(batchBytes, 1000L)))) + when(toBatchIr(any(), any())).thenReturn(finalBatchIr2) + val uncachedIr = + spiedFetcherCache.getBatchIrFromBatchResponse(kvStoreBatchResponses, batchBytes, servingInfo, toBatchIr, keys) + assertEquals(finalBatchIr2, uncachedIr) + assertEquals(batchIrCache.cache.getIfPresent(cacheKey), BatchResponses(finalBatchIr2)) // key was added + verify(toBatchIr, times(1))(any(), any()) // decoding did happen + } + + @Test + def testGetBatchIrFromBatchResponseDecodesBatchBytesIfCacheDisabled(): Unit = { + // Set up mocks and dummy data + val servingInfo = mock[GroupByServingInfoParsed] + val batchBytes = Array[Byte](1, 2, 3) + val keys = Map("key" -> "value") + val finalBatchIr = mock[FinalBatchIr] + val toBatchIr = mock[(Array[Byte], GroupByServingInfoParsed) => FinalBatchIr] + val kvStoreBatchResponses = BatchResponses(Success(Seq(TimedValue(batchBytes, 1000L)))) + + val spiedFetcherCache = Mockito.spy(new TestableFetcherCache(None)) + when(toBatchIr(any(), any())).thenReturn(finalBatchIr) + + // When getBatchIrFromBatchResponse is called, it decodes the bytes and doesn't hit the cache + val ir = + spiedFetcherCache.getBatchIrFromBatchResponse(kvStoreBatchResponses, batchBytes, servingInfo, toBatchIr, keys) + verify(toBatchIr, times(1))(batchBytes, servingInfo) // decoding did happen + assertEquals(finalBatchIr, ir) + } + + @Test + def testGetBatchIrFromBatchResponseReturnsCorrectMapResponseWithCacheEnabled(): Unit = { + // Use a real cache + val batchIrCache = new BatchIrCache("test_cache", batchIrCacheMaximumSize) + // Set up mocks and dummy data + val servingInfo = mock[GroupByServingInfoParsed] + val groupByOps = mock[GroupByOps] + val outputCodec = mock[AvroCodec] + when(servingInfo.groupByOps).thenReturn(groupByOps) + when(groupByOps.batchDataset).thenReturn("test_dataset") + when(servingInfo.groupByOps.batchDataset).thenReturn("test_dataset") + when(servingInfo.batchEndTsMillis).thenReturn(1000L) + val batchBytes = Array[Byte](1, 2, 3) + val keys = Map("key" -> "value") + val cacheKey = BatchIrCache.Key(servingInfo.groupByOps.batchDataset, keys, servingInfo.batchEndTsMillis) + + val spiedFetcherCache = Mockito.spy(new TestableFetcherCache(Some(batchIrCache))) + doReturn(true).when(spiedFetcherCache).isCachingEnabled(any()) + + // 1. Cached BatchResponse returns the same Map responses passed in + val mapResponse1 = mock[Map[String, AnyRef]] + val cachedBatchResponse = BatchResponses(mapResponse1) + val decodingFunction1 = (bytes: Array[Byte]) => { + fail("Decoding function should not be called when batch response is cached") + mapResponse1 + } + val cachedMapResponse = spiedFetcherCache.getMapResponseFromBatchResponse(cachedBatchResponse, + batchBytes, + decodingFunction1, + servingInfo, + keys) + assertEquals(mapResponse1, cachedMapResponse) + + // 2. Un-cached BatchResponse has Map responses added to cache + val mapResponse2 = mock[Map[String, AnyRef]] + val kvStoreBatchResponses = BatchResponses(Success(Seq(TimedValue(batchBytes, 1000L)))) + def decodingFunction2 = (bytes: Array[Byte]) => mapResponse2 + val decodedMapResponse = spiedFetcherCache.getMapResponseFromBatchResponse(kvStoreBatchResponses, + batchBytes, + decodingFunction2, + servingInfo, + keys) + assertEquals(mapResponse2, decodedMapResponse) + assertEquals(batchIrCache.cache.getIfPresent(cacheKey), CachedMapBatchResponse(mapResponse2)) // key was added + } + + @Test + def testGetMapResponseFromBatchResponseDecodesBatchBytesIfCacheDisabled(): Unit = { + // Set up mocks and dummy data + val servingInfo = mock[GroupByServingInfoParsed] + val batchBytes = Array[Byte](1, 2, 3) + val keys = Map("key" -> "value") + val mapResponse = mock[Map[String, AnyRef]] + val outputCodec = mock[AvroCodec] + val kvStoreBatchResponses = BatchResponses(Success(Seq(TimedValue(batchBytes, 1000L)))) + when(servingInfo.outputCodec).thenReturn(outputCodec) + when(outputCodec.decodeMap(any())).thenReturn(mapResponse) + + val spiedFetcherCache = Mockito.spy(new TestableFetcherCache(None)) + + // When getMapResponseFromBatchResponse is called, it decodes the bytes and doesn't hit the cache + val decodedMapResponse = spiedFetcherCache.getMapResponseFromBatchResponse(kvStoreBatchResponses, + batchBytes, + servingInfo.outputCodec.decodeMap, + servingInfo, + keys) + verify(servingInfo.outputCodec.decodeMap(any()), times(1)) // decoding did happen + assertEquals(mapResponse, decodedMapResponse) + } +} diff --git a/online/src/test/scala/ai/chronon/online/LRUCacheTest.scala b/online/src/test/scala/ai/chronon/online/LRUCacheTest.scala new file mode 100644 index 000000000..bcb5060a6 --- /dev/null +++ b/online/src/test/scala/ai/chronon/online/LRUCacheTest.scala @@ -0,0 +1,37 @@ +package ai.chronon.online + +import org.junit.Test +import com.github.benmanes.caffeine.cache.{Cache => CaffeineCache} +import ai.chronon.online.LRUCache + +import scala.collection.JavaConverters._ + +class LRUCacheTest { + val testCache: CaffeineCache[String, String] = LRUCache[String, String]("testCache") + + @Test + def testGetsNothingWhenThereIsNothing(): Unit = { + assert(testCache.getIfPresent("key") == null) + assert(testCache.estimatedSize() == 0) + } + + @Test + def testGetsSomethingWhenThereIsSomething(): Unit = { + assert(testCache.getIfPresent("key") == null) + testCache.put("key", "value") + assert(testCache.getIfPresent("key") == "value") + assert(testCache.estimatedSize() == 1) + } + + @Test + def testEvictsWhenSomethingIsSet(): Unit = { + assert(testCache.estimatedSize() == 0) + assert(testCache.getIfPresent("key") == null) + testCache.put("key", "value") + assert(testCache.estimatedSize() == 1) + assert(testCache.getIfPresent("key") == "value") + testCache.invalidate("key") + assert(testCache.estimatedSize() == 0) + assert(testCache.getIfPresent("key") == null) + } +}