#!/usr/bin/env python3
#------------------------------------------------------------------------------#
#  DFTB+: general package for performing fast atomistic simulations            #
#  Copyright (C) 2006 - 2023  DFTB+ developers group                           #
#                                                                              #
#  See the LICENSE file for terms of usage and distribution.                   #
#------------------------------------------------------------------------------#

from __future__ import print_function
import argparse
import numpy as np
import matplotlib.pylab as plt


def check_input(parser, args):
    if args.xticklabels is not None:
        if args.xticks is None:
            parser.error("To use ’--xticklabels' you must also use ’--xticks’!")
        elif (len_xticks := len(args.xticks)) != (len_xticklabels := len(args.xticklabels)):
            parser.error(f"The number of xticks ({len_xticks}) must equal "
                         f"the number of xticklabels ({len_xticklabels})!")


parser = argparse.ArgumentParser(description='Plot 2D data as ((N)X)Y plot')
parser.add_argument('--version', action='version', version='1.1')
parser.add_argument('-L', '--no-legend', action='store_false', default=True,
                    dest='legend', help='omit legends')
parser.add_argument('-G', '--no-grid', action='store_false', default='True',
                    dest='grid', help='omit grid')
parser.add_argument('datafile', nargs='+',
                    help='file containing the data')
parser.add_argument('-C', '--no-colors', action='store_false', default=True,
                    dest='colors', help='omit coloring')
parser.add_argument('-s', '--skip-rows', type=int, dest='skiprows',
                    metavar='ROWS', default=0,
                    help='nr. of rows to skip (default: 0)')
parser.add_argument('-x', '--x-column', dest='xcol', type=int, default=1,
                    help='column of the x values (default: 1)')
parser.add_argument('-y', '--y-columns', dest='ycols', metavar='YCOL',
                    type=int, nargs='+', default=None,
                    help='list of the columns for the y values'
                    ' (default: all columns)')
parser.add_argument('-l', '--labels', dest='labels', metavar='LABEL',
                    nargs='+', default=None,
                    help='label for each data column (default: filename:'
                    'column_number)')
parser.add_argument('--xlabel', dest='xlabel', default='',
                    help='label for x-axis')
parser.add_argument('--ylabel', dest='ylabel', default='',
                    help='label for y-axis')
parser.add_argument('--title', dest='title', default='',
                    help='title of the plot')
parser.add_argument('--xlimits', dest='xlim', nargs=2, type=float,
                    default=None, help='limits for x axis.')
parser.add_argument('--ylimits', dest='ylim', nargs=2, type=float,
                    default=None, help='limits for y axis.')
parser.add_argument('--xshift', type=float, default=0.0,
                    help='shift all x values by this')
parser.add_argument('--yshift', type=float, default=0.0,
                    help='shift all y values by this')
parser.add_argument('--logscale', action='store',
                    choices=['none', 'x', 'y', 'xy'], default='none',
                    dest='logscale', help='use log scaling on selected axis '\
                    '(default: none)')
parser.add_argument('--zero-line', action='store_true', default=False,
                    dest='zeroline', help='draw extra line for y=0')
parser.add_argument('--xticks', dest='xticks', metavar='XTICKS',
                    nargs='+', default=None, type=float,
                    help='ticks on x axis')
parser.add_argument('--xticklabels', dest='xticklabels', metavar='XTICKLABELS',
                    nargs='+', default=None,
                    help='label for ticks on x axis')

args = parser.parse_args()

check_input(parser, args)

axis = plt.gca()

if args.logscale == 'none':
    plot = axis.plot
elif args.logscale == 'x':
    plot = axis.semilogx
elif args.logscale == 'y':
    plot = axis.semilogy
else:
    plot = axis.loglog

if not args.colors:
    axis.set_prop_cycle(color=[(0.0, 0.0, 0.0)])
label_index = 0
for datafile in args.datafile:
    data = np.loadtxt(datafile, skiprows=args.skiprows)
    if len(data.shape) == 1:
        xx = np.arange(1, len(data) + 1)
        yy = data
        yy.shape = (len(yy), 1)
        colshift = 1
    else:
        xx = data[:,0]
        yy = data[:,1:]
        colshift = 2
    xx = xx + args.xshift
    yy = yy + args.yshift
    if args.ycols is None:
        ycols = np.arange(yy.shape[1])
    else:
        ycols = np.array(args.ycols) - colshift
    for ii in range(yy.shape[1]):
        if ii not in ycols:
            continue
        if args.labels is not None and len(args.labels) > label_index:
            label = args.labels[label_index]
            label_index += 1
        else:
            label = '{}:{:d}'.format(datafile, colshift + ii)
        plot(xx, yy[:,ii], label=label)

if args.zeroline:
    plot([ min(xx), max(xx) ], [ 0.0, 0.0 ], 'k--')

if args.legend:
    axis.legend()
if args.grid:
    axis.grid()
if args.xlabel:
    axis.set_xlabel(args.xlabel)
if args.ylabel:
    axis.set_ylabel(args.ylabel)
if args.title:
    axis.set_title(args.title)
if args.xlim:
    axis.set_xlim(args.xlim[0], args.xlim[1])
if args.ylim:
    axis.set_ylim(args.ylim[0], args.ylim[1])
if args.xticks:
    axis.set_xticks(args.xticks)
if args.xticklabels:
    axis.set_xticklabels(args.xticklabels)

plt.show()
