/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.cf.taste.hadoop.als;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import com.google.common.io.Closeables;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
import org.apache.hadoop.util.ToolRunner;
import org.apache.mahout.cf.taste.hadoop.TasteHadoopUtils;
import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
import org.apache.mahout.cf.taste.impl.common.RunningAverage;
import org.apache.mahout.common.AbstractJob;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.common.mapreduce.MergeVectorsCombiner;
import org.apache.mahout.common.mapreduce.MergeVectorsReducer;
import org.apache.mahout.common.mapreduce.TransposeMapper;
import org.apache.mahout.common.mapreduce.VectorSumReducer;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.SequentialAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.als.AlternatingLeastSquaresSolver;
import org.apache.mahout.math.als.ImplicitFeedbackAlternatingLeastSquaresSolver;
import org.apache.mahout.math.map.OpenIntObjectHashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
/**
*
MapReduce implementation of the two factorization algorithms described in
*
*
"Large-scale Parallel Collaborative Filtering for the Netflix Prize" available at
* http://www.hpl.hp.com/personal/Robert_Schreiber/papers/2008%20AAIM%20Netflix/netflix_aaim08(submitted).pdf.
*
* "Collaborative Filtering for Implicit Feedback Datasets" available at
* http://research.yahoo.com/pub/2433
*
*
* Command line arguments specific to this class are:
*
*
* - --input (path): Directory containing one or more text files with the dataset
* - --output (path): path where output should go
* - --lambda (double): regularization parameter to avoid overfitting
* - --userFeatures (path): path to the user feature matrix
* - --itemFeatures (path): path to the item feature matrix
*
*/
public class ParallelALSFactorizationJob extends AbstractJob {
private static final Logger log = LoggerFactory.getLogger(ParallelALSFactorizationJob.class);
static final String NUM_FEATURES = ParallelALSFactorizationJob.class.getName() + ".numFeatures";
static final String LAMBDA = ParallelALSFactorizationJob.class.getName() + ".lambda";
static final String ALPHA = ParallelALSFactorizationJob.class.getName() + ".alpha";
static final String FEATURE_MATRIX = ParallelALSFactorizationJob.class.getName() + ".featureMatrix";
private boolean implicitFeedback;
private int numIterations;
private int numFeatures;
private double lambda;
private double alpha;
public static void main(String[] args) throws Exception {
ToolRunner.run(new ParallelALSFactorizationJob(), args);
}
@Override
public int run(String[] args) throws Exception {
addInputOption();
addOutputOption();
addOption("lambda", null, "regularization parameter", true);
addOption("implicitFeedback", null, "data consists of implicit feedback?", String.valueOf(false));
addOption("alpha", null, "confidence parameter (only used on implicit feedback)", String.valueOf(40));
addOption("numFeatures", null, "dimension of the feature space", true);
addOption("numIterations", null, "number of iterations", true);
Map> parsedArgs = parseArguments(args);
if (parsedArgs == null) {
return -1;
}
numFeatures = Integer.parseInt(getOption("numFeatures"));
numIterations = Integer.parseInt(getOption("numIterations"));
lambda = Double.parseDouble(getOption("lambda"));
alpha = Double.parseDouble(getOption("alpha"));
implicitFeedback = Boolean.parseBoolean(getOption("implicitFeedback"));
/*
* compute the factorization A = U M'
*
* where A (users x items) is the matrix of known ratings
* U (users x features) is the representation of users in the feature space
* M (items x features) is the representation of items in the feature space
*/
/* create A' */
Job itemRatings = prepareJob(getInputPath(), pathToItemRatings(),
TextInputFormat.class, ItemRatingVectorsMapper.class, IntWritable.class,
VectorWritable.class, VectorSumReducer.class, IntWritable.class,
VectorWritable.class, SequenceFileOutputFormat.class);
itemRatings.setCombinerClass(VectorSumReducer.class);
boolean succeeded = itemRatings.waitForCompletion(true);
if (!succeeded)
return -1;
/* create A */
Job userRatings = prepareJob(pathToItemRatings(), pathToUserRatings(),
TransposeMapper.class, IntWritable.class, VectorWritable.class, MergeVectorsReducer.class, IntWritable.class,
VectorWritable.class);
userRatings.setCombinerClass(MergeVectorsCombiner.class);
succeeded = userRatings.waitForCompletion(true);
if (!succeeded)
return -1;
//TODO this could be fiddled into one of the upper jobs
Job averageItemRatings = prepareJob(pathToItemRatings(), getTempPath("averageRatings"),
AverageRatingMapper.class, IntWritable.class, VectorWritable.class, MergeVectorsReducer.class,
IntWritable.class, VectorWritable.class);
averageItemRatings.setCombinerClass(MergeVectorsCombiner.class);
succeeded = averageItemRatings.waitForCompletion(true);
if (!succeeded)
return -1;
Vector averageRatings = ALSUtils.readFirstRow(getTempPath("averageRatings"), getConf());
/* create an initial M */
initializeM(averageRatings);
for (int currentIteration = 0; currentIteration < numIterations; currentIteration++) {
/* broadcast M, read A row-wise, recompute U row-wise */
log.info("Recomputing U (iteration {}/{})", currentIteration, numIterations);
runSolver(pathToUserRatings(), pathToU(currentIteration), pathToM(currentIteration - 1));
/* broadcast U, read A' row-wise, recompute M row-wise */
log.info("Recomputing M (iteration {}/{})", currentIteration, numIterations);
runSolver(pathToItemRatings(), pathToM(currentIteration), pathToU(currentIteration));
}
return 0;
}
private void initializeM(Vector averageRatings) throws IOException {
Random random = RandomUtils.getRandom();
FileSystem fs = FileSystem.get(pathToM(-1).toUri(), getConf());
SequenceFile.Writer writer = null;
try {
writer = new SequenceFile.Writer(fs, getConf(), new Path(pathToM(-1), "part-m-00000"), IntWritable.class,
VectorWritable.class);
Iterator averages = averageRatings.iterateNonZero();
while (averages.hasNext()) {
Vector.Element e = averages.next();
Vector row = new DenseVector(numFeatures);
row.setQuick(0, e.get());
for (int m = 1; m < numFeatures; m++) {
row.setQuick(m, random.nextDouble());
}
writer.append(new IntWritable(e.index()), new VectorWritable(row));
}
} finally {
Closeables.closeQuietly(writer);
}
}
static class ItemRatingVectorsMapper extends Mapper {
@Override
protected void map(LongWritable offset, Text line, Context ctx) throws IOException, InterruptedException {
String[] tokens = TasteHadoopUtils.splitPrefTokens(line.toString());
int userID = Integer.parseInt(tokens[0]);
int itemID = Integer.parseInt(tokens[1]);
float rating = Float.parseFloat(tokens[2]);
Vector ratings = new RandomAccessSparseVector(Integer.MAX_VALUE, 1);
ratings.set(userID, rating);
ctx.write(new IntWritable(itemID), new VectorWritable(ratings, true));
}
}
private void runSolver(Path ratings, Path output, Path pathToUorI)
throws ClassNotFoundException, IOException, InterruptedException {
Class extends Mapper> solverMapper = implicitFeedback ?
SolveImplicitFeedbackMapper.class : SolveExplicitFeedbackMapper.class;
Job solverForUorI = prepareJob(ratings, output, SequenceFileInputFormat.class, solverMapper, IntWritable.class,
VectorWritable.class, SequenceFileOutputFormat.class);
Configuration solverConf = solverForUorI.getConfiguration();
solverConf.set(LAMBDA, String.valueOf(lambda));
solverConf.set(ALPHA, String.valueOf(alpha));
solverConf.setInt(NUM_FEATURES, numFeatures);
solverConf.set(FEATURE_MATRIX, pathToUorI.toString());
boolean succeeded = solverForUorI.waitForCompletion(true);
if (!succeeded)
throw new IllegalStateException("Job failed!");
}
static class SolveExplicitFeedbackMapper extends Mapper {
private double lambda;
private int numFeatures;
private OpenIntObjectHashMap UorM;
private AlternatingLeastSquaresSolver solver;
@Override
protected void setup(Mapper.Context ctx) throws IOException, InterruptedException {
lambda = Double.parseDouble(ctx.getConfiguration().get(LAMBDA));
numFeatures = ctx.getConfiguration().getInt(NUM_FEATURES, -1);
solver = new AlternatingLeastSquaresSolver();
Path UOrIPath = new Path(ctx.getConfiguration().get(FEATURE_MATRIX));
UorM = ALSUtils.readMatrixByRows(UOrIPath, ctx.getConfiguration());
Preconditions.checkArgument(numFeatures > 0, "numFeatures was not set correctly!");
}
@Override
protected void map(IntWritable userOrItemID, VectorWritable ratingsWritable, Context ctx)
throws IOException, InterruptedException {
Vector ratings = new SequentialAccessSparseVector(ratingsWritable.get());
List featureVectors = Lists.newArrayList();
Iterator interactions = ratings.iterateNonZero();
while (interactions.hasNext()) {
int index = interactions.next().index();
featureVectors.add(UorM.get(index));
}
Vector uiOrmj = solver.solve(featureVectors, ratings, lambda, numFeatures);
ctx.write(userOrItemID, new VectorWritable(uiOrmj));
}
}
static class SolveImplicitFeedbackMapper extends Mapper {
private ImplicitFeedbackAlternatingLeastSquaresSolver solver;
@Override
protected void setup(Mapper.Context ctx) throws IOException, InterruptedException {
double lambda = Double.parseDouble(ctx.getConfiguration().get(LAMBDA));
double alpha = Double.parseDouble(ctx.getConfiguration().get(ALPHA));
int numFeatures = ctx.getConfiguration().getInt(NUM_FEATURES, -1);
Path YPath = new Path(ctx.getConfiguration().get(FEATURE_MATRIX));
OpenIntObjectHashMap Y = ALSUtils.readMatrixByRows(YPath, ctx.getConfiguration());
solver = new ImplicitFeedbackAlternatingLeastSquaresSolver(numFeatures, lambda, alpha, Y);
Preconditions.checkArgument(numFeatures > 0, "numFeatures was not set correctly!");
}
@Override
protected void map(IntWritable userOrItemID, VectorWritable ratingsWritable, Context ctx)
throws IOException, InterruptedException {
Vector ratings = new SequentialAccessSparseVector(ratingsWritable.get());
Vector uiOrmj = solver.solve(ratings);
ctx.write(userOrItemID, new VectorWritable(uiOrmj));
}
}
static class AverageRatingMapper extends Mapper {
@Override
protected void map(IntWritable r, VectorWritable v, Context ctx) throws IOException, InterruptedException {
RunningAverage avg = new FullRunningAverage();
Iterator elements = v.get().iterateNonZero();
while (elements.hasNext()) {
avg.addDatum(elements.next().get());
}
Vector vector = new RandomAccessSparseVector(Integer.MAX_VALUE, 1);
vector.setQuick(r.get(), avg.getAverage());
ctx.write(new IntWritable(0), new VectorWritable(vector));
}
}
private Path pathToM(int iteration) {
return iteration == numIterations - 1 ? getOutputPath("M") : getTempPath("M-" + iteration);
}
private Path pathToU(int iteration) {
return iteration == numIterations - 1 ? getOutputPath("U") : getTempPath("U-" + iteration);
}
private Path pathToItemRatings() {
return getTempPath("itemRatings");
}
private Path pathToUserRatings() {
return getOutputPath("userRatings");
}
}