首页 > 学院 > 开发设计 > 正文

mahout中贝叶斯算法之训练模型

2019-11-06 08:28:36
字体:
来源:转载
供稿:网友

1.    生成模型

 训练样本,对应类org.apache.mahout.classifier.naivebayes.training.TrainNaiveBayesJob。输入是经过mahout seqdirectory和mahout seq2sparse向量化的序列化文件,输出是一个model(训练器)。具体流程:

(1)叠加所有相同label的tfidf vector

通过TrainNaiveBayesJob-IndexInstancesMapper-Reducer实现。该mr使用到了缓存文件labelindex,labelindex存储每个label对应的id。生成的文件是放到summedObservations临时目录下面。这一个过程实际上是得到了每个label下的feature的得分。

(2)叠加上一步中所有label和所有feature的向量

通过TrainNaiveBayesJob-WeightsMapper-Reducer实现,输出到tmp的weights目录下面。将(1)中所有vector相加,得到每个Feature的权重向量weightsPerFeature;将(1)中每个label对应的所有vector的所有元素相加,得到每个Label的权重向量weightsPerLabel。这一过程实际上是计算每个label的得分和每个feature在所有label下的得分。

(3)建立model

新建一个Matrix scoresPerLabelAndFeature,这是一个二维的矩阵,行是weightsPerLabel的大小,列是weightsPerFeature的大小,然后将(1)中生成的每个label下的每个feature的得分填充这个scoresPerLabelAndFeature,这样scoresPerLabelAndFeature实际上就是每个feature在每个label下的得分,加上(2)中生成的weightsPerFeature、weightsPerLabel就可以构造一个NaiveBayesModel,写入到hadoop的model/naiveBayesModel.bin中。

 

publicfinalclass TrainNaiveBayesJob extends AbstractJob {

  PRivatestaticfinal String TRAIN_COMPLEMENTARY = "trainComplementary";

  privatestaticfinal String ALPHA_I = "alphaI";

  privatestaticfinal String LABEL_INDEX = "labelIndex";

  privatestaticfinal String EXTRACT_LABELS = "extractLabels";

  privatestaticfinal String LABELS = "labels";

  publicstaticfinal String WEIGHTS_PER_FEATURE = "__SPF";

  publicstaticfinal String WEIGHTS_PER_LABEL = "__SPL";

  publicstaticfinal String LABEL_THETA_NORMALIZER = "_LTN";

 

  publicstaticfinal String SUMMED_OBSERVATIONS = "summedObservations";

  publicstaticfinal String WEIGHTS = "weights";

  publicstaticfinal String THETAS = "thetas";

 

  publicstaticvoid main(String[] args) throws Exception {

    ToolRunner.run(new Configuration(), new TrainNaiveBayesJob(), args);

  }

 

  @Override

  publicint run(String[] args) throws Exception {

 

    Path labPath;

    String labPathStr = getOption(LABEL_INDEX);

    if (labPathStr != null) {

      labPath = new Path(labPathStr);

    } else {

      labPath = getTempPath(LABEL_INDEX);

}

 

//生成类索引,即给每个类对应一个编号

    longlabelSize = createLabelIndex(labPath);

    floatalphaI = Float.parseFloat(getOption(ALPHA_I));

    booleantrainComplementary = hasOption(TRAIN_COMPLEMENTARY);

 

HadoopUtil.setSerializations(getConf());

//将包含类与类编号的文件放到分布式缓存中

    HadoopUtil.cacheFiles(labPath, getConf());

 

//参照类编号文件,将输入key,value中的key转化成对应的编号

//并对相同类编号的向量累加

    Job indexInstances = prepareJob(getInputPath(),

                                    getTempPath(SUMMED_OBSERVATIONS),

                                    SequenceFileInputFormat.class,

                                    IndexInstancesMapper.class,

                                    IntWritable.class,

                                    VectorWritable.class,

                                    VectorSumReducer.class,

                                    IntWritable.class,

                                    VectorWritable.class,

                                    SequenceFileOutputFormat.class);

    indexInstances.setCombinerClass(VectorSumReducer.class);

    booleansucceeded = indexInstances.waitForCompletion(true);

    if (!succeeded) {

      return -1;

    }

//将上面的结果相加,并将其key标为__SPF, 将上面每个类对应的向量中的元素相加

//并加入到以__SPL为key的向量中,在向量中的位置为该类的编号

    Job weightSummer = prepareJob(getTempPath(SUMMED_OBSERVATIONS),

                                  getTempPath(WEIGHTS),

                                  SequenceFileInputFormat.class,

                                  WeightsMapper.class,

                                  Text.class,

                                  VectorWritable.class,

                                  VectorSumReducer.class,

                                  Text.class,

                                  VectorWritable.class,

                                  SequenceFileOutputFormat.class);

    weightSummer.getConfiguration().set(WeightsMapper.NUM_LABELS, String.valueOf(labelSize));

    weightSummer.setCombinerClass(VectorSumReducer.class);

    succeeded = weightSummer.waitForCompletion(true);

    if (!succeeded) {

      return -1;

    }

 

    //将上面生成的结果,放到分布式缓存中

    HadoopUtil.cacheFiles(getTempPath(WEIGHTS), getConf());

 

    //将alphaI参数写入配置文件中

getConf().setFloat(ThetaMapper.ALPHA_I, alphaI);

 

//生成贝叶斯模型

    NaiveBayesModel naiveBayesModel = BayesUtils.readModelFromDir(getTempPath(), getConf());

naiveBayesModel.validate();

//将贝叶斯模型写到文件中

    naiveBayesModel.serialize(getOutputPath(), getConf());

 

    return 0;

  }

 

  privatelongcreateLabelIndex(Path labPath) throws IOException {

    longlabelSize = 0;

    if (hasOption(LABELS)) {

      Iterable<String> labels = Splitter.on(",").split(getOption(LABELS));

      labelSize = BayesUtils.writeLabelIndex(getConf(), labels, labPath);

    } elseif (hasOption(EXTRACT_LABELS)) {

      Iterable<Pair<Text,IntWritable>> iterable =

          new SequenceFileDirIterable<Text, IntWritable>(getInputPath(),

                                                         PathType.LIST,

                                                         PathFilters.logsCRCFilter(),

                                                         getConf());

     //将类与编号写入文件,并返回类的个数

      labelSize = BayesUtils.writeLabelIndex(getConf(), labPath, iterable);

    }

    returnlabelSize;

  }

}

 

1.1.BayesUtils.readModelFromDir

//从base路径中读取上一个mr生成的结果,并初始化贝叶斯模型

publicstaticNaiveBayesModelreadModelFromDir(Path base, Configuration conf) {

 

    floatalphaI = conf.getFloat(ThetaMapper.ALPHA_I, 1.0f);

 

    Vector scoresPerLabel = null;

    Vector scoresPerFeature = null;

    for (Pair<Text,VectorWritable> record : new SequenceFileDirIterable<Text, VectorWritable>(

        new Path(base, TrainNaiveBayesJob.WEIGHTS), PathType.LIST, PathFilters.partFilter(), conf)) {

      String key = record.getFirst().toString();

      VectorWritable value = record.getSecond();

      if (key.equals(TrainNaiveBayesJob.WEIGHTS_PER_FEATURE)) {

        scoresPerFeature = value.get();

      } elseif (key.equals(TrainNaiveBayesJob.WEIGHTS_PER_LABEL)) {

        scoresPerLabel = value.get();

      }

    }

 

    Preconditions.checkNotNull(scoresPerFeature);

    Preconditions.checkNotNull(scoresPerLabel);

 

    Matrix scoresPerLabelAndFeature = new SparseMatrix(scoresPerLabel.size(), scoresPerFeature.size());

    for (Pair<IntWritable,VectorWritable> entry : new SequenceFileDirIterable<IntWritable,VectorWritable>(

        new Path(base, TrainNaiveBayesJob.SUMMED_OBSERVATIONS), PathType.LIST, PathFilters.partFilter(), conf)) {

      scoresPerLabelAndFeature.assignRow(entry.getFirst().get(), entry.getSecond().get());

    }

 

    Vector perlabelThetaNormalizer = scoresPerLabel.like();

    returnnewNaiveBayesModel(scoresPerLabelAndFeature, scoresPerFeature, scoresPerLabel, perlabelThetaNormalizer,

        alphaI);

  }

1.2.naiveBayesModel.serialize

  publicvoid serialize(Path output, Configuration conf) throws IOException {

FileSystem fs = output.getFileSystem(conf);

 

//输出到文件output/naiveBayesModel.bin中

    FSDataOutputStream out = fs.create(new Path(output, "naiveBayesModel.bin"));

    try {

      out.writeFloat(alphaI);

      VectorWritable.writeVector(out, weightsPerFeature);

      VectorWritable.writeVector(out, weightsPerLabel);

      VectorWritable.writeVector(out, perlabelThetaNormalizer);

      for (introw = 0; row < weightsPerLabelAndFeature.numRows(); row++) {

        VectorWritable.writeVector(out, weightsPerLabelAndFeature.viewRow(row));

      }

    } finally {

      Closeables.close(out, false);

    }

  }

2.    测试模型

使用训练器进行分类。org.apache.mahout.classifier.naivebayes.test.TestNaiveBayesDriver。该过程只有一个MR 叫TestNaiveBayesDriver-BayesTestMapper-Reducer。它的输出key是期望的label,而value是在每一个label下的打分。首先将上面生成的model读出来,建立NaiveBayesModel对象,然后就可以计算待分类向量在每个label下的得分了。根据贝叶斯的特征独立性特点,它实际上是将待分类向量的每个feature在这个label下的分值相加了(因为feature在label下的打分使用log)。根据选择的分类器是ComplementaryNaiveBayesClassifier还是StandardNaiveBayesClassifier, feature在label下的打分的计算公式不一样:

(1)StandardNaiveBayesClassifier:For Bayes

Weight = Log [ ( W-N-Tf-Idf + alpha_i ) / ( Sigma_k + N  ) ]

其中

featureLabelWeight即W-N-Tf-Idf

labelWeight即Sigma_k

numFeatures即N

(2)ComplementaryNaiveBayesClassifier:For CBayes

Weight = Log [ ( Sigma_j - W-N-Tf-Idf + alpha_i ) / ( Sigma_jSigma_k - Sigma_k + N  ) ]

其中

featureWeight即Sigma_j

featureLabelWeight即W-N-Tf-Idf

labelWeight即Sigma_k

totalWeight即Sigma_jSigma_k,是labelWeight中各项之和。

numFeatures即N

 

publicclass TestNaiveBayesDriver extends AbstractJob {

 

  privatestaticfinal Logger log = LoggerFactory.getLogger(TestNaiveBayesDriver.class);

 

  publicstaticfinal String COMPLEMENTARY = "class"; //b for bayes, c for complementary

  privatestaticfinal Pattern SLASH = Pattern.compile("/");

 

  publicstaticvoid main(String[] args) throws Exception {

    ToolRunner.run(new Configuration(), new TestNaiveBayesDriver(), args);

  }

 

  @Override

  publicint run(String[] args) throws Exception {

   

    booleancomplementary = hasOption("testComplementary");

booleansequential = hasOption("runSequential");

//sequential =true 串行进行分类

//否则,并行分类

    if (sequential) {

      FileSystem fs = FileSystem.get(getConf());

      NaiveBayesModel model = NaiveBayesModel.materialize(new Path(getOption("model")), getConf());

      AbstractNaiveBayesClassifier classifier;

      //选折ComplementaryNaiveBayesClassifier分类器

      //还是StandardNaiveBayesClassifier分类器

 

      if (complementary) {

        classifier = new ComplementaryNaiveBayesClassifier(model);

      } else {

        classifier = new StandardNaiveBayesClassifier(model);

      }

      SequenceFile.Writer writer =

          new SequenceFile.Writer(fs, getConf(), getOutputPath(), Text.class, VectorWritable.class);

      Reader reader = new Reader(fs, getInputPath(), getConf());

      Text key = new Text();

      VectorWritable vw = new VectorWritable();

      while (reader.next(key, vw)) {

        //Key是真实的类别,value是在不同类别下的打分,组成的向量

        writer.append(new Text(SLASH.split(key.toString())[1]),

            new VectorWritable(classifier.classifyFull(vw.get())));

      }

      writer.close();

      reader.close();

    } else {

      booleansucceeded = runMapReduce(parsedArgs);

      if (!succeeded) {

        return -1;

      }

    }

   

    //加载类与类编号

    Map<Integer, String> labelMap = BayesUtils.readLabelIndex(getConf(), new Path(getOption("labelIndex")));

 

    //loop over the results and create the confusion matrix

    SequenceFileDirIterable<Text, VectorWritable> dirIterable =

        new SequenceFileDirIterable<Text, VectorWritable>(getOutputPath(),

                                                          PathType.LIST,

                                                          PathFilters.partFilter(),

                                                          getConf());

    ResultAnalyzer analyzer = new ResultAnalyzer(labelMap.values(), "DEFAULT");

    analyzeResults(labelMap, dirIterable, analyzer);

 

    log.info("{} Results: {}", complementary ? "Complementary" : "Standard NB", analyzer);

    return 0;

  }

//并行打分

  privateboolean runMapReduce(Map<String, List<String>> parsedArgs) throws IOException,

      InterruptedException, ClassNotFoundException {

    Path model = new Path(getOption("model"));

    HadoopUtil.cacheFiles(model, getConf());

    //the output key is the expected value, the output value are the scores for all the labels

    Job testJob = prepareJob(getInputPath(), getOutputPath(), SequenceFileInputFormat.class, BayesTestMapper.class,

            Text.class, VectorWritable.class, SequenceFileOutputFormat.class);

 

    booleancomplementary = hasOption("testComplementary"); //or  complementary = parsedArgs.containsKey("--testComplementary");

    testJob.getConfiguration().set(COMPLEMENTARY, String.valueOf(complementary));

    returntestJob.waitForCompletion(true);

  }

 

  privatestaticvoidanalyzeResults(Map<Integer, String> labelMap,

                                     SequenceFileDirIterable<Text, VectorWritable> dirIterable,

                                     ResultAnalyzer analyzer) {

    for (Pair<Text, VectorWritable> pair : dirIterable) {

      intbestIdx = Integer.MIN_VALUE;

      doublebestScore = Long.MIN_VALUE;

      for (Vector.Element element : pair.getSecond().get().all()) {

        if (element.get() > bestScore) {

          bestScore = element.get();

          bestIdx = element.index();

        }

      }

      if (bestIdx != Integer.MIN_VALUE) {

        ClassifierResult classifierResult = new ClassifierResult(labelMap.get(bestIdx), bestScore);

        analyzer.addInstance(pair.getFirst().toString(), classifierResult);

      }

    }

  }

}

 


上一篇:开启博客之旅

下一篇:最大质因数求解

发表评论 共有条评论
用户名: 密码:
验证码: 匿名发表