function [x, TimeCost, InnVec, Primal, phi_xy, alpha_vec, err, KL_vec, varargout] = ...
    AEM_MD(z, A, AT, beta, varargin)
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% AEM (Alternating Extragradient Method) for Total Variation based
% image restoration from Poisson data
% 
% S. BONETTINI, V. RUGGIERO: " An Alternating Extragradient Method for 
% Total Variation based image restoration from Poisson data", Inverse 
% Problems, 2011.
%
% The function solves a deblurring or denoising problem from data corrupted
% by Poisson noise, by minimizing 
%
%  min f(x)==(f0(x)+ beta f1(x))
% subject to x >=eta
%
% where f0(x) is the generalized Kullback Leibler (KL) divergence, f1(x) is
% a regularization function (Total Variation or Hypersurface potential) and
% beta is the positive regularization parameter. 
%
%% SYNOPSIS
%
% [x, TimeCost, InnVec, Primal, phi_xy, alpha_vec, err, KL_vec] = ...
%      AEM_MD(z, A, AT, beta, varargin)
%    
%% MANDATORY INPUT
% 
% z          = detected image;
% A          = convolution operator with the PSF; it uses the Afunction.m
%              file to implement the convolution;
% AT         = transposed of the convolution operator
% beta       = regularization parameter;
%
%% OPTIONAL INPUT
%
% BG             = backgroud term
%                  default = 0
% X_INITIAL      = starting point
%                  default = max(eta,z)
% PHI_PAR        = parameter for the regularization function; 
%                  default = 0
% PHI            = penalty functional
%                  default = PHI_TV
% ETA            = array with the same size of z containig the lower bound 
%                  of the solution; for deblurring problems, 
%                  eta=zeros(size(z)); for denoising problems, we can 
%                  define eta in the main script by the following statements 
%                  zeroindex = z <= 0;
%                  nonzeroindex = ~zeroindex;
%                  eta = min(gn(nonzeroindex))*ones(size(gn));
%                  eta(zeroindex) = 0;
%                  default = zeros(size(z))
% MAXIT          = maximum number of iterations; 
%                  default = 1000
% TOL            = tolerance for the stopping criterium
%                  default = 1e-6
% VERBOSE        = flag; if verbose=1, intermediate informations are 
%                  displayed; if verbose=0, no intermediate output is
%                  displayed; 
%                  default = 0
% NALPHA         = length of memory for the choice of the tentative value
%                  of the steplength; Nalpha=-1 means that the tentive 
%                  value for the steplenth is the last computed value. A well
%                  recommented value is 10;   
%                  default = 1
% EPSILON        = parameter used in the backtracking procedure for the 
%                  determination of a steplength assuring the convergence 
%                  conditions; a well recommended value is 1e-8;
%                  default = 1e-8
% OBJ            = cell array of nobj images to compute the relative difference
%                  between x and these images; in obj we can put the 
%                  original image and/or the ideal solution of the 
%                  minimization problem.  
% ERM            = dimension for the stopping criterium error vector
%                  default = 1
% STOPCRITERIUM  = choice of stopping criterium. The stopping criteria 
%                  are the following:
%                  1: maximum number of iterations
%                  2: relative distance between two successive iterates: 
%                     |(x_k,y_k) - (x_{k-1},y_{k-1}}| / |(x_k,y_k)|
%                  3: distance between the primal function value and the 
%                     primal dual function value at the current iterate, related 
%                     to the primal function value; moreover, the mean of 
%                     the last DIMERM values of this relative difference is
%                     checked
%                  4: Bertero's discrepancy (see Bertero et al, "A
%                     discrepancy principle for Poisson data", Inverse
%                     Problems, 26, 2010)
%                  default = 1
% INALPHA        = initial alpha
%                  default = 1
%
%% OUTPUT
% x         = array of restored image
% TimeCost  = array of the elapsed time in seconds; TimeCost(i+1) is the 
%             time al the end of the i-th iteration from the start of the
%             method; TimeCost(end) is the elapsed time of the method; it 
%             was calculated with the matlab tictoc function;
% InnVec    = array that contains the number of backtracking steps at any
%             iterate; the sum of the entries of InnVec is the number of 
%             total backtracking steps performed;
% Primal    = array of the values of the primal function f(x) at the 
%             current iterate x; Primal(i+1) is f(x) at the iterate i;
% phi_xy    = array of the values of the primal dual function at the 
%             current iterate (x,y1,y2,y3); phi_xy(i+1) is the function at 
%             the iterate i;
% alpha_vec = array of the adaptive steplength parameters   
% err       = cell array of nobj relative errors in euclidean norm computed
%             at each iterations with respect to images in obj{i}; 
%             err{i}(end) is the final relative error  
% KL_vec    = array of the values of the KL function at the current iterate
%             x; KL_vec(i+1) is KL(x) at the iterate i;
% 
% OPTIONAL
% varargout{1} = Y; in the columns of this matrix the dual variables are
%                stored; in particular, for 2D images Y has 2 columns of 
%                length prod(objsize), for 3D images Y has 3 columns, of 
%                length prod(objsize);
%                for example, if you want the variables in the original size, 
%                you may use Y1 = reshape(Y(:,1),objsize);
%                            Y2 = reshape(Y(:,2),objsize); (for 2D images)
%                            Y3 = reshape(Y(:,3),objsize); (for 3D images)
% varargout{2} = y3; the dual variable corresponding to the case when
%                PHI_PAR is greater than zero (see the paper cited above).
%                Again, if you want it in the original size, you may use
%                y3 = reshape(y3,objsize);
%
%
%
% This software is developed within the research project
%
%        PRISMA - Optimization methods and software for inverse problems
%                           http://www.unife.it/prisma
%
% funded by the Italian Ministry for University and Research (MIUR), under the
% PRIN2008 initiative, grant n. 2008T5KA4L, 2010-2012.
%
% Version: 2.0
% Date:    May 2012
%
% -------------------------------------------------------------------------
% % Authors: 
%   Silvia Bonettini, Valeria Ruggiero, Alessandro Benfenati
%    Dept. of Mathematics, University of Ferrara, Italy
%    silvia.bonettini@unife.it, rgv@unife.it, bnflsn@unife.it
%
% Software homepage: http://www.unife.it/prisma/software
%
% Copyright (C) 2012 by S. Bonettini, V. Ruggiero, A. Benfenati
% -------------------------------------------------------------------------
% COPYRIGHT NOTIFICATION
%
% Permission to copy and modify this software and its documentation for 
% internal research use is granted, provided that this notice is retained 
% thereon and on all copies or modifications. The authors and their
% respective Universities makes no representations as to the suitability 
% and operability of this software for any purpose. It is provided "as is"
% without express or implied warranty. Use of this software for commercial
% purposes is expressly prohibited without contacting the authors.
%
% This program is free software; you can redistribute it and/or modify it
% under the terms of the GNU General Public License as published by the
% Free Software Foundation, either version 3 of the License, or (at your 
% option) any later version.
%
% This program is distributed in the hope that it will be useful, but 
% WITHOUT ANY WARRANTY; without even the implied warranty of 
% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. 
% See the GNU General Public License for more details.
%
% You should have received a copy of the GNU General Public License along 
% with this program; if not, either visite http://www.gnu.org/licenses/
% or write to
% Free Software Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
% =========================================================================
%

%% Default parameters
%
bg              = 0;  
eta             = zeros(size(z));
x_initial       = max(eta,z);
NIT             = 1000;
tol             = 1e-6;
verbose         = 0;
Nalpha          = 10;
epsilon         = 1e-8;
inalpha         = 1;
dim_erm         = 1; 
stop            = 1; 
phi             = 'phi_TV';
phi_prime       = ''; % For compatibility with the STATE_EDGEPRES_REG structure 
phi_par         = 0;

if (nargin-length(varargin)) ~= 4
    error('Wrong number of required parameters');
end

if (rem(length(varargin),2)==1)
    error('Optional parameters should always go by pairs');
    fprintf('!\n')
else
    for i=1:2:(length(varargin)-1)
        switch upper(varargin{i})
            case 'BG'
                bg = varargin{i+1};
            case 'X_INITIAL'
                x_initial = varargin{i+1};
            case 'PHI_PAR'
                phi_par = varargin{i+1};
            case 'PHI'
                phi = varargin{i+1};
            case 'ETA'
                eta = varargin{i+1};
            case 'MAXIT'
                NIT = varargin{i+1};
            case 'TOL'
                tol = varargin{i+1};
            case 'VERBOSE'
                verbose = varargin{i+1};
            case 'NALPHA'
                Nalpha = varargin{i+1};
            case 'EPSILON'
                epsilon = varargin{i+1};
            case 'INALPHA'
                inalpha = varargin{i+1};
            case 'ERM'
                dim_erm = varargin{i+1};
            case 'STOPCRITERIUM'
                stop = varargin{i+1};
            case 'OBJ'
                obj = varargin{i+1};
            otherwise
                error(['Unrecognized option: ''' varargin{i} '''']);
        end
    end
end
                
nobj = length(obj);
for i = 1:nobj
    normobj(i) = norm(obj{i}(:));
    err{i}     = zeros(NIT+1,1); %arrays to store errors per iteration
end

%% function handles

if ( isa(A,'function_handle') )
    if ( isempty(AT) )
        error('Missing parameter: AT');
    end
    if( ~isa(AT,'function_handle') )
        error('AT is not a function handle');
    end
else
    % Check column-normalization condition on A
    sumColsA = sum(A)';
    tolCheckA = 1.0e4*eps;
    checkA = find(abs(sumColsA-1) > tolCheckA); 
    if (~isempty(checkA))
        errmsg = sprintf('\n\t%d %s\n\t%s %d:\n\t%s%d%s%e,  %s%e',...
                         length(checkA),...
                         'not-normalized columns found in blurring matrix A.',...
                         'The first one is column',checkA(1),...
                         '|sum(A(:,',checkA(1),')) - 1| = ',...
                         abs(sumColsA(checkA(1))-1),'tolerance = ',tolCheckA);
        error('Not-normalized blurring matrix A: %s%s',...
              'provide a normalized A (see documentation).',errmsg);
    end
    
    AT = @(x) A'*x;
    A  = @(x) A*x;
end

%% Setting data and functions

% Functions inizialization
obj_size  = size(z);
dimen     = numel(size(z));

edgepreserving       = sprintf('edgepreserving_%1dD',dimen);
fval_edgepreserving  = sprintf('fval_edgepreserving_%1dD',dimen);
fval_div             = sprintf('fval_div_%1dD',dimen);
clear_edgepreserving = sprintf('clear_edgepreserving_%1dD;',dimen);

feval(edgepreserving, obj_size, phi, phi_prime, phi_par);

% Vectorization
z         = z(:);
x_initial = x_initial(:);
eta       = eta(:);

N         = length(z);

% Output Inizialization
Primal    = zeros(NIT+1,1); 
phi_xy    = zeros(NIT+1,1); 
alpha_vec = zeros(NIT+1,1);
TimeCost  = zeros(NIT+1,1);
kkt_vec   = zeros(NIT+1,1);
KL_vec    = zeros(NIT+1,1);
InnVec    = zeros(NIT+1,1);
ERM       = ones(dim_erm,1);

% Additional options
% parameters for the backtracking procedure
gamma   = 0.99;
theta   = 0.99;

ONE          = ones(N,1);
zeroindex    = z <= 0;
nonzeroindex = ~zeroindex;
rapp         = zeros(N,1);

% Initial points 
x  = x_initial(:); 
Y  = zeros(N,dimen);
y3 = zeros(N,1);

% Projection of the initial x
ynorm= max(1, sqrt(sum(Y.^2,2)+y3.^2));
Y  = Y./(ynorm*ones(1,dimen));
y3 = y3./ynorm;


%% Computation of Primal function, primal dual function, KL function at the start of the method

% TV
[J_R, D] = feval(fval_edgepreserving,x);
% KL
den = A(x) + bg;
rapp(nonzeroindex) = z(nonzeroindex)./ den(nonzeroindex);
KL = sum( z(nonzeroindex).* log(rapp(nonzeroindex)) + den(nonzeroindex) - z(nonzeroindex) );
KL = KL + sum(den(zeroindex));
% gradient of the KL function
g_KL = ONE - AT(rapp);
% Divergence 
Ay = feval(fval_div,Y); 

Primal(1) = KL+beta*J_R; 
phi_xy(1) = KL + beta*sum(Ay.*x) + beta*phi_par*sum(y3);
KL_vec(1) = KL;

% relative error
for i=1:nobj
    err{i}(1) = norm(obj{i}(:)-x(:))/normobj(i);
end

if verbose
    fprintf('\nInitial: Primal=%8.3e', Primal(1));
    for i=1:nobj
        fprintf(' err{%g} %g', i, err{i}(1));
    end
end

alpha        = inalpha;
alphaOK(1)   = alpha;
alpha_vec(1) = alpha;

TimeCost(1) = 0;
t0          = tic;                %Start CPU clock

for itr=1:NIT
    
    % Old variables
    Yold  = Y;
    y3old = y3;

    Y = Y + alpha*beta*D;
    y3 = y3 + alpha*beta*phi_par;
    
    % Projection on the set Y={||y||<=1}
    ynorm = max(1, sqrt(sum(Y.^2,2)+y3.^2));
    Y  = Y./(ynorm*ones(1,dimen));
    y3 = y3./ynorm;
    Ay = feval(fval_div,Y); 
    
    % Old data
    xold    = x;
    g_KLold = g_KL;
    Dold    = D;

    x = x - alpha*( g_KL + beta*Ay);
    % Projection on the set X
    x = max(x,eta);

    [J_R, D] = feval(fval_edgepreserving,x);
    
    den  = A(x) + bg;
    rapp(nonzeroindex) = z(nonzeroindex)./ den(nonzeroindex);
    g_KL = ONE - AT(rapp);

    Deltax     = x - xold; 
    normDeltax = norm(Deltax);
    
    Ak = norm(g_KLold - g_KL) / normDeltax;
    % Computation of Bk
    s = 0;
    for d = 1:dimen
        s = s + norm(D(:,d)-Dold(:,d))^2;
    end
    s = sqrt(s);
    Bk = beta*s/normDeltax;
    clear s
    
    cond = 1-2*alpha*Ak-2*alpha^2*Bk^2;
    alphabar = gamma*(sqrt(Ak^2+2*Bk^2*(1-epsilon))-Ak)/(2*Bk^2);
    ired = 0;
    
    while cond < epsilon
        
        alpha = min(alphabar,theta*alpha);
        Y  = Yold +alpha*beta*Dold;
        y3 = y3old + alpha*beta*phi_par;
        
        % Projection on the set Y={||y||<=1}
        ynorm= max(1, sqrt(sum(Y.^2,2)+y3.^2));
        Y  = Y./(ynorm*ones(1,dimen));
        %y3 = y3./ynorm;
        Ay = feval(fval_div,Y); 
        
        x = xold - alpha*( g_KLold + beta*Ay);
        % Projection on the set X
        x = max(x,eta);

        [J_R, D] = feval(fval_edgepreserving,x);
             
        den  = A(x)+bg;
        rapp(nonzeroindex) = z(nonzeroindex)./ den(nonzeroindex);
        g_KL = ONE-AT(rapp);
        
        Deltax     = x - xold; 
        normDeltax = norm(Deltax(:));
        Ak = norm(g_KLold(:) - g_KL(:)) / normDeltax;
        % Computation of Bk
        s = 0;
        for d = 1:dimen
            s = s + norm(D(:,d)-Dold(:,d))^2;
        end
        s = sqrt(s);
        Bk = beta*s/normDeltax;
        clear s
        
        cond = 1-2*alpha*Ak-2*alpha^2*Bk^2;
        alphabar = gamma*(sqrt(Ak^2+2*Bk^2*(1-epsilon))-Ak)/(2*Bk^2);
        ired = ired + 1;
    end
    
    InnVec(itr) = ired;
    
    alpha_vec(itr + 1) = alpha;
    Y  = Yold + alpha*beta*D;
    y3 = y3old + alpha*beta*phi_par;
    
    % Projection on the set Y={||y||<=1}
    ynorm= max(1, sqrt(sum(Y.^2,2)+y3.^2));
    Y  = Y./(ynorm*ones(1,dimen));
    y3 = y3./ynorm;
    Ay = feval(fval_div,Y);
    
    KL = sum( z(nonzeroindex).* log(rapp(nonzeroindex)) + den(nonzeroindex) - z(nonzeroindex) );
    KL = KL + sum(den(zeroindex));

    alphaOK(itr) = alpha;
    alpha = mean([alphaOK(max(1,itr-Nalpha):itr) alphabar]);


    Primal(itr+1)   = KL + beta*J_R;
    phi_xy(itr+1)   = beta*sum(Ay.*x) + beta*phi_par* sum(y3) + KL ;
    KL_vec(itr+1)   = KL;
    TimeCost(itr+1) = toc(t0);
    
    for i=1:nobj
        err{i}(itr + 1) = norm(obj{i}(:)-x(:))/normobj(i);
    end
    
    if verbose
        fprintf('\n%4d): f(x)=%g Phi(x,y)=%g KL= %f alpha=%g', itr, ...
            Primal(itr+1), phi_xy(itr+1), KL_vec(itr+1),alpha );
        for i=1:nobj
            fprintf(' err{%g} %g', i, err{i}(itr + 1));
        end
    end
    
    normDeltay =sqrt( sum( sum( sum((Y-Yold).^2,1),2) + sum((y3-y3old).^2) ) ); 
    normy = sqrt(sum(sum(Y.^2,2) + sum(y3.^2)));
    kkt_vec(itr+1) = sqrt(normDeltax^2+normDeltay^2)/sqrt(norm(x)^2+normy^2);
           
    % Stopping criteria
    switch stop
        case 1 
            if itr+1>NIT
                break;
            end
        case 2 
            if kkt_vec(itr+1)<tol
                break;
            end
        case 3 
            ERM(1:dim_erm-1) = ERM(2:end);
            ERM(end) = (Primal(itr+1)-phi_xy(itr+1))/abs(Primal(itr+1));
            m_ERM = mean(ERM);
            if (ERM(end)<tol & m_ERM<10*tol)
                break;
            end
        case 4 
            if KL*2/N<1
                break;
            end
    end
end
%end of the main loop

x = reshape(x,obj_size);

%% Cutting
for i = 1:nobj
    err{i}(itr + 2:end) = [];
end
Primal(itr+2:end)    = [];
phi_xy(itr+2:end)    = [];
alpha_vec(itr+2:end) = [];
TimeCost(itr+2:end)  = [];
KL_vec(itr +2:end)   = [];
InnVec(itr +2:end)   = [];
if verbose
    fprintf('\n');
end

if nargout>8
    varargout{1} = Y;
    varargout{2} = y3;
end

if nobj<1
    err = 'no true object';
end

eval(clear_edgepreserving);

return
% ==============================================================================
% End of AEM.m file 
% ==============================================================================

