//********************************************************************************
//
// IterSolvers: A collection of Iterative Solvers
// Written by James Sandham
// 3 March 2015
//
//********************************************************************************

//********************************************************************************
//
// IterSolvers is free software; you can redistribute it and/or modify it under the
// terms of the GNU Lesser General Public License (as published by the Free
// Software Foundation) version 2.1 dated February 1999.
//
//********************************************************************************

#include<stdlib.h>
#include<stdio.h>
#include"pSLAF.h"
#include"math.h"
#include<mpi.h>


//********************************************************************************
//
// Parallel Sparse linear algebra functions
//
//********************************************************************************


//-------------------------------------------------------------------------------
// sparse parallel matrix-vector product y = A*x
//-------------------------------------------------------------------------------
void pmatrixVectorProduct(const int r[], const int c[], const double v[], const double x[], 
                          double y[], const int m, const int n, const int id, const int np)
{
  double *total_x = malloc(n*sizeof(total_x));
  MPI_Allgather(x,m,MPI_DOUBLE,total_x,m,MPI_DOUBLE,MPI_COMM_WORLD);

  for(int i=0;i<m;i++){
    double s = 0;
    for(int j=r[i];j<r[i+1];j++)
      s += v[j]*total_x[c[j]];
    y[i] = s;
  }

  free(total_x);
}





//-------------------------------------------------------------------------------
// sparse parallel matrix-vector product y = A*x
//-------------------------------------------------------------------------------
void pmatrixVectorProduct2(const int r[], const int c[], const double v[], const double x[], 
                           double y[], const int m, const int n, const int id, const int np, 
                           const int pmap[], const int msize)
{
  MPI_Status status;

  double *total_x = malloc(n*sizeof(total_x));
  for(int i=0;i<n;i++){total_x[i] = 0.0;}

  for(int i=id*m;i<(id+1)*m;i++){  //assumes m is the same accross all processes (fix this later)
    total_x[i] = x[i-id*m];
  }

  for(int i=0;i<msize;i++){
    if(pmap[i]<id){  //recieve data first from process pmap[i], then recieve from pmap[i]
      MPI_Recv(y,m,MPI_DOUBLE,pmap[i],1,MPI_COMM_WORLD,&status);
      for(int j=pmap[i]*m;j<(pmap[i]+1)*m;j++){
        total_x[j] = y[j-pmap[i]*m];
      }
      MPI_Send(x,m,MPI_DOUBLE,pmap[i],1,MPI_COMM_WORLD);
    }
    else if(pmap[i]>id){ //send data first to process pmap[i], then recieve from pmap[i]
      MPI_Send(x,m,MPI_DOUBLE,pmap[i],1,MPI_COMM_WORLD);
      MPI_Recv(y,m,MPI_DOUBLE,pmap[i],1,MPI_COMM_WORLD,&status);
      for(int j=pmap[i]*m;j<(pmap[i]+1)*m;j++){
        total_x[j] = y[j-pmap[i]*m];
      }
    }
  }

  for(int i=0;i<m;i++){
    double s = 0;
    for(int j=r[i];j<r[i+1];j++)
      s += v[j]*total_x[c[j]];
    y[i] = s;
  }

  free(total_x);
}





//-------------------------------------------------------------------------------
// sparse parallel dot product z = x*y
//-------------------------------------------------------------------------------
double pdotProduct(const double x[], const double y[], const int m, const int n, 
                   const int id, const int np)
{
  double global_prod = 0.0;
  double local_prod = 0.0;
  for(int i=0;i<m;i++){
    local_prod = local_prod + x[i]*y[i];
  }

  MPI_Allreduce(&local_prod,&global_prod,1,MPI_DOUBLE,MPI_SUM,MPI_COMM_WORLD);

  return global_prod;
}




//-------------------------------------------------------------------------------
// parallel error e = |b-A*x|
//-------------------------------------------------------------------------------
double error(const int r[], const int c[], const double v[], const double x[],
             const double b[], const int m, const int n, const int id, const int np)
{
  double *total_x = malloc(n*sizeof(total_x));
  MPI_Allgather(x,m,MPI_DOUBLE,total_x,m,MPI_DOUBLE,MPI_COMM_WORLD);

  double local_e = 0.0;
  double global_e = 0.0;
  for(int j=0;j<m;j++){
    double s = 0.0;
    for(int i=r[j];i<r[j+1];i++)
      s += v[i]*total_x[c[i]];
    local_e = local_e + (b[j] - s)*(b[j] - s);
  }

  free(total_x);

  MPI_Allreduce(&local_e,&global_e,1,MPI_DOUBLE,MPI_SUM,MPI_COMM_WORLD);

  return sqrt(global_e);
}
