Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/decimal support #982

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/DataTypesTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.Spark.Sql;
using Microsoft.Spark.Sql.Types;
using Xunit;

namespace Microsoft.Spark.E2ETest.IpcTests
{

[Collection("Spark E2E Tests")]
public class DataTypesTests
{
private readonly SparkSession _spark;

public DataTypesTests(SparkFixture fixture)
{
_spark = fixture.Spark;
}

/// <summary>
/// Tests that we can pass a decimal over to Apache Spark and collect it back again, include a check
/// for the minimum and maximum decimal that .NET can represent
/// </summary>
[Fact]
public void TestDecimalType()
{
var df = _spark.CreateDataFrame(
new List<GenericRow>
{
new GenericRow(
new object[]
{
decimal.MinValue, decimal.MaxValue, decimal.Zero, decimal.MinusOne,
new object[]
{
decimal.MinValue, decimal.MaxValue, decimal.Zero, decimal.MinusOne
}
}),
},
new StructType(
new List<StructField>()
{
new StructField("min", new DecimalType(38, 0)),
new StructField("max", new DecimalType(38, 0)),
new StructField("zero", new DecimalType(38, 0)),
new StructField("minusOne", new DecimalType(38, 0)),
new StructField("array", new ArrayType(new DecimalType(38,0)))
}));

Row row = df.Collect().First();
Assert.Equal(decimal.MinValue, row[0]);
Assert.Equal(decimal.MaxValue, row[1]);
Assert.Equal(decimal.Zero, row[2]);
Copy link
Contributor

@cutecycle cutecycle May 10, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't gotten to dive deep into whether this is an issue yet, but want to bring it to attention just in case:

There was a time when we were comparing SQL Server output to Spark SQL output trying to migrate a pipeline to Synapse, and when attempting to diff two tables, found an issue with a double.

SQL Server uses, presumably, C#'s (and JavaScript, which the Python Notebook table preview in Synapse uses)'s conception of floats: -0.0 == 0.0, but the JVM/Spark in some cases compares by bit and differentiates because of the signed bit: -0.0 != 0.0.
image
image
image

It's resolved in later versions of Spark's DataFrames, and may not apply in the case of [decimal]String, so it may not be problematic.

Copy link
Contributor

@cutecycle cutecycle Jun 2, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is because internally, BigDecimal uses BigInteger, and BigInteger also only has a single concept of zero. A BigInteger behaves as a two's-complement integer, and two's-complement only has a single zero.

Assert.Equal(decimal.MinusOne, row[3]);
Assert.Equal(new object[]{decimal.MinValue, decimal.MaxValue, decimal.Zero, decimal.MinusOne},
row[4]);
}

}
}
3 changes: 3 additions & 0 deletions src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,9 @@ private object CallJavaMethod(
case 'd':
returnValue = SerDe.ReadDouble(inputStream);
break;
case 'm':
returnValue = decimal.Parse(SerDe.ReadString(inputStream));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use decimal.Parse(SerDe.ReadString(inputStream), CultureInfo.InvariantCulture) to ensure we are using invariant culture on the wire.

break;
case 'b':
returnValue = Convert.ToBoolean(inputStream.ReadByte());
break;
Expand Down
9 changes: 8 additions & 1 deletion src/csharp/Microsoft.Spark/Interop/Ipc/PayloadHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ internal class PayloadHelper
private static readonly byte[] s_dictionaryTypeId = new[] { (byte)'e' };
private static readonly byte[] s_rowArrTypeId = new[] { (byte)'R' };
private static readonly byte[] s_objectArrTypeId = new[] { (byte)'O' };
private static readonly byte[] s_decimalTypeId = new[] { (byte)'m' };

private static readonly ConcurrentDictionary<Type, bool> s_isDictionaryTable =
new ConcurrentDictionary<Type, bool>();
Expand Down Expand Up @@ -109,6 +110,10 @@ internal static void ConvertArgsToBytes(
case TypeCode.Double:
SerDe.Write(destination, (double)arg);
break;

case TypeCode.Decimal:
SerDe.Write(destination, (decimal)arg);
break;

case TypeCode.Object:
switch (arg)
Expand Down Expand Up @@ -321,7 +326,9 @@ internal static byte[] GetTypeId(Type type)
case TypeCode.Boolean:
return s_boolTypeId;
case TypeCode.Double:
return s_doubleTypeId;
return s_doubleTypeId;
case TypeCode.Decimal:
return s_decimalTypeId;
case TypeCode.Object:
if (typeof(IJvmObjectReferenceProvider).IsAssignableFrom(type))
{
Expand Down
7 changes: 7 additions & 0 deletions src/csharp/Microsoft.Spark/Interop/Ipc/SerDe.cs
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,13 @@ public static void Write(Stream s, long value)
public static void Write(Stream s, double value) =>
Write(s, BitConverter.DoubleToInt64Bits(value));

/// <summary>
/// Writes a decimal to a stream as a string.
/// </summary>
/// <param name="s">The stream to write</param>
/// <param name="value">The decimal to write</param>
public static void Write(Stream s, decimal value) => Write(s, value.ToString());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should use ToString(CultureInfo.InvariantCulture) if we are using a string on the wire.


/// <summary>
/// Writes a string to a stream.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import scala.collection.JavaConverters._
* This implementation of methods is mostly identical to the SerDe implementation in R.
*/
class SerDe(val tracker: JVMObjectTracker) {

def readObjectType(dis: DataInputStream): Char = {
dis.readByte().toChar
}
Expand All @@ -35,6 +36,7 @@ class SerDe(val tracker: JVMObjectTracker) {
case 'g' => new java.lang.Long(readLong(dis))
case 'd' => new java.lang.Double(readDouble(dis))
case 'b' => new java.lang.Boolean(readBoolean(dis))
case 'm' => readDecimal(dis)
case 'c' => readString(dis)
case 'e' => readMap(dis)
case 'r' => readBytes(dis)
Expand All @@ -59,6 +61,10 @@ class SerDe(val tracker: JVMObjectTracker) {
in.readInt()
}

private def readDecimal(in: DataInputStream): BigDecimal = {
BigDecimal(readString(in))
}

private def readLong(in: DataInputStream): Long = {
in.readLong()
}
Expand Down Expand Up @@ -110,6 +116,11 @@ class SerDe(val tracker: JVMObjectTracker) {
(0 until len).map(_ => readInt(in)).toArray
}

private def readDecimalArr(in: DataInputStream): Array[BigDecimal] = {
val len = readInt(in)
(0 until len).map(_ => readDecimal(in)).toArray
}

private def readLongArr(in: DataInputStream): Array[Long] = {
val len = readInt(in)
(0 until len).map(_ => readLong(in)).toArray
Expand Down Expand Up @@ -156,6 +167,7 @@ class SerDe(val tracker: JVMObjectTracker) {
case 'b' => readBooleanArr(dis)
case 'j' => readStringArr(dis).map(x => tracker.getObject(x))
case 'r' => readBytesArr(dis)
case 'm' => readDecimalArr(dis)
case _ => throw new IllegalArgumentException(s"Invalid array type $arrType")
}
}
Expand Down Expand Up @@ -206,6 +218,7 @@ class SerDe(val tracker: JVMObjectTracker) {
case "long" => dos.writeByte('g')
case "integer" => dos.writeByte('i')
case "logical" => dos.writeByte('b')
case "bigdecimal" => dos.writeByte('m')
case "date" => dos.writeByte('D')
case "time" => dos.writeByte('t')
case "raw" => dos.writeByte('r')
Expand Down Expand Up @@ -238,6 +251,9 @@ class SerDe(val tracker: JVMObjectTracker) {
case "boolean" | "java.lang.Boolean" =>
writeType(dos, "logical")
writeBoolean(dos, value.asInstanceOf[Boolean])
case "BigDecimal" | "java.math.BigDecimal" =>
writeType(dos, "bigdecimal")
writeString(dos, value.toString)
case "java.sql.Date" =>
writeType(dos, "date")
writeDate(dos, value.asInstanceOf[Date])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class SerDe(val tracker: JVMObjectTracker) {
case 'g' => new java.lang.Long(readLong(dis))
case 'd' => new java.lang.Double(readDouble(dis))
case 'b' => new java.lang.Boolean(readBoolean(dis))
case 'm' => readDecimal(dis)
case 'c' => readString(dis)
case 'e' => readMap(dis)
case 'r' => readBytes(dis)
Expand All @@ -60,6 +61,10 @@ class SerDe(val tracker: JVMObjectTracker) {
in.readInt()
}

private def readDecimal(in: DataInputStream): BigDecimal = {
BigDecimal(readString(in))
}

private def readLong(in: DataInputStream): Long = {
in.readLong()
}
Expand Down Expand Up @@ -111,6 +116,11 @@ class SerDe(val tracker: JVMObjectTracker) {
(0 until len).map(_ => readInt(in)).toArray
}

private def readDecimalArr(in: DataInputStream): Array[BigDecimal] = {
val len = readInt(in)
(0 until len).map(_ => readDecimal(in)).toArray
}

private def readLongArr(in: DataInputStream): Array[Long] = {
val len = readInt(in)
(0 until len).map(_ => readLong(in)).toArray
Expand Down Expand Up @@ -157,6 +167,7 @@ class SerDe(val tracker: JVMObjectTracker) {
case 'b' => readBooleanArr(dis)
case 'j' => readStringArr(dis).map(x => tracker.getObject(x))
case 'r' => readBytesArr(dis)
case 'm' => readDecimalArr(dis)
case _ => throw new IllegalArgumentException(s"Invalid array type $arrType")
}
}
Expand Down Expand Up @@ -207,6 +218,7 @@ class SerDe(val tracker: JVMObjectTracker) {
case "long" => dos.writeByte('g')
case "integer" => dos.writeByte('i')
case "logical" => dos.writeByte('b')
case "bigdecimal" => dos.writeByte('m')
case "date" => dos.writeByte('D')
case "time" => dos.writeByte('t')
case "raw" => dos.writeByte('r')
Expand Down Expand Up @@ -239,6 +251,9 @@ class SerDe(val tracker: JVMObjectTracker) {
case "boolean" | "java.lang.Boolean" =>
writeType(dos, "logical")
writeBoolean(dos, value.asInstanceOf[Boolean])
case "BigDecimal" | "java.math.BigDecimal" =>
writeType(dos, "bigdecimal")
writeString(dos, value.toString)
case "java.sql.Date" =>
writeType(dos, "date")
writeDate(dos, value.asInstanceOf[Date])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class SerDe(val tracker: JVMObjectTracker) {
case 'g' => new java.lang.Long(readLong(dis))
case 'd' => new java.lang.Double(readDouble(dis))
case 'b' => new java.lang.Boolean(readBoolean(dis))
case 'm' => readDecimal(dis)
case 'c' => readString(dis)
case 'e' => readMap(dis)
case 'r' => readBytes(dis)
Expand All @@ -60,6 +61,10 @@ class SerDe(val tracker: JVMObjectTracker) {
in.readInt()
}

private def readDecimal(in: DataInputStream): BigDecimal = {
BigDecimal(readString(in))
}

private def readLong(in: DataInputStream): Long = {
in.readLong()
}
Expand Down Expand Up @@ -111,6 +116,11 @@ class SerDe(val tracker: JVMObjectTracker) {
(0 until len).map(_ => readInt(in)).toArray
}

private def readDecimalArr(in: DataInputStream): Array[BigDecimal] = {
val len = readInt(in)
(0 until len).map(_ => readDecimal(in)).toArray
}

private def readLongArr(in: DataInputStream): Array[Long] = {
val len = readInt(in)
(0 until len).map(_ => readLong(in)).toArray
Expand Down Expand Up @@ -157,6 +167,7 @@ class SerDe(val tracker: JVMObjectTracker) {
case 'b' => readBooleanArr(dis)
case 'j' => readStringArr(dis).map(x => tracker.getObject(x))
case 'r' => readBytesArr(dis)
case 'm' => readDecimalArr(dis)
case _ => throw new IllegalArgumentException(s"Invalid array type $arrType")
}
}
Expand Down Expand Up @@ -207,6 +218,7 @@ class SerDe(val tracker: JVMObjectTracker) {
case "long" => dos.writeByte('g')
case "integer" => dos.writeByte('i')
case "logical" => dos.writeByte('b')
case "bigdecimal" => dos.writeByte('m')
case "date" => dos.writeByte('D')
case "time" => dos.writeByte('t')
case "raw" => dos.writeByte('r')
Expand Down Expand Up @@ -239,6 +251,9 @@ class SerDe(val tracker: JVMObjectTracker) {
case "boolean" | "java.lang.Boolean" =>
writeType(dos, "logical")
writeBoolean(dos, value.asInstanceOf[Boolean])
case "BigDecimal" | "java.math.BigDecimal" =>
writeType(dos, "bigdecimal")
writeString(dos, value.toString)
case "java.sql.Date" =>
writeType(dos, "date")
writeDate(dos, value.asInstanceOf[Date])
Expand Down