Files
Aaru/Aaru.Generators/PluginRegisterGenerator.cs

174 lines
9.4 KiB
C#
Raw Normal View History

2025-08-14 17:38:51 +01:00
#nullable enable
using System.Collections.Generic;
2025-08-14 17:38:51 +01:00
using System.Collections.Immutable;
using System.Linq;
using System.Text;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
2025-08-14 17:38:51 +01:00
using Microsoft.CodeAnalysis.Text;
namespace Aaru.Generators;
[Generator]
2025-08-14 17:38:51 +01:00
public sealed class PluginRegisterGenerator : IIncrementalGenerator
{
// your map of simple names → registration methods
static readonly Dictionary<string, string> PluginInterfaces = new()
2025-08-14 17:38:51 +01:00
{
["IArchive"] = "RegisterArchivePlugins",
["IChecksum"] = "RegisterChecksumPlugins",
["IFilesystem"] = "RegisterFilesystemPlugins",
["IFilter"] = "RegisterFilterPlugins",
["IFloppyImage"] = "RegisterFloppyImagePlugins",
["IMediaImage"] = "RegisterMediaImagePlugins",
["IPartition"] = "RegisterPartitionPlugins",
["IReadOnlyFilesystem"] = "RegisterReadOnlyFilesystemPlugins",
["IWritableFloppyImage"] = "RegisterWritableFloppyImagePlugins",
["IWritableImage"] = "RegisterWritableImagePlugins",
["IByteAddressableImage"] = "RegisterByteAddressablePlugins",
["IFluxImage"] = "RegisterFluxImagePlugins",
["IWritableFluxImage"] = "RegisterWritableFluxImagePlugins"
// …snip…
2025-08-14 17:38:51 +01:00
};
#region IIncrementalGenerator Members
public void Initialize(IncrementalGeneratorInitializationContext ctx)
2025-08-14 17:38:51 +01:00
{
// 1) pick up every class syntax with a baselist (so we only inspect ones that *could* have interfaces)
IncrementalValueProvider<ImmutableArray<ClassDeclarationSyntax>> syntaxProvider = ctx.SyntaxProvider
.CreateSyntaxProvider((node, ct) => node is ClassDeclarationSyntax cds && cds.BaseList != null,
(ctx, ct) => (ClassDeclarationSyntax)ctx.Node)
.Collect(); // gather them all
// 2) combine with the full Compilation so we can do symbol lookups
IncrementalValueProvider<(Compilation Left, ImmutableArray<ClassDeclarationSyntax> Right)>
compilationAndClasses = ctx.CompilationProvider.Combine(syntaxProvider);
// 3) finally generate source
ctx.RegisterSourceOutput(compilationAndClasses,
(spc, pair) =>
{
(Compilation? compilation, ImmutableArray<ClassDeclarationSyntax> classDecls) =
pair;
// locate the interface symbols by metadata name once
(string Name, string Method, INamedTypeSymbol? Symbol)[] interfaceSymbols =
PluginInterfaces
.Select(kvp => (Name: kvp.Key, Method: kvp.Value,
Symbol: compilation
.GetTypeByMetadataName($"Aaru.CommonTypes.Interfaces.{kvp.Key}")))
.Where(x => x.Symbol != null)
.ToArray();
// find the one IPluginRegister type as well
INamedTypeSymbol? registerIf =
compilation
.GetTypeByMetadataName("Aaru.CommonTypes.Interfaces.IPluginRegister");
// collect info
var plugins = new List<PluginInfo>();
foreach(ClassDeclarationSyntax? classDecl in classDecls.Distinct())
{
SemanticModel model = compilation.GetSemanticModel(classDecl.SyntaxTree);
var symbol =
model.GetDeclaredSymbol(classDecl, spc.CancellationToken) as
INamedTypeSymbol;
if(symbol is null) continue;
// which interfaces does it *actually* implement (direct + indirect)?
ImmutableArray<INamedTypeSymbol> allIfaces = symbol.AllInterfaces;
// diagnostics to verify were seeing the right interfaces
foreach(INamedTypeSymbol? iface in allIfaces)
{
spc.ReportDiagnostic(Diagnostic.Create(new DiagnosticDescriptor("PLGN001",
"Found interface",
$"Class {symbol.Name} implements {iface.ToDisplayString()}",
"PluginGen",
DiagnosticSeverity.Info,
true),
classDecl.GetLocation()));
}
var info = new PluginInfo
{
Namespace = symbol.ContainingNamespace.ToDisplayString(),
ClassName = symbol.Name,
IsRegister =
registerIf != null &&
allIfaces.Contains(registerIf, SymbolEqualityComparer.Default)
};
// pick up every plugininterface your map knows about
foreach((string Name, string Method, INamedTypeSymbol? Symbol) in
interfaceSymbols)
{
if(SymbolEqualityComparer.Default.Equals(Symbol, null)) continue;
if(allIfaces.Contains(Symbol, SymbolEqualityComparer.Default))
info.Interfaces.Add(Name);
}
if(info.IsRegister || info.Interfaces.Count > 0) plugins.Add(info);
}
// nothing to do
if(plugins.Count == 0) return;
// find the one class that implements IPluginRegister
PluginInfo? regCls = plugins.FirstOrDefault(p => p.IsRegister);
if(regCls == null) return;
// build the generated file
var sb = new StringBuilder();
sb.AppendLine("using Microsoft.Extensions.DependencyInjection;");
sb.AppendLine("using Aaru.CommonTypes.Interfaces;");
sb.AppendLine($"namespace {regCls.Namespace};");
sb.AppendLine($"public sealed partial class {regCls.ClassName} : IPluginRegister");
sb.AppendLine("{");
foreach(KeyValuePair<string, string> kvp in PluginInterfaces)
{
// grab all classes that implement this interface
IEnumerable<string> implementations = plugins
.Where(pi =>
pi.Interfaces
.Contains(kvp.Key))
.Select(pi => pi.ClassName)
.Distinct();
sb.AppendLine($" public void {kvp.Value}(IServiceCollection services)");
sb.AppendLine(" {");
foreach(string? impl in implementations)
sb.AppendLine($" services.AddTransient<{kvp.Key}, {impl}>();");
sb.AppendLine(" }");
}
sb.AppendLine("}");
spc.AddSource("Register.g.cs", SourceText.From(sb.ToString(), Encoding.UTF8));
});
2025-08-14 17:38:51 +01:00
}
2025-08-14 17:38:51 +01:00
#endregion
2025-08-14 17:38:51 +01:00
#region Nested type: PluginInfo
2023-10-03 23:24:05 +01:00
class PluginInfo
{
public readonly List<string> Interfaces = new();
public string ClassName = "";
public bool IsRegister;
public string Namespace = "";
}
2023-10-03 23:24:05 +01:00
#endregion
}