#include <inttypes.h>
#include <immintrin.h>

typedef uint64_t uint64;

#define true 1

#define FNV_PRIME 16777619

#define seq _mm_setr_epi8(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15)

// short aliases for intrinsics for readability

// AVX2  Latency 1, TP: 1 on Intel, 2-3 on Zen
#define broadcastb _mm_set1_epi8
#define broadcastd _mm_set1_epi32
// AVX Latency 1, TP: 2-4
#define cmpeqb _mm_cmpeq_epi8
#define cmpgtb _mm_cmpgt_epi8
// AVX Latency: 3 on Intel, 5 on AMD, TP: 1
#define movmskb _mm_movemask_epi8
// AVX Latency: 1 (2 for Zen4), TP: 2 (1 pre Ice Lake)
#define pshufb _mm_shuffle_epi8
// AVX Latency: 1, TP: 3 (4 for Zen4)
#define paddb _mm_add_epi8
// AVX Latency: 10 on Intel, 3-4 on AMD TP: 1 on Intel, 2 on >=Zen3
#define pmulld _mm_mullo_epi32
// AVX Latency: 1 TP: 2 on >=Ice Lake
#define psrldq _mm_srli_si128
#define psrld _mm_srli_epi32
// BMI1 Latency: 3, 2 on AMD (1 on Zen4), TP: 1 on Intel 2 on AMD
#define tzcnt _tzcnt_u32
#define aesenc _mm_aesenc_si128

static inline __m128i
align_vec_single(__m128i a, int offset)
{
    /*
     * Shift bytes to start of vector, replace shifted in bytes with 0
     * by having 0x70+seq+offset overflow into high bit, which makes pshufb
     * replace that byte with 0.
     *
     * Relying on compiler to lift the loop invariant paddb out of loops.
     **/ 
    return pshufb(a, paddb(seq, broadcastb(0x70+offset)));
}

static inline __m128i
align_vec(__m128i a, __m128i b, int offset)
{
    /*
     * Or together last offset bytes shifted right by offset with first offset
     * bytes. Have overflow and underflow into negative replace bytes with
     * zero.
     **/
    return pshufb(a, paddb(seq, broadcastb(0x70+offset)))
         | pshufb(b, paddb(seq, broadcastb(-16 + offset)));
}

static inline int
find_zeroes(__m128i a)
{
    return movmskb(cmpeqb(a, broadcastb(0)));
}

#ifdef HAS_AES
// lifted from https://github.com/tildeleb/aeshash/blob/master/aeshash.go
#define key1 _mm_setr_epi8(0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,\
                           0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10)
#define key2 _mm_setr_epi8(0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18,\
                           0x19, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F, 0xFF)

static inline __m128i
mix(__m128i hash, __m128i data)
{
    hash = aesenc(hash, key1);
    return aesenc(hash, data);
}

static inline uint64
finalize(__m128i hash)
{
    hash = pshufb(hash, _mm_setr_epi8(0, 5, 10, 15,
                                      4, 9, 14, 3,
                                      8, 13, 2, 7,
                                      12, 1, 6, 11));
    hash = aesenc(hash, key1);
    hash = aesenc(hash, key2);
    hash = aesenc(hash, key1);
        
    return (hash ^ psrldq(hash, 8))[0];
}
    

#else
static inline __m128i
mix(__m128i hash, __m128i data)
{
    __m128i tmp = hash ^ data;
    return pmulld(tmp, broadcastd(FNV_PRIME)) ^ psrld(tmp, 17);
}

static inline uint64
finalize(__m128i hash)
{
    // Mix vector rotated by a single word so high word gets mixed in with
    // low word in 64bit result
    hash = mix(hash, pshufb(hash, paddb(seq, broadcastb(4))));
    return (hash ^ psrldq(hash, 8))[0];
}
#endif


/*
 * Fast vector hash calculates 4 parallel 32bit hashes across the data. The
 * finalizer is responsible for mixing the parallel hashes to a single value.
 * For strings that are not a multiple of 16 bytes the string is padded with
 * zeroes.
 *
 * It would possible to pad to 4 byte alignment and skip the final hash
 * iteration for the hashes with a missing value, but hashing in zeroes is
 * faster and only improves mixing.
 */
uint64
fast_vec_hash_cstring_aligned(__m128i hash, char *buf)
{
    char *cur = buf;
    while (true)
    {
        __m128i chunk = *(__m128i*) cur;
        int mask = find_zeroes(chunk);
        if (mask) {
            int end = tzcnt(mask);
            // Mask out everything past the end
            chunk &= cmpgtb(broadcastb(end), seq);
            return finalize(mix(hash, chunk));
        }
        hash = mix(hash, chunk);
        cur += sizeof(chunk);
    }
}

/*
 * Unaligned version of vectorized hash performs alignment in SIMD vectors.
 * x86 supports unaligned loads, but we don't want to use it for two reasons.
 * Loads that straddle cache line boundaries are significantly slower than 
 * loads within a cacheline, even more so for loads across page boundaries.
 * More importantly loads crossing page boundaries might segfault if the next
 * page happens to be unallocated.
 */
uint64
fast_vec_hash_cstring(__m128i hash, char *buf)
{
    int offset = ((uintptr_t) buf) & (sizeof(__m128i) - 1);

#ifdef SPECIAL_CASE_ALIGNED 
    /*
     * Instruction analysis shows that inner loop is mixing latency bound
     * so alignment overhead should not matter.
     **/ 
    if (SPECIAL_CASE_ALIGNED && !offset)
        return fast_vec_hash_cstring_aligned(hash, buf);
#endif

    char *cur = buf - offset;
    __m128i chunk = *(__m128i*) cur;

    // Mask out first offset bytes to not match string end there
    int mask = find_zeroes(chunk) & (~0L << offset);
    // If string ends in first chunk can return immediately
    if (mask)
    {
        int end = tzcnt(mask);
        // Mask out everything past the end
        chunk &= cmpgtb(broadcastb(end), seq);        
        return finalize(mix(hash, align_vec_single(chunk, offset)));
    }

    /*
     * Need to keep track of 2 vectors to perform alignment.
     *
     * _ <- already hashed, or before string
     * # <- data for next iteration if string did not end
     * prev                     chunk
     * [____|_012|3456|789A]    [BCDE|F###|####|####]
     *        ^offset
     */
    cur += sizeof(chunk);
    
    while (true)
    {
        __m128i prev = chunk;
        chunk = *(__m128i*) cur;
        mask = find_zeroes(chunk);
        // Found end of string
        if (mask) {
            int end = tzcnt(mask);
            // Mask everything past end of string with 0
            chunk &= cmpgtb(broadcastb(end), seq);
            hash = mix(hash, align_vec(prev, chunk, offset));
            if (end > offset) {
                hash = mix(hash, align_vec_single(chunk, offset));
            }
            return finalize(hash);
        }
        hash = mix(hash, align_vec(prev, chunk, offset));
        cur += sizeof(chunk);
    }
}

