Skip to content

Commit

Permalink
Merge pull request #96 from romfir/source_gen_open_generics
Browse files Browse the repository at this point in the history
OneOf.SourceGenerator open generics
  • Loading branch information
mcintyre321 authored Sep 21, 2021
2 parents 86c5aaf + 67c1f88 commit 1c21cb4
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 71 deletions.
67 changes: 57 additions & 10 deletions OneOf.SourceGenerator.Tests/SourceGeneratorTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,26 +38,26 @@ public void GenerateOneOf_Can_Assign_To_Struct_Type()
}

[Fact]
public void GenerateOneOf_Generates_Correct_Classes_For_Referance_Types()
public void GenerateOneOf_Generates_Correct_Classes_For_Reference_Types()
{
MyClass testclass = new();
MyClass2 testclass2 = new();
MyClass testClass = new();
MyClass2 testClass2 = new();

MyClass2OrMyClass myClass2OrMyClass = testclass;
MyClass2OrMyClass myClass2OrMyClassToCompare = testclass;
MyClass2OrMyClass myClass2OrMyClass = testClass;
MyClass2OrMyClass myClass2OrMyClassToCompare = testClass;

Assert.Equal(testclass, (MyClass)myClass2OrMyClass);
Assert.Equal(testClass, (MyClass)myClass2OrMyClass);
Assert.Equal((MyClass)myClass2OrMyClass, (MyClass)myClass2OrMyClassToCompare);

myClass2OrMyClass = testclass2;
myClass2OrMyClassToCompare = testclass2;
myClass2OrMyClass = testClass2;
myClass2OrMyClassToCompare = testClass2;

Assert.Equal(testclass2, (MyClass2)myClass2OrMyClass);
Assert.Equal(testClass2, (MyClass2)myClass2OrMyClass);
Assert.Equal((MyClass2)myClass2OrMyClass, (MyClass2)myClass2OrMyClassToCompare);
}

[Fact]
public void GenerateOneOf_Can_Assign_To_Referance_Type()
public void GenerateOneOf_Can_Assign_To_Reference_Type()
{
MyClass testClass = new();

Expand Down Expand Up @@ -102,6 +102,33 @@ public void GenerateOneOf_Works_With_Nested_Generics()
NestedGeneric nested2 = new Dictionary<List<string>, string> { { new List<string> { "a", "b", "c" }, "d" } };
Assert.True(nested2.IsT2);
}

[Fact]
public void GenerateOneOf_Works_With_Open_Generics_With_Records()
{
OpenGenericWithRecords<MyClass, MyClass2> open = new Ok<MyClass>(new MyClass());
Assert.True(open.IsT0);

OpenGenericWithRecords<MyClass, MyClass2> open2 = new Error<MyClass2>(new MyClass2());
Assert.True(open2.IsT1);
}

[Fact]
public void GenerateOneOf_Works_With_Open_Generics_And_Nested_Generics()
{
OpenGenericWithRecords<List<int>, MyClass2> open = new Ok<List<int>>(new List<int> { 1, 2, 3 });
Assert.True(open.IsT0);
}

[Fact]
public void GenerateOneOf_Works_With_Open_And_Closed_Generics()
{
OpenGenericWithClosed<MyClass> openWithClosed = new Ok<MyClass>(new MyClass());
Assert.True(openWithClosed.IsT0);

OpenGenericWithClosed<MyClass2> openWithClosed2 = new MyClass();
Assert.True(openWithClosed2.IsT1);
}
}

[GenerateOneOf]
Expand All @@ -128,6 +155,26 @@ public class MyClass2
{

}

public record Error<TError>
(
TError ErrorData
);

public record Ok<TResult>
(
TResult Data
);

[GenerateOneOf]
public partial class OpenGenericWithRecords<TOk, TError> : OneOfBase<Ok<TOk>, Error<TError>>
{
}

[GenerateOneOf]
public partial class OpenGenericWithClosed<TOk> : OneOfBase<Ok<TOk>, MyClass>
{
}
}

namespace NotOneOf
Expand Down
35 changes: 35 additions & 0 deletions OneOf.SourceGenerator/GeneratorDiagnosticDescriptors.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
using Microsoft.CodeAnalysis;

namespace OneOf
{
internal class GeneratorDiagnosticDescriptors
{
public static readonly DiagnosticDescriptor TopLevelError = new(id: "ONEOFGEN001",
title: "Class must be top level",
messageFormat: "Class '{0}' using OneOfGenerator must be top level",
category: "OneOfGenerator",
DiagnosticSeverity.Error,
isEnabledByDefault: true);

public static readonly DiagnosticDescriptor WrongBaseType = new(id: "ONEOFGEN002",
title: "Class must inherit from OneOfBase",
messageFormat: "Class '{0}' does not inherit from OneOfBase",
category: "OneOfGenerator",
DiagnosticSeverity.Error,
isEnabledByDefault: true);

public static readonly DiagnosticDescriptor ClassIsNotPublic = new(id: "ONEOFGEN003",
title: "Class must be public",
messageFormat: "Class '{0}' is not public",
category: "OneOfGenerator",
DiagnosticSeverity.Error,
isEnabledByDefault: true);

public static readonly DiagnosticDescriptor ObjectIsOneOfType = new(id: "ONEOFGEN004",
title: "Object is not a valid type parameter",
messageFormat: "Defined conversions to or from a base type are not allowed for class '{0}'",
category: "OneOfGenerator",
DiagnosticSeverity.Error,
isEnabledByDefault: true);
}
}
129 changes: 68 additions & 61 deletions OneOf.SourceGenerator/OneOfGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,51 +13,23 @@ namespace OneOf.SourceGenerator
[Generator]
public class OneOfGenerator : ISourceGenerator
{
private const string AttributeName = "GenerateOneOfAttribute";
private const string AttributeNamespace = "OneOf";

private static readonly DiagnosticDescriptor _topLevelError = new(id: "ONEOFGEN001",
title: "Class must be top level",
messageFormat: "Class '{0}' using OneOfGenerator must be top level",
category: "OneOfGenerator",
DiagnosticSeverity.Error,
isEnabledByDefault: true);

private static readonly DiagnosticDescriptor _wrongBaseType = new(id: "ONEOFGEN002",
title: "Class must inherit from OneOfBase",
messageFormat: "Class '{0}' does not inherit from OneOfBase",
category: "OneOfGenerator",
DiagnosticSeverity.Error,
isEnabledByDefault: true);

private static readonly DiagnosticDescriptor _classIsNotPublic = new(id: "ONEOFGEN003",
title: "Class must be public",
messageFormat: "Class '{0}' is not public",
category: "OneOfGenerator",
DiagnosticSeverity.Error,
isEnabledByDefault: true);

private static readonly DiagnosticDescriptor _objectIsOneOfType = new(id: "ONEOFGEN004",
title: "Object is not a valid type parameter",
messageFormat: "Defined conversions to or from a base type are not allowed for class '{0}'",
category: "OneOfGenerator",
DiagnosticSeverity.Error,
isEnabledByDefault: true);

private const string _attributeName = "GenerateOneOfAttribute";
private const string _attributeNamespace = "OneOf";
private readonly string _attributeText = $@"using System;
namespace {_attributeNamespace}
namespace {AttributeNamespace}
{{
[AttributeUsage(AttributeTargets.Class, Inherited = false, AllowMultiple = false)]
public sealed class {_attributeName} : Attribute
public sealed class {AttributeName} : Attribute
{{
}}
}}
";

public void Execute(GeneratorExecutionContext context)
{
context.AddSource(_attributeName, SourceText.From(_attributeText, Encoding.UTF8));
context.AddSource(AttributeName, SourceText.From(_attributeText, Encoding.UTF8));

if (context.SyntaxReceiver is not OneOfSyntaxReceiver receiver)
{
Expand All @@ -69,9 +41,12 @@ public void Execute(GeneratorExecutionContext context)
return;
}

Compilation compilation = context.Compilation.AddSyntaxTrees(CSharpSyntaxTree.ParseText(SourceText.From(_attributeText, Encoding.UTF8), options));
Compilation compilation =
context.Compilation.AddSyntaxTrees(
CSharpSyntaxTree.ParseText(SourceText.From(_attributeText, Encoding.UTF8), options));

INamedTypeSymbol? attributeSymbol = compilation.GetTypeByMetadataName($"{_attributeNamespace}.{_attributeName}");
INamedTypeSymbol? attributeSymbol =
compilation.GetTypeByMetadataName($"{AttributeNamespace}.{AttributeName}");

if (attributeSymbol is null)
{
Expand All @@ -84,11 +59,13 @@ public void Execute(GeneratorExecutionContext context)
SemanticModel model = compilation.GetSemanticModel(classDeclaration.SyntaxTree);
INamedTypeSymbol? namedTypeSymbol = model.GetDeclaredSymbol(classDeclaration);

AttributeData? attributeData = namedTypeSymbol?.GetAttributes().FirstOrDefault(ad => ad.AttributeClass?.Equals(attributeSymbol, SymbolEqualityComparer.Default) != false);
AttributeData? attributeData = namedTypeSymbol?.GetAttributes().FirstOrDefault(ad =>
ad.AttributeClass?.Equals(attributeSymbol, SymbolEqualityComparer.Default) != false);

if (attributeData is object)
if (attributeData is not null)
{
namedTypeSymbols.Add((namedTypeSymbol!, attributeData.ApplicationSyntaxReference?.GetSyntax().GetLocation()));
namedTypeSymbols.Add((namedTypeSymbol!,
attributeData.ApplicationSyntaxReference?.GetSyntax().GetLocation()));
}
}

Expand All @@ -101,59 +78,76 @@ public void Execute(GeneratorExecutionContext context)
continue;
}

context.AddSource($"{namedSymbol.ContainingNamespace}_{namedSymbol.Name}.generated.cs", SourceText.From(classSource, Encoding.UTF8));
context.AddSource($"{namedSymbol.ContainingNamespace}_{namedSymbol.Name}.generated.cs",
SourceText.From(classSource, Encoding.UTF8));
}
}

private static string? ProcessClass(INamedTypeSymbol classSymbol, GeneratorExecutionContext context, Location? attributeLocation)
private static string? ProcessClass(INamedTypeSymbol classSymbol, GeneratorExecutionContext context,
Location? attributeLocation)
{
attributeLocation ??= Location.None;

if (!classSymbol.ContainingSymbol.Equals(classSymbol.ContainingNamespace, SymbolEqualityComparer.Default))
{
context.ReportDiagnostic(Diagnostic.Create(_topLevelError, attributeLocation, classSymbol.Name, DiagnosticSeverity.Error));
context.ReportDiagnostic(Diagnostic.Create(GeneratorDiagnosticDescriptors.TopLevelError,
attributeLocation, classSymbol.Name, DiagnosticSeverity.Error));
return null;
}

if (classSymbol.BaseType is null || classSymbol.BaseType.Name != "OneOfBase" || classSymbol.BaseType.ContainingNamespace.ToString() != "OneOf")
if (classSymbol.BaseType is null || classSymbol.BaseType.Name != "OneOfBase" ||
classSymbol.BaseType.ContainingNamespace.ToString() != "OneOf")
{
context.ReportDiagnostic(Diagnostic.Create(_wrongBaseType, attributeLocation, classSymbol.Name, DiagnosticSeverity.Error));
context.ReportDiagnostic(Diagnostic.Create(GeneratorDiagnosticDescriptors.WrongBaseType,
attributeLocation, classSymbol.Name, DiagnosticSeverity.Error));
return null;
}

if (classSymbol.DeclaredAccessibility != Accessibility.Public)
{
context.ReportDiagnostic(Diagnostic.Create(_classIsNotPublic, attributeLocation, classSymbol.Name, DiagnosticSeverity.Error));
context.ReportDiagnostic(Diagnostic.Create(GeneratorDiagnosticDescriptors.ClassIsNotPublic,
attributeLocation, classSymbol.Name, DiagnosticSeverity.Error));
return null;
}

ImmutableArray<ITypeParameterSymbol> typeParameters = classSymbol.BaseType.TypeParameters;
ImmutableArray<ITypeSymbol> typeArguments = classSymbol.BaseType.TypeArguments;

if (typeArguments.Any(x => x.Name == nameof(Object)))
{
context.ReportDiagnostic(Diagnostic.Create(_objectIsOneOfType, attributeLocation, classSymbol.Name, DiagnosticSeverity.Error));
context.ReportDiagnostic(Diagnostic.Create(GeneratorDiagnosticDescriptors.ObjectIsOneOfType,
attributeLocation, classSymbol.Name, DiagnosticSeverity.Error));
return null;
}

IEnumerable<(ITypeParameterSymbol param, ITypeSymbol arg)> paramArgPairs = typeParameters.Zip(typeArguments, (param, arg) => (param, arg));
return GenerateClassSource(classSymbol, classSymbol.BaseType.TypeParameters, typeArguments);
}

private static string GenerateClassSource(INamedTypeSymbol classSymbol,
ImmutableArray<ITypeParameterSymbol> typeParameters, ImmutableArray<ITypeSymbol> typeArguments)
{
IEnumerable<(ITypeParameterSymbol param, ITypeSymbol arg)> paramArgPairs =
typeParameters.Zip(typeArguments, (param, arg) => (param, arg));

string generics = string.Join(", ", typeArguments.Select(x => x.ToDisplayString()));
string oneOfGenericPart = GetGenericPart(typeArguments);

string classNameWithGenericTypes = $"{classSymbol.Name}{GetOpenGenericPart(classSymbol)}";

StringBuilder source = new($@"using System;
namespace {classSymbol.ContainingNamespace.ToDisplayString()}
{{
public partial class {classSymbol.Name}
public partial class {classNameWithGenericTypes}");

source.Append($@"
{{
public {classSymbol.Name}(OneOf.OneOf<{generics}> _) : base(_) {{ }}
public {classSymbol.Name}(OneOf.OneOf<{oneOfGenericPart}> _) : base(_) {{ }}
");

foreach ((ITypeParameterSymbol param, ITypeSymbol arg) in paramArgPairs)
{
source.Append($@"
public static implicit operator {classSymbol.Name}({arg.ToDisplayString()} _) => new {classSymbol.Name}(_);
public static explicit operator {arg.ToDisplayString()}({classSymbol.Name} _) => _.As{param.Name};
public static implicit operator {classNameWithGenericTypes}({arg.ToDisplayString()} _) => new {classNameWithGenericTypes}(_);
public static explicit operator {arg.ToDisplayString()}({classNameWithGenericTypes} _) => _.As{param.Name};
");
}

Expand All @@ -162,20 +156,33 @@ public partial class {classSymbol.Name}
return source.ToString();
}

private static string GetGenericPart(ImmutableArray<ITypeSymbol> typeArguments) =>
string.Join(", ", typeArguments.Select(x => x.ToDisplayString()));

private static string? GetOpenGenericPart(INamedTypeSymbol classSymbol)
{
if (!classSymbol.TypeArguments.Any())
{
return null;
}

return $"<{GetGenericPart(classSymbol.TypeArguments)}>";
}

public void Initialize(GeneratorInitializationContext context)
=> context.RegisterForSyntaxNotifications(() => new OneOfSyntaxReceiver());
}

internal class OneOfSyntaxReceiver : ISyntaxReceiver
{
public List<ClassDeclarationSyntax> CandidateClasses { get; } = new();

public void OnVisitSyntaxNode(SyntaxNode syntaxNode)
internal class OneOfSyntaxReceiver : ISyntaxReceiver
{
if (syntaxNode is ClassDeclarationSyntax { AttributeLists: { Count: > 0 } } classDeclarationSyntax
&& classDeclarationSyntax.Modifiers.Any(SyntaxKind.PartialKeyword))
public List<ClassDeclarationSyntax> CandidateClasses { get; } = new();

public void OnVisitSyntaxNode(SyntaxNode syntaxNode)
{
CandidateClasses.Add(classDeclarationSyntax);
if (syntaxNode is ClassDeclarationSyntax { AttributeLists: { Count: > 0 } } classDeclarationSyntax
&& classDeclarationSyntax.Modifiers.Any(SyntaxKind.PartialKeyword))
{
CandidateClasses.Add(classDeclarationSyntax);
}
}
}
}
Expand Down

0 comments on commit 1c21cb4

Please sign in to comment.