
bool
atts_in_list(List *cKeys, List *atts)
{
	ListCell   *ilist;
	ListCell   *keylist;
	ListCell   *attlist;

	if (atts == NIL)
		return false;

	foreach(ilist, cKeys)
	{
		CandidateKey *ck = (CandidateKey *) lfirst(ilist);

		int i;
		bool found;
		for (i = 0; i < ck->nVars; i++) {
			found = false;
			foreach(keylist, ck->vars[i])  // each var in attr[i] list
			{
				Var *keyvar = (Var*) lfirst(keylist);
				foreach(attlist, atts) // check each attr passed in
				{
					Var *attvar = (Var*) lfirst(attlist);
					if (keyvar->varno == attvar->varno && keyvar->varattno == attvar->varattno) {
						found = true;
						break;
					}
				}
				if (found) // this attribute is covered so exit early
					break;
			}

			if (!found)
			{
				/*
				* we didn't find the previous index attribute in our list
				* so we can't use this candidate key
				*/
				break;
			}
		}
		if (found)
			return true;
	}
	return false;
}

/* convert a postgres index into a candidate key
*  caller is responsible for cleaning up memory
*/
List *convertUniqueIndexesToCandidateKeys(RelOptInfo *rel)
{
	ListCell *ilist;
	List *result = NIL;
	foreach (ilist, rel->indexlist)
	{
		IndexOptInfo *index = (IndexOptInfo *) lfirst(ilist);

		/*
		* Note: ignore partial indexes, since they don't allow us to conclude
		* that all attr values are distinct.  We don't take any interest in
		* expressional indexes either.
		*/
		if (index->unique && index->indpred == NIL)
		{
			CandidateKey *ckey = palloc(sizeof(CandidateKey));
			int i;
			ckey->nVars = index->ncolumns;
			ckey->vars = palloc(ckey->nVars * sizeof(List *));
			for (i = 0; i < index->ncolumns; ++i)
			{
				Var *var = palloc(sizeof(Var));
				var->varno = rel->relid;
				var->varattno = index->indexkeys[i];
				ckey->vars[i] = lcons(var, NIL);
			}
			result = lappend(result, ckey);
		}
	}
	return result;
}

List *
unionCandidateKeys(List *inner, List *outer, JoinAttrs *eq)
{
	ListCell *inAtts, *outAtts;
	ListCell *ilist;

	if (inner == NIL)
		return outer;
	if (outer == NIL)
		return inner;

	inAtts = list_head(eq->innerVars);
	outAtts = list_head(eq->outerVars);

	/* build the equivalence and concat */
	while (inAtts != NULL)
	{
		Var *ivar = lfirst(inAtts);
		Var *ovar = lfirst(outAtts);
		foreach (ilist, inner)
		{
			CandidateKey *ckey = (CandidateKey *)lfirst(ilist);
			int i;
			for (i = 0; i < ckey->nVars; ++i)
			{
				List *vars = ckey->vars[i];
				List *appLst = NIL;
				ListCell *vlist;
				foreach (vlist, vars)
				{
					Var *v = lfirst(vlist);
					if (v->varno == ivar->varno && v->varattno == ivar->varattno)
					{
						appLst = lappend(appLst, ovar);
						break;
					}
					else if (v->varno == ovar->varno && v->varattno == ovar->varattno)
					{
						appLst = lappend(appLst, ivar);
						break;
					}
				}
				list_concat(vars, appLst);
			}
		}

		foreach (ilist, outer)
		{
			CandidateKey *ckey = (CandidateKey *)lfirst(ilist);
			int i;
			for (i = 0; i < ckey->nVars; ++i)
			{
				List *vars = ckey->vars[i];
				List *appLst = NIL;
				ListCell *vlist;
				foreach (vlist, vars)
				{
					Var *v = lfirst(vlist);
					if (v->varno == ivar->varno && v->varattno == ivar->varattno)
					{
						appLst = lappend(appLst, ovar);
						break;
					}
					else if (v->varno == ovar->varno && v->varattno == ovar->varattno)
					{
						appLst = lappend(appLst, ivar);
						break;
					}
				}
				list_concat(vars, appLst);
			}
		}

		inAtts = lnext(inAtts);
		outAtts = lnext(outAtts);
	}
	return list_concat(inner, outer);
}

List *
unionMNCandidateKeys(List *inner, List *outer)
{
	ListCell *inAtts, *outAtts;
	List *result = NIL;

	/* TODO: Is this correct? */
	if (inner == NIL)
		return outer;
	if (outer == NIL)
		return inner;


	foreach (inAtts, inner)
	{
		CandidateKey *ikey = (CandidateKey *)lfirst(inAtts);
		foreach (outAtts, outer)
		{
			int i,j;
			CandidateKey *okey = (CandidateKey *)lfirst(outAtts);
			CandidateKey *newkey = palloc(sizeof(CandidateKey));
			newkey->nVars = ikey->nVars + okey->nVars;
			newkey->vars = palloc(newkey->nVars * sizeof(List *));

			for (i = 0; i < ikey->nVars; ++i)
			{
				newkey->vars[i] = ikey->vars[i];
			}

			for (j = 0; j < okey->nVars; ++j, ++i)
			{
				newkey->vars[i] = okey->vars[j];
			}
			result = lappend(result, newkey);
		}
	}

	return result;
}

List *
subtractCandidateKeys(List *keys, List *attrs)
{
	ListCell *klist;
	List *result = NIL;
	foreach (klist, keys)
	{
		CandidateKey *key = lfirst(klist);
		int i, nvars = key->nVars;
		for (i = 0; i < key->nVars; i++)
		{
			List *vars = key->vars[i];
			ListCell *vlist, *alist;
			foreach (vlist, vars)
			{
				Var *keyvar = lfirst(vlist);
				foreach (alist, attrs)
				{
					Var *attvar = lfirst(alist);
					if (keyvar->varno == attvar->varno && keyvar->varattno == attvar->varattno)
					{
						/* remove this attr */
						nvars--;
						//list_free(key->vars[i]);
						key->vars[i] = NIL;
						break;
					}
				}
				if (key->vars[i] == NIL)
					break;
			}
		}
		/* the adjusted key is now ready */
		if (nvars == 0)
		{
			// key is dead
			pfree(key->vars);
			//pfree(key);
			key->nVars = 0;
			key->vars = NULL;
		}
		else if(nvars == key->nVars)
		{
			/* key was unmodified */
			result = lappend(result, key);
		}
		else
		{
			int j;
			//CandidateKey *newKey = palloc(sizeof(CandidateKey));
			//newKey->nVars = nvars;
			//newKey->vars = palloc(sizeof(List *) * nvars);
			/* need to make a new key */
			for (j = 0, i = 0; j < key->nVars; j++)
			{
				if (key->vars[j] != NIL)
				{
					key->vars[i++] = key->vars[j];
				}
			}
			key->nVars = nvars;
			result = lappend(result, key);
		}
	}
	return result;
}


Join_Cardinality getJoinCard(JoinPath *jpath, PlannerInfo *root)
{
	Path	   *inner_path = jpath->innerjoinpath;
	Path	   *outer_path = jpath->outerjoinpath;
	Join_Cardinality result = Card_Invalid;
	bool inMJoin = true;
	bool outMJoin = true;
	JoinAttrs jattrs;
	List *inner = NIL;
	List *outer = NIL;

	jattrs = getJoinAttrs(jpath, root);
	inner = getCandidateKeys(inner_path, root);
	outer = getCandidateKeys(outer_path, root);

	if (inner)
	{
		if (atts_in_list(inner, jattrs.innerVars))
			inMJoin = false;

		/* TODO: clean up properly */
		list_free(inner);
		inner = NIL;
	}

	if (outer)
	{
		if (atts_in_list(outer, jattrs.outerVars))
			outMJoin = false;

		/* TODO: clean up properly */
		list_free(outer);
		outer = NIL;
	}

	/* TODO: clean up join attrs */



	if (outMJoin && inMJoin)
		result = Card_Many_to_Many;
	else if (inMJoin)
		result = Card_Many_to_One;
	else if (outMJoin)
		result = Card_One_to_Many;
	else
		result = Card_One_to_One;

	return result;
}

List *
getCandidateKeys(Path *path, PlannerInfo *root)
{
	List *result = NIL;

	switch (path->pathtype)
	{
	case T_Scan:
	case T_SeqScan:
	case T_IndexScan:
	case T_BitmapIndexScan:
	case T_BitmapHeapScan:
	case T_TidScan:
	case T_FunctionScan:
	case T_ValuesScan:
		//ereport(NOTICE, (errmsg("Base relation scan on %d", path->parent->relid)));
		result = convertUniqueIndexesToCandidateKeys(path->parent);
		//	ereport(NOTICE, (errmsg("  Attr: %d",index->indexkeys[i])));
		break;

	case T_Join:
	case T_NestLoop:
	case T_MergeJoin:
	case T_HashJoin:
	case T_DynHashJoin:
	case T_EarlyHashJoin:
		{
			Path	   *inner_path = ((JoinPath*)path)->innerjoinpath;
			Path	   *outer_path = ((JoinPath*)path)->outerjoinpath;
			List	   *inner;
			List	   *outer;
			JoinAttrs	jattrs;
			bool		inMJoin = true;
			bool		outMJoin = true;
			bool		equijoin = true;

			jattrs = getJoinAttrs((JoinPath*)path, root);
			inner = getCandidateKeys(inner_path, root);
			outer = getCandidateKeys(outer_path, root);

			if (atts_in_list(inner, jattrs.innerVars))
				inMJoin = false;

			if (atts_in_list(outer, jattrs.outerVars))
				outMJoin = false;

			if (!outMJoin && !inMJoin)
				result = unionCandidateKeys(inner, outer, &jattrs);
			else if (!inMJoin)
				result = outer;
			else if (!outMJoin)
				result = inner;
			else
			{
				/* M:N - special case */
				inner = subtractCandidateKeys(inner, jattrs.innerVars);
				outer = subtractCandidateKeys(outer, jattrs.outerVars);
				if (inner != NIL && outer != NIL) {
					result = unionMNCandidateKeys(inner, outer);
				}
			}

			//ereport(NOTICE, (errmsg("Join")));
			break;
		}
	case T_UniquePath:
	case T_Unique:
		//ereport(NOTICE, (errmsg("Distinct")));
		{
			UniquePath	   *upath = (UniquePath *)path;
			CandidateKey   *ckey;
			ListCell	   *tlist;
			List		   *atts = NIL;
			/* just take the target list atts as unique! */
			foreach (tlist, path->parent->reltargetlist)
			{
				Var *var = lfirst(tlist);
				if (IsA(var, Var))
				{
					atts = lappend(atts, var);
				}
			}
			if (atts != NIL)
			{
				int i = 0;
				ckey = palloc(sizeof(CandidateKey));
				ckey->nVars = list_length(atts);
				ckey->vars = palloc(sizeof(List *) * ckey->nVars);
				foreach(tlist, atts)
				{
					ckey->vars[i++] = lcons((Var*)lfirst(tlist),NIL);
				}
				result = lcons(ckey, result);
			}
		}
		break;
	case T_SubqueryScan:
		switch (path->parent->subplan->type)
		{
		case T_Agg:
			{
				Agg *subplan = (Agg *)path->parent->subplan;
				//ereport(NOTICE, (errmsg("Subquery with aggregation")));
				switch (subplan->aggstrategy)
				{
				case AGG_HASHED:
				case AGG_SORTED:
					{
						/* GroupAggregation */
						ListCell *rlist;
						ListCell *tlist;
						List *rVars = NIL;
						CandidateKey *ckey = NULL;
						rlist = list_head(path->parent->reltargetlist);
						/* take any Vars from the parent reltargetlist with a corresponding plain var in the subplan
						 * target list. This should be all the grouping attributes.
						 */
						foreach (tlist, path->parent->subplan->targetlist)
						{
							Expr *var = ((TargetEntry *)lfirst(tlist))->expr;
							if (IsA(var, Var))
							{
								Var *rvar = lfirst(rlist);
								rVars = lappend(rVars, rvar);
							}
							rlist = lnext(rlist);
							if (rlist == NULL)
								break;
						}
						if (rVars != NIL)
						{
							int i = 0;
							ckey = palloc(sizeof(CandidateKey));
							ckey->nVars = list_length(rVars);
							ckey->vars = palloc0(sizeof(List *) * ckey->nVars);
							foreach (rlist, rVars)
							{
								Var *var = (Var *)lfirst(rlist);
								ckey->vars[i] = lappend(ckey->vars[i],var);
								i++;
							}
							result = lappend(result, ckey);
						}
					}
					break;
				}
			}
			break;
		case T_Unique:
			/* TODO: We may not be able to do simply follow the tree down as the Vars do not match up properly
			*		but this problem may be moot since it seems the Unique path is set above the subquery
			*/
			ereport(NOTICE, (errmsg("Subquery with unique attributes")));
			break;
		default:
			ereport(NOTICE, (errmsg("Subquery with nodetag: %d", path->parent->subplan->type)));
			break;
		}
		break;
	default:
		ereport(NOTICE, (errmsg("Unsupported path with nodetag: %d", path->pathtype)));
		break;
	}

	return result;
}

// Given a list that contains hashClauses (lists of RestrictInfo*) determine the hash attribute set for each path (input).
void getHashAttrs(int npath, Path** paths, List **hashClauses, List **hashAttrs, PlannerInfo *root)
{
	int		i,j;
	ListCell   *hcl;
	Relids* relids;

	// Initialize all lists and determine relids for each path
	relids = palloc(npath * sizeof(Relids));
	for (i=0; i < npath; i++)
	{	hashAttrs[i] = NIL;
		relids[i] = paths[i]->parent->relids;
	}

	// Currently processing each list and making no differentiation between what attributes are in each hash clause list (is that acceptable?)
	for (i=0; i < npath-1; i++)
	{
		foreach(hcl, hashClauses[i])
		{	// hashClauses[i] is a list of RestrictInfo* nodes which constitute a single hash clause
			RestrictInfo *restrictinfo = (RestrictInfo *) lfirst(hcl);
			VariableStatData Attr;
			Node   *op;

			Assert(IsA(restrictinfo, RestrictInfo));

			if (!IsA(restrictinfo->clause, OpExpr) ||	!op_iseqjoin(((OpExpr*)restrictinfo->clause)->opno))
				// Equi-join clauses only
				continue;

			// Examine the variables on each side of the restrictinfo clause
			for (j=0; j < npath; j++)
			{	if (bms_is_subset(restrictinfo->right_relids, relids[j]))
				{	// This is the path that it belongs to
					op = get_rightop(restrictinfo->clause);
					examine_variable(root, op, 0, &Attr);
					if (IsA(Attr.var,Var))		// Basic variable = variable comparison
						hashAttrs[j] = lappend(hashAttrs[j], Attr.var);
					if (HeapTupleIsValid(Attr.statsTuple))
						ReleaseVariableStats(Attr);
				}
			}
			for (j=0; j < npath; j++)
			{	if (bms_is_subset(restrictinfo->left_relids, relids[j]))
				{	// This is the path that it belongs to
					op = get_rightop(restrictinfo->clause);
					examine_variable(root, op, 0, &Attr);
					if (IsA(Attr.var,Var))		// Basic variable = variable comparison
						hashAttrs[j] = lappend(hashAttrs[j], Attr.var);
					if (HeapTupleIsValid(Attr.statsTuple))
						ReleaseVariableStats(Attr);
				}
			}
		}
	}
}

int findIndex(Relids* relids, int n, Relids ids)
{
	int i;

	for (i=0; i < n; i++)
	{	if (bms_is_subset(ids, relids[i]))
			return i;
	}
	return -1;
}


int findIndexPath(Path** paths, int n, Relids ids)
{
	int i;

	for (i=0; i < n; i++)
	{	if (bms_is_subset(ids, paths[i]->parent->relids))
			return i;
	}
	return -1;
}

JoinAttrs getJoinAttrs(JoinPath *jpath, PlannerInfo *root)
{
	ListCell   *hcl;
	Path	   *inner_path = jpath->innerjoinpath;
	Path	   *outer_path = jpath->outerjoinpath;
	JoinAttrs result;
	result.innerVars = NIL;
	result.outerVars = NIL;

	foreach(hcl, jpath->joinrestrictinfo)
	{
		RestrictInfo *restrictinfo = (RestrictInfo *) lfirst(hcl);
		VariableStatData inJoinAttr;
		VariableStatData outJoinAttr;
		Node   *inop,
			   *outop;

		Assert(IsA(restrictinfo, RestrictInfo));

		if (!IsA(restrictinfo->clause, OpExpr) ||
			!op_iseqjoin(((OpExpr*)restrictinfo->clause)->opno))
		{
			/* we can't use non-equijoin clauses to make assumptions about the cardinality */
			continue;
		}


		/* find the proper side */
		if (bms_is_subset(restrictinfo->right_relids,
			inner_path->parent->relids)) {
				inop = get_rightop(restrictinfo->clause);
				outop = get_leftop(restrictinfo->clause);
		}
		else
		{
			inop = get_leftop(restrictinfo->clause);
			outop = get_rightop(restrictinfo->clause);
		}

		examine_variable(root, inop, 0, &inJoinAttr);
		examine_variable(root, outop, 0, &outJoinAttr);

		/* we are only interested in equijoins on basic Vars */
		if (IsA(inJoinAttr.var,Var) && IsA(outJoinAttr.var, Var))
		{
			result.innerVars = lappend(result.innerVars, inJoinAttr.var);
			result.outerVars = lappend(result.outerVars, outJoinAttr.var);
		}

		if (HeapTupleIsValid(inJoinAttr.statsTuple))
			ReleaseVariableStats(inJoinAttr);

		if (HeapTupleIsValid(outJoinAttr.statsTuple))
			ReleaseVariableStats(outJoinAttr);
	}
	return result;
}

// Returns true if varList1 is a subset of varList2
bool
subset_list(List *varList1, List *varList2)
{
	ListCell   *cellList1;
	ListCell   *cellList2;
	bool found;

	if (varList1 == NIL)
		return true;
	if (varList2 == NIL)
		return false;

	// Go through each var in list 1 and see if can find a match in list 2
	foreach(cellList1, varList1)
	{
		Var *v = (Var*) lfirst(cellList1);
		found = false;
		foreach(cellList2, varList2) // check each attr passed in
		{
			Var *v2 = (Var*) lfirst(cellList2);
			if (v->varno == v2->varno && v->varattno == v2->varattno) {
				found = true;
				break;
			}
		}
		if (!found)
			return false;
	}
	return true;
}

// Returns the intersection of two lists
List*  intersect_list(List *varList1, List *varList2)
{
	ListCell   *cellList1;
	ListCell   *cellList2;
	List		*result = NIL;

	if (varList1 == NIL || varList2 == NIL)
		return result;

	// Go through each var in list 1 and see if can find a match in list 2
	foreach(cellList1, varList1)
	{
		Var *v = (Var*) lfirst(cellList1);
		foreach(cellList2, varList2) // check each attr passed in
		{
			Var *v2 = (Var*) lfirst(cellList2);
			if (v->varno == v2->varno && v->varattno == v2->varattno) {
				result = lappend(result, v);
				break;
			}
		}
	}
	return result;
}

// Returns the intersection of N lists.  Assumes these are variable lists.
List* intersect_lists(List **lists, int N)
{	int i;

	List * result = lists[0];
	for (i=1; i < N; i++)
		result = intersect_list(result,lists[i]);
	return result;
}