From 43ad219b75740ef71ad9bad496b00c076182619b Mon Sep 17 00:00:00 2001
From: Shockah <me@shockah.pl>
Date: Wed, 9 Feb 2022 20:07:01 +0100
Subject: support proxying return values in API proxies

---
 .../Framework/Reflection/InterfaceProxyBuilder.cs  | 183 ++++++++++++++++-----
 .../Framework/Reflection/InterfaceProxyFactory.cs  |  30 +++-
 .../Framework/Reflection/InterfaceProxyGlue.cs     |  18 ++
 3 files changed, 186 insertions(+), 45 deletions(-)
 create mode 100644 src/SMAPI/Framework/Reflection/InterfaceProxyGlue.cs

(limited to 'src')

diff --git a/src/SMAPI/Framework/Reflection/InterfaceProxyBuilder.cs b/src/SMAPI/Framework/Reflection/InterfaceProxyBuilder.cs
index 5ae96dff..d8b066bd 100644
--- a/src/SMAPI/Framework/Reflection/InterfaceProxyBuilder.cs
+++ b/src/SMAPI/Framework/Reflection/InterfaceProxyBuilder.cs
@@ -1,13 +1,22 @@
 using System;
+using System.Collections.Generic;
 using System.Linq;
 using System.Reflection;
 using System.Reflection.Emit;
+using HarmonyLib;
 
 namespace StardewModdingAPI.Framework.Reflection
 {
     /// <summary>Generates a proxy class to access a mod API through an arbitrary interface.</summary>
     internal class InterfaceProxyBuilder
     {
+        /*********
+        ** Consts
+        *********/
+        private static readonly string TargetFieldName = "__Target";
+        private static readonly string GlueFieldName = "__Glue";
+        private static readonly MethodInfo CreateInstanceForProxyTypeNameMethod = typeof(InterfaceProxyGlue).GetMethod(nameof(InterfaceProxyGlue.CreateInstanceForProxyTypeName), new Type[] { typeof(string), typeof(object) });
+
         /*********
         ** Fields
         *********/
@@ -22,11 +31,14 @@ namespace StardewModdingAPI.Framework.Reflection
         ** Public methods
         *********/
         /// <summary>Construct an instance.</summary>
+        /// <param name="factory">The <see cref="InterfaceProxyFactory"/> that requested to build a proxy.</param>
         /// <param name="name">The type name to generate.</param>
         /// <param name="moduleBuilder">The CLR module in which to create proxy classes.</param>
         /// <param name="interfaceType">The interface type to implement.</param>
         /// <param name="targetType">The target type.</param>
-        public InterfaceProxyBuilder(string name, ModuleBuilder moduleBuilder, Type interfaceType, Type targetType)
+        /// <param name="sourceModID">The unique ID of the mod consuming the API.</param>
+        /// <param name="targetModID">The unique ID of the mod providing the API.</param>
+        public InterfaceProxyBuilder(InterfaceProxyFactory factory, string name, ModuleBuilder moduleBuilder, Type interfaceType, Type targetType, string sourceModID, string targetModID)
         {
             // validate
             if (name == null)
@@ -38,12 +50,13 @@ namespace StardewModdingAPI.Framework.Reflection
             TypeBuilder proxyBuilder = moduleBuilder.DefineType(name, TypeAttributes.Public | TypeAttributes.Class);
             proxyBuilder.AddInterfaceImplementation(interfaceType);
 
-            // create field to store target instance
-            FieldBuilder targetField = proxyBuilder.DefineField("__Target", targetType, FieldAttributes.Private);
+            // create fields to store target instance and proxy factory
+            FieldBuilder targetField = proxyBuilder.DefineField(TargetFieldName, targetType, FieldAttributes.Private);
+            FieldBuilder glueField = proxyBuilder.DefineField(GlueFieldName, typeof(InterfaceProxyGlue), FieldAttributes.Private);
 
-            // create constructor which accepts target instance and sets field
+            // create constructor which accepts target instance + factory, and sets fields
             {
-                ConstructorBuilder constructor = proxyBuilder.DefineConstructor(MethodAttributes.Public, CallingConventions.Standard | CallingConventions.HasThis, new[] { targetType });
+                ConstructorBuilder constructor = proxyBuilder.DefineConstructor(MethodAttributes.Public, CallingConventions.Standard | CallingConventions.HasThis, new[] { targetType, typeof(InterfaceProxyGlue) });
                 ILGenerator il = constructor.GetILGenerator();
 
                 il.Emit(OpCodes.Ldarg_0); // this
@@ -52,6 +65,9 @@ namespace StardewModdingAPI.Framework.Reflection
                 il.Emit(OpCodes.Ldarg_0);      // this
                 il.Emit(OpCodes.Ldarg_1);      // load argument
                 il.Emit(OpCodes.Stfld, targetField); // set field to loaded argument
+                il.Emit(OpCodes.Ldarg_0);      // this
+                il.Emit(OpCodes.Ldarg_2);      // load argument
+                il.Emit(OpCodes.Stfld, glueField); // set field to loaded argument
                 il.Emit(OpCodes.Ret);
             }
 
@@ -65,15 +81,20 @@ namespace StardewModdingAPI.Framework.Reflection
                 }
             }
 
-            bool AreTypesMatching(Type targetType, Type proxyType, MethodTypeMatchingPart part)
+            MatchingTypesResult AreTypesMatching(Type targetType, Type proxyType, MethodTypeMatchingPart part)
             {
                 var typeA = part == MethodTypeMatchingPart.Parameter ? targetType : proxyType;
                 var typeB = part == MethodTypeMatchingPart.Parameter ? proxyType : targetType;
 
                 if (typeA.IsGenericMethodParameter != typeB.IsGenericMethodParameter)
-                    return false;
+                    return MatchingTypesResult.False;
                 // TODO: decide if "assignable" checking is desired (instead of just 1:1 type equality)
-                return typeA.IsGenericMethodParameter ? typeA.GenericParameterPosition == typeB.GenericParameterPosition : typeA.IsAssignableFrom(typeB);
+                if (typeA.IsGenericMethodParameter ? typeA.GenericParameterPosition == typeB.GenericParameterPosition : typeA.IsAssignableFrom(typeB))
+                    return MatchingTypesResult.True;
+
+                if (!proxyType.IsGenericMethodParameter && proxyType.IsInterface && proxyType.Assembly == interfaceType.Assembly)
+                    return MatchingTypesResult.IfProxied;
+                return MatchingTypesResult.False;
             }
 
             // proxy methods
@@ -81,30 +102,61 @@ namespace StardewModdingAPI.Framework.Reflection
             {
                 var proxyMethodParameters = proxyMethod.GetParameters();
                 var proxyMethodGenericArguments = proxyMethod.GetGenericArguments();
-                var targetMethod = allTargetMethods.Where(m =>
+
+                foreach (MethodInfo targetMethod in allTargetMethods)
                 {
-                    if (m.Name != proxyMethod.Name)
-                        return false;
+                    // checking if `targetMethod` matches `proxyMethod`
 
-                    if (m.GetGenericArguments().Length != proxyMethodGenericArguments.Length)
-                        return false;
-                    if (!AreTypesMatching(m.ReturnType, proxyMethod.ReturnType, MethodTypeMatchingPart.ReturnType))
-                        return false;
+                    if (targetMethod.Name != proxyMethod.Name)
+                        continue;
+                    if (targetMethod.GetGenericArguments().Length != proxyMethodGenericArguments.Length)
+                        continue;
+                    var positionsToProxy = new HashSet<int?>(); // null = return type; anything else = parameter position
+
+                    switch (AreTypesMatching(targetMethod.ReturnType, proxyMethod.ReturnType, MethodTypeMatchingPart.ReturnType))
+                    {
+                        case MatchingTypesResult.False:
+                            continue;
+                        case MatchingTypesResult.True:
+                            break;
+                        case MatchingTypesResult.IfProxied:
+                            positionsToProxy.Add(null);
+                            break;
+                    }
 
-                    var mParameters = m.GetParameters();
-                    if (m.GetParameters().Length != proxyMethodParameters.Length)
-                        return false;
+                    var mParameters = targetMethod.GetParameters();
+                    if (mParameters.Length != proxyMethodParameters.Length)
+                        continue;
                     for (int i = 0; i < mParameters.Length; i++)
                     {
-                        if (!AreTypesMatching(mParameters[i].ParameterType, proxyMethodParameters[i].ParameterType, MethodTypeMatchingPart.Parameter))
-                            return false;
+                        switch (AreTypesMatching(mParameters[i].ParameterType, proxyMethodParameters[i].ParameterType, MethodTypeMatchingPart.Parameter))
+                        {
+                            case MatchingTypesResult.False:
+                                goto targetMethodLoopContinue;
+                            case MatchingTypesResult.True:
+                                break;
+                            case MatchingTypesResult.IfProxied:
+                                if (proxyMethodParameters[i].IsOut)
+                                {
+                                    positionsToProxy.Add(i);
+                                    break;
+                                }
+                                else
+                                {
+                                    goto targetMethodLoopContinue;
+                                }
+                        }
                     }
-                    return true;
-                }).FirstOrDefault();
-                if (targetMethod == null)
-                    throw new InvalidOperationException($"The {interfaceType.FullName} interface defines method {proxyMethod.Name} which doesn't exist in the API.");
 
-                this.ProxyMethod(proxyBuilder, proxyMethod, targetMethod, targetField);
+                    // method matched; proxying
+
+                    this.ProxyMethod(factory, proxyBuilder, proxyMethod, targetMethod, targetField, glueField, positionsToProxy, sourceModID, targetModID);
+                    goto proxyMethodLoopContinue;
+                    targetMethodLoopContinue:;
+                }
+
+                throw new InvalidOperationException($"The {interfaceType.FullName} interface defines method {proxyMethod.Name} which doesn't exist in the API.");
+                proxyMethodLoopContinue:;
             }
 
             // save info
@@ -114,12 +166,13 @@ namespace StardewModdingAPI.Framework.Reflection
 
         /// <summary>Create an instance of the proxy for a target instance.</summary>
         /// <param name="targetInstance">The target instance.</param>
-        public object CreateInstance(object targetInstance)
+        /// <param name="factory">The <see cref="InterfaceProxyFactory"/> that requested to build a proxy.</param>
+        public object CreateInstance(object targetInstance, InterfaceProxyFactory factory)
         {
-            ConstructorInfo constructor = this.ProxyType.GetConstructor(new[] { this.TargetType });
+            ConstructorInfo constructor = this.ProxyType.GetConstructor(new[] { this.TargetType, typeof(InterfaceProxyGlue) });
             if (constructor == null)
                 throw new InvalidOperationException($"Couldn't find the constructor for generated proxy type '{this.ProxyType.Name}'."); // should never happen
-            return constructor.Invoke(new[] { targetInstance });
+            return constructor.Invoke(new[] { targetInstance, new InterfaceProxyGlue(factory) });
         }
 
 
@@ -127,11 +180,16 @@ namespace StardewModdingAPI.Framework.Reflection
         ** Private methods
         *********/
         /// <summary>Define a method which proxies access to a method on the target.</summary>
+        /// <param name="factory">The <see cref="InterfaceProxyFactory"/> that requested to build a proxy.</param>
         /// <param name="proxyBuilder">The proxy type being generated.</param>
         /// <param name="proxy">The proxy method.</param>
         /// <param name="target">The target method.</param>
         /// <param name="instanceField">The proxy field containing the API instance.</param>
-        private void ProxyMethod(TypeBuilder proxyBuilder, MethodInfo proxy, MethodInfo target, FieldBuilder instanceField)
+        /// <param name="glueField">The proxy field containing an <see cref="InterfaceProxyGlue"/>.</param>
+        /// <param name="positionsToProxy">Parameter type positions (or null for the return type) for which types should also be proxied.</param>
+        /// <param name="sourceModID">The unique ID of the mod consuming the API.</param>
+        /// <param name="targetModID">The unique ID of the mod providing the API.</param>
+        private void ProxyMethod(InterfaceProxyFactory factory, TypeBuilder proxyBuilder, MethodInfo proxy, MethodInfo target, FieldBuilder instanceField, FieldBuilder glueField, ISet<int?> positionsToProxy, string sourceModID, string targetModID)
         {
             MethodBuilder methodBuilder = proxyBuilder.DefineMethod(proxy.Name, MethodAttributes.Public | MethodAttributes.Final | MethodAttributes.Virtual);
 
@@ -143,7 +201,8 @@ namespace StardewModdingAPI.Framework.Reflection
                 genericTypeParameterBuilders[i].SetGenericParameterAttributes(proxyGenericArguments[i].GenericParameterAttributes);
 
             // set up return type
-            methodBuilder.SetReturnType(proxy.ReturnType.IsGenericMethodParameter ? genericTypeParameterBuilders[proxy.ReturnType.GenericParameterPosition] : proxy.ReturnType);
+            Type returnType = proxy.ReturnType.IsGenericMethodParameter ? genericTypeParameterBuilders[proxy.ReturnType.GenericParameterPosition] : proxy.ReturnType;
+            methodBuilder.SetReturnType(returnType);
 
             // set up parameters
             Type[] argTypes = proxy.GetParameters()
@@ -152,18 +211,62 @@ namespace StardewModdingAPI.Framework.Reflection
                 .ToArray();
             methodBuilder.SetParameters(argTypes);
 
+            // proxy additional types
+            string returnValueProxyTypeName = null;
+            string[] parameterProxyTypeNames = new string[argTypes.Length];
+            if (positionsToProxy.Count > 0)
+            {
+                var targetParameters = target.GetParameters();
+                foreach (int? position in positionsToProxy)
+                {
+                    // we don't check for generics here, because earlier code does and generic positions won't end up here
+                    if (position == null) // it's the return type
+                    {
+                        var builder = factory.ObtainBuilder(target.ReturnType, proxy.ReturnType, sourceModID, targetModID);
+                        returnType = proxy.ReturnType;
+                        returnValueProxyTypeName = builder.ProxyType.FullName;
+                    }
+                    else // it's one of the parameters
+                    {
+                        var builder = factory.ObtainBuilder(targetParameters[position.Value].ParameterType, argTypes[position.Value], sourceModID, targetModID);
+                        argTypes[position.Value] = proxy.ReturnType;
+                        parameterProxyTypeNames[position.Value] = builder.ProxyType.FullName;
+                    }
+                }
+
+                methodBuilder.SetReturnType(returnType);
+                methodBuilder.SetParameters(argTypes);
+            }
+
             // create method body
             {
                 ILGenerator il = methodBuilder.GetILGenerator();
 
-                // load target instance
-                il.Emit(OpCodes.Ldarg_0);
-                il.Emit(OpCodes.Ldfld, instanceField);
+                void EmitCallInstance()
+                {
+                    // load target instance
+                    il.Emit(OpCodes.Ldarg_0);
+                    il.Emit(OpCodes.Ldfld, instanceField);
+
+                    // invoke target method on instance
+                    for (int i = 0; i < argTypes.Length; i++)
+                        il.Emit(OpCodes.Ldarg, i + 1);
+                    il.Emit(OpCodes.Callvirt, target);
+                }
 
-                // invoke target method on instance
-                for (int i = 0; i < argTypes.Length; i++)
-                    il.Emit(OpCodes.Ldarg, i + 1);
-                il.Emit(OpCodes.Call, target);
+                if (returnValueProxyTypeName == null)
+                {
+                    EmitCallInstance();
+                }
+                else
+                {
+                    // this.Glue.CreateInstanceForProxyTypeName(proxyTypeName, this.Instance.Call(args))
+                    il.Emit(OpCodes.Ldarg_0);
+                    il.Emit(OpCodes.Ldfld, glueField);
+                    il.Emit(OpCodes.Ldstr, returnValueProxyTypeName);
+                    EmitCallInstance();
+                    il.Emit(OpCodes.Call, CreateInstanceForProxyTypeNameMethod);
+                }
 
                 // return result
                 il.Emit(OpCodes.Ret);
@@ -175,5 +278,11 @@ namespace StardewModdingAPI.Framework.Reflection
         {
             ReturnType, Parameter
         }
+
+        /// <summary>The result of matching a target and a proxy type.</summary>
+        private enum MatchingTypesResult
+        {
+            False, IfProxied, True
+        }
     }
 }
diff --git a/src/SMAPI/Framework/Reflection/InterfaceProxyFactory.cs b/src/SMAPI/Framework/Reflection/InterfaceProxyFactory.cs
index 5acba569..8ce187bf 100644
--- a/src/SMAPI/Framework/Reflection/InterfaceProxyFactory.cs
+++ b/src/SMAPI/Framework/Reflection/InterfaceProxyFactory.cs
@@ -35,26 +35,40 @@ namespace StardewModdingAPI.Framework.Reflection
         /// <param name="targetModID">The unique ID of the mod providing the API.</param>
         public TInterface CreateProxy<TInterface>(object instance, string sourceModID, string targetModID)
             where TInterface : class
+        {
+            // validate
+            if (instance == null)
+                throw new InvalidOperationException("Can't proxy access to a null API.");
+
+            // create instance
+            InterfaceProxyBuilder builder = this.ObtainBuilder(instance.GetType(), typeof(TInterface), sourceModID, targetModID);
+            return (TInterface)builder.CreateInstance(instance, this);
+        }
+
+        internal InterfaceProxyBuilder ObtainBuilder(Type targetType, Type interfaceType, string sourceModID, string targetModID)
         {
             lock (this.Builders)
             {
                 // validate
-                if (instance == null)
-                    throw new InvalidOperationException("Can't proxy access to a null API.");
-                if (!typeof(TInterface).IsInterface)
+                if (!interfaceType.IsInterface)
                     throw new InvalidOperationException("The proxy type must be an interface, not a class.");
 
                 // get proxy type
-                Type targetType = instance.GetType();
-                string proxyTypeName = $"StardewModdingAPI.Proxies.From<{sourceModID}_{typeof(TInterface).FullName}>_To<{targetModID}_{targetType.FullName}>";
+                string proxyTypeName = $"StardewModdingAPI.Proxies.From<{sourceModID}_{interfaceType.FullName}>_To<{targetModID}_{targetType.FullName}>";
                 if (!this.Builders.TryGetValue(proxyTypeName, out InterfaceProxyBuilder builder))
                 {
-                    builder = new InterfaceProxyBuilder(proxyTypeName, this.ModuleBuilder, typeof(TInterface), targetType);
+                    builder = new InterfaceProxyBuilder(this, proxyTypeName, this.ModuleBuilder, interfaceType, targetType, sourceModID, targetModID);
                     this.Builders[proxyTypeName] = builder;
                 }
+                return builder;
+            }
+        }
 
-                // create instance
-                return (TInterface)builder.CreateInstance(instance);
+        internal InterfaceProxyBuilder GetBuilderByProxyTypeName(string proxyTypeName)
+        {
+            lock (this.Builders)
+            {
+                return this.Builders.TryGetValue(proxyTypeName, out InterfaceProxyBuilder builder) ? builder : null;
             }
         }
     }
diff --git a/src/SMAPI/Framework/Reflection/InterfaceProxyGlue.cs b/src/SMAPI/Framework/Reflection/InterfaceProxyGlue.cs
new file mode 100644
index 00000000..4e027252
--- /dev/null
+++ b/src/SMAPI/Framework/Reflection/InterfaceProxyGlue.cs
@@ -0,0 +1,18 @@
+namespace StardewModdingAPI.Framework.Reflection
+{
+    public sealed class InterfaceProxyGlue
+    {
+        private readonly InterfaceProxyFactory Factory;
+
+        internal InterfaceProxyGlue(InterfaceProxyFactory factory)
+        {
+            this.Factory = factory;
+        }
+
+        public object CreateInstanceForProxyTypeName(string proxyTypeName, object toProxy)
+        {
+            var builder = this.Factory.GetBuilderByProxyTypeName(proxyTypeName);
+            return builder.CreateInstance(toProxy, this.Factory);
+        }
+    }
+}
-- 
cgit