# devNote: 'randomize adjacency' functionality
# devNote: histogram of degrees
# note: mean degree is calculated as ONLY mean out degree (= mean in degree)

import numpy as N
from numpy.random import random


class SpikingNetwork:

    #NOTE: right now visualization relies on threshold = 1!!! (for efficiency)
    def __init__(self, nodeInfo, hLength, pEdge=.9, pSpikingInit=.8, stimulusVal = .3, decayConstant=.05, threshold=1, resetVal=0.):
        
        self.pEdge = pEdge;
        self.stimulusVal = stimulusVal;
        self.threshold = threshold;
        self.resetVal = resetVal;
        self.decayConstant = decayConstant;    
    
        #nodeInfo can either be a number of nodes which uses pEdge to construct adjacency or the actual adjacency matrix
        if isinstance(nodeInfo, int):
            self.nNodes = nodeInfo
            self.adjacency = self.__generateAdjacency(self.nNodes,pEdge)
        else:
            self.adjacency = nodeInfo
            self.nNodes = self.adjacency.shape[0]
        self.currentVs = self.__initCurrentVs(self.nNodes,threshold,pSpikingInit)
        self.nextStimulus = N.zeros_like(self.currentVs)
        self.nodeHist =  N.zeros((hLength,self.nNodes),float)
        self.spikeHist = N.zeros((hLength,self.nNodes),int)
        self.historyLength = hLength
        
        self.timeStep = -1;
                
    ################ STATIC PROPERTIES ######################

    def __generateAdjacency(self, nNodes, pEdge):
        adj = N.zeros((nNodes,nNodes),float)
        numEdges = 0.
        for i in xrange(nNodes):
            for j in xrange(nNodes): 
                adj[i][j] = (random() < pEdge)
                numEdges = numEdges + adj[i][j]
        self.meanDegree = numEdges/self.nNodes;        
        return adj 
                    
    def toggleEdge(self, origin, destination):
        nextVal = not(self.adjacency[origin,destination])
        self.adjacency[origin,destination] = nextVal
        if nextVal: self.meanDegree += 1./self.nNodes;
        else: self.meanDegree -= 1./self.nNodes;

    def size(self):
        return self.nNodes

    def meanDegree(self):
        return self.meanDegree
        
    def currentTimeStep(self):
        return self.timeStep    
        
    ################ DYNAMIC PROPERTIES ######################

    def __initCurrentVs(self, nNodes, threshold, pSpikingInit):
        V = N.zeros(nNodes,float)
        for i in xrange(nNodes): 
            if (random() < pSpikingInit):        
                V[i] = threshold + .0001
            else: 
                V[i] = random() * threshold
        return V 
    
    def update(self):
        self.timeStep = self.timeStep + 1;
        self.currentVs += self.nextStimulus

        self.nextStimulus = N.zeros_like(self.currentVs)
        firing = (self.currentVs>self.threshold)
        self.currentVs[N.where(firing)] = self.resetVal
        if (N.sum(firing) > 0): 
            for i in xrange(self.nNodes):
                self.nextStimulus[i] = N.sum(self.adjacency[N.where(firing),i]) * self.stimulusVal
        self.currentVs -= self.currentVs * self.decayConstant
         
# use for interactive version 
#        self.nodeHist[1:,:] = self.nodeHist[0:-1,:] #shift first n-1 rows down
#        self.nodeHist[0,:] = self.currentVs #copy the current state into the first row   

#        self.spikeHist[1:,:] = self.spikeHist[0:-1,:] #shift first n-1 rows down
#        self.spikeHist[0,:] = firing #copy the current state into the first row   

        self.spikeHist[self.timeStep,:] = firing #fill the spike history array   

        

    def applyStimiulus(self):
        print 'applying stimulus!'
        self.nextStimulus = N.random.random(self.currentVs.shape[0])
        print self.nextStimulus

    
    ################ UTIL ######################

    def __repr__(self):
        return self.currentVs
