// 完成数据的划分,有串行和并行两种方式
publicclass SplitInput extends AbstractJob {
PRivatestaticfinal Logger log = LoggerFactory.getLogger(SplitInput.class);
//用于测试的数据集大小
privateinttestSplitSize = -1;
//用于测试数据集的比例
privateinttestSplitPct = -1;
privateintsplitLocation = 100;
privateinttestRandomSelectionSize = -1;
privateinttestRandomSelectionPct = -1;
//用于带划分的数据集
privateintkeepPct = 100;
private Charset charset = Charsets.UTF_8;
//读取方式是否采用Sequencefile格式
privatebooleanuseSequence;
//是否使用并行方式进行数据划分
privatebooleanuseMapRed;
//输入目录
private Path inputDirectory;
//用于训练的数据集输出目录,只有串行才用到
private Path trainingOutputDirectory;
//用于测试的数据集输出目录,只有串行才用到
private Path testOutputDirectory;
//并行划分的输出目录
private Path mapRedOutputDirectory;
private SplitCallback callback;
@Override
publicint run(String[] args) throws Exception {
splitDirectory();
return 0;
}
publicstaticvoid main(String[] args) throws Exception {
ToolRunner.run(new Configuration(), new SplitInput(), args);
}
publicvoid splitDirectory() throws IOException, ClassNotFoundException, InterruptedException {
this.splitDirectory(inputDirectory);
}
publicvoid splitDirectory(Path inputDir) throws IOException, ClassNotFoundException, InterruptedException {
Configuration conf = getConf();
splitDirectory(conf, inputDir);
}
publicvoid splitDirectory(Configuration conf, Path inputDir)
throws IOException, ClassNotFoundException, InterruptedException {
FileSystem fs = inputDir.getFileSystem(conf);
if (fs.getFileStatus(inputDir) == null) {
thrownew IOException(inputDir + " does not exist");
}
if (!fs.getFileStatus(inputDir).isDir()) {
thrownew IOException(inputDir + " is not a directory");
}
//判断使用串行还是并行划分
if (useMapRed) {
//并行划分
SplitInputJob.run(conf, inputDir, mapRedOutputDirectory,
keepPct, testRandomSelectionPct);
} else {
//获得输入目录中的所有文件
FileStatus[] fileStats = fs.listStatus(inputDir, PathFilters.logsCRCFilter());
//对每个文件都循环划分
for (FileStatus inputFile : fileStats) {
if (!inputFile.isDir()) {
splitFile(inputFile.getPath());
}
}
}
}
publicvoidsplitFile(Path inputFile) throws IOException {
Configuration conf = getConf();
FileSystem fs = inputFile.getFileSystem(conf);
if (fs.getFileStatus(inputFile) == null) {
thrownew IOException(inputFile + " does not exist");
}
if (fs.getFileStatus(inputFile).isDir()) {
thrownew IOException(inputFile + " is a directory");
}
Path testOutputFile = new Path(testOutputDirectory, inputFile.getName());
Path trainingOutputFile = new Path(trainingOutputDirectory, inputFile.getName());
intlineCount = countLines(fs, inputFile, charset);
log.info("{} has {} lines", inputFile.getName(), lineCount);
inttestSplitStart = 0;
inttestSplitSize = this.testSplitSize; // don't modify state
BitSet randomSel = null;
if (testRandomSelectionPct > 0 || testRandomSelectionSize > 0) {
testSplitSize = this.testRandomSelectionSize;
if (testRandomSelectionPct > 0) {
testSplitSize = Math.round(lineCount * testRandomSelectionPct / 100.0f);
}
log.info("{} test split size is {} based on random selection percentage {}",
inputFile.getName(), testSplitSize, testRandomSelectionPct);
long[] ridx = newlong[testSplitSize];
RandomSampler.sample(testSplitSize, lineCount - 1, testSplitSize, 0, ridx, 0, RandomUtils.getRandom());
randomSel = new BitSet(lineCount);
for (longidx : ridx) {
randomSel.set((int) idx + 1);
}
} else {
if (testSplitPct > 0) { // calculate split size based on percentage
testSplitSize = Math.round(lineCount * testSplitPct / 100.0f);
log.info("{} test split size is {} based on percentage {}",
inputFile.getName(), testSplitSize, testSplitPct);
} else {
log.info("{} test split size is {}", inputFile.getName(), testSplitSize);
}
if (splitLocation > 0) { // calculate start of split based on percentage
testSplitStart = Math.round(lineCount * splitLocation / 100.0f);
if (lineCount - testSplitStart < testSplitSize) {
// adjust split start downwards based on split size.
testSplitStart = lineCount - testSplitSize;
}
log.info("{} test split start is {} based on split location {}",
inputFile.getName(), testSplitStart, splitLocation);
}
if (testSplitStart < 0) {
thrownew IllegalArgumentException("test split size for " + inputFile + " is too large, it would produce an "
+ "empty training set from the initial set of " + lineCount + " examples");
} elseif (lineCount - testSplitSize < testSplitSize) {
log.warn("Test set size for {} may be too large, {} is larger than the number of "
+ "lines remaining in the training set: {}",
inputFile, testSplitSize, lineCount - testSplitSize);
}
}
inttrainCount = 0;
inttestCount = 0;
if (!useSequence) {
BufferedReader reader = new BufferedReader(new InputStreamReader(fs.open(inputFile), charset));
Writer trainingWriter = new OutputStreamWriter(fs.create(trainingOutputFile), charset);
Writer testWriter = new OutputStreamWriter(fs.create(testOutputFile), charset);
try {
String line;
intpos = 0;
while ((line = reader.readLine()) != null) {
pos++;
Writer writer;
if (testRandomSelectionPct > 0) { // Randomly choose
writer = randomSel.get(pos) ? testWriter : trainingWriter;
} else { // Choose based on location
writer = pos > testSplitStart ? testWriter : trainingWriter;
}
if (writer == testWriter) {
if (testCount >= testSplitSize) {
writer = trainingWriter;
} else {
testCount++;
}
}
if (writer == trainingWriter) {
trainCount++;
}
writer.write(line);
writer.write('/n');
}
} finally {
Closeables.close(reader, true);
Closeables.close(trainingWriter, false);
Closeables.close(testWriter, false);
}
} else {
SequenceFileIterator<Writable, Writable> iterator =
new SequenceFileIterator<Writable, Writable>(inputFile, false, fs.getConf());
SequenceFile.Writer trainingWriter = SequenceFile.createWriter(fs, fs.getConf(), trainingOutputFile,
iterator.getKeyClass(), iterator.getValueClass());
SequenceFile.Writer testWriter = SequenceFile.createWriter(fs, fs.getConf(), testOutputFile,
iterator.getKeyClass(), iterator.getValueClass());
try {
intpos = 0;
while (iterator.hasNext()) {
pos++;
SequenceFile.Writer writer;
if (testRandomSelectionPct > 0) { // Randomly choose
writer = randomSel.get(pos) ? testWriter : trainingWriter;
} else { // Choose based on location
writer = pos > testSplitStart ? testWriter : trainingWriter;
}
if (writer == testWriter) {
if (testCount >= testSplitSize) {
writer = trainingWriter;
} else {
testCount++;
}
}
if (writer == trainingWriter) {
trainCount++;
}
Pair<Writable, Writable> pair = iterator.next();
writer.append(pair.getFirst(), pair.getSecond());
}
} finally {
Closeables.close(iterator, true);
Closeables.close(trainingWriter, false);
Closeables.close(testWriter, false);
}
}
log.info("file: {}, input: {} train: {}, test: {} starting at {}",
inputFile.getName(), lineCount, trainCount, testCount, testSplitStart);
// testing;
if (callback != null) {
callback.splitComplete(inputFile, lineCount, trainCount, testCount, testSplitStart);
}
}
publicint getTestSplitSize() {
returntestSplitSize;
}
publicvoid setTestSplitSize(inttestSplitSize) {
this.testSplitSize = testSplitSize;
}
publicint getTestSplitPct() {
returntestSplitPct;
}
publicvoid setTestSplitPct(inttestSplitPct) {
this.testSplitPct = testSplitPct;
}
publicvoid setKeepPct(intkeepPct) {
this.keepPct = keepPct;
}
publicvoid setUseMapRed(booleanuseMapRed) {
this.useMapRed = useMapRed;
}
publicvoid setMapRedOutputDirectory(Path mapRedOutputDirectory) {
this.mapRedOutputDirectory = mapRedOutputDirectory;
}
publicint getSplitLocation() {
returnsplitLocation;
}
publicvoid setSplitLocation(intsplitLocation) {
this.splitLocation = splitLocation;
}
public Charset getCharset() {
returncharset;
}
publicvoid setCharset(Charset charset) {
this.charset = charset;
}
public Path getInputDirectory() {
returninputDirectory;
}
publicvoid setInputDirectory(Path inputDir) {
this.inputDirectory = inputDir;
}
public Path getTrainingOutputDirectory() {
returntrainingOutputDirectory;
}
publicvoid setTrainingOutputDirectory(Path trainingOutputDir) {
this.trainingOutputDirectory = trainingOutputDir;
}
public Path getTestOutputDirectory() {
returntestOutputDirectory;
}
publicvoid setTestOutputDirectory(Path testOutputDir) {
this.testOutputDirectory = testOutputDir;
}
public SplitCallback getCallback() {
returncallback;
}
publicvoid setCallback(SplitCallback callback) {
this.callback = callback;
}
publicint getTestRandomSelectionSize() {
returntestRandomSelectionSize;
}
publicvoid setTestRandomSelectionSize(inttestRandomSelectionSize) {
this.testRandomSelectionSize = testRandomSelectionSize;
}
publicint getTestRandomSelectionPct() {
returntestRandomSelectionPct;
}
publicvoid setTestRandomSelectionPct(intrandomSelectionPct) {
this.testRandomSelectionPct = randomSelectionPct;
}
publicstaticint countLines(FileSystem fs, Path inputFile, Charset charset) throws IOException {
intlineCount = 0;
BufferedReader reader = new BufferedReader(new InputStreamReader(fs.open(inputFile), charset));
try {
while (reader.readLine() != null) {
lineCount++;
}
} finally {
Closeables.close(reader, true);
}
returnlineCount;
}
publicinterface SplitCallback {
void splitComplete(Path inputFile, intlineCount, inttrainCount, inttestCount, inttestSplitStart);
}
}
//并行划分,为了均匀地将数据集划分成两部分
//job 中设置了自定义的key比较器
//该比较器并不是对key进行比较排序
//而是采用随机的方式,使得不同的key均匀随机的分布在各个reduce中。
publicfinalclass SplitInputJob {
privatestaticfinal String DOWNSAMPLING_FACTOR =
"SplitInputJob.downsamplingFactor";
privatestaticfinal String RANDOM_SELECTION_PCT =
"SplitInputJob.randomSelectionPct";
privatestaticfinal String TRAINING_TAG = "training";
privatestaticfinal String TEST_TAG = "test";
private SplitInputJob() {
}
//inputPath 表示输入路径
//outputPath表示输出路径
//keepPct 表示待划分数据集占总数据集的比例
//randomSelectionPercent 测试数据占待划分数据集的比例
publicstaticvoid run(Configuration initialConf, Path inputPath,
Path outputPath, intkeepPct, floatrandomSelectionPercent)
throws IOException, ClassNotFoundException, InterruptedException {
//downsamplingFactor可以理解为,取数据集的间隔
intdownsamplingFactor = (int) (100.0 / keepPct);
initialConf.setInt(DOWNSAMPLING_FACTOR, downsamplingFactor);
initialConf.setFloat(RANDOM_SELECTION_PCT, randomSelectionPercent);
// Determine class of keys and values
FileSystem fs = FileSystem.get(initialConf);
//获得序列化文件中key,value 的类名
SequenceFileDirIterator<? extendsWritableComparable, Writable> iterator =
new SequenceFileDirIterator<WritableComparable, Writable>(inputPath,
PathType.LIST, PathFilters.partFilter(), null, false, fs.getConf());
Class<? extendsWritableComparable> keyClass;
Class<? extends Writable> valueClass;
if (iterator.hasNext()) {
Pair<? extendsWritableComparable, Writable> pair = iterator.next();
keyClass = pair.getFirst().getClass();
valueClass = pair.getSecond().getClass();
} else {
thrownew IllegalStateException("Couldn't determine class of the input values");
}
// 用hadoop旧的API,声明一个job,并配置MultipleOutputs
JobConf oldApiJob = new JobConf(initialConf);
MultipleOutputs.addNamedOutput(oldApiJob, TRAINING_TAG,
org.apache.hadoop.mapred.SequenceFileOutputFormat.class,
keyClass, valueClass);
MultipleOutputs.addNamedOutput(oldApiJob, TEST_TAG,
org.apache.hadoop.mapred.SequenceFileOutputFormat.class,
keyClass, valueClass);
// 用hadoop的新的API,声明一个job
Job job = newJob(oldApiJob);
job.setJarByClass(SplitInputJob.class);
FileInputFormat.addInputPath(job, inputPath);
FileOutputFormat.setOutputPath(job, outputPath);
job.setNumReduceTasks(1);
job.setInputFormatClass(SequenceFileInputFormat.class);
job.setOutputFormatClass(SequenceFileOutputFormat.class);
job.setMapperClass(SplitInputMapper.class);
job.setReducerClass(SplitInputReducer.class);
//以SplitInputComparator比较器作为reduce中key排序的比较器
job.setSortComparatorClass(SplitInputComparator.class);
job.setOutputKeyClass(keyClass);
job.setOutputValueClass(valueClass);
job.submit();
booleansucceeded = job.waitForCompletion(true);
if (!succeeded) {
thrownew IllegalStateException("Job failed!");
}
}
publicstaticclass SplitInputMapper extends
Mapper<WritableComparable<?>, Writable, WritableComparable<?>, Writable> {
privateintdownsamplingFactor;
@Override
publicvoid setup(Context context) {
downsamplingFactor =
context.getConfiguration().getInt(DOWNSAMPLING_FACTOR, 1);
}
@Override
publicvoid run(Context context) throws IOException, InterruptedException {
setup(context);
inti = 0;
//对输入的key,value进行过滤,以downsamplingFactor为间隔选取并输出,其他的丢弃
while (context.nextKeyValue()) {
if (i % downsamplingFactor == 0) {
map(context.getCurrentKey(), context.getCurrentValue(), context);
}
i++;
}
cleanup(context);
}
}
publicstaticclass SplitInputReducer extends
Reducer<WritableComparable<?>, Writable, WritableComparable<?>, Writable> {
private MultipleOutputs multipleOutputs;
private OutputCollector<WritableComparable<?>, Writable> trainingCollector = null;
private OutputCollector<WritableComparable<?>, Writable> testCollector = null;
privatefinal Random rnd = RandomUtils.getRandom();
privatefloatrandomSelectionPercent;
@SuppressWarnings("unchecked")
@Override
protectedvoid setup(Context context) throws IOException {
randomSelectionPercent =
context.getConfiguration().getFloat(RANDOM_SELECTION_PCT, 0);
multipleOutputs =
new MultipleOutputs(new JobConf(context.getConfiguration()));
trainingCollector = multipleOutputs.getCollector(TRAINING_TAG, null);
testCollector = multipleOutputs.getCollector(TEST_TAG, null);
}
@Override
protectedvoid reduce(WritableComparable<?> key, Iterable<Writable> values,
Context context) throws IOException, InterruptedException {
for (Writable value : values) {
//按照随机的方式对map的输出进行划分,并输出到test和train标记的文件中
if (rnd.nextInt(100) < randomSelectionPercent) {
testCollector.collect(key, value);
} else {
trainingCollector.collect(key, value);
}
}
}
@Override
protectedvoid cleanup(Context context) throws IOException {
multipleOutputs.close();
}
}
// key随机分布,不进行比较排序
publicstaticclassSplitInputComparatorextends WritableComparator implements Serializable {
privatefinal Random rnd = RandomUtils.getRandom();
protected SplitInputComparator() {
super(WritableComparable.class);
}
@Override
publicint compare(byte[] b1, ints1, intl1, byte[] b2, ints2, intl2) {
if (rnd.nextBoolean()) {
return 1;
} else {
return -1;
}
}
}
}
新闻热点
疑难解答