#!/usr/bin/env python3
#--------------------------------------------------------------------------------------------------#
#  DFTB+: general package for performing fast atomistic simulations                                #
#  Copyright (C) 2006 - 2023  DFTB+ developers group                                               #
#                                                                                                  #
#  See the LICENSE file for terms of usage and distribution.                                       #
#--------------------------------------------------------------------------------------------------#

"""Script to evaluate expressions on cube files"""

import argparse
from dataclasses import dataclass, field
import re
import sys
import numpy as np
import numpy.linalg as la

SCRIPT_NAME = "cubemanip"

# tolerance for positions to be treated as equal
TOLERANCE = 1.0e-5

# Pattern for the argument parsing
VAR_PATTERN = re.compile(r"\((\$(?P<number>\d+))\)")

# Help text
USAGE = f"""cubemanip [-h] [-d] expr cubefile [cubefile ...]

Combines the passed cube files according to an expression expression to a new
cube file. The cube files provided as arguments after the expression must be
referenced in the expression as ($n), with n being the index of the appropriate
file. The expression can be any arbitrary expression which is valid in python
using the functions found in the numpy and numpy.linalg packages and results in
a float or a float array. Unless the '-d' option is passed, the result must have
the same shape, as the data arrays in the cube files. The references in the
expression are substituted by the data arrays of the appropriate cube files. The
resulting cube file will have the same header as the first input file, while its
data array is created by evaluating the expression. The number of cube files
passed as arguments must be higher than the maximal reference index in the
expression.

Apart from the data arrays, following variables and functions can be used in the
expression:

  origin:  origin of the grid in the first input file as (3,) shaped array

  gridvecs: grid vectors in the first input file as (3, 3) shaped array

  ngrids: nr. of grid units along the grid vectors as a (3,) shaped array

  posx(), posy(), posz(): x-, y- and z- coordinates of each grid point

Examples:

  Weighted sum of three cube files:
  {SCRIPT_NAME} '0.25 * ($1) + 0.25 * ($2) + 0.5 * ($3)' f1.cube f2.cube f3.cube

  Integral of the values in a cube file:
  {SCRIPT_NAME} -d 'sum(($1)) * det(gridvecs)' file1.cube

  Dipole integral (x-component):
  {SCRIPT_NAME} -d 'sum(($1) * posx()) * det(gridvecs)' file1.cube
"""


@dataclass
class Cube:
    """Contains data found in a cube file

    Attributes:
        species: Integer species index for each atom, shape: (natoms,).
        coords: Cartesian coordinates of each atom, shape: (natoms, 3).
        origin: Origin of the grid, shape: (3, 3).
        gridvecs: Vectors defining shifts between neighbouring grid points, shape: (3, 3).
        ngrids: Nr. of grid points along each grid vector, shape (3,).
        data: Volumentric data for each grid point, with shape of ngrids.
        natoms: Nr. of atoms in the system.
    """

    species: np.ndarray
    coords: np.ndarray
    origin: np.ndarray
    gridvecs: np.ndarray
    ngrids: np.ndarray
    data: np.ndarray
    natoms: int = field(default=0, init=False)


    def __post_init__(self):
        """Cheks array compatibilities, fills up uninitialized fields."""

        if len(self.coords.shape) != 2 or self.coords.shape[1] != 3:
            raise ValueError(_invalid_shape_msg("coords", (-1, 3), self.coords.shape))
        self.natoms = len(self.coords)

        if self.species.shape != (self.natoms,):
            raise ValueError(_invalid_shape_msg("species", (self.natoms,), self.species.shape))

        if self.origin.shape != (3,):
            raise ValueError(_invalid_shape_msg("origin", (3,), self.origin.shape))

        if self.gridvecs.shape != (3, 3):
            raise ValueError(_invalid_shape_msg("gridvecs", (3, 3), self.gridvecs.shape))

        if self.ngrids.shape != (3,):
            raise ValueError(_invalid_shape_msg("ngrids", (3,), self.ngrids.shape))

        if self.data.shape != tuple(self.ngrids):
            raise ValueError(_invalid_shape_msg("data", self.ngrids, self.data.shape))


    @classmethod
    def from_file(cls, fname):
        """Creates a Cube instance from a file.

        Args:
            fname: Name of the file to read the data from.

        Returns:
            Cube instance
        """
        with open(fname, "r", encoding="utf-8") as fp:
            fp.readline()
            fp.readline()
            line = fp.readline()
            words = line.split()
            natoms = int(words[0])
            origin = np.array([float(s) for s in words[1:4]])
            lines = [fp.readline() for _ in range(3)]
            words = " ".join(lines).split()
            ngrids = np.array([words[0], words[4], words[8]], dtype=int)
            gridvecs = np.array([words[ii] for ii in (1, 2, 3, 5, 6, 7, 9, 10, 11)], dtype=float)
            gridvecs.shape = (3, 3)
            lines = [fp.readline() for _ in range(natoms)]
            words = " ".join(lines).split()
            species = np.array([words[ii] for ii in range(0, 5 * natoms, 5)], dtype=int)
            tmp = []
            for ii in range(2, 5 * natoms, 5):
                tmp += words[ii : ii + 3]
            coords = np.array(tmp, dtype=float)
            coords.shape = (-1, 3)
            data = np.fromfile(fp, sep=" ", count=np.prod(ngrids))
            data.shape = tuple(ngrids)
        return cls(species, coords, origin, gridvecs, ngrids, data)


    def to_file(self, fobj):
        """Writes the cube data to a file

        Args:
            fobj: File name (str) or a file like object
        """
        isfilename = isinstance(fobj, str)
        fp = open(fobj, "w", encoding="utf-8") if isfilename else fobj
        fp.write("Cube file created by cubemanip\n")
        fp.write("Cube file created by cubemanip\n")
        natoms = self.natoms
        fp.write(
            f"{natoms:10d} {self.origin[0]:18.10E} {self.origin[1]:18.10E} "
            f"{self.origin[2]:18.10E}\n"
        )
        for ii in range(3):
            fp.write(
                f"{self.ngrids[ii]:10d} {self.gridvecs[ii][0]:18.10E} "
                f"{self.gridvecs[ii][1]:18.10E} {self.gridvecs[ii][2]:18.10E}\n"
            )
        for ii in range(natoms):
            fp.write(
                f"{self.species[ii]:4d} {0.0:18.10E} {self.coords[ii][0]:18.10E} "
                f"{self.coords[ii][1]:18.10E} {self.coords[ii][2]:18.10E}\n"
            )
        _dump_data(fp, self.data)
        if isfilename:
            fp.close()


def main():
    """Main program"""

    args = _parse_cmd_line_args()
    expr = args.expr
    cubefiles = args.cubefiles

    maxref = _max_ref_ind(expr)
    if maxref == 0:
        _error("Expression contains no cube file reference indices!")
    if len(cubefiles) < maxref:
        _error(f"Missing files! You must specify at least {maxref} cube file(s).")

    try:
        cubes = [Cube.from_file(fname) for fname in cubefiles]
    except IOError as exc:
        _error(f"IO error occured during processing input files [{exc}]")

    _check_cube_compatibility(cubefiles, cubes)

    try:
        result = _evaluate_expression(expr, cubes)
    except Exception as exp:
        _error(f"Unable to evaluate the expression:\n{exp}")
    if result.shape != cubes[0].data.shape and not args.data_only:
        _error(
            f"Resulting shape {result.shape} is incompatible with the original data shape "
            f"{cubes[0].data.shape}"
        )

    if args.data_only:
        if isfile := args.output_file != "-":
            fp = open(args.output_file, "w", encoding="utf-8")
        else:
            fp = sys.stdout
        _dump_data(fp, result)
        if isfile:
            fp.close()
    else:
        c0 = cubes[0]
        cube = Cube(c0.species, c0.coords, c0.origin, c0.gridvecs, c0.ngrids, result)
        outfobj = sys.stdout if args.output_file == "-" else args.output_file
        cube.to_file(outfobj)


def _parse_cmd_line_args():
    """Parse command line arguments"""
    parser = argparse.ArgumentParser(usage=USAGE)

    msg = "expression evaluated with data from cube file(s)"
    parser.add_argument("expr", help=msg)

    msg = "cube file providing input data"
    parser.add_argument("cubefiles", metavar="cubefile", nargs="+", help=msg)

    msg = (
        "dump data without cube header, allows for dumping data with a different shape as the "
        "volumentric data in the cube file(s)"
    )
    parser.add_argument(
        "-d", "--data-only", action="store_true", dest="data_only", default=False, help=msg
    )

    msg = "output file, default: - (standard output)"
    parser.add_argument("-o", "--output-file", default="-", help=msg)

    return parser.parse_args()


def _max_ref_ind(expr):
    """Highest reference index in the expression"""
    maxref = 0
    for match in VAR_PATTERN.finditer(expr):
        maxref = max(maxref, int(match.group("number")))
    return maxref


def _check_cube_compatibility(cubefiles, cubes):
    """Checks cube file data for compatibility"""

    fname0, cube0 = cubefiles[0], cubes[0]
    for fname, cube in zip(cubefiles[1:], cubes[1:]):
        if np.any(cube.ngrids != cube0.ngrids):
            _error(f"Grid points in '{fname0}' and '{fname}' are incompatible!")

        if cube.natoms != cube0.natoms:
            _warning(f"Nr. of atoms in '{fname0}' and '{fname}' differ")

        elif np.any(cube.species != cube0.species):
            _warning(f"Specie indexes in '{fname0}' and '{fname}' differ")

        elif np.max(np.sqrt(np.sum((cube.coords - cube0.coords) ** 2, -1))) > TOLERANCE:
            _warning(f"Coordinates in '{fname0}' and '{fname}' differ!")

        if np.sqrt(np.sum((cube.origin - cube0.origin) ** 2)) > TOLERANCE:
            _warning(f"Origins in '{fname0}' and '{fname}' differ!")

        if max(np.sqrt(np.sum((cube.gridvecs - cube0.gridvecs) ** 2, -1))) > TOLERANCE:
            _warning(f"Grid vectors in '{fname0}' and '{fname}' differ!")

        if max(np.sqrt(np.sum((cube.gridvecs - cube0.gridvecs) ** 2, -1))) > TOLERANCE:
            _warning(f"Grid vectors in '{fname0}' and '{fname}' differ!")


def _evaluate_expression(expr, cubes):
    """Evaluates a cube data expression"""

    data = [cube.data for cube in cubes]
    pyexpr = VAR_PATTERN.sub(r"(data[\g<number>-1])", expr)
    exprscope = {
        "data": data,
        "origin": cubes[0].origin,
        "gridvecs": cubes[0].gridvecs,
        "ngrids": cubes[0].ngrids,
    }
    exprscope.update(np.__dict__)
    exprscope.update(la.__dict__)
    posx, posy, posz = _get_pos_funcs(cubes[0].origin, cubes[0].ngrids, cubes[0].gridvecs)
    exprscope.update({"posx": posx, "posy": posy, "posz": posz})
    return np.array(eval(pyexpr, exprscope), float)


def _error(msg):
    """Prints an error message and stops"""
    print(f"Error: {msg}", file=sys.stderr)
    sys.exit(1)


def _warning(msg):
    """Prints a warning message"""
    print(f"Warning: {msg}", file=sys.stderr)


def _get_pos_funcs(origin, ngrids, gridvecs):
    "Delivers three functions which return x, y & z position of the grid points"

    def posx():
        "Returns the x-position of every grid point, allowing for the use of a coordinate origin"
        return _calc_positions(0)

    def posy():
        "Returns the y-position of every grid point, allowing for the the use of a coordinate origin"
        return _calc_positions(1)

    def posz():
        "Returns the z-position of every grid point allowing for the use  of a coordinate origin"
        return _calc_positions(2)

    def _calc_positions(ind):
        ranges = [np.arange(nGrid) for nGrid in ngrids]
        mgrid = np.meshgrid(*ranges, sparse=True, indexing="ij")
        pos = (
            mgrid[0] * gridvecs[0, ind] + mgrid[1] * gridvecs[1, ind] + mgrid[2] * gridvecs[2, ind]
        )
        pos += origin[ind]
        return pos

    return posx, posy, posz


def _dump_data(fp, data):
    """Dumps an arbitrary shaped numerical array"""

    rank = len(data.shape)
    if rank > 1:
        for block in data:
            _dump_data(fp, block)
    elif rank == 1:
        line = []
        for item in data:
            line.append(f"{item:18.10E}")
            if len(line) == 4:
                fp.write(" ".join(line))
                fp.write("\n")
                line = []
        if line:
            fp.write(" ".join(line))
            fp.write("\n")
    else:
        fp.write(f"{data:18.10E}\n")


def _invalid_shape_msg(array, expected, obtained):
    """Generates an invalid shape mismatch message"""
    return f"Array {array} has invalid shape (expected {expected}, obtained {obtained})"


if __name__ == "__main__":
    main()
