from numpy import arange

class SIRZ:
	def __init__(self, B, g, a, phi1, phi2, I):
		self.B = B
		self.a = a
		self.g = g
		self.phi = phi1*phi2
		self.phi1 = phi1
		self.phi2 = phi2
		self.S = 1.
		self.I = I
		self.R = 0.
		self.Z = 0.
		self.P = self.S+I
		
	def dsdt(self):
		return -(self.B * self.S * self.I) - self.phi*(self.S/(self.R+self.S))*self.Z
	def didt(self):
		return self.B * self.S * self.I - self.I*(self.g + self.a)
	def drdt(self):
		return self.g*self.I - self.phi*(self.R/(self.R+self.S))*self.Z
	def dzdt(self):
		return self.a*self.I - self.Z*self.phi1*(1.-self.phi2)
	
def RK3DIntegrator(ODE, dt, steps):

	TrajS = []
	TrajI = []
	TrajR = []
	TrajZ = []
	TrajP = []
	Time  = []
	
	for t in arange(0.,steps,dt):
		s4 = ODE.S
		i4 = ODE.I
		r4 = ODE.R
		z4 = ODE.Z
		
		k1s = dt * ODE.dsdt()
		k1i = dt * ODE.didt()
		k1r = dt * ODE.drdt()
		k1z = dt * ODE.dzdt()
		
		ODE.S = s4 +k1s/2.0
		ODE.I = i4 +k1i/2.0
		ODE.R = r4 +k1r/2.0
		ODE.Z = z4 +k1z/2.0
		
		k2s = dt * ODE.dsdt()
		k2i = dt * ODE.didt()
		k2r = dt * ODE.drdt()
		k2z = dt * ODE.dzdt()
		
		ODE.S = r4 +k2s/2.0
		ODE.I = i4 +k2i/2.0
		ODE.R = s4 +k2r/2.0
		ODE.Z = z4 +k2z/2.0
		
		k3s = dt * ODE.dsdt()
		k3i = dt * ODE.didt()
		k3r = dt * ODE.drdt()
		k3z = dt * ODE.dzdt()
		
		ODE.S = s4 +k3s
		ODE.I = i4 +k3i
		ODE.R = r4 +k3r
		ODE.Z = z4 +k3z
		
		k4s = dt * ODE.dsdt()
		k4i = dt * ODE.didt()
		k4r = dt * ODE.drdt()
		k4z = dt * ODE.dzdt()
		
		ODE.S = s4 + ( k1s + 2.0 * k2s + 2.0 * k3s + k4s ) / 6.0
		if ODE.S < 0.0:
			ODE.S = 0.0
		ODE.I = i4 + ( k1i + 2.0 * k2i + 2.0 * k3i + k4i ) / 6.0
		if ODE.I < 0.0:
			ODE.I = 0.0
		ODE.R = r4 + ( k1r + 2.0 * k2r + 2.0 * k3r + k4r ) / 6.0
		if ODE.R < 0.0:
			ODE.R = 0.0
		ODE.Z = z4 + ( k1z + 2.0 * k2z + 2.0 * k3z + k4z ) / 6.0
		
		ODE.P = ODE.S + ODE.I + ODE.R
		
		TrajS.append(ODE.S)
		TrajI.append(ODE.I)
		TrajR.append(ODE.R)
		TrajZ.append(ODE.Z)
		TrajP.append(ODE.P)
		Time.append(t)
	return TrajS, TrajI, TrajR, TrajZ, TrajP, Time 