matlab 使用svm进行分类(适用于二分类和多分类

1. 简单二分类

clear,clc

%% 二分类
%训练数据20×2,20行代表20个训练样本点,第一列代表横坐标,第二列纵坐标
Train_Data =[-3 0;4 0;4 -2;3 -3;-3 -2;1 -4;-3 -4;0 1;-1 0;2 2;3 3;-2 -1;-4.5 -4;2 -1;5 -4;-2 2;-2 -3;0 2;1 -2;2 0];
%Group 20 x 1,20行代表训练数据对应点属于哪一类(1,-1)
Train_labels =[1 -1 -1 -1 1 -1 1 1 1 -1 -1 1 1 -1 -1 1 1 1 -1 -1]';
TestData = [3 -1;3 1;-2 1;-1 -2;2 -3;-3 -3];%测试数据
classifier  = fitcsvm(Train_Data,Train_labels); %train
test_labels  = predict(classifier ,TestData); % test

这里 test_labels 就是最后的分类结果啦,大家可以按照这个格式对自己的数据进行修改

2. 多分类(不调用工具箱)

因为

%% 多分类
TrainingSet=[ 1 10;2 20;3 30;4 40;5 50;6 66;3 30;4.1 42];%训练数据
TestSet=[3 34; 1 14; 2.2 25; 6.2 63];%测试数据
GroupTrain=[1;1;2;2;3;3;2;2];%训练标签
results =my_MultiSvm(TrainingSet, GroupTrain, TestSet);
disp('multi class problem');
disp(results);

results为最终的分类结果,上述中有用到 my_MultiSvm.m() 函数,以下是my_MultiSvm.m函数的全部内容

function [y_predict,models] = my_MultiSvm(X_train, y_train, X_test)
% multi svm
% one vs all 模型
% Input:
% X_train: n*m矩阵 n为训练集样本数 m为特征数
% y_train: n*1向量 为训练集label,支持任意多种类
% X_test: n*m矩阵 n为测试集样本数 m为特征数
% Output:
% y_predict: n*1向量 测试集的预测结果
% 
% Copyright(c) lihaoyang 2020
%

    y_labels = unique(y_train);
    n_class = size(y_labels, 1);
    models = cell(n_class, 1);
    % 训练n个模型
    for i = 1:n_class
        class_i_place = find(y_train == y_labels(i));
        svm_train_x = X_train(class_i_place,:);
        sample_num = numel(class_i_place);
        class_others = find(y_train ~= y_labels(i));
        randp = randperm(numel(class_others));
        svm_train_minus = randp(1:sample_num)';
        svm_train_x = [svm_train_x; X_train(svm_train_minus,:)];
        svm_train_y = [ones(sample_num, 1); -1*ones(sample_num, 1)];
        disp(['生成模型:', num2str(i)])
        models{i} = fitcsvm(svm_train_x, svm_train_y);
    end
    test_num = size(X_test, 1);
    y_predict = zeros(test_num, 1);
    % 对每条数据,n个模型分别进行预测,选择label为1且概率最大的一个作为预测类别
    for i = 1:test_num
        if mod(i, 100) == 0
            disp(['预测个数:', num2str(i)])
        end
        bagging = zeros(n_class, 1);
        for j = 1:n_class
            model = models{j};
            [label, rat] = predict(model, X_test(i,:));
            bagging(j) = bagging(j) + rat(2);
        end
        [maxn, maxp] = max(bagging);
        y_predict(i) = y_labels(maxp);
    end
end

3.多分类(调用libsvm工具箱)

以下代码是调用matlab工具箱libsvm的一种方法

TrainingSet=[ 1 10;2 20;3 30;4 40;5 50;6 66;3 30;4.1 42];%训练数据
TestSet=[3 34; 1 14; 2.2 25; 6.2 63];%测试数据
GroupTrain=[1;1;2;2;3;3;2;2];%训练标签
GroupTest=[1;2;1;3];%测试标签

%svm分类
model = svmtrain(GroupTrain,TrainingSet);
% SVM网络预测
[predict_label] = svmpredict(GroupTest,TestSet,model);

之所以放到最后,是因为需要在matlab安装libsvm的工具箱,具体方法可参看此链接在Matlab中安装LibSVM工具箱

下载libsvm也可以百度网盘:百度网盘libsvm
提取码:25ft

Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。