"""powerofoptimization.py: Code for numerical experiments in The Power of Optimization Over Randomization in Designing Experiments Involving Small Samples by Dimitris Bertsimas, Mac Johnson, Nathan Kallus (2014)."""

__author__ = "Dimitris Bertsimas, Mac Johnson, Nathan Kallus"
__copyright__ = "Copyright 2014"
__credits__ = ["Dimitris Bertsimas", "Mac Johnson", "Nathan Kallus"]
__version__ = "1.0"
__maintainer__ = "Nathan Kallus"
__email__ = "kallus@mit.edu"

import numpy
np=numpy
from numpy import *
import scipy
import scipy.linalg
from scipy.spatial.distance import pdist, squareform
from random import shuffle
from gurobipy import *

def optassign(covardata, m, rho = 0.5):
    """ Optimally assign univariate covariates to m groups
    
    Args:
        covardata: covariate data, list/array of floats
        m: number of groups
        rho: exchange rate between first and second moments
    
    Returns:
        list of lists of indices, representing the optimal partition
    """
    data = array(covardata)
    n = len(data)
    if n%m != 0:
        raise Exception("Subjects do not divide into m groups.")
    k = n/m
    data = (data-data.mean())/data.std()
    
    vardata = map(lambda w: w*w, data)
    
    model = Model("match")
    
    model.setParam('Symmetry', 2)
    
    x = [[model.addVar(obj = 0., lb = 0., ub = 1., vtype = GRB.BINARY, name = 'x'+str(i)+','+str(j)) for i in range(j,n)] for j in range(m)]
    d = model.addVar(obj = 1., lb = 0., ub = GRB.INFINITY, vtype = GRB.CONTINUOUS, name = 'd')
    
    mu = [quicksum(map(operator.mul, data[p:], x[p])) for p in range(m)]
    sig = [quicksum(map(operator.mul, vardata[p:], x[p])) for p in range(m)]
    
    model.update()
    
    for p in range(m):
        model.addConstr(quicksum(x[p]) == k)
        for q in range(p+1,m):
            model.addConstr(float(k)*d >= mu[p] - mu[q] + rho*sig[p] - rho*sig[q])
            model.addConstr(float(k)*d >= mu[p] - mu[q] + rho*sig[q] - rho*sig[p])
            model.addConstr(float(k)*d >= mu[q] - mu[p] + rho*sig[p] - rho*sig[q])
            model.addConstr(float(k)*d >= mu[q] - mu[p] + rho*sig[q] - rho*sig[p])
    
    for i in range(n):
        model.addConstr(quicksum([x[p][i-p] for p in range(min(i+1,m))]) == 1)
    
    model.optimize()
    
    assgn = [[i for i in range(j,n) if x[j][i-j].x>.99] for j in range(m)]
    shuffle(assgn)
    return assgn

def optassign2(covardata, m, rho = 0.5):
    """ Optimally assign univariate covariates to 2 groups
    
    Args:
        covardata: covariate data, list/array of floats
        m: number of groups, must be = 2
        rho: exchange rate between first and second moments
    
    Returns:
        list of lists of indices, representing the optimal partition
    """
    data = array(covardata)
    n = len(data)
    if m > 2 or n%2 != 0:
        raise Exception("Only 2 groups allowed.")
        
    k=n/2
        
    data = (data-data.mean())/data.std()
    
    vardata = map(lambda w: w*w, data)
    
    model = Model("match")
    
    model.setParam('Threads',1)
    
    model.setParam('Symmetry', 2)
    
    x = [1,]+[model.addVar(obj = 0., lb = 0., ub = 1., vtype = GRB.BINARY, name = 'x'+str(i)) for i in range(n-1)]
    d = model.addVar(obj = 1., lb = 0., ub = GRB.INFINITY, vtype = GRB.CONTINUOUS, name = 'd')
    
    mudiff  = quicksum(data[i]*(2*x[i]-1) for i in range(n))
    sigdiff = quicksum(vardata[i]*(2*x[i]-1) for i in range(n))
    
    model.update()
    
    model.addConstr(quicksum(x) == n/2)
    model.addConstr(float(k)*d >= mudiff + rho*sigdiff)
    model.addConstr(float(k)*d >= mudiff - rho*sigdiff)
    model.addConstr(float(k)*d >= -mudiff + rho*sigdiff)
    model.addConstr(float(k)*d >= -mudiff - rho*sigdiff)
    
    model.optimize()
    
    assgn = [[0,]+[i for i in range(1,n) if x[i].x>.5], [i for i in range(1,n) if x[i].x<=.5]]
    shuffle(assgn)
    return assgn

def safeinvcov(w):
    n,d = w.shape
    if d==1:
        return array(1./var(w)).reshape((1,1))
    else:
        covar = cov(w,rowvar=0)
        if linalg.det(covar)==0.:
            return scipy.linalg.pinv(cov(w,rowvar=0))
        else:
            return scipy.linalg.inv(cov(w,rowvar=0))

def optassign2multi(covardata, m, rho = 0.5):
    """ Optimally assign multivariate covariates to 2 groups
    
    Args:
        covardata: covariate data, n-by-d lists/array of floats
        m: number of groups, must be = 2
        rho: exchange rate between first and second moments
    
    Returns:
        list of lists of indices, representing the optimal partition
    """
    data = array(covardata)
    n,d = shape(data)
    
    s = safeinvcov(data)
    ss=scipy.linalg.sqrtm(s)
    data = dot(data-data.mean(0),ss)
    
    if m > 2 or n%2 != 0:
        raise Exception("Only 2 groups allowed.")
    
    k=n/2
    
    covdata = vstack(outer(x,x)[triu_indices(d,-1)] for x in data)
    vardata = data*data
    
    model = Model("match")

    model.setParam('Threads',1)
    
    model.setParam('Symmetry', 2)
    
    x = [1,]+[model.addVar(obj = 0., lb = 0., ub = 1., vtype = GRB.BINARY, name = 'x'+str(i)) for i in range(n-1)]

    muabsd  = [model.addVar(obj = 1., lb = 0., ub = GRB.INFINITY, vtype = GRB.CONTINUOUS, name = 'muabsd'+str(i)) for i in range(d)]
    varabsd = [model.addVar(obj = rho, lb = 0., ub = GRB.INFINITY, vtype = GRB.CONTINUOUS, name = 'varabsd'+str(i)) for i in range(d)]
    covabsd = [model.addVar(obj = 2.*rho, lb = 0., ub = GRB.INFINITY, vtype = GRB.CONTINUOUS, name = 'covabsd'+str(i)) for i in range((d*(d-1))/2)]
    
    mudiffs   = [quicksum(data[i,j]*(2*x[i]-1)    for i in range(n)) for j in range(d)]
    vardiffs  = [quicksum(vardata[i,j]*(2*x[i]-1) for i in range(n)) for j in range(d)]
    covdiffs  = [quicksum(covdata[i,j]*(2*x[i]-1) for i in range(n)) for j in range((d*(d-1))/2)]
    
    model.update()
    
    model.addConstr(quicksum(x) == n/2)
    for i in range(d):
        model.addConstr(muabsd[i]  >=  mudiffs[i])
        model.addConstr(muabsd[i]  >= -mudiffs[i])
        model.addConstr(varabsd[i] >=  vardiffs[i])
        model.addConstr(varabsd[i] >= -vardiffs[i])
    for i in range((d*(d-1))/2):
        model.addConstr(covabsd[i] >=  covdiffs[i])
        model.addConstr(covabsd[i] >= -covdiffs[i])
        
    model.optimize()
    
    assgn = [[0,]+[i for i in range(1,n) if x[i].x>.5], [i for i in range(1,n) if x[i].x<=.5]]
    shuffle(assgn)
    return assgn

def completerandomization(n, m):
    """ Assign subjects completely at random
    
    Args:
        n: number of subjects
        m: number of groups    
    Returns:
        list of lists of indices
    """
    if n%m != 0:
        raise Exception("Subjects do not divide into m groups.")
    k = n/m
    dataIndex = range(n)
    random.shuffle(dataIndex)
    return [dataIndex[j*k : (j+1)*k] for j in range(m)]

def pairmatch(covardata, m):
    """ Pairwise-match assign univariate covariates to 2 groups
    
    Args:
        covardata: covariate data, list/array of floats
        m: number of groups, must be = 2
    
    Returns:
        list of lists of indices
    """
    data = array(covardata)
    n = len(data)
    if m > 2 or n%2 != 0:
        raise Exception("Only 2 groups allowed.")        
    k=n/2
    
    xsorted = sorted(range(n), key = data.__getitem__)
    grouping = [[],[]]
    for i in range(k):
        indx = xsorted[2*i:2*i+2]
        random.shuffle(indx)
        for j in range(2):
            grouping[j].append(indx[j])
    return grouping

import networkx as nx
def pairmatchmulti(covardata, m):
    """ Pairwise-match assign multivariate covariates to 2 groups
    
    Args:
        covardata: covariate data, n-by-d lists/array of floats
        m: number of groups, must be = 2
    
    Returns:
        list of lists of indices
    """
    w = array(covardata)
    n = len(w)
    if m > 2 or n%2 != 0:
        raise Exception("Only 2 groups allowed.")        
    k=n/2
    
    s = safeinvcov(w)
    n = len(w)
    n2 = n/2
    D = squareform(pdist(w, 'mahalanobis', VI = s))
    match = nx.matching.max_weight_matching(nx.Graph(-D),True)
    xx = array([0,1])
    seen = []
    result = zeros(n)
    for i in m:
        if i in seen:
            continue
        seen.append(i)
        seen.append(match[i])
        random.shuffle(xx)
        result[i] = xx[0]
        result[match[i]] = xx[1]
    return [[i for i in range(n) if result[i]>0.5], [i for i in range(n) if result[i]<=0.5]]

def rerand(covardata, m, p=0.05, nrand=500):
    """ Re-randomization assign univariate covariates to 2 groups
    
    Args:
        covardata: covariate data, list/array of floats
        m: number of groups, must be = 2
        p: acceptance probability (exact)
        nrand: number of re-randomizations
    
    Returns:
        list of lists of indices
    """
    data = array(covardata)
    n = len(data)
    if m > 2 or n%2 != 0:
        raise Exception("Only 2 groups allowed.")
    n2 = n/2
    xx = array([-1,]*n2+[1,]*n2)
    l = []
    for i in xrange(nrand):
        random.shuffle(xx)
        y = dot(xx,data)/float(n2)
        l.append((y*y, xx.tolist()))
    l.sort(key=lambda x: x[0])
    ii = random.randint(int(nrand*p)+1)
    xx = l[ii][1]
    grouping = [[j for j in range(n) if xx[j]==1],[j for j in range(n) if xx[j]==-1]]
    return grouping

def rerandmulti(covardata, m, p=0.05, nrand=500):
    """ Re-randomization assign multivariate covariates to 2 groups
    
    Args:
        covardata: covariate data, n-by-d lists/array of floats
        m: number of groups, must be = 2
    
    Returns:
        list of lists of indices
    """
    w = array(covardata)
    n = len(w)
    if m > 2 or n%2 != 0:
        raise Exception("Only 2 groups allowed.")        
    n2 = n/2
    
    s = safeinvcov(w)    
    xx = array([-1,]*n2+[1,]*n2)
    l = []
    for i in xrange(nrand):
        random.shuffle(xx)
        y = dot(xx,w)/float(n2)
        l.append((dot(dot(y.T,s),y), xx.tolist()))
    l.sort(key=lambda x: x[0])
    ii = random.randint(int(nrand*p)+1)
    xx = l[ii][1]
    grouping = [[j for j in range(n) if xx[j]==1],[j for j in range(n) if xx[j]==-1]]
    return grouping

def FSM(covardata, m):
    """ Assign covariates using finite selection method
    
    Args:
        covardata: covariate data, either a list/array of floats (univariate) or a n-by-d lists/array of floats (multivariate)
        m: number of groups
    
    Returns:
        list of lists of indices
    """
    data = array(covardata)
    n = len(data)
    if n%m != 0:
        raise Exception("Subjects do not divide into m groups.")
    
    nd = len(ravel(data))
    d = nd/n
    X = data.reshape((n,d))
    indexes = [[] for i in range(m)]
    reservoir = range(n)
    k = 0
    while len(reservoir) > 0:
        Xn = X[indexes[k]]
        Sn = linalg.pinv(dot(Xn.T,Xn))
        Sn2 = dot(Sn,Sn)
        besti = 0
        bestv = inf
        for i in reservoir:
            x = X[i]
            v = dot(x.T,dot(Sn2,x))/(1.+dot(x.T,dot(Sn,x)))
            if v < bestv:
                bestv = v
                besti = i
        reservoir.remove(besti)
        indexes[k].append(besti)
        k+=1
        k%=m
    
    shuffle(indexes)
    
    return indexes    

def hypt(covardata, initialassignment, outcomes, assigner, boot = True, T = 159):
    """ Test for differences in two treatments using a randomization or bootstrap test
    
    Args:
        covardata: covariate data.
        initialassignment: the assignment that was used in the experiment.
        outcomes: the recorded outcomes after treatment.
        assigner: assignment mechanism used as a lambda taking covariate data.
        boot: whether to use the bootstrap test, otherwise the randomization test.
        T: number of randomization or bootstrap runs.
        
    Returns:
        (
            estimate of the effect,
            p-value
        )
    
    Example:
        record covariates
        >>> data = 5.*random.randn(30) # fake covariate data
        design subject assignment
        >>> initialassignment = optassign2(data, 2)
        run the experiment and measure outcomes
        >>> outcomes = data+data*data+array([1.0 if i in initialassignment[0] else 0. for i in range(30)])+random.randn(30) # fake outcomes with an effect size of 1
        test for a difference in treatments
        >>> (estimator, pvalue) = hypt(data, initialassignment, outcomes, lambda d: optassign2(d, 2))
    """
    data = array(covardata)
    n = len(data)
    n2 = n/2
    mystat = (sum(outcomes[i] for i in initialassignment[0])-sum(outcomes[i] for i in initialassignment[1]))/float(n2)
    stats = zeros(T)
    for t in range(T):
        if boot:
            indexes = np.random.randint(n, size=n)
            assgnt = list(assigner(data[indexes]))
        else:
            indexes = np.arange(n)
            np.random.shuffle(indexes)
            assgnt = list(assigner(data[indexes]))
        stats[t] = (sum(outcomes[indexes[i]] for i in assgnt[0])-sum(outcomes[indexes[i]] for i in assgnt[1]))/float(n2)
    return (mystat,float(1+(abs(stats[:])>=abs(mystat)).sum())/float(1+T))


import scipy.stats
from scipy.integrate import odeint

def gompexdtlambda(a,alpha,xc):
	return lambda x,t: x*max(a,alpha * log(xc * exp(a/alpha) / x))

def gompex(x,a,alpha,xc):
	return odeint(gompexdtlambda(a,alpha,xc),x,[0.,1.])[1,0]

def doExperiment(k, effect):
    """ Run an instance of the experiment in Section 6
    
    Args:
        k: number of subjects per group
        effect: negative of delta_0 as described in Section 6
    
    Returns:
        List of (estimator, p-value) pairs for opt, comp rand, pair match, and re-rand
    
    Example:
        >>> doExperiment(15, -50.0)
    """
    tumor0mean	= 300.0 # mean
    tumor0std	= 200.0 # standard deviation
    tumor0center= 300.0
    mya     = 1.
    myalpha = 5.
    myxc    = 400.
    
    n = 2*k
    
    assigners = [
        lambda d: optassign2(d,2,0.5), # opt
        lambda d: completerandomization(len(d),2), # comp rand
        lambda d: pairmatch(d, 2), # pair match
        lambda d: rerand(d, 2), # re-rand
    ]
    
    tumor0 = scipy.stats.truncnorm.rvs((0-tumor0mean)/tumor0std,np.inf,loc=tumor0mean,scale=tumor0std,size=n)
    data = (tumor0-tumor0mean)/tumor0std
    
    assignments = [assigner(tumor0) for assigner in assigners]
    
    controloutcomes = np.array([gompex(x0,mya,myalpha,myxc) for x0 in tumor0])
    
    outcomes = [controloutcomes+np.array([effect if i in assignment[0] else 0. for i in range(n)]) for assignment in assignments]
    
    return [hypt(tumor0, assignments[j], outcomes[j], assigners[j], j==0) for j in range(len(assigners))]