"""OptimalAPriori.py:
    Code for design algorithms, hypothesis tests, and numerical experiments in
    Optimal A Priori Balance in the Design of Controlled Experiments by
    Nathan Kallus."""

__author__ = "Nathan Kallus"
__credits__ = ["Nathan Kallus"]
__version__ = "1.0"
__maintainer__ = "Nathan Kallus"
__email__ = "kallus@cornell.edu"

try:
    import mosek 
    import mosek.fusion 
    from   mosek.fusion import * 
    mosek_available = True
except ImportError:
    mosek_available = False

try:
    import gurobipy as gp
    gurobi_available = True
except ImportError:
    gurobi_available = False

try:
    import networkx as nx
    networkx_available = True
except ImportError:
    networkx_available = False

import numpy as np
import numpy.linalg
import scipy
import scipy.linalg
from scipy.spatial.distance import pdist, squareform

def CompleteRand(n, B = 1):
    """
    Draw B assignments of subjects from a completely randomized design
    
    Args:
        n: number of subjects
        B: number of draws
    Returns:
        list of lists of +/-1 denoting assignment
    """
    A = []
    for b in range(B):
        n2 = n/2
        zz = np.array([-1,]*n2+[1,]*n2)
        np.random.shuffle(zz)
        A.append(zz.tolist())
    return A

def BlockOrthant(x, B = 1):
    """
    Draw B assignments of subjects from a design blocking on the sign of each
    covariate (i.e., block by orthant)
    
    Args:
        x: n by d array of covariates
        B: number of draws
    Returns:
        list of lists of +/-1 denoting assignment
    """
    A = []
    for b in range(B):
        n,d = x.shape
        nar = np.arange(n)
        quad = reduce(lambda z,y: 2*z+y, np.signbit(x.T))
        leftovers = []
        assgn = np.zeros(n,dtype=np.int_)
        for j in set(quad):
            idx = nar[quad==j].tolist()
            if len(idx)%2 != 0:
                idxleft = idx[np.random.randint(len(idx))]
                leftovers.append(idxleft)
                idx.remove(idxleft)
            nb2 = len(idx)/2
            zz = np.array([-1,]*nb2+[1,]*nb2,dtype=np.int_)
            np.random.shuffle(zz)
            assgn[idx] = zz
        nb2 = len(leftovers)/2
        zz = np.array([-1,]*nb2+[1,]*nb2,dtype=np.int_)
        np.random.shuffle(zz)
        assgn[leftovers] = zz
        A.append(assgn.tolist())
    return A

def PairwiseMatch(x, B = 1):
    """
    Draw B assignments of subjects from the optimal pairwise-matched design
    with respect to the Mahalanobis metric
    
    Args:
        x: n by d array of covariates
        B: number of draws
    Returns:
        list of lists of +/-1 denoting assignment
    """
    if not networkx_available:
        raise ImportError('NetworkX not available.')
    s = safeinvcov(x)
    n = len(x)
    n2 = n/2
    D = squareform(pdist(x, 'mahalanobis', VI = s))
    m = nx.matching.max_weight_matching(nx.Graph(-D),True)
    A = []
    for b in range(B):
        zz = np.array([0,1])
        seen = []
        result = np.zeros(n)
        for i in m:
            if i in seen:
                continue
            seen.append(i)
            seen.append(m[i])
            np.random.shuffle(zz)
            result[i] = zz[0]
            result[m[i]] = zz[1]
        A.append(map(lambda t: 1 if t>0.5 else -1, result))
    return A

def safeinvcov(x):
    """
    Safely invert the sample covariance matrix
    """
    n,d = x.shape
    if d==1:
        return array(1./np.var(x)).reshape((1,1))
    else:
        covar = np.cov(x,rowvar=0)
        if np.linalg.det(covar)==0.:
            return scipy.linalg.pinv(covar)
        else:
            return scipy.linalg.inv(covar)

def RerandMR(x, B = 1, p = 0.01):
    """
    Draw B assignments of subjects from the re-randomized design a la Morgan &
    Rubin with (exact) acceptance probability p
    
    Args:
        x: n by d array of covariates
        B: number of draws
        p: acceptance probability
    Returns:
        list of lists of +/-1 denoting assignment
    """
    nrand = int(float(B)/p)
    s = safeinvcov(x)
    n = len(x)
    n2 = n/2
    zz = np.array([-1,]*n2+[1,]*n2)
    l = []
    for i in xrange(nrand):
        np.random.shuffle(zz)
        y = np.dot(zz,x)/float(n2)
        l.append((np.dot(np.dot(y.T,s),y), zz.tolist()))
    l.sort(key=lambda z: z[0])
    return [map(lambda t: -1 if t>0 else 1, ll[1]) for ll in l[:B]]

def QuadMatch(K, B = 1):
    """
    Return the top B solutions in increasing objective value to the following
    quadratic optimization problem (with symmetry eliminated)
    minimize    u^T K u
    subject to  u in {-1, +1}^n
                sum_i u_i = 0
                u_1 = -1
    
    Args:
        K: d by d array representing a PSD matrix
        B: number of solutions
    Returns:
        list of lists of +/-1 denoting assignment
    """
    if not gurobi_available:
        raise ImportError('Gurobi not available.')
    zs = []
    objs = []
    n=len(K)
    k=n/2
    K1 = np.dot(np.ones((1,n)),K)
    K11 = np.dot(K1,np.ones((n,1)))
    m = gp.Model()
    m.setParam("OutputFlag", 0)
    m.setParam('Threads',1)
    z=[0,]+[m.addVar(lb = 0., ub = 1., vtype=gp.GRB.BINARY) for i in range(n-1)]
    m.update()
    m.setObjective(gp.quicksum(float(4.*K[i,j])*z[i]*z[j] for i in range(1,n) 
        for j in range(1,n))+gp.quicksum(float(-4.*K1[0,i])*z[i]
        for i in range(1,n))+float(K11[0,0]), gp.GRB.MINIMIZE)
    m.addConstr(gp.quicksum(z[1:])==k)
    m.optimize()
    zs.append((0,)+tuple(1 if zz.x>.5 else 0 for zz in z[1:]))
    objs.append(m.getAttr('ObjVal'))
    for b in range(B-1):
        if (b % (B/10)==0): print 'QuadMatch: retrieving',b,'th solution'
        try:
            m.addConstr(gp.quicksum(zz for zz in z[1:] if zz.x>.5) <= k-1)
            m.optimize()
            zs.append((0,)+tuple(1 if zz.x>.5 else 0 for zz in z[1:]))
            objs.append(m.getAttr('ObjVal'))
        except:
            break
    return [[2*zz-1 for zz in z] for z in zs]

def PSOD(K):
    """
    Return the single top solution to the following quadratic optimization
    problem (with symmetry eliminated)
    minimize    u^T K u
    subject to  u in {-1, +1}^n
                sum_i u_i = 0
                u_1 = -1
    
    Args:
        K: d by d array representing a PSD matrix
    Returns:
        list of +/-1 denoting assignment (first is always -1 so always 
        randomize the sign; see PSODDraw below)
    """
    return QuadMatch(K,1)[0]

def PSODBoot(K, B = 1):
    """
    Draw B bootstrap resamples of indices and for each one, return the indices
    and the (single) optimal solution to the quadratic optimization problem
    minimize    u^T K[bootstrapped indices] u
    subject to  u in {-1, +1}^n
                sum_i u_i = 0
                u_1 = -1
    
    Args:
        K: d by d array representing a PSD matrix
        B: number of bootstrap draws
    Returns:
        list of pairs of indices and lists of +/-1 denoting assignment
    """
    n = len(K)
    A = []
    for b in range(B):
        if (b % (B/10)==0): print 'PSODBoot: running',b,'th resample'
        indexes = np.random.randint(n, size=n)
        A.append( (indexes.tolist(), PSOD(K[indexes][:,indexes])) )
    return A

def SDPHeuristic(K, us, mixbound = 0.05):
    """
    Solve the semi-definite optimization problem in Algorithm 4.2
    """
    if not mosek_available:
        raise ImportError('Mosek not available.')
    n = len(K)
    (l,v)=np.linalg.eig(K)
    l=np.real(l)
    v=np.real(v)
    l[l<0]=0
    Ksqrt=np.dot(np.dot(v,np.diag(np.sqrt(l))),v.T)
    ZZs=[DenseMatrix(np.dot(np.dot(Ksqrt,np.outer(zz,zz)),
            Ksqrt).astype(np.float_).tolist()) for zz in us]
    I=Matrix.diag([1.,]*n)
    with Model("match") as M:
        M.setSolverParam('numThreads', 1)
        t = M.variable('t', len(ZZs), Domain.greaterThan(0.0)
            if mixbound==None else Domain.inRange(0.0, mixbound))
        z = M.variable("z",Domain.greaterThan(0.0))
        sum1cons=M.constraint(Expr.sum(t), Domain.equalsTo(1.0)) 
        opnormcons=M.constraint("z>=opnorm", Expr.sub(Expr.mul(z,I),
            reduce(Expr.add, (Expr.mul(t.index(i), ZZs[i])
            for i in range(len(ZZs)))) ), Domain.inPSDCone(n))
        M.objective(ObjectiveSense.Minimize, z)
        M.acceptedSolutionStatus(AccSolutionStatus.Anything) 
        M.solve()
        return (z.level()[0], t.level(), M.getPrimalSolutionStatus())

def MSODHeuristic(K, B):
    """
    Compute the MSOD as per heuristic Algorithm 4.3
    
    Args:
        K: d by d array representing a PSD matrix
        B: number of top solutions to use
    Returns:
        weights for each assignment vector,
        list of lists of +/-1 denoting assignment
        (first of each assignment is always -1 so always randomize the sign;
         see MSODDraw below)
    """
    print 'MSODHeuristic: getting top', B,'solutions'
    us = QuadMatch(K,B)
    print 'MSODHeuristic: solving SDP'
    res = SDPHeuristic(K, us)
    if type(res) != tuple:
        return res
    z2,t,stat2 = res
    return (t, us)

def LinearKernel(x, normalize=True):
    """
    Compute the Gram matrix for the linear kernel
    
    Args:
        x:          n by d array of covariates
        normalize:  whether to normalize the data
    Returns:
        d by d Gram matrix
    """
    if normalize:
        s = safeinvcov(x)
        xc = x - x.mean(0)
        return np.dot(np.dot(xc,s),xc.T)
    else:
        return np.dot(x,x.T)

def GaussianKernel(x, s=1., normalize=True):
    """
    Compute the Gram matrix for the Gaussian kernel
    
    Args:
        x:          n by d array of covariates
        s:          bandwidth
        normalize:  whether to normalize the data
    Returns:
        d by d Gram matrix
    """
    pairwise_dists = squareform(pdist(x, 'mahalanobis')**2 if normalize
                                else pdist(x, 'sqeuclidean'))
    return np.exp(-pairwise_dists / s**2)

def PolynomialKernel(x, deg=2, normalize=True):
    """
    Compute the Gram matrix for the polynomial kernel
    
    Args:
        x:          n by d array of covariates
        deg:        degree
        normalize:  whether to normalize the data
    Returns:
        d by d Gram matrix
    """
    if normalize:
        s = safeinvcov(x)
        xc = x - x.mean(0)
        return ((np.dot(np.dot(xc,s),xc.T)/float(deg)+1.)**deg)
    else:
        return ((np.dot(x,x.T)/float(deg)+1.)**deg)

def ExpKernel(x, normalize=True):
    """
    Compute the Gram matrix for the exponential kernel
    
    Args:
        x:          n by d array of covariates
        normalize:  whether to normalize the data
    Returns:
        d by d Gram matrix
    """
    if normalize:
        s = safeinvcov(x)
        xc = x - x.mean(0)
        return np.exp(np.dot(np.dot(xc,s),xc.T))
    else:
        return np.exp(np.dot(x,x.T))

def MSODDraw(msod):
    """
    Draw a random assignment from the MSOD
    
    Args:
        msod: the output of MSODHeuristic
    Returns:
        list of +/-1 denoting assignment
    """
    cs = np.cumsum(msod[0])/np.sum(msod[0])
    idx = np.sum(cs < np.random.rand())
    return (np.array(msod[1][idx])*((-1)**np.random.randint(2))).tolist()

def PSODDraw(psod):
    """
    Draw a random assignment from the PSOD
    
    Args:
        psod: the output of MSODHeuristic
    Returns:
        list of +/-1 denoting assignment
    """
    return (np.array(psod)*((-1)**np.random.randint(2))).tolist()

def Test_Boot(psodboot, y, tauhat):
    """
    Run PSOD bootstrap test (Algorithm 5.1)
    
    Args:
        psodboot: list of pairs of indices and lists of +/-1 denoting
                  assignment (output of PSODBoot) denoting bootstrap resamples
                  of indices and the corresponding PSOD on the resampled data
                  (with symmetry eliminated)
        y:        length n array of recorded outcomes
        tauhat:   the measured simple differences estimator
    Returns:
        p value
    Example:
        >>> x = np.random.randn(120).reshape((30,4))
        >>> K = PolynomialKernel(x)
        >>> psod = PSOD(K)
        >>> u = PSODDraw(psod)
        >>> # assign according to u, apply treatments, record y
        >>> # or, simulate fake outcomes like below for the sake of example
        >>> tau = 1.
        >>> y = np.linalg.norm(x, axis=1)**2 + 0.5*(np.array(u)+1.)*tau
        >>> # compute our estimator (should be based on actual y outcomes)
        >>> tauhat = np.dot(u, y)/float(len(y)/2)
        >>> print 'The estimate of effect is', tauhat
        >>> psodboot = PSODBoot(K, 100)
        >>> pval = Test_Boot(psodboot, y, tauhat)
        >>> print 'The p-value of zero effect is', pval
    """
    taus   = np.array([np.dot(u, y[idx])/float(len(y)/2) for idx,u in psodboot])
    pval   = (2*np.sum(np.abs(taus)>=np.abs(tauhat))+1.)/float(2*len(taus)+1)
    return pval

def Test_WeightRand(msod, y, tauhat):
    """
    Run Fisher's permutation test (Algorithm 8.1)
    
    Args:
        msod:   a pair consisting of weights for each assignment vector and
                list of lists of +/-1 denoting assignmentlist of lists of +/-1
                denoting assignment (output of MSODHeuristic), where one of
                these assignments was the one actually used (at the prescribed
                probability)
        y:      length n array of recorded outcomes
        tauhat: the measured simple differences estimator
    Returns:
        p value
    Example:
        >>> x = np.random.randn(120).reshape((30,4))
        >>> K = ExpKernel(x)
        >>> msod = MSODHeuristic(K, 100)
        >>> u = MSODDraw(msod)
        >>> # assign according to u, apply treatments, record y
        >>> # or, simulate fake outcomes like below for the sake of example
        >>> tau = 1.
        >>> y = np.linalg.norm(x, axis=1)**2 + 0.5*(np.array(u)+1.)*tau
        >>> # compute our estimator (should be based on actual y outcomes)
        >>> tauhat = np.dot(u, y)/float(len(y)/2)
        >>> print 'The estimate of effect is', tauhat
        >>> pval = Test_WeightRand(msod, y, tauhat)
        >>> print 'The p-value of zero effect is', pval
    """
    taus   = np.dot(msod[1], y)/float(len(y)/2)
    pval   = np.dot(msod[0], np.abs(taus)>=np.abs(tauhat))/np.sum(msod[0])
    return pval

def Test_UnifRand(us, y, tauhat):
    """
    Run Fisher's randomization test (Algorithm 8.2)
    
    Args:
        us:      list of lists of +/-1 denoting assignment (output of
                CompleteRand, BlockOrthant, PairwiseMatch, or RerandMR),
                where one of these assignments was the one actually used
        y:      length n array of recorded outcomes
        tauhat: the measured simple differences estimator
    Returns:
        p value
    Example:
        >>> x = np.random.randn(120).reshape((30,4))
        >>> us = CompleteRand(len(x), 200)
        >>> u = us[0]
        >>> # assign according to u, apply treatments, record y
        >>> # or, simulate fake outcomes like below for the sake of example
        >>> tau = 1.
        >>> y = np.linalg.norm(x, axis=1)**2 + 0.5*(np.array(u)+1.)*tau
        >>> # compute our estimator (should be based on actual y outcomes)
        >>> tauhat = np.dot(u, y)/float(len(y)/2)
        >>> print 'The estimate of effect is', tauhat
        >>> pval = Test_UnifRand(us, y, tauhat)
        >>> print 'The p-value of zero effect is', pval
    """
    taus   = np.dot(us, y)/float(len(y)/2)
    pval   = np.mean(np.abs(taus)>=np.abs(tauhat))
    return pval

Z = np.array([[0.5,-0.5],[0.5,-0.5]])
C = np.array([1.,-1.])
def RunOneSimulationExperiment(n, d,
    f0 = lambda xx: np.dot(C,xx[:2])+np.dot(xx[:2],np.dot(xx[:2],Z)),
    taus = [0.,0.15], alpha = 0.05, sigma = 0, seed = None, boot = 100):
    """
    Run one replicate of the experiment in Examples 2.2 and 5.1
    
    Args:
        n:     number of subjects (must be even)
        d:     dimension of covariates
        f0:    the conditional expectation function of control outcomes
        taus:  the different effect sizes to test against
        alpha: the desired testing significance
        sigma: the standard deviation of the residuals
        seed:  random seed to set (if not None)
        boot:  the number of solutions to use for 
    Returns:
        a dictionary of estimation variance under each of the designs, and
        a dictionary of the frequency of rejection the null hypothesis
        for each effect size and under each of the designs
    Example:
        >>> def run(seed): return RunOneSimulationExperiment(30, 2, seed=seed)
        >>> from multiprocessing import Pool
        >>> pool = Pool(16)
        >>> results = pool.map(run, range(100), 1)
        >>> pool.terminate()
        >>> import itertools
        >>> variances = {k : np.mean([x[1] for x in g]) for k,g in 
                itertools.groupby(sorted(itm for run in results for itm in 
                run[0].iteritems()), lambda x: x[0])}
        >>> rejprobs = {k : np.mean([x[1] for x in g], axis=0) for k,g in 
                itertools.groupby(sorted(itm for run in results for itm in 
                run[1].iteritems()), lambda x: x[0])}
    """
    
    if seed != None: np.random.seed(seed)
    
    n2 = n/2
    x  = np.random.rand(n*d).reshape((n,d))*2-1
    
    classic = {
        'comprand':  CompleteRand(n,500),
        'blocking':  BlockOrthant(x,500),
        'pairmatch': PairwiseMatch(x,500),
        'rerandom':  RerandMR(x,500,.01)
    }
    
    Ks   = {
        'lin':  LinearKernel(x, normalize=False),
        'quad': PolynomialKernel(x, deg=2, normalize=False),
        'gaus': GaussianKernel(x, s=1.),
        'exp':  ExpKernel(x, normalize=False)
    }
    MSODKs = [ 'gaus', 'exp' ]
    
    psods     = {k: PSOD(Ks[k]) for k in Ks}
    psodboots = {k: PSODBoot(Ks[k], boot) for k in Ks}
    msods     = {k: MSODHeuristic(Ks[k], boot) for k in MSODKs}
    
    y0 = np.array(map(f0, x)) + sigma*np.random.randn(n)
    
    condvar     = {}
    condprobrej = {}
    
    for k in classic:
        us = np.array(classic[k])
        stats = np.dot(us, y0)/n2
        condvar[k] = (stats**2).mean()
        us = np.vstack((us,-us))
        stats = np.hstack((stats,-stats))
        statsefabss = [np.abs(np.dot(us, (((us+1)/2)*tau).T)/n2 + stats)
                        for tau in taus]
        condprobrej[k] = [((np.diag(statsefabs).reshape(len(us),1) <=
                        statsefabs).mean(1) <= alpha).mean()
                        for statsefabs in statsefabss]
    
    for k in Ks:
        u    = np.array(psods[k])
        boot = psodboots[k]
        stat = np.dot(u, y0)/n2
        condvar[k+'_PSOD'] = (stat**2)

        bootindexes = np.array([zz[0] for zz in boot])
        bootassgnts = np.array([zz[1] for zz in boot])
        bootindexes = np.vstack((bootindexes,  bootindexes))
        bootassgnts = np.vstack((bootassgnts, -bootassgnts))

        condprobrej[k+'_PSOD'] = []

        for tau in taus:
            probrej = 0.

            for sg in [-1,1]:
                y1          = y0 + ((sg*u+1)/2)*tau
                mystatabs   = np.abs(np.dot(sg*u,y1)/n2)
                bootstatabs = np.abs((bootassgnts * y1[bootindexes]).sum(1)/n2)
                pval        = (1+(bootstatabs >= mystatabs).sum()
                               ) / float(1+len(bootindexes))
                if pval <= alpha:
                    probrej += 0.5

            condprobrej[k+'_PSOD'].append(probrej)
    
    for k in MSODKs:
        t = np.array(msods[k][0])
        z = np.array(msods[k][1])
        
        stats              = np.dot(z, y0)/n2
        condvar[k+'_MSOD']  = np.dot(t, stats**2)/t.sum()
        
        t     = np.hstack((t, t))
        z     = np.vstack((z, -z))
        stats = np.hstack((stats, -stats))
    
        statsefabss = [np.abs(np.dot(z, (((z+1)/2)*tau).T)/n2 + stats)
                        for tau in taus]
    
        condprobrej[k+'_MSOD'] = [np.dot(t, (np.dot((np.diag(statsefabs
                            ).reshape(len(z),1) <= statsefabs), t)/t.sum()
                            <= alpha))/t.sum() for statsefabs in statsefabss]
    
    return (condvar, condprobrej)