#CAFucns.py: contains functions used in SIR_CA.py

#import modules
from numpy import *
from numpy.random import *
from pylab import *
#make sure Pygame installed
try:
	import pygame
	import pygame.surfarray as surfarray
except ImportError:
	raise ImportError, "PyGame required."
	
#for initializing the grid
def Init(CellNum,Iinitial):
	state = zeros((CellNum,CellNum),float)
	if Iinitial ==1:				#if only one initally infected cell, start in the middle
		i = int(CellNum/2)
		j = int(CellNum/2)
		state[i][j] = 1
	else:
		for n in range(Iinitial):
			i = randint(0,CellNum)
			j = randint(0,CellNum)
			state[i][j] = 1
	return state

	
#state update routine for SIR Model
def Update(state,p,q,CellNum):			
	for i in range(CellNum):
		for j in range(CellNum):
			INeighbor=0					#Number of I cells around cell i,j
			if state[i][j] == 0:		#count up number of I cells around S cell: mod makes sure boundaries are periodic
				if state[(i-1)%CellNum][(j-1)%CellNum] == 1:
					INeighbor += 1
				if state[(i-1)%CellNum][j%CellNum] == 1:
					INeighbor += 1
				if state[(i-1)%CellNum][(j+1)%CellNum] == 1:
					INeighbor += 1
				if state[i%CellNum][(j-1)%CellNum] == 1:
					INeighbor += 1
				if state[i%CellNum][j%CellNum] == 1:
					INeighbor += 1
				if state[i%CellNum][(j+1)%CellNum] == 1:
					INeighbor += 1
				if state[(i+1)%CellNum][(j-1)%CellNum] == 1:
					INeighbor += 1
				if state[(i+1)%CellNum][(j+1)%CellNum] == 1:
					INeighbor += 1
				Pinfect = 1 - (1 - p)**(INeighbor)
				if Pinfect >= random():			#if prob of being infected is greater than random #, become infected
					state[i][j] = 1
			elif state[i][j] == 1:
				if q >= random():				#if prob of being recovered greater than random #, become recovered
					state [i][j] = 2
			else: 
				state[i][j] = 2					#or else just stay at recovered
	return state

	
#state update routine for SIRS Model
def Update_Cold(state,p,q,CellNum,RCount):
	for i in range(CellNum):
		for j in range(CellNum):
			INeighbor=0					#Number of I cells around cell i,j
			if state[i][j] == 0:		#count up number of I cells around S cell: mod makes sure boundaries are periodic
				if state[(i-1)%CellNum][(j-1)%CellNum] == 1:
					INeighbor += 1
				if state[(i-1)%CellNum][j%CellNum] == 1:
					INeighbor += 1
				if state[(i-1)%CellNum][(j+1)%CellNum] == 1:
					INeighbor += 1
				if state[i%CellNum][(j-1)%CellNum] == 1:
					INeighbor += 1
				if state[i%CellNum][j%CellNum] == 1:
					INeighbor += 1
				if state[i%CellNum][(j+1)%CellNum] == 1:
					INeighbor += 1
				if state[(i+1)%CellNum][(j-1)%CellNum] == 1:
					INeighbor += 1
				if state[(i+1)%CellNum][(j+1)%CellNum] == 1:
					INeighbor += 1
				Pinfect = 1 - (1 - p)**(INeighbor)
				if Pinfect >= random():
					state[i][j] = 1
					
			elif state[i][j] == 1:
				if q >= random():
					state [i][j] = 2
			else:
				RCount[i][j] += 1
				if RCount[i][j] >= 10:
					state[i][j] = 0					# if been recovered for n time steps, then go back to being susceptible
					RCount[i][j] = 0				#Reset RCount
				else:
					state[i][j] = 2					#or else still recovered
	return state, RCount
				
				
				
	
#determine which colors to display on grid
def SpDisplayState(state,CellNum,surface,CellSize,SColor,IColor,RColor):	
	for i in range(CellNum):
		for j in range(CellNum):
			if state[i][j] == 0:
				surface.fill(SColor,[i*CellSize,j*CellSize,CellSize,CellSize])
			if state[i][j] == 1:
				surface.fill(IColor,[i*CellSize,j*CellSize,CellSize,CellSize])
			if state[i][j] == 2:
				surface.fill(RColor,[i*CellSize,j*CellSize,CellSize,CellSize])



#keept track of the number of S, I, and R cells at each time step
def SIRCount(state,CellNum,STot,ITot,RTot):
	SNum = 0
	INum = 0
	RNum = 0
	for i in range(CellNum):
		for j in range(CellNum):
			if state[i][j] == 0:
				SNum += 1
			if state[i][j] == 1:
				INum += 1
			if state[i][j] == 2:
				RNum += 1
	STot.append(SNum)
	ITot.append(INum)
	RTot.append(RNum)
