from Particle import *
from Shane import *
from random import *

def averageGeneVector(env):
    genes = map(rcurry(getattr, 'gene'), env.particles)
    vectors = map(rcurry(getattr, 'v'), genes)
    return sumv(vectors)/float(len(vectors))

def cosineAngle(v1, v2):
    try:
        return arccos(dot(v1, v2)/(mag(v1) * mag(v2)))
    except:
        return 0.0

class Environment:
    def __init__(self):
        self.steps = 0
        self.xRange = a(-2, 2)
        self.yRange = a(-2, 2)
        self.mutationRate = 0.1
        self.e = 0.5 # coefficient of restitution
        self.fixedPopulation = False
        self.maxPopulation = 25
        self.minPopulation = 3
        #self.expRate = 0.001
        #self.expMean = 1000
        self.reproMeanTime = 3.0
        self.reproduce = True
        self.haveCollisions = True
        self.removedParticle = None
        self.showGeneArrows = False
        self.recomputeGene = 20
#        self.recordOnStep = None    # None to not record, 1 to record
                                 # every step, 2 to record every
                                 # second step
        self.reproTimes = []
        self.reset()

    def reset(self):
        # reset the population
        self.particles = []
        for i in range(0, self.maxPopulation):
            self.particles.append(Particle())
        self.steps = 0

    def setGeneVectors(self):
        avg = averageGeneVector(self)
        b = mag(avg)
        #print "avg %s" % str(avg)
        for p in self.particles:
            c = mag(p.gene.v)
            angle = cosineAngle(avg, p.gene.v) * 4
            #print "angle %s " % str(angle)
            # Make the gene arrow 
            normalSize = 0.2
            p.geneArrow = normalSize * c/b * a(cos(angle), sin(angle))
    
    def timeStep(self, dt):
        if self.showGeneArrows and self.steps % self.recomputeGene == 0:
            self.setGeneVectors()
        self.steps += 1
        for p in self.particles:
            p.timeStep(dt)
        if self.haveCollisions:
            self.collisions()
        self.checkBounds()
#        if self.recordOnStep and self.steps % self.recordOnStep == 0:
#            self.recordData()
        if not self.fixedPopulation and self.reproduce:
            self.checkTimes()
    
    def collisions(self):
        for i in range(0, len(self.particles)):
            a = self.particles[i]
            for j in range(i + 1, len(self.particles)):
                b = self.particles[j]
                r_ab = a.rv - b.rv
                n = normalize(r_ab)
                if dot(r_ab,r_ab) < (a.radius + b.radius)**2:
                    v_ab = a.vv - b.vv
                    if dot(v_ab, r_ab) < 0 or True:
                        #print "closing collision"
                        m = mag(v_ab)
                        #if m < 0.01:
                        #    m = 0.1
                        J = - m * (self.e + 1)/2
                        a.vv -= J * n
                        b.vv += J * n
                            

    def checkTimes(self):
        youngOnes = []
        for p in self.particles:
            l = meanTimeToLambda(self.reproMeanTime)
            if cumulativeExpDist(l, p.reproTime) > random():
                if len(self.particles) < self.maxPopulation:
                    self.reproTimes.append(p.reproTime)
                    print "gave birth when t was %f mean time %f" % (p.reproTime, array(self.reproTimes).mean())
                    youngOnes.append(p.mutate(self.mutationRate))
                p.reproTime = 0.0
        self.particles += youngOnes

    def checkBounds(self):
        removeList = []
        for i in range(0, len(self.particles)):
            p = self.particles[i]
            xmin, xmax = self.xRange
            ymin, ymax = self.yRange
            if p.rv[0] < xmin or p.rv[0] > xmax or p.rv[1] < ymin or p.rv[1] > ymax:
                # Throw this particle out
                j = int(random() * (len(self.particles) - 1))
                if j >= i: j += 1
                #print "Throwing a particle out i %d j %d." % (i, j)
                if self.fixedPopulation:
                    self.handleRemovedParticle(self.particles[i])
                    self.particles[i] = self.babyParticle()
                else:
                    removeList.append(i)
        # Delete them from the end of the list otherwise the indexes
        # will be bad.
        removeList.sort()
        removeList.reverse()
        for i in removeList:
            #print "removing particle with steps %d" % (self.particles[i].steps)
            self.handleRemovedParticle(self.particles[i])
            del self.particles[i]
        if len(self.particles) < self.minPopulation:
            babies = []
            for i in range(0, self.minPopulation - len(self.particles)):
                babies.append(self.babyParticle())
            self.particles += babies

    def handleRemovedParticle(self, particle):
        if self.removedParticle:
            self.removedParticle(particle)
        pass

    def babyParticle(self):
        if len(self.particles) == 0:
            return Particle()
        j = int(random() * len(self.particles))
        return self.particles[j].mutate(self.mutationRate)

    def closestParticle(self, position):
        dist = 5.0
        closest = None
        for p in self.particles:
            d = mag(p.rv - position)
            if d < dist:
                closest = p
                dist = d
        return closest

    def recordData(self):
        data = {}
        # find the mean of reproduction
        babies = array(map(rcurry(getattr, 'babies'), self.particles))
        data['reproduction-count-max'] = babies.max()
        data['reproduction-count-mean'] = babies.mean()
        data['reproduction-count-std'] = babies.std()
        genes = array(map(rcurry(getattr, 'gene'), self.particles))
        generations = array(map(rcurry(getattr, 'generation'), genes))
        data['generation-number-min'] = generations.min()
        data['generation-number-mean'] = generations.mean()
        data['generation-number-std'] = generations.std()
        times = array(map(rcurry(getattr, 't'), self.particles))
        data['age-max'] = times.max()
        data['age-mean'] = times.mean()
        data['age-std'] = times.std()
        data['particle-count'] = len(self.particles)
        avgGene = averageGeneVector(self)
        data['gene-mean-mag'] = mag(avgGene)
        geneVectors = map(rcurry(getattr, 'v'), genes)
        geneMagnitudes = array(map(mag, geneVectors))
        
        data['gene-mag-mean'] = geneMagnitudes.mean()
        data['gene-mag-std'] = geneMagnitudes.std()

        speed = array(map(compose(mag, rcurry(getattr, 'vv')), self.particles))
        data['speed-mean'] = speed.mean()
        data['speed-std'] = speed.std()

        dist = array(map(compose(mag, rcurry(getattr, 'rv')), self.particles))
        data['dist-mean'] = dist.mean()
        data['dist-std'] = dist.std()
        return data

# Cumulative Exponential Distribution Function
# I don't think I'm using this correctly.  Somehow I'm missing something.
def cumulativeExpDist(l, x):
    return 1.0 - exp(-l * x)

def meanTimeToLambda(meantime):
    #return 0.0034/(meantime - 3.55)
    return 0.0034/meantime

# With mean reproduction rate set at 1000, the mean time of
# reproduction is 7.3

# With mean reproduction rate set at 2000, the mean time of
# reproduction is 10.6

# With mean reproduction rate set at 500, the mean time of
# reproduction is 5.4

# meantime = 0.0034 (meanreproductionrate) + 3.55 with an R^2 = 0.9987
# (pretty good).  All right, so I have an empirical means of setting
# my mean reproduction time.  Not great, but I could do worse.
                
            
