首页 > 学院 > 开发设计 > 正文

mahout bayes中数据划分源码分析

2019-11-06 08:28:38
字体:
来源:转载
供稿:网友

1.    org.apache.mahout.utils.SplitInput

// 完成数据的划分,有串行和并行两种方式

publicclass SplitInput extends AbstractJob {

 

  PRivatestaticfinal Logger log = LoggerFactory.getLogger(SplitInput.class);

 

  //用于测试的数据集大小

  privateinttestSplitSize = -1;

  //用于测试数据集的比例

  privateinttestSplitPct = -1;

  privateintsplitLocation = 100;

  privateinttestRandomSelectionSize = -1;

  privateinttestRandomSelectionPct = -1;

 //用于带划分的数据集

  privateintkeepPct = 100;

  private Charset charset = Charsets.UTF_8;

 

 //读取方式是否采用Sequencefile格式

  privatebooleanuseSequence;

 //是否使用并行方式进行数据划分

  privatebooleanuseMapRed;

  //输入目录

  private Path inputDirectory;

  //用于训练的数据集输出目录,只有串行才用到

  private Path trainingOutputDirectory;

  //用于测试的数据集输出目录,只有串行才用到

  private Path testOutputDirectory;

  //并行划分的输出目录

  private Path mapRedOutputDirectory;

 

  private SplitCallback callback;

 

  @Override

  publicint run(String[] args) throws Exception {

 

 

      splitDirectory();

 

    return 0;

  }

 

  publicstaticvoid main(String[] args) throws Exception {

    ToolRunner.run(new Configuration(), new SplitInput(), args);

  }

 

  publicvoid splitDirectory() throws IOException, ClassNotFoundException, InterruptedException {

    this.splitDirectory(inputDirectory);

  }

 

  publicvoid splitDirectory(Path inputDir) throws IOException, ClassNotFoundException, InterruptedException {

    Configuration conf = getConf();

    splitDirectory(conf, inputDir);

  }

 

  publicvoid splitDirectory(Configuration conf, Path inputDir)

    throws IOException, ClassNotFoundException, InterruptedException {

    FileSystem fs = inputDir.getFileSystem(conf);

    if (fs.getFileStatus(inputDir) == null) {

      thrownew IOException(inputDir + " does not exist");

    }

    if (!fs.getFileStatus(inputDir).isDir()) {

      thrownew IOException(inputDir + " is not a directory");

    }

    //判断使用串行还是并行划分

if (useMapRed) {

  //并行划分

      SplitInputJob.run(conf, inputDir, mapRedOutputDirectory,

            keepPct, testRandomSelectionPct);

    } else {

      //获得输入目录中的所有文件

      FileStatus[] fileStats = fs.listStatus(inputDir, PathFilters.logsCRCFilter());

      //对每个文件都循环划分

      for (FileStatus inputFile : fileStats) {

        if (!inputFile.isDir()) {

          splitFile(inputFile.getPath());

        }

      }

    }

  }

 

  publicvoidsplitFile(Path inputFile) throws IOException {

    Configuration conf = getConf();

    FileSystem fs = inputFile.getFileSystem(conf);

    if (fs.getFileStatus(inputFile) == null) {

      thrownew IOException(inputFile + " does not exist");

    }

    if (fs.getFileStatus(inputFile).isDir()) {

      thrownew IOException(inputFile + " is a directory");

    }

 

 

    Path testOutputFile = new Path(testOutputDirectory, inputFile.getName());

    Path trainingOutputFile = new Path(trainingOutputDirectory, inputFile.getName());

 

    intlineCount = countLines(fs, inputFile, charset);

 

    log.info("{} has {} lines", inputFile.getName(), lineCount);

 

    inttestSplitStart = 0;

    inttestSplitSize = this.testSplitSize; // don't modify state

    BitSet randomSel = null;

 

    if (testRandomSelectionPct > 0 || testRandomSelectionSize > 0) {

      testSplitSize = this.testRandomSelectionSize;

 

      if (testRandomSelectionPct > 0) {

        testSplitSize = Math.round(lineCount * testRandomSelectionPct / 100.0f);

      }

      log.info("{} test split size is {} based on random selection percentage {}",

               inputFile.getName(), testSplitSize, testRandomSelectionPct);

      long[] ridx = newlong[testSplitSize];

      RandomSampler.sample(testSplitSize, lineCount - 1, testSplitSize, 0, ridx, 0, RandomUtils.getRandom());

      randomSel = new BitSet(lineCount);

      for (longidx : ridx) {

        randomSel.set((int) idx + 1);

      }

    } else {

      if (testSplitPct > 0) { // calculate split size based on percentage

        testSplitSize = Math.round(lineCount * testSplitPct / 100.0f);

        log.info("{} test split size is {} based on percentage {}",

                 inputFile.getName(), testSplitSize, testSplitPct);

      } else {

        log.info("{} test split size is {}", inputFile.getName(), testSplitSize);

      }

 

      if (splitLocation > 0) { // calculate start of split based on percentage

        testSplitStart = Math.round(lineCount * splitLocation / 100.0f);

        if (lineCount - testSplitStart < testSplitSize) {

          // adjust split start downwards based on split size.

          testSplitStart = lineCount - testSplitSize;

        }

        log.info("{} test split start is {} based on split location {}",

                 inputFile.getName(), testSplitStart, splitLocation);

      }

 

      if (testSplitStart < 0) {

        thrownew IllegalArgumentException("test split size for " + inputFile + " is too large, it would produce an "

                + "empty training set from the initial set of " + lineCount + " examples");

      } elseif (lineCount - testSplitSize < testSplitSize) {

        log.warn("Test set size for {} may be too large, {} is larger than the number of "

                + "lines remaining in the training set: {}",

                 inputFile, testSplitSize, lineCount - testSplitSize);

      }

    }

    inttrainCount = 0;

    inttestCount = 0;

    if (!useSequence) {

      BufferedReader reader = new BufferedReader(new InputStreamReader(fs.open(inputFile), charset));

      Writer trainingWriter = new OutputStreamWriter(fs.create(trainingOutputFile), charset);

      Writer testWriter = new OutputStreamWriter(fs.create(testOutputFile), charset);

 

 

      try {

 

        String line;

        intpos = 0;

        while ((line = reader.readLine()) != null) {

          pos++;

 

          Writer writer;

          if (testRandomSelectionPct > 0) { // Randomly choose

            writer = randomSel.get(pos) ? testWriter : trainingWriter;

          } else { // Choose based on location

            writer = pos > testSplitStart ? testWriter : trainingWriter;

          }

 

          if (writer == testWriter) {

            if (testCount >= testSplitSize) {

              writer = trainingWriter;

            } else {

              testCount++;

            }

          }

          if (writer == trainingWriter) {

            trainCount++;

          }

          writer.write(line);

          writer.write('/n');

        }

 

      } finally {

        Closeables.close(reader, true);

        Closeables.close(trainingWriter, false);

        Closeables.close(testWriter, false);

      }

    } else {

      SequenceFileIterator<Writable, Writable> iterator =

              new SequenceFileIterator<Writable, Writable>(inputFile, false, fs.getConf());

      SequenceFile.Writer trainingWriter = SequenceFile.createWriter(fs, fs.getConf(), trainingOutputFile,

          iterator.getKeyClass(), iterator.getValueClass());

      SequenceFile.Writer testWriter = SequenceFile.createWriter(fs, fs.getConf(), testOutputFile,

          iterator.getKeyClass(), iterator.getValueClass());

      try {

 

        intpos = 0;

        while (iterator.hasNext()) {

          pos++;

          SequenceFile.Writer writer;

          if (testRandomSelectionPct > 0) { // Randomly choose

            writer = randomSel.get(pos) ? testWriter : trainingWriter;

          } else { // Choose based on location

            writer = pos > testSplitStart ? testWriter : trainingWriter;

          }

 

          if (writer == testWriter) {

            if (testCount >= testSplitSize) {

              writer = trainingWriter;

            } else {

              testCount++;

            }

          }

          if (writer == trainingWriter) {

            trainCount++;

          }

          Pair<Writable, Writable> pair = iterator.next();

          writer.append(pair.getFirst(), pair.getSecond());

        }

 

      } finally {

        Closeables.close(iterator, true);

        Closeables.close(trainingWriter, false);

        Closeables.close(testWriter, false);

      }

    }

    log.info("file: {}, input: {} train: {}, test: {} starting at {}",

             inputFile.getName(), lineCount, trainCount, testCount, testSplitStart);

 

    // testing;

    if (callback != null) {

      callback.splitComplete(inputFile, lineCount, trainCount, testCount, testSplitStart);

    }

  }

 

  publicint getTestSplitSize() {

    returntestSplitSize;

  }

 

  publicvoid setTestSplitSize(inttestSplitSize) {

    this.testSplitSize = testSplitSize;

  }

 

  publicint getTestSplitPct() {

    returntestSplitPct;

  }

 

  publicvoid setTestSplitPct(inttestSplitPct) {

    this.testSplitPct = testSplitPct;

  }

 

  publicvoid setKeepPct(intkeepPct) {

    this.keepPct = keepPct;

  }

  publicvoid setUseMapRed(booleanuseMapRed) {

    this.useMapRed = useMapRed;

  }

 

  publicvoid setMapRedOutputDirectory(Path mapRedOutputDirectory) {

    this.mapRedOutputDirectory = mapRedOutputDirectory;

  }

 

  publicint getSplitLocation() {

    returnsplitLocation;

  }

 

  publicvoid setSplitLocation(intsplitLocation) {

    this.splitLocation = splitLocation;

  }

 

  public Charset getCharset() {

    returncharset;

  }

  publicvoid setCharset(Charset charset) {

    this.charset = charset;

  }

 

  public Path getInputDirectory() {

    returninputDirectory;

  }

 

  publicvoid setInputDirectory(Path inputDir) {

    this.inputDirectory = inputDir;

  }

 

  public Path getTrainingOutputDirectory() {

    returntrainingOutputDirectory;

  }

 

  publicvoid setTrainingOutputDirectory(Path trainingOutputDir) {

    this.trainingOutputDirectory = trainingOutputDir;

  }

 

  public Path getTestOutputDirectory() {

    returntestOutputDirectory;

  }

 

  publicvoid setTestOutputDirectory(Path testOutputDir) {

    this.testOutputDirectory = testOutputDir;

  }

 

  public SplitCallback getCallback() {

    returncallback;

  }

 

  publicvoid setCallback(SplitCallback callback) {

    this.callback = callback;

  }

 

  publicint getTestRandomSelectionSize() {

    returntestRandomSelectionSize;

  }

 

  publicvoid setTestRandomSelectionSize(inttestRandomSelectionSize) {

    this.testRandomSelectionSize = testRandomSelectionSize;

  }

 

  publicint getTestRandomSelectionPct() {

 

    returntestRandomSelectionPct;

  }

 

  publicvoid setTestRandomSelectionPct(intrandomSelectionPct) {

    this.testRandomSelectionPct = randomSelectionPct;

  }

 

 

 

  publicstaticint countLines(FileSystem fs, Path inputFile, Charset charset) throws IOException {

    intlineCount = 0;

    BufferedReader reader = new BufferedReader(new InputStreamReader(fs.open(inputFile), charset));

    try {

      while (reader.readLine() != null) {

        lineCount++;

      }

    } finally {

      Closeables.close(reader, true);

    }

 

    returnlineCount;

  }

 

  publicinterface SplitCallback {

    void splitComplete(Path inputFile, intlineCount, inttrainCount, inttestCount, inttestSplitStart);

  }

 

}

2.    org.apache.mahout.utils.SplitInputJob

//并行划分,为了均匀地将数据集划分成两部分

//job 中设置了自定义的key比较器

//该比较器并不是对key进行比较排序

//而是采用随机的方式,使得不同的key均匀随机的分布在各个reduce中。

publicfinalclass SplitInputJob {

 

  privatestaticfinal String DOWNSAMPLING_FACTOR =

      "SplitInputJob.downsamplingFactor";

  privatestaticfinal String RANDOM_SELECTION_PCT =

      "SplitInputJob.randomSelectionPct";

  privatestaticfinal String TRAINING_TAG = "training";

  privatestaticfinal String TEST_TAG = "test";

 

  private SplitInputJob() {

  }

  //inputPath 表示输入路径

  //outputPath表示输出路径

  //keepPct 表示待划分数据集占总数据集的比例

  //randomSelectionPercent  测试数据占待划分数据集的比例

  publicstaticvoid run(Configuration initialConf, Path inputPath,

      Path outputPath, intkeepPct, floatrandomSelectionPercent)

throws IOException, ClassNotFoundException, InterruptedException {

 

    //downsamplingFactor可以理解为,取数据集的间隔

    intdownsamplingFactor = (int) (100.0 / keepPct);

    initialConf.setInt(DOWNSAMPLING_FACTOR, downsamplingFactor);

    initialConf.setFloat(RANDOM_SELECTION_PCT, randomSelectionPercent);

 

    // Determine class of keys and values

    FileSystem fs = FileSystem.get(initialConf);

   

   //获得序列化文件中key,value 的类名

    SequenceFileDirIterator<? extendsWritableComparable, Writable> iterator =

        new SequenceFileDirIterator<WritableComparable, Writable>(inputPath,

            PathType.LIST, PathFilters.partFilter(), null, false, fs.getConf());

    Class<? extendsWritableComparable> keyClass;

    Class<? extends Writable> valueClass;

    if (iterator.hasNext()) {

      Pair<? extendsWritableComparable, Writable> pair = iterator.next();

      keyClass = pair.getFirst().getClass();

      valueClass = pair.getSecond().getClass();

    } else {

      thrownew IllegalStateException("Couldn't determine class of the input values");

    }

    // 用hadoop旧的API,声明一个job,并配置MultipleOutputs

    JobConf oldApiJob = new JobConf(initialConf);

    MultipleOutputs.addNamedOutput(oldApiJob, TRAINING_TAG,

        org.apache.hadoop.mapred.SequenceFileOutputFormat.class,

        keyClass, valueClass);

    MultipleOutputs.addNamedOutput(oldApiJob, TEST_TAG,

        org.apache.hadoop.mapred.SequenceFileOutputFormat.class,

        keyClass, valueClass);

 

    // 用hadoop的新的API,声明一个job

    Job job = newJob(oldApiJob);

    job.setJarByClass(SplitInputJob.class);

    FileInputFormat.addInputPath(job, inputPath);

    FileOutputFormat.setOutputPath(job, outputPath);

    job.setNumReduceTasks(1);

    job.setInputFormatClass(SequenceFileInputFormat.class);

    job.setOutputFormatClass(SequenceFileOutputFormat.class);

    job.setMapperClass(SplitInputMapper.class);

job.setReducerClass(SplitInputReducer.class);

 

//以SplitInputComparator比较器作为reduce中key排序的比较器

    job.setSortComparatorClass(SplitInputComparator.class);

    job.setOutputKeyClass(keyClass);

    job.setOutputValueClass(valueClass);

    job.submit();

    booleansucceeded = job.waitForCompletion(true);

    if (!succeeded) {

      thrownew IllegalStateException("Job failed!");

    }

  }

 

  publicstaticclass SplitInputMapper extends

      Mapper<WritableComparable<?>, Writable, WritableComparable<?>, Writable> {

 

    privateintdownsamplingFactor;

 

    @Override

    publicvoid setup(Context context) {

      downsamplingFactor =

          context.getConfiguration().getInt(DOWNSAMPLING_FACTOR, 1);

    }

 

    @Override

    publicvoid run(Context context) throws IOException, InterruptedException {

      setup(context);

      inti = 0;

     //对输入的key,value进行过滤,以downsamplingFactor为间隔选取并输出,其他的丢弃

      while (context.nextKeyValue()) {

        if (i % downsamplingFactor == 0) {

          map(context.getCurrentKey(), context.getCurrentValue(), context);

        }

        i++;

      }

      cleanup(context);

    }

 

  }

 

  publicstaticclass SplitInputReducer extends

      Reducer<WritableComparable<?>, Writable, WritableComparable<?>, Writable> {

 

    private MultipleOutputs multipleOutputs;

    private OutputCollector<WritableComparable<?>, Writable> trainingCollector = null;

    private OutputCollector<WritableComparable<?>, Writable> testCollector = null;

    privatefinal Random rnd = RandomUtils.getRandom();

    privatefloatrandomSelectionPercent;

 

    @SuppressWarnings("unchecked")

    @Override

    protectedvoid setup(Context context) throws IOException {

      randomSelectionPercent =

          context.getConfiguration().getFloat(RANDOM_SELECTION_PCT, 0);

      multipleOutputs =

          new MultipleOutputs(new JobConf(context.getConfiguration()));

      trainingCollector = multipleOutputs.getCollector(TRAINING_TAG, null);

      testCollector = multipleOutputs.getCollector(TEST_TAG, null);

    }

 

    @Override

    protectedvoid reduce(WritableComparable<?> key, Iterable<Writable> values,

        Context context) throws IOException, InterruptedException {

      for (Writable value : values) {

        //按照随机的方式对map的输出进行划分,并输出到test和train标记的文件中

        if (rnd.nextInt(100) < randomSelectionPercent) {

          testCollector.collect(key, value);

        } else {

          trainingCollector.collect(key, value);

        }

      }

 

    }

 

    @Override

    protectedvoid cleanup(Context context) throws IOException {

      multipleOutputs.close();

    }

 

  }

// key随机分布,不进行比较排序

  publicstaticclassSplitInputComparatorextends WritableComparator implements Serializable {

 

    privatefinal Random rnd = RandomUtils.getRandom();

 

    protected SplitInputComparator() {

      super(WritableComparable.class);

    }

 

    @Override

    publicint compare(byte[] b1, ints1, intl1, byte[] b2, ints2, intl2) {

      if (rnd.nextBoolean()) {

        return 1;

      } else {

        return -1;

      }

    }

  }

 

}


发表评论 共有条评论
用户名: 密码:
验证码: 匿名发表