from Environment import *
from Particle import *
from numpy import *
import random

class ImplicitFitness:
    def __init__(self, env):
        self.env = env
        
    def runLoop(self, env):
        return env

    def writeData(self, env, recorder):
        recorder.writeData(env.recordData())

class ExplicitFitness:

    def __init__(self, env):
        self.env = env
        self.keep = 0.10 # keep top 10%
        self.geneCount = 20
        self.genes = []
        self.geneIndex = 0
        #self.particleCount = 10 # how many are exercised?
        self.runSteps = 100 # run each simulation for so many steps
        self.generation = 0
        self.times = []
        self.data = {}
        
    # Seed a new population (based on the last if there was one)
    def seed(self):
        print "seed"
        need = self.geneCount - len(self.genes)
        print "generation %d" % self.generation
        self.generation += 1
        newbies = []
        for i in range(0, need):
            if len(self.genes) > 0:
                gene = self.genes[random.randrange(0, len(self.genes))].mutate(self.env.mutationRate)
            else:
                gene = RandomVectorGene()
            newbies.append(gene)
        self.genes += newbies

    def currGene(self):
        return self.genes[self.geneIndex]

    def nextGene(self):
        self.geneIndex += 1
        if self.geneIndex == len(self.genes):
            # We've reached the end of this population
            return None
        return self.currGene()
    # Setup the environment for the current gene.  (Not for the current population.)
    def setupEnvironment(self, gene, env):
        #print "setup"
        # Setup the environment
        #env = Environment()
        env.reset()
        env.particles = []
        for i in range(0, self.env.maxPopulation):
            # It'll have a different velocity and everything
            p = Particle(gene)
            #p.rv *= 0.5
            #p.vv *= 0.5
            env.particles.append(p)
            
        #env.haveCollisions = False
        #env.minPopulation = 0
        #env.maxPopulation = self.particleCount
        env.fixedPopulation = False
        env.reproduce = False
    
        self.times = []
        def removedParticle(particle):
            self.times.append(particle.t)
        # Setup a handler
        env.removedParticle = removedParticle
        #for g in self.genes:
        #    g.fitness = 0.0
        return env
        
    def teardownEnvironment(self, env):
        #print "teardown"
        # Collect all the times
        for p in env.particles:
            self.times.append(p.t)

    def rateFitness(self):
        fitness = array(self.times).mean()
        return fitness

    def compareGenes(self, a, b):
        return cmp(a.fitness, b.fitness)

    def death(self):
        print "death"
        # Kill the unfit genes
        self.genes.sort(self.compareGenes)
        self.genes.reverse()
        fitnessA = array(map(rcurry(getattr, 'fitness'), self.genes))
        meanFitness = fitnessA.mean()
        #print "genes %s" % str(self.genes)
        print "fitnesses %s" % str(fitnessA)
        print "mean fitness %f" % meanFitness
        self.data['fitness-mean'] = meanFitness #fitnessA.mean()
        self.data['fitness-std'] = fitnessA.std()
        # The bigger the fitness the better
        # [more-fit ... less-fit]
        keepCount = int(len(self.genes) * self.keep)
        print "Keeping %d individuals" % keepCount
        self.genes = self.genes[0:keepCount]

    def writeData(self, env, recorder):
        recorder.writeData(self.recordData())
        #recorder.writeData(self.env.recordData())

    def recordData(self):
        oldData = self.data
        self.data = {}
        return oldData
        
    def runLoop(self, env):
        if len(self.genes) == 0:
            # First time through the loop
            self.seed()
            env = self.setupEnvironment(self.genes[self.geneIndex], env)
        if env.steps >= self.runSteps or len(env.particles) <= self.env.minPopulation:
            self.teardownEnvironment(env)
            # Record the fitness
            fitness = self.rateFitness()
            print "gene %d fitness %f" %(self.geneIndex, fitness)
            self.currGene().fitness = fitness
            self.geneIndex += 1
            if self.geneIndex == len(self.genes):
                # We're done testing this population.
                self.death()
                self.seed()
                self.geneIndex = 0
            env = self.setupEnvironment(self.genes[self.geneIndex], env)
        return env
