- classification: https://www.mathworks.com/help/stats/classification.html
- confusionchart: https://in.mathworks.com/help/stats/confusionchart.html
Gaussian Naive Bayes classification
4 vues (au cours des 30 derniers jours)
Afficher commentaires plus anciens
I have found the following Matlab implementation of a Naive Bayes classifier:
https://github.com/jjedele/Naive-Bayes-Classifier-Octave-Matlab
How can I extend the above implementation to become Gaussian Naive Bayes?
How can I extend the implementation for using it with 4 classes? Just doing one-vs-all other?
Thank you very much for the help.
0 commentaires
Réponses (1)
Abhipsa
le 1 Sep 2025
There is no need to switch to one-vs-all for Naive Bayes as it’s naturally a multiclass classifier.
In Gaussian Naive Bayes, each feature is assumed to follow a normal distribution within each class.
You can adapt the files from repository(https://github.com/yzhou/Naive-Bayes-Classifier-Octave-Matlab) as shown below:
function model = gnb_train(X, y)
% X: n-by-d matrix (rows = samples, cols = features)
% y: n-by-1 vector of class labels (numeric/char/string)
if iscell(y), y = string(y); end
classes = unique(y);
[n,d] = size(X); K = numel(classes);
mu = zeros(K,d); varc = zeros(K,d); prior = zeros(1,K);
for k = 1:K
idx = (y == classes(k));
Xk = X(idx,:);
prior(k) = sum(idx)/n; % P(y=k)
mu(k,:) = mean(Xk,1); % feature means
varc(k,:) = var(Xk,1); % MLE variances (normalize by N)
end
% Variance floor (avoid divide-by-zero if a feature is constant)
varc = max(varc, 1e-9);
model.classes = classes;
model.prior = prior;
model.mu = mu;
model.var = varc;
end
function [yhat, logpost] = gnb_predict(model, X)
% X: m-by-d test matrix
% yhat: m-by-1 predicted labels
% logpost: m-by-K unnormalized log-posteriors (diagnostics)
[m,~] = size(X);
K = numel(model.classes);
logpost = zeros(m,K);
% Per-class constant term for diagonal Gaussian
const = -0.5 * sum(log(2*pi*model.var), 2); % K-by-1
logprior = log(model.prior); % 1-by-K
for k = 1:K
diff = X - model.mu(k,:);
quad = (diff.^2) ./ model.var(k,:); % m-by-d
ll = const(k) - 0.5 * sum(quad, 2); % m-by-1
logpost(:,k) = ll + logprior(k);
end
[~, idx] = max(logpost, [], 2);
yhat = model.classes(idx);
end
The above two functions can be run by using a synthesis data as shown in the below script:
% Seed for reproducibility (optional)
rng(7);
% ---- 1) Make a simple 4-class 2D dataset (Gaussian blobs)
nPerClass = 150;
C1 = mvnrnd([0, 0], [0.5 0; 0 0.3], nPerClass);
C2 = mvnrnd([3, 1], [0.6 0; 0 0.6], nPerClass);
C3 = mvnrnd([-2, 3], [0.4 0; 0 0.7], nPerClass);
C4 = mvnrnd([2, -2], [0.7 0; 0 0.4], nPerClass);
X = [C1; C2; C3; C4];
y = [ones(nPerClass,1);
2*ones(nPerClass,1);
3*ones(nPerClass,1);
4*ones(nPerClass,1)];
% ---- 2) Standardize features (recommended for GNB)
muX = mean(X,1); sigX = std(X,0,1); sigX(sigX==0) = 1;
Xz = (X - muX) ./ sigX;
% ---- 3) Train/test split (70/30)
n = size(Xz,1);
idx = randperm(n);
nTr = round(0.7*n);
tr = idx(1:nTr);
te = idx(nTr+1:end);
Xtr = Xz(tr,:); ytr = y(tr);
Xte = Xz(te,:); yte = y(te);
% ---- 4) Train Gaussian Naive Bayes
model = gnb_train(Xtr, ytr);
% ---- 5) Predict
[yhat, scores] = gnb_predict(model, Xte);
% ---- 6) Accuracy and confusion matrix
acc = mean(yhat == yte);
fprintf('Test accuracy: %.2f%%\n', 100*acc);
K = numel(unique(y));
CM = zeros(K);
for i = 1:numel(yte)
CM(yte(i), yhat(i)) = CM(yte(i), yhat(i)) + 1;
end
disp('Confusion matrix (rows = true, cols = predicted):');
disp(CM);
% ---- 7) (Optional) quick scatter to visualize test set predictions
figure; hold on;
clsColors = lines(4); % 4 distinct colors
for c = 1:4
pts = (yte == c);
scatter(Xte(pts,1), Xte(pts,2), 25, clsColors(c,:), 'filled', ...
'MarkerFaceAlpha', 0.7, 'DisplayName', sprintf('True C%d', c));
end
for c = 1:4
pts = (yhat == c);
scatter(Xte(pts,1), Xte(pts,2), 10, 'k', 'o'); % black rings for predicted class
end
title(sprintf('Gaussian NB on 4 classes (Test acc = %.1f%%)', 100*acc));
xlabel('z-score feature 1'); ylabel('z-score feature 2'); grid on; box on; hold of
You can refer to the following MATLAB documentations for more details:
0 commentaires
Voir également
Catégories
En savoir plus sur Classification dans Help Center et File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!
