# Function Definitions for Walker.py
# Filename: WalkerDefs.py


#from numpy import array, arange, size, shape, cos, sin, zeros, sqrt
from numpy import array, arange, size, shape, cos, sin, zeros, sqrt, \
        dot, transpose, resize
import numpy.linalg as LA
from scipy.optimize import fsolve

# t(end) 1st step = 3.843435
# t(end) fixed point = 3.88246
def pause():
    reply = raw_input('Paused.  Press any key to continue')




def RK3D2(f, X, params, dt):
    '''
    4th order, fixed step size Runge-Kutta integration algorithm
    '''
    k1 = dt * f(X, params)
    k2 = dt * f(X + k1/2.0, params)
    k3 = dt * f(X + k2/2.0, params)
    k4 = dt * f(X + k3, params)
    X = X + (k1 + 2.0 * k2 + 2.0 * k3 + k4) / 6.0
    return X





def EOM(X, params):
    '''
    the equations of motion (EOM) of Garcia's simplest PDW
    X[0] = theta; X[1] = phi; X[2] = thetadot; X[3] = phidot
    X1 = thetadot; X2 = phidot; X3 = thetaddot; X4 = phiddot
    '''
    gamma = params
    theta, phi, thetadot, phidot = X
    dX1dt = thetadot
    dX2dt = phidot
    dX3dt = sin( theta - gamma )
    dX4dt = dX3dt + thetadot**2*sin(phi) - cos(theta-gamma)*sin(phi)
    Xdot = array([dX1dt, dX2dt, dX3dt, dX4dt])
    return Xdot




def FootTrack(theta, phi, L):
    '''
    kinematical equations describing the x,y-position of the swing foot.
    this information is used in conjunction with the "Event" function
    (see below) to determine when foot-strike occurs.
    '''
    x_pos = -L * ( sin(theta) + sin(phi - theta) )
    y_pos = L * ( cos(theta) - cos(phi - theta) )
    return x_pos, y_pos




def Event(X, params, dt, Xk, Time):
    '''
    This algorithm determines when foot-strike occurs.  Specifically, it
    checks when the y-position of the swing foot goes from (+) to (-). 
    since the y-position swing foot may change sign more than once during
    a step, conditions are made to ignore the non-interest events.  this is
    done with nested if statements.  the first one means that, the stance leg
    must rotate past vertical (i.e., theta < 0) AND the x-position of the swing
    foot must be some small distance in front of the stance foot before Events
    will even be considered.  the second if-AND statement identifies when the
    y-position of the swing foot crosses zero (from (+) to (-)).
    '''
    theta = X[0]
    phi = X[1]
    L = params[1]
    x_pos,y_pos = FootTrack(theta, phi, L)
    contact = 0
    t = Time[-1]
    if (theta < 0) and (x_pos >= 0.10*L):
        lastx,lasty = FootTrack(Xk[0],Xk[1], L)
        if (lasty >= 0) and (y_pos < 0):
            X = Xk
            contact = 1
            dt = dt/10.0
        else:
            t += dt
            Time.append(t)
    else:
       t += dt
       Time.append(t)
    return X, contact, dt, Time




def Transition(Xpre):
    '''
    The Transition equations are used to calculate the post-impact state.
    A perfectly plastic collision is assumed to occur hen the swing foot
    strikes the ground.  The impulse acting on the swing foot causes
    instantaneous changes in the angular velocities of the system.
    '''
    A = array([ [-1., 0., 0., 0.], [-2.0, 0., 0., 0. ], \
            [0., 0., cos(2.0*Xpre[0]), 0.], \
            [0., 0., cos(2.*Xpre[0])*(1.-cos(2.*Xpre[0])), 0.] ])
    Xpost = dot(A,Xpre)
    return Xpost




def StrideMapError(X, *pargs):
    '''
    The StrideMapError computes the difference in the state vector of the
    nth + 1 step and the nth step.  The StrideMap involves: (1) integrating
    the EOM during the swing phase, (2) detecting foot-strike, and (3)
    calculating the post-impact states.  In order to find a limit cycle, we
    must choose a state vector that makes the StrideMapError = 0.
    '''
    err_tol = 1.0e-12
    Xto, Xk, params, dt, err, Time = pargs
    Xto = X
    while err >= err_tol:
        Xk = X
        X = RK3D2( EOM, X, params[0], dt )
        X, contact, dt, Time = Event( X, params, dt, Xk, Time )
        err = abs( X[1] - 2.0*X[0] )
        """
        if contact == 0:
            TrajTheta.append( X[0] )
            TrajPhi.append( X[1] )
            TrajThetadot.append( X[2] )
            TrajPhidot.append( X[3] )
        """
    Xpost = Transition(X)
    SME = Xpost - Xto
    #print 'SME', SME
    return SME




def StrideMap(X, pargs):
    err_tol = 1.0e-12
    #err_tol = 1.0e-6
    Xto, Xk, params, dt, err, Time = pargs
    Xto = X         # state at "toe off" (beginning of step)
    #print 'X in StrideMap is:',X
    while err >= err_tol:
        Xk = X      # current state before each integration step
        X = RK3D2( EOM, X, params[0], dt )
        X, contact, dt, Time = Event( X, params, dt, Xk, Time )
        err = abs( X[1] - 2.0*X[0] )
    Xpost = Transition(X)
    return Xpost




def Jacob_StrideMap(f,x, pargs):
    Xto, Xk, params, dt, err, Time = pargs
    h = 1.0e-6
    n = 4
    jacob = zeros((n,n),float)
    f0 = f(x,(Xto, Xk, params, dt, err, Time))
    #print 'f0 is:',f0
    #print 'x is:',x
    #print 'the value of f0 is:\n',f0
    for i in xrange(n):
        temp = x[i]
        x[i] = temp + h
        f1 = f(x,(Xto, Xk, params, dt, err, Time))
        #print 'f1 is:',f1
        #print 'the value of f1 is:\n',f1
        x[i] = temp
        jacob[:,i] = (f1 - f0)/h
    return jacob,f0

def SwingPhase(X, pargs):
#def SwingPhase(X,Xto, Xk, params, dt, err, TrajTheta, TrajPhi,\
#        TrajThetadot, TrajPhidot, Time):
    err_tol = 1.0e-12
    #err_tol = 1.0e-6
    Xto, Xk, params, dt, err, Time = pargs
    Xto = X
    while err >= err_tol:
        Xk = X
        X = RK3D2( EOM, X, params[0], dt )
        X, contact, dt, Time = Event( X, params, dt, Xk, Time )
        err = abs( X[1] - 2.0*X[0] )
        if contact == 0:
            TrajTheta.append( X[0] )
            TrajPhi.append( X[1] )
            TrajThetadot.append( X[2] )
            TrajPhidot.append( X[3] )
    return Time, TrajTheta, TrajPhi, TrajThetadot, TrajPhidot





def NRA(f,x, pargs):
    '''
    #Newton-Raphson Algorithm
    '''
    #err_tol = 1.0e-12
    err_tol = 1.0e-6
    def jacobian(f,x, pargs):
        Xto, Xk, params, dt, err, Time = pargs
        h = 1.0e-6
        n = len(x)
        jacob = zeros((n,n),float)
        f0 = f(x,(Xto, Xk, params, dt, err, Time))
        #print x
        for i in xrange(n):
            temp = x[i]
            x[i] = temp + h
            f1 = f(x,(Xto, Xk, params, dt, err, Time))
            x[i] = temp
            jacob[:,i] = (f1 - f0)/h
        return jacob,f0

    Xto, Xk, params, dt, err, Time = pargs
    for i in xrange(30):
        jac, f0 = jacobian(f,x,(Xto, Xk, params, dt, err, Time) )
        if sqrt(dot(f0,f0)/len(x)) < err_tol: return x, jac
        dx = -dot(LA.inv(jac),f0)
        x = x + dx
        #if sqrt(dot(dx,dx)) < err_tol*max(abs(x),1.0): return x,jac
        if sqrt(dot(dx,dx)) <= err_tol: return x, jac
    print 'Too many iterations required to converge'



