#!/usr/bin/env python3
############################################################################
#
#  Produces bindings based Fortran 95 procedures
#
############################################################################
import re
import sys
import os
from optparse import OptionParser

scriptName = os.path.basename(sys.argv[0])


############################################################################
#  Compiled regular expressions
############################################################################

REFLAGS = re.IGNORECASE | re.VERBOSE | re.MULTILINE | re.DOTALL

PAT_SUBROUTINE = re.compile(r"""
  subroutine\s+                    # subroutine
  (?P<name>\w+)                    # name of the subroutine
  \((?P<args>[^\)]*)\)             # arguments of the subroutine
  (?P<defs>.*?)\n[ \t]*\n          # definition block of the subroutine
  """, REFLAGS)

PAT_IFACE_INFO = re.compile(r"""if:(?P<type>\w+):(?P<info>\w+)""", REFLAGS)

PAT_CONTLINE = re.compile(r"""&\s*&?""", REFLAGS)

PAT_ARG_SEP = re.compile(r"""\s*,\s*""", REFLAGS)

PAT_MODULE_DEF = re.compile(r"""use\s+(?P<name>\w+)\s*""", REFLAGS)

PAT_PARENTHESIS = re.compile(r"""(?P<obj>\w+)(?:\((?P<arg>[^\)]*)\))?""",
                             REFLAGS)



############################################################################
#  Constants for the parser
############################################################################

INTENT_IN = 0
INTENT_OUT = 1
INTENT_INOUT = 2


############################################################################
#  Exceptions
############################################################################

class ParserError(Exception):
  pass

class FormatterError(Exception):
  pass


############################################################################
#  Dummy class for containing data about a subroutine
############################################################################

class Subroutine(object):
  pass
    

############################################################################
#  Parser
############################################################################

def parse(txt):
  
  start = 0
  match = PAT_SUBROUTINE.search(txt, start)
  tree = []
  while match:
    newSub = Subroutine()
    # Set name and argument names
    newSub.name = match.group("name")
    newSub.argNames = PAT_ARG_SEP.split(match.group("args"))
    newSub.modules = []
    newSub.varDefs = {}
    defLines = match.group("defs").strip().split("\n")
    # Strip away bind(c) stuff
    if "bind(" in defLines[0]:
        defLines.pop(0)
    for defLine in defLines:
      split = defLine.split("!", 1)
      if len(split) == 2:
        match2 = PAT_IFACE_INFO.search(split[1])
        if match2:
          mtype = match2.group("type")
          minfo = match2.group("info")
          if mtype == "mod":
            # Module definition
            if minfo == "use":
              # Module inclusion
              match3 = PAT_MODULE_DEF.search(split[0])
              newSub.modules.append(match3.group("name"))
            else:
              raise ParserError("Unknown module info '%s'" % minfo)
          elif mtype == "var":
            # Variable definition
            split2 = split[0].split("::", 1)
            split3 = split2[0].split(",", 1)
            varNameAndArgs = split2[1].strip()
            match3 = PAT_PARENTHESIS.search(varNameAndArgs)
            varName = match3.group("obj")
            varArgStr = match3.group("arg")
            if varArgStr:
              varArgs = [ s.strip() for s in varArgStr.split(",") ]
            else:
              varArgs = None
            varType = split3[0].strip()
            if len(split3) > 1:
              attribs = [ s.strip() for s in split3[1].split(",") ]
            else:
              attribs = []
            if minfo == "in":
              newSub.varDefs[varName] = (varArgs, varType, INTENT_IN, attribs)
            elif minfo == "out":
              newSub.varDefs[varName] = (varArgs, varType, INTENT_OUT, attribs)
            elif minfo == "inout":
              newSub.varDefs[varName] = (varArgs, varType, INTENT_INOUT,
                                         attribs)
            else:
              raise ParserError("Invalid variable information '%s'" % minfo)
    tree.append(newSub)
    match = PAT_SUBROUTINE.search(txt, match.end())
  
  return tree


############################################################################
#  Formatters
############################################################################

class Formatter(object):
  """Abstract parent of the formatters"""

  def format(self, tree):
    """Returns the string representation of the specified tree (list)
    tree  --  List containing the objects (class Subroutine) to be formatted.
    return --  String containg the formatted output
    """
    pass

  def fold(self, line, indentLevel):
    """Folds a line with a specified intent according to F95 style
    line  --  Line to fold.
    indentLevel  --  Indentation level for the line.
    return  --  String containing the folded line
    """
    pass



class FormatterF95(Formatter):
  """Formatter for Fortran 95 output"""

  def __init__(self, startText, endText, indentInit, indentBlock, indentCont,
               maxLineLength, userMap, functionSuffix=None, functionPrefix=None,
               functionUpperCase=False):
    """Intialises a FormatterF95 object
    startText  --  Text to print before the interfaces
    endText  --  Text to print after the interfaces
    indentInit  --  How much should be the interfaces indented
    indentBlock  --  Indentation increasment for a nested block
    indentCont  --  Indetation increasment for a continuation line
    maxLineLength  --  Maximal length of a line (incl. indentation)
    userMap  --  User defined dictionary for replacing variable types.
    functionSuffix  --  Suffix to append to procedure names.
    functionPrefix  --  String to prepend to every function name.
    functionUpperCase  --  If names should be turned into uppercase.
    """
    self.startText = startText
    self.endText = endText
    self.indentInit = indentInit
    self.indentBlock = indentBlock
    self.indentCont = indentCont
    self.maxLineLength = maxLineLength
    self.userTypeMap = userMap
    if functionSuffix != None:
      self.funcSuffix = functionSuffix
    else:
      self.funcSuffix = ""
    if functionPrefix != None:
      self.funcPrefix = functionPrefix
    else:
      self.funcPrefix = ""
    self.upperCase = functionUpperCase

    
  def fold(self, line, indentLevel):
    """F95 implementation of line folding"""
    
    curIndent = self.indentInit + indentLevel * self.indentBlock
    curIndentStr = " " * curIndent
    maxLineLength = self.maxLineLength - curIndent - 1
    if len(line) > maxLineLength:
      first = True
      newLine = []
      start = 0
      end = maxLineLength
      while end < len(line):
        ind = line.rfind(" ", start, end)
        if ind == -1:
          ind = start + maxLineLength
        if first:
          newLine.append("%s%s&" % (curIndentStr, line[start:ind+1]))
          first = False
          maxLineLength = maxLineLength - self.indentCont - 1
          curIndent = curIndent + self.indentCont
          curIndentStr = " " * curIndent
        else:
          newLine.append("%s&%s&" % (curIndentStr, line[start:ind+1]))
        start = ind + 1
        end = start + maxLineLength
      if first:
        newLine.append("%s%s" % (curIndentStr, line[start:end]))
      else:
        newLine.append("%s&%s" % (curIndentStr, line[start:end]))
      result = "\n".join(newLine)
    else:
      result = "%s%s" % (curIndentStr, line)

    return result


  def format(self, tree):
    """F95 implementation of formatting"""
    out = [ self.startText, ]
    for sub in tree:
      if self.upperCase:
        name = sub.name.upper()
      else:
        name = sub.name
      out.append(self.fold("subroutine %s%s%s(%s)"
                              % (self.funcPrefix, name, self.funcSuffix,
                                 ", ".join(sub.argNames)), 0))
      for module in sub.modules:
        out.append(self.fold("use %s" % (module,), 1))
      out.append(self.fold("implicit none", 1))
      for argName in sub.argNames:
        argProps = sub.varDefs[argName]
        varArgs = argProps[0]
        varType = argProps[1]
        if varType in self.userTypeMap.keys():
          varType = self.userTypeMap[varType]
        if argProps[2] == INTENT_IN:
          varIntent = "intent(in)"
        elif argProps[2] == INTENT_OUT:
          varIntent = "intent(out)"
        elif argProps[2] == INTENT_INOUT:
          varIntent = "intent(inout)"
        else:
          raise FormatterError("Invalid intent type %d." % argProps[1])
        attribs = argProps[3]
        if varArgs:
          argStr = "(%s)" % ", ".join(varArgs)
        else:
          argStr = ""
        if len(attribs):
          attribsStr = ", %s" % ", ".join(attribs)
        else:
          attribsStr = ""
        out.append(self.fold("%s%s, %s :: %s%s"
                                  % (varType, attribsStr, varIntent,
                                     argName, argStr), 1))
      out.append(self.fold("end subroutine %s" % sub.name, 0))
      out.append("")

    out.append(self.endText)
    return "\n".join(out)



class FormatterC(Formatter):
  """Formatter for C output"""

  typeMap = { "integer(c_int)": ("int", True),
              "integer": ("int", True),
              "character(kind=c_char)": ("char", True),
              "character": ("char", True),
              "real(c_double)" : ("double", True),
              "real" : ("double", True),
              "complex(c_double_complex)" : ("double _Complex", True),
              "complex" : ("double _Complex", True),
              "type(lnparams)" : ("struct lnparams", True),
              "type(c_ptr)" : ("double *", False)
              }
              

  def __init__(self, startText, endText, indentInit, indentBlock, indentCont,
               maxLineLength, userTypeMap, printVarNames=False,
               printOrigTypes=False, functionSuffix=None, functionPrefix=None,
               functionUpperCase=False):
    """Intialises a FormatterC object
    startText  --  Text to print before the interfaces
    endText  --  Text to print after the interfaces
    indentInit  --  How much should be the interfaces indented
    indentBlock  --  Indentation increasment for a nested block
    indentCont  --  Indetation increasment for a continuation line
    maxLineLength  --  Maximal length of a line (incl. indentation)
    userTypeMap  --  Additional type map for mapping fortran types to c types.
    printVarNames  --  If variable names in the declaration should be printed
    printOrigTypes  --  If original types should be listed as comment
    functionSuffix  --  String to append to every function name.
    functionPrefix  --  String to prepend to every function name.
    functionUpperCase  --  If names should be turned into uppercase.
    """
    self.startText = startText
    self.endText = endText
    self.indentInit = indentInit
    self.indentBlock = indentBlock
    self.indentCont = indentCont
    self.maxLineLength = maxLineLength
    self.typeMap.update(userTypeMap)
    self.printVarNames = printVarNames
    self.printOrigTypes = printOrigTypes
    if functionSuffix != None:
      self.funcSuffix = functionSuffix
    else:
      self.funcSuffix = ""
    if functionPrefix != None:
      self.funcPrefix = functionPrefix
    else:
      self.funcPrefix = ""
    self.upperCase = functionUpperCase

    
  def fold(self, line, indentLevel):
    """C implementation of line folding"""

    curIndent = self.indentInit + indentLevel * self.indentBlock
    curIndentStr = " " * curIndent
    maxLineLength = self.maxLineLength - curIndent
    if len(line) > maxLineLength:
      first = True
      newLine = []
      start = 0
      end = maxLineLength
      while end < len(line):
        ind = line.rfind(" ", start, end)
        if ind == -1:
          ind = start + maxLineLength
        newLine.append("%s%s" % (curIndentStr, line[start:ind+1]))
        if first:
          first = False
          maxLineLength = maxLineLength - self.indentCont
          curIndent = curIndent + self.indentCont
          curIndentStr = " " * curIndent
        start = ind + 1
        end = start + maxLineLength
      newLine.append("%s%s" % (curIndentStr, line[start:end]))
      result = "\n".join(newLine)
    else:
      result = "%s%s" % (curIndentStr, line)

    return result


  def format(self, tree):
    """C implementation of formatting"""

    out = [ self.startText, ]
    for sub in tree:
      line = []
      if self.upperCase:
        name = sub.name.upper()
      else:
        name = sub.name
      line.append("void %s%s%s(" % (self.funcPrefix, name, self.funcSuffix))
      cargs = []
      comments = []
      for argName in sub.argNames:
        argProps = sub.varDefs[argName]
        varType = argProps[1]
        varIntent = argProps[2]
        ref = "*"
        if "value" in argProps[3]:
            ref = ""
        if varType in self.typeMap.keys():
          (cVarType, isScalar) = self.typeMap[varType]
        else:
          raise FormatterError("No corresponding C-type for '%s'" % varType)
        strConst = varIntent == INTENT_IN and "const " or ""
        if self.printVarNames:
          cargs.append("%s%s %s%s" % (strConst, cVarType, ref, argName))
        else:
          cargs.append("%s%s %s" % (strConst, cVarType, ref))
        if self.printOrigTypes:
          if varIntent == INTENT_IN:
            strIntent = "intent(in)"
          elif varIntent == INTENT_OUT:
            strIntent = "intent(out)"
          elif varIntent == INTENT_INOUT:
            strIntent = "intent(inout)"
          else:
            raise FormatterError("Invalid intent type %d." % argProps[1])
          varArgs = argProps[0]
          varAttrs = argProps[3]
          strArgs = varArgs and "(%s)" % ", ".join(varArgs) or ""
          strAttrs = varAttrs and ", %s" % ", ".join(varAttrs) or ""
          comments.append("%s%s, %s :: %s%s" % (varType, strAttrs, strIntent,
                                             argName, strArgs))
          
          
      line.append(", ".join(cargs))
      line.append(");")
      out.append(self.fold("".join(line), 0))
      if self.printOrigTypes:
        indent = self.indentInit * " "
        out.append("%s/* %s  */"
                   % (indent, ("\n%s   " % indent).join(comments),))
      out.append("")

    out.append(self.endText)
    return "\n".join(out)



class FormatterCpp(FormatterC):
  """Formatter for C++ output. Its output relies on the definitions in the
  header file 'fortran.h'."""

  # Substitution of fortran types into C++ types. Boolean flag signalises
  # if the resulting C++-type is a scalar (True) or an array (False)
  typeMap = { "integer(c_int)": ("int", True),
              "integer": ("int", True),
              "character(kind=c_char)": ("char", True),
              "character": ("char", True),
              "real(c_double)" : ("double", True),
              "real" : ("double", True),
              "complex(c_double_complex)" : ("std::complex<double>", True),
              "complex" : ("std::complex<double>", True),
              "type(lnparams)" : ("struct lnparams", True),
              "type(c_ptr)" : ("double *", False)
              }

  def format(self, tree):
    """C++ implementation of formatting"""

    out = [ self.startText, ]
    for sub in tree:
      externDecl = []           # Declaration for the "extern" routine
      wrapDef = []              # Definition for the wrapper function
      if self.upperCase:
        name = sub.name.upper()
      else:
        name = sub.name
      externDecl.append("extern \"C\" void %s%s%s ("
                        % (self.funcPrefix, name, self.funcSuffix))
      wrapDef.append("inline void f77_%s (" % name)
      externDeclArgs = []       # Arguments of the external function
      wrapDefArgs = []          # Arguments of the wrapper function
      wrapCallArgs = []         # Arguments for calling external from wrapper
      comments = []
      for argName in sub.argNames:
        argProps = sub.varDefs[argName]
        varArgs = argProps[0]    # List of arguments (array dimensions)
        varType = argProps[1]    # Type of the variable
        varIntent = argProps[2]  # Intent of the variabe
        if varType in self.typeMap.keys():
          (cVarType, isScalar) = self.typeMap[varType]
        else:
          raise FormatterError("No corresponding C++-type for '%s'" % varType)

        if varIntent == INTENT_IN:
          if varArgs == None and isScalar:
            if "value" in argProps[3]:
              # intent(in), value scalar variables passed by value,
              strVarType = "%s" % cVarType
            else:
              # intent(in) scalar variables mapped to const&,
              strVarType = "%s const&" % cVarType
          else:
            # intent(in) array variables mapped to const*
            strVarType = "%s const*" % cVarType
        else:
          if varArgs == None and isScalar:
            # intent(out), intent(inout) scalar variables mapped to &
            strVarType = "%s&" % cVarType
          else:
            # intent(out), intent(inout) array variables mapped to *
            strVarType = "%s*" % cVarType
        wrapDefArgs.append("%s %s" % (strVarType, argName))
        if self.printVarNames:
          externDeclArgs.append("%s %s" % (strVarType, argName))
        else:
          externDeclArgs.append("%s" % (strVarType,))
        wrapCallArgs.append(argName)

        # Creating comment
        if self.printOrigTypes:
          if varIntent == INTENT_IN:
            strIntent = "intent(in)"
          elif varIntent == INTENT_OUT:
            strIntent = "intent(out)"
          elif varIntent == INTENT_INOUT:
            strIntent = "intent(inout)"
          else:
            raise FormatterError("Invalid intent type %d." % argProps[1])
          varArgs = argProps[0]
          varAttrs = argProps[3]
          strArgs = varArgs and "(%s)" % ", ".join(varArgs) or ""
          strAttrs = varAttrs and ", %s" % ", ".join(varAttrs) or ""
          comments.append("%s%s, %s :: %s%s" % (varType, strAttrs, strIntent,
                                             argName, strArgs))
      # Printing out the external declaration        
      externDecl.append(", ".join(externDeclArgs))
      externDecl.append(");")
      out.append(self.fold("".join(externDecl), 0))
      out.append("")

      # Printing out the corresponding Fortran 77 types, if needed
      if self.printOrigTypes:
        indent = self.indentInit * " "
        out.append("%s// Corresponding F77 arguments for f77_%s:"
                   % (indent, name))
        out.append("%s// %s" % (indent, ("\n%s// " % indent).join(comments),))

      # Printing out the wrapper functions
      wrapDef.append(", ".join(wrapDefArgs))
      wrapDef.append(")")
      out.append(self.fold("".join(wrapDef), 0));
      out.append("{")
      line = [ "%s%s%s (" % (self.funcPrefix, name, self.funcSuffix),
               ", ".join(wrapCallArgs), ");" ]
      out.append(self.fold("".join(line),1));
      out.append("}")
      out.append("\n")
      

    out.append(self.endText)
    return "\n".join(out)


############################################################################
#  I/O functions
############################################################################

def error(msg):
  """Prints an error message and stops the script.
  msg  --  Error message to print.
  """
  print("ERROR: %s" % msg)
  sys.exit()
  

############################################################################
#  Main program
############################################################################

if __name__ == "__main__":

  # Fill up input line parser
  usage = """%s [ options ] [ fname ]
Reads a specified Fortran file (or stdin if no files are specified), parses it
and dumps out the interface definitions to stdout.""" % scriptName
  
  optParser = OptionParser(usage=usage)
  optParser.add_option("-f", "--formatter", type="string",
                       dest="formatter", default="f95",
                       help="Specifies the output formatter to use. "
                       "(Available: f95, c. Default: f95.)")
  optParser.add_option("-b", "--begin-file", dest="beginfile",
                       help="File containing the text to print before "
                       "the interfaces are printed.")
  optParser.add_option("-e", "--end-file", dest="endfile",
                       help="File containing the text to print after "
                       "the interfaces are printed.")
  optParser.add_option("-m", "--map-file", dest="mapfile",
                       help="File containing user defined type mappings.")
  optParser.add_option("-n", "--variable-names", action="store_true",
                       dest="printVarNames", default=False,
                       help="Print variable names in the declaration "
                       "(only for C).")
  optParser.add_option("-t", "--original-types", action="store_true",
                       dest="printOrigTypes", default=False,
                       help="Print the original types as comments "
                       "(only for C).")
  optParser.add_option("-s", "--suffix", dest="suffix", default="",
                       help="Appends the specified string to every function"
                       "name.")
  optParser.add_option("-p", "--prefix", dest="prefix", default="",
                       help="Appends the specified string to every function"
                       "name.")
  optParser.add_option("-u", "--upper-case", dest="upperCase",
                       action="store_true", default=False,
                       help="Appends the specified string to every function"
                       "name.")

  #  Parse and interprete arguments
  (options, args) = optParser.parse_args(sys.argv[1:])
  if options.beginfile:
    fp = open(options.beginfile, "r")
    beginText = fp.read()
    fp.close()
  else:
    beginText = ""
  if options.endfile:
    fp = open(options.endfile, "r")
    endText = fp.read()
    fp.close()
  else:
    endText = ""

  userMap = {}
  if options.mapfile:
    fp = open(options.mapfile, "r")
    for line in fp.readlines():
      try:
        commPos = line.index("#")
      except ValueError:
        pass
      else:
        line = line[:commPos]
      words = line.split("|", 2)
      if len(words) != 3:
        continue
      userMap[words[0].strip()] = (words[1].strip(), bool(words[2].strip()))
    fp.close()

  if options.formatter == "f95":
    formatter = FormatterF95(beginText, endText, 4, 2, 4, 80, userMap,
                             functionSuffix=options.suffix,
                             functionPrefix=options.prefix,
                             functionUpperCase=options.upperCase)
  elif options.formatter == "c":
    formatter = FormatterC(beginText, endText, 0, 2, 4, 80, userMap,
                           printVarNames=options.printVarNames,
                           printOrigTypes=options.printOrigTypes,
                           functionSuffix=options.suffix,
                           functionPrefix=options.prefix,
                           functionUpperCase=options.upperCase)
  elif options.formatter == "c++":
    formatter = FormatterCpp(beginText, endText, 0, 2, 4, 80, userMap,
                             printVarNames=options.printVarNames,
                             printOrigTypes=options.printOrigTypes,
                             functionSuffix=options.suffix,
                             functionPrefix=options.prefix,
                             functionUpperCase=options.upperCase)
  else:
    error("Unknown formatter type '%s'" % options.formatter)

  # Read and remove continuation lines
  if len(args):
    fp = open(args[0], "r")
    txt = fp.read()
    fp.close()
  else:
    txt = sys.stdin.read()
  txt = PAT_CONTLINE.sub("", txt)

  # Parse text
  tree = parse(txt)

  # Dump out in the required format
  print(formatter.format(tree))


### Local Variables:
### mode:python
### End:
