Skip to content

Commit

Permalink
fix: make logging codec thread safe (#857)
Browse files Browse the repository at this point in the history
* make logging codec thread safe

* cherry-pick #775

---------

Co-authored-by: Haozhen Ding <[email protected]>
  • Loading branch information
hzding621 and Haozhen Ding authored Oct 7, 2024
1 parent b9653c4 commit 2a81f64
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 18 deletions.
10 changes: 7 additions & 3 deletions online/src/main/scala/ai/chronon/online/Fetcher.scala
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,8 @@ class Fetcher(val kvStore: KVStore,

val joinName = joinConf.metaData.nameToFilePath
val keySchema = StructType(s"${joinName.sanitize}_key", keyFields.toArray)
val keyCodec = AvroCodec.of(AvroConversions.fromChrononSchema(keySchema).toString)
val baseValueSchema = StructType(s"${joinName.sanitize}_value", valueFields.toArray)
val baseValueCodec = AvroCodec.of(AvroConversions.fromChrononSchema(baseValueSchema).toString)
val joinCodec = JoinCodec(joinConf, keySchema, baseValueSchema, keyCodec, baseValueCodec)
val joinCodec = JoinCodec(joinConf, keySchema, baseValueSchema)
logControlEvent(joinCodec)
joinCodec
}
Expand Down Expand Up @@ -316,6 +314,12 @@ class Fetcher(val kvStore: KVStore,
val loggingTry: Try[Unit] = joinCodecTry.map(codec => {
val metaData = codec.conf.join.metaData
val samplePercent = if (metaData.isSetSamplePercent) metaData.getSamplePercent else 0

// Exit early if sample percent is 0
if (samplePercent == 0) {
return Response(resp.request, Success(resp.derivedValues))
}

val keyBytesTry: Try[Array[Byte]] = encode(loggingContext.map(_.withSuffix("encode_key")),
codec.keySchema,
codec.keyCodec,
Expand Down
25 changes: 11 additions & 14 deletions online/src/main/scala/ai/chronon/online/JoinCodec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

package ai.chronon.online

import ai.chronon.api.Extensions.{DerivationOps, JoinOps, MetadataOps}
import ai.chronon.api.Extensions.{JoinOps, MetadataOps}
import ai.chronon.api.{DataType, HashUtils, StructField, StructType}
import com.google.gson.Gson
import scala.collection.Seq
Expand All @@ -26,16 +26,10 @@ import ai.chronon.online.OnlineDerivationUtil.{
DerivationFunc,
buildDerivationFunction,
buildDerivedFields,
buildRenameOnlyDerivationFunction,
timeFields
buildRenameOnlyDerivationFunction
}

case class JoinCodec(conf: JoinOps,
keySchema: StructType,
baseValueSchema: StructType,
keyCodec: AvroCodec,
baseValueCodec: AvroCodec)
extends Serializable {
case class JoinCodec(conf: JoinOps, keySchema: StructType, baseValueSchema: StructType) extends Serializable {

@transient lazy val valueSchema: StructType = {
val fields = if (conf.join == null || conf.join.derivations == null || baseValueSchema.fields.isEmpty) {
Expand Down Expand Up @@ -65,7 +59,10 @@ case class JoinCodec(conf: JoinOps,
@transient lazy val renameOnlyDeriveFunc: (Map[String, Any], Map[String, Any]) => Map[String, Any] =
buildRenameOnlyDerivationFunction(conf.derivationsScala)

@transient lazy val valueCodec: AvroCodec = AvroCodec.of(AvroConversions.fromChrononSchema(valueSchema).toString)
@transient lazy val keySchemaStr: String = AvroConversions.fromChrononSchema(keySchema).toString
@transient lazy val valueSchemaStr: String = AvroConversions.fromChrononSchema(valueSchema).toString
def keyCodec: AvroCodec = AvroCodec.of(keySchemaStr)
def valueCodec: AvroCodec = AvroCodec.of(valueSchemaStr)

/*
* Get the serialized string repr. of the logging schema.
Expand All @@ -74,7 +71,7 @@ case class JoinCodec(conf: JoinOps,
* Example:
* {"join_name":"unit_test/test_join","key_schema":"{\"type\":\"record\",\"name\":\"unit_test_test_join_key\",\"namespace\":\"ai.chronon.data\",\"doc\":\"\",\"fields\":[{\"name\":\"listing\",\"type\":[\"null\",\"long\"],\"doc\":\"\"}]}","value_schema":"{\"type\":\"record\",\"name\":\"unit_test_test_join_value\",\"namespace\":\"ai.chronon.data\",\"doc\":\"\",\"fields\":[{\"name\":\"unit_test_listing_views_v1_m_guests_sum\",\"type\":[\"null\",\"long\"],\"doc\":\"\"},{\"name\":\"unit_test_listing_views_v1_m_views_sum\",\"type\":[\"null\",\"long\"],\"doc\":\"\"}]}"}
*/
lazy val loggingSchema: String = JoinCodec.buildLoggingSchema(conf.join.metaData.name, keyCodec, valueCodec)
lazy val loggingSchema: String = JoinCodec.buildLoggingSchema(conf.join.metaData.name, keySchemaStr, valueSchemaStr)
lazy val loggingSchemaHash: String = HashUtils.md5Base64(loggingSchema)

val keys: Array[String] = keySchema.fields.iterator.map(_.name).toArray
Expand All @@ -88,11 +85,11 @@ case class JoinCodec(conf: JoinOps,

object JoinCodec {

def buildLoggingSchema(joinName: String, keyCodec: AvroCodec, valueCodec: AvroCodec): String = {
def buildLoggingSchema(joinName: String, keySchemaStr: String, valueSchemaStr: String): String = {
val schemaMap = Map(
"join_name" -> joinName,
"key_schema" -> keyCodec.schemaStr,
"value_schema" -> valueCodec.schemaStr
"key_schema" -> keySchemaStr,
"value_schema" -> valueSchemaStr
)
new Gson().toJson(schemaMap.toJava)
}
Expand Down
3 changes: 2 additions & 1 deletion spark/src/main/scala/ai/chronon/spark/LoggingSchema.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ case class LoggingSchema(keyCodec: AvroCodec, valueCodec: AvroCodec) {
lazy val keyIndices: Map[StructField, Int] = keyFields.zipWithIndex.toMap
lazy val valueIndices: Map[StructField, Int] = valueFields.zipWithIndex.toMap

def hash(joinName: String): String = HashUtils.md5Base64(JoinCodec.buildLoggingSchema(joinName, keyCodec, valueCodec))
def hash(joinName: String): String =
HashUtils.md5Base64(JoinCodec.buildLoggingSchema(joinName, keyCodec.schemaStr, valueCodec.schemaStr))
}

object LoggingSchema {
Expand Down

0 comments on commit 2a81f64

Please sign in to comment.