####### some useful functions #######
# round a complex number
def rd(x, pr = 6):
    y = CC(x)
    return round(real(y), pr) + i*round(imag(y), pr)

# round a matrix with complex entries
def rd_mtx(A, pr = 6):
    return A.apply_map(lambda it: rd(it, pr))

# squared modulus of a complex number
def sqmd(x):
    return real(x)^2+imag(x)^2

# partition a list [lst] to a given precision [eps] and record indices
def partn(lst, eps = 0):
    dc = {lst[0]:[0]}
    keys = [lst[0]]
    for it in [1..len(lst) - 1]:
        new_key = lst[it]
        key_exists = False
        for key in keys:
            if abs(new_key - key) <= eps:
                key_exists = True
                dc[key].append(it)
                break
        if not key_exists:
                dc[new_key] = [it]
                keys.append(new_key)
    return dc

####### linear algebra #######
import scipy
from scipy import linalg

# spectral decomposition of a Hermitian matrix H
def spd_herm(H, evec = False, eps = 1e-6):
    n = H.nrows()
    evs, evecs = linalg.eigh(H.change_ring(CC))
    dc = partn(evs, eps)
    Es = []
    for ev in dc.keys():
        emtx = Matrix(evecs[:, dc[ev]])
        eproj = emtx*(emtx.H)
        if evec:
            Es.append((ev, Matrix(eproj), emtx))
        else:
            Es.append((ev, Matrix(eproj)))
    return Es

# spectral decomposition of a normal matrix M
def spd(M):
    A = 1/2*(M + M.H)
    B = i/2*(M.H - M)
    Es_A = spd_herm(A.change_ring(CC), True)
    Es_M = []
    for it in Es_A:
        C = (it[2].H)*B.change_ring(CC)*it[2]
        Es_C = spd_herm(C)
        for j in Es_C:
            Es_M.append((it[0] + i*j[0], it[2]*j[1]*(it[2].H)))
    return Es_M

####### discrete-time quantum walks #######
# transition matrix of the arc-reversal walk with Grover coins
def trans_disc_arg(X):
    Y = DiGraph(X).line_graph(labels = False)
    lst_arcs = Y.vertices()
    m = len(lst_arcs)
    U = Matrix(QQ, m, m)
    for arc_row in lst_arcs:
        deg = X.degree(arc_row[0])
        ind_row = lst_arcs.index(arc_row)
        for arc_col in Y.neighbors_out(arc_row):
            ind_col = lst_arcs.index(arc_col)
            if arc_row[0] == arc_col[1]:
                U[ind_row, ind_col] = 2/deg - 1
            else:
                U[ind_row, ind_col] = 2/deg
    return U

### states ###
# uniform linear combination of outgoing arcs of v
def state_og(X, v):
    deg = X.degree(v)
    lst_arcs = DiGraph(X).edges(labels = False)
    l = len(lst_arcs)
    inds = [j for j in range(l) if lst_arcs[j][0] == v]
    state = Matrix(CC, l, 1)
    for it in inds:
        state[it, 0] = 1/sqrt(deg)
    return state

### probabilities ###
# mixing matrix
def prob_disc(sp):
    return lambda t: Matrix(sum([ev^t * eproj for ev, eproj in sp])).apply_map(sqmd)

# the j-th column of the mixing matrix
def prob_disc_col(sp, j):
    sp_col = [(ev, eproj[:, j]) for ev, eproj in sp]
    return lambda t: sum(it[0]^t * it[1] for it in sp_col)

# the (it, j)-entry of the mixing matrix
def prob_disc_entry(sp, it, j):
    sp_entry = [(ev, eproj[it, j]) for ev, eproj in sp]
    return lambda t: sqmd(sum(it[0]^t * it[1] for it in sp_entry))

# probability from vertex a to vertex b
def prob_disc_vxs(sp, X, a, b):
    lst_arcs = DiGraph(X).edges(labels = False)
    l = len(lst_arcs)
    ini_state = state_og(X, a)
    sp_state = [(ev, eproj * ini_state) for ev, eproj in sp]
    inds = [j for j in range(l) if lst_arcs[j][0] == b]
    return lambda t: sum(sum(it[0]^t * it[1] for it in sp_state).apply_map(sqmd)[j, 0] for j in inds)