#KM_ODE.py: Plot the percentage of population that is S, I, or R

#import modules
from pylab import *

#x,y,z = S,I,R
#f,g,h = Sdot, Idot, Rdot
def RK3DIntegrator(beta,gamma,x,y,z,f,g,h,dt):
	k1x = dt * f(beta,gamma,x,y,z)
	k1y = dt * g(beta,gamma,x,y,z)
	k1z = dt * h(beta,gamma,x,y,z)
	k2x = dt * f(beta,gamma,x + k1x / 2.0,y + k1y / 2.0,z + k1z/2.0)
	k2y = dt * g(beta,gamma,x + k1x / 2.0,y + k1y / 2.0,z + k1z/2.0)
	k2z = dt * h(beta,gamma,x + k1x / 2.0,y + k1y / 2.0,z + k1z/2.0)
	k3x = dt * f(beta,gamma,x + k2x / 2.0,y + k2y / 2.0,z + k2z/2.0)
	k3y = dt * g(beta,gamma,x + k2x / 2.0,y + k2y / 2.0,z + k2z/2.0)
	k3z = dt * h(beta,gamma,x + k2x / 2.0,y + k2y / 2.0,z + k2z/2.0)
	k4x = dt * f(beta,gamma,x + k3x,y + k3y,z + k3z)
	k4y = dt * g(beta,gamma,x + k3x,y + k3y,z + k3z)
	k4z = dt * h(beta,gamma,x + k3x,y + k3y,z + k3z)
	x = x + ( k1x + 2.0 * k2x + 2.0 * k3x + k4x ) / 6.0
	y = y + ( k1y + 2.0 * k2y + 2.0 * k3y + k4y ) / 6.0
	z = z + ( k1z + 2.0 * k2z + 2.0 * k3z + k4z ) / 6.0
	return x,y,z

def SDot(beta,gamma,S,I,R):
	Pop = S + I + R
	S = S/Pop
	I = I/Pop
	R = R/Pop
	return -beta * S * I * Pop

def IDot(beta,gamma,S,I,R):
	Pop = S + I + R
	S = S/Pop
	I = I/Pop
	R = R/Pop
	return beta * S * I * Pop - gamma * I

def RDot(beta,gamma,S,I,R):
	Pop = S + I + R
	S = S/Pop
	I = I/Pop
	R = R/Pop
	return gamma * I
	
	
def KM_Plot(CellNum,Iinitial):
	#time step dt
	dt = 0.01
	#number of time steps, N
	N = 10000

	#enter values for beta, gamma
	beta = 0.4
	gamma = 0.2

	#matrix of percentage of tot population that is S, I or R
	S = []
	I = []
	R = []

	#time at each iteration
	t = [0.]

	#initial conditions: number S(0), I(0), and R(0)
	S0 = (CellNum**2) - Iinitial
	I0 = Iinitial
	R0 = 0.

	#but need to take percentage of tot pop for functions
	Pop = float(S0+I0+R0)
	S.append(float(S0/Pop))
	I.append(float(I0/Pop))
	R.append(float(R0/Pop))

	for i in range(N):
		STemp,ITemp,RTemp = RK3DIntegrator(beta,gamma,S[i],I[i],R[i],SDot,IDot,RDot,dt)
		S.append(STemp)
		I.append(ITemp)
		R.append(RTemp)
		t.append(t[i] + dt)
	
	#subplot(211)
	plot(t,S,'g',label = 'S(t)')
	plot(t,I,'r',label = 'I(t)')
	plot(t,R,'k',label = 'R(t)')
	xlabel('time')
	ylabel('percent of total population')
	title('Kermack-McKendrick Model with beta = ' + str(beta) + ', gamma = ' + str(gamma))
	legend()
	return show()
		