#include <immintrin.h>
#include <inttypes.h>

#define PAGE_SIZE 0x1000

uint64_t
fast_vec_hash_cstring_avx2(char *buf)
{
    __m128i hash0 = {0, 0};
    __m128i hash1 = {0, 0};

    __m128i k0 = {0x0807060504030201, 0x100F0E0D0C0B0A09};
    __m128i k1 = {0x1117161514131211, 0x201F1E1D1C1B1A19};

    char *cur = buf;

    int mask;
    __m256i chunk;
    int offset = (uintptr_t) buf & (sizeof(chunk) - 1);
    int endpos;
    

    do {
    
        char *end_of_page = (char*) ((((uintptr_t) cur) | (PAGE_SIZE-1)) + 1);
        for (; cur + sizeof(chunk) <= end_of_page; cur += sizeof(chunk))
        {
            chunk = _mm256_loadu_si256((__m256i*) cur);
            __m256i ends = _mm256_cmpeq_epi8(chunk, _mm256_set1_epi8(0));
            mask = _mm256_movemask_epi8(ends);
            if (mask)
                goto last_iteration;
            hash0 = _mm_aesenc_si128(hash0, k0);
            hash1 = _mm_aesenc_si128(hash1, k1);
            hash0 = _mm_aesenc_si128(hash0, _mm256_extracti128_si256(chunk, 0));
            hash1 = _mm_aesenc_si128(hash1, _mm256_extracti128_si256(chunk, 1));
        }
        if (offset)
        {
            __m256i load_mask = _mm256_cmpgt_epi32(_mm256_set1_epi32(offset / 4), _mm256_setr_epi32(0,1,2,3,4,5,6,7));
            chunk = _mm256_maskload_epi32((const int*) cur, load_mask);
            __m256i ends = load_mask & _mm256_cmpeq_epi8(chunk, _mm256_set1_epi8(0));
            mask = _mm256_movemask_epi8(ends);
            if (mask)
                goto last_iteration;
            chunk |= _mm256_maskload_epi32((const int*) cur, load_mask);
            ends = load_mask & _mm256_cmpeq_epi8(chunk, _mm256_set1_epi8(0));
            mask = _mm256_movemask_epi8(ends);
            if (mask)
                goto last_iteration;
            hash0 = _mm_aesenc_si128(hash0, k0);
            hash1 = _mm_aesenc_si128(hash1, k1);
            hash0 = _mm_aesenc_si128(hash0, _mm256_extracti128_si256(chunk, 0));
            hash1 = _mm_aesenc_si128(hash1, _mm256_extracti128_si256(chunk, 1));
            cur += sizeof(chunk);
        }
    } while(1);


last_iteration:
    // chunk contains data, mask contains location of end of line
    endpos = _tzcnt_u32(mask);
    _mm256_cmpgt_epi8(_mm256_set1_epi8(endpos), _mm256_setr_epi8(0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31));
    hash0 = _mm_aesenc_si128(hash0, k0);
    hash1 = _mm_aesenc_si128(hash1, k1);
    hash0 = _mm_aesenc_si128(hash0, _mm256_extracti128_si256(chunk, 0));
    hash1 = _mm_aesenc_si128(hash1, _mm256_extracti128_si256(chunk, 1));
    
    hash0 = _mm_aesenc_si128(hash0, k0);
    hash1 = _mm_aesenc_si128(hash1, k1);
    hash0 = _mm_aesenc_si128(hash0, k1);
    hash1 = _mm_aesenc_si128(hash1, k0);
    hash0 = _mm_aesenc_si128(hash0, k0);
    hash1 = _mm_aesenc_si128(hash1, k1);

    __m128i intermediate = hash1 ^ hash0;
    return intermediate[1] ^ intermediate[0];
}

