//
// lmath.cxx
//
// some basic complex number theoretical algorithms
//
// Copyright (C) 1996-7 by Leonard Janke (janke@unixg.ubc.ca)

#include <linteger/lmath.hxx>
#include <linteger/montyrep.hxx>
#include <linteger/residue.hxx>
#include <linteger/lexp.hxx>
#include <cassert>

void LMath::LongDivide(const LInteger& dividend, const LInteger& divisor, 
		       LInteger& Quotient, LInteger& Remainder)
{
  assert( divisor.IsNonZero() );

  if ( divisor._digits > dividend._digits )
    {
      if ( dividend.IsNonNegative() )
	{
	  Quotient=0u;
	  Remainder=dividend;
	}
      else
	{
	  if ( divisor.IsPositive() )
	    Quotient=-1;
	  else
	    Quotient=1u;

	  Remainder=LInteger::AbsoluteValue(divisor)+dividend;
	}

      return;
    }

  unsigned int* q;
  unsigned int* r;
  int qDigits, rDigits;
    
  BMath::Divide(dividend._magnitude, dividend._digits, 
		divisor._magnitude, divisor._digits, q,r);

  qDigits=LInteger::compress(q,dividend._digits-divisor._digits+1);
  rDigits=LInteger::compress(r,dividend._digits+1);

  if (  dividend.IsNonNegative() && divisor.IsNonNegative() )
    {
      Quotient=LInteger(q,qDigits,0,0);
      Remainder=LInteger(r,rDigits,0,0);
      return;
    }

  if ( dividend.IsNonNegative() && divisor.IsNegative() )
    {
      Quotient=LInteger::Negative(LInteger(q,qDigits,0,0));
      Remainder=LInteger(r,rDigits,0,0);
      return;
    }

  if ( dividend.IsNegative() && divisor.IsNonNegative() )
    {
      Remainder=LInteger(r,rDigits,0,0);

      if ( Remainder.IsZero() )
	Quotient=LInteger::Negative(LInteger(q,qDigits,0,0));
      else
	{
	  Quotient=-(LInteger(q,qDigits,0,0)+LInteger::One);
	  Remainder=LInteger::AbsoluteValue(divisor)-Remainder;
	}

      return;
    }

   // dividend<0 && divisor<0

  Remainder=LInteger(r,rDigits,0,0);

  if ( Remainder.IsZero() )
    Quotient=LInteger(q,qDigits,0,0);
  else
    {
      Quotient=LInteger(q,qDigits,0,0)+LInteger::One;
      Remainder=LInteger::AbsoluteValue(divisor)-Remainder;
    }

  return;
}

LInteger LMath::Sqrt(const LInteger& n)
{
  // This algorithm is from Henri Cohen's 
  // _A_Course_In_Computational_Algebraic_Number_Theory_

  assert( n.IsPositive() );
  // get an (over)estimate of the square root
  const int e=BMath::BSR(n._magnitude,n._digits);
  LInteger x=LInteger::TwoToThe((e+2)/2);

  // now correct with a variant of Newton's method
  LInteger y;
  while ( 1 )
    {
      y=LInteger(x+n/x).DivByTwo();
      if ( y<x )
	x=y;
      else
	break;
    }

  return x;
}

LInteger LMath::CRTModExp(const LInteger& g, const LInteger& x, 
			  const LInteger& p, const LInteger& q)
{
  // special case of ModExp when p and q are prime
  // algorithm from Knuth

  LInteger gModP(g%p);
  LInteger gModQ(g%q);

  const LInteger pPow(x%(p-1));
  const LInteger qPow(x%(q-1));

  const MontyRing saveRing=MontyRep::Ring();
  const MontyRing pRing(p);
  const MontyRing qRing(q);
  
  MontyRep::SetRing(pRing);
  const MontyRep gModPRep=MontyRep(gModP);
  gModP=LC_Exp(gModPRep,pPow).ToLInteger();
  
  MontyRep::SetRing(qRing);
  const MontyRep gModQRep=MontyRep(gModQ);
  gModQ=LC_Exp(gModQRep,qPow).ToLInteger();
  
  MontyRep::SetRing(saveRing);
  
  return ChineseRemainderTheorem(gModP,p,gModQ,q);
}

LInteger LMath::ModExp(const LInteger& g, const LInteger& x, const LInteger& n)
{
  assert ( n > LInteger::One );

  LInteger gModN=g%n;

  if ( n.IsEven() ) // can't use MontyRep
    {
      const LInteger saveRing=ResidueClass::Ring();

      ResidueClass::SetRing(n);
      const ResidueClass gModNRep(g);
      gModN=LC_Exp(gModNRep,x).ToLInteger();

      ResidueClass::SetRing(saveRing);
    }
  else
    {
      const MontyRing saveRing=MontyRep::Ring();
      const MontyRing nRing(n);

      MontyRep::SetRing(nRing);

      const MontyRep gModNRep(gModN);

      gModN=LC_Exp(gModNRep,x).ToLInteger();

      MontyRep::SetRing(saveRing);
    }
 
  return gModN;
}

LInteger LMath::GCD(const LInteger& x, const LInteger& y)
{
  // This algorithm is from Henri Cohen's 
  // _A_Course_in_Computational_Algebraic_Number_Theory_ 

  assert (  x.IsNonNegative() &&  y.IsNonNegative() );

  if ( y.IsZero() )
    return x;

  if ( x.IsZero() )
    return y;

  LInteger a(x);
  LInteger b(y);
  LInteger r;

  if ( a < b )
    LC_Swap(a,b);

  r=a%b;
  a=b;
  b=r;

  if ( b.IsZero() )
    return a;

  int k=0;
  if ( a.IsEven() && b.IsEven() )
    {
      k=LC_Min( BMath::BSF(a._magnitude,a._digits),
		BMath::BSF(b._magnitude,b._digits) );
      a>>=k;
      b>>=k;
    }

  if ( a.IsEven() )
    {
      int bsf=BMath::BSF(a._magnitude,a._digits);
      a>>=bsf;
    }
  else if ( b.IsEven() )
    {
      int bsf=BMath::BSF(b._magnitude,b._digits);
      b>>=bsf;
    }

  LInteger t;

  while ( 1 )
    {
      t=a-b;
      t>>=1;

      if ( t.IsZero() )
	break;

      if ( t.IsEven() )
	{
	  int bsf=BMath::BSF(t._magnitude,t._digits);
	  t>>=bsf;
	}

      if ( t.IsNegative() )
	b=-t;
      else
	a=t;
    }

  return LInteger::TwoToThe(k)*a;
}

LInteger LMath::ExtendedEuclid(const LInteger& x, const LInteger& y,
			       LInteger& u, LInteger& v)
{
  // This algorithm is from Henri Cohen's 
  // _A_Course_in_Computational_Algebraic_Number_Theory_ 

  LInteger v1, t1;
  LInteger d, q, v3, t3;
  u=1u;
  d=x;

  if ( y.IsZero() )
    {
      v=0u;
      return d;
    }

  v1=0u;
  v3=y;

  while ( v3.IsNonZero() )
    {
      LongDivide(d,v3,q,t3);
      t1=u-q*v1;
      u=v1;
      d=v3;
      v1=t1;
      v3=t3;
    }

  v=(d-x*u)/y;
  return d;
}


LInteger LMath::InvertUnit(const LInteger& x, const LInteger& n)
{
  assert( n.IsPositive() );
  LInteger u, v;
  LInteger gcd=ExtendedEuclid(x,n,u,v);

  assert( gcd.IsOne() );

  return ( u.IsNonNegative()  ) ? u : u+n ;
}

LInteger LMath::ChineseRemainderTheorem(const LInteger& x_p, const LInteger& p,
					const LInteger& x_q, const LInteger& q)
{
  assert ( p.IsPositive() && q.IsPositive() );
  LInteger u,v;

  LInteger gcd(ExtendedEuclid(p,q,u,v));
  assert ( gcd.IsOne() );

  LInteger n=p*q;

  LInteger x=(((u*p)%n)*x_q)%n+(((v*q)%n)*x_p)%n;

  return ( x < n ) ? x : x-n ;
}

int LMath::Composite(const LInteger& x, PRNG& prng)
{
  // returns 0 if x passes one round of the Rabin Miller test and 1
  // if it fails
  //
  // 1 means that x is definitely composite
  // while 0 means that x is likely prime
  //
  // see Schneier's _Applied_Cryptography_,
  // for references to research regarding the probability that a number
  // is it prime fails the Rabin Miller compositeness test
  //
  // This algorithm is from Henri Cohen's 
  // _A_Course_in_Computational_Algebraic_Number_Theory_ 
 
  assert( x.IsPositive() );

  if ( x.IsEven() ) 
    return 1;

  if ( x.IsOne() )
    return 0;

  const LInteger xMinusOne(x-1);

  int t=BMath::BSF(xMinusOne._magnitude,xMinusOne._digits);

  const LInteger q(xMinusOne>>t);

  LInteger a(prng,LInteger(2u),xMinusOne);
  LInteger b=ModExp(a,q,x);

  if (  ! b.IsOne() )
    {
      int e=0;

      while ( ! b.IsOne()  && b != xMinusOne && e <= t-2 )
	{
	  b.Square();
	  b%=x;
	  e++;
	}

      if ( b != xMinusOne )
	return 1;
    }

  return 0;
}

LInteger LMath::RandomProbablePrime(const int numBits, PRNG& prng,
				    int (*PreProcess)(LInteger&)=NULL,
				    const int fewerBits=0, 
				    const int RabinMillerIterations=20)
{
  LInteger c;
  int composite=1;

  do 
    {
      do
	{
	  do
	    c=LInteger(numBits,prng,fewerBits,1);
	  while ( PreProcess && !PreProcess(c) );

	  composite=c.HasSmallPrimeFactor();
	}
      while ( composite );

      if ( c >=3 )
	for (int i=0; i<(RabinMillerIterations-1) && !composite ; i++)
	  composite=Composite(c,prng);
      else
	if ( c==2 )
	  composite=0;
	else
	  composite=1;
    }
  while ( composite );

  return c;
}
