Skip to content

Commit

Permalink
[CHIP-1] Cache batch IRs in the Fetcher (#682)
Browse files Browse the repository at this point in the history
* Add LRU Cache

* Double gauge

* Revert accidental build sbt change

* Move GroupByRequestMeta to object

* Add FetcherCache and tests

* Update FetcherBase to use cache

* Scala 2.13 support?

* Add getServingInfo unit tests

* Refactor getServingInfo tests

* Fix stub

* Fix Mockito "ambiguous reference to overloaded definition" error

* Fix "Both batch and streaming data are null" check

* Address PR review: add comments and use logger

* Fewer comments for constructGroupByResponse

* Use the FlagStore to determine if caching is enabled

* Apply suggestions from code review

Co-authored-by: Pengyu Hou <[email protected]>
Signed-off-by: Caio Camatta (Stripe) <[email protected]>

* Address review, add comments, rename tests

* Change test names

* CamelCase FetcherBaseTest

* Update online/src/main/scala/ai/chronon/online/FetcherBase.scala

Co-authored-by: Pengyu Hou <[email protected]>
Signed-off-by: Caio Camatta (Stripe) <[email protected]>

* fmt

---------

Signed-off-by: Caio Camatta (Stripe) <[email protected]>
Co-authored-by: Pengyu Hou <[email protected]>
  • Loading branch information
caiocamatta-stripe and pengyu-hou authored Jun 27, 2024
1 parent caa0d1e commit 215c014
Show file tree
Hide file tree
Showing 9 changed files with 942 additions and 100 deletions.
6 changes: 4 additions & 2 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,8 @@ lazy val online = project
// statsd 3.0 has local aggregation - TODO: upgrade
"com.datadoghq" % "java-dogstatsd-client" % "2.7",
"org.rogach" %% "scallop" % "4.0.1",
"net.jodah" % "typetools" % "0.4.1"
"net.jodah" % "typetools" % "0.4.1",
"com.github.ben-manes.caffeine" % "caffeine" % "2.9.3"
),
libraryDependencies ++= fromMatrix(scalaVersion.value, "spark-all", "scala-parallel-collections", "netty-buffer"),
version := git.versionProperty.value
Expand All @@ -326,7 +327,8 @@ lazy val online_unshaded = (project in file("online"))
// statsd 3.0 has local aggregation - TODO: upgrade
"com.datadoghq" % "java-dogstatsd-client" % "2.7",
"org.rogach" %% "scallop" % "4.0.1",
"net.jodah" % "typetools" % "0.4.1"
"net.jodah" % "typetools" % "0.4.1",
"com.github.ben-manes.caffeine" % "caffeine" % "2.9.3"
),
libraryDependencies ++= fromMatrix(scalaVersion.value,
"jackson",
Expand Down
9 changes: 8 additions & 1 deletion online/src/main/java/ai/chronon/online/FlagStore.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,14 @@
import java.io.Serializable;
import java.util.Map;

// Interface to allow rolling out features/infrastructure changes in a safe & controlled manner
/**
* Interface to allow rolling out features/infrastructure changes in a safe & controlled manner.
*
* The "Flag"s in FlagStore referes to 'feature flags', a technique that allows enabling or disabling features at
* runtime.
*
* Chronon users can provide their own implementation in the Api.
*/
public interface FlagStore extends Serializable {
Boolean isSet(String flagName, Map<String, String> attributes);
}
288 changes: 197 additions & 91 deletions online/src/main/scala/ai/chronon/online/FetcherBase.scala

Large diffs are not rendered by default.

216 changes: 216 additions & 0 deletions online/src/main/scala/ai/chronon/online/FetcherCache.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
package ai.chronon.online

import ai.chronon.aggregator.windowing.FinalBatchIr
import ai.chronon.api.Extensions.MetadataOps
import ai.chronon.api.GroupBy
import ai.chronon.online.FetcherBase.GroupByRequestMeta
import ai.chronon.online.Fetcher.Request
import ai.chronon.online.FetcherCache.{
BatchIrCache,
BatchResponses,
CachedBatchResponse,
CachedFinalIrBatchResponse,
CachedMapBatchResponse,
KvStoreBatchResponse
}
import ai.chronon.online.KVStore.{GetRequest, TimedValue}
import com.github.benmanes.caffeine.cache.{Cache => CaffeineCache}

import scala.util.{Success, Try}
import java.util.concurrent.ConcurrentHashMap
import scala.collection.JavaConverters.mapAsScalaConcurrentMapConverter
import scala.collection.Seq
import org.slf4j.{Logger, LoggerFactory}

/*
* FetcherCache is an extension to FetcherBase that provides caching functionality. It caches KV store
* requests to decrease feature serving latency.
*
* To use it,
* 1. Set the system property `ai.chronon.fetcher.batch_ir_cache_size_elements` to the desired cache size
* in therms of elements. This will create a cache shared across all GroupBys. To determine a size, start with a
* small number (e.g. 1,000) and measure how much memory it uses, then adjust accordingly.
* 2. Enable caching for a specific GroupBy by overriding `isCachingEnabled` and returning `true` for that GroupBy.
* FetcherBase already provides an implementation of `isCachingEnabled` that uses the FlagStore.
* */
trait FetcherCache {
@transient private lazy val logger: Logger = LoggerFactory.getLogger(getClass)

val batchIrCacheName = "batch_cache"
val maybeBatchIrCache: Option[BatchIrCache] =
Option(System.getProperty("ai.chronon.fetcher.batch_ir_cache_size_elements"))
.map(size => new BatchIrCache(batchIrCacheName, size.toInt))
.orElse(None)

// Caching needs to be configured globally
def isCacheSizeConfigured: Boolean = maybeBatchIrCache.isDefined
// Caching needs to be enabled for the specific groupBy
def isCachingEnabled(groupBy: GroupBy): Boolean = false

protected val caffeineMetricsContext: Metrics.Context = Metrics.Context(Metrics.Environment.JoinFetching)

/**
* Obtain the Map[String, AnyRef] response from a batch response.
*
* If batch IR caching is enabled, this method will try to fetch the IR from the cache. If it's not in the cache,
* it will decode it from the batch bytes and store it.
*
* @param batchResponses the batch responses
* @param batchBytes the batch bytes corresponding to the batchResponses. Can be `null`.
* @param servingInfo the GroupByServingInfoParsed that contains the info to decode the bytes
* @param decodingFunction the function to decode bytes into Map[String, AnyRef]
* @param keys the keys used to fetch this particular batch response, for caching purposes
*/
private[online] def getMapResponseFromBatchResponse(batchResponses: BatchResponses,
batchBytes: Array[Byte],
decodingFunction: Array[Byte] => Map[String, AnyRef],
servingInfo: GroupByServingInfoParsed,
keys: Map[String, Any]): Map[String, AnyRef] = {
if (!isCachingEnabled(servingInfo.groupBy)) return decodingFunction(batchBytes)

batchResponses match {
case _: KvStoreBatchResponse =>
val batchRequestCacheKey =
BatchIrCache.Key(servingInfo.groupByOps.batchDataset, keys, servingInfo.batchEndTsMillis)
val decodedBytes = decodingFunction(batchBytes)
if (decodedBytes != null)
maybeBatchIrCache.get.cache.put(batchRequestCacheKey, CachedMapBatchResponse(decodedBytes))
decodedBytes
case cachedResponse: CachedBatchResponse =>
cachedResponse match {
case CachedFinalIrBatchResponse(_: FinalBatchIr) => decodingFunction(batchBytes)
case CachedMapBatchResponse(mapResponse: Map[String, AnyRef]) => mapResponse
}
}
}

/**
* Obtain the FinalBatchIr from a batch response.
*
* If batch IR caching is enabled, this method will try to fetch the IR from the cache. If it's not in the cache,
* it will decode it from the batch bytes and store it.
*
* @param batchResponses the batch responses
* @param batchBytes the batch bytes corresponding to the batchResponses. Can be `null`.
* @param servingInfo the GroupByServingInfoParsed that contains the info to decode the bytes
* @param decodingFunction the function to decode bytes into FinalBatchIr
* @param keys the keys used to fetch this particular batch response, for caching purposes
*/
private[online] def getBatchIrFromBatchResponse(
batchResponses: BatchResponses,
batchBytes: Array[Byte],
servingInfo: GroupByServingInfoParsed,
decodingFunction: (Array[Byte], GroupByServingInfoParsed) => FinalBatchIr,
keys: Map[String, Any]): FinalBatchIr = {
if (!isCachingEnabled(servingInfo.groupBy)) return decodingFunction(batchBytes, servingInfo)

batchResponses match {
case _: KvStoreBatchResponse =>
val batchRequestCacheKey =
BatchIrCache.Key(servingInfo.groupByOps.batchDataset, keys, servingInfo.batchEndTsMillis)
val decodedBytes = decodingFunction(batchBytes, servingInfo)
if (decodedBytes != null)
maybeBatchIrCache.get.cache.put(batchRequestCacheKey, CachedFinalIrBatchResponse(decodedBytes))
decodedBytes
case cachedResponse: CachedBatchResponse =>
cachedResponse match {
case CachedFinalIrBatchResponse(finalBatchIr: FinalBatchIr) => finalBatchIr
case CachedMapBatchResponse(_: Map[String, AnyRef]) => decodingFunction(batchBytes, servingInfo)
}
}
}

/**
* Given a list of GetRequests, return a map of GetRequests to cached FinalBatchIrs.
*/
def getCachedRequests(
groupByRequestToKvRequest: Seq[(Request, Try[GroupByRequestMeta])]): Map[GetRequest, CachedBatchResponse] = {
if (!isCacheSizeConfigured) return Map.empty

groupByRequestToKvRequest
.map {
case (request, Success(GroupByRequestMeta(servingInfo, batchRequest, _, _, _))) =>
if (!isCachingEnabled(servingInfo.groupBy)) { Map.empty }
else {
val batchRequestCacheKey =
BatchIrCache.Key(batchRequest.dataset, request.keys, servingInfo.batchEndTsMillis)

// Metrics so we can get per-groupby cache metrics
val metricsContext =
request.context.getOrElse(Metrics.Context(Metrics.Environment.JoinFetching, servingInfo.groupBy))

maybeBatchIrCache.get.cache.getIfPresent(batchRequestCacheKey) match {
case null =>
metricsContext.increment(s"${batchIrCacheName}_gb_misses")
val emptyMap: Map[GetRequest, CachedBatchResponse] = Map.empty
emptyMap
case cachedIr: CachedBatchResponse =>
metricsContext.increment(s"${batchIrCacheName}_gb_hits")
Map(batchRequest -> cachedIr)
}
}
case _ =>
val emptyMap: Map[GetRequest, CachedBatchResponse] = Map.empty
emptyMap
}
.foldLeft(Map.empty[GetRequest, CachedBatchResponse])(_ ++ _)
}
}

object FetcherCache {
private[online] class BatchIrCache(val cacheName: String, val maximumSize: Int = 10000) {
import BatchIrCache._

val cache: CaffeineCache[Key, Value] =
LRUCache[Key, Value](cacheName = cacheName, maximumSize = maximumSize)
}

private[online] object BatchIrCache {
// We use the dataset, keys, and batchEndTsMillis to identify a batch request.
// There's one edge case to be aware of: if a batch job is re-run in the same day, the batchEndTsMillis will
// be the same but the underlying data may have have changed. If that new batch data is needed immediately, the
// Fetcher service should be restarted.
case class Key(dataset: String, keys: Map[String, Any], batchEndTsMillis: Long)

// FinalBatchIr is for GroupBys using temporally accurate aggregation.
// Map[String, Any] is for GroupBys using snapshot accurate aggregation or no aggregation.
type Value = BatchResponses
}

/**
* Encapsulates the response for a GetRequest for batch data. This response could be the values received from
* a KV Store request, or cached values.
*
* (The fetcher uses these batch values to construct the response for a request for feature values.)
* */
sealed abstract class BatchResponses {
def getBatchBytes(batchEndTsMillis: Long): Array[Byte]
}
object BatchResponses {
def apply(kvStoreResponse: Try[Seq[TimedValue]]): KvStoreBatchResponse = KvStoreBatchResponse(kvStoreResponse)
def apply(cachedResponse: FinalBatchIr): CachedFinalIrBatchResponse = CachedFinalIrBatchResponse(cachedResponse)
def apply(cachedResponse: Map[String, AnyRef]): CachedMapBatchResponse = CachedMapBatchResponse(cachedResponse)
}

/** Encapsulates batch response values received from a KV Store request. */
case class KvStoreBatchResponse(response: Try[Seq[TimedValue]]) extends BatchResponses {
def getBatchBytes(batchEndTsMillis: Long): Array[Byte] =
response
.map(_.maxBy(_.millis))
.filter(_.millis >= batchEndTsMillis)
.map(_.bytes)
.getOrElse(null)
}

/** Encapsulates a batch response that was found in the Fetcher's internal IR cache. */
sealed abstract class CachedBatchResponse extends BatchResponses {
// This is the case where we don't have bytes because the decoded IR was cached so we didn't hit the KV store again.
def getBatchBytes(batchEndTsMillis: Long): Null = null
}

/** Encapsulates a decoded batch response that was found in the Fetcher's internal IR cache. */
case class CachedFinalIrBatchResponse(response: FinalBatchIr) extends CachedBatchResponse

/** Encapsulates a decoded batch response that was found in the Fetcher's internal IR cache */
case class CachedMapBatchResponse(response: Map[String, AnyRef]) extends CachedBatchResponse
}
59 changes: 59 additions & 0 deletions online/src/main/scala/ai/chronon/online/LRUCache.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package ai.chronon.online

import com.github.benmanes.caffeine.cache.{Caffeine, Cache => CaffeineCache}
import org.slf4j.{Logger, LoggerFactory}

/**
* Utility to create a cache with LRU semantics.
*
* The original purpose of having an LRU cache in Chronon is to cache KVStore calls and decoded IRs
* in the Fetcher. This helps decrease to feature serving latency.
*/
object LRUCache {
@transient private lazy val logger: Logger = LoggerFactory.getLogger(getClass)

/**
* Build a bounded, thread-safe Caffeine cache that stores KEY-VALUE pairs.
*
* @param cacheName Name of the cache
* @param maximumSize Maximum number of entries in the cache
* @tparam KEY The type of the key used to access the cache
* @tparam VALUE The type of the value stored in the cache
* @return Caffeine cache
*/
def apply[KEY <: Object, VALUE <: Object](cacheName: String, maximumSize: Int = 10000): CaffeineCache[KEY, VALUE] = {
buildCaffeineCache[KEY, VALUE](cacheName, maximumSize)
}

private def buildCaffeineCache[KEY <: Object, VALUE <: Object](
cacheName: String,
maximumSize: Int = 10000): CaffeineCache[KEY, VALUE] = {
logger.info(s"Chronon Cache build started. cacheName=$cacheName")
val cache: CaffeineCache[KEY, VALUE] = Caffeine
.newBuilder()
.maximumSize(maximumSize)
.recordStats()
.build[KEY, VALUE]()
logger.info(s"Chronon Cache build finished. cacheName=$cacheName")
cache
}

/**
* Report metrics for a Caffeine cache. The "cache" tag is added to all metrics.
*
* @param metricsContext Metrics.Context for recording metrics
* @param cache Caffeine cache to get metrics from
* @param cacheName Cache name for tagging
*/
def collectCaffeineCacheMetrics(metricsContext: Metrics.Context,
cache: CaffeineCache[_, _],
cacheName: String): Unit = {
val stats = cache.stats()
metricsContext.gauge(s"$cacheName.hits", stats.hitCount())
metricsContext.gauge(s"$cacheName.misses", stats.missCount())
metricsContext.gauge(s"$cacheName.evictions", stats.evictionCount())
metricsContext.gauge(s"$cacheName.loads", stats.loadCount())
metricsContext.gauge(s"$cacheName.hit_rate", stats.hitRate())
metricsContext.gauge(s"$cacheName.average_load_penalty", stats.averageLoadPenalty())
}
}
2 changes: 2 additions & 0 deletions online/src/main/scala/ai/chronon/online/Metrics.scala
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,8 @@ object Metrics {

def gauge(metric: String, value: Long): Unit = stats.gauge(prefix(metric), value, tags)

def gauge(metric: String, value: Double): Unit = stats.gauge(prefix(metric), value, tags)

def recordEvent(metric: String, event: Event): Unit = stats.recordEvent(event, prefix(metric), tags)

def toTags: Array[String] = {
Expand Down
Loading

0 comments on commit 215c014

Please sign in to comment.