"""
Ryan James
Physics 250
Non-Linear Dynamics and Chaos
Final Project

phase modulated logistic map.  perhaps of interest:  i wrote this code while
listening to the battlehooch ep -- does the genre of music influence the way i
code?
"""

from __future__ import division, with_statement

from random import uniform

from enthought.traits.api import HasTraits, Range, Function, Property, Type

from numpy import sign, log, abs, sum, linspace, nan, inf, zeros, array
from numpy import histogram, hstack, mean, vectorize

from pylab import cla, title, xlabel, ylabel, grid, plot, show, imshow, gray

class PMLM(HasTraits):
	"""
	Simple map class.  Keeps track of its state and has its derivative at
	particular values for x0, x1.  Implements the iterator protocol for
	simplicity.
	"""

	x0 = Range(0.0, 1.0)
	x1 = Range(0.0, 1.0)

	r0 = Range(2.0, 4.0, 4.0)
	r1 = Range(0.0, 1.0, 0.0)

	def _x0_default(self):
		"""
		Default to a random value in the domain.
		"""
		return uniform(0, 1)

	def _x1_default(self):
		"""
		Default to the same value as x0 -- note, this ensures that the first
		iteration behaves exactly like the standard logistic map since
		sign(x1 - x0) will be 0.
		"""
		return self.x0

	def f0(self, x0, x1):
		"""
		Iteration function for x0.
		"""
		return (self.r0 + (4 - self.r0)*self.r1*sign(x1 - x0))*x0*(1 - x0)

	def f1(self, x0, x1):
		"""
		Iteration function for x1.
		"""
		return x0

	def dx(self, (x0, x1)):
		"""
		The derivative of this map at x0, x1.
		"""
		return (self.r0 + (4 - self.r0)*self.r1*sign(x1 - x0))*(1 - 2*x0)

	def __iter__(self):
		"""
		Support that iterator protocol.
		"""
		return self

	def next(self):
		"""
		More of the iterator protocol.
		"""
		self.x0, self.x1 = self.f0(self.x0, self.x1), self.f1(self.x0, self.x1)
		return self.x0, self.x1

	def trajectory(self, length):
		"""
		Return a trajectory of length.
		"""
		return [ self.next() for i in xrange(length) ]

class Project(HasTraits):
	"""
	A class to hold all the nice little tools needed to analyse the pmlm.
	"""

	# keep the constructor for the map at hand.
	map = Type(PMLM)

	def lyapunov_exp(self, r0 = 4.0, r1 = 0.0, accuracy = 500):
		"""
		Return the lyapunov exponent for particular values of r0, r1.  Return
		nan instead of -inf to aid with matplotlib since it will silently drop
		nans when plotting, whereas -inf will screw the scale of any plot.
		"""
		pmlm = self.map(r0 = r0, r1 = r1)
		transients = pmlm.trajectory(200)
		traj = pmlm.trajectory(accuracy)
		l = mean(log(abs(map(pmlm.dx, traj))))
		return l if l > -inf else nan

	def _lyapunov(self, low = 3.0, high = 4.0, res = 2**10, r1 = 0.0):
		"""
		Return a line of lyapunov exponents -- purely a utility for the next
		method.
		"""
		rs = linspace(low, high, res)
		ls = array([ self.lyapunov_exp(r0 = r, r1 = r1) for r in rs ])
		return rs, ls

	def lyapunov(self, low = 3.0, high = 4.0, res = 2**10, r1 = 0.0):
		"""
		Plot the lyapunov exponent dynamics for a particular range of r0 and at
		a particular value of r1.
		"""
		rs, ls = self._lyapunov(low, high, res, r1)
		cla()
		title(r'Liapunov Exponent Dynamics for $r_1 = %1.3f$' % r1)
		xlabel(r'Control Parameter $r_0$')
		ylabel(r'Lyapunov Exponent $\lambda$')
		grid()
		plot(rs, ls)
		show()

	def chaos_map(self, r0 = (3.0, 4.0), r1 = (0.0, 1.0), res = (2**9, 2**9)):
		"""
		Create a full image of the lyapunov exponents in parameter space.
		"""
		img = zeros(res)
		for i, r_1 in enumerate(linspace(r1[0], r1[1], res[0])):
			for j, r_0 in enumerate(linspace(r0[0], r0[1], res[1])):
				img[i, j] = self.lyapunov_exp(r_0, r_1)
		cla()
		title(r'Lyapunov Exponent Space')
		xlabel(r'Control Parameter $r_0$')
		ylabel(r'Control Parameter $r_1$')
		imshow(img, extent = r0 + r1)

	def bifurcation0(self, low = 3.0, high = 4.0, res = (1024, 768), r1 = 0.0):
		"""
		Create a bifurcation diagram along r0.
		"""
		bins = linspace(0, 1, res[1]+1)
		rs = linspace(low, high, res[0])
		maps = [ self.map(r0 = r, r1 = r1) for r in rs ]
		trajs = [ map.trajectory(10*res[1] + 100)[100:] for map in maps ]
		hists = [ histogram(t, bins, new=True, normed=True)[0] for t in trajs ]
		hists = [ hist.reshape((res[1], 1))/max(hist) for hist in hists ]
		bifur = 1 - hstack(hists)
		cla()
		title(r'Bifurcation Diagram for $r_1 = %1.3f$' % r1)
		xlabel(r'Control Parameter $r_0$')
		ylabel(r'Orbit Values')
		gray()
		imshow(bifur, extent = (low, high, 0, 1))

	def bifurcation1(self, low = 0.0, high = 1.0, res = (1024, 768), r0 = 3.0):
		"""
		Create a bifurcation diagram along r1.  This and the previous really
		should be rolled into a single function, but since I never really used
		this one I didn't feel like spending the time to do it elegantly.
		"""
		bins = linspace(0, 1, res[1]+1)
		rs = linspace(low, high, res[0])
		maps = [ self.map(r0 = r0, r1 = r) for r in rs ]
		trajs = [ map.trajectory(10*res[1] + 100)[100:] for map in maps ]
		hists = [ histogram(t, bins, new=True, normed=True)[0] for t in trajs ]
		hists = [ hist.reshape((res[1], 1))/max(hist) for hist in hists ]
		bifur = 1 - hstack(hists)
		cla()
		title(r'Bifurcation Diagram for $r_0 = %1.3f$' % r0)
		xlabel(r'Control Parameter $r_1$')
		ylabel(r'Orbit Values')
		gray()
		imshow(bifur, extent = (low, high, 0, 1))