From 530b120014c0ae7fc2994b21fc388ea36ddb4ce8 Mon Sep 17 00:00:00 2001
From: Jesse Plamondon-Willard <github@jplamondonw.com>
Date: Sun, 8 Jul 2018 15:48:32 -0400
Subject: rewrite TypeReference comparison to handle more edge cases, exit
 earlier if possible, and encapsulate a bit more

---
 .../Framework/ModLoading/TypeReferenceComparer.cs  | 304 ++++++++++-----------
 1 file changed, 148 insertions(+), 156 deletions(-)

diff --git a/src/SMAPI/Framework/ModLoading/TypeReferenceComparer.cs b/src/SMAPI/Framework/ModLoading/TypeReferenceComparer.cs
index 8d128b37..f7497789 100644
--- a/src/SMAPI/Framework/ModLoading/TypeReferenceComparer.cs
+++ b/src/SMAPI/Framework/ModLoading/TypeReferenceComparer.cs
@@ -1,21 +1,23 @@
+using System;
 using System.Collections.Generic;
-using System.Text.RegularExpressions;
+using System.Linq;
 using Mono.Cecil;
 
 namespace StardewModdingAPI.Framework.ModLoading
 {
     /// <summary>Performs heuristic equality checks for <see cref="TypeReference"/> instances.</summary>
+    /// <remarks>
+    /// This implementation compares <see cref="TypeReference"/> instances to see if they likely
+    /// refer to the same type. While the implementation is obvious for types like <c>System.Bool</c>,
+    /// this class mainly exists to handle cases like <c>System.Collections.Generic.Dictionary`2&lt;!0,Netcode.NetRoot`1&lt;!1&gt;&gt;</c>
+    /// and <c>System.Collections.Generic.Dictionary`2&lt;TKey,Netcode.NetRoot`1&lt;TValue&gt;&gt;</c>
+    /// which are compatible, but not directly comparable. It does this by splitting each type name
+    /// into its component token types, and performing placeholder substitution (e.g. <c>!0</c> to
+    /// <c>TKey</c> in the above example). If all components are equal after substitution, and the
+    /// tokens can all be mapped to the same generic type, the types are considered equal.
+    /// </remarks>
     internal class TypeReferenceComparer : IEqualityComparer<TypeReference>
     {
-        /*********
-        ** Properties
-        *********/
-        /// <summary>A pattern matching type name substrings to strip for display.</summary>
-        private readonly Regex StripTypeNamePattern = new Regex(@"`\d+(?=<)", RegexOptions.Compiled);
-
-        private List<char> symbolBoundaries = new List<char> { '<', '>', ',' };
-
-
         /*********
         ** Public methods
         *********/
@@ -24,25 +26,13 @@ namespace StardewModdingAPI.Framework.ModLoading
         /// <param name="b">The second object to compare.</param>
         public bool Equals(TypeReference a, TypeReference b)
         {
-            string typeA = this.GetComparableTypeID(a);
-            string typeB = this.GetComparableTypeID(b);
-
-            string placeholderType = "", actualType = "";
+            if (a == null || b == null)
+                return a == b;
 
-            if (this.HasPlaceholder(typeA))
-            {
-                placeholderType = typeA;
-                actualType = typeB;
-            }
-            else if (this.HasPlaceholder(typeB))
-            {
-                placeholderType = typeB;
-                actualType = typeA;
-            }
-            else
-                return typeA == typeB;
-
-            return this.PlaceholderTypeValidates(placeholderType, actualType);
+            return
+                a == b
+                || a.FullName == b.FullName
+                || this.HeuristicallyEquals(a, b);
         }
 
         /// <summary>Get a hash code for the specified object.</summary>
@@ -57,153 +47,155 @@ namespace StardewModdingAPI.Framework.ModLoading
         /*********
         ** Private methods
         *********/
-        /// <summary>Get a unique string representation of a type.</summary>
-        /// <param name="type">The type reference.</param>
-        private string GetComparableTypeID(TypeReference type)
+        /// <summary>Get whether two types are heuristically equal based on generic type token substitution.</summary>
+        /// <param name="typeA">The first type to compare.</param>
+        /// <param name="typeB">The second type to compare.</param>
+        private bool HeuristicallyEquals(TypeReference typeA, TypeReference typeB)
         {
-            return this.StripTypeNamePattern.Replace(type.FullName, "");
-        }
-
-        /// <summary>Determine whether this type ID has a placeholder such as !0.</summary>
-        /// <param name="typeID">The type to check.</param>
-        /// <returns>true if the type ID contains a placeholder, false if not.</returns>
-        private bool HasPlaceholder(string typeID)
-        {
-            return typeID.Contains("!0");
-        }
-
-        /// <summary> returns whether this type ID is a placeholder, i.e., it begins with "!".</summary>
-        /// <param name="symbol">The symbol to validate.</param>
-        /// <returns>true if the symbol is a placeholder, false if not</returns>
-        private bool IsPlaceholder(string symbol)
-        {
-            return symbol.StartsWith("!");
-        }
-
-        /// <summary> Traverses and parses out symbols from a type which does not contain placeholder values.</summary>
-        /// <param name="type">The type to traverse.</param>
-        /// <param name="typeSymbols">A List in which to store the parsed symbols.</param>
-        private void TraverseActualType(string type, List<SymbolLocation> typeSymbols)
-        {
-            int depth = 0;
-            string symbol = "";
-
-            foreach (char c in type)
+            bool HeuristicallyEquals(string typeNameA, string typeNameB, IDictionary<string, string> tokenMap)
             {
-                if (this.symbolBoundaries.Contains(c))
+                // analyse type names
+                bool hasTokensA = typeNameA.Contains("!");
+                bool hasTokensB = typeNameB.Contains("!");
+                bool isTokenA = hasTokensA && typeNameA[0] == '!';
+                bool isTokenB = hasTokensB && typeNameB[0] == '!';
+
+                // validate
+                if (!hasTokensA && !hasTokensB)
+                    return typeNameA == typeNameB; // no substitution needed
+                if (hasTokensA && hasTokensB)
+                    throw new InvalidOperationException("Can't compare two type names when both contain generic type tokens.");
+
+                // perform substitution if applicable
+                if (isTokenA)
+                    typeNameA = this.MapPlaceholder(placeholder: typeNameA, type: typeNameB, map: tokenMap);
+                if (isTokenB)
+                    typeNameB = this.MapPlaceholder(placeholder: typeNameB, type: typeNameA, map: tokenMap);
+
+                // compare inner tokens
+                string[] symbolsA = this.GetTypeSymbols(typeNameA).ToArray();
+                string[] symbolsB = this.GetTypeSymbols(typeNameB).ToArray();
+                if (symbolsA.Length != symbolsB.Length)
+                    return false;
+
+                for (int i = 0; i < symbolsA.Length; i++)
                 {
-                    typeSymbols.Add(new SymbolLocation(symbol, depth));
-                    symbol = "";
-                    switch (c)
-                    {
-                        case '<':
-                            depth++;
-                            break;
-                        case '>':
-                            depth--;
-                            break;
-                        default:
-                            break;
-                    }
+                    if (!HeuristicallyEquals(symbolsA[i], symbolsB[i], tokenMap))
+                        return false;
                 }
-                else
-                    symbol += c;
-            }
-        }
 
-        /// <summary> Determines whether two symbols in a type ID match, accounting for placeholders such as !0.</summary>
-        /// <param name="symbolA">A symbol in a typename which contains placeholders.</param>
-        /// <param name="symbolB">A symbol in a typename which does not contain placeholders.</param>
-        /// <param name="placeholderMap">A dictionary containing a mapping of placeholders to concrete types.</param>
-        /// <returns>true if the symbols match, false if not.</returns>
-        private bool SymbolsMatch(SymbolLocation symbolA, SymbolLocation symbolB, Dictionary<string, string> placeholderMap)
-        {
-            if (symbolA.depth != symbolB.depth)
-                return false;
-
-            if (!this.IsPlaceholder(symbolA.symbol))
-            {
-                return symbolA.symbol == symbolB.symbol;
+                return true;
             }
 
-            if (placeholderMap.ContainsKey(symbolA.symbol))
-            {
-                return placeholderMap[symbolA.symbol] == symbolB.symbol;
-            }
+            return HeuristicallyEquals(typeA.FullName, typeB.FullName, new Dictionary<string, string>());
+        }
 
-            placeholderMap[symbolA.symbol] = symbolB.symbol;
+        /// <summary>Map a generic type placeholder (like <c>!0</c>) to its actual type.</summary>
+        /// <param name="placeholder">The token placeholder.</param>
+        /// <param name="type">The actual type.</param>
+        /// <param name="map">The map of token to map substitutions.</param>
+        /// <returns>Returns the previously-mapped type if applicable, else the <paramref name="type"/>.</returns>
+        private string MapPlaceholder(string placeholder, string type, IDictionary<string, string> map)
+        {
+            if (map.TryGetValue(placeholder, out string result))
+                return result;
 
-            return true;
+            map[placeholder] = type;
+            return type;
         }
 
-        /// <summary> Determines whether a type which has placeholders correctly resolves to the concrete type provided. </summary>
-        /// <param name="type">A type containing placeholders such as !0.</param>
-        /// <param name="typeSymbols">The list of symbols extracted from the concrete type.</param>
-        /// <returns>true if the type resolves correctly, false if not.</returns>
-        private bool PlaceholderTypeResolvesToActualType(string type, List<SymbolLocation> typeSymbols)
+        /// <summary>Get the top-level type symbols in a type name (e.g. <code>List</code> and <code>NetRef&lt;T&gt;</code> in <code>List&lt;NetRef&lt;T&gt;&gt;</code>)</summary>
+        /// <param name="typeName">The full type name.</param>
+        private IEnumerable<string> GetTypeSymbols(string typeName)
         {
-            Dictionary<string, string> placeholderMap = new Dictionary<string, string>();
+            int openGenerics = 0;
 
-            int depth = 0, symbolCount = 0;
+            Queue<char> queue = new Queue<char>(typeName);
             string symbol = "";
-
-            foreach (char c in type)
+            while (queue.Any())
             {
-                if (this.symbolBoundaries.Contains(c))
+                char ch = queue.Dequeue();
+                switch (ch)
                 {
-                    bool match = this.SymbolsMatch(new SymbolLocation(symbol, depth), typeSymbols[symbolCount], placeholderMap);
-                    if (typeSymbols.Count <= symbolCount ||
-                        !match)
-                        return false;
-
-                    symbolCount++;
-                    symbol = "";
-                    switch (c)
-                    {
-                        case '<':
-                            depth++;
-                            break;
-                        case '>':
-                            depth--;
-                            break;
-                        default:
-                            break;
-                    }
+                    // skip `1 generic type identifiers
+                    case '`':
+                        while (int.TryParse(queue.Peek().ToString(), out int _))
+                            queue.Dequeue();
+                        break;
+
+                    // start generic args
+                    case '<':
+                        switch (openGenerics)
+                        {
+                            // start new generic symbol
+                            case 0:
+                                yield return symbol;
+                                symbol = "";
+                                openGenerics++;
+                                break;
+
+                            // continue accumulating nested type symbol
+                            default:
+                                symbol += ch;
+                                openGenerics++;
+                                break;
+                        }
+                        break;
+
+                    // generic args delimiter
+                    case ',':
+                        switch (openGenerics)
+                        {
+                            // invalid
+                            case 0:
+                                throw new InvalidOperationException($"Encountered unexpected comma in type name: {typeName}.");
+
+                            // start next generic symbol
+                            case 1:
+                                yield return symbol;
+                                symbol = "";
+                                break;
+
+                            // continue accumulating nested type symbol
+                            default:
+                                symbol += ch;
+                                break;
+                        }
+                        break;
+
+
+                    // end generic args
+                    case '>':
+                        switch (openGenerics)
+                        {
+                            // invalid
+                            case 0:
+                                throw new InvalidOperationException($"Encountered unexpected closing generic in type name: {typeName}.");
+
+                            // end generic symbol
+                            case 1:
+                                yield return symbol;
+                                symbol = "";
+                                openGenerics--;
+                                break;
+
+                            // continue accumulating nested type symbol
+                            default:
+                                symbol += ch;
+                                openGenerics--;
+                                break;
+                        }
+                        break;
+
+                    // continue symbol
+                    default:
+                        symbol += ch;
+                        break;
                 }
-                else
-                    symbol += c;
             }
 
-            return true;
-        }
-
-        /// <summary>Determines whether a type with placeholders in it matches a type without placeholders.</summary>
-        /// <param name="placeholderType">The type with placeholders in it.</param>
-        /// <param name="actualType">The type without placeholders.</param>
-        /// <returns>true if the placeholder type can resolve to the actual type, false if not.</returns>
-        private bool PlaceholderTypeValidates(string placeholderType, string actualType)
-        {
-            List<SymbolLocation> typeSymbols = new List<SymbolLocation>();
-
-            this.TraverseActualType(actualType, typeSymbols);
-            return PlaceholderTypeResolvesToActualType(placeholderType, typeSymbols);
-        }
-
-
-
-        /*********
-        ** Inner classes
-        *********/
-        protected class SymbolLocation
-        {
-            public string symbol;
-            public int depth;
-
-            public SymbolLocation(string symbol, int depth)
-            {
-                this.symbol = symbol;
-                this.depth = depth;
-            }
+            if (symbol != "")
+                yield return symbol;
         }
     }
 }
-- 
cgit