function [obj, Y, output] = QAP_ADMM_diag(A,B,options)
% This mfile solves QAP problem with cost matrices A and B, where A is
% defined as
%               A := \sum_{i=0}^{d} a_{i}A_{i} for some given a in R^{d+1},
% where {A_{0},..,A_{d} is the basis of the Bose-Mesner algebra
%
% INPUTS:
% A := matrix in S^{n} in the hamming scheme
% B := matrix in S^{n}, where 2^d = n
% opts := options
%
% OUTPUTS:
% lb := the lower bound to the input QAP instance
% Y  := the solution to the symmetrized QAP relaxation
% output := some more details
% 
% This code is an implementation of the method described in
% Facial Reduction for Symmetry Reduced Semidefinite Programs
% by Hao Hu, Renata Sotirov and Henry Wolkowicz
% 
% last update 2020-9-12

% get the coefficients of A = sum_i a_{i}A_{i}
a = get_coeff(A);

% obtain the parameters
scalar = norm(B,'fro')*1.5;
if abs(scalar)<1e-3
    scalar = 1;
end
B = B/scalar;
n = length(B);
d = length(a)-1;
q = 2;

% get the character table of the Hamming scheme H(d,q)
P = hamming_eigmat(q,d);
blkmu_map = get_blkmu_map(q,d);
w = get_J_map(q,d)'; % the weight for projection to a weighted simplex subproblem
blkmu_map_mat = cell2mat(blkmu_map);

% the multiplicity of eigenvalues
mu = zeros(d+1,1);
for i = 0:d
    mu(i+1) = nchoosek(d,i)*(q-1)^i;
end

% get C in the objective function
C = cell(d+1,1);
for i = 1:d+1
    C{i} = sqrt(mu(i))*(P(i,:)*a)*B(:);
end
obj_map = cell2mat(C)'*cell2mat(blkmu_map);

% the gangster constraints ==> y(~g) = 0
idx1 = ones(n) - eye(n);
idx2 = eye(n);
g = ~logical([idx1(:); kron(ones(d,1),idx2(:))]);

% obtain the exposing vectors and its orthogonal complement
Vhat = cell(d+1,1);
Vhat{1} = mu(1)^(1/4)*ones(n,1);
Vhat{1} = orth_col(Vhat{1});
for i = 2:d+1
    Vhat{i} = mu(i)^(1/4)*[eye(n-1); -ones(1,n-1)];
    Vhat{i} = orth_col(Vhat{i}); % make the columns orthogonal
end
Vdim = [1; ones(d,1)*(n-1)];

% initialize the variables
y = zeros((d+1)*n^2,1);
R = cell(d+1,1);
Z = cell(d+1,1);
VRV =  cell(d+1,1);
for i = 1:d+1
    Z{i} = zeros(n^2,1);
    R{i} = zeros(Vdim(i));
    VRV{i} = Vhat{i}*R{i}*Vhat{i}';
end
Zdiff = Inf;
ydiff = Inf;
pres = Inf;
dres = Inf;
stag = 0;
max_stag = options.max_stag;
lb_best = -99999;
lb = 0;
gap = Inf;

% parameters for the admm
beta = q^d;
gamma = (1+sqrt(5))/2;
tol = options.tol;

% running the admm
admm_t = tic;
fprintf('\nADMM **with** facial reduction ====== \n')
fprintf('%-8s','iter#');
fprintf('%-15s','obj','lb','lb_best','gap','pres','dres','cpu-per-iter','time');

fprintf('\n')
iter = 0;
iter_t = tic;
while (( (pres > tol) || (dres > tol) ) && (stag < max_stag))
    iter = iter + 1;
    
    %%%%%%% Y-subproblem
    y_old = y;
    yhat = cell(d+1,1);
    for i = 1:d+1
        yhat{i} = zeros(n^2,1);
        yhat{i} = reshape((VRV{i}'),n^2,1);
        yhat{i} = yhat{i} - (C{i}+Z{i})/beta;
    end
    yhat = cell2mat(yhat);
    
    % explict solution
    Ay = blkmu_map_mat'*yhat;
    y = zeros((d+1)*n^2,1);
    y(g)  = proj_sim(Ay(g),n^2,w(g));
    ydiff = norm(y - y_old,'fro'); % y-steps
    
    %%%%%% R-subproblem
    for i = 1:d+1
        temp = zeros(n);
        temp(:) = blkmu_map{i}*y + Z{i}/(beta);
        temp = (temp + temp')/2;
        temp = Vhat{i}'*temp*Vhat{i};
        R{i} = proj_psd(temp);

    end
    
    %%%%%% dual update
    Z_old = Z;
    VRV =  cell(d+1,1);
    pres = 0;
    for i = 1:d+1
        VRV{i} =  Vhat{i}*R{i}*Vhat{i}';
        temp = (blkmu_map{i}*y - VRV{i}(:));
        Z{i} = Z{i} + beta*gamma*temp;
        pres = pres + norm(temp,'fro')^2;
    end
    pres = sqrt(pres);
    
    % compute Z-step
    Zdiff = 0;
    for i = 1:d+1
        Zdiff = Zdiff + norm(Z{i} - Z_old{i},'fro')^2;
    end
    Zdiff = sqrt(Zdiff);
    
    % print the results
    if mod(iter,100) == 0
        % compute the objective value
        obj = obj_map*y*scalar;
        
        % compute lb and dual residual
        [lb,dres] = proj_lb(q,d,Z,Vhat,C,w,n^2,scalar,blkmu_map);
        
        % compute the relative gap between primal and dual obj
        old_gap = gap;
        gap = abs((obj-lb_best)/(1+obj+lb_best));

        % save the best bound
        if lb > lb_best
            lb_best = lb;
        end
        
        % detect stagnation
        if abs(gap-old_gap) < 1e-4
            stag = stag + 1;
        else
            stag = 0;
        end

        % print the restuls
        fprintf('%-8i',iter);
        fprintf('%-15.4e',obj,lb,lb_best,gap);
        fprintf('%-15.4e',pres,dres,toc(iter_t)/100,toc(admm_t));
        fprintf('\n')
        
        % restart the timer
        iter_t = tic;
    end
    
end

% print the last iteration
if mod(iter,100)~=0
    fprintf('%-8i',iter);
    fprintf('%-15.4e',obj,lb,lb_best,gap);
    fprintf('%-15.4e',pres,dres,toc(iter_t)/mod(iter,100),toc(admm_t));
    fprintf('\n')
end

% save the solution
Y = y;
obj = obj_map*y*scalar;
output.time = toc(admm_t);
output.obj = obj;
output.lb = lb_best; % the best lower bound
output.lb_last = lb; % lower bound from the last iteration
output.iter = iter;
output.pres = pres;
output.pres = dres;
output.ydiff = ydiff;
output.Zdiff = Zdiff;
output.stag = (stag >= max_stag); % 1 if exit due to stagnation
output.Z = Z;
output.R = R;

% save the ADMM setting
output.beta = beta;
output.scalar = scalar;

fprintf('\n')
fprintf('objective value      : %.5f \n', output.obj)
fprintf('lower bound(best)    : %.5f \n', output.lb)
fprintf('lower bound(last)    : %.5f \n', output.lb_last)
fprintf('pres                 : %d \n', pres)
fprintf('dres                 : %d \n', dres)
fprintf('gap                  : %d \n', gap)
fprintf('Total# iterations    : %d \n', iter)
fprintf('ADMM time(s)         : %.3f (s) \n', output.time)
fprintf('\n')

end   % end of main function

function [lb,dres] = proj_lb(q,d,Z,V,C,w,b,scalar,blkmu_map)
% INPUTS: Z - dual variable from admm
%         V - range of the feasible solutions

% project the dual variable
n = q^d;
nr_blk = size(V,1);
Z_p = cell(nr_blk,1);
for i = 1:nr_blk
    Z_p{i} = zeros(n^2,1);
    Z_p{i}(:) = proj_Z(V{i},reshape(Z{i},n,n)); %*sqrt(mu(i)
end

% vertorize the objective function
for i = 1:nr_blk
    C{i} = C{i}(:);
end

% compute the dual residual
dres = 0;
for i = 1:nr_blk
    dres = dres + norm(Z_p{i} - Z{i},'fro')^2;
end
dres = sqrt(dres);

% add the gangster constraint
idx1 = ones(n) - eye(n);
idx2 = eye(n);
g = logical([idx1(:); kron(ones(d,1),idx2(:))]);

% we know the solution of the dual problem explicitly
w(g) = [];
c = ((cell2mat(C) + cell2mat(Z_p))'*cell2mat(blkmu_map))';
c(g) = [];
lb =  min(c.*(b./w))*scalar;


end

function Z_p = proj_Z(V,Z)
% project Z to the set {Z | V'ZV<=0}

% for simplicity we denote Vhat by V
r = size(V,2);
W = [V null(V')];

% normalize V such that V'*V = I
W(:,1)  = W(:,1)/norm(W(:,1));
for i = r+1:size(W,2)
    alpha = W(:,1:i-1)'*W(:,i);
    W(:,i) = W(:,i) - W(:,1:i-1)*alpha;
    W(:,i) = W(:,i)/norm(W(:,i));
end

W_p = W'*Z*W;
W_p = (W_p+W_p')/2;

% get the projection
[v,D] = eig(-W_p(1:r,1:r));
IND =  find(diag(D) > 0);
W11 = v(:,IND)*D(IND,IND)*v(:,IND)';
W11 = (W11+W11')/2;

% WZW is the projected Z now
W_p(1:r,1:r) = -W11;

Z_p = W*W_p*W';
Z_p = (Z_p + Z_p)/2;

end

%%%%%%%%%%%%%%start of subfunctions
function blk_map = get_blk_map(q,d)
% The cell arry blk_map is used to map the vector y = (vec(Y_0),..,vec(Y_d))
% to the distinct blocks \sum_{i=0}^{d} p_{i,k}Y_i for k = 0,...,d.
% In particular, blk_map{k} maps to the k-th block.

% get the character table of the Hamming scheme H(d,q)
P = hamming_eigmat(q,d);

% construct the matrix representation for each operator P_{k}
n = q^d;
blk_map = cell(d+1,1);
temp = sparse(1:n^2,1:n^2,ones(n^2,1),n^2,n^2);
for i = 0:d
    blk_map{i+1} = kron(P(i+1,:),temp);
end

end

function blkmu_map = get_blkmu_map(q,d)
% The mapping blkmu_map is simiarly to blk_map
% It maps the vector y = (vec(Y_0),..,vec(Y_d)) to the distinct blocks
%  mu(k)*\sum_{i=0}^{d} p_{i,k}Y_i for k = 0,...,d.
% The only difference is the scaler mu(k) in front.

% get blk_map
blkmu_map = get_blk_map(q,d);

% get multiplicities
mu = zeros(d+1,1);
for i = 0:d
    mu(i+1) = nchoosek(d,i)*(q-1)^i;
end

for j = 1:d+1
    blkmu_map{j} = sqrt(mu(j))*blkmu_map{j};
end


end

function J_map = get_J_map(q,d)
% the operator J maps the matrix variables Y_0,..,Y_d to a constant defined
% by <J, \sum A_i \otimes Y_i> = \sum c_i*trace(J*Y_i)
% where c_i =  nchoosek(d,s)*(q-1)^s*q^d is
% the number of ordered pairs of vertices with hamming distance $i$.

a = zeros(1,d+1);
for s = 0:d
    a(s+1) = nchoosek(d,s)*(q-1)^s*q^d;
end

n = q^d;
J_map = kron(a,ones(1,n^2));

end

function a = get_coeff(A)
% INTPUT: matrix A
% OUTPUT: vector a in R^{d+1}
%
% We assume A := \sum_{i=0}^{d} a_{i}A_{i} for some given a in R^{d+1},
% where {A_{0},..,A_{d} is the basis of the Bose-Mesner algebra

n = length(A);
d = log(n)/log(2);

[~,~,basis] = get_hamming(2,d);
a = zeros(d+1,1);
for i = 1:d+1
    I = find(basis{i},1); % find the first non-zero element in basis{i}
    a(i) = A(I);
end

end