/* From the TODO:
 *    Allow INET + INT4 to increment the host part of the address, or
 *    throw an error on overflow
 */

#include "postgres.h"

#include <sys/socket.h>

#include "fmgr.h"
#include "utils/inet.h"

PG_FUNCTION_INFO_V1(inet_inc);

Datum
inet_inc(PG_FUNCTION_ARGS)
{
	inet	*in  = PG_GETARG_INET_P(0), *out;
	int32	 inc = PG_GETARG_INT32(1);
	inet_struct *src, *dst;
	uint32	 netmask, host, newhost;
	int		 i;

	src = (inet_struct *)VARDATA(in);
	if (src->family != PGSQL_AF_INET)
		ereport(ERROR,
			(errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
			 errmsg("Function \"inet_inc\" only supports AF_INET "
			                                                 "addresses")));

	/* avoid int32 overflow when bits == 0 */
	netmask = (src->bits == 0) ? 0 : (~((1 << (32 - src->bits)) - 1));

	/* if (inc doesn't fit in src->bits) overflow */
	if ((abs(inc) & ~netmask) != abs(inc))
		ereport(ERROR,
			(errcode(ERRCODE_DATA_EXCEPTION),
			 errmsg("Increment (%d) too big for network (/%d)",
				                                         inc, src->bits)));

	/* can do this with htonl/ntohl */
	host = 0;
	for (i=0; i<4; ++i)
		host |= src->ipaddr[i] << (8 * (3-i));

	if ((host & ~netmask) == 0)
		ereport(ERROR,
			(errcode(ERRCODE_DATA_EXCEPTION),
			 errmsg("Trying to increment a network (%d.%d.%d.%d/%d) rather "
			        "than a host", src->ipaddr[0], src->ipaddr[1],
				    src->ipaddr[2], src->ipaddr[3], src->bits)));

	newhost = host + inc;

	if (((host & netmask) != (newhost & netmask))
		|| (inc>0 && newhost<host)
		|| (inc<0 && newhost>host))
		ereport(ERROR,
			(errcode(ERRCODE_DATA_EXCEPTION),
			 errmsg("Increment (%d) takes address (%d.%d.%d.%d) out of its "
			        "network (/%d)", inc,
			        src->ipaddr[0], src->ipaddr[1], src->ipaddr[2],
			        src->ipaddr[3], src->bits)));

	out = (inet *)palloc0(VARHDRSZ + sizeof(inet_struct));

	dst = (inet_struct *)VARDATA(out);

	dst->family = src->family;
	dst->bits   = src->bits;
	dst->type   = src->type;
	for (i=0; i<4; ++i)
		dst->ipaddr[i] = (newhost >> (8 * (3-i))) & 0xff;
	for (i=4; i<16; ++i)
		dst->ipaddr[i] = 0;

	if ((inc < 0) && (newhost & ~netmask) == 0)
		ereport(ERROR,
			(errcode(ERRCODE_DATA_EXCEPTION),
			 errmsg("Increment returns a network (%d.%d.%d.%d/%d) rather "
			        "than a host", dst->ipaddr[0], dst->ipaddr[1],
				    dst->ipaddr[2], dst->ipaddr[3], dst->bits)));

	VARATT_SIZEP(out) = VARHDRSZ + sizeof(dst->family) + sizeof(dst->bits)
		+ sizeof(dst->type) + 4;

	PG_RETURN_INET_P(out);
}
