#include "nr3matlab.h" 
typedef complex<double> dcmplx;

void update_cost(const MatDoub &A1,const MatDoub &A2,const MatDoub &E, const double &baseline_cost, const double& lam_a,VecDoub &cost_vector, int& iter, int& n);
void calc_roots(double[], double[],int&);
void apply_dot(const MatDoub &A1,const MatDoub &A2, int& p,int& q,int& n,double& K4_plus,double& K5_plus);
void apply_rotation(MatDoub &A1,MatDoub &E, MatDoub &Q, double& c_cur,double& s_cur,int& p,int& q,int& n);
void initialize_Q(MatDoub &Q, int &n);
class roots_class{
  int i,j,num,best_index, polish_iter;
  double p_coef[5],c[5],s[5], y_wik,R,W,P,Q,U,alpha,beta,gamma,*A,*B,*C,*D,*E,*K,roots_const,discrim_1,discrim_2,discrim_3,y_try,c_try,max_cost,cost, s_const_term,c_const_term,opp_try,pol_root,deriv,f_eval;  
  char flag;
  void solve_quartic();
  void check_root(double&);
  void best_cond();
  void best_root();
  void polish_root(double&);
public:
  // double c[5];
  // double s[5];
  double c_cur;
  double s_cur;
  int num_roots;
  roots_class();
  //void reset();
  void find_roots(double[]); 
};

// main routine; should be a simple function of K[5] 
void roots_class::find_roots(double K_copy[])
{
  K=K_copy;
  num_roots=0;
  // find "most convenient" polynomial; either 'c' or 's'
  best_cond();
  // solve quartic using Ferrari method
  solve_quartic();
  // try all roots, wee which one maximizes the expression
  best_root();
}

inline void roots_class::polish_root(double &pol_root)
{
  // this function takes 1 as a starting point to converge to a root close to 1;
  // pol_root = 1.0; 
  for (j=0;j<polish_iter;j++)
    {
      deriv = (*D) + pol_root*(2.0*(*C) + pol_root*(3.0*(*B) + pol_root*4.0*(*A)));
      f_eval = *E + pol_root*(*D + pol_root*(*C + pol_root*(*B + pol_root*(*A))));
      pol_root = pol_root - f_eval/deriv;
    }
  //mexPrintf("pol_root = %15.20f \n",(pol_root));
  //mexPrintf("f_eval = %15.20f \n",(f_eval));
  
  //return pol_root;
}

inline void roots_class::best_cond()
{
  s_const_term = pow(K[2],2) - pow(K[4],2);
  c_const_term = pow(K[2],2) - pow(K[3],2);
  
  if ( abs(c_const_term) >= abs(s_const_term))
    {
      flag = 'c';
      p_coef[0] = 4*pow(K[2],2) + 4*pow(K[0]-K[1],2);
      p_coef[1] = 4*K[2]*K[4] + 4*K[3]*(K[0] - K[1]);
      p_coef[2] = pow(K[4],2) - 4*pow(K[2],2) + pow(K[3],2) - 4*pow(K[0]-K[1],2);
      p_coef[3] = -2*K[2]*K[4] - 4*K[3]*(K[0]-K[1]);
      p_coef[4] = c_const_term;
    }
  else 
    {
      flag = 's';
      p_coef[0] = 4*pow(K[2],2) + 4*pow(K[1]-K[0],2);
      p_coef[1] = 4*K[2]*K[3] + 4*K[4]*(K[1] - K[0]);
      p_coef[2] = pow(K[3],2) - 4*pow(K[2],2) + pow(K[4],2) - 4*pow(K[1]-K[0],2);
      p_coef[3] = -2*K[2]*K[3] - 4*K[4]*(K[1]-K[0]);
      p_coef[4] = s_const_term;
    }
}

inline void roots_class::best_root()
{
  best_index = 0;
  // mexPrintf("num_roots = %d \n",num_roots);
  for (num=0 ; num<num_roots;++num)
    {
      //K1*pow(c[num],2) + K2*pow(y[num],2) + K3*y[num]*c[num] + K4*c[num] + K5*y[num];
      cost = K[0]*pow(c[num],2) + K[1]*pow(s[num],2) + K[2]*s[num]*c[num] + K[3]*c[num] + K[4]*s[num];
      if (num == 0) max_cost=cost;
      else 
	{
	  if (cost > max_cost)  best_index=num;
	}
    }
  c_cur = c[best_index];
  s_cur = s[best_index];
  //mexPrintf("c_cur = %15.20f \n",(c_cur));
  //mexPrintf("s_cur = %15.20f \n",(s_cur));
}

inline void roots_class::check_root(double& cur_root)
{
  // this function ensures that c^2 + s^2 = 1;

  // we first polish the root if it is very close to 1, using 1 as our starting point
  //mexPrintf("cur_root_pre = %15.20f \n",(cur_root));
  if (abs(cur_root - 1.0) < 0.01) polish_root(cur_root);
  //mexPrintf("cur_root_post = %15.20f \n",(cur_root));

  //cur_root = polish_root(cur_root);

  if (flag == 's')
    {
      if (abs(cur_root) < 1.0)
	{
	  opp_try = (2*K[2]*pow(cur_root,2) + K[3]*cur_root - K[2])/(cur_root*(2*(K[1]-K[0])) + K[4]);
	  if (opp_try > 0) opp_try = sqrt(1 - pow(cur_root,2));
	  else opp_try = -sqrt(1 - pow(cur_root,2));
	  s[num_roots] = cur_root; 
	  c[num_roots] = opp_try;
	  ++num_roots;
	}
      else {
	//mexPrintf("cur_root = %15.20f \n",(cur_root));
	s[num_roots]=0.0;
	c[num_roots]=1.0;
	++num_roots;
      }
    }
  else 
    {
      if (abs(cur_root) < 1.0)
	{
	  opp_try = (2*K[2]*pow(cur_root,2) + K[4]*cur_root - K[2])/(cur_root*(2*(K[0]-K[1])) + K[3]);
	  if (opp_try > 0) opp_try = sqrt(1 - pow(cur_root,2));
	  else opp_try = -sqrt(1 - pow(cur_root,2));
	  s[num_roots] = opp_try;
	  c[num_roots] = cur_root;
	  ++num_roots;
	}
      else {
	//mexPrintf("cur_root = %15.20f \n",(cur_root));
	s[num_roots]=0.0;
	c[num_roots]=1.0;
	++num_roots;
      }
    }
}
    

 void roots_class::solve_quartic()
{
  A=&p_coef[0];
  B=&p_coef[1];
  C=&p_coef[2];
  D=&p_coef[3];
  E=&p_coef[4];
  
  alpha = (-3.0/8.0)*(pow(*B,2))/(pow(*A,2)) + (*C)/(*A);
  beta = (1.0/8.0)*(pow((*B)/(*A),3.0)) - ((*B)*(*C))/(2.0*pow(*A,2)) + (*D)/(*A);
  gamma = (-3.0/256.0)*(pow((*B)/(*A),4.0)) + (1.0/16.0)*(*C)*pow(*B,2.0)/(pow(*A,3.0)) - (*B)*(*D)/(4.0*pow(*A,2.0)) + (*E)/(*A);
  
  if (beta==0){
    // mexErrMsgTxt("Beta = 0");
    roots_const = -(*B)/(4.0*(*A));
    discrim_1 = (pow(alpha,2)-4*gamma);
    if (discrim_1 >= 0)
      {
	discrim_2 =  (-alpha - sqrt(discrim_1))/2.0;
	discrim_3 =  (-alpha + sqrt(discrim_2))/2.0;
	if (discrim_2 > 0){
	  y_try =  roots_const + sqrt(discrim_2);
	  check_root(y_try);
	  y_try = roots_const - sqrt(discrim_2);
	  check_root(y_try);  
	}

	if (discrim_3 > 0){
	  y_try = roots_const + sqrt(discrim_3);
	  check_root(y_try);
	  y_try = roots_const - sqrt(discrim_3);
	  check_root(y_try); 
	}
      }
  }
		   
  else
    {
      P = -(pow(alpha,2.0))/12.0 - gamma;     
      Q = (-1.0/108.0)*pow(alpha,3.0) + alpha*gamma/3.0 - (1.0/8.0)*pow(beta,2.0);
      discrim_3 =  (1.0/4.0)*pow(Q,2.0) + (1.0/27.0)*pow(P,3.0);

      if (discrim_3 >= 0) {
	R = -Q/2.0 + sqrt( (1.0/4.0)*pow(Q,2.0) + (1.0/27.0)*pow(P,3.0));
	U = pow(R,1.0/3.0);

	if (U != 0.00) {
	  y_wik = (-5.0/6.0)*alpha + U - P/(3.0*U);
	}
	else {
	  y_wik = (-5.0/6.0)*alpha + U - pow(Q,1.0/3.0);
	}
    
	W = sqrt(alpha + 2.0*y_wik);  
	roots_const = (-(*B)/(4.0*(*A)));
	discrim_1 = -(3.0*alpha + 2.0*y_wik + 2.0*beta/W);
	discrim_2 = -(3.0*alpha + 2.0*y_wik - 2.0*beta/W);

	if (discrim_1 >= 0){
	  y_try =  roots_const + ( W  + sqrt(discrim_1))/2.0;
	  check_root(y_try);
	  y_try = roots_const + ( W  - sqrt(discrim_1))/2.0;
	  check_root(y_try);
	}

	if (discrim_2 >= 0){
	  y_try = roots_const + (-W  + sqrt(discrim_2))/2.0;
	  check_root(y_try);
	  y_try = roots_const + (-W  - sqrt(discrim_2))/2.0;
	  check_root(y_try);
	}
      }
    }
}


// constructor for root_class objects
roots_class::roots_class() {
  for (int num_ct = 0; num_ct<5 ; ++num_ct)
    {
      s[num_ct]=0;
      c[num_ct]=1;
    }
  num_roots = 0;
  c_cur=1;
  s_cur=0;
  polish_iter=4;
}

// void roots_class::reset(){
//   for (int num_ct = 0; num_ct<num_roots+1 ; ++num_ct)
//     {
//       s[num_ct]=0;
//       c[num_ct]=1;
//     }
//   c_cur=1;
//   s_cur=0;
//   num_roots = 0;
// }

// entry to program [Q,f] = f(A1,A2,E,lam_a,iter_max,theta_thresh)
// NOTE: E = C1'C2 
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) {
  // create inputs & outputs 
  MatDoub A1(prhs[0]); 
  MatDoub A2(prhs[1]);
  MatDoub E(prhs[2]);
  double baseline_cost = *(mxGetPr(prhs[3]));
  double lam_a = *(mxGetPr(prhs[4]));
  int iter_max = *(mxGetPr(prhs[5]));
  double theta_thresh = *(mxGetPr(prhs[6]));
  int n=A1.ncols();
  MatDoub Q(n,n,plhs[0]);
  // VecDoub cost_vector(iter_max,plhs[1]);
  VecDoub cost_vector(1,plhs[1]);

  // initialize matrices, counters, etc
  double y[4],p_coef[5],K[5],K4_plus,K5_plus,c_cur_mex,s_cur_mex,*iter_max_pt,c[4],cost,max_cost;
  int num_roots,best_index,num,p,q;
  int stop_flag=0;
  int iter=0;
  roots_class roots_inst;
  initialize_Q(Q,n);
   
  while (iter < iter_max && stop_flag == 0 ){
    //update_cost(A1,A2,E,cost_vector,iter,n);
    stop_flag=1;
    //if (cost_vector[iter]<0.1) break;
    for (p=0 ; p < n-1 ; p++){
      for (q=p+1 ; q < n; q++){
     
	// this function modifies the contents of K4_plus/K5_plus
	// this is done in order to calculate coefficients 
	apply_dot(A1,A2,p,q,n,K4_plus,K5_plus);
	K[0] = lam_a*(A1[p][p]*A2[p][p] + A1[q][p]*A2[q][p] + A1[p][q]*A2[p][q] + A1[q][q]*A2[q][q]);
	K[1] = lam_a*(A1[p][p]*A2[q][q] - A1[p][q]*A2[q][p] - A1[q][p]*A2[p][q] + A1[q][q]*A2[p][p]);
	K[2] = lam_a*(A1[p][p]*(-A2[q][p] - A2[p][q]) + A1[q][p]*(A2[p][p]-A2[q][q]) + A1[p][q]*(A2[p][p] - A2[q][q]) + A1[q][q]*(A2[q][p]+A2[p][q]));
	K[3] = E[p][p]+E[q][q] + lam_a*K4_plus; 
	K[4] =-E[p][q]+E[q][p] + lam_a*K5_plus;

	// this function modifies the contents of y=[root1,root2,root3,root4]
	// it also modifies num_roots[0]
	
	roots_inst.find_roots(K);
    	c_cur_mex = roots_inst.c_cur;
	s_cur_mex = roots_inst.s_cur;
	
	if (abs(s_cur_mex>theta_thresh)) stop_flag=0;
	apply_rotation(A1,E,Q,c_cur_mex,s_cur_mex,p,q,n);
	//mexPrintf("c = %15.8f \n",c_cur_mex);    
	//mexPrintf("s = %15.8f \n",s_cur_mex);    	
      }
    }
    
    ++iter;
  }
  
  // this way, we only update the cost once
  iter=0;
  update_cost(A1,A2,E,baseline_cost,lam_a,cost_vector,iter,n);
}

inline void initialize_Q(MatDoub &Q, int &n)
{
  for (int num_count =0;num_count < n; ++num_count) Q[num_count][num_count]=1.0;
}

// this function calculates the cost and appends cost_vector
inline void update_cost(const MatDoub &A1,const MatDoub &A2,const MatDoub &E, const double &baseline_cost, const double& lam_a,VecDoub &cost_vector, int& iter, int& n) 
{
  static double cost;
  static int num;
  static int num_2;
  cost=0.0;
  
  for (num=0; num<n ; ++num  )
    {
      cost += (-2.0)*E[num][num];
      for (num_2=0; num_2<n ; ++num_2)
	{
	  cost += lam_a*(SQR( (A1[num][num_2] - A2[num][num_2]) ));
	}
    }
  cost_vector[iter]=cost+baseline_cost;
}

// this function applies the Jacobi rotation Q'AQ, Q'E, and Q_new = Q*Q_old 
inline void apply_rotation(MatDoub &A1,MatDoub &E, MatDoub &Q, double& c_cur,double& s_cur,int& p,int& q,int& n) 
{
  static double A1_p_p,A1_p_q,A1_q_p,A1_q_q,A1_p_num,A1_q_num,A1_num_p,A1_num_q,E_p_num,E_q_num,Q_p_num,Q_q_num;
  static int num;
   
  A1_p_p = A1[p][p];
  A1_p_q = A1[p][q];
  A1_q_p = A1[q][p];
  A1_q_q = A1[q][q];
  A1[p][p] = A1_p_p*pow(c_cur,2) + c_cur*s_cur*(A1_q_p + A1_p_q) + A1_q_q*pow(s_cur,2);
  A1[q][q] = A1_p_p*pow(s_cur,2) - c_cur*s_cur*(A1_p_q + A1_q_p) + A1_q_q*pow(c_cur,2);
  A1[p][q] = A1_p_q*pow(c_cur,2) + c_cur*s_cur*(A1_q_q - A1_p_p) - A1_q_p*pow(s_cur,2);
  A1[q][p] = A1_q_p*pow(c_cur,2) + c_cur*s_cur*(A1_q_q - A1_p_p) - A1_p_q*pow(s_cur,2);

  for (num=0 ; num<n ; ++num)
    {
      E_p_num = E[p][num];
      E_q_num = E[q][num];
      Q_p_num = Q[p][num];
      Q_q_num = Q[q][num];
      E[p][num] = c_cur*E_p_num + s_cur*E_q_num;
      E[q][num] =-s_cur*E_p_num + c_cur*E_q_num;
      Q[p][num] = c_cur*Q_p_num + s_cur*Q_q_num;
      Q[q][num] =-s_cur*Q_p_num + c_cur*Q_q_num;

      if (num!=p && num!=q){
	// we make copies of variables that we modify
	A1_p_num = A1[p][num];
	A1_q_num = A1[q][num];
	A1_num_p = A1[num][p];
	A1_num_q = A1[num][q];
	A1[p][num] = c_cur*A1_p_num + s_cur*A1_q_num;
	A1[q][num] = c_cur*A1_q_num - s_cur*A1_p_num;
	A1[num][p] = c_cur*A1_num_p + s_cur*A1_num_q;
	A1[num][q] = c_cur*A1_num_q - s_cur*A1_num_p;
      }
    }
 
}

inline void apply_dot(const MatDoub &A1,const MatDoub &A2, int &p,int &q,int &n,double &K4_plus,double &K5_plus)
{
  static double sum_K4;
  static double sum_K5;
  static int i;
  
  sum_K4=0.0;
  sum_K5=0.0;

  for(i = 0; i<n ; i++)
    {
      if ( i!=p  &&  i!=q)
	{
	  sum_K4  += A1[p][i]*A2[p][i] + A1[q][i]*A2[q][i] + A1[i][p]*A2[i][p] + A1[i][q]*A2[i][q];
	  sum_K5  += A1[i][q]*A2[i][p] + A1[q][i]*A2[p][i] - A1[i][p]*A2[i][q] - A1[p][i]*A2[q][i];
	}
      }
   K4_plus=sum_K4;
   K5_plus=sum_K5;
}

