From 67c1f88bb2b3af336da4ccfd78a6def6bfcdb08e Mon Sep 17 00:00:00 2001 From: Damian Romanowski Date: Fri, 10 Sep 2021 10:20:41 +0200 Subject: [PATCH] added support for open generics with OneOf.SourceGenerator + some minor refactoring --- .../SourceGeneratorTests.cs | 67 +++++++-- .../GeneratorDiagnosticDescriptors.cs | 35 +++++ OneOf.SourceGenerator/OneOfGenerator.cs | 129 +++++++++--------- 3 files changed, 160 insertions(+), 71 deletions(-) create mode 100644 OneOf.SourceGenerator/GeneratorDiagnosticDescriptors.cs diff --git a/OneOf.SourceGenerator.Tests/SourceGeneratorTests.cs b/OneOf.SourceGenerator.Tests/SourceGeneratorTests.cs index 93a4a4d..fa0b73e 100644 --- a/OneOf.SourceGenerator.Tests/SourceGeneratorTests.cs +++ b/OneOf.SourceGenerator.Tests/SourceGeneratorTests.cs @@ -38,26 +38,26 @@ public void GenerateOneOf_Can_Assign_To_Struct_Type() } [Fact] - public void GenerateOneOf_Generates_Correct_Classes_For_Referance_Types() + public void GenerateOneOf_Generates_Correct_Classes_For_Reference_Types() { - MyClass testclass = new(); - MyClass2 testclass2 = new(); + MyClass testClass = new(); + MyClass2 testClass2 = new(); - MyClass2OrMyClass myClass2OrMyClass = testclass; - MyClass2OrMyClass myClass2OrMyClassToCompare = testclass; + MyClass2OrMyClass myClass2OrMyClass = testClass; + MyClass2OrMyClass myClass2OrMyClassToCompare = testClass; - Assert.Equal(testclass, (MyClass)myClass2OrMyClass); + Assert.Equal(testClass, (MyClass)myClass2OrMyClass); Assert.Equal((MyClass)myClass2OrMyClass, (MyClass)myClass2OrMyClassToCompare); - myClass2OrMyClass = testclass2; - myClass2OrMyClassToCompare = testclass2; + myClass2OrMyClass = testClass2; + myClass2OrMyClassToCompare = testClass2; - Assert.Equal(testclass2, (MyClass2)myClass2OrMyClass); + Assert.Equal(testClass2, (MyClass2)myClass2OrMyClass); Assert.Equal((MyClass2)myClass2OrMyClass, (MyClass2)myClass2OrMyClassToCompare); } [Fact] - public void GenerateOneOf_Can_Assign_To_Referance_Type() + public void GenerateOneOf_Can_Assign_To_Reference_Type() { MyClass testClass = new(); @@ -102,6 +102,33 @@ public void GenerateOneOf_Works_With_Nested_Generics() NestedGeneric nested2 = new Dictionary, string> { { new List { "a", "b", "c" }, "d" } }; Assert.True(nested2.IsT2); } + + [Fact] + public void GenerateOneOf_Works_With_Open_Generics_With_Records() + { + OpenGenericWithRecords open = new Ok(new MyClass()); + Assert.True(open.IsT0); + + OpenGenericWithRecords open2 = new Error(new MyClass2()); + Assert.True(open2.IsT1); + } + + [Fact] + public void GenerateOneOf_Works_With_Open_Generics_And_Nested_Generics() + { + OpenGenericWithRecords, MyClass2> open = new Ok>(new List { 1, 2, 3 }); + Assert.True(open.IsT0); + } + + [Fact] + public void GenerateOneOf_Works_With_Open_And_Closed_Generics() + { + OpenGenericWithClosed openWithClosed = new Ok(new MyClass()); + Assert.True(openWithClosed.IsT0); + + OpenGenericWithClosed openWithClosed2 = new MyClass(); + Assert.True(openWithClosed2.IsT1); + } } [GenerateOneOf] @@ -128,6 +155,26 @@ public class MyClass2 { } + + public record Error + ( + TError ErrorData + ); + + public record Ok + ( + TResult Data + ); + + [GenerateOneOf] + public partial class OpenGenericWithRecords : OneOfBase, Error> + { + } + + [GenerateOneOf] + public partial class OpenGenericWithClosed : OneOfBase, MyClass> + { + } } namespace NotOneOf diff --git a/OneOf.SourceGenerator/GeneratorDiagnosticDescriptors.cs b/OneOf.SourceGenerator/GeneratorDiagnosticDescriptors.cs new file mode 100644 index 0000000..b33fd64 --- /dev/null +++ b/OneOf.SourceGenerator/GeneratorDiagnosticDescriptors.cs @@ -0,0 +1,35 @@ +using Microsoft.CodeAnalysis; + +namespace OneOf +{ + internal class GeneratorDiagnosticDescriptors + { + public static readonly DiagnosticDescriptor TopLevelError = new(id: "ONEOFGEN001", + title: "Class must be top level", + messageFormat: "Class '{0}' using OneOfGenerator must be top level", + category: "OneOfGenerator", + DiagnosticSeverity.Error, + isEnabledByDefault: true); + + public static readonly DiagnosticDescriptor WrongBaseType = new(id: "ONEOFGEN002", + title: "Class must inherit from OneOfBase", + messageFormat: "Class '{0}' does not inherit from OneOfBase", + category: "OneOfGenerator", + DiagnosticSeverity.Error, + isEnabledByDefault: true); + + public static readonly DiagnosticDescriptor ClassIsNotPublic = new(id: "ONEOFGEN003", + title: "Class must be public", + messageFormat: "Class '{0}' is not public", + category: "OneOfGenerator", + DiagnosticSeverity.Error, + isEnabledByDefault: true); + + public static readonly DiagnosticDescriptor ObjectIsOneOfType = new(id: "ONEOFGEN004", + title: "Object is not a valid type parameter", + messageFormat: "Defined conversions to or from a base type are not allowed for class '{0}'", + category: "OneOfGenerator", + DiagnosticSeverity.Error, + isEnabledByDefault: true); + } +} diff --git a/OneOf.SourceGenerator/OneOfGenerator.cs b/OneOf.SourceGenerator/OneOfGenerator.cs index fe4ba21..f9d5362 100644 --- a/OneOf.SourceGenerator/OneOfGenerator.cs +++ b/OneOf.SourceGenerator/OneOfGenerator.cs @@ -13,43 +13,15 @@ namespace OneOf.SourceGenerator [Generator] public class OneOfGenerator : ISourceGenerator { + private const string AttributeName = "GenerateOneOfAttribute"; + private const string AttributeNamespace = "OneOf"; - private static readonly DiagnosticDescriptor _topLevelError = new(id: "ONEOFGEN001", - title: "Class must be top level", - messageFormat: "Class '{0}' using OneOfGenerator must be top level", - category: "OneOfGenerator", - DiagnosticSeverity.Error, - isEnabledByDefault: true); - - private static readonly DiagnosticDescriptor _wrongBaseType = new(id: "ONEOFGEN002", - title: "Class must inherit from OneOfBase", - messageFormat: "Class '{0}' does not inherit from OneOfBase", - category: "OneOfGenerator", - DiagnosticSeverity.Error, - isEnabledByDefault: true); - - private static readonly DiagnosticDescriptor _classIsNotPublic = new(id: "ONEOFGEN003", - title: "Class must be public", - messageFormat: "Class '{0}' is not public", - category: "OneOfGenerator", - DiagnosticSeverity.Error, - isEnabledByDefault: true); - - private static readonly DiagnosticDescriptor _objectIsOneOfType = new(id: "ONEOFGEN004", - title: "Object is not a valid type parameter", - messageFormat: "Defined conversions to or from a base type are not allowed for class '{0}'", - category: "OneOfGenerator", - DiagnosticSeverity.Error, - isEnabledByDefault: true); - - private const string _attributeName = "GenerateOneOfAttribute"; - private const string _attributeNamespace = "OneOf"; private readonly string _attributeText = $@"using System; -namespace {_attributeNamespace} +namespace {AttributeNamespace} {{ [AttributeUsage(AttributeTargets.Class, Inherited = false, AllowMultiple = false)] - public sealed class {_attributeName} : Attribute + public sealed class {AttributeName} : Attribute {{ }} }} @@ -57,7 +29,7 @@ public sealed class {_attributeName} : Attribute public void Execute(GeneratorExecutionContext context) { - context.AddSource(_attributeName, SourceText.From(_attributeText, Encoding.UTF8)); + context.AddSource(AttributeName, SourceText.From(_attributeText, Encoding.UTF8)); if (context.SyntaxReceiver is not OneOfSyntaxReceiver receiver) { @@ -69,9 +41,12 @@ public void Execute(GeneratorExecutionContext context) return; } - Compilation compilation = context.Compilation.AddSyntaxTrees(CSharpSyntaxTree.ParseText(SourceText.From(_attributeText, Encoding.UTF8), options)); + Compilation compilation = + context.Compilation.AddSyntaxTrees( + CSharpSyntaxTree.ParseText(SourceText.From(_attributeText, Encoding.UTF8), options)); - INamedTypeSymbol? attributeSymbol = compilation.GetTypeByMetadataName($"{_attributeNamespace}.{_attributeName}"); + INamedTypeSymbol? attributeSymbol = + compilation.GetTypeByMetadataName($"{AttributeNamespace}.{AttributeName}"); if (attributeSymbol is null) { @@ -84,11 +59,13 @@ public void Execute(GeneratorExecutionContext context) SemanticModel model = compilation.GetSemanticModel(classDeclaration.SyntaxTree); INamedTypeSymbol? namedTypeSymbol = model.GetDeclaredSymbol(classDeclaration); - AttributeData? attributeData = namedTypeSymbol?.GetAttributes().FirstOrDefault(ad => ad.AttributeClass?.Equals(attributeSymbol, SymbolEqualityComparer.Default) != false); + AttributeData? attributeData = namedTypeSymbol?.GetAttributes().FirstOrDefault(ad => + ad.AttributeClass?.Equals(attributeSymbol, SymbolEqualityComparer.Default) != false); - if (attributeData is object) + if (attributeData is not null) { - namedTypeSymbols.Add((namedTypeSymbol!, attributeData.ApplicationSyntaxReference?.GetSyntax().GetLocation())); + namedTypeSymbols.Add((namedTypeSymbol!, + attributeData.ApplicationSyntaxReference?.GetSyntax().GetLocation())); } } @@ -101,59 +78,76 @@ public void Execute(GeneratorExecutionContext context) continue; } - context.AddSource($"{namedSymbol.ContainingNamespace}_{namedSymbol.Name}.generated.cs", SourceText.From(classSource, Encoding.UTF8)); + context.AddSource($"{namedSymbol.ContainingNamespace}_{namedSymbol.Name}.generated.cs", + SourceText.From(classSource, Encoding.UTF8)); } } - private static string? ProcessClass(INamedTypeSymbol classSymbol, GeneratorExecutionContext context, Location? attributeLocation) + private static string? ProcessClass(INamedTypeSymbol classSymbol, GeneratorExecutionContext context, + Location? attributeLocation) { attributeLocation ??= Location.None; if (!classSymbol.ContainingSymbol.Equals(classSymbol.ContainingNamespace, SymbolEqualityComparer.Default)) { - context.ReportDiagnostic(Diagnostic.Create(_topLevelError, attributeLocation, classSymbol.Name, DiagnosticSeverity.Error)); + context.ReportDiagnostic(Diagnostic.Create(GeneratorDiagnosticDescriptors.TopLevelError, + attributeLocation, classSymbol.Name, DiagnosticSeverity.Error)); return null; } - if (classSymbol.BaseType is null || classSymbol.BaseType.Name != "OneOfBase" || classSymbol.BaseType.ContainingNamespace.ToString() != "OneOf") + if (classSymbol.BaseType is null || classSymbol.BaseType.Name != "OneOfBase" || + classSymbol.BaseType.ContainingNamespace.ToString() != "OneOf") { - context.ReportDiagnostic(Diagnostic.Create(_wrongBaseType, attributeLocation, classSymbol.Name, DiagnosticSeverity.Error)); + context.ReportDiagnostic(Diagnostic.Create(GeneratorDiagnosticDescriptors.WrongBaseType, + attributeLocation, classSymbol.Name, DiagnosticSeverity.Error)); return null; } if (classSymbol.DeclaredAccessibility != Accessibility.Public) { - context.ReportDiagnostic(Diagnostic.Create(_classIsNotPublic, attributeLocation, classSymbol.Name, DiagnosticSeverity.Error)); + context.ReportDiagnostic(Diagnostic.Create(GeneratorDiagnosticDescriptors.ClassIsNotPublic, + attributeLocation, classSymbol.Name, DiagnosticSeverity.Error)); return null; } - ImmutableArray typeParameters = classSymbol.BaseType.TypeParameters; ImmutableArray typeArguments = classSymbol.BaseType.TypeArguments; if (typeArguments.Any(x => x.Name == nameof(Object))) { - context.ReportDiagnostic(Diagnostic.Create(_objectIsOneOfType, attributeLocation, classSymbol.Name, DiagnosticSeverity.Error)); + context.ReportDiagnostic(Diagnostic.Create(GeneratorDiagnosticDescriptors.ObjectIsOneOfType, + attributeLocation, classSymbol.Name, DiagnosticSeverity.Error)); return null; } - IEnumerable<(ITypeParameterSymbol param, ITypeSymbol arg)> paramArgPairs = typeParameters.Zip(typeArguments, (param, arg) => (param, arg)); + return GenerateClassSource(classSymbol, classSymbol.BaseType.TypeParameters, typeArguments); + } + + private static string GenerateClassSource(INamedTypeSymbol classSymbol, + ImmutableArray typeParameters, ImmutableArray typeArguments) + { + IEnumerable<(ITypeParameterSymbol param, ITypeSymbol arg)> paramArgPairs = + typeParameters.Zip(typeArguments, (param, arg) => (param, arg)); - string generics = string.Join(", ", typeArguments.Select(x => x.ToDisplayString())); + string oneOfGenericPart = GetGenericPart(typeArguments); + + string classNameWithGenericTypes = $"{classSymbol.Name}{GetOpenGenericPart(classSymbol)}"; StringBuilder source = new($@"using System; namespace {classSymbol.ContainingNamespace.ToDisplayString()} {{ - public partial class {classSymbol.Name} + public partial class {classNameWithGenericTypes}"); + + source.Append($@" {{ - public {classSymbol.Name}(OneOf.OneOf<{generics}> _) : base(_) {{ }} + public {classSymbol.Name}(OneOf.OneOf<{oneOfGenericPart}> _) : base(_) {{ }} "); foreach ((ITypeParameterSymbol param, ITypeSymbol arg) in paramArgPairs) { source.Append($@" - public static implicit operator {classSymbol.Name}({arg.ToDisplayString()} _) => new {classSymbol.Name}(_); - public static explicit operator {arg.ToDisplayString()}({classSymbol.Name} _) => _.As{param.Name}; + public static implicit operator {classNameWithGenericTypes}({arg.ToDisplayString()} _) => new {classNameWithGenericTypes}(_); + public static explicit operator {arg.ToDisplayString()}({classNameWithGenericTypes} _) => _.As{param.Name}; "); } @@ -162,20 +156,33 @@ public partial class {classSymbol.Name} return source.ToString(); } + private static string GetGenericPart(ImmutableArray typeArguments) => + string.Join(", ", typeArguments.Select(x => x.ToDisplayString())); + + private static string? GetOpenGenericPart(INamedTypeSymbol classSymbol) + { + if (!classSymbol.TypeArguments.Any()) + { + return null; + } + + return $"<{GetGenericPart(classSymbol.TypeArguments)}>"; + } + public void Initialize(GeneratorInitializationContext context) => context.RegisterForSyntaxNotifications(() => new OneOfSyntaxReceiver()); - } - internal class OneOfSyntaxReceiver : ISyntaxReceiver - { - public List CandidateClasses { get; } = new(); - - public void OnVisitSyntaxNode(SyntaxNode syntaxNode) + internal class OneOfSyntaxReceiver : ISyntaxReceiver { - if (syntaxNode is ClassDeclarationSyntax { AttributeLists: { Count: > 0 } } classDeclarationSyntax - && classDeclarationSyntax.Modifiers.Any(SyntaxKind.PartialKeyword)) + public List CandidateClasses { get; } = new(); + + public void OnVisitSyntaxNode(SyntaxNode syntaxNode) { - CandidateClasses.Add(classDeclarationSyntax); + if (syntaxNode is ClassDeclarationSyntax { AttributeLists: { Count: > 0 } } classDeclarationSyntax + && classDeclarationSyntax.Modifiers.Any(SyntaxKind.PartialKeyword)) + { + CandidateClasses.Add(classDeclarationSyntax); + } } } }