function [global_groups, global_obj] = MAP_EM_MPPCA(data, n, d, varargin)
% data: D by N data matrix. 
% n: number of subspaces
% d: dimension of subspaces

% Set default 
vararg = {'replicates', 5, ...
          'prior_pi', ones(n, 1) / n, ...
          'mu_input', cell(n, 1), ...
          'B_input', cell(n, 1), ...
          'noise_variance_input', zeros(n, 1)};
% Overwrite by input
vararg = vararginParser(vararg, varargin);
% Generate variables
for pair = reshape(vararg, 2, []) % pair is {propName;propValue}
   eval([pair{1} '= pair{2};']);
end

% definitions
[D, N] = size(data);

global_obj = +inf;
global_groups = [];
for ireplicates = 1:replicates
    % initialization
    Sigma = cell(1, n);
    p_pi = prior_pi;
    sigma_input = noise_variance_input;
    
    mu0 = mean(data, 2);
    data0 = bsxfun(@minus, data, mu0);
    Sigma0 = data0 * data0' / N;
    [U0, Gamma0] = eigs(Sigma0, d);
    mu = cell(n, 1);
    B = cell(n, 1);
    sigma = zeros(n, 1);
    for ii = 1:n
        if isempty(mu_input{ii})
            mu{ii} = mu0 + randn(D, 1);
        else
            mu{ii} = mu_input{ii};
        end
        if sigma_input(ii) < eps
            sigma(ii) = (trace(Sigma0) - sum(Gamma0(:))) / (D-d);
        else
            sigma(ii) = sigma_input(ii);
        end
        if isempty(B_input{ii})
            basis = U0 + 0.01 * randn(D, d);
            [basis, ~, ~] = svd(basis, 'econ');
            B{ii} = basis * (Gamma0 - sigma(ii) * eye(d)) ^0.5;
        else
            B{ii} = B_input{ii};
        end
        Sigma{ii} = B{ii} * B{ii}' + sigma(ii) * eye(D);
    end
    % iterations
    iter = 0;
    obj_0 = +inf;
    while(1)
        % - E-step
        % compute probability
        log_prob = zeros(n, N);
        for ii = 1:n
            log_prob(ii, :) = logmvnpdf(data', mu{ii}', Sigma{ii}')';
        end
        log_post_pi = bsxfun(@plus, log_prob, log(p_pi(:)));
        % compute weight
        W = zeros(n, N);
        [~, idx] = max(log_post_pi);
        W(sub2ind([n, N], idx, 1:N)) = 1;
        obj = -sum(sum(W .* log_post_pi));
        
        % - M-step
        % update estimate for pi
        p_pi = sum(W, 2) / sum(W(:));
        for ii = 1:n
            % update estimate for pi, mu, Sigma
            mu{ii} = data * W(ii, :)' / sum(W(ii, :));
            data_0 = bsxfun(@minus, data, mu{ii});
            Sigma{ii} = bsxfun(@times, W(ii, :), data_0) * data_0' / sum(W(ii, :));
            % update B, sigma, reset Sigma
            [U, Gamma] = eigs(Sigma{ii}, d);
            B{ii} = U * (Gamma - sigma(ii) * eye(d)) ^0.5;
            sigma(ii) = (trace(Sigma{ii}) - sum(Gamma(:))) / (D-d);
            Sigma{ii} = B{ii} * B{ii}' + sigma(ii) * eye(D);
        end

        iter = iter + 1;
        [~, groups] = max(W, [], 1); % format output.
        if ((obj_0 - obj < 1) && length(unique(groups)) == n) || (iter == 50)
            break;
        else
            obj_0 = obj;
        end
    end
    fprintf('MAP_EM_MPPCA: %d in %d finished, objective: %f\n', ireplicates, replicates, obj);
    if obj < global_obj   
        global_obj = obj;
        global_groups = groups;
    end
end
