# CNN_lattice_simulator.py
# This program leads the user through the creation, manipulation, and simulation of SETJ array dynamics
# This program depends on many other programs and files, but the real work-horse is TPL_CNN.py
# Import modules
	#
#import getopt,sys
import time
from random import random
from numpy import cos, sin, pi, array, mod, ones, zeros, remainder	
###################
## ** WARNING ** ## 	Do NOT use "from numpy import *" while running this program.  It will mess it up and give you a headache trying to figure out what is going on... trust me.
###################

from TPL_CNN import *	# This is the key module for the SETJ dynamics
try:
		# Disable error checking for increased performance
	from pyglet import options
	options['debug_gl'] = False
	from pyglet import window
	from pyglet.gl import *
except ImportError:
	raise ImportError, "PyGlet package required."
from MatrixIO import *	#includes txt2array(txtfile) & array2txt(TwoD_array,txt_outfile,a_or_w)
from PIL import Image

class parameters:
	
	def __init__(self):
		self.binary = 0#1  # phase output can either be binary (1) or grey scale (0)
		self.randomIC = 0	 # initial tunneling voltage can be == u_distr (0), or random (1)
		self.b = 2.
		self.gamma = 0.1
		self.epsilon = 0.1
		self.k = 0.1#10.
		self.dt = 0.1
		self.capture_images = 0
		self.capture_image_name = 'none'
		#self.u_distribution = 'square1'
	
	def refresh_param(self,M,N):	# must refresh parameters for MyArray to respond to change
		MyArray.new_param(M,N,self)	
		# u? / a?
		# M? / N?

def change_param(M,N):
	paramlist = ['binary','randomIC','b','gamma','epsilon','k','dt', 'capture_images']
	print 'Here is a list of current parameters for the array:\n' 
	k = 0
	for item in paramlist:		# display list of types available
		k +=1
		print '\t[%i] %s = %s\n' % (k, paramlist[k-1], `eval('param.'+paramlist[k-1])`)		# e.g.  '\t[1] constant\n'
		
	asker = raw_input('Do you want to change the current parameters? y/[n]  ')
	if asker == 'y':
		asker = int(input('Please choose one of the previous parameters by reference number:  '))
		if asker in range(1,k+1):
			if asker < 3:	# for those inputs that require integers
				asker2 = int(input('Please enter a new value for %s:  ' % paramlist[asker-1]))
				if asker == 1:
					param.binary = asker2
				if asker == 2:
					param.randomIC = asker2	
			else:	
				asker2 = float(input('Please enter a new value for %s:  ' % paramlist[asker-1]))
				if asker == 3:
					param.b = asker2
				if asker == 4:
					param.gamma = asker2
				if asker == 5:
					param.epsilon = asker2
				if asker == 6:
					param.k = asker2
				if asker == 7:
					param.dt = asker2
				if asker == 8:
					param.capture_images = int(asker2)
					param.capture_image_name = raw_input('Please enter the name for the new images (no spaces):  ')
			param.refresh_param(M,N)
			return change_param(M,N)	
		else:	# if choice is not in range
			print '\n"%i" is an invalid choice. You must choose an integer reference number from 1 to %i next time.\n' % (asker, k)
			return change_param(M,N)

		
param = parameters()		
	
# Initialize array from TPL_CNN.py
# ask user for M,N
#def InteractiveUser():
print 'Okay!  Looks like its time to create a new capacitively coupled M x N SETJ array!\n'
M = int(input('\tPlease enter the number of rows of SETJ cells desired:  '))
print '\n\t\tYou have set M = ', M
N = int(input('\tPlease enter the number of columns of SETJ cells desired:  '))
print '\n\t\tYou have set N = ', N
#param = [1.77,2.0,.4,.3]	#already defined
#return SETJarray(M,N,param)
MyArray = SETJarray(M,N,param)	# create array
change_param(M,N)			# possibly change parameters

def choose_u():
	u_list = ['constant','random','checkerboard','square1','square2','image_of_Poincare','image_from_computer']	# add 'custom' and 'time varying'
	print "\nNow let's choose from this list of available bias voltage distributions:\n"
	k = 0
	for u_type in u_list:		# display list of types available
		k +=1
		print '\t[%i] %s\n' % (k, u_type)		# e.g.  '\t[1] constant\n'
	#print "(note that 'time_varying' can add time dependence to any previous distribution)"
	uChoice = int(input('\n'))	# Have user choose one of the available distributions, by reference number
	if uChoice in range(1,k+1):	# if the user chooses one of the available distributions (from 1 to k):
		uChoice = u_list[uChoice-1]	# then name the choice, e.g. uChoice = 'image_of_Poincare'
	else:
		print '\n"%i" is an invalid choice. You must choose an integer reference number from 1 to %i next time.\n' % (uChoice, k)
		return choose_u()
	
	return eval(uChoice+'_u_distr()')	

# define bias voltage distributions

def image_of_Poincare_u_distr():
	ImageProcessor()

def image_from_computer_u_distr():
	computer_image_name = raw_input('Please enter the name of the image file (with directory path, if not same as this program) exactly as it is saved (with extension):\t')
	ImageProcessor(computer_image_name)

def constant_u_distr():
	u_scale_factor = float(input('Please input a constant at which you would like to bias each SETJ (probably use |u| < ~5):\t'))
	MyArray.u_current = u_scale_factor*MyArray.u_current
	
def random_u_distr():
	"""
	u[i][j] will be distributed randomly from 0 to 2
	"""
	print'u[i][j] will be distributed randomly from 0 to 2\n'
	for i in xrange(M):
		for j in xrange(N):
			MyArray.u_current[i][j] = 2.0*random()

def checkerboard_u_distr():
	"""
	u[i][j] will alternate between -2 and 2 for neighboring SETJs
	"""
	u_scale_factor = 2.0
	MyArray.u_current = u_scale_factor*ones((M,N),float)
	for i in xrange(M):
		for j in xrange(N):
			MyArray.u_current[i][j] -= remainder(i+j,2)*2*u_scale_factor

def square1_u_distr():
	"""
	u will be distributed as: 
	u[i][j] = 1.0 for 20 < i < 40 and 20 < j < 40
	u[i][j] = 1.77 otherwise
	i,j will range from 0 to 59

	note that u = 1.0 is 8 pi periodic (4 stable cycles), u = 1.77 is 4 pi periodic (bistable)
	"""
	MyArray.__init__(60,60,param)	# re-initialize array with correct size
	MyArray.u_current = 1.77*ones((60,60),float)	# give the array the correct u values
	for i in xrange(60):
		if i > 19:
			if i < 41:
				MyArray.u_current[i,19:41] = 1.0	# make an inner block of u = 1.0
	MyArray.x_current = MyArray.u_current 
	MyArray.y_current = MyArray.u_current / 1.78	# ranges from 0 to 1;
	
def square2_u_distr():
	"""
	u will be distributed as: 
	u[i][j] = 1.2 for 20 < i < 40 and 20 < j < 40
	u[i][j] = 1.0 otherwise
	i,j will range from 0 to 59

	note that u = 1.0 has 4 stable cycles, u = 1.2 is 3 stable cycles
	"""
	MyArray.__init__(60,60,param)	# re-initialize array with correct size
	MyArray.u_current = ones((60,60),float)	# give the array the correct u values
	for i in xrange(60):
		if i > 19:
			if i < 41:
				MyArray.u_current[i,19:41] = 1.2	# make an inner block of u = 1.0
	MyArray.x_current = MyArray.u_current 
	MyArray.y_current = MyArray.u_current / 1.21	# ranges from 0 to 1;

def choose_output():
	output_list = ['x_current','y_current',]
	print "\nPlease choose one of the output types by reference number:\n"
	k = 0
	for output_type in output_list:		# display list of types available
		k +=1
		print '\t[%i] %s\n' % (k, output_type)		# e.g.  '\t[1] constant\n'
	#print "(note that 'time_varying' can add time dependence to any previous distribution)"
	OutputChoice = int(input('\n'))	# Have user choose one of the available distributions, by reference number
	if OutputChoice in range(1,k+1):	# if the user chooses one of the available distributions (from 1 to k):
		OutputChoice = output_list[OutputChoice-1]	
	else:
		print '\n"%i" is an invalid choice. You must choose an integer reference number from 1 to %i next time.\n' % (OutputChoice, k)
		return choose_output()	
	output_colormap = eval(OutputChoice+'_ColorMap')	# e.g. y_current_ColorMap	
	return OutputChoice, output_colormap


def ImageProcessor(image_filename='Poincare_image_medium.jpg'):	# image_filename should be a string, in filepath
	im = Image.open(image_filename)
	print 'The SETJ array currently has the shape, %i x %i\n'  % (M,N)
	asker = raw_input('Do you want to use the original image size as array size? [y]/n  ')
	if asker != 'n':
		m = im.size[0]	#[1]	#im.size --> (200,303) for Poincare_image.jpg
		n = im.size[1]	#[0]	??
		#print 'Im here'
		#MyArray = SETJarray(m,n,param)	#only local*******
	else:
		print("Sorry, can't do that until we fix a bug!  \nI'll reroute you to ImageProcessor(image_filename='Poincare_image_medium.jpg')\n")
		return ImageProcessor()
		m = M
		n = N
		#MyArray = 
		
	#200*303 == 60600, so im.getdata()[60599][0] is the last value to get (we choose [0] at the end b/c image is greyscale, so only need 1 value per pixel)
	
	maxpixels = im.getextrema()[0][1]
	image_size = im.size
	image_data = im.getdata()
	flat_image_array = array(image_data)[:,0]
	reverse_flat_image = [ ]
	for item in flat_image_array:
		reverse_flat_image.insert(0,item)
	new_image_array = array(reverse_flat_image).reshape((image_size[1],image_size[0])) / float(maxpixels+1.)	# use maxpixels + 1 so that values of 1.0 are not mod-ed to 0
	image_array = new_image_array.transpose()
	array2txt(image_array,'greydata','w')	# write color data to file
	MyArray.__init__(m,n,param)	# re-initialize array with correct size
	scaleU = 1.77#**********2.0	# so bias ranges from [0,10]
	MyArray.u_current = scaleU * txt2array('greydata')	# give the array the correct u values
	if param.randomIC != 1:
		MyArray.x_current = MyArray.u_current 
	MyArray.y_current = txt2array('greydata')	# ranges from 0 to 1;	this is the cheating way to see if the image is correct at the beginning... though the current setup is no more meaningful for y0
#	

choose_u()		# choose the bias voltage distribution
M,N = MyArray.shape	# reset M and N unless they were changed by one of the image processsing functions

    # Colormap: Site value in [0,1] to RGB color
	#
def y_current_ColorMap(x):		#This linear colormap might actually make more sense than circular color map because it shows +V_T as the opposite of -V_T and so should show a big difference between non-tunneling and tunneling events.
	#renormalize assuming that x currently ranges from -pi to pi
	#x = (pi + x) / (2*pi)	
	x = mod(x,1.0)	# use mod to avoid roundoff errors spilling out of [0,1]
	return x,x,(1.0-x)	# returns r,g,b

def x_current_ColorMap(x):		#This linear colormap might actually make more sense than circular color map because it shows +V_T as the opposite of -V_T and so should show a big difference between non-tunneling and tunneling events.
	#renormalize assuming that x currently ranges from -pi to pi
	x = (pi + x) / (2*pi)	
	x = mod(x,1.0)	# use mod to avoid roundoff errors spilling out of [0,1]
	return x,x,(1.0-x)	# returns r,g,b

    # Circular Colormap: Site value in [0,1] to RGB color periodically
	#
def CircularColorMap(x):
	x = 1.0 + sin(2.0*pi*x)	# this won't work... figure out what is needed
	x *= 0.5
	return 1.0-x,0.6,x

def SpDisplayState(state,M,N,CellSize):	# changed nSites to M,N -- will have to change elsewhere
	for i in range(M):
		for j in range(N):
			r,g,b = ColorMap(state[i][j])
			glColor3f(r,g,b)
			glRecti(i*CellSize,j*CellSize,(i+1)*CellSize,(j+1)*CellSize)
	screen.flip()
	#screen.flip()
## Choose output type and corresonding color map
state, ColorMap = choose_output()

CellSize = 12
if M > 200:
	CellSize = 3
if N > 200:
	CellSize = 3
nSteps = 100
if M*N < 10:
	CellSize = 30
if M*N > 100:
	#dt = 0.1	# still need to see if this is a reasonable value
	nSteps = 10
#nSites   = 47
#randomIC = 1	# binary 0 or 1 to indicate whether the ICs should be random or not
    # Default system: Dripping Handrail
#LDSIterate = DHRIterate
#ColorMap = SimpleColorMap	# chosen above, will depend on output choice

	# Get command line arguments, if any
	#
opts,args = getopt.getopt(sys.argv[1:],'t:n:s:ICLr:c:')
for key,val in opts:
	if key == '-t': nSteps   = int(val)
	#if key == '-n': nSites   = int(val)
	if key == '-s': CellSize = int(val)
	if key == '-I': randomIC = 0
	if key == '-C': ColorMap = CircularColorMap
	#if key == '-L':			# could add key to change from binary (TPL) to continuous
	#	
	
def get_screen_shot(run_name,iteration):
	if iteration == 1:
		pyglet.image.get_buffer_manager().get_color_buffer().save(run_name+ `iteration`+'.png')
	if iteration == 100:
		pyglet.image.get_buffer_manager().get_color_buffer().save(run_name+ `iteration`+'.png')
	if iteration == 250:
		pyglet.image.get_buffer_manager().get_color_buffer().save(run_name+ `iteration`+'.png')
	if iteration == 500:
		pyglet.image.get_buffer_manager().get_color_buffer().save(run_name+ `iteration`+'.png')
	if iteration == 1000:
		pyglet.image.get_buffer_manager().get_color_buffer().save(run_name+ `iteration`+'.png')
	if iteration == 10000:
		pyglet.image.get_buffer_manager().get_color_buffer().save(run_name+ `iteration`+'.png')
	if iteration == 20000:
		pyglet.image.get_buffer_manager().get_color_buffer().save(run_name+ `iteration`+'.png')	
		
size = (CellSize*M,CellSize*N)
	# Set initial configuration

	# Get display surface
screen = window.Window(CellSize*M,CellSize*N,
	caption = '2D SETJ Lattice Dynamical System Simulator',resizable = False)
screen.set_location(100,50)
screen.clear()	

iteration_counter = 0
t0 = time.clock()
#pyglet.image.get_buffer_manager().get_color_buffer().save('Poincare_y_binary_269t.png')	# ex. of how to save current screen
while not screen.has_exit:
	screen.dispatch_events()
		# Iterate state
		# Display state
	SpDisplayState(eval('MyArray.'+state+'.tolist()'),M,N,CellSize)	# output function is announced here
	MyArray.evolvearray(M,N)
	iteration_counter += 1
	if param.capture_images == 1:
		get_screen_shot(param.capture_image_name,iteration_counter)
	if (iteration_counter % nSteps) == 0:
		t1 = time.clock()
		print "Iterations per second: ", float(nSteps) / (t1 - t0)
		print "Current SETJ time (nomralized to pump cycles): ", MyArray.t
		t0 = t1
