
"""
nnMainMultiSim.py

Created by Daniel Wuellner on 2007-05-15.
Last edited on 2007-06-08

Main program driver for neuronal network simulation.  Requires nnNetwork.py,
nnUtils.py, nnMorphogenesis.py modules.
"""

import sys
import os
import time
from numpy import *
from nnNetwork import *
from nnUtils import *
from nnMorphogenesis import *

#=======================================================================#
# GLOBAL VARIABLES							#
#=======================================================================#
currentTimeStep=0
timeBeforeMorphogenesis=10
morphogenesisFrequency=10
verbose=True	#output to terminal window progress
debug=False	#interrupt execution with programming points
dynamic=False	#morphogenesis
logMode=True	#output network info, ics to log file
parameterDictionary={	"Number Of Neurons":neuronalNetwork.NN_DEFAULT, 
			"Initial Activity Percentage":neuronalNetwork.INITIAL_ACTIVITY_PERCENTAGE_DEFAULT,
			"Threshold":neuronalNetwork.THRESHOLD_DEFAULT,
			"Percent Inhibitory":neuronalNetwork.PERCENT_INHIBITORY_DEFAULT,
			"Inhibitory Weight":neuronalNetwork.INHIBITORY_WEIGHT_DEFAULT,
			"Connection Strength Mean":neuronalNetwork.CONNECTION_STRENGTH_MEAN_DEFAULT,
			"Connection Strength SD":neuronalNetwork.CONNECTION_STRENGTH_SD_DEFAULT,
			"Percent Connections Per Neuron Mean":neuronalNetwork.CONNECTIONS_MEAN_PERCENT_DEFAULT,
			"Percent Connections Per Neuron SD":neuronalNetwork.CONNECTIONS_SD_PERCENT_DEFAULT }

defaultFileString='timeSeries'	#prefix to simulation data

#=======================================================================#
# MODULE FUNCTIONS							#
#=======================================================================#

#-----------------------------------------------------------------------#	
# write a single vector to file						#
#-----------------------------------------------------------------------#	
def writeActivityVector(targetFile,ntwk):
	"""
	writes to open file object targetFile ntwk.currentState vector
	"""
	j=0
	#params=ntwk
	#programmingPoint(params)
	while j<ntwk.numberOfNeurons:
		targetFile.write(str(ntwk.currentState[j])+ " ")
		j=j+1
	targetFile.write("\n")

#-----------------------------------------------------------------------#	
# single time step wrapper: updates network state and implements	#
# morphogenesis if applicable						#
#-----------------------------------------------------------------------#
def timeStep(network,currentTime,dynamic=False):	
	currentTime=synapticActivityStep(network,currentTime)
	PSPCalculationStep(network)
	if dynamic:
		if (currentTime>=timeBeforeMorphogenesis and currentTime%morphogenesisFrequency==0):
			if logMode:
				morphogenesisStep(network,currentTime)
				#morphogenesisStep(network,currentTime,logF)
			else:
				morphogenesisStep(network,currentTime)
	updateRefractoryStateStep(network)
	#if debug: programmingPoint(network)
	return currentTime

#-----------------------------------------------------------------------#	
# logs currentPSP and calculates deviation from average PSP.		#
# (for use in compensatory algorithms)					#
#-----------------------------------------------------------------------#
def PSPCalculationStep(network):
	i=0
	while i<network.numberOfNeurons:
		if network.refractoryStatusArray[i]==0 and network.currentPSP[i]!="":
			delta=calculatePSPDeviation(network,i,network.currentPSP[i])
			if delta!="":
				network.PSPVarHistory[i].append(delta)
		i=i+1

	i=0
	while i<network.numberOfNeurons:
		if network.refractoryStatusArray[i]==0:
			val=updateAveragePSP(network,i,10)
			if val!="":
				network.AvePSPHistory[i].append(val)
		i=i+1

#-----------------------------------------------------------------------#	
# calculates deviation of one neuron from average PSP.			#
# (for use in compensatory algorithms)					#
#-----------------------------------------------------------------------#
def calculatePSPDeviation(network,neuronIdx,PSP):
	if len(network.AvePSPHistory[neuronIdx])==0:
		return ""
	return (PSP-network.AvePSPHistory[neuronIdx][-1])

#-----------------------------------------------------------------------#	
# updates PSPHistory for one neuron					#
# (for use in compensatory algorithms)					#
#-----------------------------------------------------------------------#
def updateAveragePSP(network,neuronIdx,interval=""):
	neuronHist=network.PSPHistory[neuronIdx]
	if interval=="":
		interval=len(neuronHist)
	if interval>len(neuronHist):
		return ""
	i=0
	totPSP=0
	j=interval
	while j>0:
		totPSP=totPSP+neuronHist[(len(neuronHist)-j)]
		j=j-1
	return (1.0/interval)*(totPSP)


#-----------------------------------------------------------------------#	
# refractory status logic; if a neuron fired this step, set its 	#
# refractory state to 1.  If a neuron's refractory state is currently 1,#
# reset it to 0.							#
#									#
# This algorithm assumes the refractory period is identical for all 	#
# neurons and is equal to one time step.				#
#-----------------------------------------------------------------------#
def updateRefractoryStateStep(network):
	i=0
	while i<network.numberOfNeurons:
		if network.refractoryStatusArray[i]==1:
			network.refractoryStatusArray[i]=0
		if network.currentState[i]==1:
			network.refractoryStatusArray[i]=1 
		i=i+1
	
#-----------------------------------------------------------------------#
# network activity logic (fast time scale): calculate each neuron's 	#
# PSP and update network.currentState vector for those neurons with	#
# PSP>=threshold.  							#
#									#
# This algorithm is implemented in a synchronous fashion; the state of	#
# all neurons is calculated based on the current setting and then 	#
# updated simultaneously.						#
#-----------------------------------------------------------------------#
def synapticActivityStep(network,currentTime):
	#local copy of connectivity matrix at time t	
	CM=network.connectivityMatrix
	zNext=[]
	for i in range(network.numberOfNeurons):
		zNext.append(0)
	# for kth neuron, calculate postsynaptic potential psp_k(t)
	for k in range(network.numberOfNeurons):
		exSum=0.0
		inSum=0.0
		#params=[CM,network]
		#programmingPoint(params)
		for j in range(network.NE):
			exSum=exSum+((CM[k][j])*network.currentState[j])
		for j in range(network.NE,network.numberOfNeurons):
			inSum=inSum+(CM[k][j]*network.inhibitoryWeight*network.currentState[j])
		psp_k=exSum-inSum
		#params=[exSum,inSum,k]
		#programmingPoint(params)
		#print "E "+str(exSum)+" I "+str(inSum)
		#inp=raw_input("Cont")
		if (network.refractoryStatusArray[k]==0 and (exSum!=0 or inSum!=0)):
			network.PSPHistory[k].append(psp_k)
                        network.currentPSP[k]=psp_k
			# calculate this neuron's output state
			if (psp_k>=network.threshold):
				zNext[k]=1
		else:
                        network.currentPSP[k]=''
	# update network's currentState
	network.currentState=zNext
	return (currentTime+1)


#-----------------------------------------------------------------------#
# instantiates a new neuronalNetwork object given parameters		#
#-----------------------------------------------------------------------#
def getNewNetwork(parameterDictionary):
	network=neuronalNetwork(numberOfNeurons=parameterDictionary["Number Of Neurons"],
					threshold=parameterDictionary["Threshold"],
					percentInhibitory=parameterDictionary["Percent Inhibitory"],
					inhibitoryWeight=parameterDictionary["Inhibitory Weight"],
					connectionStrengthMean=parameterDictionary["Connection Strength Mean"],
					connectionStrengthSD=parameterDictionary["Connection Strength SD"],
					connectionsMeanPercent=parameterDictionary["Percent Connections Per Neuron Mean"],
					connectionsSDPercent=parameterDictionary["Percent Connections Per Neuron SD"])
	return network

#-----------------------------------------------------------------------#
# returns a new initial activity vector; an element of {0,1}^networkSize#
#-----------------------------------------------------------------------#
def getNewIC(initialActivityPercentage,networkSize):
	initialActivity=int(initialActivityPercentage*networkSize)
	ICvec=[]
		#generate initial condition list
	for i in range(networkSize):
		ICvec.append(0)
	#populate initial condition list
	for i in range(initialActivity):
		match=0
		while not(match):
			test=int(uniform(0,networkSize))
			if (ICvec[test]==0):
				ICvec[test]=1
				match=1
	return ICvec
#-----------------------------------------------------------------------#
# debugging utility - returns control to command window during exec. 	#
#-----------------------------------------------------------------------#
def programmingPoint(params=[]):
	stop=0
	while not stop:
		x=raw_input("Enter python code to execute: ")
		if x=="": stop=1
		exec x
#-----------------------------------------------------------------------#
# starts a new simulation 						#
#-----------------------------------------------------------------------#
def runNewSimulation(steps,sims=1):
	for x in range(sims):
		#generateNetwork
		networkTemplate=getNewNetwork(parameterDictionary)
		if logMode:
			logToFile(logF,networkTemplate,"sim","simStart",parameterDictionary)
		initialActivityPercent=parameterDictionary["Initial Activity Percentage"]		
		runNetwork(networkTemplate,steps,ICs,initialActivityPercentage=initialActivityPercent,sim=x)
		if verbose:
			print "Finished simulation "+str(x)
		if debug:
			print "*  PP at end of sim:  *"
			programmingPoint()

#-----------------------------------------------------------------------#
# instantiates a network based on the passed connectivity matrix 	#
#-----------------------------------------------------------------------#
def reloadNetwork(networkMatrix,steps,newICs=1,initialCondition=[],morphoModes=('S','D')):
	#NEEDS REVISION: fixes critical params to defaults for now (esp. NI)
	network=neuronalNetwork(numberOfNeurons=len(networkMatrix),sourceConnectivityMatrix=array(networkMatrix))
	if initialCondition==[]:
		runNetwork(network,steps,newICs,modes=morphoModes,reloaded=True)
	else:
		runNetwork(network,steps,ICTemplate=initialCondition,modes=morphoModes,reloaded=True)


#-----------------------------------------------------------------------#
# main simulation driver						#
#-----------------------------------------------------------------------#
def runNetwork(networkTemplate,steps,ICs=1,ICTemplate=[],
		initialActivityPercentage=neuronalNetwork.INITIAL_ACTIVITY_PERCENTAGE_DEFAULT,
		modes=('S','D'),sim=0,reloaded=False):	
	for ic in range(ICs):
		tempCM=networkTemplate.connectivityMatrix.copy()	#array copy method
		network=networkTemplate.copy(tempCM)  	#need a fresh copy of the network for each IC
		generateICs=(ICTemplate==[])
		if generateICs:
			tempIC=getNewIC(initialActivityPercentage,network.numberOfNeurons)
		else:
			tempIC=ICTemplate[:]
		for k in modes:
                        IC=tempIC[:]
			if logMode:
				logToFile(logF,network,"network",str(k)+"_"+str(sim)+"_ic"+str(ic)+"_start")
			network.currentState=IC[:]
			#params=[network,IC,tempIC,ICTemplate,generateICs]
			#programmingPoint(params)
			if logMode:
				logToFile(logF,network,"ic","IC_"+str(k)+"_"+str(sim)+"_ic"+str(ic))
			currentTimeStep=0
			dynamic=(k=='D')
			network.refractoryStatusArray=network.currentState
			suffix=""
			if reloaded: suffix='R'
			timeSeriesFile=defaultFileString+suffix+str(k)+"_"+str(sim)+"_ic"+str(ic)
			f=open(timeSeriesFile,'w')
			#-----------------------------------------------------------------------#
			# (2) SIMULATION 							#
			#-----------------------------------------------------------------------#	
			# write initial time step
			writeActivityVector(f,network)
			for step in range(steps):
				if max(network.currentState)!=0:
					currentTimeStep=timeStep(network,currentTimeStep,dynamic)
					#write currentState to file
					writeActivityVector(f,network)
					#print network.currentState
			
			f.close()
			if verbose:
				print "IC"+str(ic)
			if debug:
				print "*  PP at end of IC:  *"
				programmingPoint()
			if logMode:
				logToFile(logF,network,"network",str(k)+"_"+str(sim)+"_ic"+str(ic)+"_end")
			if logMode:
				logToFile(logF,network,"ic","END_"+str(k)+"_"+str(sim)+"_ic"+str(ic))
	


#=======================================================================#
# I/O and PROGRAM EXECUTION						#
#=======================================================================#

title="NEURONAL NETWORK PLASTICITY SIMULATOR V0.1\nDaniel Wuellner, 2007\n"
print title
# load saved network or start new

inp=raw_input("Start new simulation [y]/n? ")
reloadSim=(inp.upper()=='N')

# this is still buggy
if reloadSim:
	inp=raw_input("Enter log file to search [nn_sim.log]: ")
	if inp=="":
		inp='nn_sim.log'
	lFile=inp
	
	inp=raw_input("Enter network label to load [S_0_ic0_start]: ")
	if inp=="":
		inp='S_0_ic0_start'
	networkLabel=inp
	lF=open(lFile)
	networkMatrix=getNetwork(lF,networkLabel)
	
	#params=[temp,networkMatrix]
	#programmingPoint(params)
	lF.close()
	inp=raw_input("Start new IC [y]/n? ")
	reloadIC=(inp.upper()=='N')
	loadedIC=[]
	ICs=1
	if reloadIC:
		inp=raw_input("Enter IC label to load [IC_S_0_ic0]: ")
		if inp=="":
			inp="IC_S_0_ic0"
		ICLabel=inp
		lF=open(lFile)
		match=False
		for line in lF:
		        if (line.find(ICLabel)>-1 and not match):
                       		newLine=lF.next()
                         	match=True
				for x in newLine.split():
					loadedIC.append(int(x))
			
		lF.close()
	else:
		inp=raw_input("Enter number of different initial conditions per simulation [10]: ")
		if inp=="":
			inp=10
		ICs=int(inp)

	inp=raw_input("Enter number of steps [20]: ")
	if inp=="":
		inp=20
	steps=int(inp)
	
	inp=raw_input("Choose morphogenesis modes to run (D)ynamic, (S)tatic, [(B)]oth: ")
	mode=inp.upper()
	if mode=='D':
		modes=('D',)
	elif mode=='S':
		modes=('S',)
	else:
		modes=('D','S')
	
	inp=raw_input("Log network information to file [y]/n? ")
	logMode=(inp.upper()!='N')
	if logMode:
		logFileName=raw_input("Enter log file name [nn_sim.log]: ")
		if logFileName=="":
			logFileName='nn_sim.log'
	if logMode:
		logF=open(logFileName,'w')
		reloadNetwork(networkMatrix,steps,ICs,loadedIC,modes)
	if logMode:
		logToFile(logF,"","sim","simEnd")
		logF.close()
	

# start new sim
else:			
	print "Current default settings: "
	for item in parameterDictionary.items():
		print item[0]+": "+str(item[1])
	inp=raw_input("Accept current defaults [y]/n? ")
	if (inp.upper()=='N'):
		for item in parameterDictionary.items():
			newInp=raw_input( "Enter "+item[0]+"("+str(item[1])+"): ")
			if newInp!="":
				parameterDictionary[item[0]]=float(newInp)

	inp=raw_input("Enter number of simulations [1]: ")
	if inp=="":
		inp=1
	sims=int(inp)
	inp=raw_input("Enter number of different initial conditions per simulation [10]: ")
	if inp=="":
		inp=10
	ICs=int(inp)
	inp=raw_input("Enter number of steps [20]: ")
	if inp=="":
		inp=20
	steps=int(inp)
	inp=raw_input("Log network information to file [y]/n? ")
	logMode=(inp.upper()!='N')
	if logMode:
		logFileName=raw_input("Enter log file name [nn_sim.log]: ")
		if logFileName=="":
			logFileName='nn_sim.log'
	
	defaultFileString='timeSeries'
	#-----------------------------------------------------------------------#
	# (1) INITIAL STAGE: generate simulation, network, and ICs		#
	#-----------------------------------------------------------------------#
	NN=parameterDictionary["Number Of Neurons"]
	if logMode:
		logF=open(logFileName,'w')
		runNewSimulation(steps,sims)
	if logMode:
		logToFile(logF,"","sim","simEnd")
		logF.close()

if debug:
	print "*  PP at end of execution:  *"
	programmingPoint()

#=======================================================================#	

	