#!/usr/bin/env python

from SpikingNetwork import SpikingNetwork    
from SpikeExploreGUI import NetGUI    
#import sys
#import time
from random import random
try:
    import numpy as N
    import pygame
    import pygame.surfarray as surfarray
    from pygame.locals import *
except ImportError:
    raise ImportError, "numpy required."

#devNote: for pres purposes allow for giving initial conditions (to recreate the adj matrix and spike train graphically)

################ MAIN ###################################

def wordProbs(data):
    probs = {}
    val = 1./data.shape[0]
    for i in xrange(data.shape[0]):
        s = str(data[i,:])
        if probs.has_key(s):
            probs[s] = probs[s] + val
        else:           
            probs[s] = val
    return probs

def entropy(data):
    probs = wordProbs(data).values()
    return -sum(probs*N.log2(probs))

def mutualInfo(data, node1, node2):
    return entropy(N.atleast_2d(data[:,node1]).T) \
            + entropy(N.atleast_2d(data[:,node2]).T) \
            - entropy(data[:,[node1,node2]])

def pairwiseMI(data):
    numNodes = data.shape[1]
    entropies = N.zeros((numNodes,numNodes),float)
    MI = N.zeros((numNodes,numNodes),float)
    
    for i in xrange(numNodes):
        for j in xrange(numNodes):
            if i == j:
                entropies[i][j] = entropy(N.atleast_2d(data[:,i]).T)
            else:
                entropies[i][j] = entropy(data[:,[i,j]])
    for i in xrange(numNodes):
        for j in xrange(numNodes):
            MI[i][j] = entropies[i][i] + entropies[j][j] - entropies[i][j]
    return MI
    
def shiftedJointEntropies(data,maxShift):
    numNodes = data.shape[1]
    sEntropies = N.zeros((maxShift*2+1,numNodes,numNodes),float)
    for s in xrange(-maxShift,maxShift+1):
        j_min = max(0,s)
        j_max = min(data.shape[0],data.shape[0]+s)
        i_min = max(0,-s)
        i_max = min(data.shape[0],data.shape[0]-s)
        for i in xrange(numNodes):
            for j in xrange(i):
                if i == j:
                    continue
                else:
                    sEntropies[s+maxShift][i][j] = entropy(N.vstack((data[i_min:i_max,i],data[j_min:j_max,j])).T);    
    return sEntropies             
    
def selfEntropies(data):           
    numNodes = data.shape[1]
    sEntropies = N.zeros(numNodes,float)

    for i in xrange(numNodes):
        sEntropies[i] = entropy(N.atleast_2d(data[:,i]).T)
    return sEntropies
            
def shiftedPairwiseMI(data,maxShift):
    numNodes = data.shape[1]
    entropies = N.zeros((numNodes,numNodes),float)
    MI = N.zeros((maxShift*2+1,numNodes,numNodes),float)

    selfH = selfEntropies(data)
    shiftH = shiftedJointEntropies(data,maxShift)
    for s in xrange(-maxShift,maxShift+1):
        for i in xrange(numNodes):
            for j in xrange(i):
                if i==j:
                    MI[s][i][j] = selfH[i]
                else:   
                    MI[s][i][j] = selfH[i] + selfH[j] - shiftH[s][i][j]
    return MI

network_data = True
test_data = False
single_run = False
random_adj = False
visualize_it = True

if network_data:

    nNodes = 6
    historySize = 2000
    pEdge=.65 #.8
    pSpikingInit= 1
    stimulusVal = .5001
    decayConstant=.01
    
    if single_run:

        # construct/initialize network
        if random_adj:
            net = SpikingNetwork(nNodes,historySize,pEdge,pSpikingInit,stimulusVal,decayConstant)
        else:
    #        adj = N.array([[1, 1, 0],[1, 1, 1],[0, 0, 0]])
    #        adj = N.array([[1, 1, 0, 0],[1, 1, 1, 0],[0, 0, 0, 1],[0, 0, 0, 0]])
            adj = N.array([[1, 1, 0, 0, 0],[1, 1, 1, 0, 0],[0, 0, 0, 1, 0],[0, 0, 0, 0, 1],[0, 0, 0, 0, 0]])
    #        adj = N.array([[1, 1, 1, 0, 1],[1, 1, 0, 1, 1],[1, 1, 1, 0, 1],[1, 0, 1, 0, 1]])
            net = SpikingNetwork(adj,historySize,pEdge,pSpikingInit,stimulusVal,decayConstant)

        if visualize_it:
            # initialize display stuff
            gui = NetGUI(net,950,950,40)

            gui.update()
            pygame.display.flip()

        # generate data
        while net.currentTimeStep() < historySize-1:
            net.update()

        if visualize_it:
            gui.update()

        data = net.spikeHist
        #print data
        #print entropy(data)
        sPMI = shiftedPairwiseMI(data[20:,:],3)
    #    print N.argsort(sPMI,2)
        print N.argmax(sPMI,0)-3



        print N.around(1000*sPMI)/1000


        #a = N.array([[1, 3], \
        #           [2, 4]])
        #print N.argsort(a,0)
        #print N.argsort(a,1)

    else: #batch run

        nNodes = 6
        historySize = 41
        pSpikingInit= 1
        stimulusVal = .5001
        decayConstant=.05

        numTrials = 100
        visualize_it = True


        allAdj = N.zeros((numTrials,nNodes,nNodes),bool)
        allSpikes = N.zeros((numTrials,historySize,nNodes),bool)
        allH = N.zeros(numTrials,float)

        for i in xrange(numTrials):
        
        
            print 'Trial %(n)d' % {'n': i}
        
            pEdge = N.random.uniform(.25,.75)
            net = SpikingNetwork(nNodes,historySize,pEdge,pSpikingInit,stimulusVal,decayConstant)                                            
            # generate data
            while net.currentTimeStep() < historySize-1:
                net.update()
            allAdj[i][:,:] = net.adjacency
            allSpikes[i] = net.spikeHist
            #allH[i] = entropy(net.spikeHist[50:,:]) #check non-transients
            allH[i] = entropy(net.spikeHist) #check beginning transients
          
        rankedH = N.argsort(allH)[::-1]
        
        if visualize_it:
            # initialize display stuff
            gui = NetGUI(net,950,950,40)
            gui.update()
            pygame.display.flip()

        for i in rankedH[0:15]:
            print '%(n)d of 36 possible edges' % {'n': sum(sum(allAdj[i]))}
            print 'Entropy: %(n)g' % {'n': allH[i]}

            if visualize_it:
                net.adjacency = allAdj[i]
                net.spikeHist = allSpikes[i]
                gui.updateAdj()
                gui.update()
                
            raw_input()
                
        for i in rankedH[-15:-1]:
            print allH[i]

            if visualize_it:
                net.adjacency = allAdj[i]
                net.spikeHist = allSpikes[i]
                gui.updateAdj()
                gui.update()
                
            raw_input()




elif test_data:
    # generate test data
    data = (N.random.rand(historySize,nNodes)>.8);
    data2 = (N.random.rand(historySize,1)>.5)
    data = N.append(data,data2,1)
    
    H = selfEntropies(data)
    print N.argsort(H)
    
    
    #data[:,-1] = data[:,0] #should see MI=1 between nodes 1 and 4
#    data[0:-2,-1] = data[1:-1,0] #but not with a shifted time on one
    #print pairwiseMI(data)
    #print shiftedJointEntropies(data,1)
    #print shiftedPairwiseMI(data,1)
