diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/DataTypesTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/DataTypesTests.cs new file mode 100644 index 000000000..919f61cb8 --- /dev/null +++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/DataTypesTests.cs @@ -0,0 +1,65 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.Spark.Sql; +using Microsoft.Spark.Sql.Types; +using Xunit; + +namespace Microsoft.Spark.E2ETest.IpcTests +{ + + [Collection("Spark E2E Tests")] + public class DataTypesTests + { + private readonly SparkSession _spark; + + public DataTypesTests(SparkFixture fixture) + { + _spark = fixture.Spark; + } + + /// + /// Tests that we can pass a decimal over to Apache Spark and collect it back again, include a check + /// for the minimum and maximum decimal that .NET can represent + /// + [Fact] + public void TestDecimalType() + { + var df = _spark.CreateDataFrame( + new List + { + new GenericRow( + new object[] + { + decimal.MinValue, decimal.MaxValue, decimal.Zero, decimal.MinusOne, + new object[] + { + decimal.MinValue, decimal.MaxValue, decimal.Zero, decimal.MinusOne + } + }), + }, + new StructType( + new List() + { + new StructField("min", new DecimalType(38, 0)), + new StructField("max", new DecimalType(38, 0)), + new StructField("zero", new DecimalType(38, 0)), + new StructField("minusOne", new DecimalType(38, 0)), + new StructField("array", new ArrayType(new DecimalType(38,0))) + })); + + Row row = df.Collect().First(); + Assert.Equal(decimal.MinValue, row[0]); + Assert.Equal(decimal.MaxValue, row[1]); + Assert.Equal(decimal.Zero, row[2]); + Assert.Equal(decimal.MinusOne, row[3]); + Assert.Equal(new object[]{decimal.MinValue, decimal.MaxValue, decimal.Zero, decimal.MinusOne}, + row[4]); + } + + } +} diff --git a/src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs b/src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs index f7bd145e3..535991b36 100644 --- a/src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs +++ b/src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs @@ -269,6 +269,9 @@ private object CallJavaMethod( case 'd': returnValue = SerDe.ReadDouble(inputStream); break; + case 'm': + returnValue = decimal.Parse(SerDe.ReadString(inputStream)); + break; case 'b': returnValue = Convert.ToBoolean(inputStream.ReadByte()); break; diff --git a/src/csharp/Microsoft.Spark/Interop/Ipc/PayloadHelper.cs b/src/csharp/Microsoft.Spark/Interop/Ipc/PayloadHelper.cs index 3373bca62..ac9914672 100644 --- a/src/csharp/Microsoft.Spark/Interop/Ipc/PayloadHelper.cs +++ b/src/csharp/Microsoft.Spark/Interop/Ipc/PayloadHelper.cs @@ -32,6 +32,7 @@ internal class PayloadHelper private static readonly byte[] s_dictionaryTypeId = new[] { (byte)'e' }; private static readonly byte[] s_rowArrTypeId = new[] { (byte)'R' }; private static readonly byte[] s_objectArrTypeId = new[] { (byte)'O' }; + private static readonly byte[] s_decimalTypeId = new[] { (byte)'m' }; private static readonly ConcurrentDictionary s_isDictionaryTable = new ConcurrentDictionary(); @@ -109,6 +110,10 @@ internal static void ConvertArgsToBytes( case TypeCode.Double: SerDe.Write(destination, (double)arg); break; + + case TypeCode.Decimal: + SerDe.Write(destination, (decimal)arg); + break; case TypeCode.Object: switch (arg) @@ -321,7 +326,9 @@ internal static byte[] GetTypeId(Type type) case TypeCode.Boolean: return s_boolTypeId; case TypeCode.Double: - return s_doubleTypeId; + return s_doubleTypeId; + case TypeCode.Decimal: + return s_decimalTypeId; case TypeCode.Object: if (typeof(IJvmObjectReferenceProvider).IsAssignableFrom(type)) { diff --git a/src/csharp/Microsoft.Spark/Interop/Ipc/SerDe.cs b/src/csharp/Microsoft.Spark/Interop/Ipc/SerDe.cs index c2c742e87..a36a293a0 100644 --- a/src/csharp/Microsoft.Spark/Interop/Ipc/SerDe.cs +++ b/src/csharp/Microsoft.Spark/Interop/Ipc/SerDe.cs @@ -322,6 +322,13 @@ public static void Write(Stream s, long value) public static void Write(Stream s, double value) => Write(s, BitConverter.DoubleToInt64Bits(value)); + /// + /// Writes a decimal to a stream as a string. + /// + /// The stream to write + /// The decimal to write + public static void Write(Stream s, decimal value) => Write(s, value.ToString()); + /// /// Writes a string to a stream. /// diff --git a/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala b/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala index 44cad97c1..31cf97c12 100644 --- a/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala +++ b/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala @@ -19,6 +19,7 @@ import scala.collection.JavaConverters._ * This implementation of methods is mostly identical to the SerDe implementation in R. */ class SerDe(val tracker: JVMObjectTracker) { + def readObjectType(dis: DataInputStream): Char = { dis.readByte().toChar } @@ -35,6 +36,7 @@ class SerDe(val tracker: JVMObjectTracker) { case 'g' => new java.lang.Long(readLong(dis)) case 'd' => new java.lang.Double(readDouble(dis)) case 'b' => new java.lang.Boolean(readBoolean(dis)) + case 'm' => readDecimal(dis) case 'c' => readString(dis) case 'e' => readMap(dis) case 'r' => readBytes(dis) @@ -59,6 +61,10 @@ class SerDe(val tracker: JVMObjectTracker) { in.readInt() } + private def readDecimal(in: DataInputStream): BigDecimal = { + BigDecimal(readString(in)) + } + private def readLong(in: DataInputStream): Long = { in.readLong() } @@ -110,6 +116,11 @@ class SerDe(val tracker: JVMObjectTracker) { (0 until len).map(_ => readInt(in)).toArray } + private def readDecimalArr(in: DataInputStream): Array[BigDecimal] = { + val len = readInt(in) + (0 until len).map(_ => readDecimal(in)).toArray + } + private def readLongArr(in: DataInputStream): Array[Long] = { val len = readInt(in) (0 until len).map(_ => readLong(in)).toArray @@ -156,6 +167,7 @@ class SerDe(val tracker: JVMObjectTracker) { case 'b' => readBooleanArr(dis) case 'j' => readStringArr(dis).map(x => tracker.getObject(x)) case 'r' => readBytesArr(dis) + case 'm' => readDecimalArr(dis) case _ => throw new IllegalArgumentException(s"Invalid array type $arrType") } } @@ -206,6 +218,7 @@ class SerDe(val tracker: JVMObjectTracker) { case "long" => dos.writeByte('g') case "integer" => dos.writeByte('i') case "logical" => dos.writeByte('b') + case "bigdecimal" => dos.writeByte('m') case "date" => dos.writeByte('D') case "time" => dos.writeByte('t') case "raw" => dos.writeByte('r') @@ -238,6 +251,9 @@ class SerDe(val tracker: JVMObjectTracker) { case "boolean" | "java.lang.Boolean" => writeType(dos, "logical") writeBoolean(dos, value.asInstanceOf[Boolean]) + case "BigDecimal" | "java.math.BigDecimal" => + writeType(dos, "bigdecimal") + writeString(dos, value.toString) case "java.sql.Date" => writeType(dos, "date") writeDate(dos, value.asInstanceOf[Date]) diff --git a/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala b/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala index a3df3788a..31cf97c12 100644 --- a/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala +++ b/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala @@ -36,6 +36,7 @@ class SerDe(val tracker: JVMObjectTracker) { case 'g' => new java.lang.Long(readLong(dis)) case 'd' => new java.lang.Double(readDouble(dis)) case 'b' => new java.lang.Boolean(readBoolean(dis)) + case 'm' => readDecimal(dis) case 'c' => readString(dis) case 'e' => readMap(dis) case 'r' => readBytes(dis) @@ -60,6 +61,10 @@ class SerDe(val tracker: JVMObjectTracker) { in.readInt() } + private def readDecimal(in: DataInputStream): BigDecimal = { + BigDecimal(readString(in)) + } + private def readLong(in: DataInputStream): Long = { in.readLong() } @@ -111,6 +116,11 @@ class SerDe(val tracker: JVMObjectTracker) { (0 until len).map(_ => readInt(in)).toArray } + private def readDecimalArr(in: DataInputStream): Array[BigDecimal] = { + val len = readInt(in) + (0 until len).map(_ => readDecimal(in)).toArray + } + private def readLongArr(in: DataInputStream): Array[Long] = { val len = readInt(in) (0 until len).map(_ => readLong(in)).toArray @@ -157,6 +167,7 @@ class SerDe(val tracker: JVMObjectTracker) { case 'b' => readBooleanArr(dis) case 'j' => readStringArr(dis).map(x => tracker.getObject(x)) case 'r' => readBytesArr(dis) + case 'm' => readDecimalArr(dis) case _ => throw new IllegalArgumentException(s"Invalid array type $arrType") } } @@ -207,6 +218,7 @@ class SerDe(val tracker: JVMObjectTracker) { case "long" => dos.writeByte('g') case "integer" => dos.writeByte('i') case "logical" => dos.writeByte('b') + case "bigdecimal" => dos.writeByte('m') case "date" => dos.writeByte('D') case "time" => dos.writeByte('t') case "raw" => dos.writeByte('r') @@ -239,6 +251,9 @@ class SerDe(val tracker: JVMObjectTracker) { case "boolean" | "java.lang.Boolean" => writeType(dos, "logical") writeBoolean(dos, value.asInstanceOf[Boolean]) + case "BigDecimal" | "java.math.BigDecimal" => + writeType(dos, "bigdecimal") + writeString(dos, value.toString) case "java.sql.Date" => writeType(dos, "date") writeDate(dos, value.asInstanceOf[Date]) diff --git a/src/scala/microsoft-spark-3-1/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala b/src/scala/microsoft-spark-3-1/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala index a3df3788a..31cf97c12 100644 --- a/src/scala/microsoft-spark-3-1/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala +++ b/src/scala/microsoft-spark-3-1/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala @@ -36,6 +36,7 @@ class SerDe(val tracker: JVMObjectTracker) { case 'g' => new java.lang.Long(readLong(dis)) case 'd' => new java.lang.Double(readDouble(dis)) case 'b' => new java.lang.Boolean(readBoolean(dis)) + case 'm' => readDecimal(dis) case 'c' => readString(dis) case 'e' => readMap(dis) case 'r' => readBytes(dis) @@ -60,6 +61,10 @@ class SerDe(val tracker: JVMObjectTracker) { in.readInt() } + private def readDecimal(in: DataInputStream): BigDecimal = { + BigDecimal(readString(in)) + } + private def readLong(in: DataInputStream): Long = { in.readLong() } @@ -111,6 +116,11 @@ class SerDe(val tracker: JVMObjectTracker) { (0 until len).map(_ => readInt(in)).toArray } + private def readDecimalArr(in: DataInputStream): Array[BigDecimal] = { + val len = readInt(in) + (0 until len).map(_ => readDecimal(in)).toArray + } + private def readLongArr(in: DataInputStream): Array[Long] = { val len = readInt(in) (0 until len).map(_ => readLong(in)).toArray @@ -157,6 +167,7 @@ class SerDe(val tracker: JVMObjectTracker) { case 'b' => readBooleanArr(dis) case 'j' => readStringArr(dis).map(x => tracker.getObject(x)) case 'r' => readBytesArr(dis) + case 'm' => readDecimalArr(dis) case _ => throw new IllegalArgumentException(s"Invalid array type $arrType") } } @@ -207,6 +218,7 @@ class SerDe(val tracker: JVMObjectTracker) { case "long" => dos.writeByte('g') case "integer" => dos.writeByte('i') case "logical" => dos.writeByte('b') + case "bigdecimal" => dos.writeByte('m') case "date" => dos.writeByte('D') case "time" => dos.writeByte('t') case "raw" => dos.writeByte('r') @@ -239,6 +251,9 @@ class SerDe(val tracker: JVMObjectTracker) { case "boolean" | "java.lang.Boolean" => writeType(dos, "logical") writeBoolean(dos, value.asInstanceOf[Boolean]) + case "BigDecimal" | "java.math.BigDecimal" => + writeType(dos, "bigdecimal") + writeString(dos, value.toString) case "java.sql.Date" => writeType(dos, "date") writeDate(dos, value.asInstanceOf[Date])