#ANALYSIS TOOLS
from numpy import *
from pylab import *
from Scientific.Statistics.Histogram import Histogram
from Scientific.Functions.LeastSquares import leastSquaresFit


def A_highlow(nSites,massHigh,massLow,state):
	numHigh = 0
	numLow = 0
	for i in xrange(nSites):
		for j in xrange(nSites):
			if state[i][j] > massHigh:
				numHigh += 1
			else:
				if state[i][j] < massLow:
					numLow += 1
	print "Final state results:"
	print "Percent of sites with mass > " + str( massHigh/(nSites**2) ) + ":  " + str(numHigh)
	print "Percent of sites with mass < " + str(massLow/nSites**2) + ":  " + str(numLow)

def A_hist():
	#make a plot of mass density distribution as a function of time
	num_hgs = len(saved_states)
	xlist = [zeros((numBins),int)]*num_hgs

	#Display histogram data...
	#time going down along y axis,
	#different bins along x axis,
	#grayscale to indicate fullness of bins

	screen = pygame.display.set_mode( (numBins*CellSize,num_hgs*CellSize) )
	pygame.display.set_caption('mass density (x) over time (y)')

	# Create RGB array whose elements refer to screen pixels
	# Strangeness: sptmdiag is no longer used, but if this line
	# is removed, then display updates slow down considerably!
	sptmdiag = surfarray.pixels3d(screen)

	ConfigRect2 = pygame.Rect(0,0,numBins*CellSize,num_hgs*CellSize)
	screen.fill(OffColor,ConfigRect2)

	for k in xrange(num_hgs):
		single_state = ravel(saved_states[k])
		hg = Histogram(single_state, numBins)

		for i in xrange(numBins):
			fc = int( ( 1 - (hg[i][1])/(nSites*nSites) )*255 )	#calculate grayscale value

			if fc > 255:
				fc = 255
			if fc < 0:
				fc = 0

			fill_Color = fc,fc,fc
			screen.fill(fill_Color,[i*CellSize,k*CellSize,CellSize,CellSize])
	pygame.display.flip()


def A_make_xlist(nIter,nSites,add_total):
	settle_val = 0.001
	xlist=[]
	settle_time = "not reached in "+ str(nIter)
	for i in range(0,len(add_total)):
		xlist.append(i)
		if (settle_time  == ("not reached in "+str(nIter)) ) and (add_total[i]/pow(nSites,2)) < settle_val:
			settle_time = i
	return xlist, settle_time

def A_settle(xlist,add_total,settle_time):
	plot(xlist,add_total)
	show()

	#display settle time
	print "Settle time was " + str(settle_time) + " iterations"


def A_fit_exponential(xlist,add_total):

	def exponential(params, x):
		a = params[0]
		b = params[1]
		return a*exp(-b*x)
	data=[]
	for i in range(0,len(add_total)):
		data.append(  (xlist[i],add_total[i])  )
#	x = float( raw_input("Please enter an initial x data point for fit function: ") )
#	y = float( raw_input("Please enter an initial y data point for fit function: ") )
	x = xlist[int(len(add_total)/2.0)]
	y = add_total[int(len(add_total)/2.0)]

	fit = leastSquaresFit( exponential, (x,y), data)

	print "Fitted parameters:", fit[0]
	print "Fit error:", fit[1]

	fit_curve = []
	for i in range(0,len(add_total)):
		fit_curve.append(  fit[0][0] * exp(-1 * fit[0][1] * i)  )


	plot(xlist,fit_curve)
	plot(xlist,add_total)

	#savefig('fitted_plot',dpi=600)	#uncomment this line to save figure
	show()

	return fit


def A_multi_run_avg(data_run_list,disp_all_runs=False):
	min_x = 999999999
	for i in range(0,len(data_run_list)):
		if len( data_run_list[i] ) < min_x:
			min_x = len( data_run_list[i] )

	if min_x == 999999999:
		print "No data to plot"
		return

	xlist = []
	avg_data_run_list = []
	for i in range(min_x):
		xlist.append(i)
		avg_data_run_list.append(0)

	for i in range(0,len(data_run_list)):
		for j in range(0,min_x):
			avg_data_run_list[j] += data_run_list[i][j]

	for j in range(0,min_x):
		avg_data_run_list[j] /= len(data_run_list)

	if disp_all_runs == True:
	        for i in range(0, len(data_run_list) ):
        	        plot(xlist,data_run_list[i],'-r')

	plot(xlist,avg_data_run_list, 'g')

	A_fit_exponential(xlist,avg_data_run_list)
	#savefig('multi_run_avg',dpi=600)	#uncomment this line to save figure
	show()
	return avg_data_run_list
