import sys, pygame
import numpy as N
import math as M
import random as R
from Scientific.Geometry import Vector
from visual import *
from time import clock
from visual.graph import *


# Ask the user to select the window size.

sizeSelection = 0
validInput = false
while (validInput == false):
    print 'Please select the window size.'
    print '1. 600 x 600 window'
    print '2. 900 x 400 window'
    print '3. 1200 x 300 window'
    sizeSelection = raw_input()
    if sizeSelection.isdigit():
        sizeSelection = int(sizeSelection)
        if (sizeSelection == 1) or (sizeSelection == 2) or (sizeSelection == 3):
            validInput = true

# Ask the user to enter the number of balls.
num = 0
validInput = false
while (validInput == false):
    print 'Plesase enter the number of balls. For example 15'
    num = raw_input()
    if num.isdigit():
        num = int(num)
        validInput = true

# Ask the user to enter the velocity range. If the user enter 5, then the balls will have
# initial velocity (x and y directions) randomly chosen from [-5, 5] 
maxSpeed  = 0
validInput = false
while (validInput == false):
    print 'Please enter the initial velocity range. For example 10'
    maxSpeed = raw_input()
    if maxSpeed.isdigit():
        maxSpeed = int(maxSpeed)
        if (maxSpeed > 0):
            validInput = true

# Askt he user if he/she wants to see the velocity histogram.
histogramOption  = 2
validInput = false
while (validInput == false):
    print 'To see the histogram, enter 2. Otherwise, enter 1'
    histogramOption = raw_input()
    if histogramOption.isdigit():
        histogramOption = int(histogramOption)
        if (histogramOption == 1) or (histogramOption == 2):
            validInput = true


pygame.init()

# Initialize the size of the window according to user's selection
width = 0
height = 0
if sizeSelection == 1:
    width = 600
    height = 600

elif sizeSelection == 2:
    width = 900
    height = 400
else:
    width = 1200
    height = 300

size = width, height
white = 255,255,255
black = 0, 0, 0
screen = pygame.display.set_mode(size)


# Initialize some parameters for plotting the histogram
win=500
numOfBalls = num  # change this to have more or fewer atoms
# Typical values
L = 1. # container is a cube L on a side
k = 1.4E-23 # Boltzmann constant
T = 30. # around room temperature

observation = 0
if(histogramOption == 2):
    deltav = 1. # binning for v histogram
    vdist = gdisplay(x=0, y=win, ymax = numOfBalls*deltav/2.,
             width=win, height=win/2, xtitle='v', ytitle='dN')
    observation = ghistogram(bins=arange(0.,(2.*maxSpeed**2)**0.5,deltav),
                        accumulate=1, average=1, color=color.red)



# Create some lists to store infomation of the balls
balls = []
ballrects = []
positions = []
speeds = []
radius = []
mass = []

for i in range(num):        

        # Associate an image to each ball. Note that we have seven ball images with different color.
	k = i
	while k >= 7:
                k = k - 7
        k = k + 1
	filename = "ball" + str(k) + ".bmp"
	ball = pygame.image.load(filename)
	balls.append(ball)

        # Randomly assign the x any y posision coordinate to the ball. Note the ball image has width 32
        # and height 32, so the possible value for xPosition is bwtween 0 and width-21. Simularly for yPosition
	xPosition = (width-32)*R.random() 
	yPosition = (height-32)*R.random()

        # Note xPosition and yPosition are just the coordinates of the topleft corner of the ball image,
        # the center of the ball should be (xPosition+16, and yPosition+16) since the image has size 32 x 32
	position = [xPosition+16, yPosition+16]	
	positions.append(position)

        # Move the ball to its initial location.
	ballrect = ball.get_rect()
	ballrect = ballrect.move(position)
	ballrects.append(ballrect)


	# Randomly assigns velocity (both x and y component) to the ball from the range [-maxSpeed, maxSpeed].
	sign = 1
	if (R.random()< 0.5):
		sign = -1
	else:
		sign = 1
	xSpeed = sign*(maxSpeed)*R.random()
	if (R.random()< 0.5):
		sign = -1
	else:
		sign = 1	
	ySpeed = sign*(maxSpeed)*R.random()
	tempspeed = [xSpeed,ySpeed]
	speeds.append(tempspeed)

	#All ball has identical mass = 1, radius = 16
	mass.append(1.0)
	radius.append(16)


# Initialize the variables ball_collisions and wall_collisions to 0
ball_collisions = 0
wall_collisions = 0



# Note that in this program, the unit of velocity is pixel per iteration, not
# miles per second. If each iteration represent one second, then 24*60*60 iterations = one day.
# We will terminate the program afater 24*60*t0 iterations.
run_count = 24*60*60
while (run_count > 0):
        run_count = run_count - 1;
	for event in pygame.event.get():
		if event.type == pygame.QUIT: sys.exit()
	
        # Fill the screen with white color
	screen.fill(white)
	for i in range(num):
                # Update the ball location
                ballrects[i] = ballrects[i].move(speeds[i])
                positions[i][0] = (ballrects[i].right + ballrects[i].left) / 2
                positions[i][1] = (ballrects[i].top + ballrects[i].bottom) / 2

                # Update the ball-to-wall collision count.
		if ballrects[i].left < 0 or ballrects[i].right > width:
 			speeds[i][0] = -speeds[i][0]
 			wall_collisions = wall_collisions + 1

                        # Fix the location error if the ball overlap with the left or right walls
                        xerror = 0
                        if (ballrects[i].left < 0):
                                xerror = 2*(0 - ballrects[i].left) 
                        else:
                                xerror =  2*(width - ballrects[i].right)
                        ballrects[i] = ballrects[i].move([xerror,0])
		if ballrects[i].top < 0 or ballrects[i].bottom > height:
			speeds[i][1] = -speeds[i][1]
			wall_collisions = wall_collisions + 1

                        # Fix the location error if the ball overlap with the top or bottom walls
                        yerror = 0
                        if (ballrects[i].top < 0):
                            yerror = 2*(0 - ballrects[i].top)
                        else:
                            yerror =  2*(height - ballrects[i].bottom);
                        ballrects[i] = ballrects[i].move([0,yerror]);

                
		for j in range(num):
			if(i != j):
                                
				pi2 = 2*M.acos(-1.0) 
				error = 0
				r12 = radius[i] + radius[j] # sum of two radii
				m21 = mass[j]/mass[i]       # The ratio of two balls' masses
				x21 = positions[j][0] - positions[i][0] # The different of two balls' x coordinate
				y21 = positions[j][1] - positions[i][1] # The different of two balls' y coordinate
				vx21 = speeds[j][0] - speeds[i][0]      # The different of two balls' x velocity coordinate
				vy21 = speeds[j][1] - speeds[i][1]	# The different of two balls' y velocity coordinate



                                # d is the distance of two balls
				d = (x21**2 + y21**2)**(0.5)

                                # Check if two balls collide with each other
				if (d < r12):
                                        # Update the ball-to-ball collision count
                                        ball_collisions = ball_collisions + 1

                                        # Compute the relative velocity angle
					gammav = M.atan2(-vy21, -vx21)
					# Compute the relative position angle
					gammaxy = M.atan2(y21,x21)
					# Compute the normalized impact parameter
					dgamma = gammaxy - gammav
					if(dgamma > pi2):
						dgamma = dgamma - pi2
					elif(dgamma < -pi2):
						dgamma = dgamma + pi2
					dr = d*M.sin(dgamma)/r12

					# Calculate the impact angle if balls do collide	
					alpha = M.asin(dr)
				
					# Update velocities
					a = M.tan(gammav + alpha)
					dvx2 = -2*(vx21 + a*vy21) / ((1+ a*a)*(1+m21))
					speeds[j][0] = speeds[j][0] + dvx2
					speeds[j][1] = speeds[j][1] + a*dvx2
					speeds[i][0] = speeds[i][0] - m21*dvx2
					speeds[i][1] = speeds[i][1] - a*m21*dvx2;

	for i in range(num):
		screen.blit(balls[i], ballrects[i])

        # Plot the velocity distribution
        if (histogramOption == 2):
            velocityNorms = []
            for i in range(num):
                    velocityNorms.append(mag(speeds[i]))
            observation.plot(data=velocityNorms)
		
	pygame.display.flip()


# Print the number of ball-to-wall collision, the number of ball-to-ball collision, and total collisions
print 'Wall collisions = ', + wall_collisions
print 'Ball collisions = ', + ball_collisions
print 'Total collisions = ', + (wall_collisions + ball_collisions)

	
	







