Skip to content

Commit

Permalink
Add abstraction for serialization/deserialization (#5193)
Browse files Browse the repository at this point in the history
Resolves #4433

- Add abstraction for serialization/deserialization to handle more types
- handle 202 classifier
- Make `JsonElementSnippets` and `Utf8JsonWriterSnippets` public, so
that sub-plugin don't need to duplicate the implementation
  • Loading branch information
live1206 authored Nov 27, 2024
1 parent 98b8ba4 commit 8713c2c
Show file tree
Hide file tree
Showing 8 changed files with 67 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -842,15 +842,15 @@ private List<MethodBodyStatement> BuildDeserializePropertiesStatements(ScopedApi

if (_additionalBinaryDataProperty != null)
{
var binaryDataDeserializationValue = GetValueTypeDeserializationExpression(
var binaryDataDeserializationValue = ClientModelPlugin.Instance.TypeFactory.GetValueTypeDeserializationExpression(
_additionalBinaryDataProperty.Type.ElementType.FrameworkType, jsonProperty.Value(), SerializationFormat.Default);
propertyDeserializationStatements.Add(
_additionalBinaryDataProperty.AsVariableExpression.AsDictionary(_additionalBinaryDataProperty.Type).Add(jsonProperty.Name(), binaryDataDeserializationValue));
}
else if (rawBinaryData != null)
{
var elementType = rawBinaryData.Type.Arguments[1].FrameworkType;
var rawDataDeserializationValue = GetValueTypeDeserializationExpression(elementType, jsonProperty.Value(), SerializationFormat.Default);
var rawDataDeserializationValue = ClientModelPlugin.Instance.TypeFactory.GetValueTypeDeserializationExpression(elementType, jsonProperty.Value(), SerializationFormat.Default);
propertyDeserializationStatements.Add(new IfStatement(_isNotEqualToWireConditionSnippet)
{
rawBinaryData.AsVariableExpression.AsDictionary(rawBinaryData.Type).Add(jsonProperty.Name(), rawDataDeserializationValue)
Expand Down Expand Up @@ -1267,11 +1267,11 @@ private ValueExpression CreateDeserializeValueExpression(CSharpType valueType, S
valueType switch
{
{ IsFrameworkType: true } when valueType.FrameworkType == typeof(Nullable<>) =>
GetValueTypeDeserializationExpression(valueType.Arguments[0].FrameworkType, jsonElement, serializationFormat),
ClientModelPlugin.Instance.TypeFactory.GetValueTypeDeserializationExpression(valueType.Arguments[0].FrameworkType, jsonElement, serializationFormat),
{ IsFrameworkType: true } =>
GetValueTypeDeserializationExpression(valueType.FrameworkType, jsonElement, serializationFormat),
ClientModelPlugin.Instance.TypeFactory.GetValueTypeDeserializationExpression(valueType.FrameworkType, jsonElement, serializationFormat),
{ IsEnum: true } =>
valueType.ToEnum(GetValueTypeDeserializationExpression(valueType.UnderlyingEnumType!, jsonElement, serializationFormat)),
valueType.ToEnum(ClientModelPlugin.Instance.TypeFactory.GetValueTypeDeserializationExpression(valueType.UnderlyingEnumType!, jsonElement, serializationFormat)),
_ => valueType.Deserialize(jsonElement, _mrwOptionsParameterSnippet)
};

Expand Down Expand Up @@ -1585,7 +1585,7 @@ private MethodBodyStatement CreateValueSerializationStatement(
ValueExpression value)
{
if (type.IsFrameworkType)
return SerializeValueType(type, serializationFormat, value, type.FrameworkType);
return ClientModelPlugin.Instance.TypeFactory.SerializeValueType(type, serializationFormat, value, type.FrameworkType, _utf8JsonWriterSnippet, _mrwOptionsParameterSnippet);

if (!type.IsEnum)
return _utf8JsonWriterSnippet.WriteObjectValue(value.As(type), options: _mrwOptionsParameterSnippet);
Expand All @@ -1611,11 +1611,13 @@ private MethodBodyStatement CreateValueSerializationStatement(
}
}

private MethodBodyStatement SerializeValueType(
internal static MethodBodyStatement SerializeValueTypeCore(
CSharpType type,
SerializationFormat serializationFormat,
ValueExpression value,
Type valueType)
Type valueType,
ScopedApi<Utf8JsonWriter> utf8JsonWriter,
ScopedApi<ModelReaderWriterOptions> mrwOptionsParameter)
{
if (valueType == typeof(Nullable<>))
{
Expand All @@ -1627,34 +1629,34 @@ private MethodBodyStatement SerializeValueType(
return valueType switch
{
var t when t == typeof(JsonElement) =>
value.As<JsonElement>().WriteTo(_utf8JsonWriterSnippet),
value.As<JsonElement>().WriteTo(utf8JsonWriter),
var t when ValueTypeIsInt(t) && serializationFormat == SerializationFormat.Int_String =>
_utf8JsonWriterSnippet.WriteStringValue(value.InvokeToString()),
utf8JsonWriter.WriteStringValue(value.InvokeToString()),
var t when ValueTypeIsNumber(t) =>
_utf8JsonWriterSnippet.WriteNumberValue(value),
utf8JsonWriter.WriteNumberValue(value),
var t when t == typeof(object) =>
_utf8JsonWriterSnippet.WriteObjectValue(value.As(valueType), _mrwOptionsParameterSnippet),
utf8JsonWriter.WriteObjectValue(value.As(valueType), mrwOptionsParameter),
var t when t == typeof(string) || t == typeof(char) || t == typeof(Guid) =>
_utf8JsonWriterSnippet.WriteStringValue(value),
utf8JsonWriter.WriteStringValue(value),
var t when t == typeof(bool) =>
_utf8JsonWriterSnippet.WriteBooleanValue(value),
utf8JsonWriter.WriteBooleanValue(value),
var t when t == typeof(byte[]) =>
_utf8JsonWriterSnippet.WriteBase64StringValue(value, serializationFormat.ToFormatSpecifier()),
utf8JsonWriter.WriteBase64StringValue(value, serializationFormat.ToFormatSpecifier()),
var t when t == typeof(DateTimeOffset) || t == typeof(DateTime) || t == typeof(TimeSpan) =>
SerializeDateTimeRelatedTypes(valueType, serializationFormat, value),
SerializeDateTimeRelatedTypes(valueType, serializationFormat, value, utf8JsonWriter, mrwOptionsParameter),
var t when t == typeof(IPAddress) =>
_utf8JsonWriterSnippet.WriteStringValue(value.InvokeToString()),
utf8JsonWriter.WriteStringValue(value.InvokeToString()),
var t when t == typeof(Uri) =>
_utf8JsonWriterSnippet.WriteStringValue(new MemberExpression(value, nameof(Uri.AbsoluteUri))),
utf8JsonWriter.WriteStringValue(new MemberExpression(value, nameof(Uri.AbsoluteUri))),
var t when t == typeof(BinaryData) =>
SerializeBinaryData(valueType, serializationFormat, value),
SerializeBinaryData(valueType, serializationFormat, value, utf8JsonWriter),
var t when t == typeof(Stream) =>
_utf8JsonWriterSnippet.WriteBinaryData(BinaryDataSnippets.FromStream(value, false)),
utf8JsonWriter.WriteBinaryData(BinaryDataSnippets.FromStream(value, false)),
_ => throw new NotSupportedException($"Type {valueType} serialization is not supported.")
};
}

public static ValueExpression GetValueTypeDeserializationExpression(
internal static ValueExpression GetValueTypeDeserializationExpressionCore(
Type valueType,
ScopedApi<JsonElement> element,
SerializationFormat format)
Expand Down Expand Up @@ -1738,25 +1740,25 @@ private static bool ValueTypeIsNumber(Type valueType) =>
}
};

private MethodBodyStatement SerializeDateTimeRelatedTypes(Type valueType, SerializationFormat serializationFormat, ValueExpression value)
private static MethodBodyStatement SerializeDateTimeRelatedTypes(Type valueType, SerializationFormat serializationFormat, ValueExpression value, ScopedApi<Utf8JsonWriter> utf8JsonWriter, ScopedApi<ModelReaderWriterOptions> mrwOptionsParameter)
{
var format = serializationFormat.ToFormatSpecifier();
return serializationFormat switch
{
SerializationFormat.Duration_Seconds => _utf8JsonWriterSnippet.WriteNumberValue(ConvertSnippets.InvokeToInt32(value.As<TimeSpan>().InvokeToString(format))),
SerializationFormat.Duration_Seconds_Float or SerializationFormat.Duration_Seconds_Double => _utf8JsonWriterSnippet.WriteNumberValue(ConvertSnippets.InvokeToDouble(value.As<TimeSpan>().InvokeToString(format))),
SerializationFormat.DateTime_Unix => _utf8JsonWriterSnippet.WriteNumberValue(value, format),
_ => format is not null ? _utf8JsonWriterSnippet.WriteStringValue(value, format) : _utf8JsonWriterSnippet.WriteStringValue(value)
SerializationFormat.Duration_Seconds => utf8JsonWriter.WriteNumberValue(ConvertSnippets.InvokeToInt32(value.As<TimeSpan>().InvokeToString(format))),
SerializationFormat.Duration_Seconds_Float or SerializationFormat.Duration_Seconds_Double => utf8JsonWriter.WriteNumberValue(ConvertSnippets.InvokeToDouble(value.As<TimeSpan>().InvokeToString(format))),
SerializationFormat.DateTime_Unix => utf8JsonWriter.WriteNumberValue(value, format),
_ => format is not null ? utf8JsonWriter.WriteStringValue(value, format) : utf8JsonWriter.WriteStringValue(value)
};
}

private MethodBodyStatement SerializeBinaryData(Type valueType, SerializationFormat serializationFormat, ValueExpression value)
private static MethodBodyStatement SerializeBinaryData(Type valueType, SerializationFormat serializationFormat, ValueExpression value, ScopedApi<Utf8JsonWriter> utf8JsonWriter)
{
if (serializationFormat is SerializationFormat.Bytes_Base64 or SerializationFormat.Bytes_Base64Url)
{
return _utf8JsonWriterSnippet.WriteBase64StringValue(value.As<BinaryData>().ToArray(), serializationFormat.ToFormatSpecifier());
return utf8JsonWriter.WriteBase64StringValue(value.As<BinaryData>().ToArray(), serializationFormat.ToFormatSpecifier());
}
return _utf8JsonWriterSnippet.WriteBinaryData(value);
return utf8JsonWriter.WriteBinaryData(value);
}

private static ScopedApi GetEnumerableExpression(ValueExpression expression, CSharpType enumerableType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,14 @@ public class RestClientProvider : TypeProvider

private FieldProvider _pipelineMessageClassifier200;
private FieldProvider _pipelineMessageClassifier201;
private FieldProvider _pipelineMessageClassifier202;
private FieldProvider _pipelineMessageClassifier204;
private FieldProvider _pipelineMessageClassifier2xxAnd4xx;
private TypeProvider _classifier2xxAnd4xxDefinition;

private PropertyProvider _classifier201Property;
private PropertyProvider _classifier200Property;
private PropertyProvider _classifier202Property;
private PropertyProvider _classifier204Property;
private PropertyProvider _classifier2xxAnd4xxProperty;

Expand All @@ -51,11 +53,13 @@ public RestClientProvider(InputClient inputClient, ClientProvider clientProvider
ClientProvider = clientProvider;
_pipelineMessageClassifier200 = new FieldProvider(FieldModifiers.Private | FieldModifiers.Static, ClientModelPlugin.Instance.TypeFactory.StatusCodeClassifierApi.ResponseClassifierType, "_pipelineMessageClassifier200", this);
_pipelineMessageClassifier201 = new FieldProvider(FieldModifiers.Private | FieldModifiers.Static, ClientModelPlugin.Instance.TypeFactory.StatusCodeClassifierApi.ResponseClassifierType, "_pipelineMessageClassifier201", this);
_pipelineMessageClassifier202 = new FieldProvider(FieldModifiers.Private | FieldModifiers.Static, ClientModelPlugin.Instance.TypeFactory.StatusCodeClassifierApi.ResponseClassifierType, "_pipelineMessageClassifier202", this);
_pipelineMessageClassifier204 = new FieldProvider(FieldModifiers.Private | FieldModifiers.Static, ClientModelPlugin.Instance.TypeFactory.StatusCodeClassifierApi.ResponseClassifierType, "_pipelineMessageClassifier204", this);
_classifier2xxAnd4xxDefinition = new Classifier2xxAnd4xxDefinition(this);
_pipelineMessageClassifier2xxAnd4xx = new FieldProvider(FieldModifiers.Private | FieldModifiers.Static, _classifier2xxAnd4xxDefinition.Type, "_pipelineMessageClassifier2xxAnd4xx", this);
_classifier200Property = GetResponseClassifierProperty(_pipelineMessageClassifier200, 200);
_classifier201Property = GetResponseClassifierProperty(_pipelineMessageClassifier201, 201);
_classifier202Property = GetResponseClassifierProperty(_pipelineMessageClassifier202, 202);
_classifier204Property = GetResponseClassifierProperty(_pipelineMessageClassifier204, 204);
_classifier2xxAnd4xxProperty = new PropertyProvider(
$"Gets the PipelineMessageClassifier2xxAnd4xx",
Expand All @@ -76,6 +80,7 @@ protected override PropertyProvider[] BuildProperties()
[
_classifier200Property,
_classifier201Property,
_classifier202Property,
_classifier204Property,
_classifier2xxAnd4xxProperty
];
Expand All @@ -99,6 +104,7 @@ protected override FieldProvider[] BuildFields()
[
_pipelineMessageClassifier200,
_pipelineMessageClassifier201,
_pipelineMessageClassifier202,
_pipelineMessageClassifier204,
_pipelineMessageClassifier2xxAnd4xx
];
Expand Down Expand Up @@ -193,6 +199,7 @@ private PropertyProvider GetClassifier(InputOperation operation)
{
200 => _classifier200Property,
201 => _classifier201Property,
202 => _classifier202Property,
204 => _classifier204Property,
_ => throw new InvalidOperationException($"Unexpected status code {response.StatusCodes[0]}")
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,18 @@
using System.ClientModel;
using System.ClientModel.Primitives;
using System.Collections.Generic;
using System.IO;
using System.Net;
using System.Text.Json;
using Microsoft.Generator.CSharp.ClientModel.Providers;
using Microsoft.Generator.CSharp.ClientModel.Snippets;
using Microsoft.Generator.CSharp.Expressions;
using Microsoft.Generator.CSharp.Input;
using Microsoft.Generator.CSharp.Primitives;
using Microsoft.Generator.CSharp.Providers;
using Microsoft.Generator.CSharp.Snippets;
using Microsoft.Generator.CSharp.Statements;
using static Microsoft.Generator.CSharp.Snippets.Snippet;

namespace Microsoft.Generator.CSharp.ClientModel
{
Expand Down Expand Up @@ -118,5 +126,17 @@ public ClientProvider CreateClient(InputClient inputClient)
}
return methods;
}

public virtual ValueExpression GetValueTypeDeserializationExpression(Type valueType, ScopedApi<JsonElement> element, SerializationFormat format)
=> MrwSerializationTypeDefinition.GetValueTypeDeserializationExpressionCore(valueType, element, format);

public virtual MethodBodyStatement SerializeValueType(
CSharpType type,
SerializationFormat serializationFormat,
ValueExpression value,
Type valueType,
ScopedApi<Utf8JsonWriter> utf8JsonWriter,
ScopedApi<ModelReaderWriterOptions> mrwOptionsParameter)
=> MrwSerializationTypeDefinition.SerializeValueTypeCore(type, serializationFormat, value, valueType, utf8JsonWriter, mrwOptionsParameter);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

namespace Microsoft.Generator.CSharp.ClientModel.Snippets
{
internal static class JsonElementSnippets
public static class JsonElementSnippets
{
private const string GetRequiredStringMethodName = "GetRequiredString";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

namespace Microsoft.Generator.CSharp.ClientModel.Snippets
{
internal static class Utf8JsonWriterSnippets
public static class Utf8JsonWriterSnippets
{
public static ScopedApi<long> BytesCommitted(this ScopedApi<Utf8JsonWriter> writer) => writer.Property(nameof(Utf8JsonWriter.BytesCommitted)).As<long>();
public static ScopedApi<long> BytesPending(this ScopedApi<Utf8JsonWriter> writer) => writer.Property(nameof(Utf8JsonWriter.BytesPending)).As<long>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,7 @@ public void TestIntSerializationStatement(
[TestCase(typeof(sbyte), SerializationFormat.Default, ExpectedResult = "foo.GetSByte()")]
public string TestIntDeserializeExpression(Type type, SerializationFormat format)
{
var expr = MrwSerializationTypeDefinition.GetValueTypeDeserializationExpression(type, new ScopedApi<JsonElement>(new VariableExpression(typeof(JsonElement), "foo")), format);
var expr = MrwSerializationTypeDefinition.GetValueTypeDeserializationExpressionCore(type, new ScopedApi<JsonElement>(new VariableExpression(typeof(JsonElement), "foo")), format);
return expr.ToDisplayString();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,16 @@ public partial class TestClient
{
private static global::System.ClientModel.Primitives.PipelineMessageClassifier _pipelineMessageClassifier200;
private static global::System.ClientModel.Primitives.PipelineMessageClassifier _pipelineMessageClassifier201;
private static global::System.ClientModel.Primitives.PipelineMessageClassifier _pipelineMessageClassifier202;
private static global::System.ClientModel.Primitives.PipelineMessageClassifier _pipelineMessageClassifier204;
private static global::Sample.Custom.TestClient.Classifier2xxAnd4xx _pipelineMessageClassifier2xxAnd4xx;

private static global::System.ClientModel.Primitives.PipelineMessageClassifier PipelineMessageClassifier200 => _pipelineMessageClassifier200 = global::System.ClientModel.Primitives.PipelineMessageClassifier.Create(stackalloc ushort[] { 200 });

private static global::System.ClientModel.Primitives.PipelineMessageClassifier PipelineMessageClassifier201 => _pipelineMessageClassifier201 = global::System.ClientModel.Primitives.PipelineMessageClassifier.Create(stackalloc ushort[] { 201 });

private static global::System.ClientModel.Primitives.PipelineMessageClassifier PipelineMessageClassifier202 => _pipelineMessageClassifier202 = global::System.ClientModel.Primitives.PipelineMessageClassifier.Create(stackalloc ushort[] { 202 });

private static global::System.ClientModel.Primitives.PipelineMessageClassifier PipelineMessageClassifier204 => _pipelineMessageClassifier204 = global::System.ClientModel.Primitives.PipelineMessageClassifier.Create(stackalloc ushort[] { 204 });

private static global::Sample.Custom.TestClient.Classifier2xxAnd4xx PipelineMessageClassifier2xxAnd4xx => _pipelineMessageClassifier2xxAnd4xx ??= new global::Sample.Custom.TestClient.Classifier2xxAnd4xx();
Expand Down
Loading

0 comments on commit 8713c2c

Please sign in to comment.