import sys, os.path
from Environment import *
from pickle import *
from Tkinter import *
from tkFileDialog import *
from Recorder import *
from ExplicitFitness import *

size = width, height = 640, 640
draw = True
speed = [2,2,2]

black = 0,0,0
white = 255,255,255
whiteRGB = "#ffffff"

xscale, yscale = 2.0,2.0
pause = True
updaters = []
recorder = None
canvas = None
dt = 0.05

env = Environment()
explicit = False

implicitFitness = ImplicitFitness(env)
explicitFitness = ExplicitFitness(env)

def fitnessEval():
    global explicit, explicitFitness, implicitFitness
    if explicit:
        return explicitFitness
    else:
        return implicitFitness

def drawAxis():
    drawLines([trans([0, -1]), trans([0, 1])], whiteRGB)
    drawLines([trans([-1, 0]), trans([1, 0])], whiteRGB)

def trans(point):
    # Translate from a point in a different RF into pixel coordinates
    return [ int(point[0]/xscale * width/2.0 + width/2.0), 
             int(-point[1]/yscale * height/2.0 + height/2.0) ]

def transinv(point):
    return [ (float(point[0]) * 2.0/width - 1) * xscale,
             (float(-point[1]) * 2.0/height + 1) * yscale]

def currpath():
    path = sys.argv[0]
    return os.path.dirname(path)

def within(minx, maxx, x):
    return max(minx, min(maxx, x))

def colorFtoI(colors):
    result = []
    for f in colors:
        c = int(f * 255)
        result.append(within(0, 255, c))
    return result

def colorRGB(colors):
    return "#%02x%02x%02x" % tuple(colorFtoI(colors))

def drawArrow(start, end, color = whiteRGB):
    global canvas
    canvas.create_line(start[0], start[1], end[0], end[1], fill = color, arrow = 'last')

def togglePause():
    global pause
    pause = not pause

def togglePopulation():
    #global env
    env.fixedPopulation = not env.fixedPopulation

def setMaxPopulation(maxPop):
    #global env
    maxPop = int(maxPop)
    env.maxPopulation = maxPop
    if env.maxPopulation < env.minPopulation:
        env.minPopulation = env.maxPopulation
#    print "set max pop to %d" % (maxPop)

def setMinPopulation(minPop):
    #global env
    minPop = int(minPop)
    env.minPopulation = minPop
    if env.minPopulation > env.maxPopulation:
        env.maxPopulation = env.minPopulation
#    print "set min pop to %d" % (minPop)

def if_(p, a, b):
    if p:
        return a
    else:
        return b

def tryint(x):
    try:
        return int(x)
    except:
        return 0

def save():
    global pause, env
    p = pause
    pause = True
    file = asksaveasfile(mode='w', initialdir=".", title="Save state into file")
    dump(env, file)
    pause = p

def load1():
    global pause, env
    p = pause
    pause = True
    file = askopenfile(mode='r', initialdir=".", title="Load state from file")
    env = load(file)
    #dump(env, file)
    pause = p

def setupRecorder():
    global pause, recorder
    p = pause
    pause = True
    dirname = askdirectory(initialdir=".", title="Select directory for results")
    recorder = Recorder(dirname)
    pause = p

def stopRecorder():
    global recorder
    recorder = None

# DoubleVar
def scaleWidget(root, title, range, getter, setter, var = IntVar, res = 1):
    row3 = Frame(root)
    w = Label(row3, text = title, justify=LEFT, anchor=W)
    w.pack(fill="x")
    row3.pack(fill="x")
    
    row3 = Frame(root)
    w = Scale(row3, from_ = range[0], to = range[1], orient = HORIZONTAL, 
              command = lambda x: setter(if_(var == IntVar, tryint(x), float(x))),
              variable = var, resolution = res)
    updater = lambda : (w.set(getter()))
    updater()
    updaters.append(updater)
    w.pack(fill="x")
    row3.pack(fill="x")
    return row3

def canvasWidget():
    window = Tk()
    window.title("Dots")
    window.bind("<Button-1>", buttonClicked)
    canvas = Canvas(window, width=640, height=640, bg="black")
    canvas.pack(expand=YES, fill=BOTH)
    text = canvas.create_text(50,10, text = "test")
    return canvas

def drawLines(points, color):
    global canvas
    apply(canvas.create_line, points, {'fill': color} )

def drawCircle(point, radius, **options):
    global canvas
    apply(canvas.create_oval, [point[0] - radius, point[1] - radius, 
                               point[0] + radius, point[1] + radius], options)

def drawState():
    canvas.delete('all')
    for p in env.particles:
        color = colorFtoI(p.color())
        try:
            r = int(p.radius * width/4.0)
            drawCircle(trans(p.rv), r, fill = colorRGB(p.color()), width = 0)
            if p.geneArrow != None:
                #print "arrow = %s" % str(p.geneArrow)
                drawArrow(trans(p.rv), trans(p.rv + p.geneArrow))
        except e:
            raise e
        #print "p.rv %s" % str(p.rv)
        if len(p.tail) > 1:
            color = colorRGB(p.color())
            drawLines(map(trans, p.tail), color)
            if p.selected:
                r = int(p.radius * width/2.0)
                drawCircle(trans(p.rv), r, outline = whiteRGB, width= 1)
                for arrow in p.gridBehavior():
                    drawArrow(trans(arrow[0]), trans(arrow[1]), color)


def toggleButton(root, title, getter, setter):
    w = Checkbutton(root, text = title, anchor=W, justify=LEFT, command = lambda: setter(not getter()))
    def updater():
        if getter(): 
            w.select() 
        else: 
            w.deselect()
    updaters.append(updater)
    updater()
    
    w.pack(anchor=W, fill='x')

def buttonClicked(event):
    global env
    pos = array(transinv([event.x, event.y]))
    print "pos = %s" % str(pos)
    p = env.closestParticle(pos)
    if p:
        p.selected = not p.selected
        print p.displayEquation()

def setRadius(value):
    global env
    for p in env.particles:
        p.radius = value
    Particle.defaultRadius = value

def getRadius():
    global env
    return Particle.defaultRadius

def getTimeStep():
    global dt
    return dt

def setTimeStep(value):
    global dt
    dt = value

def initSettingsWindow():
    global env, pause, canvas
    root = Tk()
    root.title("Settings")
    canvas = canvasWidget()
    row1 = Frame(root)
    # Wow, python has really crimpled lambda.  Only one expression, no
    # statements (no ifs), and no assignment.  Instead, you have to
    # define some inner functions if you want to do all that, which is
    # fine, just annoying.

#    toggleButton(row1, "Fixed Population", 
#                 curry(getattr, env, 'fixedPopulation'),
#                 curry(setattr, env, 'fixedPopulation'))

#    row1.pack(side="top", fill='x')

    def getExplicit():
        global explicit
        return explicit
    
    def setExplicit(x):
        global explicit
        explicit = x
    toggleButton(row1, "Explicit Fitness", 
                 getExplicit,
                 setExplicit)

    row1.pack(side="top", fill='x')

    row2 = Frame(root)
    toggleButton(row2, "Collisions", 
                 curry(getattr, env, 'haveCollisions'),
                 curry(setattr, env, 'haveCollisions'))
    toggleButton(row2, "Show Gene Vectors", 
                 curry(getattr, env, 'showGeneArrows'),
                 curry(setattr, env, 'showGeneArrows'))
    row2.pack(fill='x')

    row = Frame(root)
    toggleButton(row, "Enable Drag ", 
                 curry(getattr, Particle, 'haveDrag'),
                 curry(setattr, Particle, 'haveDrag'))
    row.pack(fill='x')


    scaleWidget(root, "Min Population", a(0, 100), 
                curry(getattr, env, 'minPopulation'),
                setMinPopulation)
                #curry(setattr, env, 'minPopulation'))

    scaleWidget(root, "Max Population", a(0, 100), 
                curry(getattr, env, 'maxPopulation'),
                setMaxPopulation)
                #curry(setattr, env, 'maxPopulation'))

    scaleWidget(root, "Gene Count", a(0, 100), 
                curry(getattr, explicitFitness, 'geneCount'),
                curry(setattr, explicitFitness, 'geneCount'))

    scaleWidget(root, "Run Steps", a(0, 1000), 
                curry(getattr, explicitFitness, 'runSteps'),
                curry(setattr, explicitFitness, 'runSteps'))

    scaleWidget(root, "Keep", a(0, 1), 
                curry(getattr, explicitFitness, 'keep'),
                curry(setattr, explicitFitness, 'keep'), DoubleVar, 0.01)

    scaleWidget(root, "Coefficient of Restitution", a(0, 1.0), 
                curry(getattr, env, 'e'),
                curry(setattr, env, 'e'), DoubleVar, 0.1)

    scaleWidget(root, "Mean Reproduction Time", a(0.1, 100.0), 
                curry(getattr, env, 'reproMeanTime'),
                curry(setattr, env, 'reproMeanTime'), DoubleVar, 0.1)

    scaleWidget(root, "Mutation Rate", a(0, 0.5), 
                curry(getattr, env, 'mutationRate'),
                curry(setattr, env, 'mutationRate'), DoubleVar, 0.01)

    scaleWidget(root, "Tail Size", a(0, 100), 
                curry(getattr, Particle, 'tailCount'),
                curry(setattr, Particle, 'tailCount'))

    scaleWidget(root, "Radius", a(0.0, 2.0), 
                getRadius,
                setRadius, DoubleVar, 0.01)

    scaleWidget(root, "Time Step", a(0.0, 1.0), 
                getTimeStep,
                setTimeStep, DoubleVar, 0.01)

    row3 = Frame(root)
    def togglePause():
        global pause
        pause = not pause

    w = Button(row3, text="Pause", command=togglePause)
    w.pack(side="left")
    w = Button(row3, text="Reset", command=env.reset)
    w.pack(side="right")
    row3.pack()

    def toggleDraw():
        global draw
        draw = not draw
    w = Button(row3, text="Draw", command=toggleDraw)
    w.pack()

    row3 = Frame(root)
    w = Button(row3, text="Save", command=save)
    w.pack(side="left")

    w = Button(row3, text="Load", command=load1)
    w.pack()
    row3.pack()

    row4 = Frame(root)

    w = Button(row4, text="Setup Recorder", command=setupRecorder)
    w.pack()

    w = Button(row4, text="Stop Recorder", command=stopRecorder)
    w.pack()

    row4.pack()

    return root

def updateWidgets():
    for updater in updaters:
        updater()

def run():
    global pause, env, canvas, dt, draw, recorder, fitnessEval
    settings = initSettingsWindow()
    while True:

        # Update tk
        settings.update()
        canvas.update()
        # Update widget values
        updateWidgets()
        if draw:
            drawState()

            drawAxis()

        # Next time step
        if not pause:
            env.timeStep(dt)
            # Record if applicable
            if recorder:
                fitnessEval().writeData(env, recorder)
                #recorder.writeData(env.recordData())
            fitnessEval().runLoop(env)            
            
if __name__ == '__main__':
    run()
