%   This code tests the performance of R-graph for outlier detection on the 
%   Extended Yale B face image database. The code generates results in 
%   Table 1 of the paper
%
%   Chong You, Daniel Robinson, Rene Vidal,
%   "Provable Self-Representation Based Outlier Detection in a Union of 
%   Subspaces", CVPR 2017.

% Copyright Chong You @ Johns Hopkins University, 2017
% chong.you1987@gmail.com

clear all;close all;

addpath('data')
addpath('tools')
load( 'ExtendedYaleB.mat' );
% Data is preprocessed and saved in the ExtendedYaleB.mat file. 
% EYALEB_DATA is a D by N matrix. Each column is a face image and N = 
% 38 subjects * 64 images/subject = 2414. Each image is downsampled from 
% 192*168 to D = 48*42 = 2016.
% EYALEB_LABEL is a 1 by N vector. Each entry is the label for the
% corresponding column in EYALEB_DATA.
% EYALEB_NAME is a 1 by N cell array. Each entry is the name of the image
% of the corresponding column in EYALEB_DATA
%% Set param
Num_Inlier_Group = 3; % number of inlier groups (taken to be 1 or 3 in the paper)
outlier_perc = 0.15; % percentage of outliers (taken to be 35% or 15% in the paper)
Nexperiment = 50;

fprintf('EYaleB: Num_Inlier_Group = %d, outlier perc = %f, Nexperiment = %d\n', Num_Inlier_Group, outlier_perc, Nexperiment);

flag_ROC = false; % set true to generate ROC curve plots for each experiment 
%(set Nexperiment to be small if flag_ROC is set to be true!)

%% Experiment
result = zeros(3, Nexperiment);
for iexperiment = 1:Nexperiment
    subjectIdx = randperm(38, Num_Inlier_Group);
    
% Prepare data  
    % inliers are images of certain individual(s) specified by subjectIdx
    datapointIdx = find(ismember(EYALEB_LABEL, subjectIdx));
    data_inlier = EYALEB_DATA(:, datapointIdx);
    name_inlier = EYALEB_NAME(datapointIdx);
    Ninlier = length(datapointIdx);

    % outliers are images of other individuals, with at most one image taken
    % from each individual.
    Noutlier = round(outlier_perc / (1-outlier_perc) * Ninlier);
    % pick #Noutlier subjects
    outsubjectIdx = setdiff(1:38, subjectIdx); 
    outsubjectIdx = outsubjectIdx(randperm(38 - length(subjectIdx), Noutlier)); 
    outpointIdx = zeros(1, Noutlier);
    for ii = 1:Noutlier
        % pick one image from each subject
        datapointIdx = find(ismember(EYALEB_LABEL, outsubjectIdx(ii)));
        outpointIdx(ii) = datapointIdx(randperm(length(datapointIdx), 1)); 
    end
    data_outlier = EYALEB_DATA(:, outpointIdx);
    name_outlier = EYALEB_NAME(outpointIdx);

    % compose test data
    data = [data_inlier, data_outlier];
    s = [zeros(1, Ninlier), ones(1, Noutlier)];
    filename = [name_inlier, name_outlier];
    N = Ninlier + Noutlier;

    
% R-graph outlier detection
    tic;
    
    data = dimReduction(data, min(size(data)));
    % step 1: compute representation R from data (line 1 of Alg. 1)
    lambda = 0.95;
    alpha = 5;
    gamma = @(X, y, lambda, alpha)  alpha*lambda/max(abs(X'*y));
    EN_solver =  @(X, y) rfss( full(X), full(y), lambda / gamma(X, y, lambda, alpha), (1-lambda) / gamma(X, y, lambda, alpha) );
    R = selfRepresentation(cnormalize(data), EN_solver);
    % step 2: compute transition P from R (line 2 of Alg. 1)
    P = cnormalize(abs(R), 1)';
    % step 3: compute \pi from P (line 3 - 7 of Alg. 1)
    T = 1000;
    pi = ones(1, N) / N;
    pi_bar = zeros(1, N);
    for ii = 1:T
        pi = pi * P;
        pi_bar = pi_bar + pi;
    end
    pi_bar = pi_bar / T;
    %
    feat = - pi_bar; % larger values in feat indicate higher "outlierness"
      
% Evaluation
    time = toc;
    
    [FPR, TPR, T, AUC] = perfcurve(s, feat, 1);
    [PREC, RECA] = perfcurve(s, feat, 1, 'XCrit', 'prec', 'YCrit', 'reca');
    F1 = max(2 * (PREC .* RECA) ./ (PREC + RECA));
    fprintf('Experiment %d: AUC = %f, F1 = %f, time = %f\n', iexperiment, AUC, F1, time);

    result(:, iexperiment) = [AUC, F1, time]';
    if flag_ROC
        figure;
        plot(FPR, TPR, '-r');
        xlabel('False positive rate')
        ylabel('True positive rate')
    end
end
fprintf('Mean AUROC: %f, Mean F1: %f, Mean time: %f sec.\n', mean(result(1, :)), mean(result(2, :)), mean(result(3, :)));




    
