from scipy import arange
import matplotlib.pyplot as plt

# This plots a bifurcation diagram of a discretized map as a function of either the control parameter or the grid scaling parameter

def BifnPlot(map,r,N,ic,iterations,transients):

	# if r and N are both lists, this is interpreted as attempting to vary both parameters in the bifurcation diagram
	if type(r)==type([]) and type(N)==type([]):
		raise NotImplementedError, '3D bifurcation plots not yet supported.'
	
	# if r  and N are both numbers, this is simply an orbit diagram for fixed r,N
	elif type(r)==type(1.) and type(N)==(1.):
		# Here the transient iterates are 'skipped'
		map.iterate(ic,r,N,transients,'reset')
		
		xvals=[ ]	# An array of x values
		
		# Here the iterates to be plotted are calculated
		for i in xrange(iterations):
			x=map.iterate(ic,r,N,1,'hold')
			xvals.append(x)
			
		plt.plot(r,xvals,'b.',ms=.01)
		
		plt.title(map.titlestring + ' r=' + `r` + ', N=' + `N`)
		plt.xlabel('Control Parameter r')
		plt.ylabel('x')
	
	# if r is a list and N is a number, we vary r in the bifurcation diagram
	elif type(r)==type([]) and type(N)==type(1.):
		# this checks to see if a parameter increment value was input. if not, it assumes 128 steps
		if len(r)<=2: rinc=float(r[1]-r[0])/128.
		else: rinc=r[2]
		
		# This is the main loop. It constructs the plot iteratively, to be shown upon completion
		for rval in arange(r[0],r[1],rinc):
		
			# Here the transient iterates are 'skipped'
			map.iterate(ic,rval,N,transients,'reset')
			
			rvals=[ ]  # An array of r values, to be used in the plot
			xvals=[ ]	# A similar array of x values
		
			# Here the iterates to be plotted are calculated
			for i in xrange(iterations):
				x=map.iterate(ic,rval,N,1,'hold')
				rvals.append(rval)
				xvals.append(x)
			
			plt.plot(rvals,xvals,'b.',ms=.01)
		
		# The plot is formatted and then shown
		plt.title(map.titlestring + ' N=' + `N`)
		plt.xlabel('Control Parameter r')
		plt.ylabel('x')
	
	# if r is a number and N a list, we vary N in the bifurcation diagram.	
	elif type(r)==type(1.) and type(N)==type([]):
		# this checks to see if a parameter increment value was input. if not, it assumes 128 steps
		if len(N)<=2: Ninc=float(N[1]-N[0])/128.
		else: Ninc=N[2]
		
		# This is the main loop. It constructs the plot iteratively, to be shown upon completion
		for Nval in arange(round( N[0] ),round( N[1] ),1):
		
			# Here the transient iterates are 'skipped'
			map.iterate(ic,r,Nval,transients,'reset')
			
			Nvals=[ ]  # An array of N values, to be used in the plot
			xvals=[ ]	# A similar array of x values
		
			# Here the iterates to be plotted are calculated
			for i in xrange(iterations):
				x=map.iterate(ic,r,Nval,1,'hold')
				Nvals.append(Nval)
				xvals.append(x)
			
			plt.plot(Nvals,xvals,'b.',ms=.01)
			
		# The plot is formatted and then shown
		plt.title(map.titlestring + ' r=' + `r`)
		plt.xlabel('Grid points N')
		plt.ylabel('x')