Fix writable images not being recognized.

This commit is contained in:
2025-08-17 01:08:19 +01:00
parent 8279a1524b
commit a7f415c54a
3 changed files with 79 additions and 91 deletions

View File

@@ -1,7 +1,7 @@
<component name="ProjectRunConfigurationManager"> <component name="ProjectRunConfigurationManager">
<configuration default="false" name="Aaru" type="DotNetProject" factoryName=".NET Project"> <configuration default="false" name="Aaru" type="DotNetProject" factoryName=".NET Project">
<option name="EXE_PATH" value="$PROJECT_DIR$/Aaru/bin/Debug/net10.0/aaru" /> <option name="EXE_PATH" value="$PROJECT_DIR$/Aaru/bin/Debug/net10.0/aaru" />
<option name="PROGRAM_PARAMETERS" value="fs --pause --logfile mylog.log ls qcow.qc.lz" /> <option name="PROGRAM_PARAMETERS" value="formats" />
<option name="WORKING_DIRECTORY" value="/mnt/AaruTests/Media image formats/QEMU/QEMU Copy On Write" /> <option name="WORKING_DIRECTORY" value="/mnt/AaruTests/Media image formats/QEMU/QEMU Copy On Write" />
<option name="PASS_PARENT_ENVS" value="1" /> <option name="PASS_PARENT_ENVS" value="1" />
<option name="USE_EXTERNAL_CONSOLE" value="0" /> <option name="USE_EXTERNAL_CONSOLE" value="0" />

View File

@@ -64,13 +64,13 @@ public class PluginRegister
} }
/// <summary>List of writable media image plugins</summary> /// <summary>List of writable media image plugins</summary>
public SortedDictionary<string, IBaseWritableImage> WritableImages public SortedDictionary<string, IWritableImage> WritableImages
{ {
get get
{ {
SortedDictionary<string, IBaseWritableImage> mediaImages = new(); SortedDictionary<string, IWritableImage> mediaImages = new();
foreach(IBaseWritableImage plugin in _serviceProvider.GetServices<IBaseWritableImage>()) foreach(IWritableImage plugin in _serviceProvider.GetServices<IWritableImage>())
mediaImages[plugin.Name.ToLower()] = plugin; mediaImages[plugin.Name.ToLower()] = plugin;
return mediaImages; return mediaImages;

View File

@@ -12,148 +12,136 @@ namespace Aaru.Generators;
[Generator] [Generator]
public sealed class PluginRegisterGenerator : IIncrementalGenerator public sealed class PluginRegisterGenerator : IIncrementalGenerator
{ {
// your map of simple names → registration methods // name → (registration method, directOnly)
static readonly Dictionary<string, string> PluginInterfaces = new() static readonly (string Name, string Method, bool DirectOnly)[] PluginMap = new[]
{ {
["IArchive"] = "RegisterArchivePlugins", ("IArchive", "RegisterArchivePlugins", true), ("IChecksum", "RegisterChecksumPlugins", true),
["IChecksum"] = "RegisterChecksumPlugins", ("IFilesystem", "RegisterFilesystemPlugins", true), ("IFilter", "RegisterFilterPlugins", true),
["IFilesystem"] = "RegisterFilesystemPlugins", ("IFloppyImage", "RegisterFloppyImagePlugins", true),
["IFilter"] = "RegisterFilterPlugins", ("IMediaImage", "RegisterMediaImagePlugins", true), // direct only
["IFloppyImage"] = "RegisterFloppyImagePlugins", ("IPartition", "RegisterPartitionPlugins", true),
["IMediaImage"] = "RegisterMediaImagePlugins", ("IReadOnlyFilesystem", "RegisterReadOnlyFilesystemPlugins", true),
["IPartition"] = "RegisterPartitionPlugins", ("IWritableFloppyImage", "RegisterWritableFloppyImagePlugins", true),
["IReadOnlyFilesystem"] = "RegisterReadOnlyFilesystemPlugins", ("IWritableImage", "RegisterWritableImagePlugins", false), // inherited OK
["IWritableFloppyImage"] = "RegisterWritableFloppyImagePlugins", ("IByteAddressableImage", "RegisterByteAddressablePlugins", false),
["IWritableImage"] = "RegisterWritableImagePlugins", ("IFluxImage", "RegisterFluxImagePlugins", true),
["IByteAddressableImage"] = "RegisterByteAddressablePlugins", ("IWritableFluxImage", "RegisterWritableFluxImagePlugins", false)
["IFluxImage"] = "RegisterFluxImagePlugins",
["IWritableFluxImage"] = "RegisterWritableFluxImagePlugins"
// …snip // …add more as needed
}; };
#region IIncrementalGenerator Members #region IIncrementalGenerator Members
public void Initialize(IncrementalGeneratorInitializationContext ctx) public void Initialize(IncrementalGeneratorInitializationContext ctx)
{ {
// 1) pick up every class syntax with a baselist (so we only inspect ones that *could* have interfaces) // 1) grab every ClassDeclarationSyntax that has a base list
IncrementalValueProvider<ImmutableArray<ClassDeclarationSyntax>> syntaxProvider = ctx.SyntaxProvider IncrementalValueProvider<ImmutableArray<ClassDeclarationSyntax>> classSyntaxes = ctx.SyntaxProvider
.CreateSyntaxProvider((node, ct) => node is ClassDeclarationSyntax cds && cds.BaseList != null, .CreateSyntaxProvider((node, ct) => node is ClassDeclarationSyntax cds && cds.BaseList != null,
(ctx, ct) => (ClassDeclarationSyntax)ctx.Node) (ctx, ct) => (ClassDeclarationSyntax)ctx.Node)
.Collect(); // gather them all .Collect();
// 2) combine with the full Compilation so we can do symbol lookups // 2) combine with the compilation for symbol lookups
IncrementalValueProvider<(Compilation Left, ImmutableArray<ClassDeclarationSyntax> Right)> IncrementalValueProvider<(Compilation Left, ImmutableArray<ClassDeclarationSyntax> Right)>
compilationAndClasses = ctx.CompilationProvider.Combine(syntaxProvider); compilationAndClasses = ctx.CompilationProvider.Combine(classSyntaxes);
// 3) finally generate source // 3) register our source output
ctx.RegisterSourceOutput(compilationAndClasses, ctx.RegisterSourceOutput(compilationAndClasses,
(spc, pair) => (spc, source) =>
{ {
(Compilation? compilation, ImmutableArray<ClassDeclarationSyntax> classDecls) = (Compilation? compilation, ImmutableArray<ClassDeclarationSyntax> classDecls) =
pair; source;
// locate the interface symbols by metadata name once if(compilation is null) return;
(string Name, string Method, INamedTypeSymbol? Symbol)[] interfaceSymbols =
PluginInterfaces
.Select(kvp => (Name: kvp.Key, Method: kvp.Value,
Symbol: compilation
.GetTypeByMetadataName($"Aaru.CommonTypes.Interfaces.{kvp.Key}")))
.Where(x => x.Symbol != null)
.ToArray();
// find the one IPluginRegister type as well // load all plugininterface symbols
INamedTypeSymbol? registerIf = (string Name, string Method, bool DirectOnly, INamedTypeSymbol? Symbol)[]
ifaceDefs = PluginMap.Select(x =>
{
INamedTypeSymbol? sym =
compilation
.GetTypeByMetadataName($"Aaru.CommonTypes.Interfaces.{x.Name}");
return (x.Name, x.Method, x.DirectOnly, Symbol: sym);
})
.Where(x => x.Symbol is not null)
.ToArray();
// load IPluginRegister
INamedTypeSymbol? registerSym =
compilation compilation
.GetTypeByMetadataName("Aaru.CommonTypes.Interfaces.IPluginRegister"); .GetTypeByMetadataName("Aaru.CommonTypes.Interfaces.IPluginRegister");
// collect info
var plugins = new List<PluginInfo>(); var plugins = new List<PluginInfo>();
foreach(ClassDeclarationSyntax? classDecl in classDecls.Distinct()) foreach(ClassDeclarationSyntax? decl in classDecls.Distinct())
{ {
SemanticModel model = compilation.GetSemanticModel(classDecl.SyntaxTree); SemanticModel model = compilation.GetSemanticModel(decl.SyntaxTree);
var symbol = var cls = model.GetDeclaredSymbol(decl, spc.CancellationToken)
model.GetDeclaredSymbol(classDecl, spc.CancellationToken) as as INamedTypeSymbol;
INamedTypeSymbol;
if(symbol is null) continue; if(cls is null) continue;
// which interfaces does it *actually* implement (direct + indirect)? // direct vs. all (transitive) interfaces
ImmutableArray<INamedTypeSymbol> allIfaces = symbol.AllInterfaces; ImmutableArray<INamedTypeSymbol> directIfaces = cls.Interfaces;
ImmutableArray<INamedTypeSymbol> allIfaces = cls.AllInterfaces;
// diagnostics to verify were seeing the right interfaces
foreach(INamedTypeSymbol? iface in allIfaces)
{
spc.ReportDiagnostic(Diagnostic.Create(new DiagnosticDescriptor("PLGN001",
"Found interface",
$"Class {symbol.Name} implements {iface.ToDisplayString()}",
"PluginGen",
DiagnosticSeverity.Info,
true),
classDecl.GetLocation()));
}
var info = new PluginInfo var info = new PluginInfo
{ {
Namespace = symbol.ContainingNamespace.ToDisplayString(), Namespace = cls.ContainingNamespace.ToDisplayString(),
ClassName = symbol.Name, ClassName = cls.Name,
IsRegister = IsRegister =
registerIf != null && registerSym != null &&
allIfaces.Contains(registerIf, SymbolEqualityComparer.Default) allIfaces.Contains(registerSym, SymbolEqualityComparer.Default)
}; };
// pick up every plugininterface your map knows about // for each plugin interface, choose direct or inherited match
foreach((string Name, string Method, INamedTypeSymbol? Symbol) in foreach((string Name, string Method, bool DirectOnly,
interfaceSymbols) INamedTypeSymbol? Symbol) in ifaceDefs)
{ {
if(SymbolEqualityComparer.Default.Equals(Symbol, null)) continue; bool matches = DirectOnly
? directIfaces.Contains(Symbol!,
SymbolEqualityComparer.Default)
: allIfaces.Contains(Symbol!,
SymbolEqualityComparer.Default);
if(allIfaces.Contains(Symbol, SymbolEqualityComparer.Default)) if(matches) info.Interfaces.Add((Name, Method));
info.Interfaces.Add(Name);
} }
if(info.IsRegister || info.Interfaces.Count > 0) plugins.Add(info); if(info.IsRegister || info.Interfaces.Count > 0) plugins.Add(info);
} }
// nothing to do
if(plugins.Count == 0) return; if(plugins.Count == 0) return;
// find the one class that implements IPluginRegister // find the one class that implements IPluginRegister
PluginInfo? regCls = plugins.FirstOrDefault(p => p.IsRegister); PluginInfo? registrar = plugins.FirstOrDefault(p => p.IsRegister);
if(regCls == null) return; if(registrar is null) return;
// build the generated file // build the generated file
var sb = new StringBuilder(); var sb = new StringBuilder();
sb.AppendLine("using Microsoft.Extensions.DependencyInjection;"); sb.AppendLine("using Microsoft.Extensions.DependencyInjection;");
sb.AppendLine("using Aaru.CommonTypes.Interfaces;"); sb.AppendLine("using Aaru.CommonTypes.Interfaces;");
sb.AppendLine($"namespace {regCls.Namespace};"); sb.AppendLine($"namespace {registrar.Namespace};");
sb.AppendLine($"public sealed partial class {regCls.ClassName} : IPluginRegister"); sb.AppendLine($"public sealed partial class {registrar.ClassName} : IPluginRegister");
sb.AppendLine("{"); sb.AppendLine("{");
foreach(KeyValuePair<string, string> kvp in PluginInterfaces) // emit one registration method per plugininterface
foreach((string Name, string Method, bool _) in PluginMap)
{ {
// grab all classes that implement this interface sb.AppendLine($" public void {Method}(IServiceCollection services)");
IEnumerable<string> implementations = plugins
.Where(pi =>
pi.Interfaces
.Contains(kvp.Key))
.Select(pi => pi.ClassName)
.Distinct();
sb.AppendLine($" public void {kvp.Value}(IServiceCollection services)");
sb.AppendLine(" {"); sb.AppendLine(" {");
foreach(string? impl in implementations) foreach(string? impl in plugins
sb.AppendLine($" services.AddTransient<{kvp.Key}, {impl}>();"); .Where(pi => pi.Interfaces.Any(i => i.Name == Name))
.Select(pi => pi.ClassName)
.Distinct())
sb.AppendLine($" services.AddTransient<{Name}, {impl}>();");
sb.AppendLine(" }"); sb.AppendLine(" }");
} }
sb.AppendLine("}"); sb.AppendLine("}");
spc.AddSource("Register.g.cs", SourceText.From(sb.ToString(), Encoding.UTF8)); spc.AddSource("Register.g.cs", SourceText.From(sb.ToString(), Encoding.UTF8));
}); });
} }
@@ -164,10 +152,10 @@ public sealed class PluginRegisterGenerator : IIncrementalGenerator
class PluginInfo class PluginInfo
{ {
public readonly List<string> Interfaces = new(); public string ClassName = "";
public string ClassName = ""; public readonly List<(string Name, string Method)> Interfaces = new();
public bool IsRegister; public bool IsRegister;
public string Namespace = ""; public string Namespace = "";
} }
#endregion #endregion