%% Single layer classifier
close all;
clear all;
clc;

%% Parameters
% Dataset size
N=1000;
% Learning rate
Eta=0.003;




%% Dataset
% Generate dataset for each class [ X , Y , class (+1 or -1) ]

% Class A (+1)

% ~ 98% good classification
classA = mvnrnd ([2, 2] , [5 1.5; 1.5 1] ,N/2) ;
%% Uncomment the following line to create a 100% good classification
%classA = mvnrnd ([2, 4] , [5 1.5; 1.5 1] ,N/2) ;
classA= [classA , ones(N/2,1) ];

% class B (-1)

classB = mvnrnd ([2,-2] , [3,0;0,0.5] ,N/2);
classB= [classB , -ones(N/2,1) ];

% Merge classes for creating the dataset
dataset=[ classA ; classB ];
% Shuffle dataset
dataset=dataset(randperm(length(dataset)),:);




%% Initialize weight
W=[0;0;0];

%% Trainig loop
for i = 1:size(dataset,1)
    % Forward
    S=W'*[dataset(i,1:2),1]';
    Y=tanh (S);
    
    % Expected output
    Y_=dataset(i,3);
    
    % Update weights
    W=W+Eta*(Y_ - Y)*[dataset(i,1:2),1]'*(1-tanh(S)*tanh(S));
end


%% Display

% Get boundaries (for display)
Xmin=min(dataset(:,1));
Xmax=max(dataset(:,1));
Ymin=min(dataset(:,2));
Ymax=max(dataset(:,2));

%% Output
[X,Y] = meshgrid(Xmin:0.1:Xmax,Ymin:0.1:Ymax);
Z = tanh ( X*W(1) + Y*W(2) + W(3) );
surf(X,Y,Z,'facecolor','texture')
hold on;



%% Display dataset
plot3 (classA(:,1),classA(:,2),4+classA(:,3),'.r'); hold on;
plot3 (classB(:,1),classB(:,2),4+classB(:,3),'.b');
grid on;
axis square equal;




%% Test on training set
good=0;
for i = 1:size(dataset,1)
    % Compute network output
    Y=tanh (W'*[dataset(i,1:2),1]');
    
    % Compare to the expected output
    if (sign(Y)==dataset(i,3))
        % Good classification (green circle)
        good=good+1;
        plot3 (dataset(i,1),dataset(i,2),8+sign(Y),'og');
    else
        % Wrong classification (black cross)
        plot3 (dataset(i,1),dataset(i,2),8+sign(Y),'xk');            
    end
end

% Axis labels and colormap
colormap(jet);
colorbar
shading interp;
xlabel ('X');
ylabel ('Y');
% Uncomment for a top view
%view(0,90);

% Compute success ratio 
badly_classified = 1-good/N


%% Expected output :
% badly_classified =
%
% 0.0060