# Functions for the Henon-Hieles project


# Call module
import numpy as np
import string as st
import pylab as pl

# Runge-Kutta 4th order integration
# dx/dt = f(a,b,c,d,x,y,px,py)
# dy/dt = g(a,b,c,d,x,y,px,py)
# dpx/dt = p(a,b,c,d,x,y,px,py)
# dpy/dt = q(a,b,c,d,x,y,px,py)

def RK4DIntegrator(a,b,c,d,x,y,px,py,f,g,p,q,dt):
    k1x = dt * f(a,b,c,d,x,y,px,py)
    k1y = dt * g(a,b,c,d,x,y,px,py)
    k1px = dt * p(a,b,c,d,x,y,px,py)
    k1py = dt * q(a,b,c,d,x,y,px,py)
    k2x = dt * f(a,b,c,d,x + k1x / 2.0,y + k1y / 2.0,px + k1px / 2.0,py + k1py / 2.0)
    k2y = dt * g(a,b,c,d,x + k1x / 2.0,y + k1y / 2.0,px + k1px / 2.0,py + k1py / 2.0)
    k2px = dt * p(a,b,c,d,x + k1x / 2.0,y + k1y / 2.0,px + k1px / 2.0,py + k1py / 2.0)
    k2py = dt * q(a,b,c,d,x + k1x / 2.0,y + k1y / 2.0,px + k1px / 2.0,py + k1py / 2.0)
    k3x = dt * f(a,b,c,d,x + k2x / 2.0,y + k2y / 2.0,px + k2px / 2.0,py + k2py / 2.0)
    k3y = dt * g(a,b,c,d,x + k2x / 2.0,y + k2y / 2.0,px + k2px / 2.0,py + k2py / 2.0)
    k3px = dt * p(a,b,c,d,x + k2x / 2.0,y + k2y / 2.0,px + k2px / 2.0,py + k2py / 2.0)
    k3py = dt * q(a,b,c,d,x + k2x / 2.0,y + k2y / 2.0,px + k2px / 2.0,py + k2py / 2.0)
    k4x = dt * f(a,b,c,d,x + k3x,y + k3y,px + k3px,py + k3py)
    k4y = dt * g(a,b,c,d,x + k3x,y + k3y,px + k3px,py + k3py)
    k4px = dt * p(a,b,c,d,x + k3x,y + k3y,px + k3px,py + k3py)
    k4py = dt * q(a,b,c,d,x + k3x,y + k3y,px + k3px,py + k3py)
    
    x = x + ( k1x + 2.0 * k2x + 2.0 * k3x + k4x ) / 6.0
    y = y + ( k1y + 2.0 * k2y + 2.0 * k3y + k4y ) / 6.0
    px = px + ( k1px + 2.0 * k2px + 2.0 * k3px + k4px ) / 6.0
    py = py + ( k1py + 2.0 * k2py + 2.0 * k3py + k4py ) / 6.0
    
    return x,y,px,py

# The differential equations
# From the Hamiltonian
# H=1/2*(px^2+py^2+ax^2+by^2)+dy*x^2-cy^3/3
def dxdt(a=0,b=0,c=0,d=0,x=0,y=0,px=0,py=0):
    return px

def dydt(a=0,b=0,c=0,d=0,x=0,y=0,px=0,py=0):
    return py

def dpxdt(a=1.0,b=1.0,c=1.0,d=1.0,x=0,y=0,px=0,py=0):
    return -a * x - 2 * d *x *y

def dpydt(a=1.0,b=1.0,c=1.0,d=1.0,x=0,y=0,px=0,py=0):
    return -b * y - d * (x ** 2) + c * (y ** 2)


# Energy equations builde on E=1/2*(px^2+py^2+x^2+y^2)+y*x^2-y^3/3
# Determain 4th variable using energy conservation
# All those function are of the form fun(a,b,c,d,E,x,y,z) where x,y,z are the 
# 3 known quantities

def Ex(a=1.0,b=1.0,c=1.0,d=1.0,E=0.1667,y=0.0,py=0.0,px=0.0):
    temp1 = 2*E - px**2 - py**2 + c*(y**3)*2.0/3.0 - b*(y**2)
    temp2 = a + d*y
    return sqrt(temp1)/temp2

def Ey(a=1.0,b=1.0,c=1.0,d=1.0,E=0.1667,x=0.0,py=0.0,px=0.0):
    # I got this solution using mathematica
    t = 2*E - px**2 - py**2 - a*(y**2)
    tc = c*2.0/3.0
    td = d*(x**2)
    y1 = 2**(1.0/3.0) *(-b**2 - 3*c*d)
    y2 = np.sqrt(4*(-b**2 - 3*c*d)**3 + (-2*b**3 - 9*b*c*d + 27*(c**2)*f)**2)
    y3 = -2*b**3 - 9.0*b*c*d + 27*(c**2)*f
    y4 = 3.0*(2**(1.0/3.0))*c      
    y5 = b/(3.0 * c)
    y6 = y1 / (3.0*c*(y3 + y2)**(1.0/3.0))
    y9 = (y3 + y2)**(1.0/3.0)/y4
    y = y5 + y6 - y9
    return y

def Epx(a=1.0,b=1.0,c=1.0,d=1.0,E=0.1667,x=0.0,y=0.0,py=0.0):
    tpx = 2*E - py**2 - a*(x**2) - b*(y**2) - d*y*(x**2) + c*(y**3)*2.0/3.0
    epx = tpx**0.5
    return epx

def Epy(a=1.0,b=1.0,c=1.0,d=1.0,E=0.1667,x=0.0,y=0.0,px=0.0):
    tpy = 2*E - px**2 - a*(x**2) - b*(y**2) - d*y*(x**2) + c*(y**3)*2.0/3.0
    epy = tpy**0.5
    return epy

# Jacobian determinant and Eigen Values

def jDet(x,y,a=1.0,b=1.0,c=1.0,d=1.0):
    return a*b - 4*(d**2)*(x**2) + 2*a*c*y + 2*b*d*y + 4*c*d*(y**2)

def eg1(x,y,a=1.0,b=1.0,c=1.0,d=1.0):
    e1 = np.sqrt((a+b+2*c*y+2*d*y)**2 - 4*(a*b-4*(d**2)*(x**2)+2*a*c*y+2*d*y+4*c*d*(y**2)))
    e2 = -a - b - 2.0 * c * y - 2.0 * d * y
    e3 = -0.5 * (e2 - e1)**0.5
    return e3

def eg2(x,y,a=1.0,b=1.0,c=1.0,d=1.0):
    e1 = np.sqrt((a+b+2*c*y+2*d*y)**2 - 4*(a*b-4*(d**2)*(x**2)+2*a*c*y+2*d*y+4*c*d*(y**2)))
    e2 = -a - b - 2.0 * c * y - 2.0 * d * y
    e3 = 0.5 * (e2 - e1)**0.5
    return e3

def eg3(x,y,a=1.0,b=1.0,c=1.0,d=1.0):
    e1 = np.sqrt((a+b+2*c*y+2*d*y)**2 - 4*(a*b-4*(d**2)*(x**2)+2*a*c*y+2*d*y+4*c*d*(y**2)))
    e2 = -(a/2.0) - b/2.0 - c * y - d * y
    e3 = -0.5 * (e2 + 0.5 * e1)**0.5
    return e3

def eg4(x,y,a=1.0,b=1.0,c=1.0,d=1.0):
    e1 = np.sqrt((a+b+2*c*y+2*d*y)**2 - 4*(a*b-4*(d**2)*(x**2)+2*a*c*y+2*d*y+4*c*d*(y**2)))
    e2 = -(a/2.0) - b/2.0 - c * y - d * y
    e3 = 0.5 * (e2 + 0.5 * e1)**0.5
    return e3

# Equipotential points - Using the potential V = 1/2 * (Ax^2+By^2+2Dx^2y-Cy^3/3)
def xV(E,y,a,b,c,d):
    x1 = (2.0*E + c * (y**3) /3.0 - b * (y**2))**0.5
    x2 = (a + 2 * d * y)**0.5   
    return x1/x2

# This function helps in uniformally distirbuting the initial (x,y) 
# in the 1/6 equipotential region
# mapping points that are out of our area into it, using symmetry
def symmXY(x,y):
    # yy = aa * x + 1
    yy = x * ((-0.5)-1.0)/(np.sqrt(3.0)/2.0) + 1.0
    if y > yy:
        x = -x
        y = 0.25 - ( y - 0.25)
    return [x,y]

def EE(x,y,px,py,a=1.0,b=1.0,c=1.0,d=1.0):
    return 0.5*(px**2+py**2+a*(x**2)+b*(y**2))+d*y*(x**2)-(y**3)/3.0

def pType():
    print '\n'
    print 'Collecting data for Henon-Heiles hamiltonian dynamics exploration'
    # Plot Poincare map or 3D graph
    print 'Choose:'
    print '(1) Poincare plot'
    print '(2) 3D plot'
    stemp = raw_input('[1]: ')
    if stemp=='2':
        PlotType = '3D'
    else:
        PlotType = 'PC'    
    return PlotType

def pAxis(PlotType):
    '''
    What axis to plot
    '''
    if PlotType == '3D':
        print '\n'
        print 'Choose 3D plot options:'
        print '(1) Axis - y,py,x'
        print '(2) Axis - x,px,y'
        stemp = raw_input('[1]: ')
        if stemp =='2':
            PlotAxis = ['x','p_x','y']
        else:    
            PlotAxis = ['y','p_y','x']
            
    if PlotType == 'PC':
        print '\n'
        print 'Choose Pointcare plot options:'
        print '(1) Axis - py vs. y'
        print '(2) Axis - px vs. x'
        stemp = raw_input('[1]: ')
        if stemp =='2':
            PlotAxis = ['x','p_x']
        else:    
            PlotAxis = ['y','p_y']
    return PlotAxis

def HamConst():
    print '\n'
    print 'The Hamiltonian we a re working with is : E=1/2*(px^2+py^2+A*x^2+B*y^2)+D*y*x^2-C*y^3/3'
    stemp = raw_input ('Would you like to modify the Hamiltonian constans A,B,C,D [N]')
    if stemp == ('y' or 'Y'):
        stemp = raw_input('A[1.0]: ')
        if stemp == '':
            a = 1.0
        else:
            a = float(stemp)
        stemp = raw_input('B[1.0]: ')
        if stemp == '':
            b = 1.0
        else:
            b = float(stemp)
        stemp = raw_input('C[1.0]: ')
        if stemp == '':
            c = 1.0
        else:
            c = float(stemp)
        stemp = raw_input('D[1.0]: ')
        if stemp == '':
            d = 1.0
        else:
            d = float(stemp)                    
    else:
        a = 1.0
        b = 1.0
        c = 1.0
        d = 1.0
    return a,b,c,d

def eLevel():
    print '\n'
    print 'Interesting energy range : 0 < E < 1/6'
    print 'E = 1/12 regular orbit'
    print 'E = 1/9 Close to chaos onset'
    print 'E = 1/8 Mixed phase space'
    print 'E = 1/6 Edge of potential '
    stemp = raw_input('E = [1/12] ')
    if stemp=='':
        E = 1.0/12.0    
        Estr = '1/12'
    else:
        sE = st.split(stemp,'/')
        Estr = stemp
        E1 = float(sE[0])
        if len(sE) > 1:
            E2 = float(sE[-1])
        else:
            E2 = 1.0
        try:
            E1/E2
        except ValueError:
            print 'Plase check the E value, %.4f / %.4f' % (E1,E2)
        else:
            E = E1/E2
    return E, Estr

def nIter(PlotType):
    if PlotType == '3D':
        nIterate = 250000
    else:
        nIterate = 1000000
    print '\n'
    print 'The integration using Runge-Kutta mothod'
    inpStr = 'Iteration number [%g]: ' % nIterate
    stemp = raw_input(inpStr)
    if stemp != '':
        nIterate = int(stemp)
    return nIterate

def setDt(DT):
    dt = DT
    inpStr = 'dt step size [%g]: ' % dt
    stemp = raw_input(inpStr)
    if stemp != '':
        dt = float(stemp)
    return dt
                      
def yRange(E):
    # Setting valid range for y
    # I could calculate it but to save time
    # I just use some preset ranges

    if (E >= 1./6.) and (E >= 1./7.):
        yMin = -0.4778
        yMax = 0.76
    elif (E >= 1./7.) and (E >= 1./8.):
        yMin = -0.438
        yMax = 0.673
    elif (E >= 1./8.) and (E >= 1./10.):
        yMin = -0.397
        yMax = 0.567
    elif (E >= 1./10.) and (E >= 1./12.):
        yMin = -0.366
        yMax = 0.49999
    elif (E >= 1./12.) and (E >= 1./16.):
        yMin = -0.32
        yMax = 0.415
    elif (E >= 1./16.) and (E >= 1./24.):
        yMin = -0.266
        yMax = 0.326
    else:
        print 'Please consider other E level'
    return yMin,yMax