summaryrefslogtreecommitdiff
path: root/src/StardewModdingAPI/Framework/AssemblyRewriting
diff options
context:
space:
mode:
authorJesse Plamondon-Willard <github@jplamondonw.com>2016-11-26 16:12:21 -0500
committerJesse Plamondon-Willard <github@jplamondonw.com>2016-11-26 16:12:21 -0500
commitb06aed66c47d093585600ca0f7ee1e247507e6b8 (patch)
tree8459efc0986e4172b7696e5e9e5b717df3120404 /src/StardewModdingAPI/Framework/AssemblyRewriting
parent4df199985558caee3275fdd65762bdc3e1d040c6 (diff)
downloadSMAPI-b06aed66c47d093585600ca0f7ee1e247507e6b8.tar.gz
SMAPI-b06aed66c47d093585600ca0f7ee1e247507e6b8.tar.bz2
SMAPI-b06aed66c47d093585600ca0f7ee1e247507e6b8.zip
rewrite type references in mod assemblies to match target platform (#166)
Diffstat (limited to 'src/StardewModdingAPI/Framework/AssemblyRewriting')
-rw-r--r--src/StardewModdingAPI/Framework/AssemblyRewriting/AssemblyTypeRewriter.cs318
1 files changed, 318 insertions, 0 deletions
diff --git a/src/StardewModdingAPI/Framework/AssemblyRewriting/AssemblyTypeRewriter.cs b/src/StardewModdingAPI/Framework/AssemblyRewriting/AssemblyTypeRewriter.cs
new file mode 100644
index 00000000..93003a64
--- /dev/null
+++ b/src/StardewModdingAPI/Framework/AssemblyRewriting/AssemblyTypeRewriter.cs
@@ -0,0 +1,318 @@
+using System.Collections.Generic;
+using System.Linq;
+using System.Reflection;
+using Mono.Cecil;
+using Mono.Cecil.Cil;
+using CallSite = Mono.Cecil.CallSite;
+
+namespace StardewModdingAPI.Framework.AssemblyRewriting
+{
+ /// <summary>Rewrites type references.</summary>
+ internal class AssemblyTypeRewriter
+ {
+ /*********
+ ** Properties
+ *********/
+ /// <summary>The assemblies to target. Equivalent types will be rewritten to use these assemblies.</summary>
+ private readonly Assembly[] TargetAssemblies;
+
+ /// <summary>>The short assembly names to remove as assembly reference, and replace with the <see cref="TargetAssemblies"/>.</summary>
+ private readonly string[] RemoveAssemblyNames;
+
+ /// <summary>A type => assembly lookup for types which should be rewritten.</summary>
+ private readonly IDictionary<string, Assembly> TypeAssemblies;
+
+ /// <summary>An assembly => reference cache.</summary>
+ private readonly IDictionary<Assembly, AssemblyNameReference> AssemblyNameReferences;
+
+ /// <summary>An assembly => module cache.</summary>
+ private readonly IDictionary<Assembly, ModuleDefinition> AssemblyModules;
+
+
+ /*********
+ ** Public methods
+ *********/
+ /// <summary>Construct an instance.</summary>
+ /// <param name="targetAssemblies">The assembly filenames to target. Equivalent types will be rewritten to use these assemblies.</param>
+ /// <param name="removeAssemblyNames">The short assembly names to remove as assembly reference, and replace with the <paramref name="targetAssemblies"/>.</param>
+ public AssemblyTypeRewriter(Assembly[] targetAssemblies, string[] removeAssemblyNames)
+ {
+ // save config
+ this.TargetAssemblies = targetAssemblies;
+ this.RemoveAssemblyNames = removeAssemblyNames;
+
+ // cache assembly metadata
+ this.AssemblyNameReferences = targetAssemblies.ToDictionary(assembly => assembly, assembly => AssemblyNameReference.Parse(assembly.FullName));
+ this.AssemblyModules = targetAssemblies.ToDictionary(assembly => assembly, assembly => ModuleDefinition.ReadModule(assembly.Modules.Single().FullyQualifiedName)); // technically an assembly can contain multiple modules, but none of the build tools (including MSBuild itself) support it
+
+ // collect type => assembly lookup
+ this.TypeAssemblies = new Dictionary<string, Assembly>();
+ foreach (Assembly assembly in targetAssemblies)
+ {
+ ModuleDefinition module = this.AssemblyModules[assembly];
+ foreach (TypeDefinition type in module.GetTypes())
+ {
+ if (!type.IsPublic)
+ continue; // no need to rewrite
+ if (type.Namespace.Contains("<"))
+ continue; // ignore C++ stuff
+ this.TypeAssemblies[type.FullName] = assembly;
+ }
+ }
+ }
+
+ /// <summary>Rewrite the types referenced by an assembly.</summary>
+ /// <param name="assembly">The assembly to rewrite.</param>
+ public void RewriteAssembly(AssemblyDefinition assembly)
+ {
+ foreach (ModuleDefinition module in assembly.Modules)
+ {
+ // rewrite assembly references
+ bool shouldRewriteTypes = false;
+ for (int i = 0; i < module.AssemblyReferences.Count; i++)
+ {
+ bool shouldRemove = this.RemoveAssemblyNames.Any(name => module.AssemblyReferences[i].Name == name) || this.TargetAssemblies.Any(a => module.AssemblyReferences[i].Name == a.GetName().Name);
+ if (shouldRemove)
+ {
+ shouldRewriteTypes = true;
+ module.AssemblyReferences.RemoveAt(i);
+ i--;
+ }
+ }
+ foreach (AssemblyNameReference target in this.AssemblyNameReferences.Values)
+ {
+ module.AssemblyReferences.Add(target);
+ shouldRewriteTypes = true;
+ }
+
+ // rewrite references
+ if (shouldRewriteTypes)
+ {
+ // rewrite types
+ foreach (TypeDefinition type in module.GetTypes())
+ this.RewriteReferences(type, module);
+
+ // rewrite type references
+ TypeReference[] refs = (TypeReference[])module.GetTypeReferences();
+ for (int i = 0; i < refs.Length; ++i)
+ refs[i] = this.GetTypeReference(refs[i], module);
+ }
+ }
+ }
+
+
+ /*********
+ ** Private methods
+ *********/
+ /// <summary>Rewrite the references for a code object.</summary>
+ /// <param name="type">The type to rewrite.</param>
+ /// <param name="module">The module being rewritten.</param>
+ private void RewriteReferences(TypeDefinition type, ModuleDefinition module)
+ {
+ // rewrite base type
+ type.BaseType = this.GetTypeReference(type.BaseType, module);
+
+ // rewrite interfaces
+ for (int i = 0; i < type.Interfaces.Count; i++)
+ type.Interfaces[i] = this.GetTypeReference(type.Interfaces[i], module);
+
+ // rewrite events
+ foreach (EventDefinition @event in type.Events)
+ {
+ this.RewriteReferences(@event.AddMethod, module);
+ this.RewriteReferences(@event.RemoveMethod, module);
+ this.RewriteReferences(@event.InvokeMethod, module);
+ }
+
+ // rewrite properties
+ foreach (PropertyDefinition property in type.Properties)
+ {
+ this.RewriteReferences(property.GetMethod, module);
+ this.RewriteReferences(property.SetMethod, module);
+ }
+
+ // rewrite methods
+ foreach (MethodDefinition method in type.Methods)
+ this.RewriteReferences(method, module);
+
+ // rewrite fields
+ foreach (FieldDefinition field in type.Fields)
+ this.RewriteReferences(field, module);
+
+ // rewrite nested types
+ foreach (TypeDefinition nestedType in type.NestedTypes)
+ this.RewriteReferences(nestedType, module);
+
+ // rewrite generic parameters
+ foreach (GenericParameter parameter in type.GenericParameters)
+ this.RewriteReferences(parameter, module);
+
+ module.Import(type);
+ }
+
+ /// <summary>Rewrite the references for a code object.</summary>
+ /// <param name="method">The method to rewrite.</param>
+ /// <param name="module">The module being rewritten.</param>
+ private void RewriteReferences(MethodReference method, ModuleDefinition module)
+ {
+ // parameter types
+ if (method.HasParameters)
+ {
+ foreach (ParameterDefinition parameter in method.Parameters)
+ parameter.ParameterType = this.GetTypeReference(parameter.ParameterType, module);
+ }
+
+ // return type
+ method.MethodReturnType.ReturnType = this.GetTypeReference(method.MethodReturnType.ReturnType, module);
+
+ module.Import(method);
+ }
+
+ /// <summary>Rewrite the references for a code object.</summary>
+ /// <param name="method">The method to rewrite.</param>
+ /// <param name="module">The module being rewritten.</param>
+ private void RewriteReferences(MethodDefinition method, ModuleDefinition module)
+ {
+ if (method == null)
+ return;
+
+ this.RewriteReferences((MethodReference)method, module);
+
+ // overrides
+ foreach (MethodReference @override in method.Overrides)
+ this.RewriteReferences(@override, module);
+
+ // body
+ if (method.HasBody)
+ {
+ // this
+ if (method.Body.ThisParameter != null)
+ method.Body.ThisParameter.ParameterType = this.GetTypeReference(method.Body.ThisParameter.ParameterType, module);
+
+ // variables
+ if (method.Body.HasVariables)
+ {
+ foreach (VariableDefinition variable in method.Body.Variables)
+ variable.VariableType = this.GetTypeReference(variable.VariableType, module);
+ }
+
+ // instructions
+ foreach (Instruction instruction in method.Body.Instructions)
+ {
+ object operand = instruction.Operand;
+
+ // type
+ {
+ TypeReference type = operand as TypeReference;
+ if (type != null)
+ {
+ instruction.Operand = this.GetTypeReference(type, module);
+ continue;
+ }
+ }
+
+ // method
+ {
+ MethodReference methodRef = operand as MethodReference;
+ if (methodRef != null)
+ {
+ this.RewriteReferences(methodRef, module);
+ continue;
+ }
+ }
+
+ // field
+ {
+ FieldReference field = operand as FieldReference;
+ if (field != null)
+ {
+ this.RewriteReferences(field, module);
+ continue;
+ }
+ }
+
+ // variable
+ {
+ VariableDefinition variable = operand as VariableDefinition;
+ if (variable != null)
+ {
+ variable.VariableType = this.GetTypeReference(variable.VariableType, module);
+ continue;
+ }
+ }
+
+ // parameter
+ {
+ ParameterDefinition parameter = operand as ParameterDefinition;
+ if (parameter != null)
+ {
+ parameter.ParameterType = this.GetTypeReference(parameter.ParameterType, module);
+ continue;
+ }
+ }
+
+ // call site
+ {
+ CallSite call = operand as CallSite;
+ if (call != null)
+ {
+ foreach (ParameterDefinition parameter in call.Parameters)
+ parameter.ParameterType = this.GetTypeReference(parameter.ParameterType, module);
+ call.ReturnType = this.GetTypeReference(call.ReturnType, module);
+ }
+ }
+ }
+ }
+
+ module.Import(method);
+ }
+
+ /// <summary>Rewrite the references for a code object.</summary>
+ /// <param name="parameter">The generic parameter to rewrite.</param>
+ /// <param name="module">The module being rewritten.</param>
+ private void RewriteReferences(GenericParameter parameter, ModuleDefinition module)
+ {
+ // constraints
+ for (int i = 0; i < parameter.Constraints.Count; i++)
+ parameter.Constraints[i] = this.GetTypeReference(parameter.Constraints[i], module);
+
+ // generic parameters
+ foreach (GenericParameter genericParam in parameter.GenericParameters)
+ this.RewriteReferences(genericParam, module);
+ }
+
+ /// <summary>Rewrite the references for a code object.</summary>
+ /// <param name="field">The field to rewrite.</param>
+ /// <param name="module">The module being rewritten.</param>
+ private void RewriteReferences(FieldReference field, ModuleDefinition module)
+ {
+ field.DeclaringType = this.GetTypeReference(field.DeclaringType, module);
+ field.FieldType = this.GetTypeReference(field.FieldType, module);
+ module.Import(field);
+ }
+
+ /// <summary>Get the correct reference to use for compatibility with the current platform.</summary>
+ /// <param name="type">The type reference to rewrite.</param>
+ /// <param name="module">The module being rewritten.</param>
+ private TypeReference GetTypeReference(TypeReference type, ModuleDefinition module)
+ {
+ // check skip conditions
+ if (type == null)
+ return null;
+ if (type.FullName.StartsWith("System."))
+ return type;
+
+ // get assembly
+ Assembly assembly;
+ if (!this.TypeAssemblies.TryGetValue(type.FullName, out assembly))
+ return type;
+
+ // replace type
+ AssemblyNameReference newAssembly = this.AssemblyNameReferences[assembly];
+ ModuleDefinition newModule = this.AssemblyModules[assembly];
+ type = new TypeReference(type.Namespace, type.Name, newModule, newAssembly);
+
+ return module.Import(type);
+ }
+ }
+}