import pygame
from pygame.locals import *
from pygame import draw
from pygame import color
    
class NetGUI:

    def __init__(self, net, W=1000., H=1000., visHistLength=40):
        #W2,H2 = 320,240

        pygame.init()
        self.screen = pygame.display.set_mode((W,H))
        pygame.display.set_caption('Spiking Networks Interactive Explorer')
        #pygame.display.toggle_fullscreen()
    
        #visual history length
        #NOTE: because of implementation of negative indexing in numpy, vHL must be < hL
        self.visHistLength = visHistLength
        self.vHL = visHistLength
    
        #assign the corners of all the components
        self.histX = 0.,W/2.
        self.histY = 0.,H
        self.adjX = W/2.,W
        self.adjY = 0.,H/2.
        self.controlX = W/2.,W
        self.controlY = H/2.,H

        #determine size of cells
        self.hSize = self.histX[1]-self.histX[0], self.histY[1]-self.histY[0]
        self.histCellSize = min(self.hSize[0]/net.nNodes,self.hSize[1]/self.visHistLength)
        self.aSize = self.adjX[1]-self.adjX[0], self.adjY[1]-self.adjY[0]
        self.adjCellSize = min(self.aSize[0]/net.nNodes,self.aSize[1]/net.nNodes)
        
        #readjust adj area to make sure square
        self.aSize = min(self.aSize[0],self.aSize[1]), min(self.aSize[0],self.aSize[1])
        self.adjX = self.adjX[0], self.adjX[0] + self.aSize[0]
        self.adjY = self.adjY[0], self.adjY[0] + self.aSize[0]
        
        self.net = net;
        #t = NetControl()

        self.adjColorMap = self.__BWColorMap2
        self.spikeColorMap = self.__BWColorMap2
        self.nodeColorMap = self.__BWColorMap2
        self.spikeAndNodeColorMap = self.__RedColorMap2

        #note: necessary to draw this now since doesn't get updated every step
        self.__dispAdj(self.net.adjacency,self.net.nNodes,self.screen,self.adjCellSize)
            	
        #our default history display is for node voltages
        #devNote: modified for spike information
        self.__dispHist = self.__dispSpikeHist #self.__dispNodeHist


    def update(self):
        self.__render() 
        pygame.display.update((self.histX[0],self.histY[0],self.hSize[0],self.hSize[1]))
        
    def updateAdj(self):    
        self.__dispAdj(self.net.adjacency,self.net.nNodes,self.screen,self.adjCellSize)
        pygame.display.update()
        
    def __render(self):
        #pStimulus,pEdge,netSize,color = self.form['pStimulus'].value,self.form['pEdge'].value,self.form['netSize'].value,self.form['color'].value
        #nHistDisp,sHistDisp,cycleDisp = self.form['nHistDisp'].value,self.form['sHistDisp'].value,self.form['cycleDisp'].value
        self.__dispHist(self.net.nodeHist[-self.vHL-1:-1,:],self.net.spikeHist[-self.vHL-1:-1,:],self.visHistLength,self.net.nNodes,self.screen,self.histCellSize)

    # Colormap: Site value in [0,1] to Greyscale
    def __BWColorMap(self,x):
        return int(255*x),int(255*x),int(255*x)

        # Colormap: Site value in [0,1] to Greyscale
    def __BWColorMap2(self,x,flag):
        x=min(x,1)
        return int(255*x),int(255*x),int(abs(255*x-255*flag))

        # Colormap: Site value: either 1 or 0.
    def __RedColorMap2(self,x,flag):
        x=min(x,1)
        flag=not(flag)
        return int(255*x),int(0),int(abs(255*x-255*flag))

        # display adjacency matrix
    def __dispAdj(self,adj,nNodes,surface,adjCellSize):
        for i in xrange(nNodes):
            for j in xrange(nNodes):
                flag = 0;
                if (((i+1)%10) == 0) or (((j+1)%10) == 0):
                    flag = 1;
                surface.fill(self.adjColorMap(adj[i][j],flag),
                    [self.adjX[0]+i*adjCellSize,self.adjY[0]+j*adjCellSize,adjCellSize+1,adjCellSize+1])
        draw.rect(surface,(255,255,0),(self.adjX[0],self.adjY[0],self.aSize[0],self.aSize[1]),3);

        # display Node Voltage History
    def __dispNodeHist(self,nHist,sHist,hSize,nNodes,surface,histCellSize):
        for i in xrange(hSize):
            for j in xrange(nNodes):
                flag = 0;
                if (((i+1)%10) == 0) or (((j+1)%10) == 0):
                    flag = 1;
                surface.fill(self.nodeColorMap(nHist[i][j],flag),
                    [self.histX[0]+j*histCellSize,self.histY[0]+i*histCellSize,histCellSize+1,histCellSize+1])


        # display Spike History
    def __dispSpikeHist(self,nHist,sHist,hSize,nNodes,surface,histCellSize):
        for i in xrange(hSize):
            for j in xrange(nNodes):
                flag = 0;
                if (((i+1)%10) == 0) or (((j+1)%10) == 0):
                    flag = 1
                surface.fill(self.spikeColorMap(sHist[i][j],flag),
                    [self.histX[0]+j*histCellSize,self.histY[0]+i*histCellSize,histCellSize+1,histCellSize+1])


        # display Spike and Node History
    def __dispSpikeAndNodeHist(self,nHist,sHist,hSize,nNodes,surface,histCellSize):
        for i in xrange(hSize):
            for j in xrange(nNodes):
                flag = 0;
                if (((i+1)%10) == 0) or (((j+1)%10) == 0):
                    flag = 1
                if sHist[i][j]: 
                    surface.fill(self.spikeAndNodeColorMap(1,flag),
                        [self.histX[0]+j*histCellSize,self.histY[0]+i*histCellSize,histCellSize+1,histCellSize+1])
                else:
                    surface.fill(self.nodeColorMap(nHist[i][j],flag),
                        [self.histX[0]+j*histCellSize,self.histY[0]+i*histCellSize,histCellSize+1,histCellSize+1])


    def __toggleHistDisp(self):
        if (self.__dispHist == self.__dispNodeHist):
            self.__dispHist = self.__dispSpikeHist                    
        elif (self.__dispHist == self.__dispSpikeHist):
            self.__dispHist = self.__dispSpikeAndNodeHist                    
        else: 
            self.__dispHist = self.__dispNodeHist     


    ################ EVENTS #################################

    def handleEvents(self,paused):
        running = True
        for event in pygame.event.get():
            if event.type == QUIT: running = False
            if event.type == KEYDOWN and event.key == K_ESCAPE:
                running = False
            if event.type == MOUSEBUTTONDOWN:
                print 'mouse!'
                section,cell = self.posToNode(pygame.mouse.get_pos());
                if section=='adj':
                    self.net.toggleEdge(cell[0],cell[1])
                    self.__dispAdj(self.net.adjacency,self.net.nNodes,self.screen,self.adjCellSize)
                    pygame.display.update((self.adjX[0],self.adjY[0],self.aSize[0],self.aSize[1]))
                elif section=='hist':
                    if event.button == 3:
                        self.net.applyStimiulus()
                    elif event.button == 2:
                        paused = not paused
                    else: 
                        self.__toggleHistDisp()    
                        if paused:
                            self.update()

        return running,paused

    def posToNode(self,pos):
        section = 'none'
        cell = 0
        print pos
        print self.adjX
        print self.adjY
        if (pos[0]>self.adjX[0] and pos[0]<self.adjX[1])and(pos[1]>self.adjY[0] and pos[1]<self.adjY[1]):
            section = 'adj'
            cell = (pos[0]-self.adjX[0])/self.adjCellSize,(pos[1]-self.adjY[0])/self.adjCellSize
        elif (pos[0]>self.histX[0] and pos[0]<self.histX[1])and(pos[1]>self.histY[0] and pos[1]<self.histY[1]):
            section = 'hist'
            cell = 0
        return (section,cell)

#class NetControl():

#    def __init__(self):

#    def toggleSpikeView():
    
#    def toggleNodeView():
    
#    def toggleCycleView():
