//********************************************************************************
//*
//*  C++ finite element method for heat equation
//*  James sandham
//*  15 April 2015
//*
//********************************************************************************

//********************************************************************************
//
// HeatFE is free software; you can redistribute it and/or modify it under the
// terms of the GNU Lesser General Public License (as published by the Free
// Software Foundation) version 2.1 dated February 1999.
//
//********************************************************************************

#include<iostream>
#include<fstream>
#include"Element.h"
#include"Solver.h"
#include"FeModel.h"
#include"FeConst.h"
#include"math.h"
#include"PCG.h"
#include"AMG.h"


//********************************************************************************
//
//  Solve AT^(n+1) = b where  A = (M+dt*K) and b = (M*T^n+Q)
//
//********************************************************************************



//----------------------------------------------------------------------------
// Constructor for Solver
//----------------------------------------------------------------------------
Solver::Solver(FeModel fem) : fem(fem)
{
  neq = fem.N;
  nrow = new int[neq+1];
  ncol = new int[neq*MAX_NNZ];
  A = new double[neq*MAX_NNZ];
  x = new double[neq];
  b = new double[neq];

  for(int i=0;i<neq+1;i++)
    nrow[i] = 0;
  for(int i=0;i<neq*MAX_NNZ;i++){
    ncol[i] = -1;
    A[i] = 0.0;
  }
  for(int i=0;i<neq;i++){
    x[i] = 0.0;
    b[i] = 0.0;
  }
}




//----------------------------------------------------------------------------
// Constructor for Preconditioned Conjugate Gradient Solver
//----------------------------------------------------------------------------
SolverPCG::SolverPCG(FeModel fem) : Solver(fem)
{
   
}



//----------------------------------------------------------------------------
// Constructor for Algebraic Multigrid Solver
//----------------------------------------------------------------------------
SolverAMG::SolverAMG(FeModel fem) : Solver(fem)
{
 
}





//----------------------------------------------------------------------------
// Assemble Global A Matrix
//----------------------------------------------------------------------------
void Solver::assembleGAM()
{
  int r=0,c=0;
  double v=0;

  for(int k=0;k<fem.Ne;k++){
    Element e(fem.Type);
    for(int l=0;l<e.Npe;l++){
      e.nodes[l] = fem.connect[k][l];
      e.xpts[l] = fem.xpoints[e.nodes[l]-1];
      e.ypts[l] = fem.ypoints[e.nodes[l]-1];
      e.zpts[l] = fem.zpoints[e.nodes[l]-1];
    }

    e.stiffnessMatrix();

    //for(int i=0;i<e.Npe;i++){
    //  for(int j=0;j<e.Npe;j++){
    //    std::cout<<e.kmat[i][j]<<"  ";   
    //  } 
    //  std::cout<<" "<<std::endl;
    //}
    //std::cout<<" "<<std::endl;

    //update ncol, A arrays
    if(fem.time){ //time dependent problem
      e.massMatrix();
      for(int i=0;i<e.Npe;i++){
        for(int j=0;j<e.Npe;j++){
          r = e.nodes[i];
          c = e.nodes[j]; 
          v = e.mmat[i][j] + fem.dt*e.kmat[i][j];
          for(int p=MAX_NNZ*r-MAX_NNZ;p<MAX_NNZ*r;p++){
            if(ncol[p]==-1){
              ncol[p] = c-1;
              A[p] = v;
              break;
            }
            else if(ncol[p]==c-1){
              A[p] += v;
              break;
            } 
          }
        }
      }
    }
    else{
      for(int i=0;i<e.Npe;i++){
        for(int j=0;j<e.Npe;j++){
          r = e.nodes[i];
          c = e.nodes[j];
          v = e.kmat[i][j];
          for(int p=MAX_NNZ*r-MAX_NNZ;p<MAX_NNZ*r;p++){
            if(ncol[p]==-1){
              ncol[p] = c-1;
              A[p] = v;
              break;
            }
            else if(ncol[p]==c-1){
              A[p] += v;
              break;
            }
          }
        }
      }
    }
  }

  //update nrow array
  int jj = 0;
  for(int i=1;i<neq+1;i++){
    for(int j=0;j<MAX_NNZ;j++){
      if(ncol[i*MAX_NNZ-MAX_NNZ+j]==-1){jj=j; break;}
    }
    nrow[i] = nrow[i-1] + jj;
  }

  //sort (insertion) ncol and A arrays
  for(int p=0;p<neq;p++){
    for(int i=0;i<nrow[p+1]-nrow[p];i++){
      int entry = ncol[i+p*MAX_NNZ];
      double entryA = A[i+p*MAX_NNZ];
      int index = i+p*MAX_NNZ;
      for(int j=i-1;j>=0;j--){
        if(entry<ncol[j+p*MAX_NNZ]){
        int a = ncol[j+p*MAX_NNZ];
        double b = A[j+p*MAX_NNZ];
        ncol[j+p*MAX_NNZ] = entry;
        A[j+p*MAX_NNZ] = entryA;
        ncol[index] = a;
        A[index] = b;
        index = j+p*MAX_NNZ;
        }
      }
    }
  }

  //compress ncol and A arrays
  int index=0;
  for(int i=0;i<neq*MAX_NNZ;i++){
    if(ncol[i]==-1)
      continue;
    ncol[index] = ncol[i];
    A[index] = A[i];
    index++;
  }

  for(int i=index;i<neq*MAX_NNZ;i++){
    ncol[i] = -1;
    A[i] = 0.0;
  }
}



//----------------------------------------------------------------------------
// Assemble Global RHS Vector
//----------------------------------------------------------------------------
void Solver::assembleGRHSV()
{
  int r=0,c=0;
  double v=0;
  double Q = 0.0;

  for(int i=0;i<neq;i++){b[i] = 0.0;}

  //add surface heat flux boundary conditions to global RHS vector
  if(fem.nHFGrps>0){
    for(int i=0;i<fem.Nb;i++){
      int grp = fem.bconnect[i][0];
  
      //surface heat flux boundary condition
      for(int j=0;j<fem.nHFGrps;j++){
        if(grp==fem.heatFluxGrps[j]){
          Element e(fem.bType);
          for(int k=0;k<e.Npe;k++){
            e.nodes[k] = fem.bconnect[i][k+1];
            e.xpts[k] = fem.xpoints[e.nodes[k]-1];
            e.ypts[k] = fem.ypoints[e.nodes[k]-1];
            e.zpts[k] = fem.zpoints[e.nodes[k]-1];
          }
  
          e.elementVector();
          for(int k=0;k<e.Npe;k++){
            b[e.nodes[k]-1] += fem.heatFluxVals[j]*e.evec[k];
          }
        }
      }
    }
  }

  //add surface convection boundary conditions to global RHS vector
  //if(fem.nCGrps>0){
  //
  //}


  //add internal heat generation vector and then update for time dependence
  for(int k=0;k<fem.Ne;k++){
    Element e(fem.Type);
    for(int l=0;l<e.Npe;l++){
      e.nodes[l] = fem.connect[k][l];
      e.xpts[l] = fem.xpoints[e.nodes[l]-1];
      e.ypts[l] = fem.ypoints[e.nodes[l]-1];
      e.zpts[l] = fem.zpoints[e.nodes[l]-1];
    }

    //internal heat generation vector Q
    e.elementVector();
    for(int i=0;i<e.Npe;i++){
      b[e.nodes[i]-1] += Q*e.evec[i];
    }

    //update RHS = M*T^n + Q 
    e.massMatrix();
    for(int i=0;i<e.Npe;i++){
      for(int j=0;j<e.Npe;j++){
        r = e.nodes[i];
        c = e.nodes[j];
        v = e.mmat[i][j];

        b[r-1] += v*x[c-1];
      }
    }
  }
}




//-------------------------------------------------------------------
// Dirichlet Boundary Conditions
//-------------------------------------------------------------------
void Solver::dirichletBC()
{
  //apply dirichlet BC to A and b using method of large numbers
  for(int i=0;i<fem.Nb;i++){
    int grp = fem.bconnect[i][0];
    for(int j=0;j<fem.nDirGrps;j++){   
      if(grp==fem.dirichletGrps[j]){
        int l=fem.bconnect[i][1]-1;   
        for(int p=nrow[l];p<nrow[l+1];p++){
          if(ncol[p]==l){A[p] = LARGE_NUM;}
        }
        b[l] = fem.dirichletVals[j]*LARGE_NUM;
        break;
      }
    }
  }
}



//-------------------------------------------------------------------
// Solve (PCG)
//-------------------------------------------------------------------
int SolverPCG::solve()
{
  int iter = 0;
  
  //assemble global A matrix: A = M+dt*K (transient) or A = K (stationary)
  assembleGAM();

  if(fem.time){    //time dependent problem
    for(int i=0;i<fem.steps;i++){
      //assemble global RHS vector
      assembleGRHSV();

      //apply any boundary conditions
      dirichletBC();   

      //solve using preconditioned conjugate gradient
      iter = pcg(nrow,ncol,A,x,b,neq,10e-8,10000);
    }
    //write solution vector to output file
    writeResults();  
  }
  else{
    //assemble global RHS vector
    assembleGRHSV();

    //apply any boundary conditions
    dirichletBC();   
    
    //solve using preconditioned conjugate gradient
    iter = pcg(nrow,ncol,A,x,b,neq,10e-8,10000);

    //write solution vector to output file
    writeResults(); 
  }
  
  return iter;
}



//-------------------------------------------------------------------
// Solve (AMG)
//-------------------------------------------------------------------
int SolverAMG::solve()
{
  int iter = 0;

  //assemble global A matrix: A = M+dt*K (transient) or A = K (stationary)
  assembleGAM();

  if(fem.time){    //time dependent problem
    for(int i=0;i<fem.steps;i++){
      //assemble global RHS vector
      assembleGRHSV();

      //apply any dirichlet boundary conditions
      dirichletBC();

      //solve using preconditioned conjugate gradient
      amg(nrow,ncol,A,x,b,neq,0.25,10e-8);
    }
    //write solution vector to output file
    writeResults();    
  }
  else{
    //assemble global RHS vector
    assembleGRHSV();

    //apply any boundary conditions
    dirichletBC();  
    
    //solve using algebraic multigrid
    amg(nrow,ncol,A,x,b,neq,0.25,10e-8);

    //write solution vector to output file
    writeResults();    
  }
  
  return iter;
}



//----------------------------------------------------------------------------
// Delete solver
//----------------------------------------------------------------------------
void Solver::deleteSolver()
{
  delete [] nrow;
  delete [] ncol;
  delete [] A;
  delete [] x;
  delete [] b;
}




//----------------------------------------------------------------------
// Write Results to Output File
//----------------------------------------------------------------------
void Solver::writeResults()
{
  std::ofstream myfile;
  myfile.open("output.pos");

  char typ[3];
  int npe;

  switch(fem.Type)
  {
    case 1:
      typ[0] = 'S'; typ[1] = 'L'; typ[2] = ' '; npe=2;
      break;
    case 2:
      typ[0] = 'S'; typ[1] = 'T'; typ[2] = ' '; npe=3;
      break;
    case 3:
      typ[0] = 'S'; typ[1] = 'Q'; typ[2] = ' '; npe=4;
      break;
    case 4:
      typ[0] = 'S'; typ[1] = 'S'; typ[2] = ' '; npe=4;
      break;
    case 8:
      typ[0] = 'S'; typ[1] = 'L'; typ[2] = '2'; npe=3;
      break;
    case 9:
      typ[0] = 'S'; typ[1] = 'T'; typ[2] = '2'; npe=6;
      break;
    case 10:
      typ[0] = 'S'; typ[1] = 'Q'; typ[2] = '2'; npe=8;
      break;
    case 11:
      typ[0] = 'S'; typ[1] = 'S'; typ[2] = '2'; npe=10;
      break;
  }

  myfile<<"View \"Temperature\" {\n";
  for(int i=0;i<fem.Ne;i++){
    myfile<<typ[0]<<typ[1]<<typ[2]<<" (";
    for(int j=0;j<npe;j++){
      myfile<<fem.xpoints[fem.connect[i][j]-1]<<","
            <<fem.ypoints[fem.connect[i][j]-1]<<","
            <<fem.zpoints[fem.connect[i][j]-1];
      if(j<npe-1){
        myfile<<",";
      }
      else{
        myfile<<"){";
      }
    }

    for(int j=0;j<npe;j++){ 
      myfile<<x[fem.connect[i][j]-1];
      if(j<npe-1){
        myfile<<",";
      }
      else{
        myfile<<"};\n";
      }
    }
  }
  myfile<<"};\n";
  myfile.close();
}
