#include <immintrin.h>
#include <stdio.h>
#include <stdint.h>

__attribute__((target("sse4.2")))
static uint32_t
crc32c_scalar(const unsigned char *data, ssize_t len, uint32_t crc)
{
    const unsigned char *p = data;
    const unsigned char *pend = p + len;

    while (p + 8 <= pend)
    {
        crc = (uint32_t) _mm_crc32_u64(crc, *((const uint64_t *) p));
        p += 8;
    }

    /* Process remaining full four bytes if any */
    if (p + 4 <= pend)
    {
        crc = _mm_crc32_u32(crc, *((const unsigned int *) p));
        p += 4;
    }

    /* go byte by byte: */
    while (p < pend)
    {
        crc = _mm_crc32_u8(crc, *p);
        p++;
    }

    return crc;
}

#define clmul_lo_(a, b) (_mm512_clmulepi64_epi128((a), (b), 0))
#define clmul_hi_(a, b) (_mm512_clmulepi64_epi128((a), (b), 17))

__attribute__((target("avx512vl,vpclmulqdq")))
static uint32_t
crc32c_avx512(const unsigned char* data, ssize_t length, uint32_t crc)
{
	/* adjust names to match generated code */
	uint32_t crc0 = crc;
	size_t		len = length;
	const unsigned char *buf = data;

	if (len >= 64)
	{
		const unsigned char *end = buf + len;
		const unsigned char *limit = buf + len - 64;
		__m128i		z0;

		/* First vector chunk. */
		__m512i		x0 = _mm512_loadu_si512((const void *) buf),
					y0;
		__m512i		k;

		k = _mm512_broadcast_i32x4(_mm_setr_epi32(0x740eef02, 0, 0x9e4addf8, 0));
		x0 = _mm512_xor_si512(_mm512_castsi128_si512(_mm_cvtsi32_si128(crc0)), x0);
		buf += 64;

		/* Main loop. */
		while (buf <= limit)
		{
			y0 = clmul_lo_(x0, k), x0 = clmul_hi_(x0, k);
			x0 = _mm512_ternarylogic_epi64(x0, y0, _mm512_loadu_si512((const void *) buf), 0x96);
			buf += 64;
		}

		/* Reduce 512 bits to 128 bits. */
		k = _mm512_setr_epi32(0x1c291d04, 0, 0xddc0152b, 0, 0x3da6d0cb, 0, 0xba4fc28e, 0, 0xf20c0dfe, 0, 0x493c7d27, 0, 0, 0, 0, 0);
		y0 = clmul_lo_(x0, k), k = clmul_hi_(x0, k);
		y0 = _mm512_xor_si512(y0, k);
		z0 = _mm_ternarylogic_epi64(_mm512_castsi512_si128(y0), _mm512_extracti32x4_epi32(y0, 1), _mm512_extracti32x4_epi32(y0, 2), 0x96);
		z0 = _mm_xor_si128(z0, _mm512_extracti32x4_epi32(x0, 3));

		/* Reduce 128 bits to 32 bits, and multiply by x^32. */
		crc0 = _mm_crc32_u64(0, _mm_extract_epi64(z0, 0));
		crc0 = _mm_crc32_u64(crc0, _mm_extract_epi64(z0, 1));
		len = end - buf;
	}

	return crc0;
}


static uint8_t randomval()
{
    return (rand() % 255);
}

int main() {
    const int size = 64;
    /* Initialize to random values */
    unsigned char arr[size];
    srand(42);
    for (size_t ii = 0; ii < size; ++ii) {
        arr[ii] = randomval();
    }

    /* Compute crc32c using simple scalar methods and SIMD method */
    uint32_t avxcrc = crc32c_avx512(arr, size, 0xFFFFFFFF);
    uint32_t scalar_crc = crc32c_scalar(arr, size, 0xFFFFFFFF);

    /* ASSERT values are the same */
    if (scalar_crc != avxcrc) {
        printf("Failed! ");
    }
    else {
        printf("Success! ");
    }
    printf("0x%x, 0x%x\n", scalar_crc, avxcrc);
    return 0;
}

