/*******************************************************************************
* Copyright 2014-2020 Intel Corporation.
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

//@HEADER
// ***************************************************
//
// HPCG: High Performance Conjugate Gradient Benchmark
//
// Contact:
// Michael A. Heroux ( maherou@sandia.gov)
// Jack Dongarra     (dongarra@eecs.utk.edu)
// Piotr Luszczek    (luszczek@eecs.utk.edu)
//
// ***************************************************
//@HEADER

/*!
 @file CheckProblem.cpp

 HPCG routine
 */

#ifndef HPCG_NO_MPI
#include <mpi.h>
#endif

#if defined(HPCG_DEBUG) || defined(HPCG_DETAILED_DEBUG)
#include <fstream>
using std::endl;
#include "hpcg.hpp"
#endif
#include <cassert>

#include "CheckProblem.hpp"


/*!
  Check the contents of the generated sparse matrix to see if values match expected contents.

  @param[in]  A      The known system matrix
  @param[inout] b      The newly allocated and generated right hand side vector (if b!=0 on entry)
  @param[inout] x      The newly allocated solution vector with entries set to 0.0 (if x!=0 on entry)
  @param[inout] xexact The newly allocated solution vector with entries set to the exact solution (if the xexact!=0 non-zero on entry)

  @see GenerateGeometry
*/

sycl::event CheckProblem(const SparseMatrix & A, Vector * b, Vector * x, Vector * xexact,
                         sycl::queue& queue) {

#if defined(HPCG_DEBUG)
  global_int_t nx = A.geom->nx;
  global_int_t ny = A.geom->ny;
  global_int_t nz = A.geom->nz;
  global_int_t gnx = A.geom->gnx;
  global_int_t gny = A.geom->gny;
  global_int_t gnz = A.geom->gnz;
  global_int_t gix0 = A.geom->gix0;
  global_int_t giy0 = A.geom->giy0;
  global_int_t giz0 = A.geom->giz0;

  local_int_t localNumberOfRows = nx*ny*nz; // This is the size of our subblock
  global_int_t totalNumberOfRows = gnx*gny*gnz; // Total number of grid points in mesh

  double * bv = nullptr;
  double * xv = nullptr;
  double * xexactv = nullptr;
  if (b!=nullptr) bv = b->values; // Only compute exact solution if requested
  if (x!=nullptr) xv = x->values; // Only compute exact solution if requested
  if (xexact!=nullptr) xexactv = xexact->values; // Only compute exact solution if requested

  local_int_t localNumberOfNonzeros = 0;

  local_int_t * error_flag = sycl::malloc_shared<local_int_t>(1, queue);
  *error_flag = 0;

  double ** A_matrixValues = A.matrixValues;
  global_int_t ** A_mtxIndG = A.mtxIndG;
  global_int_t * A_localToGlobalMap = A.localToGlobalMap;
  char * A_nonzerosInRow = A.nonzerosInRow;

  local_int_t *nnz_ptr_dev = sycl::malloc_device<local_int_t>(1, queue);
  local_int_t *nnz_ptr_host = sycl::malloc_host<local_int_t>(1, queue);

  auto ev = queue.submit([&](sycl::handler& cgh) {
    auto reductionSum = sycl::reduction(nnz_ptr_dev, sycl::plus<>(),
                                        sycl::property::reduction::initialize_to_identity());

    auto kernel = [=](sycl::id<3> item, auto& sum) {
        const local_int_t iz = item.get(0);
        const global_int_t giz = giz0 + iz;
        const local_int_t iy = item.get(1);
        const global_int_t giy = giy0 + iy;
        const local_int_t ix = item.get(2);
        const global_int_t gix = gix0 + ix;
        const local_int_t currentLocalRow = iz*nx*ny + iy*nx + ix;
        const global_int_t currentGlobalRow = giz*gnx*gny + giy*gnx + gix;
        if (currentGlobalRow != A_localToGlobalMap[currentLocalRow]) { *error_flag = 1; }

        char numberOfNonzerosInRow = 0;
        double * currentValuePointer = A_matrixValues[currentLocalRow];
        global_int_t * currentIndexPointerG = A_mtxIndG[currentLocalRow];
        for (int sz=-1; sz<=1; sz++) {
          if (giz+sz>-1 && giz+sz<gnz) {
            for (int sy=-1; sy<=1; sy++) {
              if (giy+sy>-1 && giy+sy<gny) {
                for (int sx=-1; sx<=1; sx++) {
                  if (gix+sx>-1 && gix+sx<gnx) {
                    global_int_t curcol = currentGlobalRow+sz*gnx*gny+sy*gnx+sx;
                    if (curcol == currentGlobalRow) {
                      if (*currentValuePointer != 26.0) { *error_flag = 1; }
                    }
                    else {
                      if (*currentValuePointer != -1.0) { *error_flag = 1; }
                    }
                    currentValuePointer++;
                    if (*currentIndexPointerG != curcol) { *error_flag = 1; }
                    currentIndexPointerG++;
                    numberOfNonzerosInRow++;
                  } // end x bounds test
                } // end sx loop
              } // end y bounds test
            } // end sy loop
          } // end z bounds test
        } // end sz loop
        if (A_nonzerosInRow[currentLocalRow] != numberOfNonzerosInRow) { *error_flag = 1; }
        sum += numberOfNonzerosInRow;
        if (bv != nullptr) {
          if (bv[currentLocalRow] != 26.0 - ((double) (numberOfNonzerosInRow-1))) { *error_flag = 1; }
        }
        if (xv != nullptr) {
          if (xv[currentLocalRow] != 0.0) { *error_flag = 1; }
        }
        if (xexactv != nullptr) {
          if (xexactv[currentLocalRow] != 1.0) { *error_flag = 1; }
        }
      };
    cgh.parallel_for(sycl::range<3>(nz, ny, nx), reductionSum, kernel);
  });
  ev.wait();

  queue.memcpy(nnz_ptr_host, nnz_ptr_dev, sizeof(local_int_t)).wait();
  localNumberOfNonzeros = nnz_ptr_host[0];
  sycl::free(nnz_ptr_dev, queue);
  sycl::free(nnz_ptr_host, queue);

  assert(*error_flag == 0);
  sycl::free(error_flag, queue);

  global_int_t totalNumberOfNonzeros = 0;
#ifndef HPCG_NO_MPI
  // Use MPI's reduce function to sum all nonzeros
  long long lnnz = localNumberOfNonzeros, gnnz = 0; // convert to 64 bit for MPI call
  MPI_Allreduce(&lnnz, &gnnz, 1, MPI_LONG_LONG_INT, MPI_SUM, MPI_COMM_WORLD);
  totalNumberOfNonzeros = gnnz; // Copy back
#else
  totalNumberOfNonzeros = localNumberOfNonzeros;
#endif

  assert(A.totalNumberOfRows == totalNumberOfRows);
  assert(A.totalNumberOfNonzeros == totalNumberOfNonzeros);
  assert(A.localNumberOfRows == localNumberOfRows);
  assert(A.localNumberOfNonzeros == localNumberOfNonzeros);
#endif
  
  return sycl::event();
}
