# Jesse Singh
# Spectral Equations

# Note: This program requires access to cmpy

from numpy import *
import cmpy

def rIs1(q1, qinf, value):
	# The r=1 Equations are:
	# 1.	Pr(0) = Pr(00) + Pr(10) = Pr(00) + Pr(01)
	# 2.	Pr(11) + Pr(10) + Pr(01) + Pr(00) = 1
	# 3.	q(1) = Pr(11) + Pr(00)
	# 4.	qinf = (Pr(00) + Pr(01))**2 + (Pr(10) + Pr(11))**2
	#
	# The equations can be simplified by noting certain relationships:
	#	Pr(01) = Pr(10) = (1/2) * (1 - q(1))
	#	Pr(11) = q(1) - Pr(00)
	#	0 = 2*Pr(00)**2 -2*q(1)*Pr(00) + 2*Pr(01)**2 + q(1)**2 + 2*q(1)*Pr(01) - qinf
	#	
	#	a = 2
	#	b = -2*q(1)
	#	c = 2*Pr(01)**2 + q(1)**2 + 2*q(1)*Pr(01) - qinf
	#	Which becomes:
	# Pr(00) is found by solving the following, where y = Pr(00)
	#	0 = a*y**2 + b*y + c
	# Which is:
	#	Pr(00) = (-b +- sqrt(b**2 - 4 * a * c ))/(2*a)
	
	Pr01 = 0.5 * (1 - q1)
	Pr10 = Pr01
	
	a = 2
	b = -2*q1
	c = 2*Pr01**2 + q1**2 + 2*q1*Pr01 - qinf
	
	Pr00plus = ( -b + sqrt(b**2 - 4 * a * c ))/(2*a)
	Pr00minus = ( -b - sqrt(b**2 - 4 * a * c ))/(2*a)
	if Pr00plus > 0:
		Pr00 = Pr00minus
	else:
		Pr00 = Pr00plus
	Pr11 = q1 - Pr00

	Pr0 = Pr01 + Pr00
	Pr1 = Pr10 + Pr11

	Pr0 = cmpy.log2(Pr0)
	Pr1 = cmpy.log2(Pr1)
	Pr00 = cmpy.log2(Pr00)
	Pr01 = cmpy.log2(Pr01)
	Pr10 = cmpy.log2(Pr10)
	Pr11 = cmpy.log2(Pr11)

	L1 = {(0,): Pr0, (1,): Pr1}
	L2 = {(0,0): Pr00, (0,1): Pr01, (1,0): Pr10, (1,1): Pr11}
	
	if value == 1:
		return L1
	else:
		return L2

def rIs2(q1, q2, q3, qinf):
	# 1. Pr001 - Pr100 = 0
	# 2. Pr011 - Pr110 = 0
	# 3. Pr001 + Pr101 - Pr011 - Pr010 = 0
	# 4. Pr111 + Pr101 + Pr011 + Pr001 + Pr110 + Pr100 + Pr010 + Pr000 = 1
	# 5. q1 = Pr111 + Pr110 + Pr000 + Pr001
	# 6. q2 = Pr111 + Pr101 + Pr000 + Pr010
	# 7. qinf = (Pr000 + Pr001 + Pr010 + Pr011)**2 + (Pr100 + Pr101 + Pr110 + Pr111)**2
	# 8. q3 = ((Pr111**2)/(Pr111 + Pr110)) + ((Pr110*Pr101)/(Pr100 + Pr101))+ ((Pr101*Pr011)/(Pr010 + Pr011)) + ((Pr100*Pr001)/(Pr000 + Pr001)) + ((Pr000**2)/(Pr000 + Pr001)) + ((Pr001*Pr010)/(Pr010 + Pr011)) + ((Pr010*Pr100)/(Pr100 + Pr101)) + ((Pr011*Pr110)/(Pr111 + Pr110))
	#
	# Some relationships become clear:
	#
	# Pr001 = Pr100
	# Pr011 = Pr110
	# Pr011 + Pr001 = (1 - q2)/2  CORRECT!
	# Pr101 + Pr010 = (q2 + 1 - 2*q1)/2 CORRECT!
	# Pr010 + Pr011 = (1 - q1)/2 CORRECT!
	# qinf = (Pr000 + ((1 - q2)/2) + Pr010)**2 + (((1 + q2)/2) - Pr111 - Pr101)**2
	# qinf = (Pr000 + ((1 - q2)/2) + Pr010)**2 + (((1 - q2)/2) + Pr101 + Pr111)**2
	# Pr111 + Pr000 = (2*q1 + q2 - 1)/2 CORRECT!
	#
	# Some more become clear in the code

	mp = (1 - q1)/2
	np = (1 - q2)/2
	op = (q2 + 1 - 2*q1)/2
	pp = (2*q1 + q2 - 1)/2
	wp = (q1 - q2)/2
	
	A = 2
	B = 4*q1 - 2
	C = 2*(q1**2) - 2*q1 + 1 - qinf
	
	if B**2 - 4*A*C < 0:
		biSolplus = -B / (2*A)
		biSolminus = biSolplus
	else:
		biSolplus = (-B + sqrt(B**2 - 4*A*C))/(2*A)
		biSolminus = (-B - sqrt(B**2 - 4*A*C))/(2*A)

	if biSolplus < 0:
		biSol = biSolplus
	else:
		biSol = biSolminus
	
	cp = mp - biSol
	dp = q1 - cp
	ep = pp - cp
	ap = (np - wp)*(op + wp)

	comDenom = cp*dp*(mp**2)
	AP = ( ap * ( 2 * cp + 2 * dp ) + 2 * cp *dp )/comDenom
	BP = ( ap * ( 2 * cp * ep - 2 * np * cp - 2 * cp * dp) - cp * dp * ( 2 * wp + np + op ))/comDenom
	CP = ( ap * ( cp * ( ep**2 ) + ( cp**2 ) * dp - q3 * cp * dp + cp * ( np**2 )) + cp * dp * np * ( op + wp ))/comDenom

	if BP**2 - 4*AP*CP < 0:
		gplus = -BP / (2*AP)
		gminus = gplus
	else:
		gplus = (-BP + sqrt(BP**2 - 4*AP*CP))/(2*AP)
		gminus = (-BP - sqrt(BP**2 - 4*AP*CP))/(2*AP)
	
	if gplus > 0:
		g = gplus
	else:
		g = gminus

	# Setting g = 0 solves some problems
	if mp == 0:
		g = 0
	elif cp == 0:
		g = 0
	elif dp == 0:
		g = 0

	h = cp - g
	a = pp - h
	b = np - g
	c = mp - g
	d = b
	e = op - c
	f = g
	
	a = cmpy.log2(a)
	b = cmpy.log2(b)
	c = cmpy.log2(c)
	d = cmpy.log2(d)
	e = cmpy.log2(e)
	f = cmpy.log2(f)
	g = cmpy.log2(g)
	h = cmpy.log2(h)

	L3 = {(0,0,0): a, (0,0,1): b, (0,1,0): c, (1,0,0): d, (1,0,1): e, (1,1,0): f, (0,1,1): g, (1,1,1): h}
	return L3

print "This program returns the L = 1, 2, and 3 words\n"
q1 = float(raw_input("q1: "))
q2 = float(raw_input("q2: "))
q3 = float(raw_input("q3: "))
qinf = float(raw_input("qinf: "))
L1 = rIs1(q1, qinf, 1)
L2 = rIs1(q1, qinf, 2) 
L3 = rIs2(q1,q2,q3,qinf)

print "\n1. L = 1"
print "2. L = 2"
print "3. L = 3"
print "4. No display, save all figures"
choice = int(raw_input("Choose an option: "))

if choice == 1:
	cmpy.plot_distribution(title="L = 1", distribution=L1, alphabet=(0,1), word_length=1, use_logs=True, ylim = (-5, 3), save_file='lis1')
elif choice == 2:
	cmpy.plot_distribution(title="L = 2", distribution=L2, alphabet=(0,1), word_length=2, use_logs=True, ylim = (-5, 3),save_file='lis2')
elif choice == 3:
	cmpy.plot_distribution(title="L = 3", distribution=L3, alphabet=(0,1), word_length=3, use_logs=True, ylim = (-5, 3), save_file='lis3')
elif choice == 4:
	cmpy.plot_distribution(title="L = 1", display=False, distribution=L1, alphabet=(0,1), word_length=1, use_logs=True, ylim = (-5, 3), save_file='lis1')
	cmpy.plot_distribution(title="L = 2", display=False, distribution=L2, alphabet=(0,1), word_length=2, use_logs=True, ylim = (-5, 3), save_file='lis2')
	cmpy.plot_distribution(title="L = 3", display=False, distribution=L3, alphabet=(0,1), word_length=3, use_logs=True, ylim = (-5, 3), save_file='lis3')
else:
	print "Error\n"
