//
// This file is part of ProSMART.
//

#include "restrain_bfgs_optimisation.h"

//#define LIKELIHOOD_COUTS
//#define UD

vector<double> generate_sigmas(vector<double> k0, vector<residue_alignment> &restraints1, vector<residue_alignment> &restraints2, PDBfile &pdb1, PDBfile &pdb2, vector<double> &res_scores1, vector<double> &res_scores2, double dist_param)
{
  int no_params = k0.size();
  Diff_Info X;
  //Diff_Info X_all;
  //double cutoff = dist_param*0.7;	//used for filtering higher distances - these are unreliable due to the method of spherical filtering. Currently disabled, as this is accounted for by automatically increasing the sphere size for the prior structure.
  vector<double> means; //values that the variables were divided by to normalise.
  vector<double> d;		//original distance
  vector<double> r;		//restraint distance
  //vector<double> B;		//sum of the two B-factors in the original protein
  //vector<double> s;		//average of the two alignment scores
  vector<vector<double> > x_Brs;	//to store the variables
  vector<double> sk;	//delta(k)
  vector<double> yk;	//delta(dlL(k))
  double pk;	//1/((yk^T)*sk)
  vector<vector<double> > sy;
  vector<vector<double> > ys;
  vector<vector<double> > ss;
  vector<vector<double> > psy;
  vector<vector<double> > pys;
  vector<vector<double> > pss;
  vector<vector<double> > k;		//parameters
  vector<double> d_r_2;	//(d-r)^2
  vector<double> dr2s2; //((d-r)^2)/(sigma^2)
  vector<double> f_n;
  vector<double> f_d;
  vector<double> k_new;
  vector<double> empty_vector;
  //bool found_alpha;
  bool keep_going = 1;
  int idx;
  double unif_1_n = 0.5/dist_param;	//1/n where n = uniform parameter (n=2*dist_param)
  
  double c1;			//parameter used for Wolfe conditions
  double c2;			//parameter used for Wolfe conditions
  double e;				//convergence tolerance parameter
  
  vector<double> sigma2; //estimate of the variance of d, given mean=r
  vector<vector<double> > grad_sigma2; //derivative of sigma2 with respect to k
  vector<vector<double> > H;	//estimate of the inverse Hessian
  vector<vector<double> > H1;	//temporary
  vector<vector<double> > H2;	//temporary
  vector<vector<double> > I;	//Identity matrix
  vector<vector<double> > IN;	//(analagous to) information matrix
  vector<double> p;		//direction in parameter-space in which to travel
  vector<double> temp;
  
  double lL;			//log-likelihood at k
  vector<double> dlL;	//grad-log-likelihood at k
  vector<double> dlL_old;
  double mdlL;			//||dlL||
  double alpha0;
    
  //read data
  //X_all.initialise(restraints1, restraints2, pdb1, pdb2, res_scores1, res_scores2, diff_dist);
  X.initialise(restraints1, restraints2, pdb1, pdb2, res_scores1, res_scores2);
  //X_all.filter(X,cutoff);	//puts all restraints, with values less than cutoff, into X.
  //cout << endl << X << endl << X.size() << endl;
    
  //get variables (observations)
  r = X.get_r();			//distance between atoms in the prior structure.
  d = X.get_d();			//distance between atoms in the target structure. 
  d_r_2 = get_x_m_2(d,r);	//this is computationally convenient, since (d-r)^2 pops up in the likelihood function.
  
  //x_Brs.push_back(X.get_B());
  x_Brs.push_back(r);
  //x_Brs.push_back(X.get_s());
  
  //get initial values
  k.push_back(k0);
  c1 = 0.0001;		//standard value for Newton methods
  c2 = 0.9;			//standard value for Newton methods
  e = 1;		//threshold for the mod-grad-lL : different values should be tried.
  sigma2 = get_sigma2(no_params,x_Brs,k[0]);
  grad_sigma2 = get_grad_sigma2(no_params,x_Brs,k[0]);
  
  /*cout << endl << endl;
  for(int i=0; i<10; i++){
    cout << grad_sigma2[i] << endl;
  }
  cout << endl << endl;*/
  
  #ifdef UD
    dr2s2 = d_r_2/sigma2;
	f_n = get_normal(dr2s2,sigma2);
    f_d = get_f_d(k[0][0],unif_1_n,f_n);		//k[0][0]==w
	lL = logL_(f_d);
	dlL = grad_logL_(f_d,f_n,k[0],sigma2,dr2s2,grad_sigma2,unif_1_n);
  #else
    lL = logL(d_r_2, sigma2);
	dlL = grad_logL(d_r_2, sigma2, grad_sigma2);
  #endif
  
  alpha0 = 1.0/sqrt(v_prod(dlL,dlL));
  I = identity(no_params);			//needed for calculation of H_k in update algorithm
  H = diag(k[0]);
  //H = 1/(v_prod_T(dlL,dlL));		//this isn't negative, since -log(L) has been taken rather than log(L)
  //temp.push_back(1/v_prod(dlL,dlL));
  //temp.push_back(1);
  //H.push_back(temp);
  #ifdef LIKELIHOOD_COUTS
  cout << endl << endl << endl << "Refining parameters to estimate restraint sigmas...";
  cout << endl << endl << "H: ";
  M_view(H);
  #endif

  idx = 0;
  while(keep_going==1){
    mdlL = mod(dlL);
  
  #ifdef LIKELIHOOD_COUTS
	cout << endl << "Iteration = " << idx;
	cout << endl << "Parameters = " << k[idx];
    cout << endl << "log-likelihood = " << lL;
    cout << endl << "grad-log-likelihood = " << dlL;
    cout << endl << "mod-grad-log-likelihood = " << mdlL;
  #endif

	//test whether the specified requirement has been met
	if(mdlL < e){
	  keep_going = 0;
	  break;
	}
	
	//calculate direction in parameter-space in which to travel
    p = v_negative(v_prod(H,dlL));
	if(v_prod(dlL,p) > 0.0){			//if H is very bad, dlLp can be positive instead of negative. In which case, swap signs so that the lL can be improved.
      for(unsigned int i=0; i<p.size(); i++){
        p[i] = 0.0 - p[i];
      }
    }
    #ifdef LIKELIHOOD_COUTS
	cout << endl << "\tp = " << p;
	#endif
	if(p!=p){		//if numerical instability has occured and p==nan
	  keep_going=0;
	  return(empty_vector);		//indicates failure.
	}
	
	dlL_old = dlL;
	//find magnitude along p that satisfies the Wolfe conditions (if exists)
	//found_alpha = Wolfe(no_params,k_new,k[idx],p,x_Brs,d_r_2,lL,dlL,c1,c2,alpha0,unif_1_n);
    Wolfe(no_params,k_new,k[idx],p,x_Brs,d_r_2,lL,dlL,c1,c2,alpha0,unif_1_n);
	//this alters k_new, and if found_alpha==1 (return value) then it alters lL and dlL also.
	alpha0 = 1;	//after the first iteration, always try alpha0=1 first.
	/*if(found_alpha == 0){
	  keep_going = 0;
	  break;
	}*/
	k.push_back(k_new);
	idx++;		//new parameter values have been accepted, and will be used for the next iteration.
	sk = v_subtract(k[idx],k[idx-1]);
	yk = v_subtract(dlL,dlL_old);
	pk = 1/(v_prod(yk,sk));
	sy = v_prod_T(sk,yk);
	ys = v_prod_T(yk,sk);
	ss = v_prod_T(sk,sk);
	psy = M_mult(pk,sy);
	pys = M_mult(pk,ys);
	pss = M_mult(pk,ss);
	H1 = M_subtract(I,psy);
	H2 = M_subtract(I,pys);
	H = crossprod(H,H2);
	H = crossprod(H1,H); 
	H = M_add(H,pss);
	
	#ifdef LIKELIHOOD_COUTS
	cout << endl << "H: ";
	M_view(H);
	#endif
	
	//if(idx>50)
	//cin >> tmp;
  }
  
  cout << endl << "Parameter(s) for sigma estimation: " << k[idx];
  
  return k[idx];
}

bool Wolfe(int no_params, vector<double> &x_new, vector<double> &x, vector<double> &p, vector<vector<double> > &x_Brs, vector<double> &d_r_2, double &lL, vector<double> &dlL, double c1, double c2, double alpha0, double unif_1_n)
{
  vector<double> sigma2_new;
  vector<vector<double> > grad_sigma2_new;
  vector<double> dlL_new;
  double dlLp = v_prod(dlL,p);
  double cond1 = c1*dlLp;
  double cond2 = c2*abs(dlLp);
  x_new.clear();
  bool success = 0;
  int ai = -1;
  
  //need to pass back more useful information
  
  double a_low = 0.0;
  double dlLp_low = dlLp;
  double lL_low = lL;
  double a_high = 0.0;
  double dlLp_high = 0.0;
  double lL_high = 0.0;
  double a_new = 1.0;
  double dlLp_new = 0.0;
  double lL_new = 0.0;
  
  vector<double> dr2s2;
  vector<double> f_n;
  vector<double> f_d;
    
  while(1==1){
    ai++;	
	//generate new alpha
	a_new = alpha0*exp(ai);

	//calculate required terms
	x_new = v_add(x,s_prod(a_new,p));
	
	sigma2_new = get_sigma2(no_params,x_Brs,x_new);
	grad_sigma2_new = get_grad_sigma2(no_params,x_Brs,x_new);
	
	#ifdef UD
      dr2s2 = d_r_2/sigma2_new;
	  f_n = get_normal(dr2s2,sigma2_new);
      f_d = get_f_d(x_new[0],unif_1_n,f_n);		//x_new[0] == k[0][0] == w
	  lL_new = logL_(f_d);
	  dlL_new = grad_logL_(f_d,f_n,x_new,sigma2_new,dr2s2,grad_sigma2_new,unif_1_n);
	  /*
	  cout << endl;
	  cout << endl << "lL " << lL_new;
	  cout << endl << "dlL " << dlL_new;
	  cout << endl;

	  cin >> tmp;
	  */
    #else
      lL_new = logL(d_r_2, sigma2_new);
	  dlL_new = grad_logL(d_r_2, sigma2_new, grad_sigma2_new);
    #endif
	
	//lL_new = logL(d_r_2, sigma2_new);
	//dlL_new = grad_logL(d_r_2, sigma2_new, grad_sigma2_new);
	
	dlLp_new = v_prod(dlL_new,p);
	#ifdef LIKELIHOOD_COUTS
	cout << endl << endl << "Line Search Find Interval. Iteration: " << ai;
	cout << endl << "a_low = " << a_low << "\tlL_low = " << lL_low << "\tdlLp_low = " << dlLp_low;
	cout << endl << "a_high = " << a_high << "\tlL_high = " << lL_high << "\tdlLp_high = " << dlLp_high;
	cout << endl << "a_new = " << a_new << "\tlL_new = " << lL_new << "\tdlLp_new = " << dlLp_new;
	#endif
	
	
	//test whether this alpha satisfies the Wolfe conditions
	if(lL_new > lL + a_new*cond1){	//condition 1
	  //cout << endl << "condition 1 failed";
	  a_high = a_new;
	  lL_high = lL_new;
	  dlLp_high = dlLp_new;
	  break;
	} else{
	  //cout << endl << "condition 1 passed";
	  if(abs(dlLp_new) <= cond2){	//condition 2
	    //cout << "\tcondition 2 passed";
		success = 1;
	    break;
	  } else {
	    //cout << "\tcondition 2 failed";
		if(dlLp_new < 0){
		  a_low = a_new;
		  lL_low = lL_new;
		  dlLp_low = dlLp_new;
		} else {
		  a_high = a_new;
		  lL_high = lL_new;
		  dlLp_high = dlLp_new;
		  break;
		}
	  }
	}
  }
  
  ////cout << endl << "|";
  ////cout << endl << "Found initial interval. Iterations: " << ai;
  ////cout << endl << "a_low = " << a_low << "\tlL_low = " << lL_low << "\tdlLp_low = " << dlLp_low;
  ////cout << endl << "a_high = " << a_high << "\tlL_high = " << lL_high << "\tdlLp_high = " << dlLp_high;
  ////cout << endl << "a_new = " << a_new << "\tlL_new = " << lL_new << "\tdlLp_new = " << dlLp_new;
  ////cout << endl << "success = " << success;
  ////cout << endl << "|";
  
  if(success==0){
    ai = 0;
    while(1==1){
	  ai++;
	  /*
	  if(lL_high<=lL){
	    //cubic using two derivatives
		cout << endl << "interpolate3";
		a_new = interpolate3(a_low,a_high,lL_low,lL_high,dlLp_low,dlLp_high);
	  } else if(a_low==0){
	    //quadratic
		cout << endl << "interpolate2";
		interpolate2(a_new, a_low, a_high, lL_low, lL_high, dlLp_low);
		if(a_new < a_low || a_new > a_high){	//there is no extremum within the interval.
		  cout << endl << "Interpolation failed - using average.";
		  a_new = (a_low+a_high)/2;
		}
	  } else {
		//cubic only using derivative at a_low
		cout << endl << "interpolate3_1";
		a_new = interpolate3(a_low,a_high,lL_low,lL_high,dlLp_low,lL);
	  }*/
	  
	  interpolate2(a_new, a_low, a_high, lL_low, lL_high, dlLp_low);
		if(a_new < a_low || a_new > a_high){	//there is no extremum within the interval.
		  #ifdef LIKELIHOOD_COUTS
		  cout << endl << "Interpolation failed - using average.";
		  #endif
		  a_new = (a_low+a_high)/2;
		}
	  
	  restrict_value(a_new,a_low,a_high);		//a very small change in a_low or a_high is pointless, so check whether: a_low + (a_high-a_low)/1000 <= a_new <= a_high - (a_high-a_low)/1000. If not, set a_new = (a_high+a_low)/2.
	  x_new = v_add(x,s_prod(a_new,p));
	  sigma2_new = get_sigma2(no_params,x_Brs,x_new);
	  lL_new = logL(d_r_2, sigma2_new);
	  grad_sigma2_new = get_grad_sigma2(no_params,x_Brs,x_new);
	  dlL_new = grad_logL(d_r_2, sigma2_new, grad_sigma2_new);
	  dlLp_new = v_prod(dlL_new,p);
	  #ifdef LIKELIHOOD_COUTS
	  cout << endl << "Line Search Interpolate " << ai;
	  cout << endl << "a_low = " << a_low << "\tlL_low = " << lL_low << "\tdlLp_low = " << dlLp_low;
	  cout << endl << "a_high = " << a_high << "\tlL_high = " << lL_high << "\tdlLp_high = " << dlLp_high;
 	  cout << endl << "a_new = " << a_new << "\tlL_new = " << lL_new << "\tdlLp_new = " << dlLp_new << endl;
	  #endif

	  if(a_new<a_low || a_new>a_high){
	    #ifdef LIKELIHOOD_COUTS
	    cout << endl << "Interval converged before satisfying conditions." << endl;
		#endif
		if(lL_low < lL_high){
		  a_new = a_low;
		  lL_new = lL_low;
		  dlLp_new = dlLp_low;
		} else {
		  a_new = a_high;
		  lL_new = lL_high;
		  dlLp_new = dlLp_high;
		}
		break;
	  }
	  
	  //test whether this alpha satisfies the Wolfe conditions
	  if(lL_new > lL + a_new*cond1){	//condition 1
	    a_high = a_new;
	    lL_high = lL_new;
	    dlLp_high = dlLp_new;
		#ifdef LIKELIHOOD_COUTS
	    cout << "\tcondition 1 failed";
		#endif
	  } else{
	    #ifdef LIKELIHOOD_COUTS
	    cout << "\tcondition 1 passed";
		#endif
	    if(abs(dlLp_new) <= cond2){	//condition 2
		  #ifdef LIKELIHOOD_COUTS
	      cout << "\tcondition 2 passed";
		  #endif
		  success = 1;
	      break;
	    } else {
		  #ifdef LIKELIHOOD_COUTS
	      cout << "\tcondition 2 failed";
		  #endif
		  if(dlLp_new < 0){
		    a_low = a_new;
		    lL_low = lL_new;
		    dlLp_low = dlLp_new;
		  } else {
		    a_high = a_new;
		    lL_high = lL_new;
		    dlLp_high = dlLp_new;
		  }
		  if(ai>4){								//after a few iterations
	        if(dlLp_low<0 && dlLp_high<0){		//if (possibly) stuck in an interval without a minimum
		      break;
		    }
			if(dlLp_low>0 && dlLp_high>0){		//if (possibly) stuck in an interval without a minimum
		      break;
		    }
	      }
	    }
	  }
	}
  }
  #ifdef LIKELIHOOD_COUTS
  cout << endl << "|";
  cout << endl << "Line Search Result:";
  cout << endl << "a_new = " << a_new << "\tlL_new = " << lL_new << "\tdlLp_new = " << dlLp_new;
  cout << endl << "x_new = " << x_new;
  cout << endl << "success = " << success;
  #endif
  lL = lL_new;
  dlL = dlL_new;
  
  return success;
}

bool interpolate2(double &a_new, double a1, double a2, double f1, double f2, double df1)
{
  double da = a2-a1;
  double dada = da*da;
  double c2 = (2*(f2-f1-(df1*da)))/dada;
  bool is_min = 0;
  a_new = a1-(df1/c2);
  if(c2 > 0){
    is_min = 1;
  }
  return is_min;
}

double interpolate3(double a1, double a2, double f1, double f2, double df1, double df2)
//finds cubic interpolant between two points a1 and a2, with known function f1 and f2, and derivatives df1 and df2.
//if extremum exists between a1 and a2 then this value is passed back.
//Otherwise, a quadratic interpolant is attempted.
//If that also fails, the average of these points, (a1+a2)/2 is passed back.
{
  //bool is_min = 0;
  bool extremum_in_interval = 1;
  double a_new;
  double b_new;
  cout << endl << endl << "a1 a2: " << a1 << " " << a2;
  double da = a2-a1;
  cout << endl << "da: " << da;
  double da2 = da*da;
  cout << endl << "da2: " << da2;
  double df = f2-f1;
  cout << endl << "f1 f2: " << f1 << " " << f2;
  cout << endl << "df: " << df;
  double ddf = df2-df1;
  cout << endl << "df1 df2: " << df1 << " " << df2;
  cout << endl << "ddf: " << ddf;
  double c = ((3.0*df) + (da*(df2-(4.0*df1))))/da2;
  cout << endl << "c: " << c;
  double d = (ddf-(2.0*c*da))/(3.0*da2);
  cout << endl << "d: " << d;
  double sq = (c*c)-(3.0*df1*d);
  
  cout << endl << "sq = " << sq;
  if(sq >= 0){						//extrema exist
	b_new = (-c-sqrt(sq))/(3*d);
	cout << endl << "bnew- = " << b_new;
	cout << endl << "fq' : " << df1 + (2.0*c*b_new) + (3.0*d*b_new*b_new);
	cout << endl << "fq'' : " << (2.0*c) + (6.0*d*b_new);
    if(b_new > 0.0 && b_new < da){	//this extremum is within the interval.
	  a_new = a1 + b_new;
	  cout << endl << "interpolate_pass_1\t" << b_new;
	} else {
	  b_new = (-c+sqrt(sq))/(3*d);
	  cout << endl << "bnew+ = " << b_new;
	  cout << endl << "fq' : " << df1 + (2.0*c*b_new) + (3.0*d*b_new*b_new);
	  cout << endl << "fq'' : " << (2.0*c) + (6.0*d*b_new);
	  if(b_new > 0.0 && b_new < da){	//this extremum is within the interval.
	    a_new = a1 + b_new;
		cout << endl << "interpolate_pass_2\t" << b_new;
	  } else {						//there is no extremum within the interval.
	    extremum_in_interval = 0;
	  }
	}
  } else {
    extremum_in_interval = 0;
  }
  
  if(extremum_in_interval==0){
    cout << endl << "Using quadractic interpolant.";
    //is_min = interpolate2(a_new, a1, a2, f1, f2, df1);
    interpolate2(a_new, a1, a2, f1, f2, df1);
	if(a_new < a1 || a_new > a2){	//there is no extremum within the interval.
	  cout << endl << "Interpolation failed - using average.";
	  a_new = (a1+a2)/2;
	}
  }
  
  return a_new;
}

double interpolate3_0(double a1, double a2, double f1, double f2, double df1, double f0)
//finds cubic interpolant between two points a1 and a2, with known function f1 and f2, and derivatives df1 and df2.
//if extremum exists between a1 and a2 then this value is passed back.
//Otherwise, a quadratic interpolant is attempted.
//If that also fails, the average of these points, (a1+a2)/2 is passed back.
{
  //bool is_min = 0;
  bool extremum_in_interval = 1;
  double a_new;
  double b_new;
  cout << endl << endl << "a1 a2: " << a1 << " " << a2;
  double da = a2-a1;
  cout << endl << "da: " << da;
  double da2 = da*da;
  cout << endl << "da2: " << da2;
  cout << endl << "f0 f1 f2: " << f0 << " " << f1 << " " << f2;
  cout << endl << "df1: " << df1;
  double d = (((a1*f2)-(da2*f0)+(a2*(a2-(2*a1))*f1)-(a1*a2*da*df1))/(a1*a1*a2*da2));
  cout << endl << "d: " << d;
  double c = ((f0-f1+(df1*a1)+(d*a1*a1*a1))/(a1*a1));
  cout << endl << "c: " << c;
  double sq = (c*c)-(3.0*df1*d);
  
  cout << endl << "sq = " << sq;
  if(sq >= 0){						//extrema exist
	b_new = (-c-sqrt(sq))/(3*d);
	cout << endl << "bnew- = " << b_new;
	cout << endl << "fq' : " << df1 + (2.0*c*b_new) + (3.0*d*b_new*b_new);
	cout << endl << "fq'' : " << (2.0*c) + (6.0*d*b_new);
    if(b_new > 0.0 && b_new < da){	//this extremum is within the interval.
	  a_new = a1 + b_new;
	  cout << endl << "interpolate_pass_1\t" << b_new;
	} else {
	  b_new = (-c+sqrt(sq))/(3*d);
	  cout << endl << "bnew+ = " << b_new;
	  cout << endl << "fq' : " << df1 + (2.0*c*b_new) + (3.0*d*b_new*b_new);
	  cout << endl << "fq'' : " << (2.0*c) + (6.0*d*b_new);
	  if(b_new > 0.0 && b_new < da){	//this extremum is within the interval.
	    a_new = a1 + b_new;
		cout << endl << "interpolate_pass_2\t" << b_new;
	  } else {						//there is no extremum within the interval.
	    extremum_in_interval = 0;
	  }
	}
  } else {
    extremum_in_interval = 0;
  }
  
  if(extremum_in_interval==0){
    cout << endl << "Using quadractic interpolant.";
    //is_min = interpolate2(a_new, a1, a2, f1, f2, df1);
    interpolate2(a_new, a1, a2, f1, f2, df1);
	if(a_new < a1 || a_new > a2){	//there is no extremum within the interval.
	  cout << endl << "Interpolation failed - using average.";
	  a_new = (a1+a2)/2;
	}
  }
  
  return a_new;
}


void restrict_value(double &a_new, double a_low, double a_high)
{
  double temp = (a_high-a_low)/100;
  
  if(a_low+temp > a_new || a_high-temp < a_new){
    a_new = (a_low+a_high)/2;
  }

  return;
}
