首页 > 编程 > Java > 正文

基于Spark实现随机森林代码

2019-11-26 08:43:50
字体:
来源:转载
供稿:网友

本文实例为大家分享了基于Spark实现随机森林的具体代码,供大家参考,具体内容如下

public class RandomForestClassficationTest extends TestCase implements Serializable{   /**  *   */  private static final long serialVersionUID = 7802523720751354318L;    class PredictResult implements Serializable{    /**    *     */    private static final long serialVersionUID = -168308887976477219L;    double label;    double prediction;        public PredictResult(double label,double prediction){    this.label = label;    this.prediction = prediction;    }        @Override    public String toString(){      return this.label + " : " + this.prediction ;    }  }      public void test_randomForest() throws JAXBException{      SparkConf sparkConf = new SparkConf();    sparkConf.setAppName("RandomForest");    sparkConf.setMaster("local");        SparkContext sc = new SparkContext(sparkConf);    String dataPath = RandomForestClassficationTest.class.getResource("/").getPath() + "/sample_libsvm_data.txt";        RDD dataSet = MLUtils.loadLibSVMFile(sc, dataPath);    RDD[] rddList = dataSet.randomSplit(new double[]{0.7,0.3},1);        RDD trainingData = rddList[0];    RDD testData = rddList[1];        ClassTag labelPointClassTag = trainingData.elementClassTag();        JavaRDD trainingJavaData = new JavaRDD(trainingData,labelPointClassTag);        int numClasses = 2;    Map categoricalFeatureInfos = new HashMap();    int numTrees = 3;    String featureSubsetStrategy = "auto";    String impurity = "gini";    int maxDepth = 4;    int maxBins = 32;        /**    * 1 numClasses分类个数为2    * 2 numTrees 表示的是随机森林中树的个数    * 3 featureSubsetStrategy    * 4     */    final RandomForestModel model = RandomForest.trainClassifier(trainingJavaData,    numClasses,    categoricalFeatureInfos,    numTrees,    featureSubsetStrategy,    impurity,    maxDepth,    maxBins,    1);     JavaRDD testJavaData = new JavaRDD(testData,testData.elementClassTag());        JavaRDD predictRddResult = testJavaData.map(new Function(){      /**    *     */    private static final long serialVersionUID = 1L;            public PredictResult call(LabeledPoint point) throws Exception {      // TODO Auto-generated method stub      double pointLabel = point.label();      double prediction = model.predict(point.features());      PredictResult result = new PredictResult(pointLabel,prediction);      return result;    }        });     List predictResultList = predictRddResult.collect();    for(PredictResult result:predictResultList){      System.out.println(result.toString());    }          System.out.println(model.toDebugString());    }}

得到的随机森林的展示结果如下:

TreeEnsembleModel classifier with 3 trees  Tree 0:If (feature 435 <= 0.0)If (feature 516 <= 0.0)Predict: 0.0Else (feature 516 > 0.0)Predict: 1.0Else (feature 435 > 0.0)Predict: 1.0Tree 1:If (feature 512 <= 0.0)Predict: 1.0Else (feature 512 > 0.0)Predict: 0.0Tree 2:If (feature 377 <= 1.0)Predict: 0.0Else (feature 377 > 1.0)If (feature 455 <= 0.0)Predict: 1.0Else (feature 455 > 0.0)Predict: 0.0

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持武林网。

发表评论 共有条评论
用户名: 密码:
验证码: 匿名发表