diff --git a/spark/src/main/scala/ai/chronon/spark/Driver.scala b/spark/src/main/scala/ai/chronon/spark/Driver.scala index 1edc090a0..a5276ea40 100644 --- a/spark/src/main/scala/ai/chronon/spark/Driver.scala +++ b/spark/src/main/scala/ai/chronon/spark/Driver.scala @@ -875,7 +875,6 @@ object Driver { def main(baseArgs: Array[String]): Unit = { val args = new Args(baseArgs) - var shouldExit = true args.subcommand match { case Some(x) => x match { @@ -883,26 +882,21 @@ object Driver { case args.GroupByBackfillArgs => GroupByBackfill.run(args.GroupByBackfillArgs) case args.StagingQueryBackfillArgs => StagingQueryBackfill.run(args.StagingQueryBackfillArgs) case args.GroupByUploadArgs => GroupByUploader.run(args.GroupByUploadArgs) - case args.GroupByStreamingArgs => - shouldExit = false - GroupByStreaming.run(args.GroupByStreamingArgs) - - case args.MetadataUploaderArgs => MetadataUploader.run(args.MetadataUploaderArgs) - case args.FetcherCliArgs => FetcherCli.run(args.FetcherCliArgs) - case args.LogFlattenerArgs => LogFlattener.run(args.LogFlattenerArgs) - case args.ConsistencyMetricsArgs => ConsistencyMetricsCompute.run(args.ConsistencyMetricsArgs) - case args.CompareJoinQueryArgs => CompareJoinQuery.run(args.CompareJoinQueryArgs) - case args.AnalyzerArgs => Analyzer.run(args.AnalyzerArgs) - case args.DailyStatsArgs => DailyStats.run(args.DailyStatsArgs) - case args.LogStatsArgs => LogStats.run(args.LogStatsArgs) - case args.MetadataExportArgs => MetadataExport.run(args.MetadataExportArgs) - case args.LabelJoinArgs => LabelJoin.run(args.LabelJoinArgs) - case _ => logger.info(s"Unknown subcommand: $x") + case args.GroupByStreamingArgs => GroupByStreaming.run(args.GroupByStreamingArgs) + case args.MetadataUploaderArgs => MetadataUploader.run(args.MetadataUploaderArgs) + case args.FetcherCliArgs => FetcherCli.run(args.FetcherCliArgs) + case args.LogFlattenerArgs => LogFlattener.run(args.LogFlattenerArgs) + case args.ConsistencyMetricsArgs => ConsistencyMetricsCompute.run(args.ConsistencyMetricsArgs) + case args.CompareJoinQueryArgs => CompareJoinQuery.run(args.CompareJoinQueryArgs) + case args.AnalyzerArgs => Analyzer.run(args.AnalyzerArgs) + case args.DailyStatsArgs => DailyStats.run(args.DailyStatsArgs) + case args.LogStatsArgs => LogStats.run(args.LogStatsArgs) + case args.MetadataExportArgs => MetadataExport.run(args.MetadataExportArgs) + case args.LabelJoinArgs => LabelJoin.run(args.LabelJoinArgs) + case _ => logger.info(s"Unknown subcommand: $x") } case None => logger.info(s"specify a subcommand please") } - if (shouldExit) { - System.exit(0) - } + System.exit(0) } } diff --git a/spark/src/main/scala/ai/chronon/spark/streaming/GroupBy.scala b/spark/src/main/scala/ai/chronon/spark/streaming/GroupBy.scala index cfa5fe48c..2775a4216 100644 --- a/spark/src/main/scala/ai/chronon/spark/streaming/GroupBy.scala +++ b/spark/src/main/scala/ai/chronon/spark/streaming/GroupBy.scala @@ -34,6 +34,7 @@ import java.time.{Instant, ZoneId, ZoneOffset} import java.util.Base64 import scala.collection.JavaConverters._ import scala.concurrent.duration.{DurationInt} +import scala.util.{Failure, Success} class GroupBy(inputStream: DataFrame, session: SparkSession, @@ -82,8 +83,13 @@ class GroupBy(inputStream: DataFrame, def buildDataStream(local: Boolean = false): DataStreamWriter[KVStore.PutRequest] = { val streamingTable = groupByConf.metaData.cleanName + "_stream" val fetcher = onlineImpl.buildFetcher(local) - val groupByServingInfo = fetcher.getGroupByServingInfo(groupByConf.getMetaData.getName).get - + val groupByServingInfoOpt = fetcher.getGroupByServingInfo(groupByConf.getMetaData.getName) + if (groupByServingInfoOpt.isFailure) { + logger.error(s"Failed to retrieve groupByServingInfo: ${groupByServingInfoOpt.failed.get.getMessage}") + session.stop() + sys.exit(1) + } + val groupByServingInfo = groupByServingInfoOpt.get val streamDecoder = onlineImpl.streamDecoder(groupByServingInfo) assert(groupByConf.streamingSource.isDefined, "No streaming source defined in GroupBy. Please set a topic/mutationTopic.")