///////////////////////////////////////////////////////////////////////////////
//                                                                           //
//  Copyright (C) 1995-2002 by the Board of Trustees of Leland Stanford      //
//  Junior University.  See LICENSE for details.                             //
//                                                                           //
///////////////////////////////////////////////////////////////////////////////


#include <ac.h>
#include <flags.h>
#include "ineq.h"


///////////////////////////////////////////////////////////////////////////////
// Initialize statics and constants                                          //
///////////////////////////////////////////////////////////////////////////////


Hash_Table<PExpr, PII> LE_Expr::IneqTable(Expr::Hash, Expr::Match);
const Symbol LE_Expr::LE_SYM = "<=";
const Symbol LE_Expr::GE_SYM = ">=";
const Symbol LE_Expr::LT_SYM = "<";
const Symbol LE_Expr::GT_SYM = ">";
int LE_Expr::LE_SORT;


///////////////////////////////////////////////////////////////////////////////
//                                                                           //
// Function: init_ineq_library						     //
// Description: Initialize ineq library                                      //
//                                                                           //
///////////////////////////////////////////////////////////////////////////////
void init_ineq_library()
{
  cerr << "initializing libineq..." << endl;
  LE_Expr::Init();
}


///////////////////////////////////////////////////////////////////////////////
// Ineq_Info class methods                                                   //
///////////////////////////////////////////////////////////////////////////////


///////////////////////////////////////////////////////////////////////////////
//                                                                           //
// Function: Ineq_Info::Ineq_Info					     //
// Description: Constructor for Ineq_Info class                              //
//                                                                           //
///////////////////////////////////////////////////////////////////////////////
Ineq_Info::Ineq_Info(PExpr pexpr, PExpr ineq)
  : _pexpr(pexpr), _hasmin(FALSE), _hasmax(FALSE), _use()
{
  if (ineq)
    _use.Append(ineq);
}


///////////////////////////////////////////////////////////////////////////////
//                                                                           //
// Function: Ineq_Info::Ineq_Info					     //
// Description: Copy constructor for Ineq_Info class                         //
//                                                                           //
///////////////////////////////////////////////////////////////////////////////
Ineq_Info::Ineq_Info(const Ineq_Info& ii)
  : Scoped_Obj(ii), _pexpr(ii._pexpr), _hasmin(ii._hasmin),  _min(ii._min),
    _hasmax(ii._hasmax), _max(ii._max), _use(ii._use) {}


///////////////////////////////////////////////////////////////////////////////
//                                                                           //
// Function: Ineq_Info::RestoreData					     //
// Description: Restore previously saved version of Ineq_Info                //
//                                                                           //
///////////////////////////////////////////////////////////////////////////////
void Ineq_Info::RestoreData(void)
{
  const PII &pii = (PII)GetRestore();
  _pexpr = pii->_pexpr;
  _hasmin = pii->_hasmin;
  _min = pii->_min;
  _hasmax = pii->_hasmax;
  _max = pii->_max;
  _use.Restore(pii->_use);
}


///////////////////////////////////////////////////////////////////////////////
//                                                                           //
// Function: Ineq_Info::DeleteSelf					     //
// Description: Remove Ineq_Info from the table and then delete it           //
//                                                                           //
///////////////////////////////////////////////////////////////////////////////
void Ineq_Info::DeleteSelf()
{
  _use.Destroy();
  LE_Expr::DeleteFromTable(_pexpr);
  delete this;
}


///////////////////////////////////////////////////////////////////////////////
//                                                                           //
// Function: Ineq_Info::SetMin						     //
// Description: Set a new minimum value for a variable                       //
//                                                                           //
///////////////////////////////////////////////////////////////////////////////
void Ineq_Info::SetMin(rat minimum) 
{
  minimum = _hasmin ? max(_min, minimum) : minimum;
  if (!_hasmin || minimum != _min) {
    MakeCurrent();
    _min = minimum;
    _hasmin = 1;
    if (_hasmax) Check();
  }
}


///////////////////////////////////////////////////////////////////////////////
//                                                                           //
// Function: Ineq_Info::SetMax						     //
// Description: Set a new maximum value for a variable                       //
//                                                                           //
///////////////////////////////////////////////////////////////////////////////
void Ineq_Info::SetMax(rat maximum)
{
  maximum = _hasmax ? min(_max, maximum) : maximum;
  if (!_hasmax || maximum != _max) {
    MakeCurrent();
    _max = maximum;
    _hasmax = 1;
    if (_hasmin) Check();
  }
}


///////////////////////////////////////////////////////////////////////////////
//                                                                           //
// Function: Ineq_Info::Check						     //
// Description: Check for consequences of maximum and minimum values of a    //
// variable.                                                                 //
//                                                                           //
///////////////////////////////////////////////////////////////////////////////
void Ineq_Info::Check()
{
  if (_max == _min) {
    AC::CurrentContext()->AssertEq(Num_Expr::New(_min),_pexpr);
  }
  else if (_max < _min) {
    AC::CurrentContext()->MakeInconsistent();
  }
  // TODO: add approximate integer reasoning once we have integer types
}


///////////////////////////////////////////////////////////////////////////////
// End of Ineq_Info class methods                                            //
///////////////////////////////////////////////////////////////////////////////


///////////////////////////////////////////////////////////////////////////////
// LE_Expr class methods                                                     //
///////////////////////////////////////////////////////////////////////////////


///////////////////////////////////////////////////////////////////////////////
//                                                                           //
// Function: LE_Expr::LE_Expr						     //
// Description: Constructor for LE_Expr                                      //
//                                                                           //
///////////////////////////////////////////////////////////////////////////////
LE_Expr::LE_Expr(PExpr right)
  : Gen_Expr(LE_SORT), _right(right)
{
  SetFlag(INTERP);
}


///////////////////////////////////////////////////////////////////////////////
//                                                                           //
// Function: LE_Expr::TypeCheck						     //
// Description: Typecheck a LE_Expr                                          //
//                                                                           //
///////////////////////////////////////////////////////////////////////////////
Bool LE_Expr::TypeCheck()
{
  if (!UnifyType(Expr_Type::TYPE_BOOL) ||
      !_right->UnifyType(Expr_Type::TYPE_RAT))
    return FALSE;
  SetFlag(TYPECHECKED);
  return TRUE;
}


///////////////////////////////////////////////////////////////////////////////
//                                                                           //
// Function: LE_Expr::Print1						     //
// Description: Print a LE_Expr                                              //
//                                                                           //
///////////////////////////////////////////////////////////////////////////////
void LE_Expr::Print1(ostream& os) const
{
  FLAGS::indentcount++;
  os << "\n" << indent 
     << '$' << ExprNum() << ':' << StartInterp << LE_SYM << " 0 ";
  Print1Ptr(os, _right);
  os << FinishInterp;
  FLAGS::indentcount--;
}


///////////////////////////////////////////////////////////////////////////////
//                                                                           //
// Function: LE_Expr::IsSimplerThan1					     //
// Description: Compare two inequality expressions                           //
//                                                                           //
///////////////////////////////////////////////////////////////////////////////
Bool LE_Expr::IsSimplerThan1(PExpr pexpr)
{
  const PLE_Expr &ple = PLE_Expr(pexpr);
  return _right->IsSimplerThan(ple->_right);
}


///////////////////////////////////////////////////////////////////////////////
//                                                                           //
// Function: LE_Expr::Hash1						     //
// Description: Hash function for LE_Expr's                                  //
//                                                                           //
///////////////////////////////////////////////////////////////////////////////
int LE_Expr::Hash1() const
{
  return abs(int(_right) * 23);
}


///////////////////////////////////////////////////////////////////////////////
//                                                                           //
// Function: LE_Expr::Match1						     //
// Description: Match function for LE_Expr's                                 //
//                                                                           //
///////////////////////////////////////////////////////////////////////////////
int LE_Expr::Match1(PCExpr pexpr) const
{
  const PLE_Expr &ple = PLE_Expr(pexpr);
  return _right == ple->_right;
}


///////////////////////////////////////////////////////////////////////////////
//                                                                           //
// Function: LE_Expr::NotifyAssertEq					     //
// Description: Function which is called whenever a LE_Expr is asserted      //
// equal to a distinct.  Since LE_Expr's are Boolean, this will be called    //
// whenever a LE_Expr becomes true or false.  This is the main procedure for //
// determining the consequences of new inequalities.                         //
//                                                                           //
///////////////////////////////////////////////////////////////////////////////
void LE_Expr::NotifyAssertEq(PExpr distinct)
{
  if (_right->Simp() != _right) return;
  if (distinct == TrueVal()) {
    PExpr e,u,pexpr;
    rat c,c1,*pc1;

    if (_right->Sort() != Add_Expr::ADD_SORT) {
      UpdateMin(_right,0);
      e = _right;
      c = 1;
    } else {
      const PAdd_Expr &padd = (PAdd_Expr)_right;
      if (padd->terms()->NumEntries() == 1) {          // Fast case: one variable and a constant
	Hash_Ptr<PExpr, rat> ptr(padd->terms());
	e = ptr->Key();
	c = ptr->Data();
	if (c.sign() == -1)
	  UpdateMax(e, padd->constant() * (-1)/c);
	else UpdateMin(e, padd->constant() * (-1)/c);
	ASSERT_MSG(e->Simp() == e,("Should be simplest"));
      }
      // Solve for most complex
      else padd->Isolate(e, c);                        // 0 <= other terms + c*e
    }

    PII *ppii = IneqTable.Fetch(e);
    if (!ppii) IneqTable.Insert(e, new Ineq_Info(e, this));
    else {
      const PII &pii = *ppii;
      Hash_Table<PExpr, rat> terms(Expr::Hash, Expr::Match, Add_Expr::SIZE);
      terms.Insert(_right, c.sign()/c);                  // 0 <= |1/c|*(other terms) + c.sign()*e

      if (c.sign() == 1 && pii->Hasmax() && e != _right) {     // e <= Max
	terms.Insert(e, -1);
	pexpr = New(Add_Expr::New(pii->Max(), terms));
	AC::CurrentContext()->Assert(pexpr);
	terms.Delete(e);
      }
	
      else if (c.sign() == -1 && pii->Hasmin()) {              // Min <= e
	ASSERT_MSG(e != _right, ("Invariant violated"));
	terms.Insert(e, 1);
	pexpr = New(Add_Expr::New(-pii->Min(), terms));
	AC::CurrentContext()->Assert(pexpr);
	terms.Delete(e);
      }

      Link_List<PExpr> *pll = (pii)->GetUse();
      Link_Ptr<PExpr> ptr(*pll);

      for(; !ptr.EOL(); ++ptr) {

	PLE_Expr ple = ((PLE_Expr)(*ptr));

	u = ple->_right;
	if (ple != ple->Sig() || ple != NewSig(u->Simp1(),NULL)) continue;
	if (u->Sort() != Add_Expr::ADD_SORT)
	  c1 = 1;
	else {
	  pc1 = ((PAdd_Expr)u)->terms()->Fetch(e);
	  assert_msg(pc1,("Couldn't find term"));
	  c1 = *pc1;
	}
	assert_msg(u != _right,("Sanity check"));
	if (c.sign() != c1.sign()) {             // u: 0 <= u.(other terms) + c1*e
	  terms.Insert(u, c1.sign()/c1);         // 0 <= |1/c|*(other terms) + |1/c1|*u.(other terms),
	                                         // c.sign()*e + c1.sign()*e cancel
	  pexpr = Add_Expr::New(0, terms);
	  if (pexpr == Num_Expr::New(0))
	    pexpr = Eq_Expr::New(pexpr,_right);
	  else pexpr = New(pexpr);
	  AC::CurrentContext()->Assert(pexpr);
	  terms.Delete(u);
	}
      }
      (pii)->AddUse(this);
    }
  }
  else {
    ASSERT_MSG(distinct == FalseVal(),("Bool expected"));
    Hash_Table<PExpr, rat> terms(Expr::Hash, Expr::Match, Add_Expr::SIZE);
    terms.Insert(_right, -1);
    PExpr pexpr = New(Add_Expr::New(0, terms));
    AC::CurrentContext()->Assert(pexpr);
    pexpr = Eq_Expr::New(Num_Expr::New(0), _right);
    AC::CurrentContext()->Deny(pexpr);
  }
}


///////////////////////////////////////////////////////////////////////////////
//                                                                           //
// Function: LE_Expr::UpdateMin						     //
// Description: Update the minimum value of a variable                       //
//                                                                           //
///////////////////////////////////////////////////////////////////////////////
void LE_Expr::UpdateMin(PExpr pexpr, rat min)
{
  PII *ppii = IneqTable.Fetch(pexpr);
  if (ppii)
    (*ppii)->SetMin(min);
  else {
    PII pii = new Ineq_Info(pexpr);
    pii->SetMin(min);
    IneqTable.Insert(pexpr, pii);
  }
}


///////////////////////////////////////////////////////////////////////////////
//                                                                           //
// Function: LE_Expr::UpdateMax						     //
// Description: Update the maximum value of a variable                       //
//                                                                           //
///////////////////////////////////////////////////////////////////////////////
void LE_Expr::UpdateMax(PExpr pexpr, rat max)
{
  PII *ppii = IneqTable.Fetch(pexpr);
  if (ppii)
    (*ppii)->SetMax(max);
  else {
    PII pii = new Ineq_Info(pexpr);
    pii->SetMax(max);
    IneqTable.Insert(pexpr, pii);
  }
}


///////////////////////////////////////////////////////////////////////////////
//                                                                           //
// Function: LE_Expr::Init						     //
// Description: Install inequality symbols                                   //
//                                                                           //
///////////////////////////////////////////////////////////////////////////////
void LE_Expr::Init()
{
  LE_SORT = Install_Operator(LE_SYM, Parse);
  Install_Operator_Abbreviation(GE_SYM, Parse);
  Install_Operator_Abbreviation(LT_SYM, Parse);
  Install_Operator_Abbreviation(GT_SYM, Parse);
}


///////////////////////////////////////////////////////////////////////////////
//                                                                           //
// Function: LE_Expr::NewSig						     //
// Description: Create a new LE_Expr                                         //
//                                                                           //
///////////////////////////////////////////////////////////////////////////////
PExpr LE_Expr::NewSig(PExpr right, PExpr sigx)
{
  if (right->IsAtomic()) {
    ASSERT_MSG(right == right->Simp(),
	       ("Expected simplest"));

    if (right->Sort() == NUM_SORT) {
      return 0 <= ((PNum_Expr)right)->value() ? 
	(PExpr)TrueVal() : (PExpr)FalseVal();
    }
  }

  LE_Expr expr(right);
  return NewExpr(&expr, sigx);
}


///////////////////////////////////////////////////////////////////////////////
//                                                                           //
// Function: LE_Expr::Parse						     //
// Description: Parse a LE_Expr                                              //
//                                                                           //
///////////////////////////////////////////////////////////////////////////////
PExpr LE_Expr::Parse(Symbol sym, int numargs, PExpr *args)
{
  if (numargs != 2) goto bad_num_args;
  else {
    Bool less = (sym == LE_SYM || sym == GT_SYM);
    Bool eql = (sym == LE_SYM || sym == GE_SYM);

    if (args[0] == args[1])
      return eql ? (PExpr)TrueVal() : (PExpr)FalseVal();

    Hash_Table<PExpr, rat> t(Expr::Hash, Expr::Match, Add_Expr::SIZE);
    t.Insert(args[0], less ? -1 :  1);
    t.Insert(args[1], less ?  1 : -1);

    PExpr lepart = New(Add_Expr::New(0,t));
    return eql ? lepart : ITE_Expr::New(lepart, FalseVal(), TrueVal());
  }

 bad_num_args:
  cout << "Two arguments required for inequality" << endl;
  return NULL;
}


///////////////////////////////////////////////////////////////////////////////
// End of LE_Expr class methods                                              //
///////////////////////////////////////////////////////////////////////////////
