"""
nnMorphogenesis.py

Created by Daniel Wuellner on 2007-05-15.
Last edited on 2007-06-08

Library of morphogenetic algorithms for neuronal network simulation.  
"""

import sys
import os
import time
from numpy import *
from nnNetwork import *
from nnUtils import *

#-----------------------------------------------------------------------#
# GLOBAL VARIABLES							#
#-----------------------------------------------------------------------#	
DEFAULT_K_DICT={'kH_bepo':0.1,'kH_fpr':1.00,'kH_fepo':1.0,'kH_fipo':1.0,
		'kL_bpr':0.1,'kL_bipo':0.1,'kL_fpr':1.0,'kL_fepo':1.0,'kL_fipo':1.0}

#-----------------------------------------------------------------------#	
# general wrapper							#
#-----------------------------------------------------------------------#	
def morphogenesisStep(network,currentTime,logFile=""):
    #method 1: stabilize to within 10% of threshold
    stabilizingMorphogenesis(network,currentTime,logFile)
    #method 2: hebbian

#-----------------------------------------------------------------------#	
# wrapper for compensation algorithm					#
#-----------------------------------------------------------------------#	
def stabilizingMorphogenesis(network,currentTime,logFile,stableRange=.3,k_Dict=DEFAULT_K_DICT):
	t=network.threshold
	pspVar=[]
	for i in range(network.numberOfNeurons):
		#check if update?
		#determine delta of psp and threshold
		##drw 05292007
                if network.currentPSP[i]!='':
                    if (network.currentPSP[i]-t)>t*(stableRange/2.0):
			modulateDown(network,i,network.currentPSP[i]-t,k_Dict,logFile)
                    elif (network.currentPSP[i]-t)<t*(stableRange/2.0):
			modulateUp(network,i,network.currentPSP[i]-t,k_Dict,logFile)
		    else:   #tolerable response
                        pass
        synapseRecombination(network)

#-----------------------------------------------------------------------#	
# compensation algorithm - modDown					#
#-----------------------------------------------------------------------#	
def modulateDown(network,neuronIdx,PSPVar,kDict,logFile):
        params=[network,neuronIdx]
	#print network.connectivityMatrix
	# sum across row of excitation inputs ([neuronIdx][k] for k<NE)
        totalBEPO=0
        for i in range(network.NE):
            totalBEPO=totalBEPO+network.connectivityMatrix[neuronIdx][i]    
	#decrease bepo
	if totalBEPO==0: totalBEPO=1
	for i in range(network.NE):
        	oldBEPO=network.connectivityMatrix[neuronIdx][i]
		#take kH_bepo*(this connections % of total) fraction of existing connectivity
		delta=kDict['kH_bepo']*(oldBEPO/totalBEPO)*oldBEPO 
		#print delta
                newBEPO=oldBEPO-delta
                if newBEPO<0.0:
                	newBEPO=0.0
            	network.connectivityMatrix[neuronIdx][i]=newBEPO
		if logFile!='':
			logFile.write('Altered ('+str(neuronIdx)+','+str(i)+') from '+str(oldBEPO)+' to '+str(newBEPO)+'\n')
	#increase fpr
	fprIndex=int(uniform(0,network.numberOfNeurons))
	if fprIndex!=neuronIdx:
		network.fpr[fprIndex][neuronIdx]=1	#don't care if already 1 or 0
		
	#print("modDown FPR")
	#programmingPoint(params)		
	#decrease fepo
	fepoAvailable=[]
	for i in range(0,network.NE):
		if network.fpo[neuronIdx][i]==1:
			fepoAvailable.append(i)
	if len(fepoAvailable)>0:
		fepoIndex=int(uniform(0,len(fepoAvailable)))
		fepoChoice=fepoAvailable[fepoIndex]
		network.fpo[neuronIdx][fepoChoice]=0
	#print("modDown FEPO")
	#programmingPoint(params)	
	#increase fipo
	fipoIndex=int(uniform(network.NE,network.numberOfNeurons))
	if fipoIndex!=neuronIdx:
		network.fpo[neuronIdx][fipoIndex]=1	#don't care if already 1 or 0

        #print("modDown FIPO")
	#programmingPoint(params)

#-----------------------------------------------------------------------#	
# compensation algorithm - modUp					#
#-----------------------------------------------------------------------#	
def modulateUp(network,neuronIdx,PSPVar,kDict,logFile):
        # sum across row of inhibitory inputs ([neuronIdx][k] for NE<k<NN)
        params=[network,neuronIdx]
	#print network.connectivityMatrix
	totalBIPO=0
        for i in range(network.NE,network.numberOfNeurons):
            totalBIPO=totalBIPO+network.connectivityMatrix[neuronIdx][i]    
	if totalBIPO==0: totalBIPO=1
	#decrease bipo
	for i in range(network.NE,network.numberOfNeurons):
		oldBIPO=network.connectivityMatrix[neuronIdx][i]
        	#take kL_bipo*(this connections % of total) fraction of existing connectivity
		delta=kDict['kL_bipo']*(oldBIPO/totalBIPO)*oldBIPO
	        newBIPO=oldBIPO-delta
        	#print delta
		if newBIPO<0.0:
                	newBIPO=0.0
	        network.connectivityMatrix[neuronIdx][i]=newBIPO
		if logFile!='':
			logFile.write('Altered ('+str(neuronIdx)+','+str(i)+') from '+str(oldBIPO)+' to '+str(newBIPO)+'\n')
 	#print("modUp BIPO")
	#programmingPoint(params)
	# sum down column of outputs ([k][neuronIdx])
	totalBPR=0
	for i in range(network.numberOfNeurons):
		totalBPR=totalBPR+network.connectivityMatrix[i][neuronIdx]
	if totalBPR==0: totalBPR=1
	#programmingPoint(params)
	#decrease bpr
	for i in range(network.numberOfNeurons):
                oldBPR=network.connectivityMatrix[i][neuronIdx]
                #take kL_bpr*(this connections % of total) fraction of existing connectivity
		delta=kDict['kL_bpr']*(oldBPR/totalBPR)*oldBPR
                newBPR=oldBPR-delta
                if newBPR<0.0:
                    newBPR=0.0
                network.connectivityMatrix[i][neuronIdx]=newBPR          
 	if logFile!='':
			logFile.write('Altered ('+str(i)+','+str(neuronIdx)+') from '+str(oldBPR)+' to '+str(newBPR)+'\n')
	#print("modUp BPR")
	#programmingPoint(params)
	
	#decrease fpr
	fprAvailable=[]
	for i in range(0,network.numberOfNeurons):
		if network.fpr[i][neuronIdx]==1:
			fprAvailable.append(i)
	if len(fprAvailable)>0:
		fprIndex=int(uniform(0,len(fprAvailable)))
		fprChoice=fprAvailable[fprIndex]
		network.fpr[fprChoice][neuronIdx]=0
	#programmingPoint(params)
 	#print("modUp FPR")
	#increase fepo
	fepoIndex=int(uniform(0,network.NE))
	if fepoIndex!=neuronIdx:
		network.fpo[neuronIdx][fepoIndex]=1
	#programmingPoint(params)
	#print("modUp FEPO")
	#decrease fipo
	fipoAvailable=[]
	for i in range(network.NE,network.numberOfNeurons):
		if network.fpo[neuronIdx][i]==1:
			fipoAvailable.append(i)
	if len(fipoAvailable)>0:
		fipoIndex=int(uniform(0,len(fipoAvailable)))
		fipoChoice=fipoAvailable[fipoIndex]
		network.fpo[neuronIdx][fipoChoice]=0
	#programmingPoint(params)
	#print("modUp FIPO")   
#-----------------------------------------------------------------------#	
# compensation algorithm	- generate new synapses			#
#-----------------------------------------------------------------------#	
def synapseRecombination(network):
	#loop through fpr and fpo; if match, generate connection
        for i in range(network.numberOfNeurons):
            for j in range(network.numberOfNeurons):
                if (network.fpr[j][i]==1 and network.fpo[i][j]==1):
                    network.fpr[j][i]=0
                    network.fpo[i][j]=0
                    network.connectivityMatrix[i][j]=network.getConnectionStrength()

#-----------------------------------------------------------------------#	
# original algorithm (based on Dammasch, 1986)				#
#-----------------------------------------------------------------------#		
def binaryStabilizingMorphogenesis(network,currentTime,stableRange=.1,k_Dict=DEFAULT_K_DICT):
	t=network.threshold
	pspVar=[]
	for i in range(network.numberOfNeurons):
		#check if update?
		#determine delta of psp and threshold
		##drw 05292007
		if (network.refractoryStatusArray[i]==1 and len(network.PSPHistory[i]>0)):
			pspVar.append(network.PSPHistory[i][-1]-t)
		else:
			pspVar.append(t)
	kH_bepo=k_Dict['kH_bepo']
	kH_bepo=k_Dict['kH_bepo']
	kH_bepo=k_Dict['kH_bepo']
	kH_bepo=k_Dict['kH_bepo']
	kH_bepo=k_Dict['kH_bepo']
	kH_bepo=k_Dict['kH_bepo']
	kH_bepo=k_Dict['kH_bepo']
	kH_bepo=k_Dict['kH_bepo']

	i=0
	while i<len(pspVar):
		if pspVar[i]>(t*(stableRange/2.0)):	#consider this 'high'-type
			binaryModulateDown(network,i,pspVar[i],k_Dict['kH_bepo'],k_Dict['kH_fpr'],
					k_Dict['kH_fepo'],k_Dict['kH_fipo'])
		elif pspVar[i]<(t*(stableRange/2.0)):	#consider this 'low'-type
			binaryModulateUp(network,i,pspVar[i],k_Dict['kL_bpr'],k_Dict['kL_bipo'],
					k_Dict['kL_fpr'],k_Dict['kL_fepo'],k_Dict['kL_fipo'])
		else:	#tolerable response
			pass
		i=i+1
	synapseRecombination()
#-----------------------------------------------------------------------#	
# original algorithm (based on Dammasch, 1986)				#
#-----------------------------------------------------------------------#		
def binaryModulateDown(network,neuronIdx,pspVar,kH_bepo,kH_fpr,kH_fepo,kH_fipo):
	#decrease bepo
	bepoAvailable=[]
	for i in range(0,network.NE):
		if network.bpo[neuronIdx][i]==1:
			bepoAvailable.append(i)
	if len(bepoAvailable)>0:
		bepoIndex=int(uniform(0,len(bepoAvailable)))
		bepoChoice=bepoAvailable[bepoIndex]
		network.bpo[neuronIdx][bepoChoice]=0
		network.bpr[bepoChoice][neuronIdx]=0
		#released free (excitatory) presynaptic element
		network.fpr[bepoChoice][neuronIdx]=1
		network.connectivityMatrix[neuronIdx][bepoChoice]=0
	
	#increase fpr
	fprIndex=int(uniform(0,network.numberOfNeurons))
	if fprIndex!=neuronIdx:
		network.fpr[fprIndex][neuronIdx]=1	#don't care if already 1 or 0
		
			
	#decrease fepo
	fepoAvailable=[]
	for i in range(0,network.NE):
		if network.fpo[neuronIdx][i]==1:
			fepoAvailable.append(i)
	if len(fepoAvailable)>0:
		fepoIndex=int(uniform(0,len(fepoAvailable)))
		fepoChoice=fepoAvailable[fepoIndex]
		network.fpo[neuronIdx][fepoChoice]=0
		
	#increase fipo
	fipoIndex=int(uniform(network.NE,network.numberOfNeurons))
	if fipoIndex!=neuronIdx:
		network.fpo[neuronIdx][fipoIndex]=1	#don't care if already 1 or 0

#-----------------------------------------------------------------------#	
# original algorithm (based on Dammasch, 1986)				#
#-----------------------------------------------------------------------#			

def binaryModulateUp(network,neuronIdx,pspVar,kL_bpr,kL_bipo,kL_fpr,kL_fepo,kL_fipo):
	#decrease bipo
	bipoAvailable=[]
	for i in range(network.NE,network.numberOfNeurons):
		if network.bpo[neuronIdx][i]==1:
			bipoAvailable.append(i)
	if len(bipoAvailable)>0:
		bipoIndex=int(uniform(0,len(bipoAvailable)))
		bipoChoice=bepoAvailable[bipoIndex]
		network.bpo[neuronIdx][bipoChoice]=0
		network.bpr[bipoChoice][neuronIdx]=0
		#released free (inhibitory) presynaptic element
		network.fpr[bipoChoice][neuronIdx]=1
		network.connectivityMatrix[neuronIdx][bipoChoice]=0
	#decrease bpr
	bprAvailable=[]
	for i in range(network.numberOfNeurons):
		if network.bpr[i][neuronIdx]==1:
			bprAvailable.append(i)
	if len(bprAvailable)>0:
		bprIndex=int(uniform(0,len(bprAvailable)))
		bprChoice=bprAvailable[bprIndex]
		network.bpr[bprChoice][neuronIdx]=0
		network.bpo[neuronIdx][bprChoice]=0
		network.connectivityMatrix[neuronIdx][bprChoice]=0
	
	#decrease fpr
	fprAvailable=[]
	for i in range(0,network.numberOfNeurons):
		if network.fpr[i][neuronIdx]==1:
			fprAvailable.append(i)
	if len(fprAvailable)>0:
		fprIndex=int(uniform(0,len(fprAvailable)))
		fprChoice=fprAvailable[fprIndex]
		network.fpr[fprChoice][neuronIdx]=0
	#increase fepo
	fepoIndex=int(uniform(0,network.NE))
	if fepoIndex!=neuronIdx:
		network.fpo[neuronIdx][fepoIndex]=1
	
	#decrease fipo
	fipoAvailable=[]
	for i in range(network.NE,network.numberOfNeurons):
		if network.fpo[neuronIdx][i]==1:
			fipoAvailable.append(i)
	if len(fipoAvailable)>0:
		fipoIndex=int(uniform(0,len(fipoAvailable)))
		fipoChoice=fipoAvailable[fipoIndex]
		network.fpo[neuronIdx][fipoChoice]=0
#-----------------------------------------------------------------------#
def programmingPoint(params=[]):
	stop=0
	while not stop:
		x=raw_input("Enter python code to execute: ")
		if x=="": stop=1
		exec x