- use spark mllib linear regression to do traffic forecast printing training, weight and other coefficients are all NaN
data format:
520221 | 0009 | 0009 | 292 | 000541875150 | 2018 | 04 | 18 | 11 | 3 | 137
520626 | 0038 | 0038 | 520626 | 203030001000 | 2018 | 04 | 18 | 3 | 119
520621 | 0024 | 0024 | 005 | 000530002050 | 2018 | 04 | 18 | 11 | 3 | 91
the last item is labeled traffic flow
2. The code is as follows:
package com.spark;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaDoubleRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.evaluation.RegressionMetrics;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.regression.LinearRegressionModel;
import org.apache.spark.mllib.regression.LinearRegressionWithSGD;
import scala.Tuple2;
public class CarPassRegression {
public static void main(String[] args){
Logger.getLogger("org.apache.spark").setLevel(Level.ERROR);//
//
SparkConf conf= new SparkConf();
conf.setAppName("pass_regression").setMaster("local[*]")
.set("spark.sql.warehouse.dir","file:///");
JavaSparkContext sc = new JavaSparkContext(conf);
String trainDataPath ="E://test_data//target_carPass//traindata//*";
//
JavaRDD<String> rdd= sc.textFile(trainDataPath);
JavaRDD<LabeledPoint> traindata=rdd.map(new Function<String, LabeledPoint>() {
@Override
public LabeledPoint call(String s) throws Exception {
String [] part = s.split("\\|");
//label
double lable =Double.parseDouble(part[part.length-1]);
double [] features = new double[part.length-1];
for(int i=0;i<features.length;iPP){
features[i] =Double.parseDouble(part[i]);
}
return new LabeledPoint(lable, Vectors.dense(features));
}
});
traindata.cache();
/*
*/
int numIterations = 10000; //
double stepSize = 0.000001;//
final LinearRegressionModel model= LinearRegressionWithSGD.
train(JavaRDD.toRDD(traindata),numIterations,stepSize);
System.out.println(model.weights()); //
//
JavaRDD<Tuple2<Double, Double>> valuesAndPreds = traindata.map(
new Function<LabeledPoint, Tuple2<Double, Double>>(){
public Tuple2<Double, Double> call(LabeledPoint point){
double prediction = model.predict(point.features());
return new Tuple2<Double, Double>(prediction, point.label());
}
}
);
//
double MSE = new JavaDoubleRDD(valuesAndPreds.map(
new Function<Tuple2<Double, Double>, Object>(){
public Object call(Tuple2<Double, Double> pair){
return Math.pow(pair._1() - pair._2(), 2.0);
}
}
).rdd()).mean();
System.out.println("training MeanSquared Error = " + MSE);
//
JavaRDD<Tuple2<Object, Object>> valuesAndPreds2= traindata.map(new Function<LabeledPoint, Tuple2<Object, Object>>(){
public Tuple2<Object, Object> call(LabeledPoint point)
throws Exception {
double prediction = model.predict(point.features());
return new Tuple2<Object, Object>(prediction, point.label());
}
});
RegressionMetrics metrics = new RegressionMetrics(JavaRDD.toRDD(valuesAndPreds2));
System.out.println("R2()= "+metrics.r2());
System.out.println("MSE() = "+metrics.meanSquaredError());
System.out.println("RMSE() "+metrics.rootMeanSquaredError());
System.out.println("MAE()= "+metrics.meanAbsoluteError());
//
model.save(sc.sc(), "target/tmp/carPassLinearRegressionWithSGDModel");
LinearRegressionModel sameModel = LinearRegressionModel.load(sc.sc(),
"target/tmp/carPassLinearRegressionWithSGDModel");
}
}
3. Result: