Skip to content

Commit

Permalink
add ml code
Browse files Browse the repository at this point in the history
  • Loading branch information
xubo245 committed Aug 7, 2016
1 parent adbb16c commit 464a703
Show file tree
Hide file tree
Showing 26 changed files with 2,203 additions and 79 deletions.
4 changes: 4 additions & 0 deletions file/data/mllib/input/basic/sample_libsvm_data_simple.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
0 128:51 129:159 130:253
1 159:124 160:253 161:255
1 125:145 126:255 127:211
1 153:5 154:63 155:197 181:20
87 changes: 87 additions & 0 deletions src/main/java/org/apache/spark/ml/Example/ExampleETP.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@

//package org.apache.spark.ml;
//
//import java.util.Arrays;
//
//import org.apache.spark.SparkConf;
//import org.apache.spark.api.java.JavaRDD;
//import org.apache.spark.api.java.JavaSparkContext;
//import org.apache.spark.ml.classification.LogisticRegression;
//import org.apache.spark.ml.classification.LogisticRegressionModel;
//import org.apache.spark.ml.param.ParamMap;
//import org.apache.spark.mllib.linalg.Vectors;
//import org.apache.spark.mllib.regression.LabeledPoint;
//import org.apache.spark.sql.DataFrame;
//import org.apache.spark.sql.Row;
//import org.apache.spark.sql.SQLContext;
//
///**
// *
// * @author xingyun.xb
// * @version $Id: ExampleETP.java, v 0.1 2016-07-23 17:58 xingyun.xb Exp $
// */
//public class ExampleETP {
// public static void main(String[] args) {
//
// SparkConf conf = new SparkConf().setAppName("Simple Application").setMaster("local[4]");
// JavaSparkContext sc = new JavaSparkContext(conf);
// SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc);
//
// // Prepare training data.
// // We use LabeledPoint, which is a JavaBean. Spark SQL can convert RDDs of JavaBeans
// // into DataFrames, where it uses the bean metadata to infer the schema.
// DataFrame training = sqlContext
// .createDataFrame((JavaRDD<Object>) Arrays.asList(new LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)),
// new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)),
// new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)),
// new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5))), LabeledPoint.class);
//
// // Create a LogisticRegression instance. This instance is an Estimator.
// LogisticRegression lr = new LogisticRegression();
// // Print out the parameters, documentation, and any default values.
// System.out.println("LogisticRegression parameters:\n" + lr.explainParams() + "\n");
//
// // We may set parameters using setter methods.
// lr.setMaxIter(10).setRegParam(0.01);
//
// // Learn a LogisticRegression model. This uses the parameters stored in lr.
// LogisticRegressionModel model1 = lr.fit(training);
// // Since model1 is a Model (i.e., a Transformer produced by an Estimator),
// // we can view the parameters it used during fit().
// // This prints the parameter (name: value) pairs, where names are unique IDs for this
// // LogisticRegression instance.
// System.out
// .println("Model 1 was fit using parameters: " + model1.parent().extractParamMap());
//
// // We may alternatively specify parameters using a ParamMap.
// ParamMap paramMap = new ParamMap().put(lr.maxIter().w(20)) // Specify 1 Param.
// .put(lr.maxIter(), 30) // This overwrites the original maxIter.
// .put(lr.regParam().w(0.1), lr.threshold().w(0.55)); // Specify multiple Params.
//
// // One can also combine ParamMaps.
// ParamMap paramMap2 = new ParamMap().put(lr.probabilityCol().w("myProbability")); // Change output column name
// ParamMap paramMapCombined = paramMap.$plus$plus(paramMap2);
//
// // Now learn a new model using the paramMapCombined parameters.
// // paramMapCombined overrides all parameters set earlier via lr.set* methods.
// LogisticRegressionModel model2 = lr.fit(training, paramMapCombined);
// System.out
// .println("Model 2 was fit using parameters: " + model2.parent().extractParamMap());
//
// // Prepare test documents.
// DataFrame test = sqlContext
// .createDataFrame((JavaRDD<?>) Arrays.asList(new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)),
// new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)),
// new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5))), LabeledPoint.class);
//
// // Make predictions on test documents using the Transformer.transform() method.
// // LogisticRegression.transform will only use the 'features' column.
// // Note that model2.transform() outputs a 'myProbability' column instead of the usual
// // 'probability' column since we renamed the lr.probabilityCol parameter previously.
// DataFrame results = model2.transform(test);
// for (Row r : results.select("features", "label", "myProbability", "prediction").collect()) {
// System.out.println("(" + r.get(0) + ", " + r.get(1) + ") -> prob=" + r.get(2)
// + ", prediction=" + r.get(3));
// }
// }
//}
106 changes: 106 additions & 0 deletions src/main/java/org/apache/spark/ml/Example/PipelineLearning.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@

//package org.apache.spark.ml.Example;
//
//import org.apache.spark.SparkConf;
//import org.apache.spark.api.java.JavaRDD;
//import org.apache.spark.api.java.JavaSparkContext;
//import org.apache.spark.rdd.RDD;
//import org.apache.spark.sql.SQLContext;
//
//import java.io.Serializable;
//import java.util.Arrays;
//import java.util.List;
//
//import org.apache.spark.ml.Pipeline;
//import org.apache.spark.ml.PipelineModel;
//import org.apache.spark.ml.PipelineStage;
//import org.apache.spark.ml.classification.LogisticRegression;
//import org.apache.spark.ml.feature.HashingTF;
//import org.apache.spark.ml.feature.Tokenizer;
//import org.apache.spark.sql.DataFrame;
//import org.apache.spark.sql.Row;
///**
// *
// * @author xingyun.xb
// * @version $Id: PipelineLearning.java, v 0.1 2016-07-23 18:10 xingyun.xb Exp $
// */
//public class PipelineLearning {
// public static void main(String[] args) {
// SparkConf conf = new SparkConf().setAppName("Simple Application").setMaster("local[4]");
// JavaSparkContext sc = new JavaSparkContext(conf);
// SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc);
//
//
//
//// Prepare training documents, which are labeled.
// DataFrame training = sqlContext.createDataFrame((JavaRDD<LabeledDocument>) Arrays.asList(
// new LabeledDocument(0L, "a b c d e spark", 1.0),
// new LabeledDocument(1L, "b d", 0.0),
// new LabeledDocument(2L, "spark f g h", 1.0),
// new LabeledDocument(3L, "hadoop mapreduce", 0.0)
// ), LabeledDocument.class);
//
//// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.
// Tokenizer tokenizer = new Tokenizer()
// .setInputCol("text")
// .setOutputCol("words");
// HashingTF hashingTF = new HashingTF()
// .setNumFeatures(1000)
// .setInputCol(tokenizer.getOutputCol())
// .setOutputCol("features");
// LogisticRegression lr = new LogisticRegression()
// .setMaxIter(10)
// .setRegParam(0.01);
// Pipeline pipeline = new Pipeline()
// .setStages(new PipelineStage[] {tokenizer, hashingTF, lr});
//
//// Fit the pipeline to training documents.
// PipelineModel model = pipeline.fit(training);
//
//// Prepare test documents, which are unlabeled.
// DataFrame test = sqlContext.createDataFrame((JavaRDD<?>) Arrays.asList(
// new Document(4L, "spark i j k"),
// new Document(5L, "l m n"),
// new Document(6L, "mapreduce spark"),
// new Document(7L, "apache hadoop")
// ), Document.class);
//
//// Make predictions on test documents.
// DataFrame predictions = model.transform(test);
// for (Row r: predictions.select("id", "text", "probability", "prediction").collect()) {
// System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2)
// + ", prediction=" + r.get(3));
// }
// }
//}
//
//
//// Labeled and unlabeled instance types.
//// Spark SQL can infer schema from Java Beans.
//class Document implements Serializable {
// private long id;
// private String text;
//
// public Document(long id, String text) {
// this.id = id;
// this.text = text;
// }
//
// public long getId() { return this.id; }
// public void setId(long id) { this.id = id; }
//
// public String getText() { return this.text; }
// public void setText(String text) { this.text = text; }
//}
//
//class LabeledDocument extends Document implements Serializable {
// private double label;
//
// public LabeledDocument(long id, String text, double label) {
// super(id, text);
// this.label = label;
// }
//
// public double getLabel() { return this.label; }
// public void setLabel(double label) { this.label = label; }
//}
29 changes: 29 additions & 0 deletions src/main/java/org/apache/spark/ml/Example/test.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package org.apache.spark.ml.Example;

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.SQLContext;

import java.util.Arrays;
import java.util.List;

/**
*
* @author xingyun.xb
* @version $Id: Pipeline.java, v 0.1 2016-07-23 18:06 xingyun.xb Exp $
*/
public class test {
public static void main(String[] args) {
SparkConf conf = new SparkConf().setAppName("Simple Application").setMaster("local[4]");
JavaSparkContext sc = new JavaSparkContext(conf);
SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc);

List<Integer> data = Arrays.asList(1, 2, 3, 4, 5);
JavaRDD<Integer> distData = sc.parallelize(data);

System.out.println(distData.count());

sc.stop();
}
}
147 changes: 147 additions & 0 deletions src/main/scala/org/apache/spark/ml/DecisionTrees/DTSuite.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
package org.apache.spark.ml.DecisionTrees

import org.apache.log4j.{Level, Logger}
import org.apache.spark.util.SparkLearningFunSuite

/**
* Created by xingyun.xb on 2016/7/24.
*/
class DTSuite extends SparkLearningFunSuite{


test("Classification Suite"){
Logger.getLogger("org.apache.spark").setLevel(Level.ERROR)
Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.ERROR)

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.DecisionTreeClassifier
import org.apache.spark.ml.classification.DecisionTreeClassificationModel
import org.apache.spark.ml.feature.{StringIndexer, IndexToString, VectorIndexer}
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.mllib.util.MLUtils

// Load and parse the data file, converting it to a DataFrame.
val dataRDD = MLUtils.loadLibSVMFile(sc, "file/data/mllib/input/basic/sample_libsvm_data.txt")
val data=sqlContext.createDataFrame(dataRDD)

data.show()

// Index labels, adding metadata to the label column.
// Fit on whole dataset to include all labels in index.
val labelIndexer = new StringIndexer()
.setInputCol("label")
.setOutputCol("indexedLabel")
.fit(data)
// Automatically identify categorical features, and index them.
val featureIndexer = new VectorIndexer()
.setInputCol("features")
.setOutputCol("indexedFeatures")
.setMaxCategories(4) // features with > 4 distinct values are treated as continuous
.fit(data)

// Split the data into training and test sets (30% held out for testing)
val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))

// Train a DecisionTree model.
val dt = new DecisionTreeClassifier()
.setLabelCol("indexedLabel")
.setFeaturesCol("indexedFeatures")

// Convert indexed labels back to original labels.
val labelConverter = new IndexToString()
.setInputCol("prediction")
.setOutputCol("predictedLabel")
.setLabels(labelIndexer.labels)

// Chain indexers and tree in a Pipeline
val pipeline = new Pipeline()
.setStages(Array(labelIndexer, featureIndexer, dt, labelConverter))

// Train model. This also runs the indexers.
val model = pipeline.fit(trainingData)

// Make predictions.
val predictions = model.transform(testData)

// Select example rows to display.
predictions.select("predictedLabel", "label", "features").show(5)

// Select (prediction, true label) and compute test error
val evaluator = new MulticlassClassificationEvaluator()
.setLabelCol("indexedLabel")
.setPredictionCol("prediction")
.setMetricName("precision")
val accuracy = evaluator.evaluate(predictions)
println("Test Error = " + (1.0 - accuracy))

val treeModel = model.stages(2).asInstanceOf[DecisionTreeClassificationModel]
println("Learned classification tree model:\n" + treeModel.toDebugString)

}

test("Regression Suite"){
Logger.getLogger("org.apache.spark").setLevel(Level.ERROR)
Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.ERROR)

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.regression.DecisionTreeRegressor
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
import org.apache.spark.ml.feature.VectorIndexer
import org.apache.spark.ml.evaluation.RegressionEvaluator
import org.apache.spark.mllib.util.MLUtils

// Load and parse the data file, converting it to a DataFrame.
val dataRDD = MLUtils.loadLibSVMFile(sc, "file/data/mllib/input/basic/sample_libsvm_data.txt")
val data=sqlContext.createDataFrame(dataRDD)

data.show()
// Automatically identify categorical features, and index them.
// Here, we treat features with > 4 distinct values as continuous.
val featureIndexer = new VectorIndexer()
.setInputCol("features")
.setOutputCol("indexedFeatures")
.setMaxCategories(4)
.fit(data)

// Split the data into training and test sets (30% held out for testing)
val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))

// Train a DecisionTree model.
val dt = new DecisionTreeRegressor()
.setLabelCol("label")
.setFeaturesCol("indexedFeatures")

// Chain indexer and tree in a Pipeline
val pipeline = new Pipeline()
.setStages(Array(featureIndexer, dt))

// Train model. This also runs the indexer.
val model = pipeline.fit(trainingData)

// Make predictions.
val predictions = model.transform(testData)

// Select example rows to display.
predictions.select("prediction", "label", "features").show(5)

// Select (prediction, true label) and compute test error
val evaluator = new RegressionEvaluator()
.setLabelCol("label")
.setPredictionCol("prediction")
.setMetricName("rmse")
val rmse = evaluator.evaluate(predictions)
println("Root Mean Squared Error (RMSE) on test data = " + rmse)

val treeModel = model.stages(1).asInstanceOf[DecisionTreeRegressionModel]
println("Learned regression tree model:\n" + treeModel.toDebugString)

}


test("Suite"){
Logger.getLogger("org.apache.spark").setLevel(Level.ERROR)
Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.ERROR)

}

}
Loading

0 comments on commit 464a703

Please sign in to comment.