# OneDMapClasses.py

# import in a try-except statement in case a module is not installed
try:	
	import numpy as np
	import matplotlib.pyplot as plt

except ImportError, msg:	# catch any import exceptions, print error message
	print 'Error importing modules:', msg

class OneDimMap():
	"""
	Base class for all map classes. This provides 
	methods common to all map classes.
	"""

	def __init__(self):
		"""Constructor for the OneDMap base class."""

		# list of defaults in this order: [ ic, r low, r high ]
		self.defaults = [0.0, 0.0, 1.0]

		self.resetVals()	# set map values to defaults

		self.mapStr = 'One Dimensional Map f(x) = x'	# place holder for map string

	def __repr__(self):
		"""String representation of map."""

		return self.mapStr	# defined in __init__

	def func(self, x):
		"""Returns f(x). This is a place holder for all map classes."""

		return x	# in this case, f(x) = x

	def iterate(self, num_its):
		"""Iterates the map num_its times and saves the state."""

		for n in xrange(num_its):
			self.state = self.func(self.state)	# compute the maps's function
		# state of last iteration is saved in self.state

	def plotBifn(self):
		"""Generate and display a bifurcation plot of the map."""

		plt.figure()
		# Title for plot
		plt.title(self.mapStr + ', bifurcation diagram for r in [%g,%g]' \
			% (self.rLow,self.rHi))
		# Label axes
		plt.xlabel('Control parameter r')
		plt.ylabel('{X(n)}')

		# The iterates we'll throw away
		nTransients = 200
		# This sets how much the attractor is filled in
		nIterates = 250
		# This sets how dense the bifurcation diagram will be
		nSteps = 400
		# Sweep the control parameter over the desired range
		rInc = (self.rHi-self.rLow)/float(nSteps)

		for self.r in np.arange(self.rLow, self.rHi, rInc):

			# Set the initial condition to the reference value
			self.state = self.ic
			# Throw away the transient iterations
			self.iterate(nTransients)
			# Now store the next batch of iterates
			rsweep = np.repeat(self.r, nIterates)     # array of parameter values
			x = [ ]                                   # The iterates
			for i in xrange(nIterates):
				self.iterate(1)
				x.append( self.state )
			plt.plot(rsweep, x, 'k,') # Plot the list of (r,x) pairs as pixels

		plt.show()

	def resetVals(self):
		"""Set map values to their respective defaults."""

		# set default values for state and parameters/limits
		self.ic = self.defaults[0]
		self.state = self.defaults[0]
		self.r = self.defaults[1]
		self.rLow = self.defaults[1]
		self.rHi = self.defaults[2]

## end class OneDMap ##

class LogisticMap(OneDimMap):
	"""The logistic map class."""

	def __init__(self):
		"""Constructor for a LogisticMap object."""

		OneDimMap.__init__(self)	# invoke constructor of base class

		# list of defaults in this order: [ ic, r low, r high ]
		self.defaults = [0.3, 0.0, 4.0]

		self.resetVals()

		# descriptive string of map
		self.mapStr = 'Logistic Map: f(x) = r*x*(1-x)'

	def func(self, x):
		"""Computes the logistic map."""

		return self.r * self.state * (1.0 - self.state)

## end class LogisticMap ##

class CosineMap(OneDimMap):
	"""The cosine map class."""

	def __init__(self):
		"""Constructor for a CosineMap object."""

		OneDimMap.__init__(self)	# invoke constructor of base class

		# list of defaults in this order: [ ic, r low, r high ]
		self.defaults = [0.3, -5.0, 5.0]

		self.resetVals()

		# descriptive string of map
		self.mapStr = 'Cosine Map: f(x)=r*cos(x)'

	def func(self, x):
		"""Computes the cosine map."""

		return self.r * np.cos(self.state)

## end class CosineMap ##
