diff --git a/Aaru.Generators/PluginRegisterGenerator.cs b/Aaru.Generators/PluginRegisterGenerator.cs index d14be89f3..68b024ad5 100644 --- a/Aaru.Generators/PluginRegisterGenerator.cs +++ b/Aaru.Generators/PluginRegisterGenerator.cs @@ -12,7 +12,8 @@ namespace Aaru.Generators; [Generator] public sealed class PluginRegisterGenerator : IIncrementalGenerator { - private static readonly Dictionary PluginInterfaces = new() + // your map of simple names → registration methods + static readonly Dictionary PluginInterfaces = new() { ["IArchive"] = "RegisterArchivePlugins", ["IChecksum"] = "RegisterChecksumPlugins", @@ -27,142 +28,146 @@ public sealed class PluginRegisterGenerator : IIncrementalGenerator ["IByteAddressableImage"] = "RegisterByteAddressablePlugins", ["IFluxImage"] = "RegisterFluxImagePlugins", ["IWritableFluxImage"] = "RegisterWritableFluxImagePlugins" + + // …snip… }; #region IIncrementalGenerator Members - public void Initialize(IncrementalGeneratorInitializationContext context) + public void Initialize(IncrementalGeneratorInitializationContext ctx) { - IncrementalValueProvider> pluginClasses = context.SyntaxProvider - .CreateSyntaxProvider(static (node, _) => node is ClassDeclarationSyntax, - static (ctx, _) => GetPluginInfo(ctx)) - .Where(static info => info is not null) - .Collect(); + // 1) pick up every class syntax with a base‐list (so we only inspect ones that *could* have interfaces) + IncrementalValueProvider> syntaxProvider = ctx.SyntaxProvider + .CreateSyntaxProvider((node, ct) => node is ClassDeclarationSyntax cds && cds.BaseList != null, + (ctx, ct) => (ClassDeclarationSyntax)ctx.Node) + .Collect(); // gather them all - context.RegisterSourceOutput(pluginClasses, (ctx, pluginInfos) => GeneratePluginRegister(ctx, pluginInfos!)); + // 2) combine with the full Compilation so we can do symbol lookups + IncrementalValueProvider<(Compilation Left, ImmutableArray Right)> + compilationAndClasses = ctx.CompilationProvider.Combine(syntaxProvider); + + // 3) finally generate source + ctx.RegisterSourceOutput(compilationAndClasses, + (spc, pair) => + { + (Compilation? compilation, ImmutableArray classDecls) = + pair; + + // locate the interface symbols by metadata name once + (string Name, string Method, INamedTypeSymbol? Symbol)[] interfaceSymbols = + PluginInterfaces + .Select(kvp => (Name: kvp.Key, Method: kvp.Value, + Symbol: compilation + .GetTypeByMetadataName($"Aaru.CommonTypes.Interfaces.{kvp.Key}"))) + .Where(x => x.Symbol != null) + .ToArray(); + + // find the one IPluginRegister type as well + INamedTypeSymbol? registerIf = + compilation + .GetTypeByMetadataName("Aaru.CommonTypes.Interfaces.IPluginRegister"); + + // collect info + var plugins = new List(); + + foreach(ClassDeclarationSyntax? classDecl in classDecls.Distinct()) + { + SemanticModel model = compilation.GetSemanticModel(classDecl.SyntaxTree); + + var symbol = + model.GetDeclaredSymbol(classDecl, spc.CancellationToken) as + INamedTypeSymbol; + + if(symbol is null) continue; + + // which interfaces does it *actually* implement (direct + indirect)? + ImmutableArray allIfaces = symbol.AllInterfaces; + + // diagnostics to verify we’re seeing the right interfaces + foreach(INamedTypeSymbol? iface in allIfaces) + { + spc.ReportDiagnostic(Diagnostic.Create(new DiagnosticDescriptor("PLGN001", + "Found interface", + $"Class {symbol.Name} implements {iface.ToDisplayString()}", + "PluginGen", + DiagnosticSeverity.Info, + true), + classDecl.GetLocation())); + } + + var info = new PluginInfo + { + Namespace = symbol.ContainingNamespace.ToDisplayString(), + ClassName = symbol.Name, + IsRegister = + registerIf != null && + allIfaces.Contains(registerIf, SymbolEqualityComparer.Default) + }; + + // pick up every plugin‐interface your map knows about + foreach((string Name, string Method, INamedTypeSymbol? Symbol) in + interfaceSymbols) + { + if(SymbolEqualityComparer.Default.Equals(Symbol, null)) continue; + + if(allIfaces.Contains(Symbol, SymbolEqualityComparer.Default)) + info.Interfaces.Add(Name); + } + + if(info.IsRegister || info.Interfaces.Count > 0) plugins.Add(info); + } + + // nothing to do + if(plugins.Count == 0) return; + + // find the one class that implements IPluginRegister + PluginInfo? regCls = plugins.FirstOrDefault(p => p.IsRegister); + + if(regCls == null) return; + + // build the generated file + var sb = new StringBuilder(); + sb.AppendLine("using Microsoft.Extensions.DependencyInjection;"); + sb.AppendLine("using Aaru.CommonTypes.Interfaces;"); + sb.AppendLine($"namespace {regCls.Namespace};"); + sb.AppendLine($"public sealed partial class {regCls.ClassName} : IPluginRegister"); + sb.AppendLine("{"); + + foreach(KeyValuePair kvp in PluginInterfaces) + { + // grab all classes that implement this interface + IEnumerable implementations = plugins + .Where(pi => + pi.Interfaces + .Contains(kvp.Key)) + .Select(pi => pi.ClassName) + .Distinct(); + + sb.AppendLine($" public void {kvp.Value}(IServiceCollection services)"); + sb.AppendLine(" {"); + + foreach(string? impl in implementations) + sb.AppendLine($" services.AddTransient<{kvp.Key}, {impl}>();"); + + sb.AppendLine(" }"); + } + + sb.AppendLine("}"); + + spc.AddSource("Register.g.cs", SourceText.From(sb.ToString(), Encoding.UTF8)); + }); } #endregion - private static PluginInfo? GetPluginInfo(GeneratorSyntaxContext context) - { - if(context.Node is not ClassDeclarationSyntax classDecl) return null; - - var info = new PluginInfo - { - ClassName = classDecl.Identifier.Text, - Namespace = GetNamespace(classDecl), - IsRegister = ImplementsInterface(classDecl, "IPluginRegister") - }; - - foreach(string? iface in PluginInterfaces.Keys) - { - if(ImplementsInterface(classDecl, iface)) info.Interfaces.Add(iface); - } - - if(info is { IsRegister: false, Interfaces.Count: 0 }) return null; - - return info; - } - - private static bool ImplementsInterface(ClassDeclarationSyntax classDecl, string interfaceName) - { - return classDecl.BaseList?.Types.Any(t => (t.Type as IdentifierNameSyntax)?.Identifier.ValueText == - interfaceName) == - true; - } - - private static string? GetNamespace(SyntaxNode node) => - node.Ancestors().OfType().FirstOrDefault()?.Name.ToString(); - - private static void GeneratePluginRegister(SourceProductionContext context, IReadOnlyList pluginInfos) - { - PluginInfo? registerClass = pluginInfos.FirstOrDefault(p => p.IsRegister); - - if(registerClass is null) return; - - var sb = new StringBuilder(); - - sb.AppendLine(""" - // /*************************************************************************** - // Aaru Data Preservation Suite - // ---------------------------------------------------------------------------- - // - // Filename : Register.g.cs - // Author(s) : Natalia Portillo - // - // --[ Description ] ---------------------------------------------------------- - // - // Registers all plugins in this assembly. - // - // --[ License ] -------------------------------------------------------------- - // - // Permission is hereby granted, free of charge, to any person obtaining a - // copy of this software and associated documentation files (the - // "Software"), to deal in the Software without restriction, including - // without limitation the rights to use, copy, modify, merge, publish, - // distribute, sublicense, and/or sell copies of the Software, and to - // permit persons to whom the Software is furnished to do so, subject to - // the following conditions: - // - // The above copyright notice and this permission notice shall be included - // in all copies or substantial portions of the Software. - // - // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS - // OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF - // MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. - // IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY - // CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, - // TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE - // SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - // - // ---------------------------------------------------------------------------- - // Copyright © 2011-2025 Natalia Portillo - // ****************************************************************************/ - """); - - sb.AppendLine("using System;"); - sb.AppendLine("using System.Collections.Generic;"); - sb.AppendLine("using Microsoft.Extensions.DependencyInjection;"); - sb.AppendLine("using Aaru.CommonTypes.Interfaces;"); - sb.AppendLine(); - sb.AppendLine($"namespace {registerClass.Namespace};"); - sb.AppendLine(); - sb.AppendLine($"public sealed partial class {registerClass.ClassName} : IPluginRegister"); - sb.AppendLine("{"); - - foreach(KeyValuePair kvp in PluginInterfaces) - { - string? interfaceName = kvp.Key; - string? methodName = kvp.Value; - - var plugins = pluginInfos.Where(p => p.Interfaces.Contains(interfaceName)) - .Select(p => p.ClassName) - .Distinct() - .ToList(); - - sb.AppendLine($" public void {methodName}(IServiceCollection services)"); - sb.AppendLine(" {"); - - foreach(string? plugin in plugins) - sb.AppendLine($" services.AddTransient<{interfaceName}, {plugin}>();"); - - sb.AppendLine(" }"); - } - - sb.AppendLine("}"); - - context.AddSource("Register.g.cs", SourceText.From(sb.ToString(), Encoding.UTF8)); - } - #region Nested type: PluginInfo - private sealed class PluginInfo + class PluginInfo { - public string? Namespace { get; set; } - public string ClassName { get; set; } = ""; - public bool IsRegister { get; set; } - public List Interfaces { get; } = []; + public readonly List Interfaces = new(); + public string ClassName = ""; + public bool IsRegister; + public string Namespace = ""; } #endregion