From 2154b6de95bc97e3412b0800d3e2809bd2a1e544 Mon Sep 17 00:00:00 2001 From: Jesse Plamondon-Willard Date: Sat, 26 Nov 2016 16:14:10 -0500 Subject: use simpler, non-broken approach for rewriting mod type references (#166) --- .../AssemblyRewriting/AssemblyTypeRewriter.cs | 255 ++------------------- 1 file changed, 24 insertions(+), 231 deletions(-) (limited to 'src') diff --git a/src/StardewModdingAPI/Framework/AssemblyRewriting/AssemblyTypeRewriter.cs b/src/StardewModdingAPI/Framework/AssemblyRewriting/AssemblyTypeRewriter.cs index 93003a64..7a339266 100644 --- a/src/StardewModdingAPI/Framework/AssemblyRewriting/AssemblyTypeRewriter.cs +++ b/src/StardewModdingAPI/Framework/AssemblyRewriting/AssemblyTypeRewriter.cs @@ -2,8 +2,6 @@ using System.Linq; using System.Reflection; using Mono.Cecil; -using Mono.Cecil.Cil; -using CallSite = Mono.Cecil.CallSite; namespace StardewModdingAPI.Framework.AssemblyRewriting { @@ -25,9 +23,6 @@ namespace StardewModdingAPI.Framework.AssemblyRewriting /// An assembly => reference cache. private readonly IDictionary AssemblyNameReferences; - /// An assembly => module cache. - private readonly IDictionary AssemblyModules; - /********* ** Public methods @@ -43,20 +38,22 @@ namespace StardewModdingAPI.Framework.AssemblyRewriting // cache assembly metadata this.AssemblyNameReferences = targetAssemblies.ToDictionary(assembly => assembly, assembly => AssemblyNameReference.Parse(assembly.FullName)); - this.AssemblyModules = targetAssemblies.ToDictionary(assembly => assembly, assembly => ModuleDefinition.ReadModule(assembly.Modules.Single().FullyQualifiedName)); // technically an assembly can contain multiple modules, but none of the build tools (including MSBuild itself) support it // collect type => assembly lookup this.TypeAssemblies = new Dictionary(); foreach (Assembly assembly in targetAssemblies) { - ModuleDefinition module = this.AssemblyModules[assembly]; - foreach (TypeDefinition type in module.GetTypes()) + foreach (Module assemblyModule in assembly.Modules) { - if (!type.IsPublic) - continue; // no need to rewrite - if (type.Namespace.Contains("<")) - continue; // ignore C++ stuff - this.TypeAssemblies[type.FullName] = assembly; + ModuleDefinition module = ModuleDefinition.ReadModule(assemblyModule.FullyQualifiedName); + foreach (TypeDefinition type in module.GetTypes()) + { + if (!type.IsPublic) + continue; // no need to rewrite + if (type.Namespace.Contains("<")) + continue; // ignore assembly metadata + this.TypeAssemblies[type.FullName] = assembly; + } } } } @@ -67,36 +64,25 @@ namespace StardewModdingAPI.Framework.AssemblyRewriting { foreach (ModuleDefinition module in assembly.Modules) { - // rewrite assembly references - bool shouldRewriteTypes = false; + // remove old assembly references for (int i = 0; i < module.AssemblyReferences.Count; i++) { bool shouldRemove = this.RemoveAssemblyNames.Any(name => module.AssemblyReferences[i].Name == name) || this.TargetAssemblies.Any(a => module.AssemblyReferences[i].Name == a.GetName().Name); if (shouldRemove) { - shouldRewriteTypes = true; module.AssemblyReferences.RemoveAt(i); i--; } } + + // add target assembly references foreach (AssemblyNameReference target in this.AssemblyNameReferences.Values) - { module.AssemblyReferences.Add(target); - shouldRewriteTypes = true; - } - - // rewrite references - if (shouldRewriteTypes) - { - // rewrite types - foreach (TypeDefinition type in module.GetTypes()) - this.RewriteReferences(type, module); - // rewrite type references - TypeReference[] refs = (TypeReference[])module.GetTypeReferences(); - for (int i = 0; i < refs.Length; ++i) - refs[i] = this.GetTypeReference(refs[i], module); - } + // rewrite type scopes to use target assemblies + TypeReference[] refs = (TypeReference[])module.GetTypeReferences(); + foreach (TypeReference type in refs) + this.ChangeTypeScope(type); } } @@ -104,215 +90,22 @@ namespace StardewModdingAPI.Framework.AssemblyRewriting /********* ** Private methods *********/ - /// Rewrite the references for a code object. - /// The type to rewrite. - /// The module being rewritten. - private void RewriteReferences(TypeDefinition type, ModuleDefinition module) - { - // rewrite base type - type.BaseType = this.GetTypeReference(type.BaseType, module); - - // rewrite interfaces - for (int i = 0; i < type.Interfaces.Count; i++) - type.Interfaces[i] = this.GetTypeReference(type.Interfaces[i], module); - - // rewrite events - foreach (EventDefinition @event in type.Events) - { - this.RewriteReferences(@event.AddMethod, module); - this.RewriteReferences(@event.RemoveMethod, module); - this.RewriteReferences(@event.InvokeMethod, module); - } - - // rewrite properties - foreach (PropertyDefinition property in type.Properties) - { - this.RewriteReferences(property.GetMethod, module); - this.RewriteReferences(property.SetMethod, module); - } - - // rewrite methods - foreach (MethodDefinition method in type.Methods) - this.RewriteReferences(method, module); - - // rewrite fields - foreach (FieldDefinition field in type.Fields) - this.RewriteReferences(field, module); - - // rewrite nested types - foreach (TypeDefinition nestedType in type.NestedTypes) - this.RewriteReferences(nestedType, module); - - // rewrite generic parameters - foreach (GenericParameter parameter in type.GenericParameters) - this.RewriteReferences(parameter, module); - - module.Import(type); - } - - /// Rewrite the references for a code object. - /// The method to rewrite. - /// The module being rewritten. - private void RewriteReferences(MethodReference method, ModuleDefinition module) - { - // parameter types - if (method.HasParameters) - { - foreach (ParameterDefinition parameter in method.Parameters) - parameter.ParameterType = this.GetTypeReference(parameter.ParameterType, module); - } - - // return type - method.MethodReturnType.ReturnType = this.GetTypeReference(method.MethodReturnType.ReturnType, module); - - module.Import(method); - } - - /// Rewrite the references for a code object. - /// The method to rewrite. - /// The module being rewritten. - private void RewriteReferences(MethodDefinition method, ModuleDefinition module) - { - if (method == null) - return; - - this.RewriteReferences((MethodReference)method, module); - - // overrides - foreach (MethodReference @override in method.Overrides) - this.RewriteReferences(@override, module); - - // body - if (method.HasBody) - { - // this - if (method.Body.ThisParameter != null) - method.Body.ThisParameter.ParameterType = this.GetTypeReference(method.Body.ThisParameter.ParameterType, module); - - // variables - if (method.Body.HasVariables) - { - foreach (VariableDefinition variable in method.Body.Variables) - variable.VariableType = this.GetTypeReference(variable.VariableType, module); - } - - // instructions - foreach (Instruction instruction in method.Body.Instructions) - { - object operand = instruction.Operand; - - // type - { - TypeReference type = operand as TypeReference; - if (type != null) - { - instruction.Operand = this.GetTypeReference(type, module); - continue; - } - } - - // method - { - MethodReference methodRef = operand as MethodReference; - if (methodRef != null) - { - this.RewriteReferences(methodRef, module); - continue; - } - } - - // field - { - FieldReference field = operand as FieldReference; - if (field != null) - { - this.RewriteReferences(field, module); - continue; - } - } - - // variable - { - VariableDefinition variable = operand as VariableDefinition; - if (variable != null) - { - variable.VariableType = this.GetTypeReference(variable.VariableType, module); - continue; - } - } - - // parameter - { - ParameterDefinition parameter = operand as ParameterDefinition; - if (parameter != null) - { - parameter.ParameterType = this.GetTypeReference(parameter.ParameterType, module); - continue; - } - } - - // call site - { - CallSite call = operand as CallSite; - if (call != null) - { - foreach (ParameterDefinition parameter in call.Parameters) - parameter.ParameterType = this.GetTypeReference(parameter.ParameterType, module); - call.ReturnType = this.GetTypeReference(call.ReturnType, module); - } - } - } - } - - module.Import(method); - } - - /// Rewrite the references for a code object. - /// The generic parameter to rewrite. - /// The module being rewritten. - private void RewriteReferences(GenericParameter parameter, ModuleDefinition module) - { - // constraints - for (int i = 0; i < parameter.Constraints.Count; i++) - parameter.Constraints[i] = this.GetTypeReference(parameter.Constraints[i], module); - - // generic parameters - foreach (GenericParameter genericParam in parameter.GenericParameters) - this.RewriteReferences(genericParam, module); - } - - /// Rewrite the references for a code object. - /// The field to rewrite. - /// The module being rewritten. - private void RewriteReferences(FieldReference field, ModuleDefinition module) - { - field.DeclaringType = this.GetTypeReference(field.DeclaringType, module); - field.FieldType = this.GetTypeReference(field.FieldType, module); - module.Import(field); - } - /// Get the correct reference to use for compatibility with the current platform. /// The type reference to rewrite. - /// The module being rewritten. - private TypeReference GetTypeReference(TypeReference type, ModuleDefinition module) + private void ChangeTypeScope(TypeReference type) { // check skip conditions - if (type == null) - return null; - if (type.FullName.StartsWith("System.")) - return type; + if (type == null || type.FullName.StartsWith("System.")) + return; // get assembly Assembly assembly; if (!this.TypeAssemblies.TryGetValue(type.FullName, out assembly)) - return type; - - // replace type - AssemblyNameReference newAssembly = this.AssemblyNameReferences[assembly]; - ModuleDefinition newModule = this.AssemblyModules[assembly]; - type = new TypeReference(type.Namespace, type.Name, newModule, newAssembly); + return; - return module.Import(type); + // replace scope + AssemblyNameReference assemblyRef = this.AssemblyNameReferences[assembly]; + type.Scope = assemblyRef; } } } -- cgit