package gregtech.api.objects;

import java.util.Arrays;

/**
 * Object-2-object map based on IntIntMap4a
 */
public class ObjMap<K, V> {

    private static final Object FREE_KEY = new Object();
    private static final Object REMOVED_KEY = new Object();

    /** Keys and values */
    private Object[] m_data;

    /** Value for the null key (if inserted into a map) */
    private Object m_nullValue;

    private boolean m_hasNull;

    /** Fill factor, must be between (0 and 1) */
    private final float m_fillFactor;
    /** We will resize a map once it reaches this size */
    private int m_threshold;
    /** Current map size */
    private int m_size;
    /** Mask to calculate the original position */
    private int m_mask;
    /** Mask to wrap the actual array pointer */
    private int m_mask2;

    public ObjMap(final int size, final float fillFactor) {
        if (fillFactor <= 0 || fillFactor >= 1) throw new IllegalArgumentException("FillFactor must be in (0, 1)");
        if (size <= 0) throw new IllegalArgumentException("Size must be positive!");
        final int capacity = arraySize(size, fillFactor);
        m_mask = capacity - 1;
        m_mask2 = capacity * 2 - 1;
        m_fillFactor = fillFactor;

        m_data = new Object[capacity * 2];
        Arrays.fill(m_data, FREE_KEY);

        m_threshold = (int) (capacity * fillFactor);
    }

    @SuppressWarnings("unchecked")
    public V get(final K key) {
        if (key == null) return (V) m_nullValue; // we null it on remove, so safe not to check a flag here

        int ptr = (key.hashCode() & m_mask) << 1;
        Object k = m_data[ptr];

        if (k == FREE_KEY) return null; // end of chain already
        if (k.equals(key)) // we check FREE and REMOVED prior to this call
            return (V) m_data[ptr + 1];
        while (true) {
            ptr = (ptr + 2) & m_mask2; // that's next index
            k = m_data[ptr];
            if (k == FREE_KEY) return null;
            if (k.equals(key)) return (V) m_data[ptr + 1];
        }
    }

    @SuppressWarnings("unchecked")
    public V put(final K key, final V value) {
        if (key == null) return insertNullKey(value);

        int ptr = getStartIndex(key) << 1;
        Object k = m_data[ptr];

        if (k == FREE_KEY) // end of chain already
        {
            m_data[ptr] = key;
            m_data[ptr + 1] = value;
            if (m_size >= m_threshold) rehash(m_data.length * 2); // size is set inside
            else++m_size;
            return null;
        } else if (k.equals(key)) // we check FREE and REMOVED prior to this call
        {
            final Object ret = m_data[ptr + 1];
            m_data[ptr + 1] = value;
            return (V) ret;
        }

        int firstRemoved = -1;
        if (k == REMOVED_KEY) firstRemoved = ptr; // we may find a key later

        while (true) {
            ptr = (ptr + 2) & m_mask2; // that's next index calculation
            k = m_data[ptr];
            if (k == FREE_KEY) {
                if (firstRemoved != -1) ptr = firstRemoved;
                m_data[ptr] = key;
                m_data[ptr + 1] = value;
                if (m_size >= m_threshold) rehash(m_data.length * 2); // size is set inside
                else++m_size;
                return null;
            } else if (k.equals(key)) {
                final Object ret = m_data[ptr + 1];
                m_data[ptr + 1] = value;
                return (V) ret;
            } else if (k == REMOVED_KEY) {
                if (firstRemoved == -1) firstRemoved = ptr;
            }
        }
    }

    @SuppressWarnings("unchecked")
    public V remove(final K key) {
        if (key == null) return removeNullKey();

        int ptr = getStartIndex(key) << 1;
        Object k = m_data[ptr];
        if (k == FREE_KEY) return null; // end of chain already
        else if (k.equals(key)) // we check FREE and REMOVED prior to this call
        {
            --m_size;
            if (m_data[(ptr + 2) & m_mask2] == FREE_KEY) m_data[ptr] = FREE_KEY;
            else m_data[ptr] = REMOVED_KEY;
            final V ret = (V) m_data[ptr + 1];
            m_data[ptr + 1] = null;
            return ret;
        }
        while (true) {
            ptr = (ptr + 2) & m_mask2; // that's next index calculation
            k = m_data[ptr];
            if (k == FREE_KEY) return null;
            else if (k.equals(key)) {
                --m_size;
                if (m_data[(ptr + 2) & m_mask2] == FREE_KEY) m_data[ptr] = FREE_KEY;
                else m_data[ptr] = REMOVED_KEY;
                final V ret = (V) m_data[ptr + 1];
                m_data[ptr + 1] = null;
                return ret;
            }
        }
    }

    @SuppressWarnings("unchecked")
    private V insertNullKey(final V value) {
        if (m_hasNull) {
            final Object ret = m_nullValue;
            m_nullValue = value;
            return (V) ret;
        } else {
            m_nullValue = value;
            ++m_size;
            return null;
        }
    }

    @SuppressWarnings("unchecked")
    private V removeNullKey() {
        if (m_hasNull) {
            final Object ret = m_nullValue;
            m_nullValue = null;
            m_hasNull = false;
            --m_size;
            return (V) ret;
        } else {
            return null;
        }
    }

    public int size() {
        return m_size;
    }

    @SuppressWarnings("unchecked")
    private void rehash(final int newCapacity) {
        m_threshold = (int) (newCapacity / 2 * m_fillFactor);
        m_mask = newCapacity / 2 - 1;
        m_mask2 = newCapacity - 1;

        final int oldCapacity = m_data.length;
        final Object[] oldData = m_data;

        m_data = new Object[newCapacity];
        Arrays.fill(m_data, FREE_KEY);

        m_size = m_hasNull ? 1 : 0;

        for (int i = 0; i < oldCapacity; i += 2) {
            final Object oldKey = oldData[i];
            if (oldKey != FREE_KEY && oldKey != REMOVED_KEY) put((K) oldKey, (V) oldData[i + 1]);
        }
    }

    public int getStartIndex(final Object key) {
        // key is not null here
        return key.hashCode() & m_mask;
    }

    /* Taken from FastUtil implementation */

    /**
     * Return the least power of two greater than or equal to the specified value.
     *
     * <p>
     * Note that this function will return 1 when the argument is 0.
     *
     * @param x a long integer smaller than or equal to 2<sup>62</sup>.
     * @return the least power of two greater than or equal to the specified value.
     */
    public static long nextPowerOfTwo(long x) {
        if (x == 0) return 1;
        x--;
        x |= x >> 1;
        x |= x >> 2;
        x |= x >> 4;
        x |= x >> 8;
        x |= x >> 16;
        return (x | x >> 32) + 1;
    }

    /**
     * Returns the least power of two smaller than or equal to 2<sup>30</sup> and larger than or equal to
     * <code>Math.ceil( expected / f )</code>.
     *
     * @param expected the expected number of elements in a hash table.
     * @param f        the load factor.
     * @return the minimum possible size for a backing array.
     * @throws IllegalArgumentException if the necessary size is larger than 2<sup>30</sup>.
     */
    public static int arraySize(final int expected, final float f) {
        final long s = Math.max(2, nextPowerOfTwo((long) Math.ceil(expected / f)));
        if (s > (1 << 30)) throw new IllegalArgumentException(
            "Too large (" + expected + " expected elements with load factor " + f + ")");
        return (int) s;
    }

    // taken from FastUtil
    private static final int INT_PHI = 0x9E3779B9;

    public static int phiMix(final int x) {
        final int h = x * INT_PHI;
        return h ^ (h >> 16);
    }
}