//
// This file is part of ProSMART.
//

#include "restrain_align_handler.h"

ostream& operator<<(ostream& out, residue_alignment &v)
{
  cout << v.res1 << " " << v.res2 << " " << v.dist << endl;
  return out;
}

vector<residue_alignment> align_to_self(PDBfile &pdb)
{
    vector<residue_alignment> result;
    residue_alignment record;
    int idx=-1;
    for(int i=0; i<pdb.size(); i++){
        if(pdb.get_resnum(i)==idx) continue;
        while(pdb.get_resnum(i)>idx)idx++;
        if(pdb.get_resnum(i)==idx){
            record.res1 = idx;
            record.res2 = idx;
            record.dist = 0.0;
            record.sideAV = 0.0;
            result.push_back(record);
        }
    }
    return result;
}

vector<residue_alignment> file_read_alignment(string &filein_string, vector<res_corresp> &orig_res1, vector<res_corresp> &orig_res2, int NONINCREASING_RESNUM2)
{
  const char *filein = filein_string.c_str();
  ifstream infile(filein, ios::in);
  string line;
  
  string align_ins1;
  string align_ins2;
  int align_resnum1;
  int align_resnum2;
  
  int idx1=0;
  int idx2=0;
  vector<residue_alignment> result;
  residue_alignment record;

  vector<string> line_v;
  char last_char;
  
  if(infile.is_open()){
    while(!infile.eof()){
	  line.clear();
	  getline(infile,line);
	  
	  line_v = string_to_vector(line);
        if(line[0]=='#')continue;
        if(line_v.size()<11)continue;
        if(line_v[0]=="Res1")continue;
	  
		last_char = line_v[0][line_v[0].size()-1];
		if(last_char=='0' || last_char=='1' || last_char=='2' || last_char=='3' || last_char=='4' || last_char=='5' || last_char=='6' || last_char=='7' || last_char=='8' || last_char=='9'){
		  align_ins1 = " ";
		} else {
		  align_ins1 = last_char;
		  line_v[0].erase(line_v[0].size()-1);
		}
		
		last_char = line_v[1][line_v[1].size()-1];
		if(last_char=='0' || last_char=='1' || last_char=='2' || last_char=='3' || last_char=='4' || last_char=='5' || last_char=='6' || last_char=='7' || last_char=='8' || last_char=='9'){
		  align_ins2 = " ";
		} else {
		  align_ins2 = last_char;
		  line_v[1].erase(line_v[1].size()-1);
		}
		
		align_resnum1 = str_to_int(line_v[0]);
		align_resnum2 = str_to_int(line_v[1]);
		
		while(orig_res1[idx1].res < align_resnum1){
		  idx1++;
		}
		while(orig_res1[idx1].res == align_resnum1 && orig_res1[idx1].ins != align_ins1[0]){
		  idx1++;
		}
		//cout << endl << ":" << orig_res1[idx1].res << ":" << align_resnum1 << ":" << orig_res1[idx1].ins << ":" << align_ins1[0] << ":";
		if(orig_res1[idx1].res != align_resnum1 || orig_res1[idx1].ins != align_ins1[0]){
		  continue;
		}
		
		if(NONINCREASING_RESNUM2==1){
		  idx2 = 0;
		}
		
		while(orig_res2[idx2].res < align_resnum2){
		  idx2++;
		}
		while(orig_res2[idx2].res == align_resnum2 && orig_res2[idx2].ins != align_ins2[0]){
		  idx2++;
		}
		if(orig_res2[idx2].res != align_resnum2 || orig_res2[idx2].ins != align_ins2[0]){
		  continue;
		}
		
		record.res1 = idx1;
		record.res2 = idx2;
		record.dist = str_to_double(line_v[6]);
		record.sideAV = str_to_double(line_v[10]);
		result.push_back(record);
		
		//cout << endl << line_v;
		//cout << record.res1 << " " << record.res2 << " " << record.dist;
    }
	cout << endl << "Successfully read file: " << filein << endl;
    infile.close();
  } 
  else cout << "Unable to open " << filein << " for reading" << endl; 
  
  return result;
}

vector<residue_alignment> filter_alignment(vector<residue_alignment> &old_align, double cutoff, double side_cutoff)
{
  vector<residue_alignment> result;
  result.clear();
  
  for(unsigned int i=0; i<old_align.size(); i++){
    if(old_align[i].dist <= cutoff && old_align[i].sideAV <= side_cutoff){
	  result.push_back(old_align[i]);
	}
  }

  return result;
}

vector<vector<string> > file_read_bonds(string &filein_string, char chain, int MAX_RESNUM)
{
  const char *filein = filein_string.c_str();
  ifstream infile(filein, ios::in);
  string line;
  string temp;
  
  vector<string> arg;
  vector<vector<string> > args;
  string res1;
  
  int idx;
  int idx1=0;
  stringstream ss;
  bool IS_ALT;
  
  if(infile.is_open()){
    while(!infile.eof()){
	  line.clear();
	  getline(infile,line);
	  if(line.size() > 4){
	    if(line[0] == "#"[0]){continue;}
	    
	    idx = 0;
	    temp = line[0];
		arg.clear();
		IS_ALT = 0;
		
		idx1 = 0;
		while(idx1<5){			//chain
		  idx1++;
		  res1 = "";
	      while(temp != " "){
	        res1 += temp;
		    idx++;
		    temp = line[idx];
	      }
		  while(temp == " "){
		    idx++;
			temp = line[idx];
		  }
		}
		arg.push_back(res1);
		
		idx1 = 0;				//resnum
		while(idx1<2){
		  idx1++;
		  res1 = "";
	      while(temp != " "){
	        res1 += temp;
		    idx++;
		    temp = line[idx];
	      }
		  while(temp == " "){
		    idx++;
			temp = line[idx];
		  }
		}
		arg.push_back(res1);
		
		idx1 = 0;
		while(idx1<2){			//ins code
		  idx1++;
		  res1 = "";
	      while(temp != " "){
	        res1 += temp;
		    idx++;
		    temp = line[idx];
	      }
		  while(temp == " "){
		    idx++;
			temp = line[idx];
		  }
		}
		arg.push_back(res1);
		
		idx1 = 0;
		while(idx1<2){			//atom
		  idx1++;
		  res1 = "";
	      while(temp != " "){
	        res1 += temp;
		    idx++;
		    temp = line[idx];
	      }
		  while(temp == " "){
		    idx++;
			temp = line[idx];
		  }
		}
		arg.push_back(res1);
		
		idx1 = 0;
		while(idx1<2){			//check alt code
		  idx1++;
		  res1 = "";
	      while(temp != " "){
	        res1 += temp;
		    idx++;
		    temp = line[idx];
	      }
		  while(temp == " "){
		    idx++;
			temp = line[idx];
		  }
		}
		if(res1 != "."){
		  IS_ALT = 1;
		}
		
		idx1 = 0;
		while(idx1<3){			//chain
		  idx1++;
		  res1 = "";
	      while(temp != " "){
	        res1 += temp;
		    idx++;
		    temp = line[idx];
	      }
		  while(temp == " "){
		    idx++;
			temp = line[idx];
		  }
		}
		arg.push_back(res1);
		
		idx1 = 0;
		while(idx1<2){			//resnum
		  idx1++;
		  res1 = "";
	      while(temp != " "){
	        res1 += temp;
		    idx++;
		    temp = line[idx];
	      }
		  while(temp == " "){
		    idx++;
			temp = line[idx];
		  }
		}
		arg.push_back(res1);
		
		idx1 = 0;
		while(idx1<2){			//ins code
		  idx1++;
		  res1 = "";
	      while(temp != " "){
	        res1 += temp;
		    idx++;
		    temp = line[idx];
	      }
		  while(temp == " "){
		    idx++;
			temp = line[idx];
		  }
		}
		arg.push_back(res1);
		
		idx1 = 0;
		while(idx1<2){			//atom
		  idx1++;
		  res1 = "";
	      while(temp != " "){
	        res1 += temp;
		    idx++;
		    temp = line[idx];
	      }
		  while(temp == " "){
		    idx++;
			temp = line[idx];
		  }
		}
		arg.push_back(res1);
		
		idx1 = 0;
		while(idx1<2){			//check alt code
		  idx1++;
		  res1 = "";
	      while(temp != " "){
	        res1 += temp;
		    idx++;
		    temp = line[idx];
	      }
		  while(temp == " "){
		    idx++;
			temp = line[idx];
		  }
		}
		if(res1 != "."){
		  IS_ALT = 1;
		}
		
		if(arg[0][0]==chain && arg[4][0]==chain && IS_ALT==0){
		  if(str_to_int(arg[1])<=MAX_RESNUM && str_to_int(arg[5])<=MAX_RESNUM){
		    args.push_back(arg);
		  }
		}
	  }
    }
	cout << endl << "Successfully read file: " << filein << endl;
    infile.close();	
  } 
  else cout << endl << "Bonded atom-pairs not removed - file " << filein << " could not be found." << endl; 
  
  return args;
}

vector<residue_alignment> get_atomic_alignment(vector<residue_alignment> &res_align, Residues &res1, Residues &res2, int HELIX)
{
  vector<residue_alignment> result;
  residue_alignment current;
  int resnum1;
  int resnum2;
  string resid1;
  string resid2;
  vector<int> atoms1;
  vector<int> atoms2;
  vector<string> amino1;
  vector<string> amino2;
  
  for(unsigned int i=0; i<res_align.size(); i++){
    resnum1 = res_align[i].res1;
	resnum2 = res_align[i].res2;
	//current.dist = res_align[i].dist;
    resid1 = res1.get_resid(resnum1);
	resid2 = res2.get_resid(resnum2);
	atoms1 = res1.get_atoms(resnum1);
	atoms2 = res2.get_atoms(resnum2);
	amino1 = res1.get_amino(resnum1);
	amino2 = res2.get_amino(resnum2);
     
     //this hack avoids nomenclature errors (e.g. position of nucleic acid code)
     //should probably be replaced by command line flag... and appropriate warning message should be displayed to prompt user to enable option! Low sequence identity reported by prosmart_align would be indicator.
     //plus need to allow backbone restraints if D/RNA type is different
     resid1 = delete_spaces(resid1);
     resid2 = delete_spaces(resid2);
     
     /*cout << endl << i << "\t" << resnum1 << " " << resid1 << "\t";
      cout << resnum2 << " " << resid2 << " ";
      cout << endl;*/
	if(resid1 == resid2 && HELIX == 0){		//if the aligned residues are the same
	  for(unsigned int k=0; k<atoms1.size(); k++){	//search for atomic correspondences
	    if(atoms2.size()>k){
	      if(amino1[k] == amino2[k]){
		    current.res1 = atoms1[k];
		    current.res2 = atoms2[k];
			result.push_back(current);
			//cout << "\t" << current.res1 << " " << current.res2;
			continue;
		  }
		}
	    for(unsigned int j=0; j<amino2.size(); j++){
		  if(amino1[k] == amino2[j]){
		    current.res1 = atoms1[k];
		    current.res2 = atoms2[j];
			result.push_back(current);
			//cout << "\t" << current.res1 << " " << current.res2;
			break;
		  }
		}
	  }
	} else {								//if the aligned residues are different
	  for(unsigned int j=0; j<amino1.size(); j++){	//only get the N, CA, C and O atoms
	    if(amino1[j] == " N  "){
	      for(unsigned int k=0; k<amino2.size(); k++){
			if(amino2[k] == " N  "){
	          current.res1 = atoms1[j];
		      current.res2 = atoms2[k];
			  result.push_back(current);
			  //cout << "\t" << current.res1 << " " << current.res2;
			  break;
	        }
	      }
		  break;
	    }
	  }
	  for(unsigned int j=0; j<amino1.size(); j++){
	    if(amino1[j] == " CA "){
	      for(unsigned int k=0; k<amino2.size(); k++){
			if(amino2[k] == " CA "){
	          current.res1 = atoms1[j];
		      current.res2 = atoms2[k];
			  result.push_back(current);
			  //cout << "\t" << current.res1 << " " << current.res2;
			  break;
	        }
	      }
		  break;
	    }
	  }
	  for(unsigned int j=0; j<amino1.size(); j++){
	    if(amino1[j] == " C  "){
	      for(unsigned int k=0; k<amino2.size(); k++){
			if(amino2[k] == " C  "){
	          current.res1 = atoms1[j];
		      current.res2 = atoms2[k];
			  result.push_back(current);
			  //cout << "\t" << current.res1 << " " << current.res2;
			  break;
	        }
	      }
		  break;
	    }
	  }
	  for(unsigned int j=0; j<amino1.size(); j++){
	    if(amino1[j] == " O  "){
	      for(unsigned int k=0; k<amino2.size(); k++){
			if(amino2[k] == " O  "){
	          current.res1 = atoms1[j];
		      current.res2 = atoms2[k];
			  result.push_back(current);
			  //cout << "\t" << current.res1 << " " << current.res2;
			  break;
	        }
	      }
		  break;
	    }
	  }
	}
  }

  return result;
}

Array1D<int> get_atom_indexes(vector<residue_alignment> &atom_align, int N)
{
  Array1D<int> result(N,-1);

  for(unsigned int i=0; i<atom_align.size(); i++){
    result[atom_align[i].res1] = i;
  }

  return result;
}
