#include <math.h>
#include <stdio.h>
#include <stdlib.h>

typedef unsigned char uint8;
typedef long int64;
typedef unsigned long uint64;
typedef unsigned __int128 uint128;
typedef int64 (*permute_fn)(const int64, const int64, const int64);

#define PRP_PRIMES 16

static uint64 primes[PRP_PRIMES] = {
    8388617,
    8912921,
    9437189,
    9961487,
    10485767,
    11010059,
    11534351,
    12058679,
    12582917,
    13107229,
    13631489,
    14155777,
    14680067,
    15204391,
    15728681,
    16252967
};

#define PRP_ROUNDS 4

static uint64
compute_mask(uint64 n)
{
    n |= n >> 1;
    n |= n >> 2;
    n |= n >> 4;
    n |= n >> 8;
    n |= n >> 16;
    n |= n >> 32;
    return n;
}

static uint64
modular_multiply(uint64 x, uint64 y, const uint64 m)
{
    return (uint128) x * (uint128) y % (uint128) m;
}

#define DK_LCG_MUL 6364136223846793005L
#define DK_LCG_INC 1442695040888963407L

#define LCG_SHIFT 13

static int64
permute(const int64 data, const int64 isize, const int64 seed)
{
    uint64      size = (uint64) isize;
    uint64      v = (uint64) data % size;
    uint64      key = (uint64) seed;
    uint64      mask = compute_mask(size - 1) >> 1;

    if (isize == 1)
        return 0;

    for (unsigned int i = 0, p = key % PRP_PRIMES;
         i < PRP_ROUNDS; i++, p = (p + 1) % PRP_PRIMES)
    {
        uint64      t;

        key = key * DK_LCG_MUL + DK_LCG_INC;
        if (v <= mask)
            v ^= (key >> LCG_SHIFT) & mask;

        key = key * DK_LCG_MUL + DK_LCG_INC;
        t = size - 1 - v;
        if (t <= mask)
        {
            t ^= (key >> LCG_SHIFT) & mask;
            v = size - 1 - t;
        }

        while (size % primes[p] == 0)
            p = (p + 1) % PRP_PRIMES;

        key = key * DK_LCG_MUL + DK_LCG_INC;

        if ((v & 0xffffffffffL) == v)
            v = (primes[p] * v + (key >> LCG_SHIFT)) % size;
        else
            v = (modular_multiply(primes[p], v, size) +
                (key >> LCG_SHIFT)) % size;
    }

    return (int64) v;
}

static int64
permute2(const int64 data, const int64 isize, const int64 seed)
{
    unsigned short eseed[] = { (seed >> 32) & 0xffff,
                               (seed >> 16) &0xffff,
                               seed & 0xffff };
    uint64      size = (uint64) isize;
    uint64      v = (uint64) data % size;
    uint64      mask;
    uint64      top_bit;
    int         i;

    if (isize == 1)
        return 0;

    // This choice of mask satisfies size/2 <= mask <= size-1
    mask = compute_mask(size - 1);
    if (mask >= size) mask >>= 1;

    // Most significant bit of mask
    top_bit = (mask + 1) >> 1;

    for (i = 0; i < 4; i++)
    {
        uint64 m;
        uint64 r;
        uint64 t;

        m = (uint64) (erand48(eseed) * (mask + 1)) | 1;
        r = (uint64) (erand48(eseed) * (mask + 1));
        if (v <= mask)
        {
          v = ((v * m) ^ r) & mask;
          v = ((v << 1) & mask) | (v & top_bit ? 1 : 0);
        }

        r = (uint64) (erand48(eseed) * size);
        v = (v + r) % size;

        m = (uint64) (erand48(eseed) * (mask + 1)) | 1;
        r = (uint64) (erand48(eseed) * (mask + 1));
        t = size - 1 - v;
        if (t <= mask)
        {
          t = ((t * m) ^ r) & mask;
          t = ((t << 1) & mask) | (t & top_bit ? 1 : 0);
          v = size - 1 - t;
        }

        r = (uint64) (erand48(eseed) * size);
        v = (v + r) % size;
    }

    return (int64) v;
}

static int int64_cmp(const void *a, const void *b)
{
  int64 x = *((int64 *) a);
  int64 y = *((int64 *) b);
  return x - y;
}

int main()
{
  permute_fn permute_fn = &permute2;
  unsigned short seed[3] = { 1234, 5678, 9012 };
  int s[] = { 1000, 1001, 1002, 1020, 1021, 1022, 1023, 1024, 1025, 1026,
              2000, 3000, 3900, 3950, 4000, 4090, 4092, 4094, 4095, 4096,
              4097, 4098, 4100, 4200, 4300, 4400, 4500, 5000, 6000, 7000,
              8000, 9000, 9973, 10000, 10001, 10005, 10006, 10007,
              (1<<14)-3, (1<<14)-2, (1<<14)-1, 1<<14, (1<<14)+1, (1<<14)+2,
              (1<<15)-(1<<14)-1, (1<<15)-(1<<14), (1<<15)-1, 1<<15,
              (1<<15)+1, (1<<15)+(1<<14)-1, (1<<15)+(1<<14), (1<<15)+(1<<14)+1,
              1<<16, 1<<17, 1<<18, 1<<19, 1<<20, 1<<21, 1<<22, -1 };
  int x;
  int y;
  int64 size;
  int64 seed64;
  int i;

  for (x = 0, size = s[x]; size > 0; size = s[++x])
  {
    for (y = 1; y <= 4; y++)
    {
      uint64 v1 = (uint64) (erand48(seed) * size);
      uint64 v2 = (v1 + y) % size;
      int N = 10000 + size * 2;
      int64 *rvals = malloc(N * sizeof(int64));
      int64 *pvals = malloc(N * sizeof(int64));
      int64 *dvals = malloc(N * sizeof(int64));
      double Dr = 0;
      double Dp = 0;
      double Dd = 0;
      double Kr;
      double Kp;
      double Kd;
      double D_alpha;

      // Check that we really get a permuation
      seed64 = erand48(seed) * (1L<<48);
      for (i = 0; i < size; i++)
      {
        pvals[i] = (*permute_fn)(i, size, seed64);
      }
      qsort(pvals, size, sizeof(int64), int64_cmp);
      for (i = 0; i < size; i++)
      {
        if (pvals[i] != i)
        {
          printf("permute() failed (size=%ld, seed=%ld)\n", size, seed64);
          for (int j = 0; j < size && j < 100; j++) printf(" %ld", pvals[j]);
          printf("\npvals[%ld] = %ld\n", i, pvals[i]);
          return 1;
        }
      }

      // Test how uniformly random it is with a bunch of different seeds
      for (i = 0; i < N; i++)
      {
        rvals[i] = (int64) (erand48(seed) * size);

        seed64 = erand48(seed) * (1L<<48);
        pvals[i] = (*permute_fn)(v1, size, seed64);
        dvals[i] = pvals[i] - (*permute_fn)(v2, size, seed64);
        if (dvals[i] < 0) dvals[i] += size;
      }
      qsort(rvals, N, sizeof(int64), int64_cmp);
      qsort(pvals, N, sizeof(int64), int64_cmp);
      qsort(dvals, N, sizeof(int64), int64_cmp);
/*
if (size == (1<<10) && y == 1) {
  int64 last=pvals[0];
  int c=1;
  printf("dvals:\n");
  for (i = 1; i < N; i++)
    if (dvals[i] != last) {
      printf("%ld,%d\n", last, c);
      last = dvals[i];
      c=1;
    } else c++;
  printf("%ld,%d\n", last, c);
}
*/
      // K-S test for uniformity
      for (i = 1; i <= N; i++)
      {
        double D;

        D = (double) i / N - (double) rvals[i-1] / size;
        if (D > Dr) Dr = D;
        D = (double) rvals[i-1] / size - (double) (i-1) / N;
        if (D > Dr) Dr = D;

        D = (double) i / N - (double) pvals[i-1] / size;
        if (D > Dp) Dp = D;
        D = (double) pvals[i-1] / size - (double) (i-1) / N;
        if (D > Dp) Dp = D;

        D = (double) i / N - (double) dvals[i-1] / size;
        if (D > Dd) Dd = D;
        D = (double) dvals[i-1] / size - (double) (i-1) / N;
        if (D > Dd) Dd = D;
      }
      free(rvals);
      free(pvals);
      free(dvals);

      Kr = Dr * sqrt(N);
      Kp = Dp * sqrt(N);
      Kd = Dd * sqrt(N);

      // Critical value by confidence level
      //   0.001  1.94947
      //   0.01   1.62762
      //   0.02   1.51743
      //   0.05   1.35810
      //   0.1    1.22385
      //   0.15   1.13795
      //   0.2    1.07275
      D_alpha = 1.94947 / sqrt(N);

      printf("size=%ld, v1=%ld, v2=%ld, N=%d:\n"
             "  Dr=%f, Kr=%f %s\n"
             "  Dp=%f, Kp=%f %s\n"
             "  Dd=%f, Kd=%f %s\n",
             size, v1, v2, N,
             Dr, Kr, Dr > D_alpha ? "non-uniform" : "uniform",
             Dp, Kp, Dp > D_alpha ? "non-uniform" : "uniform",
             Dd, Kd, Dd > D_alpha ? "non-uniform" : "uniform");
    }
  }

  return 0;
}
