# Shane.py
#
# Just a collection of functions I typically use.

from sys import *
from numpy import *

# Need a better short hand for defining a numpy array.
def a(*v):
    return array(v)

# Where are these definitions in numpy?  I couldn't find them.
def mag(a):
    return sqrt(dot(a,a))

def normalize(a):
    return a / mag(a)

# Regular sum() will not work with numpy arrays.
def sumv(vs):
    s = 0
    for v in vs:
        s += v
    return s


def rcurry(f, *a):
    return lambda *x: apply(f, x + a)

def curry(f, *a):
    return lambda *x: apply(f, a + x)

def compose(f, g):
    return lambda *x: f(apply(g, x))

def const(k):
    return lambda *a: k

def iterate(f, i, n):
    if (n == 0):
        return i
    else:
        return iterate(f, f(i), n - 1)

def iterateCollect(f, i, n):
    r = []
    def _iterateCollect(f, i, n):
        if (n == 0):
            return r
        else:
            fnew = f(i)
            r.append(fnew)
            return _iterateCollect(f, fnew, n - 1)
    _iterateCollect(f, i, n)
    return r

def minimum(l):
    return array(l).min()

def frange(x1, x2, dx):
        result = []
        x = x1
        while x < x2:
                result.append(x)
                x += dx
        return result

def menuOptions(options, default = -1):
        answer = -1
        while answer < 1 or answer > len(options):
                x = 0
                for option in options:
                        x += 1
                        print "%d. %s" % (x, option)
                if default == -1:
                    stdout.write("?")
                else:
                    stdout.write("[%d] ?" % (default + 1) )
                s = stdin.readline()
                if s.strip() == "":
                    answer = default + 1
                else:
                    answer = int(s)
        return answer - 1

def yesNo(message, default = None):
        answer = None
        response = { None : "y/n", True : "Y/n", False : "y/N" }
        while answer == None:
                stdout.write("%s (%s)" % (message, response[default]))
                answer = stdin.readline().strip().lower()
                if answer == "":
                        answer = default
                elif answer == "y":
                        answer = True
                elif answer == "n":
                        answer = False
                else:
                        answer = None
        return answer

def askFloat(message, default = None):
        answer = None
        while answer == None:
                if default == None:
                        stdout.write("%s" % message)
                else:
                        stdout.write("%s [%f] " % (message, default))
                answer = stdin.readline().strip()
                if answer == "":
                        answer = default
                else:
                        try:
                                answer = float(answer)
                        except:
                                answer = None
        return answer
