首页 > 学院 > 开发设计 > 正文

SVM简单实例-A simple implementation of SVM using Matlab

2019-11-06 08:59:21
字体:
来源:转载
供稿:网友

本文是一个用Matlab实现的简单的SVM实例,仅供参考,如有不足之处,欢迎指正。

主程序如下。

% simple SMO with 2D dataclear; close all; clc;% generate two 'circles' of diameters 20 and 40, each consisting of 100 points[X1, y1] = generateData(100, 2, 10, 1);[X2, y2] = generateData(100, 2, 20, -1);X = [X1; X2];y = [y1; y2];m = size(X,1);xPos = X(y == 1,:);xNeg = X(y == -1,:);scatter(xPos(:,1),xPos(:,2),'y+');hold onscatter(xNeg(:,1),xNeg(:,2),'k+');axis equalhold off% initialization of lagrange multipliersalphas = zeros(m,1);b = 0;% punishment factor CC = 1;% update of lagrange multipliers% the loop stops when alphas remain unchange during 5 iterations% if |alpha_new - alpha| < delta, we don't update alpha% num_updated = number of pairs of alpha updated during one loopdelta = 1e-10;count_not_updated = 0;iter = 0;while(count_not_updated < 6)    %w = zeros(1,2);    w = (y .* alphas)' * X    b    [alphas, b, num_updated] = update_alphas(alphas, b, X, y, C, delta);    if(num_updated == 0)        count_not_updated = count_not_updated + 1;    else        count_not_updated = 0;    end    iter = iter + 1;    fPRintf('iteration = %d/n',iter);    fprintf('count_not_updated = %d/n', count_not_updated);end% calculate w w = (y .* alphas)' * X;% test[res1, res2, res3, res4, res5] = test_svm(X, y, alphas, b);

下面两个函数用于更新拉格朗日乘子。

function [alphas_new, b_new, updated] = update_alphas(alphas, b, X, y, C, delta)alphas_new = alphas;b_new = b;updated = 0;m = size(X,1);index1 = 0;index2 = 0;% update alphasfor index = 1:1:m    u2 = calculate_u(alphas_new, b_new, X, y, X(index,1:2));    bool_2 = is_kkt(index, y, u2, alphas_new, C);    if bool_2 == 0        index2 = index;        E2 = u2 - y(index2);        diff_E = 0;        for indexx = (index2 + 1):1:m            u1 = calculate_u(alphas_new, b_new, X, y, X(indexx,1:2));            bool_1 = is_kkt(indexx, y, u1, alphas_new, C);            if bool_1 == 0                E1 = u1 - y(indexx);                dif = abs(E2 - E1);                if dif > diff_E                    diff_E = dif;                    index1 = indexx;                end                    end        end        if index1 == 0            index1 = index2 + 1;        end        [alpha1, alpha2, b_new, bool] = update_alpha_pair(index1,index2, alphas_new, b_new, X, y, C, delta);        if bool == 1            updated = updated + 1;            alphas_new(index1) = alpha1;            alphas_new(index2) = alpha2;        end        if bool == 0            continue;        end    endendend
function [alpha1, alpha2, b, bool] = update_alpha_pair(index_a1,index_a2, alphas, b_, X, y, C, delta)alpha1 = alphas(index_a1);alpha2 = alphas(index_a2);b = b_;if index_a1 == index_a2    bool = 0;    return;end[L, H] = calculateLH(index_a1, index_a2, alphas, y ,C);eta = calculateLs(index_a1, index_a2, X); %% relation with the kernel ?x1 = X(index_a1,1:2);x2 = X(index_a2,1:2);y1 = y(index_a1);y2 = y(index_a2);u1 = calculate_u(alphas, b_, X, y, x1);u2 = calculate_u(alphas, b_, X, y, x2);E1 = u1 - y1;E2 = u2 - y2;% update alpha2% eta > 0 ?if eta > 0    alpha2 = alphas(index_a2) + y2 * (E1 - E2) / eta;    if alpha2 > H        alpha2 = H;    else if alpha2 < L            alpha2 = L;        end    endelse    bool = 0;    return;endif abs(alphas(index_a2)-alpha2) < delta    bool = 0;    return;endalpha1 = alphas(index_a1) + y1 * y2 * (alphas(index_a2) - alpha2);% update bb1 = b_ - E1 - y1 * (alpha1 - alphas(index_a1)) * ker(x1,x1) - y2 * (alpha2 - alphas(index_a2)) * ker(x1,x2);b2 = b_ - E2 - y1 * (alpha1 - alphas(index_a1)) * ker(x1,x2) - y2 * (alpha2 - alphas(index_a2)) * ker(x2,x2);b = update_b(alpha1, alpha2, b1, b2, C);bool = 1;end下面这个函数用于计算输出函数u。

function u = calculate_u(alphas, b, X, y, x)u = 0;m = size(X,1);Xk = zeros(m,1);for index = 1:1:m    Xk(index) = ker(X(index,1:2),x);end% prediction of x using alphas and b that we foundu = (alphas .* y)' * Xk + b;     end下面这个函数用于计算拉格朗日乘子的上下边界。

function [L, H] = calculateLH(index1, index2, alphas, y ,C)% calculate the borders of alphaif y(index1) == y(index2)   L = max(0, alphas(index2) + alphas(index1) - C);   H = min(C, alphas(index2) + alphas(index1));else    L = max(0, alphas(index2) - alphas(index1));    H = min(C, C + alphas(index2) - alphas(index1));endend        下面这个函数利用核函数计算二阶导。

function Ls = calculateLs(index1, index2, X)x1 = X(index1,:);x2 = X(index2,:);Ls = ker(x1, x1) + ker(x2, x2) - 2 * ker(x1, x2);end下面这个函数更新参数b。

function b = update_b(alpha1, alpha2, b1, b2, C)if 0 < alpha1 && alpha1 < C    b = b1;else if 0 < alpha2 && alpha2 < C        b = b2;    else        b = (b1 + b2)/2;    endendend下面这个函数用于判断拉格朗日乘子是否满足KKT条件。

function bool = is_kkt(index, y, u, alphas, C)bool = 1;% three cases where the lagrange multiplier does not satisfy the kkt% conditionif y(index)*u <= 1 && alphas(index) < C    bool = 0;endif y(index)*u >= 1 && alphas(index) > 0    bool = 0;endif (y(index)*u == 1 && alphas(index) == 0) || (y(index)*u == 1 && alphas(index) == C)    bool = 0;endend下面这个函数用于定义核函数。

function kernel = ker(x1, x2)% kernel function% kernel = x1 * x2' + (x1.^2) * (x2.^2)' + x1(1) * x1(2) * x2(1) * x2(2);kernel = (x1 * x2' + 1)^2;end下面这个函数用于测试。

function [res1, res2, res3, res4, res5] = test_svm(X, y, alphas, b)m = size(X,1);[Xt1, ~] = generateData(100, 2, 5, 0);[Xt2, ~] = generateData(100, 2, 15, 0);[Xt3, ~] = generateData(100, 2, 40, 0);% Wrong test example%res1 = Xt1 * w' + b;%res2 = X(1:fix(m/2),:) * w' + b;%res3 = Xt2 * w' + b;%res4 = X(fix(m/2)+1:m,:) * w' + b;%res5 = Xt3 * w' + b;res1 = zeros(100,1);res2 = zeros(100,1);res3 = zeros(100,1);res4 = zeros(100,1);res5 = zeros(100,1);for i = 1:100    res1(i) = calculate_u(alphas, b, X, y, Xt1(i,:));    res2(i) = calculate_u(alphas, b, X, y, X(i,:));    res3(i) = calculate_u(alphas, b, X, y, Xt2(i,:));    res4(i) = calculate_u(alphas, b, X, y, X(100+i,:));    res5(i) = calculate_u(alphas, b, X, y, Xt3(i,:));endend下面这个函数用于生成训练和测试所用的数据。

function [data, labels] = generateData(m, n, r, label)% we add some perturbations to the circleperturbation = 1;data = zeros(m, n);labels = ones(m, 1) .* label;for i = 1:m;    data(i,1) = 2*r*rand - r;    data(i,2) = (sqrt(r^2 - data(i,1)^2)) * (round(rand)*2-1) + rand * perturbation;    %labels(i) = label;endend

生成的数据大致如下图所示。

参考资料:

Sequential Minimal Optimization: A Fast Algorithm for Training Support Vector Machines, John C. Platt.

支持向量机通俗导论——理解SVM 的三层境界,July · pluskid (http://blog.csdn.net/v_july_v/article/details/7624837)


发表评论 共有条评论
用户名: 密码:
验证码: 匿名发表