基于Java實現(xiàn)的一層簡單人工神經網絡算法示例
更新時間:2017年12月08日 11:16:46 作者:土豆拍死馬鈴薯
這篇文章主要介紹了基于Java實現(xiàn)的一層簡單人工神經網絡算法,結合實例形式分析了java實現(xiàn)人工神經網絡的具體實現(xiàn)技巧,需要的朋友可以參考下
本文實例講述了基于Java實現(xiàn)的一層簡單人工神經網絡算法。分享給大家供大家參考,具體如下:
先來看看筆者繪制的算法圖:
2、數(shù)據(jù)類
import java.util.Arrays; public class Data { double[] vector; int dimention; int type; public double[] getVector() { return vector; } public void setVector(double[] vector) { this.vector = vector; } public int getDimention() { return dimention; } public void setDimention(int dimention) { this.dimention = dimention; } public int getType() { return type; } public void setType(int type) { this.type = type; } public Data(double[] vector, int dimention, int type) { super(); this.vector = vector; this.dimention = dimention; this.type = type; } public Data() { } @Override public String toString() { return "Data [vector=" + Arrays.toString(vector) + ", dimention=" + dimention + ", type=" + type + "]"; } }
3、簡單人工神經網絡
package cn.edu.hbut.chenjie; import java.util.ArrayList; import java.util.List; import java.util.Random; import org.jfree.chart.ChartFactory; import org.jfree.chart.ChartFrame; import org.jfree.chart.JFreeChart; import org.jfree.data.xy.DefaultXYDataset; import org.jfree.ui.RefineryUtilities; public class ANN2 { private double eta;//學習率 private int n_iter;//權重向量w[]訓練次數(shù) private List<Data> exercise;//訓練數(shù)據(jù)集 private double w0 = 0;//閾值 private double x0 = 1;//固定值 private double[] weights;//權重向量,其長度為訓練數(shù)據(jù)維度+1,在本例中數(shù)據(jù)為2維,故長度為3 private int testSum = 0;//測試數(shù)據(jù)總數(shù) private int error = 0;//錯誤次數(shù) DefaultXYDataset xydataset = new DefaultXYDataset(); /** * 向圖表中增加同類型的數(shù)據(jù) * @param type 類型 * @param a 所有數(shù)據(jù)的第一個分量 * @param b 所有數(shù)據(jù)的第二個分量 */ public void add(String type,double[] a,double[] b) { double[][] data = new double[2][a.length]; for(int i=0;i<a.length;i++) { data[0][i] = a[i]; data[1][i] = b[i]; } xydataset.addSeries(type, data); } /** * 畫圖 */ public void draw() { JFreeChart jfreechart = ChartFactory.createScatterPlot("exercise", "x1", "x2", xydataset); ChartFrame frame = new ChartFrame("訓練數(shù)據(jù)", jfreechart); frame.pack(); RefineryUtilities.centerFrameOnScreen(frame); frame.setVisible(true); } public static void main(String[] args) { ANN2 ann2 = new ANN2(0.001,100);//構造人工神經網絡 List<Data> exercise = new ArrayList<Data>();//構造訓練集 //人工模擬1000條訓練數(shù)據(jù) ,分界線為x2=x1+0.5 for(int i=0;i<1000000;i++) { Random rd = new Random(); double x1 = rd.nextDouble();//隨機產生一個分量 double x2 = rd.nextDouble();//隨機產生另一個分量 double[] da = {x1,x2};//產生數(shù)據(jù)向量 Data d = new Data(da, 2, x2 > x1+0.5 ? 1 : -1);//構造數(shù)據(jù) exercise.add(d);//將訓練數(shù)據(jù)加入訓練集 } int sum1 = 0;//記錄類型1的訓練記錄數(shù) int sum2 = 0;//記錄類型-1的訓練記錄數(shù) for(int i = 0; i < exercise.size(); i++) { if(exercise.get(i).getType()==1) sum1++; else if(exercise.get(i).getType()==-1) sum2++; } double[] x1 = new double[sum1]; double[] y1 = new double[sum1]; double[] x2 = new double[sum2]; double[] y2 = new double[sum2]; int index1 = 0; int index2 = 0; for(int i = 0; i < exercise.size(); i++) { if(exercise.get(i).getType()==1) { x1[index1] = exercise.get(i).vector[0]; y1[index1++] = exercise.get(i).vector[1]; } else if(exercise.get(i).getType()==-1) { x2[index2] = exercise.get(i).vector[0]; y2[index2++] = exercise.get(i).vector[1]; } } ann2.add("1", x1, y1); ann2.add("-1", x2, y2); ann2.draw(); ann2.input(exercise);//將訓練集輸入人工神經網絡 ann2.fit();//訓練 ann2.showWeigths();//顯示權重向量 //人工生成一千條測試數(shù)據(jù) for(int i=0;i<10000;i++) { Random rd = new Random(); double x1_ = rd.nextDouble(); double x2_ = rd.nextDouble(); double[] da = {x1_,x2_}; Data test = new Data(da, 2, x2_ > x1_+0.5 ? 1 : -1); ann2.predict(test);//測試 } System.out.println("總共測試" + ann2.testSum + "條數(shù)據(jù),有" + ann2.error + "條錯誤,錯誤率:" + ann2.error * 1.0 /ann2.testSum * 100 + "%"); } /** * * @param eta 學習率 * @param n_iter 權重分量學習次數(shù) */ public ANN2(double eta, int n_iter) { this.eta = eta; this.n_iter = n_iter; } /** * 輸入訓練集到人工神經網絡 * @param exercise */ private void input(List<Data> exercise) { this.exercise = exercise;//保存訓練集 weights = new double[exercise.get(0).dimention + 1];//初始化權重向量,其長度為訓練數(shù)據(jù)維度+1 weights[0] = w0;//權重向量第一個分量為w0 for(int i = 1; i < weights.length; i++) weights[i] = 0;//其余分量初始化為0 } private void fit() { for(int i = 0; i < n_iter; i++)//權重分量調整n_iter次 { for(int j = 0; j < exercise.size(); j++)//對于訓練集中的每條數(shù)據(jù)進行訓練 { int real_result = exercise.get(j).type;//y int calculate_result = CalculateResult(exercise.get(j));//y' double delta0 = eta * (real_result - calculate_result);//計算閾值更新 w0 += delta0;//閾值更新 weights[0] = w0;//更新w[0] for(int k = 0; k < exercise.get(j).getDimention(); k++)//更新權重向量其它分量 { double delta = eta * (real_result - calculate_result) * exercise.get(j).vector[k]; //Δw=η*(y-y')*X weights[k+1] += delta; //w=w+Δw } } } } private int CalculateResult(Data data) { double z = w0 * x0; for(int i = 0; i < data.dimention; i++) z += data.vector[i] * weights[i+1]; //z=w0x0+w1x1+...+WmXm //激活函數(shù) if(z>=0) return 1; else return -1; } private void showWeigths() { for(double w : weights) System.out.println(w); } private void predict(Data data) { int type = CalculateResult(data); if(type == data.getType()) { //System.out.println("預測正確"); } else { //System.out.println("預測錯誤"); error ++; } testSum ++; } }
運行結果:
-0.22000000000000017 -0.4416843982815453 0.442444202054685 總共測試10000條數(shù)據(jù),有17條錯誤,錯誤率:0.16999999999999998%
更多關于java算法相關內容感興趣的讀者可查看本站專題:《Java數(shù)據(jù)結構與算法教程》、《Java操作DOM節(jié)點技巧總結》、《Java文件與目錄操作技巧匯總》和《Java緩存操作技巧匯總》
希望本文所述對大家java程序設計有所幫助。
相關文章
Java?for循環(huán)標簽跳轉到指定位置的示例詳解
這篇文章主要介紹了Java?for循環(huán)標簽跳轉到指定位置,本文通過實例代碼給大家介紹的非常詳細,對大家的學習或工作具有一定的參考借鑒價值,需要的朋友可以參考下2023-05-05Spring Boot報錯:No session repository could be auto-configured
這篇文章主要給大家介紹了關于Spring Boot報錯:No session repository could be auto-configured, check your configuration的解決方法,文中給出了詳細的解決方法,對遇到這個問題的朋友們具有一定參考價值,需要的朋友下面來一起看看吧。2017-07-07