#!/usr/bin/python
# -*- coding: utf-8 -*-
#
# File:   formula.py
# Date:   04-May-12 (creation)
# Author: I. Chuang <ichuang@mit.edu>
#
# flexible python representation of a symbolic mathematical formula.
# Acceptes Presentation MathML, Content MathML (and could also do OpenMath)
# Provides sympy representation.

import os, sys, string, re
import logging
import operator
import sympy
from sympy.printing.latex import LatexPrinter
from sympy.printing.str import StrPrinter
from sympy import latex, sympify
from sympy.physics.quantum.qubit import *
from sympy.physics.quantum.state import *
# from sympy import exp, pi, I
# from sympy.core.operations import LatticeOp
# import sympy.physics.quantum.qubit

import urllib
from xml.sax.saxutils import escape, unescape
import sympy
import unicodedata
from lxml import etree
#import subprocess
import requests
from copy import deepcopy

log = logging.getLogger(__name__)

log.warning("Dark code. Needs review before enabling in prod.")

os.environ['PYTHONIOENCODING'] = 'utf-8'

#-----------------------------------------------------------------------------

class dot(sympy.operations.LatticeOp):	# my dot product
    zero = sympy.Symbol('dotzero')
    identity = sympy.Symbol('dotidentity')

#class dot(sympy.Mul):	# my dot product
#    is_Mul = False

def _print_dot(self,expr):
    return '{((%s) \cdot (%s))}' % (expr.args[0],expr.args[1])

LatexPrinter._print_dot = _print_dot

#-----------------------------------------------------------------------------
# unit vectors (for 8.02)

def _print_hat(self,expr): return '\\hat{%s}' % str(expr.args[0]).lower()

LatexPrinter._print_hat = _print_hat
StrPrinter._print_hat = _print_hat

#-----------------------------------------------------------------------------
# helper routines

def to_latex(x):
    if x==None: return ''
    # LatexPrinter._print_dot = _print_dot
    xs = latex(x)
    xs = xs.replace(r'\XI','XI')	# workaround for strange greek
    #return '<math>%s{}{}</math>' % (xs[1:-1])
    if xs[0]=='$':
        return '[mathjax]%s[/mathjax]<br>' % (xs[1:-1])	# for sympy v6
    return '[mathjax]%s[/mathjax]<br>' % (xs)		# for sympy v7

def my_evalf(expr,chop=False):
    if type(expr)==list:
        try:
            return [x.evalf(chop=chop) for x in expr]
        except:
            return expr
    try:
        return expr.evalf(chop=chop)
    except:
        return expr

#-----------------------------------------------------------------------------
# my version of sympify to import expression into sympy

def my_sympify(expr,normphase=False,matrix=False,abcsym=False,do_qubit=False,symtab=None):
    # make all lowercase real?
    if symtab:
        varset = symtab
    else:
        varset = {'p':sympy.Symbol('p'),
                  'g':sympy.Symbol('g'),
                  'e':sympy.E,			# for exp
                  'i':sympy.I,			# lowercase i is also sqrt(-1)
                  'Q':sympy.Symbol('Q'),	# otherwise it is a sympy "ask key"
                  #'X':sympy.sympify('Matrix([[0,1],[1,0]])'),
                  #'Y':sympy.sympify('Matrix([[0,-I],[I,0]])'),
                  #'Z':sympy.sympify('Matrix([[1,0],[0,-1]])'),
                  'ZZ':sympy.Symbol('ZZ'),	# otherwise it is the PythonIntegerRing
                  'XI':sympy.Symbol('XI'),	# otherwise it is the capital \XI
                  'hat':sympy.Function('hat'),	# for unit vectors (8.02)
                  }
    if do_qubit:		# turn qubit(...) into Qubit instance
        varset.update({'qubit':sympy.physics.quantum.qubit.Qubit,
                       'Ket':sympy.physics.quantum.state.Ket,
                       'dot':dot,
                       'bit':sympy.Function('bit'),
                       })
    if abcsym:			# consider all lowercase letters as real symbols, in the parsing
        for letter in string.lowercase:
            if letter in varset:	# exclude those already done
                continue
            varset.update({letter:sympy.Symbol(letter,real=True)})

    sexpr = sympify(expr,locals=varset)
    if normphase:	# remove overall phase if sexpr is a list
        if type(sexpr)==list:
            if sexpr[0].is_number:
                ophase = sympy.sympify('exp(-I*arg(%s))' % sexpr[0])
                sexpr = [ sympy.Mul(x,ophase) for x in sexpr ]

    def to_matrix(x):		# if x is a list of lists, and is rectangular, then return Matrix(x)
        if not type(x)==list:
            return x
        for row in x:
            if (not type(row)==list):
                return x
        rdim = len(x[0])
        for row in x:
            if not len(row)==rdim:
                return x
        return sympy.Matrix(x)

    if matrix:
        sexpr = to_matrix(sexpr)
    return sexpr

#-----------------------------------------------------------------------------
# class for symbolic mathematical formulas

class formula(object):
    '''
    Representation of a mathematical formula object.  Accepts mathml math expression for constructing,
    and can produce sympy translation.  The formula may or may not include an assignment (=).
    '''
    def __init__(self,expr,asciimath='',options=None):
        self.expr = expr.strip()
        self.asciimath = asciimath
        self.the_cmathml = None
        self.the_sympy = None
        self.options = options

    def is_presentation_mathml(self):
        return '<mstyle' in self.expr

    def is_mathml(self):
        return '<math ' in self.expr

    def fix_greek_in_mathml(self,xml):
        def gettag(x):
            return re.sub('{http://[^}]+}','',x.tag)

        for k in xml:
            tag = gettag(k)
            if tag=='mi' or tag=='ci':
                usym = unicode(k.text)
                try:
                    udata = unicodedata.name(usym)
                except Exception,err:
                    udata = None 
                #print "usym = %s, udata=%s" % (usym,udata)
                if udata:			# eg "GREEK SMALL LETTER BETA"
                    if 'GREEK' in udata:
                        usym = udata.split(' ')[-1]	
                        if 'SMALL' in udata: usym = usym.lower()
                        #print "greek: ",usym
                k.text = usym
            self.fix_greek_in_mathml(k)
        return xml

    def preprocess_pmathml(self,xml):
        '''
        Pre-process presentation MathML from ASCIIMathML to make it more acceptable for SnuggleTeX, and also
        to accomodate some sympy conventions (eg hat(i) for \hat{i}).
        '''

        if type(xml)==str or type(xml)==unicode:
            xml = etree.fromstring(xml)		# TODO: wrap in try

        xml = self.fix_greek_in_mathml(xml)	# convert greek utf letters to greek spelled out in ascii

        def gettag(x):
            return re.sub('{http://[^}]+}','',x.tag)

        # f and g are processed as functions by asciimathml, eg  "f-2" turns into "<mrow><mi>f</mi><mo>-</mo></mrow><mn>2</mn>"
        # this is really terrible for turning into cmathml.
        # undo this here.
        def fix_pmathml(xml):
            for k in xml:
                tag = gettag(k)
                if tag=='mrow':
                    if len(k)==2:
                        if gettag(k[0])=='mi' and k[0].text in ['f','g'] and gettag(k[1])=='mo':
                            idx = xml.index(k)
                            xml.insert(idx,deepcopy(k[0]))	# drop the <mrow> container
                            xml.insert(idx+1,deepcopy(k[1]))
                            xml.remove(k)
                fix_pmathml(k)

        fix_pmathml(xml)

        # hat i is turned into <mover><mi>i</mi><mo>^</mo></mover> ; mangle this into <mi>hat(f)</mi>
        # hat i also somtimes turned into <mover><mrow> <mi>j</mi> </mrow><mo>^</mo></mover>

        def fix_hat(xml):
            for k in xml:
                tag = gettag(k)
                if tag=='mover':
                    if len(k)==2:
                        if gettag(k[0])=='mi' and gettag(k[1])=='mo' and str(k[1].text)=='^':
                            newk = etree.Element('mi')
                            newk.text = 'hat(%s)' % k[0].text
                            xml.replace(k,newk)
                        if gettag(k[0])=='mrow' and gettag(k[0][0])=='mi' and gettag(k[1])=='mo' and str(k[1].text)=='^':
                            newk = etree.Element('mi')
                            newk.text = 'hat(%s)' % k[0][0].text
                            xml.replace(k,newk)
                fix_hat(k)
        fix_hat(xml)

        self.xml = xml
        return self.xml

    def get_content_mathml(self):
        if self.the_cmathml: return self.the_cmathml

        # pre-process the presentation mathml before sending it to snuggletex to convert to content mathml
        try:
            xml = self.preprocess_pmathml(self.expr)
        except Exception,err:
            return "<html>Error! Cannot process pmathml</html>"
        pmathml = etree.tostring(xml,pretty_print=True)
        self.the_pmathml = pmathml

        # convert to cmathml
        self.the_cmathml = self.GetContentMathML(self.asciimath,pmathml)
        return self.the_cmathml

    cmathml = property(get_content_mathml,None,None,'content MathML representation')

    def make_sympy(self,xml=None):
        '''
        Return sympy expression for the math formula.
        The math formula is converted to Content MathML then that is parsed.
        '''

        if self.the_sympy: return self.the_sympy

        if xml==None:	# root
            if not self.is_mathml():
                return my_sympify(self.expr)
            if self.is_presentation_mathml():
                try:
                    cmml = self.cmathml
                    xml = etree.fromstring(str(cmml))
                except Exception,err:
                    raise Exception,'Err %s while converting cmathml to xml; cmml=%s' % (err,cmml)
                xml = self.fix_greek_in_mathml(xml)
                self.the_sympy = self.make_sympy(xml[0])
            else:
                xml = etree.fromstring(self.expr)
                xml = self.fix_greek_in_mathml(xml)
                self.the_sympy = self.make_sympy(xml[0])
            return self.the_sympy

        def gettag(x):
            return re.sub('{http://[^}]+}','',x.tag)

        # simple math
        def op_divide(*args):
            if not len(args)==2:
                raise Exception,'divide given wrong number of arguments!'
            # print "divide: arg0=%s, arg1=%s" % (args[0],args[1])
            return sympy.Mul(args[0],sympy.Pow(args[1],-1))

        def op_plus(*args): return args[0] if len(args)==1 else op_plus(*args[:-1])+args[-1]
        def op_times(*args): return reduce(operator.mul,args)

        def op_minus(*args):
            if len(args)==1:
                return -args[0]
            if not len(args)==2:
                raise Exception,'minus given wrong number of arguments!'
            #return sympy.Add(args[0],-args[1])
            return args[0]-args[1]

        opdict = {'plus': op_plus,            
                  'divide' : operator.div,
                  'times' : op_times,
                  'minus' : op_minus,
                  #'plus': sympy.Add,
                  #'divide' : op_divide,
                  #'times' : sympy.Mul,
                  'minus' : op_minus,
                  'root' : sympy.sqrt,
                  'power' : sympy.Pow,
                  'sin': sympy.sin,
                  'cos': sympy.cos,
                   }

        # simple sumbols
        nums1dict = {'pi': sympy.pi,
                     }

        def parsePresentationMathMLSymbol(xml):
            '''
            Parse <msub>, <msup>, <mi>, and <mn>
            '''
            tag = gettag(xml)
            if tag=='mn': return xml.text
            elif tag=='mi': return xml.text
            elif tag=='msub': return '_'.join([parsePresentationMathMLSymbol(y) for y in xml])
            elif tag=='msup': return '^'.join([parsePresentationMathMLSymbol(y) for y in xml])
            raise Exception,'[parsePresentationMathMLSymbol] unknown tag %s' % tag

        # parser tree for Content MathML
        tag = gettag(xml)
        # print "tag = ",tag

        # first do compound objects

        if tag=='apply':		# apply operator
            opstr = gettag(xml[0])
            if opstr in opdict:
                op = opdict[opstr]
                args = [ self.make_sympy(x) for x in xml[1:]]
                try:
                    res = op(*args)
                except Exception,err:
                    self.args = args
                    self.op = op
                    raise Exception,'[formula] error=%s failed to apply %s to args=%s' % (err,opstr,args)
                return res
            else:
                raise Exception,'[formula]: unknown operator tag %s' % (opstr)

        elif tag=='list':		# square bracket list
            if gettag(xml[0])=='matrix':
                return self.make_sympy(xml[0])
            else:
                return [ self.make_sympy(x) for x in xml ]

        elif tag=='matrix':
            return sympy.Matrix([ self.make_sympy(x) for x in xml ])

        elif tag=='vector':
            return [ self.make_sympy(x) for x in xml ]

        # atoms are below

        elif tag=='cn':			# number
            return sympy.sympify(xml.text)
            return float(xml.text)

        elif tag=='ci':			# variable (symbol)
            if len(xml)>0 and (gettag(xml[0])=='msub' or gettag(xml[0])=='msup'):	# subscript or superscript
                usym = parsePresentationMathMLSymbol(xml[0])
                sym = sympy.Symbol(str(usym))
            else:
                usym = unicode(xml.text)
                if 'hat' in usym:
                    sym = my_sympify(usym)
                else:
                    if usym=='i': print "options=",self.options
                    if usym=='i' and 'imaginary' in self.options:	# i = sqrt(-1)
                        sym = sympy.I	
                    else:
                        sym = sympy.Symbol(str(usym))
            return sym

        else:				# unknown tag
            raise Exception,'[formula] unknown tag %s' % tag

    sympy = property(make_sympy,None,None,'sympy representation')

    def GetContentMathML(self,asciimath,mathml):
        # URL = 'http://192.168.1.2:8080/snuggletex-webapp-1.2.2/ASCIIMathMLUpConversionDemo'
        URL = 'http://127.0.0.1:8080/snuggletex-webapp-1.2.2/ASCIIMathMLUpConversionDemo'

        if 1:
            payload = {'asciiMathInput':asciimath,
                       'asciiMathML':mathml,
                       #'asciiMathML':unicode(mathml).encode('utf-8'),
                       }
            headers = {'User-Agent':"Mozilla/5.0 (Windows; U; Windows NT 5.1; en-US; rv:1.8.1.13) Gecko/20080311 Firefox/2.0.0.13"}
            r = requests.post(URL,data=payload,headers=headers)
            r.encoding = 'utf-8'
            ret = r.text
            #print "encoding: ",r.encoding

        # return ret
    
        mode = 0
        cmathml = []
        for k in ret.split('\n'):
            if 'conversion to Content MathML' in k:
                mode = 1
                continue
            if mode==1:
                if '<h3>Maxima Input Form</h3>' in k:
                    mode = 0
                    continue
                cmathml.append(k)
        # return '\n'.join(cmathml)
        cmathml = '\n'.join(cmathml[2:])
        cmathml = '<math xmlns="http://www.w3.org/1998/Math/MathML">\n' + unescape(cmathml) + '\n</math>'
        # print cmathml
        #return unicode(cmathml)
        return cmathml
    
#-----------------------------------------------------------------------------

def test1():
    xmlstr = '''
<math xmlns="http://www.w3.org/1998/Math/MathML">
   <apply>
      <plus/>
      <cn>1</cn>
      <cn>2</cn>
   </apply>
</math>
    '''
    return formula(xmlstr)

def test2():
    xmlstr = u'''
<math xmlns="http://www.w3.org/1998/Math/MathML">
   <apply>
      <plus/>
      <cn>1</cn>
      <apply>
         <times/>
         <cn>2</cn>
	 <ci>α</ci>
      </apply>
   </apply>
</math>
    '''
    return formula(xmlstr)

def test3():
    xmlstr = '''
<math xmlns="http://www.w3.org/1998/Math/MathML">
   <apply>
      <divide/>
      <cn>1</cn>
      <apply>
         <plus/>
         <cn>2</cn>
         <ci>γ</ci>
      </apply>
   </apply>
</math>
    '''
    return formula(xmlstr)

def test4():
    xmlstr = u'''
<math xmlns="http://www.w3.org/1998/Math/MathML">
  <mstyle displaystyle="true">
    <mn>1</mn>
    <mo>+</mo>
    <mfrac>
      <mn>2</mn>
      <mi>α</mi>
    </mfrac>
  </mstyle>
</math>
'''
    return formula(xmlstr)

def test5():		# sum of two matrices
    xmlstr = u'''
<math xmlns="http://www.w3.org/1998/Math/MathML">
  <mstyle displaystyle="true">
    <mrow>
      <mi>cos</mi>
      <mrow>
        <mo>(</mo>
        <mi>&#x3B8;</mi>
        <mo>)</mo>
      </mrow>
    </mrow>
    <mo>&#x22C5;</mo>
    <mrow>
      <mo>[</mo>
      <mtable>
        <mtr>
          <mtd>
            <mn>1</mn>
          </mtd>
          <mtd>
            <mn>0</mn>
          </mtd>
        </mtr>
        <mtr>
          <mtd>
            <mn>0</mn>
          </mtd>
          <mtd>
            <mn>1</mn>
          </mtd>
        </mtr>
      </mtable>
      <mo>]</mo>
    </mrow>
    <mo>+</mo>
    <mrow>
      <mo>[</mo>
      <mtable>
        <mtr>
          <mtd>
            <mn>0</mn>
          </mtd>
          <mtd>
            <mn>1</mn>
          </mtd>
        </mtr>
        <mtr>
          <mtd>
            <mn>1</mn>
          </mtd>
          <mtd>
            <mn>0</mn>
          </mtd>
        </mtr>
      </mtable>
      <mo>]</mo>
    </mrow>
  </mstyle>
</math>
'''
    return formula(xmlstr)

def test6():		# imaginary numbers
    xmlstr = u'''
<math xmlns="http://www.w3.org/1998/Math/MathML">
  <mstyle displaystyle="true">
    <mn>1</mn>
    <mo>+</mo>
    <mi>i</mi>
  </mstyle>
</math>
'''
    return formula(xmlstr,options='imaginaryi')