diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/Providers/MrwSerializationTypeDefinition.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/Providers/MrwSerializationTypeDefinition.cs index e40247089d..004d37334b 100644 --- a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/Providers/MrwSerializationTypeDefinition.cs +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/Providers/MrwSerializationTypeDefinition.cs @@ -842,7 +842,7 @@ private List BuildDeserializePropertiesStatements(ScopedApi if (_additionalBinaryDataProperty != null) { - var binaryDataDeserializationValue = GetValueTypeDeserializationExpression( + var binaryDataDeserializationValue = ClientModelPlugin.Instance.TypeFactory.GetValueTypeDeserializationExpression( _additionalBinaryDataProperty.Type.ElementType.FrameworkType, jsonProperty.Value(), SerializationFormat.Default); propertyDeserializationStatements.Add( _additionalBinaryDataProperty.AsVariableExpression.AsDictionary(_additionalBinaryDataProperty.Type).Add(jsonProperty.Name(), binaryDataDeserializationValue)); @@ -850,7 +850,7 @@ private List BuildDeserializePropertiesStatements(ScopedApi else if (rawBinaryData != null) { var elementType = rawBinaryData.Type.Arguments[1].FrameworkType; - var rawDataDeserializationValue = GetValueTypeDeserializationExpression(elementType, jsonProperty.Value(), SerializationFormat.Default); + var rawDataDeserializationValue = ClientModelPlugin.Instance.TypeFactory.GetValueTypeDeserializationExpression(elementType, jsonProperty.Value(), SerializationFormat.Default); propertyDeserializationStatements.Add(new IfStatement(_isNotEqualToWireConditionSnippet) { rawBinaryData.AsVariableExpression.AsDictionary(rawBinaryData.Type).Add(jsonProperty.Name(), rawDataDeserializationValue) @@ -1267,11 +1267,11 @@ private ValueExpression CreateDeserializeValueExpression(CSharpType valueType, S valueType switch { { IsFrameworkType: true } when valueType.FrameworkType == typeof(Nullable<>) => - GetValueTypeDeserializationExpression(valueType.Arguments[0].FrameworkType, jsonElement, serializationFormat), + ClientModelPlugin.Instance.TypeFactory.GetValueTypeDeserializationExpression(valueType.Arguments[0].FrameworkType, jsonElement, serializationFormat), { IsFrameworkType: true } => - GetValueTypeDeserializationExpression(valueType.FrameworkType, jsonElement, serializationFormat), + ClientModelPlugin.Instance.TypeFactory.GetValueTypeDeserializationExpression(valueType.FrameworkType, jsonElement, serializationFormat), { IsEnum: true } => - valueType.ToEnum(GetValueTypeDeserializationExpression(valueType.UnderlyingEnumType!, jsonElement, serializationFormat)), + valueType.ToEnum(ClientModelPlugin.Instance.TypeFactory.GetValueTypeDeserializationExpression(valueType.UnderlyingEnumType!, jsonElement, serializationFormat)), _ => valueType.Deserialize(jsonElement, _mrwOptionsParameterSnippet) }; @@ -1585,7 +1585,7 @@ private MethodBodyStatement CreateValueSerializationStatement( ValueExpression value) { if (type.IsFrameworkType) - return SerializeValueType(type, serializationFormat, value, type.FrameworkType); + return ClientModelPlugin.Instance.TypeFactory.SerializeValueType(type, serializationFormat, value, type.FrameworkType, _utf8JsonWriterSnippet, _mrwOptionsParameterSnippet); if (!type.IsEnum) return _utf8JsonWriterSnippet.WriteObjectValue(value.As(type), options: _mrwOptionsParameterSnippet); @@ -1611,11 +1611,13 @@ private MethodBodyStatement CreateValueSerializationStatement( } } - private MethodBodyStatement SerializeValueType( + internal static MethodBodyStatement SerializeValueTypeCore( CSharpType type, SerializationFormat serializationFormat, ValueExpression value, - Type valueType) + Type valueType, + ScopedApi utf8JsonWriter, + ScopedApi mrwOptionsParameter) { if (valueType == typeof(Nullable<>)) { @@ -1627,34 +1629,34 @@ private MethodBodyStatement SerializeValueType( return valueType switch { var t when t == typeof(JsonElement) => - value.As().WriteTo(_utf8JsonWriterSnippet), + value.As().WriteTo(utf8JsonWriter), var t when ValueTypeIsInt(t) && serializationFormat == SerializationFormat.Int_String => - _utf8JsonWriterSnippet.WriteStringValue(value.InvokeToString()), + utf8JsonWriter.WriteStringValue(value.InvokeToString()), var t when ValueTypeIsNumber(t) => - _utf8JsonWriterSnippet.WriteNumberValue(value), + utf8JsonWriter.WriteNumberValue(value), var t when t == typeof(object) => - _utf8JsonWriterSnippet.WriteObjectValue(value.As(valueType), _mrwOptionsParameterSnippet), + utf8JsonWriter.WriteObjectValue(value.As(valueType), mrwOptionsParameter), var t when t == typeof(string) || t == typeof(char) || t == typeof(Guid) => - _utf8JsonWriterSnippet.WriteStringValue(value), + utf8JsonWriter.WriteStringValue(value), var t when t == typeof(bool) => - _utf8JsonWriterSnippet.WriteBooleanValue(value), + utf8JsonWriter.WriteBooleanValue(value), var t when t == typeof(byte[]) => - _utf8JsonWriterSnippet.WriteBase64StringValue(value, serializationFormat.ToFormatSpecifier()), + utf8JsonWriter.WriteBase64StringValue(value, serializationFormat.ToFormatSpecifier()), var t when t == typeof(DateTimeOffset) || t == typeof(DateTime) || t == typeof(TimeSpan) => - SerializeDateTimeRelatedTypes(valueType, serializationFormat, value), + SerializeDateTimeRelatedTypes(valueType, serializationFormat, value, utf8JsonWriter, mrwOptionsParameter), var t when t == typeof(IPAddress) => - _utf8JsonWriterSnippet.WriteStringValue(value.InvokeToString()), + utf8JsonWriter.WriteStringValue(value.InvokeToString()), var t when t == typeof(Uri) => - _utf8JsonWriterSnippet.WriteStringValue(new MemberExpression(value, nameof(Uri.AbsoluteUri))), + utf8JsonWriter.WriteStringValue(new MemberExpression(value, nameof(Uri.AbsoluteUri))), var t when t == typeof(BinaryData) => - SerializeBinaryData(valueType, serializationFormat, value), + SerializeBinaryData(valueType, serializationFormat, value, utf8JsonWriter), var t when t == typeof(Stream) => - _utf8JsonWriterSnippet.WriteBinaryData(BinaryDataSnippets.FromStream(value, false)), + utf8JsonWriter.WriteBinaryData(BinaryDataSnippets.FromStream(value, false)), _ => throw new NotSupportedException($"Type {valueType} serialization is not supported.") }; } - public static ValueExpression GetValueTypeDeserializationExpression( + internal static ValueExpression GetValueTypeDeserializationExpressionCore( Type valueType, ScopedApi element, SerializationFormat format) @@ -1738,25 +1740,25 @@ private static bool ValueTypeIsNumber(Type valueType) => } }; - private MethodBodyStatement SerializeDateTimeRelatedTypes(Type valueType, SerializationFormat serializationFormat, ValueExpression value) + private static MethodBodyStatement SerializeDateTimeRelatedTypes(Type valueType, SerializationFormat serializationFormat, ValueExpression value, ScopedApi utf8JsonWriter, ScopedApi mrwOptionsParameter) { var format = serializationFormat.ToFormatSpecifier(); return serializationFormat switch { - SerializationFormat.Duration_Seconds => _utf8JsonWriterSnippet.WriteNumberValue(ConvertSnippets.InvokeToInt32(value.As().InvokeToString(format))), - SerializationFormat.Duration_Seconds_Float or SerializationFormat.Duration_Seconds_Double => _utf8JsonWriterSnippet.WriteNumberValue(ConvertSnippets.InvokeToDouble(value.As().InvokeToString(format))), - SerializationFormat.DateTime_Unix => _utf8JsonWriterSnippet.WriteNumberValue(value, format), - _ => format is not null ? _utf8JsonWriterSnippet.WriteStringValue(value, format) : _utf8JsonWriterSnippet.WriteStringValue(value) + SerializationFormat.Duration_Seconds => utf8JsonWriter.WriteNumberValue(ConvertSnippets.InvokeToInt32(value.As().InvokeToString(format))), + SerializationFormat.Duration_Seconds_Float or SerializationFormat.Duration_Seconds_Double => utf8JsonWriter.WriteNumberValue(ConvertSnippets.InvokeToDouble(value.As().InvokeToString(format))), + SerializationFormat.DateTime_Unix => utf8JsonWriter.WriteNumberValue(value, format), + _ => format is not null ? utf8JsonWriter.WriteStringValue(value, format) : utf8JsonWriter.WriteStringValue(value) }; } - private MethodBodyStatement SerializeBinaryData(Type valueType, SerializationFormat serializationFormat, ValueExpression value) + private static MethodBodyStatement SerializeBinaryData(Type valueType, SerializationFormat serializationFormat, ValueExpression value, ScopedApi utf8JsonWriter) { if (serializationFormat is SerializationFormat.Bytes_Base64 or SerializationFormat.Bytes_Base64Url) { - return _utf8JsonWriterSnippet.WriteBase64StringValue(value.As().ToArray(), serializationFormat.ToFormatSpecifier()); + return utf8JsonWriter.WriteBase64StringValue(value.As().ToArray(), serializationFormat.ToFormatSpecifier()); } - return _utf8JsonWriterSnippet.WriteBinaryData(value); + return utf8JsonWriter.WriteBinaryData(value); } private static ScopedApi GetEnumerableExpression(ValueExpression expression, CSharpType enumerableType) diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/Providers/RestClientProvider.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/Providers/RestClientProvider.cs index 687fa0c8ad..be85bdce5e 100644 --- a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/Providers/RestClientProvider.cs +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/Providers/RestClientProvider.cs @@ -36,12 +36,14 @@ public class RestClientProvider : TypeProvider private FieldProvider _pipelineMessageClassifier200; private FieldProvider _pipelineMessageClassifier201; + private FieldProvider _pipelineMessageClassifier202; private FieldProvider _pipelineMessageClassifier204; private FieldProvider _pipelineMessageClassifier2xxAnd4xx; private TypeProvider _classifier2xxAnd4xxDefinition; private PropertyProvider _classifier201Property; private PropertyProvider _classifier200Property; + private PropertyProvider _classifier202Property; private PropertyProvider _classifier204Property; private PropertyProvider _classifier2xxAnd4xxProperty; @@ -51,11 +53,13 @@ public RestClientProvider(InputClient inputClient, ClientProvider clientProvider ClientProvider = clientProvider; _pipelineMessageClassifier200 = new FieldProvider(FieldModifiers.Private | FieldModifiers.Static, ClientModelPlugin.Instance.TypeFactory.StatusCodeClassifierApi.ResponseClassifierType, "_pipelineMessageClassifier200", this); _pipelineMessageClassifier201 = new FieldProvider(FieldModifiers.Private | FieldModifiers.Static, ClientModelPlugin.Instance.TypeFactory.StatusCodeClassifierApi.ResponseClassifierType, "_pipelineMessageClassifier201", this); + _pipelineMessageClassifier202 = new FieldProvider(FieldModifiers.Private | FieldModifiers.Static, ClientModelPlugin.Instance.TypeFactory.StatusCodeClassifierApi.ResponseClassifierType, "_pipelineMessageClassifier202", this); _pipelineMessageClassifier204 = new FieldProvider(FieldModifiers.Private | FieldModifiers.Static, ClientModelPlugin.Instance.TypeFactory.StatusCodeClassifierApi.ResponseClassifierType, "_pipelineMessageClassifier204", this); _classifier2xxAnd4xxDefinition = new Classifier2xxAnd4xxDefinition(this); _pipelineMessageClassifier2xxAnd4xx = new FieldProvider(FieldModifiers.Private | FieldModifiers.Static, _classifier2xxAnd4xxDefinition.Type, "_pipelineMessageClassifier2xxAnd4xx", this); _classifier200Property = GetResponseClassifierProperty(_pipelineMessageClassifier200, 200); _classifier201Property = GetResponseClassifierProperty(_pipelineMessageClassifier201, 201); + _classifier202Property = GetResponseClassifierProperty(_pipelineMessageClassifier202, 202); _classifier204Property = GetResponseClassifierProperty(_pipelineMessageClassifier204, 204); _classifier2xxAnd4xxProperty = new PropertyProvider( $"Gets the PipelineMessageClassifier2xxAnd4xx", @@ -76,6 +80,7 @@ protected override PropertyProvider[] BuildProperties() [ _classifier200Property, _classifier201Property, + _classifier202Property, _classifier204Property, _classifier2xxAnd4xxProperty ]; @@ -99,6 +104,7 @@ protected override FieldProvider[] BuildFields() [ _pipelineMessageClassifier200, _pipelineMessageClassifier201, + _pipelineMessageClassifier202, _pipelineMessageClassifier204, _pipelineMessageClassifier2xxAnd4xx ]; @@ -193,6 +199,7 @@ private PropertyProvider GetClassifier(InputOperation operation) { 200 => _classifier200Property, 201 => _classifier201Property, + 202 => _classifier202Property, 204 => _classifier204Property, _ => throw new InvalidOperationException($"Unexpected status code {response.StatusCodes[0]}") }; diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/ScmTypeFactory.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/ScmTypeFactory.cs index 371cc00c69..18364b2d46 100644 --- a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/ScmTypeFactory.cs +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/ScmTypeFactory.cs @@ -5,10 +5,18 @@ using System.ClientModel; using System.ClientModel.Primitives; using System.Collections.Generic; +using System.IO; +using System.Net; +using System.Text.Json; using Microsoft.Generator.CSharp.ClientModel.Providers; +using Microsoft.Generator.CSharp.ClientModel.Snippets; +using Microsoft.Generator.CSharp.Expressions; using Microsoft.Generator.CSharp.Input; using Microsoft.Generator.CSharp.Primitives; using Microsoft.Generator.CSharp.Providers; +using Microsoft.Generator.CSharp.Snippets; +using Microsoft.Generator.CSharp.Statements; +using static Microsoft.Generator.CSharp.Snippets.Snippet; namespace Microsoft.Generator.CSharp.ClientModel { @@ -118,5 +126,17 @@ public ClientProvider CreateClient(InputClient inputClient) } return methods; } + + public virtual ValueExpression GetValueTypeDeserializationExpression(Type valueType, ScopedApi element, SerializationFormat format) + => MrwSerializationTypeDefinition.GetValueTypeDeserializationExpressionCore(valueType, element, format); + + public virtual MethodBodyStatement SerializeValueType( + CSharpType type, + SerializationFormat serializationFormat, + ValueExpression value, + Type valueType, + ScopedApi utf8JsonWriter, + ScopedApi mrwOptionsParameter) + => MrwSerializationTypeDefinition.SerializeValueTypeCore(type, serializationFormat, value, valueType, utf8JsonWriter, mrwOptionsParameter); } } diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/Snippets/JsonElementSnippets.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/Snippets/JsonElementSnippets.cs index b6755a4096..69d604c62a 100644 --- a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/Snippets/JsonElementSnippets.cs +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/Snippets/JsonElementSnippets.cs @@ -11,7 +11,7 @@ namespace Microsoft.Generator.CSharp.ClientModel.Snippets { - internal static class JsonElementSnippets + public static class JsonElementSnippets { private const string GetRequiredStringMethodName = "GetRequiredString"; diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/Snippets/Utf8JsonWriterSnippets.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/Snippets/Utf8JsonWriterSnippets.cs index f97e5b9416..032ed9becb 100644 --- a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/Snippets/Utf8JsonWriterSnippets.cs +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/Snippets/Utf8JsonWriterSnippets.cs @@ -10,7 +10,7 @@ namespace Microsoft.Generator.CSharp.ClientModel.Snippets { - internal static class Utf8JsonWriterSnippets + public static class Utf8JsonWriterSnippets { public static ScopedApi BytesCommitted(this ScopedApi writer) => writer.Property(nameof(Utf8JsonWriter.BytesCommitted)).As(); public static ScopedApi BytesPending(this ScopedApi writer) => writer.Property(nameof(Utf8JsonWriter.BytesPending)).As(); diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/MrwSerializationTypeDefinitions/MrwSerializationTypeDefinitionTests.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/MrwSerializationTypeDefinitions/MrwSerializationTypeDefinitionTests.cs index e8bdabeb71..e328c3f819 100644 --- a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/MrwSerializationTypeDefinitions/MrwSerializationTypeDefinitionTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/MrwSerializationTypeDefinitions/MrwSerializationTypeDefinitionTests.cs @@ -712,7 +712,7 @@ public void TestIntSerializationStatement( [TestCase(typeof(sbyte), SerializationFormat.Default, ExpectedResult = "foo.GetSByte()")] public string TestIntDeserializeExpression(Type type, SerializationFormat format) { - var expr = MrwSerializationTypeDefinition.GetValueTypeDeserializationExpression(type, new ScopedApi(new VariableExpression(typeof(JsonElement), "foo")), format); + var expr = MrwSerializationTypeDefinition.GetValueTypeDeserializationExpressionCore(type, new ScopedApi(new VariableExpression(typeof(JsonElement), "foo")), format); return expr.ToDisplayString(); } diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/RestClientProviders/TestData/RestClientProviderCustomizationTests/CanChangeClientNamespace.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/RestClientProviders/TestData/RestClientProviderCustomizationTests/CanChangeClientNamespace.cs index 6025c67f49..3edc8fd499 100644 --- a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/RestClientProviders/TestData/RestClientProviderCustomizationTests/CanChangeClientNamespace.cs +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/RestClientProviders/TestData/RestClientProviderCustomizationTests/CanChangeClientNamespace.cs @@ -12,6 +12,7 @@ public partial class TestClient { private static global::System.ClientModel.Primitives.PipelineMessageClassifier _pipelineMessageClassifier200; private static global::System.ClientModel.Primitives.PipelineMessageClassifier _pipelineMessageClassifier201; + private static global::System.ClientModel.Primitives.PipelineMessageClassifier _pipelineMessageClassifier202; private static global::System.ClientModel.Primitives.PipelineMessageClassifier _pipelineMessageClassifier204; private static global::Sample.Custom.TestClient.Classifier2xxAnd4xx _pipelineMessageClassifier2xxAnd4xx; @@ -19,6 +20,8 @@ public partial class TestClient private static global::System.ClientModel.Primitives.PipelineMessageClassifier PipelineMessageClassifier201 => _pipelineMessageClassifier201 = global::System.ClientModel.Primitives.PipelineMessageClassifier.Create(stackalloc ushort[] { 201 }); + private static global::System.ClientModel.Primitives.PipelineMessageClassifier PipelineMessageClassifier202 => _pipelineMessageClassifier202 = global::System.ClientModel.Primitives.PipelineMessageClassifier.Create(stackalloc ushort[] { 202 }); + private static global::System.ClientModel.Primitives.PipelineMessageClassifier PipelineMessageClassifier204 => _pipelineMessageClassifier204 = global::System.ClientModel.Primitives.PipelineMessageClassifier.Create(stackalloc ushort[] { 204 }); private static global::Sample.Custom.TestClient.Classifier2xxAnd4xx PipelineMessageClassifier2xxAnd4xx => _pipelineMessageClassifier2xxAnd4xx ??= new global::Sample.Custom.TestClient.Classifier2xxAnd4xx(); diff --git a/packages/http-client-csharp/generator/TestProjects/Local/Unbranded-TypeSpec/src/Generated/UnbrandedTypeSpecClient.RestClient.cs b/packages/http-client-csharp/generator/TestProjects/Local/Unbranded-TypeSpec/src/Generated/UnbrandedTypeSpecClient.RestClient.cs index c58521c03a..5f6ddb7250 100644 --- a/packages/http-client-csharp/generator/TestProjects/Local/Unbranded-TypeSpec/src/Generated/UnbrandedTypeSpecClient.RestClient.cs +++ b/packages/http-client-csharp/generator/TestProjects/Local/Unbranded-TypeSpec/src/Generated/UnbrandedTypeSpecClient.RestClient.cs @@ -13,6 +13,7 @@ public partial class UnbrandedTypeSpecClient { private static PipelineMessageClassifier _pipelineMessageClassifier200; private static PipelineMessageClassifier _pipelineMessageClassifier201; + private static PipelineMessageClassifier _pipelineMessageClassifier202; private static PipelineMessageClassifier _pipelineMessageClassifier204; private static Classifier2xxAnd4xx _pipelineMessageClassifier2xxAnd4xx; @@ -20,6 +21,8 @@ public partial class UnbrandedTypeSpecClient private static PipelineMessageClassifier PipelineMessageClassifier201 => _pipelineMessageClassifier201 = PipelineMessageClassifier.Create(stackalloc ushort[] { 201 }); + private static PipelineMessageClassifier PipelineMessageClassifier202 => _pipelineMessageClassifier202 = PipelineMessageClassifier.Create(stackalloc ushort[] { 202 }); + private static PipelineMessageClassifier PipelineMessageClassifier204 => _pipelineMessageClassifier204 = PipelineMessageClassifier.Create(stackalloc ushort[] { 204 }); private static Classifier2xxAnd4xx PipelineMessageClassifier2xxAnd4xx => _pipelineMessageClassifier2xxAnd4xx ??= new Classifier2xxAnd4xx();