from pylab import *
from Numeric import *
from visual import vector
from Agent import *

#Project the 3d data from vector x onto the x,y plane for easier visualization
def TwoDProjection(x):
    projectionMatrix = array([[ -1./2.,1./2.,0.],
	   [ 0,0,sqrt(3.)/2.]])
    return dot(projectionMatrix,x)

#Step size
dt = .1

#Number of iterates to throw away
nTransients = 40000

#Total number of iterates to integrate
nIterates = 50000

# Tiebreak parameter for reward matrix (between -1 and 1)
e = [0.5,-0.5]

#Memory loss rate for each agent. (0 = perfect memory)
alpha = [0.0198,.01]

#Adaptation rate for each agent
beta = [1.0,1.0]

rewardX = array([[2./3*e[0], 1 - 1./3*e[0], -1 - 1./3*e[0]],
                 [-1 - 1./3*e[0], 2./3*e[0], 1 - 1./3*e[0]],
                 [1 - 1./3*e[0], -1 - 1./3*e[0], 2./3*e[0]] ])

rewardY = array( [ [2./3*e[1], 1 - 1./3*e[1], -1 - 1./3*e[1]],
                   [-1 - 1./3*e[1], 2./3*e[1], 1 - 1./3*e[1]],
                   [1 - 1./3*e[1], -1 - 1./3*e[1], 2./3*e[1]] ])

X = Agent(alpha[0],beta[0],rewardX)
Y = Agent(alpha[1],beta[1],rewardY)

# Set initial conditions.
u = -log(vector(0.5,0.2,0.3))
v = -log(vector(0.1,0.5,0.4))

xVals = [vector(0,0,0)] * nIterates
yVals = [vector(0,0,0)] * nIterates

for i in xrange(1,nIterates):
    u = X.RK4(yVals[i-1],u,dt)
    v = Y.RK4(xVals[i-1],v,dt)
    xVals[i] = exp(-u)/sum(exp(-u))
    yVals[i] = exp(-v)/sum(exp(-v))

# Initialize arrays for holding transformed coordinates
XTrans = []
YTrans = []

# Transform the coordinates for plotting
for i in range(nTransients,nIterates):
#for i in range(nTransients,nIterates):
    XTrans.append(TwoDProjection(xVals[i]))
    YTrans.append(TwoDProjection(yVals[i]))
    
XTrans = array(XTrans)
YTrans = array(YTrans)

# Plot the solutions
plot(XTrans[:,0],XTrans[:,1])
plot(YTrans[:,0]+1,YTrans[:,1])
#print XTrans
# Plots the triangular phase space
plot([-.5,0.,.5,-.5],[0.,.5*sqrt(3.),.0,.0])
plot([-.5+1.,0.+1.,.5+1.,-.5+1.],[0.,.5*sqrt(3.),.0,.0])

show()
