"""
nnUtils.py

Created by Daniel Wuellner on 2007-05-15.
Last edited on 2007-06-08

Utility functions to test neuronal network simulation.  Requires nnGenericUtilityFunctions.py
"""
import sys
import os


from nnGenericUtilityFunctions import *
from pylab import *
from numpy import *
from Scientific.Statistics import average, standardDeviation
from Scientific.Statistics.Histogram import Histogram
import time
#-----------------------------------------------------------------------#
# TOOLS FOR ANALYZING SIMULATION					#
#-----------------------------------------------------------------------#	
def analyzeMorphogenesis(logFile,networkLabel):
	f=open(logFile)
	startNetwork=array(getNetwork(f,networkLabel+"_start"))
	f.close()
	f=open(logFile)
	endNetwork=array(getNetwork(f,networkLabel+"_end"))
	f.close()
	diffArray=(startNetwork-endNetwork)
	rowSum=diffArray.sum(axis=1)
	colSum=diffArray.sum(axis=0)
	return diffArray,rowSum,colSum
	

def getNetwork(lF,networkLabel):
	networkMatrix=[]
	match=True
	k=0
	for line in lF:
	        if (match and line.find(networkLabel)>-1):
            		while match:
	                	newLine=lF.next()
         	                if newLine.find("\N")>-1:match=False
                	        if match: 				
					temp=newLine.split()
					#params=[temp,networkMatrix]
					#programmingPoint(params)
					networkMatrix.append([])
					for y in temp:
						networkMatrix[k].append(float(y))
					k=k+1
	return networkMatrix

def logSimulation(logFile,fromWhere,parameterDictionary):
	if fromWhere=="simStart":
		logFile.write("<html>\n<head>\n")
		logFile.write("<title>NETWORK SIMULATION LOG (started at "+str(time.time())+")<\\title>\n")
		logFile.write("\nNetwork Parameters:\n")
		for item in parameterDictionary.items():
			logFile.write("<"+str(item[0])+"="+str(item[1])+">")
			logFile.write("\n")
		logFile.write("\n<\\head>\n<body>\n")
	elif fromWhere=="simEnd":
		logFile.write("<P>NETWORK SIMULATION LOG (ended at "+str(time.time())+")\n")
		logFile.write("<\\body>\n<\\html>")

def logIC(logFile,network,fromWhere,ic):
	logFile.write("<IC>\n<label="+fromWhere+">\n")
	outputList(logFile,ic)
	logFile.write("\n<\IC>\n")

def logNetwork(logFile,network,fromWhere):
	logFile.write("<N>\n<label="+fromWhere+">\n")
	c=network.connectivityMatrix
	outputMatrix(logFile,c)
	logFile.write("<\N>\n")


def logToFile(logFile,network,what,fromWhere,parameterDictionary=""):
	if what=="sim":
		logSimulation(logFile,fromWhere,parameterDictionary)
	elif what=="network":
		logNetwork(logFile,network,fromWhere)
	elif what=="ic":
		logIC(logFile,network,fromWhere,network.currentState)
	
def outputList(targetFile,lst):
	i=0
	while i<len(lst):
		targetFile.write(str(lst[i])+ " ")
		i=i+1
	targetFile.write("\n")
def outputMatrix(targetFile,matrix):
	#only applicable for 2-d arrays
	dimensions=matrix.shape
	rows=dimensions[0]
	cols=dimensions[1]
	i=0
	while i<rows:
		j=0
		while j<cols:
			targetFile.write(str(matrix[i][j])+ " ")
			j=j+1
		targetFile.write("\n")
		i=i+1 

def inputMatrix(targetFile):
	f=open(targetFile)
	#matrix will contain a list of rows
	matrix=[]
	for line in f:
		columns=line.split()
		i=0
		thisRow=[]
		while i<len(columns):
			thisRow.append(float(columns[i]))
			i=i+1
		if (len(thisRow)!=0):
			matrix.append(thisRow)
	matrixArray=array(matrix)	
	f.close()
	return matrixArray

def getActivity(fileName,transients=0):
	m=array(inputMatrix(fileName)).sum(1)
	#calculate row sum
	return m


def getTransients(fileName):
	f=open(fileName)
	matrix=[]
	iteration=0
	transients=0
	iterationsToNullActivity=0
	for line in f:
		entries=line.split()
		if int(max(entries))==1:
			iteration=iteration+1
			j=0
			while j<len(matrix):
				if (line==matrix[j] and transients==0):
					transients=j
				j=j+1
			matrix.append(line)
		else:
			iterationsToNullActivity=iteration
			break	#stop parsing this file; it's all zeroes from now on
	f.close()			
	return transients,iterationsToNullActivity

#-----------------------------------------------------------------------#
# getFileList: 	generates list of strings containing data file names	#
#		to analyze; queries user if no parameters passed	#
#-----------------------------------------------------------------------#
def getFileList(baseName="",simNum="",icRange=""):
	if baseName=="":
		baseName=raw_input("Enter base file name [timeSeries]: ")
		if baseName=="":
			baseName="timeSeries"
	
	if simNum=="":
		simNum=raw_input("Enter simulation number [0]: ")
		if simNum=="":
			simNum=0
	if icRange=="":
		icRange=raw_input("Enter number of ICs to analyze [10]: ")
		if icRange=="":
			icRange=10
	icRange=range(int(icRange))
	fileList=[]
	for ic in icRange:
		fileList.append(baseName+"_"+str(simNum)+"_ic"+str(ic))
	return fileList
def transientStatistics():
	t,z=getTransients()
	print "Mean: "+str(average(t))
	print "SD: "+str(standardDeviation(t))
	dead=[]
	for i in z:
		if i>0:
			dead.append(i)
				
	print "Percent Dead Runs: "+str(len(dead)/float(len(z)))
	if dead!=[]:
		print "Mean Transients to Dead: "+str(average(dead))

def getSingleNetworkStats(baseName="",simNum="",icRange=""):
	files=getFileList(baseName,simNum,icRange)
	transientList=[]
	deadList=[]
	for x in files:
		t,z=getTransients(x)
		transientList.append(t)
		if z>0:
			deadList.append(z)
	dic={"MeanTransients":average(transientList),"SDTransients":standardDeviation(transientList),
		"PercentDead":(len(deadList)/float(len(transientList))),
		"MeanTimeToDead":0}
	if deadList!=[]:
		dic["MeanTimeToDead"]=average(deadList)
	
	return dic
		

def networkActivityHistogram(dataDict,key,bins=20):
	if not(dataDict.has_key(key)):
		return []
	hg=Histogram(dataDict[key],bins)
	return hg
					
def plotNetworkActivityHistogram(dataDict,key="",bins=""):
	if key=="":
		key=interactiveList(dataDict.keys(),"Data")
	if bins=="":
		bins=raw_input("Enter number of bins [20]: ")
		if bins=="":
			bins=20
		bins=int(bins)
	print key
	figure(1,(8,6))
	TitleString = 'Histogram of '+key
	title(TitleString)
	xlabel(key)       # set x-axis label
	ylabel('Count('+key+')') # set y-axis label
	hg=networkActivityHistogram(dataDict,key,bins)
	x=[]
	y=[]
	for i in hg:
		x.append(i[0])
		y.append(i[1])
	plot(x,y,'ro')
	show()

#-----------------------------------------------------------------------#
# Plot activity v time							#
#-----------------------------------------------------------------------#	
def plotGlobalActivity(baseName="",simNum="",icRange=""):
	if baseName=="":
		baseName=raw_input("Enter base file name [timeSeriesD]: ")
		if baseName=="":
			baseName="timeSeriesD"
	
	if simNum=="":
		simNum=raw_input("Enter simulation number [0]: ")
		if simNum=="":
			simNum=0
	if icRange=="":
		icRange=raw_input("Enter number of ICs to analyze [10]: ")
		if icRange=="":
			icRange=10
	icRange=range(int(icRange))
	for ic in icRange:
		rowSumArray=getActivity(baseName+"_"+str(simNum)+"_ic"+str(ic))
		plot(arange(len(rowSumArray)),rowSumArray,label='activity over time')
		title(baseName+"_"+str(simNum)+"_ic"+str(ic))
		legend()
		show()
#-----------------------------------------------------------------------#
# Plot average activity v time						#
#-----------------------------------------------------------------------#	     
def plotAverageActivity(baseName="",simNum="",icRange=""):
	if baseName=="":
		baseName=raw_input("Enter base file name [timeSeriesD]: ")
		if baseName=="":
			baseName="timeSeriesD"
	
	if simNum=="":
		simNum=raw_input("Enter simulation number [0]: ")
		if simNum=="":
			simNum=0
	if icRange=="":
		icRange=raw_input("Enter number of ICs to analyze [10]: ")
		if icRange=="":
			icRange=10
	icRange=range(int(icRange))
	for ic in icRange:
		rowSumArray=getActivity(baseName+"_"+str(simNum)+"_ic"+str(ic))
		j=0
		averageActivity=[]
		while j<len(rowSumArray)-1:
			z=(rowSumArray[j]+rowSumArray[j+1])/2.0
			averageActivity.append(z)
			j=j+1
			
		plot(arange(len(averageActivity)),averageActivity,label='Average activity over time')
		title(baseName+"_"+str(simNum)+"_ic"+str(ic))
		legend()
		show()
         
#-----------------------------------------------------------------------#
# Retrieve basic stats about series of simulations			#
#-----------------------------------------------------------------------#	
def getMultiNetworkStats():
	baseName=raw_input("Enter base file name [timeSeries]: ")
	if baseName=="":
		baseName="timeSeries"
	
	simNumRange=raw_input("Enter number of simulations to analyze [1]: ")
	if simNumRange=="":
			simNumRange=1
	icRange=raw_input("Enter number of ICs to analyze [10]: ")
	if icRange=="":
		icRange=10
	multiStats={}
	for sim in range(int(simNumRange)):
		singleStats=getSingleNetworkStats(baseName,sim,icRange)
		for key in singleStats.keys():
			if multiStats.has_key(key):
				multiStats[key].append(singleStats[key])
			else:
				multiStats[key]=[]
				multiStats[key].append(singleStats[key])
	return multiStats
		
	