diff options
Diffstat (limited to 'src/StardewModdingAPI/Framework/AssemblyLoader.cs')
-rw-r--r-- | src/StardewModdingAPI/Framework/AssemblyLoader.cs | 134 |
1 files changed, 91 insertions, 43 deletions
diff --git a/src/StardewModdingAPI/Framework/AssemblyLoader.cs b/src/StardewModdingAPI/Framework/AssemblyLoader.cs index 123211b9..f6fe89f5 100644 --- a/src/StardewModdingAPI/Framework/AssemblyLoader.cs +++ b/src/StardewModdingAPI/Framework/AssemblyLoader.cs @@ -5,7 +5,6 @@ using System.Linq; using System.Reflection; using Mono.Cecil; using Mono.Cecil.Cil; -using Mono.Cecil.Rocks; using StardewModdingAPI.AssemblyRewriters; namespace StardewModdingAPI.Framework @@ -55,14 +54,17 @@ namespace StardewModdingAPI.Framework /// <summary>Preprocess and load an assembly.</summary> /// <param name="assemblyPath">The assembly file path.</param> + /// <param name="assumeCompatible">Assume the mod is compatible, even if incompatible code is detected.</param> /// <returns>Returns the rewrite metadata for the preprocessed assembly.</returns> - public Assembly Load(string assemblyPath) + /// <exception cref="IncompatibleInstructionException">An incompatible CIL instruction was found while rewriting the assembly.</exception> + public Assembly Load(string assemblyPath, bool assumeCompatible) { // get referenced local assemblies AssemblyParseResult[] assemblies; { AssemblyDefinitionResolver resolver = new AssemblyDefinitionResolver(); - assemblies = this.GetReferencedLocalAssemblies(new FileInfo(assemblyPath), new HashSet<string>(), resolver).ToArray(); + HashSet<string> visitedAssemblyNames = new HashSet<string>(AppDomain.CurrentDomain.GetAssemblies().Select(p => p.GetName().Name)); // don't try loading assemblies that are already loaded + assemblies = this.GetReferencedLocalAssemblies(new FileInfo(assemblyPath), visitedAssemblyNames, resolver).ToArray(); if (!assemblies.Any()) throw new InvalidOperationException($"Could not load '{assemblyPath}' because it doesn't exist."); resolver.Add(assemblies.Select(p => p.Definition).ToArray()); @@ -72,10 +74,10 @@ namespace StardewModdingAPI.Framework Assembly lastAssembly = null; foreach (AssemblyParseResult assembly in assemblies) { - this.Monitor.Log($"Loading {assembly.File.Name}...", LogLevel.Trace); - bool changed = this.RewriteAssembly(assembly.Definition); + bool changed = this.RewriteAssembly(assembly.Definition, assumeCompatible); if (changed) { + this.Monitor.Log($"Loading {assembly.File.Name} (rewritten in memory)...", LogLevel.Trace); using (MemoryStream outStream = new MemoryStream()) { assembly.Definition.Write(outStream); @@ -84,7 +86,10 @@ namespace StardewModdingAPI.Framework } } else + { + this.Monitor.Log($"Loading {assembly.File.Name}...", LogLevel.Trace); lastAssembly = Assembly.UnsafeLoadFrom(assembly.File.FullName); + } } // last assembly loaded is the root @@ -116,18 +121,16 @@ namespace StardewModdingAPI.Framework ****/ /// <summary>Get a list of referenced local assemblies starting from the mod assembly, ordered from leaf to root.</summary> /// <param name="file">The assembly file to load.</param> - /// <param name="visitedAssemblyPaths">The assembly paths that should be skipped.</param> + /// <param name="visitedAssemblyNames">The assembly names that should be skipped.</param> + /// <param name="assemblyResolver">A resolver which resolves references to known assemblies.</param> /// <returns>Returns the rewrite metadata for the preprocessed assembly.</returns> - private IEnumerable<AssemblyParseResult> GetReferencedLocalAssemblies(FileInfo file, HashSet<string> visitedAssemblyPaths, IAssemblyResolver assemblyResolver) + private IEnumerable<AssemblyParseResult> GetReferencedLocalAssemblies(FileInfo file, HashSet<string> visitedAssemblyNames, IAssemblyResolver assemblyResolver) { // validate if (file.Directory == null) throw new InvalidOperationException($"Could not get directory from file path '{file.FullName}'."); - if (visitedAssemblyPaths.Contains(file.FullName)) - yield break; // already visited if (!file.Exists) yield break; // not a local assembly - visitedAssemblyPaths.Add(file.FullName); // read assembly byte[] assemblyBytes = File.ReadAllBytes(file.FullName); @@ -135,11 +138,16 @@ namespace StardewModdingAPI.Framework using (Stream readStream = new MemoryStream(assemblyBytes)) assembly = AssemblyDefinition.ReadAssembly(readStream, new ReaderParameters(ReadingMode.Deferred) { AssemblyResolver = assemblyResolver }); + // skip if already visited + if (visitedAssemblyNames.Contains(assembly.Name.Name)) + yield break; + visitedAssemblyNames.Add(assembly.Name.Name); + // yield referenced assemblies foreach (AssemblyNameReference dependency in assembly.MainModule.AssemblyReferences) { FileInfo dependencyFile = new FileInfo(Path.Combine(file.Directory.FullName, $"{dependency.Name}.dll")); - foreach (AssemblyParseResult result in this.GetReferencedLocalAssemblies(dependencyFile, visitedAssemblyPaths, assemblyResolver)) + foreach (AssemblyParseResult result in this.GetReferencedLocalAssemblies(dependencyFile, visitedAssemblyNames, assemblyResolver)) yield return result; } @@ -152,62 +160,88 @@ namespace StardewModdingAPI.Framework ****/ /// <summary>Rewrite the types referenced by an assembly.</summary> /// <param name="assembly">The assembly to rewrite.</param> + /// <param name="assumeCompatible">Assume the mod is compatible, even if incompatible code is detected.</param> /// <returns>Returns whether the assembly was modified.</returns> - private bool RewriteAssembly(AssemblyDefinition assembly) + /// <exception cref="IncompatibleInstructionException">An incompatible CIL instruction was found while rewriting the assembly.</exception> + private bool RewriteAssembly(AssemblyDefinition assembly, bool assumeCompatible) { - ModuleDefinition module = assembly.Modules.Single(); // technically an assembly can have multiple modules, but none of the build tools (including MSBuild) support it; simplify by assuming one module + ModuleDefinition module = assembly.MainModule; + HashSet<string> loggedMessages = new HashSet<string>(); - // remove old assembly references - bool shouldRewrite = false; + // swap assembly references if needed (e.g. XNA => MonoGame) + bool platformChanged = false; for (int i = 0; i < module.AssemblyReferences.Count; i++) { + // remove old assembly reference if (this.AssemblyMap.RemoveNames.Any(name => module.AssemblyReferences[i].Name == name)) { - shouldRewrite = true; + this.LogOnce(this.Monitor, loggedMessages, $"Rewriting {assembly.Name.Name} for OS..."); + platformChanged = true; module.AssemblyReferences.RemoveAt(i); i--; } } - if (!shouldRewrite) - return false; - - // add target assembly references - foreach (AssemblyNameReference target in this.AssemblyMap.TargetReferences.Values) - module.AssemblyReferences.Add(target); + if (platformChanged) + { + // add target assembly references + foreach (AssemblyNameReference target in this.AssemblyMap.TargetReferences.Values) + module.AssemblyReferences.Add(target); - // rewrite type scopes to use target assemblies - IEnumerable<TypeReference> typeReferences = module.GetTypeReferences().OrderBy(p => p.FullName); - foreach (TypeReference type in typeReferences) - this.ChangeTypeScope(type); + // rewrite type scopes to use target assemblies + IEnumerable<TypeReference> typeReferences = module.GetTypeReferences().OrderBy(p => p.FullName); + foreach (TypeReference type in typeReferences) + this.ChangeTypeScope(type); + } - // rewrite incompatible methods - IMethodRewriter[] methodRewriters = Constants.GetMethodRewriters().ToArray(); + // find (and optionally rewrite) incompatible instructions + bool anyRewritten = false; + IInstructionRewriter[] rewriters = Constants.GetRewriters().ToArray(); foreach (MethodDefinition method in this.GetMethods(module)) { - // skip methods with no rewritable method - bool hasMethodToRewrite = method.Body.Instructions.Any(op => (op.OpCode == OpCodes.Call || op.OpCode == OpCodes.Callvirt) && methodRewriters.Any(rewriter => rewriter.ShouldRewrite((MethodReference)op.Operand))); - if (!hasMethodToRewrite) - continue; + // check method definition + foreach (IInstructionRewriter rewriter in rewriters) + { + try + { + if (rewriter.Rewrite(module, method, this.AssemblyMap, platformChanged)) + { + this.LogOnce(this.Monitor, loggedMessages, $"Rewrote {assembly.Name.Name} to fix {rewriter.NounPhrase}..."); + anyRewritten = true; + } + } + catch (IncompatibleInstructionException) + { + if (!assumeCompatible) + throw new IncompatibleInstructionException(rewriter.NounPhrase, $"Found an incompatible CIL instruction ({rewriter.NounPhrase}) while loading assembly {assembly.Name.Name}."); + this.LogOnce(this.Monitor, loggedMessages, $"Found an incompatible CIL instruction ({rewriter.NounPhrase}) while loading assembly {assembly.Name.Name}, but SMAPI is configured to allow it anyway. The mod may crash or behave unexpectedly.", LogLevel.Warn); + } + } - // rewrite method references - method.Body.SimplifyMacros(); + // check CIL instructions ILProcessor cil = method.Body.GetILProcessor(); - Instruction[] instructions = cil.Body.Instructions.ToArray(); - foreach (Instruction op in instructions) + foreach (Instruction instruction in cil.Body.Instructions.ToArray()) { - if (op.OpCode == OpCodes.Call || op.OpCode == OpCodes.Callvirt) + foreach (IInstructionRewriter rewriter in rewriters) { - IMethodRewriter rewriter = methodRewriters.FirstOrDefault(p => p.ShouldRewrite((MethodReference)op.Operand)); - if (rewriter != null) + try { - MethodReference methodRef = (MethodReference)op.Operand; - rewriter.Rewrite(module, cil, op, methodRef, this.AssemblyMap); + if (rewriter.Rewrite(module, cil, instruction, this.AssemblyMap, platformChanged)) + { + this.LogOnce(this.Monitor, loggedMessages, $"Rewrote {assembly.Name.Name} to fix {rewriter.NounPhrase}..."); + anyRewritten = true; + } + } + catch (IncompatibleInstructionException) + { + if (!assumeCompatible) + throw new IncompatibleInstructionException(rewriter.NounPhrase, $"Found an incompatible CIL instruction ({rewriter.NounPhrase}) while loading assembly {assembly.Name.Name}."); + this.LogOnce(this.Monitor, loggedMessages, $"Found an incompatible CIL instruction ({rewriter.NounPhrase}) while loading assembly {assembly.Name.Name}, but SMAPI is configured to allow it anyway. The mod may crash or behave unexpectedly.", LogLevel.Warn); } } } - method.Body.OptimizeMacros(); } - return true; + + return platformChanged || anyRewritten; } /// <summary>Get the correct reference to use for compatibility with the current platform.</summary> @@ -240,5 +274,19 @@ namespace StardewModdingAPI.Framework select method ); } + + /// <summary>Log a message for the player or developer the first time it occurs.</summary> + /// <param name="monitor">The monitor through which to log the message.</param> + /// <param name="hash">The hash of logged messages.</param> + /// <param name="message">The message to log.</param> + /// <param name="level">The log severity level.</param> + private void LogOnce(IMonitor monitor, HashSet<string> hash, string message, LogLevel level = LogLevel.Trace) + { + if (!hash.Contains(message)) + { + this.Monitor.Log(message, level); + hash.Add(message); + } + } } } |