这学期在学模式识别,老师布置作业让实现一些比较基础简单的算法

下面是感知器算法的实现过程

感知器算法是线性分类器中一个比较基础但是比较重要的算法

W为权向量,g(x)为线性判别函数

通过对W的调整,可实现判别函数g(x) =WTX > RT

其中RT为响应阈值

定义感知准则函数:只考虑错分样本

定义:其中x0为错分样本

当分类发生错误时就有WTX <0,或-WTX >0, 所以J(W)

总是正值,错误分类愈少, J(W)就愈小。

理想情况为即求最小值的问题。

感知器算法:

1.错误分类修正wk

w(k)Tx≤0并且x∈ω1

w(k+1)=

w(k)+ρ

(k)x

w(k)Tx≥0并且x∈ω2

w(k+1)=

w(k)-ρ

(k)x

2.正确分类 ,wk不修正

如w(k)Tx>0并且x∈ω1

如w(k)Tx<0并且x∈ω2     w(k+1)= w(k)

换句话说感知器算法也就是对于正确分类的样本w(k+1)=w(k)而对于错误分类的样本就进行校正

程序如下:

说明:

其中w1是一类,w2是一类,因为感知器算法在模式线性可分的情况下才能收敛得到最后的权向量,所以在次我在两个区域随机生成了点,为了其可线性可分

/*

感知器算法

分为两类

*/

#define N1 30

#define N2 40

#include

#include

#include

#include

//模式结构体

struct Pattern {

//x1和x2表示特征

int x1;

int x2;

//c代表添加的量

int c;

};

//定义w权矢量

Pattern w;

//定义w1和w2数组

Pattern w1[N1],w2[N2];

void init()

{

glClearColor(1.0f , 1.0f , 1.0f , 1.0f);

glClear(GL_COLOR_BUFFER_BIT);

//glEnable(GL_POINT_SMOOTH);

}

void display()

{

float temp1,temp2;

if(w.x2==0)

{

temp1 = (float)(-w.c)/(float)(w.x1);

glColor3f(1.0f,0.0f,0.0f);

glEnable(GL_POINT_SMOOTH);

glBegin(GL_LINES);

glVertex2f(temp1,-1);

glVertex2f(temp1,20);

glEnd();

glFlush();

}

else

{

temp1 = (float)(-w.c+2*w.x1)/(float)w.x2;

temp2 = (float)(-w.c-20*w.x1)/(float)w.x2;

glColor3f(1.0f,0.0f,0.0f);

glEnable(GL_POINT_SMOOTH);

glBegin(GL_LINES);

glVertex2f(-2,temp1);

glVertex2f(20,temp2);

glEnd();

glFlush();

}

glPointSize(10.0);

glBegin(GL_POINTS);

int i,j;

for(i=0;i

{

glColor3f(1.0,0.0,0.0);

glVertex2f(w1[i].x1,w1[i].x2);

}

for(j=0;j

{

glColor3f(0.0,0.0,1.0);

glVertex2f(w2[j].x1,w2[j].x2);

}

glEnd();

glFlush();

}

int main(int argc, char **argv)

{

int i,j,count=0;

w.x1 = 1;

w.x2 = 1;

w.c = 1;

//产生随机数的种子

printf("这是第一类中的点\n");

srand((unsigned)time(NULL));

for(i=0;i

{

w1[i].x1=rand()%10;

w1[i].x2=rand()%10;

w1[i].c=1;

//打印随机产生的点

printf("%d %d\n",w1[i].x1,w1[i].x2);

}

printf("这是第二类中的点\n");

srand((unsigned)time(NULL));

for(j=0;j

{

w2[j].x1=10+rand()%10;

w2[j].x2=10+rand()%10;

w2[j].c=1;

//打印随机产生的点

printf("%d %d\n",w2[j].x1,w2[j].x2);

}

//感知器迭代判断

while(1)

{

for(i=0;i

{

if(count>=N1+N2)

{

break;

}

if(((w1[i].x1*w.x1)+(w1[i].x2*w.x2)+(w1[i].c*w.c))>0)

{

count++;

}

else

{

count=0;

w.x1=w.x1+w1[i].x1;

w.x2=w.x2+w1[i].x2;

w.c=w.c+w1[i].c;

}

}

for(j=0;j

{

if(count>=N1+N2)

{

break;

}

if((w2[j].x1*w.x1+w2[j].x2*w.x2+w2[j].c*w.c)<0)

{

count++;

}

else

{

count=0;

w.x1=w.x1-w2[j].x1;

w.x2=w.x2-w2[j].x2;

w.c=w.c-w2[j].c;

}

}

if(count==N1+N2)

{

break;

}

}

printf("线性分类器的权矢量的值");

printf("%2d,%2d,%2d\n",w.x1,w.x2,w.c);

glutInit(&argc, argv); //初始化glut

glutInitDisplayMode(GLUT_SINGLE | GLUT_RGB); //设置显示属性为RGB颜色, 单缓冲

glutInitWindowSize(800,550); //设置窗口大小

glutInitWindowPosition(200,100); //设置窗口位置

glutCreateWindow("线性分类器的设计"); //生成窗口

init();

gluOrtho2D(-1,20,-1,20);

glutDisplayFunc(&display); //显示回调

glutMainLoop();

//getchar();

//getchar();

system("pause");

return 0;

}程序结果:

0818b9ca8b590ca3270a3433284dd417.png

Logo

CSDN联合极客时间,共同打造面向开发者的精品内容学习社区,助力成长!

更多推荐