using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Reflection.Emit;
using System.Runtime.CompilerServices;
namespace StardewModdingAPI.Framework.Reflection
{
/// Generates a proxy class to access a mod API through an arbitrary interface.
internal class InterfaceProxyBuilder
{
/*********
** Consts
*********/
private static readonly string TargetFieldName = "__Target";
private static readonly string GlueFieldName = "__Glue";
private static readonly MethodInfo ObtainInstanceForProxyTypeNameMethod = typeof(InterfaceProxyGlue).GetMethod(nameof(InterfaceProxyGlue.ObtainInstanceForProxyTypeName), new Type[] { typeof(string), typeof(object) });
private static readonly MethodInfo UnproxyOrObtainInstanceForProxyTypeNameMethod = typeof(InterfaceProxyGlue).GetMethod(nameof(InterfaceProxyGlue.UnproxyOrObtainInstanceForProxyTypeName), new Type[] { typeof(string), typeof(string), typeof(object) });
/*********
** Fields
*********/
/// The target class type.
private readonly Type TargetType;
/// The interfce type.
private readonly Type InterfaceType;
/// The full name of the generated proxy type.
private readonly string ProxyTypeName;
/// The generated proxy type.
private Type ProxyType;
/// A cache of all proxies generated by this builder.
private readonly ConditionalWeakTable ProxyCache = new();
/*********
** Public methods
*********/
/// Construct an instance.
/// The target type.
/// The interface type to implement.
/// The type name to generate.
public InterfaceProxyBuilder(Type targetType, Type interfaceType, string proxyTypeName)
{
// validate and store
this.TargetType = targetType ?? throw new ArgumentNullException(nameof(targetType));
this.InterfaceType = interfaceType ?? throw new ArgumentNullException(nameof(interfaceType));
this.ProxyTypeName = proxyTypeName ?? throw new ArgumentNullException(nameof(proxyTypeName));
if (!interfaceType.IsInterface)
throw new ArgumentException($"{nameof(interfaceType)} is not an interface.");
}
/// Creates and sets up the proxy type.
/// The that requested to build a proxy.
/// The CLR module in which to create proxy classes.
/// The unique ID of the mod consuming the API.
/// The unique ID of the mod providing the API.
public void SetupProxyType(InterfaceProxyFactory factory, ModuleBuilder moduleBuilder, string sourceModID, string targetModID)
{
// define proxy type
TypeBuilder proxyBuilder = moduleBuilder.DefineType(this.ProxyTypeName, TypeAttributes.Public | TypeAttributes.Class);
proxyBuilder.AddInterfaceImplementation(this.InterfaceType);
// create fields to store target instance and proxy factory
FieldBuilder targetField = proxyBuilder.DefineField(TargetFieldName, this.TargetType, FieldAttributes.Private);
FieldBuilder glueField = proxyBuilder.DefineField(GlueFieldName, typeof(InterfaceProxyGlue), FieldAttributes.Private);
// create constructor which accepts target instance + factory, and sets fields
{
ConstructorBuilder constructor = proxyBuilder.DefineConstructor(MethodAttributes.Public, CallingConventions.Standard | CallingConventions.HasThis, new[] { this.TargetType, typeof(InterfaceProxyGlue) });
ILGenerator il = constructor.GetILGenerator();
il.Emit(OpCodes.Ldarg_0); // this
// ReSharper disable once AssignNullToNotNullAttribute -- never null
il.Emit(OpCodes.Call, typeof(object).GetConstructor(Array.Empty())); // call base constructor
il.Emit(OpCodes.Ldarg_0); // this
il.Emit(OpCodes.Ldarg_1); // load argument
il.Emit(OpCodes.Stfld, targetField); // set field to loaded argument
il.Emit(OpCodes.Ldarg_0); // this
il.Emit(OpCodes.Ldarg_2); // load argument
il.Emit(OpCodes.Stfld, glueField); // set field to loaded argument
il.Emit(OpCodes.Ret);
}
var allTargetMethods = this.TargetType.GetMethods().ToList();
foreach (Type targetInterface in this.TargetType.GetInterfaces())
{
foreach (MethodInfo targetMethod in targetInterface.GetMethods())
{
if (!targetMethod.IsAbstract)
allTargetMethods.Add(targetMethod);
}
}
MatchingTypesResult AreTypesMatching(Type targetType, Type proxyType, MethodTypeMatchingPart part)
{
var typeA = part == MethodTypeMatchingPart.Parameter ? targetType : proxyType;
var typeB = part == MethodTypeMatchingPart.Parameter ? proxyType : targetType;
if (typeA.IsGenericMethodParameter != typeB.IsGenericMethodParameter)
return MatchingTypesResult.False;
// TODO: decide if "assignable" checking is desired (instead of just 1:1 type equality)
if (typeA.IsGenericMethodParameter ? typeA.GenericParameterPosition == typeB.GenericParameterPosition : typeA.IsAssignableFrom(typeB))
return MatchingTypesResult.True;
if (!proxyType.IsGenericMethodParameter)
{
if (proxyType.GetNonRefType().IsInterface)
return MatchingTypesResult.IfProxied;
if (targetType.GetNonRefType().IsInterface)
return MatchingTypesResult.IfProxied;
}
return MatchingTypesResult.False;
}
// proxy methods
foreach (MethodInfo proxyMethod in this.InterfaceType.GetMethods())
{
var proxyMethodParameters = proxyMethod.GetParameters();
var proxyMethodGenericArguments = proxyMethod.GetGenericArguments();
foreach (MethodInfo targetMethod in allTargetMethods)
{
// checking if `targetMethod` matches `proxyMethod`
if (targetMethod.Name != proxyMethod.Name)
continue;
if (targetMethod.GetGenericArguments().Length != proxyMethodGenericArguments.Length)
continue;
var positionsToProxy = new HashSet(); // null = return type; anything else = parameter position
switch (AreTypesMatching(targetMethod.ReturnType, proxyMethod.ReturnType, MethodTypeMatchingPart.ReturnType))
{
case MatchingTypesResult.False:
continue;
case MatchingTypesResult.True:
break;
case MatchingTypesResult.IfProxied:
positionsToProxy.Add(null);
break;
}
var mParameters = targetMethod.GetParameters();
if (mParameters.Length != proxyMethodParameters.Length)
continue;
for (int i = 0; i < mParameters.Length; i++)
{
switch (AreTypesMatching(mParameters[i].ParameterType, proxyMethodParameters[i].ParameterType, MethodTypeMatchingPart.Parameter))
{
case MatchingTypesResult.False:
goto targetMethodLoopContinue;
case MatchingTypesResult.True:
break;
case MatchingTypesResult.IfProxied:
positionsToProxy.Add(i);
break;
}
}
// method matched; proxying
this.ProxyMethod(factory, proxyBuilder, proxyMethod, targetMethod, targetField, glueField, positionsToProxy, sourceModID, targetModID);
goto proxyMethodLoopContinue;
targetMethodLoopContinue:;
}
throw new InvalidOperationException($"The {this.InterfaceType.FullName} interface defines method {proxyMethod.Name} which doesn't exist in the API.");
proxyMethodLoopContinue:;
}
// save info
this.ProxyType = proxyBuilder.CreateType();
}
/// Get an existing or create a new instance of the proxy for a target instance.
/// The target instance.
/// The that requested to build a proxy.
public object ObtainInstance(object targetInstance, InterfaceProxyFactory factory)
{
if (this.ProxyCache.TryGetValue(targetInstance, out object proxyInstance))
return proxyInstance;
ConstructorInfo constructor = this.ProxyType.GetConstructor(new[] { this.TargetType, typeof(InterfaceProxyGlue) });
if (constructor == null)
throw new InvalidOperationException($"Couldn't find the constructor for generated proxy type '{this.ProxyType.Name}'."); // should never happen
proxyInstance = constructor.Invoke(new[] { targetInstance, new InterfaceProxyGlue(factory) });
this.ProxyCache.Add(targetInstance, proxyInstance);
return proxyInstance;
}
/*********
** Private methods
*********/
/// Define a method which proxies access to a method on the target.
/// The that requested to build a proxy.
/// The proxy type being generated.
/// The proxy method.
/// The target method.
/// The proxy field containing the API instance.
/// The proxy field containing an .
/// Parameter type positions (or null for the return type) for which types should also be proxied.
/// The unique ID of the mod consuming the API.
/// The unique ID of the mod providing the API.
private void ProxyMethod(InterfaceProxyFactory factory, TypeBuilder proxyBuilder, MethodInfo proxy, MethodInfo target, FieldBuilder instanceField, FieldBuilder glueField, ISet positionsToProxy, string sourceModID, string targetModID)
{
MethodBuilder methodBuilder = proxyBuilder.DefineMethod(proxy.Name, MethodAttributes.Public | MethodAttributes.Final | MethodAttributes.Virtual);
// set up generic arguments
Type[] proxyGenericArguments = proxy.GetGenericArguments();
string[] genericArgNames = proxyGenericArguments.Select(a => a.Name).ToArray();
GenericTypeParameterBuilder[] genericTypeParameterBuilders = proxyGenericArguments.Length == 0 ? null : methodBuilder.DefineGenericParameters(genericArgNames);
for (int i = 0; i < proxyGenericArguments.Length; i++)
genericTypeParameterBuilders[i].SetGenericParameterAttributes(proxyGenericArguments[i].GenericParameterAttributes);
// set up return type
Type returnType = proxy.ReturnType.IsGenericMethodParameter ? genericTypeParameterBuilders[proxy.ReturnType.GenericParameterPosition] : proxy.ReturnType;
methodBuilder.SetReturnType(returnType);
// set up parameters
var targetParameters = target.GetParameters();
Type[] argTypes = proxy.GetParameters()
.Select(a => a.ParameterType)
.Select(t => t.IsGenericMethodParameter ? genericTypeParameterBuilders[t.GenericParameterPosition] : t)
.ToArray();
// proxy additional types
string returnValueProxyTypeName = null;
string[] parameterTargetToArgProxyTypeNames = new string[argTypes.Length];
string[] parameterArgToTargetUnproxyTypeNames = new string[argTypes.Length];
if (positionsToProxy.Count > 0)
{
foreach (int? position in positionsToProxy)
{
// we don't check for generics here, because earlier code does and generic positions won't end up here
if (position == null) // it's the return type
{
var builder = factory.ObtainBuilder(target.ReturnType, proxy.ReturnType, sourceModID, targetModID);
returnType = proxy.ReturnType;
returnValueProxyTypeName = builder.ProxyTypeName;
}
else // it's one of the parameters
{
bool isByRef = argTypes[position.Value].IsByRef;
var targetType = targetParameters[position.Value].ParameterType;
var argType = argTypes[position.Value];
var builder = factory.ObtainBuilder(targetType.GetNonRefType(), argType.GetNonRefType(), sourceModID, targetModID);
argTypes[position.Value] = argType;
parameterTargetToArgProxyTypeNames[position.Value] = builder.ProxyTypeName;
if (!targetParameters[position.Value].IsOut)
{
var argToTargetBuilder = factory.ObtainBuilder(argType.GetNonRefType(), targetType.GetNonRefType(), sourceModID, targetModID);
parameterArgToTargetUnproxyTypeNames[position.Value] = argToTargetBuilder.ProxyTypeName;
}
}
}
methodBuilder.SetReturnType(returnType);
}
methodBuilder.SetParameters(argTypes);
for (int i = 0; i < argTypes.Length; i++)
methodBuilder.DefineParameter(i, targetParameters[i].Attributes, targetParameters[i].Name);
// create method body
{
ILGenerator il = methodBuilder.GetILGenerator();
LocalBuilder[] inputLocals = new LocalBuilder[argTypes.Length];
LocalBuilder[] outputLocals = new LocalBuilder[argTypes.Length];
void ProxyIfNeededAndStore(LocalBuilder inputLocal, LocalBuilder outputLocal, string proxyTypeName, string unproxyTypeName)
{
if (proxyTypeName == null)
{
il.Emit(OpCodes.Ldloc, inputLocal);
il.Emit(OpCodes.Stloc, outputLocal);
return;
}
var isNullLabel = il.DefineLabel();
il.Emit(OpCodes.Ldloc, inputLocal);
il.Emit(OpCodes.Brfalse, isNullLabel);
il.Emit(OpCodes.Ldarg_0);
il.Emit(OpCodes.Ldfld, glueField);
if (unproxyTypeName == null)
{
il.Emit(OpCodes.Ldstr, proxyTypeName);
il.Emit(OpCodes.Ldloc, inputLocal);
il.Emit(OpCodes.Call, ObtainInstanceForProxyTypeNameMethod);
}
else
{
il.Emit(OpCodes.Ldstr, proxyTypeName);
il.Emit(OpCodes.Ldstr, unproxyTypeName);
il.Emit(OpCodes.Ldloc, inputLocal);
il.Emit(OpCodes.Call, UnproxyOrObtainInstanceForProxyTypeNameMethod);
}
il.Emit(OpCodes.Castclass, outputLocal.LocalType);
il.Emit(OpCodes.Stloc, outputLocal);
il.MarkLabel(isNullLabel);
}
// calling the proxied method
LocalBuilder resultInputLocal = target.ReturnType == typeof(void) ? null : il.DeclareLocal(target.ReturnType);
LocalBuilder resultOutputLocal = returnType == typeof(void) ? null : il.DeclareLocal(returnType);
il.Emit(OpCodes.Ldarg_0);
il.Emit(OpCodes.Ldfld, instanceField);
for (int i = 0; i < argTypes.Length; i++)
{
if (targetParameters[i].IsOut && parameterTargetToArgProxyTypeNames[i] != null) // out parameter, proxy on the way back
{
inputLocals[i] = il.DeclareLocal(targetParameters[i].ParameterType.GetNonRefType());
outputLocals[i] = il.DeclareLocal(argTypes[i].GetNonRefType());
il.Emit(OpCodes.Ldloca, inputLocals[i]);
}
else if (parameterArgToTargetUnproxyTypeNames[i] != null) // normal parameter, proxy on the way in
{
inputLocals[i] = il.DeclareLocal(argTypes[i].GetNonRefType());
outputLocals[i] = il.DeclareLocal(targetParameters[i].ParameterType.GetNonRefType());
il.Emit(OpCodes.Ldarg, i + 1);
il.Emit(OpCodes.Stloc, inputLocals[i]);
ProxyIfNeededAndStore(inputLocals[i], outputLocals[i], parameterArgToTargetUnproxyTypeNames[i], parameterTargetToArgProxyTypeNames[i]);
il.Emit(OpCodes.Ldloc, outputLocals[i]);
}
else // normal parameter, no proxying
{
il.Emit(OpCodes.Ldarg, i + 1);
}
}
il.Emit(target.IsVirtual ? OpCodes.Callvirt : OpCodes.Call, target);
if (target.ReturnType != typeof(void))
il.Emit(OpCodes.Stloc, resultInputLocal);
// proxying `out` parameters
for (int i = 0; i < argTypes.Length; i++)
{
if (parameterTargetToArgProxyTypeNames[i] == null)
continue;
if (!targetParameters[i].IsOut)
continue;
ProxyIfNeededAndStore(inputLocals[i], outputLocals[i], parameterTargetToArgProxyTypeNames[i], null);
il.Emit(OpCodes.Ldarg, i + 1);
il.Emit(OpCodes.Ldloc, outputLocals[i]);
il.Emit(OpCodes.Stind_Ref);
}
// proxying return value
if (target.ReturnType != typeof(void))
ProxyIfNeededAndStore(resultInputLocal, resultOutputLocal, returnValueProxyTypeName, null);
// return result
if (target.ReturnType != typeof(void))
il.Emit(OpCodes.Ldloc, resultOutputLocal);
il.Emit(OpCodes.Ret);
}
}
/// Try to get a target instance for a given proxy instance.
/// The proxy instance to look for.
/// The reference to store the found target instance in.
public bool TryUnproxy(object potentialProxyInstance, out object targetInstance)
{
foreach ((object cachedTargetInstance, object cachedProxyInstance) in this.ProxyCache)
{
if (object.ReferenceEquals(potentialProxyInstance, cachedProxyInstance))
{
targetInstance = cachedTargetInstance;
return true;
}
}
targetInstance = null;
return false;
}
/// The part of a method that is being matched.
private enum MethodTypeMatchingPart
{
ReturnType, Parameter
}
/// The result of matching a target and a proxy type.
private enum MatchingTypesResult
{
False, IfProxied, True
}
}
internal static class TypeExtensions
{
internal static Type GetNonRefType(this Type type)
{
return type.IsByRef ? type.GetElementType() : type;
}
}
}