首页 > 编程 > Java > 正文

基于Java实现的一层简单人工神经网络算法示例

2019-11-26 10:44:09
字体:
来源:转载
供稿:网友

本文实例讲述了基于Java实现的一层简单人工神经网络算法。分享给大家供大家参考,具体如下:

先来看看笔者绘制的算法图:

2、数据类

import java.util.Arrays;public class Data {  double[] vector;  int dimention;  int type;  public double[] getVector() {    return vector;  }  public void setVector(double[] vector) {    this.vector = vector;  }  public int getDimention() {    return dimention;  }  public void setDimention(int dimention) {    this.dimention = dimention;  }  public int getType() {    return type;  }  public void setType(int type) {    this.type = type;  }  public Data(double[] vector, int dimention, int type) {    super();    this.vector = vector;    this.dimention = dimention;    this.type = type;  }  public Data() {  }  @Override  public String toString() {    return "Data [vector=" + Arrays.toString(vector) + ", dimention=" + dimention + ", type=" + type + "]";  }}

3、简单人工神经网络

package cn.edu.hbut.chenjie;import java.util.ArrayList;import java.util.List;import java.util.Random;import org.jfree.chart.ChartFactory;import org.jfree.chart.ChartFrame;import org.jfree.chart.JFreeChart;import org.jfree.data.xy.DefaultXYDataset;import org.jfree.ui.RefineryUtilities;public class ANN2 {  private double eta;//学习率  private int n_iter;//权重向量w[]训练次数  private List<Data> exercise;//训练数据集  private double w0 = 0;//阈值  private double x0 = 1;//固定值  private double[] weights;//权重向量,其长度为训练数据维度+1,在本例中数据为2维,故长度为3  private int testSum = 0;//测试数据总数  private int error = 0;//错误次数  DefaultXYDataset xydataset = new DefaultXYDataset();  /**   * 向图表中增加同类型的数据   * @param type 类型   * @param a 所有数据的第一个分量   * @param b 所有数据的第二个分量   */  public void add(String type,double[] a,double[] b)  {    double[][] data = new double[2][a.length];    for(int i=0;i<a.length;i++)    {      data[0][i] = a[i];      data[1][i] = b[i];    }    xydataset.addSeries(type, data);  }  /**   * 画图   */  public void draw()  {    JFreeChart jfreechart = ChartFactory.createScatterPlot("exercise", "x1", "x2", xydataset);    ChartFrame frame = new ChartFrame("训练数据", jfreechart);    frame.pack();    RefineryUtilities.centerFrameOnScreen(frame);    frame.setVisible(true);  }  public static void main(String[] args)  {    ANN2 ann2 = new ANN2(0.001,100);//构造人工神经网络    List<Data> exercise = new ArrayList<Data>();//构造训练集    //人工模拟1000条训练数据 ,分界线为x2=x1+0.5    for(int i=0;i<1000000;i++)    {      Random rd = new Random();      double x1 = rd.nextDouble();//随机产生一个分量      double x2 = rd.nextDouble();//随机产生另一个分量      double[] da = {x1,x2};//产生数据向量      Data d = new Data(da, 2, x2 > x1+0.5 ? 1 : -1);//构造数据      exercise.add(d);//将训练数据加入训练集    }    int sum1 = 0;//记录类型1的训练记录数    int sum2 = 0;//记录类型-1的训练记录数    for(int i = 0; i < exercise.size(); i++)    {      if(exercise.get(i).getType()==1)        sum1++;      else if(exercise.get(i).getType()==-1)        sum2++;    }    double[] x1 = new double[sum1];    double[] y1 = new double[sum1];    double[] x2 = new double[sum2];    double[] y2 = new double[sum2];    int index1 = 0;    int index2 = 0;    for(int i = 0; i < exercise.size(); i++)    {      if(exercise.get(i).getType()==1)      {        x1[index1] = exercise.get(i).vector[0];        y1[index1++] = exercise.get(i).vector[1];      }      else if(exercise.get(i).getType()==-1)      {        x2[index2] = exercise.get(i).vector[0];        y2[index2++] = exercise.get(i).vector[1];      }    }    ann2.add("1", x1, y1);    ann2.add("-1", x2, y2);    ann2.draw();    ann2.input(exercise);//将训练集输入人工神经网络    ann2.fit();//训练    ann2.showWeigths();//显示权重向量    //人工生成一千条测试数据    for(int i=0;i<10000;i++)    {      Random rd = new Random();      double x1_ = rd.nextDouble();      double x2_ = rd.nextDouble();      double[] da = {x1_,x2_};      Data test = new Data(da, 2, x2_ > x1_+0.5 ? 1 : -1);      ann2.predict(test);//测试    }    System.out.println("总共测试" + ann2.testSum + "条数据,有" + ann2.error + "条错误,错误率:" + ann2.error * 1.0 /ann2.testSum * 100 + "%");  }  /**   *   * @param eta 学习率   * @param n_iter 权重分量学习次数   */  public ANN2(double eta, int n_iter) {    this.eta = eta;    this.n_iter = n_iter;  }  /**   * 输入训练集到人工神经网络   * @param exercise   */  private void input(List<Data> exercise) {    this.exercise = exercise;//保存训练集    weights = new double[exercise.get(0).dimention + 1];//初始化权重向量,其长度为训练数据维度+1    weights[0] = w0;//权重向量第一个分量为w0    for(int i = 1; i < weights.length; i++)      weights[i] = 0;//其余分量初始化为0  }  private void fit() {    for(int i = 0; i < n_iter; i++)//权重分量调整n_iter次    {      for(int j = 0; j < exercise.size(); j++)//对于训练集中的每条数据进行训练      {        int real_result = exercise.get(j).type;//y        int calculate_result = CalculateResult(exercise.get(j));//y'        double delta0 = eta * (real_result - calculate_result);//计算阈值更新        w0 += delta0;//阈值更新        weights[0] = w0;//更新w[0]        for(int k = 0; k < exercise.get(j).getDimention(); k++)//更新权重向量其它分量        {          double delta = eta * (real_result - calculate_result) * exercise.get(j).vector[k];          //Δw=η*(y-y')*X          weights[k+1] += delta;          //w=w+Δw        }      }    }  }  private int CalculateResult(Data data) {    double z = w0 * x0;    for(int i = 0; i < data.dimention; i++)      z += data.vector[i] * weights[i+1];    //z=w0x0+w1x1+...+WmXm    //激活函数    if(z>=0)      return 1;    else      return -1;  }  private void showWeigths()  {    for(double w : weights)      System.out.println(w);  }  private void predict(Data data) {    int type = CalculateResult(data);    if(type == data.getType())    {      //System.out.println("预测正确");    }    else    {      //System.out.println("预测错误");      error ++;    }    testSum ++;  }}

运行结果:

-0.22000000000000017-0.44168439828154530.442444202054685总共测试10000条数据,有17条错误,错误率:0.16999999999999998%

更多关于java算法相关内容感兴趣的读者可查看本站专题:《Java数据结构与算法教程》、《Java操作DOM节点技巧总结》、《Java文件与目录操作技巧汇总》和《Java缓存操作技巧汇总

希望本文所述对大家java程序设计有所帮助。

发表评论 共有条评论
用户名: 密码:
验证码: 匿名发表