#Sharon Chang
#UC Davis Physics
#Physics 150 Project


#SIR_CA.py: Cellular automaton model for the spread of epidemics, based on the SIR model. 
#User can input which type of simulation they want to run: regular disease (SIR) or the common cold (SIRS)
#Other paramters such is intial infected population, population size, and probabilities for infection and recover
	# can also be set by the user.
#Program will plot the percentage of the population that is susceptible, infected, or recovered as a function of time step
#SIRS model will also plot the number of times each cell has been infected.  (more times infected, whiter the cell becomes)
#Contains code for displaying plots of the Kermack-McKendrick ODEs


#import modules
from numpy import *
from numpy.random import *
from pylab import *
from CAFuncs import *
from KM_ODE import *
from SIS_ICount import *

#make sure Pygame installed
try:
	import pygame
	import pygame.surfarray as surfarray
except ImportError:
	raise ImportError, "PyGame required."
	
#initial pop N = CellNum*CellNum
#initial infected I(0) = Iinitial
#initial susceptible S(0) = N - I(0)
#initial recovered R(0) = 0

print "Run Simulation for:"
print "1)SIR Model"
print "2)SIS Model (i.e., common cold)"
ModelType = raw_input("Enter 1 or 2: ")
ModelType = int(ModelType)


print "\nHow many cells will be infected initially?"
Iinitial = raw_input("I(0) = ")
Iinitial = int(Iinitial)

print "\nHow long to run simulation?"
time = raw_input("time = ")
time = int(time)

print "\nHow many cells on each side of square?"
print "Total population = (Number of Cells)^2"
CellNum = raw_input("Number of Cells = ")
CellNum = int(CellNum)

print "\nProbability of an infected cell infecting one of its neighbors?"
p = raw_input("p = ")
p = float(p)

print "\nProbability of an infected cell recovering?"
q = raw_input("q = ")
q = float(q)

#to display plot of Kermack-McKendrick model, uncomment
#KM_Plot(CellNum, Iinitial)

#to run simulation more than once and average results, enter number of simulations to run here
SimNum = 5

CellSize = 5
size = (CellSize*CellNum,CellSize*CellNum)

pygame.init()
#Set colors for S, I, and R
SColor = 0,255,0			# Green
IColor = 255,0,0			# Red
RColor = 0,0,0 				# Black

for x in range(SimNum):
	#initialize the cells for S, I, or R
	state = Init(CellNum,Iinitial)
	
	#if model is SIRS, then can keep track of how many times (ITimes) each cell gets infected
	#(Use SIS for SIRS = Common cold)
	if ModelType == 2:
		ITimes = InitICount(CellNum,state)
		RCount = zeros((CellNum,CellNum),float)

	#array of total number of S, I, and R cells to be added at each time step
	STot = [CellNum*CellNum - Iinitial]
	ITot = [Iinitial]
	RTot = [0]
	
	#Display initialized cell distribution
	screen = pygame.display.set_mode(size)
	pygame.display.set_caption('2D SIR Simulator')
	SpDisplayState(state,CellNum,screen,CellSize,SColor,IColor,RColor)
	pygame.display.flip()

	for t in range(time):
		#to save image every 10 time steps, uncomment
		# if t%10 == 0:
			# if ModelType == 1:
				# filename = 'SIR_p'+str(p)+'_q'+str(q)+'_t'+str(t)
			# if ModelType == 2:
				# filename = 'SIS_p'+str(p)+'_q'+str(q)+'_t'+str(t)
			# pygame.image.save(screen, filename+'.bmp')
			
		for event in pygame.event.get():
			if event.type == pygame.QUIT: sys.exit()
		#fine new state
		if ModelType == 1:
			state = Update(state,p,q,CellNum)
		if ModelType == 2:
			state,RCount = Update_Cold(state,p,q,CellNum,RCount)
			ITimes = ICountUpdate(CellNum,ITimes,state)			#update the number of times each cell gets infected
		#set disply
		SpDisplayState(state,CellNum,screen,CellSize,SColor,IColor,RColor)
		#count up number of S, I, R cells and add to array
		pygame.display.flip()
		SIRCount(state,CellNum,STot,ITot,RTot)
		
	#SOut,IOut,ROut for calculating average S, I, and R of population over iterates
	SOut = zeros((SimNum,len(STot)),float)
	IOut = zeros((SimNum,len(ITot)),float)
	ROut = zeros((SimNum,len(RTot)),float)

	for i in range(len(STot)):
		SOut[x][i] = STot[i]
		IOut[x][i] = ITot[i]
		ROut[x][i] = RTot[i]
		
#to save image at end, uncomment
# if ModelType == 1:
	# filename = 'SIR_p'+str(p)+'_q'+str(q)+'_t'+str(t)
# if ModelType == 2:
	# filename = 'SIS_p'+str(p)+'_q'+str(q)+'_t'+str(t)
# pygame.image.save(screen, filename+'.bmp')
	
#SAvg,IAvg,RAvg for storing average S, I, and R at each time step
SAvg = []
IAvg = []
RAvg = []

#find the average S, I, and R at each time step
for i in range(len(STot)):
	SSum = 0
	ISum = 0
	RSum = 0
	for j in range(SimNum):
		SSum = SSum + SOut[j][i]
		RSum = RSum + ROut[j][i]
		ISum = ISum + IOut[j][i]
	SAvg.append(float(SSum))
	IAvg.append(float(ISum))
	RAvg.append(float(RSum))
	
#plot population stats versus time
x = arange(0., float(time)+1,1.)
#percentage of total population that is S, I, or R
SPercent = []
IPercent = []
RPercent = []
for i in range(len(STot)):
	SPercent.append(float(SAvg[i])/float(CellNum*CellNum))
	IPercent.append(float(IAvg[i])/float(CellNum*CellNum))
	RPercent.append(float(RAvg[i])/float(CellNum*CellNum))
#subplot(212)
plot(x,SPercent,'g',label = 'S(t)')
plot(x,IPercent,'r',label = 'I(t)')
plot(x,RPercent,'k',label = 'R(t)')
xlabel('time')
ylabel('percent of total population')
if ModelType == 1:
	title('Cellular Automata of SIR Model with p = ' + str(p) + ', q = ' + str(q))
if ModelType == 2:
	title('Cellular Automata of SIS Model with p = ' + str(p) + ', q = ' + str(q))
legend()

show()

#plot the number of times each cell gets infected
if ModelType == 2:
	SIS_Plot(ITimes)