#include <nmmintrin.h>

#include <stdint.h>
#include <stdio.h>
#include <stdbool.h>
#include <string.h>

/* print the byte contents of a register */
void p128_hex_u8(const __m128i in, const char *msg) {
    uint8_t v[16];
    _mm_storeu_si128((__m128i*)v, in);
    printf("%s: %x %x %x %x | %x %x %x %x | %x %x %x %x | %x %x %x %x\n",
            msg,
           v[0], v[1],  v[2],  v[3],  v[4],  v[5],  v[6],  v[7],
           v[8], v[9], v[10], v[11], v[12], v[13], v[14], v[15]);
}

#define TOO_SHORT   (1 << 0)	/* 11______ 0_______ */
								/* 11______ 11______ */
#define TOO_LONG    (1 << 1)	/* 0_______ 10______ */
#define OVERLONG_3  (1 << 2)	/* 11100000 100_____ */
#define SURROGATE   (1 << 4)	/* 11101101 101_____ */
#define OVERLONG_2  (1 << 5)	/* 1100000_ 10______ */
#define TWO_CONTS   (1 << 7)	/* 10______ 10______ */
#define TOO_LARGE   (1 << 3)	/* 11110100 1001____ */
								/* 11110100 101_____ */
								/* 11110101 1001____ */
								/* 11110101 101_____ */
								/* 1111011_ 1001____ */
								/* 1111011_ 101_____ */
								/* 11111___ 1001____ */
								/* 11111___ 101_____ */
#define TOO_LARGE_1000 (1 << 6)
								// 11110101 1000____
								// 1111011_ 1000____
								// 11111___ 1000____
#define OVERLONG_4  (1 << 6)	/* 11110000 1000____ */

// These all have ____ in byte 1 .
#define CARRY (TOO_SHORT | TOO_LONG | TWO_CONTS)

static inline const __m128i
byte_1_high_table()
{
	return _mm_setr_epi8(
		// 0_______ ________ <ASCII in byte 1>
		TOO_LONG, TOO_LONG, TOO_LONG, TOO_LONG,
		TOO_LONG, TOO_LONG, TOO_LONG, TOO_LONG,
		// 10______ ________ <continuation in byte 1>
		TWO_CONTS, TWO_CONTS, TWO_CONTS, TWO_CONTS,
		// 1100____ ________ <two byte lead in byte 1>
		TOO_SHORT | OVERLONG_2,
		// 1101____ ________ <two byte lead in byte 1>
		TOO_SHORT,
		// 1110____ ________ <three byte lead in byte 1>
		TOO_SHORT | OVERLONG_3 | SURROGATE,
		// 1111____ ________ <four+ byte lead in byte 1>
		TOO_SHORT | TOO_LARGE | TOO_LARGE_1000 | OVERLONG_4
    );
}

static inline const __m128i
byte_1_low_table()
{
	return _mm_setr_epi8(
		// ____0000 ________
		CARRY | OVERLONG_3 | OVERLONG_2 | OVERLONG_4,
		// ____0001 ________
		CARRY | OVERLONG_2,
		// ____001_ ________
		CARRY,
		CARRY,

		// ____0100 ________
		CARRY | TOO_LARGE,
		// ____0101 ________
		CARRY | TOO_LARGE | TOO_LARGE_1000,
		// ____011_ ________
		CARRY | TOO_LARGE | TOO_LARGE_1000,
		CARRY | TOO_LARGE | TOO_LARGE_1000,

		// ____1___ ________
		CARRY | TOO_LARGE | TOO_LARGE_1000,
		CARRY | TOO_LARGE | TOO_LARGE_1000,
		CARRY | TOO_LARGE | TOO_LARGE_1000,
		CARRY | TOO_LARGE | TOO_LARGE_1000,
		CARRY | TOO_LARGE | TOO_LARGE_1000,
		// ____1101 ________
		CARRY | TOO_LARGE | TOO_LARGE_1000 | SURROGATE,
		CARRY | TOO_LARGE | TOO_LARGE_1000,
		CARRY | TOO_LARGE | TOO_LARGE_1000
    );
}

static inline const __m128i
byte_2_high_table()
{
	return _mm_setr_epi8(
      // ________ 0_______ <ASCII in byte 2>
      TOO_SHORT, TOO_SHORT, TOO_SHORT, TOO_SHORT,
      TOO_SHORT, TOO_SHORT, TOO_SHORT, TOO_SHORT,

      // ________ 1000____
      TOO_LONG | OVERLONG_2 | TWO_CONTS | OVERLONG_3 | TOO_LARGE_1000 | OVERLONG_4,
      // ________ 1001____
      TOO_LONG | OVERLONG_2 | TWO_CONTS | OVERLONG_3 | TOO_LARGE,
      // ________ 101_____
      TOO_LONG | OVERLONG_2 | TWO_CONTS | SURROGATE  | TOO_LARGE,
      TOO_LONG | OVERLONG_2 | TWO_CONTS | SURROGATE  | TOO_LARGE,

      // ________ 11______
      TOO_SHORT, TOO_SHORT, TOO_SHORT, TOO_SHORT
    );
}


static inline const __m128i
vzero()
{
    return _mm_setzero_si128();
}

/* Return a vector with each 8-bit lane populated with the input */
static inline __m128i
splat(uint8_t byte)
{
    return _mm_set1_epi8(byte);
}

static inline __m128i
greater_than(const __m128i v1, const __m128i v2)
{
    return _mm_cmpgt_epi8(v1, v2);
}

/*
 * Shift right each 8-bit lane
 *
 * There is no intrinsic to do this on bytes, so shift right in each
 * 16-bit lane then apply a mask of 1-bytes shifted the same amount.
 */
static inline __m128i
shift_right(const __m128i v, const int n)
{
    const __m128i shift16 = _mm_srli_epi16(v, n);
    const __m128i mask = splat(0xFF >> n);
    return _mm_and_si128(shift16, mask);
}

static inline __m128i
bitwise_and(const __m128i v1, const __m128i v2)
{
    return _mm_and_si128(v1, v2);
}

static inline __m128i
bitwise_or(const __m128i v1, const __m128i v2)
{
    return _mm_or_si128(v1, v2);
}

static inline __m128i
bitwise_xor(const __m128i v1, const __m128i v2)
{
    return _mm_xor_si128(v1, v2);
}

/*
 * Do unsigned subtraction, but instead of wrapping around
 * on overflow, stop at zero.
 */
static inline __m128i
saturating_sub(const __m128i v1, const __m128i v2)
{
    return _mm_subs_epu8(v1, v2);
}

/*
 * Shift entire "input" register right by one 8-bit lanes, and
 * replace the first lane with the last lane from the 
 * "previous" register.
 *
 * (prev << 128) | input) >> (c * 8)
 *
 * The third argument to the intrinsic must be a constant, so
 * we must have different functions for different shift amounts.
 */
static inline __m128i
prev1(__m128i prev, __m128i input)
{
    return _mm_alignr_epi8(input, prev, 16 - 1);
}

static inline __m128i
prev2(__m128i prev, __m128i input)
{
    return _mm_alignr_epi8(input, prev, 16 - 2);
}

static inline __m128i
prev3(__m128i prev, __m128i input)
{
    return _mm_alignr_epi8(input, prev, 16 - 3);
}

/*
 * For each 1-byte lane in the input, use that value as an index
 * into the lookup register as if it were an array.
 */
static inline __m128i
lookup_16(const __m128i input, __m128i lookup)
{
    return _mm_shuffle_epi8(lookup, input);
}

static inline bool
to_bool(const __m128i v)
{
	const int bitmask = _mm_movemask_epi8(_mm_cmpeq_epi8(v, vzero()));
	if (bitmask)
		return true;
	else
		return false;
}


static __m128i
classify(const __m128i prev, const __m128i input)
{
	p128_hex_u8(input, "input");

	const __m128i input_shift1 = prev1(vzero(), input);

	const __m128i byte_1_high = shift_right(input_shift1, 4);
	const __m128i byte_1_low  = bitwise_and(input_shift1, splat(0x0F));
	const __m128i byte_2_high = shift_right(input, 4);

	p128_hex_u8(byte_1_high, "byte_1_high");
	p128_hex_u8(byte_1_low, "byte_1_low");
	p128_hex_u8(byte_2_high, "byte_2_high");

	const __m128i lookup_1_high = lookup_16(byte_1_high, byte_1_high_table());
	const __m128i lookup_1_low  = lookup_16(byte_1_low, byte_1_low_table());
	const __m128i lookup_2_high = lookup_16(byte_2_high, byte_2_high_table());

	p128_hex_u8(lookup_1_high, "lookup_1_high");
	p128_hex_u8(lookup_1_low, "lookup_1_low");
	p128_hex_u8(lookup_2_high, "lookup_2_high");

	const __m128i ret = bitwise_and(bitwise_and(lookup_1_high, lookup_1_low), lookup_2_high);
	p128_hex_u8(ret, "multiple conts.");

	return ret;
}

static __m128i
get_lead_byte_mask(const __m128i prev, const __m128i input, const __m128i special_cases)
{
	const __m128i input_shift2 = prev2(prev, input);
	const __m128i input_shift3 = prev3(prev, input);

	/*
	 * There is no unsigned comparison, so we use saturating subtraction
	 * followed by signed comparison with zero. 
	 */
	const __m128i is_third_byte  = saturating_sub(input_shift2, splat(0b11100000u-1));
	const __m128i is_fourth_byte = saturating_sub(input_shift3, splat(0b11110000u-1));

	// p128_hex_u8(is_third_byte, "is_third_byte");
	// p128_hex_u8(is_fourth_byte, "is_fourth_byte");

	const __m128i temp = bitwise_or(is_third_byte, is_fourth_byte);

	/*
	 * If the continuation matches to a valid header, set all bits for that byte
	 */
	const __m128i must23 = greater_than(temp, vzero());
	p128_hex_u8(must23, "is_3rd_or_4th");

	/* greater_than() sets all bits in the result when true. We want to compare 
	 * with the result of the classifier so apply a mask to set only the high bit. */
	const __m128i must23_80 =  bitwise_and(must23, splat(0x80));
	return must23_80;

}

static __m128i Previous;

int main()
{
	Previous = vzero();

// 	p128_hex_u8(byte_1_high_table(), "byte_1_high_table");
// 	p128_hex_u8(byte_1_low_table(), "byte_1_low_table");
// 	p128_hex_u8(byte_2_high_table(), "byte_1_high_table");

	//__m128i input = _mm_setr_epi8(0x39, 0xC3, 0xA7, 0xE9, 0x8F, 0xA1, 0xF0, 0x9F, 0x98, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00);
	const char *raw = "9¢鏡🙂";
    char buffer[16];
    memset(buffer, 0, 16);
    memcpy(buffer, raw, strlen(raw));

	__m128i input = _mm_loadu_si128((const __m128i *)buffer);

	// classify input
	const __m128i special_cases = classify(Previous, input);

	// check for 3- and 4-byte lengths
	const __m128i mask = get_lead_byte_mask(Previous, input, special_cases);

	/* After XOR-ing with the special cases, all valid bytes give zero */
	const __m128i error = bitwise_xor(mask, special_cases);
	p128_hex_u8(error, "error");

}