Skip to content

Commit

Permalink
Some changes introduced to adapt the MDLP code to Spark 2.1.0. Some i…
Browse files Browse the repository at this point in the history
…mports have been rewritten, and MLLIB Vectors trasformed to OldVectors.
  • Loading branch information
sramirez committed May 24, 2017
1 parent 8202014 commit 2db9d58
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 20 deletions.
4 changes: 2 additions & 2 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_2.11</artifactId>
<version>1.6.2</version>
<version>2.1.0</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-mllib_2.11</artifactId>
<version>1.6.2</version>
<version>2.1.0</version>
</dependency>
<dependency>
<groupId>org.scala-lang</groupId>
Expand Down
2 changes: 1 addition & 1 deletion project/Build.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ object ProjectBuild extends Build {
organization := "org.apache.spark",
scalaVersion := "2.11.6",
spName := "apache/spark-MDLP-discretization",
sparkVersion := "1.6.2",
sparkVersion := "2.1.0",
sparkComponents += "mllib",
publishMavenStyle := true,
licenses += "Apache-2.0" -> url("http://www.apache.org/licenses/LICENSE-2.0.html"),
Expand Down
37 changes: 23 additions & 14 deletions src/main/scala/org/apache/spark/ml/feature/MDLPDiscretizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,21 @@ package org.apache.spark.ml.feature

import org.apache.hadoop.fs.Path
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml._
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg._
import org.apache.spark.ml.linalg._
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{StructField, StructType}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.ml.attribute._
import org.apache.spark.mllib.linalg.{Vectors => OldVectors}
import org.apache.spark.ml.Estimator
import org.apache.spark.ml.Model
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
import org.apache.spark.sql.types.{DoubleType, StructField, StructType}

/**
* Params for [[MDLPDiscretizer]] and [[DiscretizerModel]].
Expand Down Expand Up @@ -121,18 +125,22 @@ class MDLPDiscretizer (override val uid: String) extends Estimator[DiscretizerMo
/**
* Computes a [[DiscretizerModel]] that contains the cut points (splits) for each input feature.
*/
override def fit(dataset: DataFrame): DiscretizerModel = {
@Since("2.1.0")
override def fit(dataset: Dataset[_]): DiscretizerModel = {

transformSchema(dataset.schema, logging = true)
val input = dataset.select($(labelCol), $(inputCol)).map {
case Row(label: Double, features: Vector) =>
LabeledPoint(label, features)
}.cache() // cache the input to avoid performance warning (see issue #18)
val discretizer = feature.MDLPDiscretizer.train(input, None, $(maxBins), $(maxByPart), $(stoppingCriterion), $(minBinPercentage))
val input: RDD[OldLabeledPoint] =
dataset.select(col($(labelCol)).cast(DoubleType), col($(inputCol))).rdd.map {
case Row(label: Double, features: Vector) =>
OldLabeledPoint(label, OldVectors.fromML(features))
}.cache() // cache the input to avoid performance warning (see issue #18)
val discretizer = feature.MDLPDiscretizer.train(
input, None, $(maxBins), $(maxByPart), $(stoppingCriterion), $(minBinPercentage))
copyValues(new DiscretizerModel(uid, discretizer.thresholds).setParent(this))
}

@Since("1.6.0")
override def transformSchema(schema: StructType): StructType = {
validateParams()
val inputType = schema($(inputCol)).dataType
require(inputType.isInstanceOf[VectorUDT],
s"Input column ${$(inputCol)} must be a vector column")
Expand Down Expand Up @@ -177,17 +185,18 @@ class DiscretizerModel private[ml] (
* NOTE: Vectors to be transformed must be the same length
* as the source vectors given to MDLPDiscretizer.fit.
*/
override def transform(dataset: DataFrame): DataFrame = {
@Since("2.1.0")
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
val discModel = new feature.DiscretizerModel(splits)
val discOp = udf { discModel.transform _ }
val discOp = udf ( (mlVector: Vector) => discModel.transform(OldVectors.fromML(mlVector)).asML )
dataset.withColumn($(outputCol), discOp(col($(inputCol))))
}

override def transformSchema(schema: StructType): StructType = {
validateParams()
val buckets = splits.map(_.sliding(2).map(bucket => bucket.mkString(", ")).toArray)
val featureAttributes: Seq[attribute.Attribute] = for(i <- splits.indices) yield new NominalAttribute(
val featureAttributes: Seq[org.apache.spark.ml.attribute.Attribute] =
for(i <- splits.indices) yield new NominalAttribute(
isOrdinal = Some(true),
values = Some(buckets(i)))
val newAttributeGroup = new AttributeGroup($(outputCol), featureAttributes.toArray)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
package org.apache.spark.mllib.feature

import org.apache.spark.storage.StorageLevel
import org.apache.spark.{Logging, SparkContext}
import org.apache.spark.internal.Logging
import org.apache.spark.SparkContext
import org.apache.spark.rdd._
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.linalg._
Expand Down Expand Up @@ -292,7 +293,10 @@ class MDLPDiscretizer private (val data: RDD[LabeledPoint],
thresholds(k) = if (nFeatures > 0) vth.toArray else Array(Float.PositiveInfinity)
})
logInfo("Number of features with thresholds computed: " + allThresholds.length)
logDebug("thresholds = " + thresholds.map(_.mkString(", ")).mkString(";\n"))



("thresholds = " + thresholds.map(_.mkString(", ")).mkString(";\n"))

new DiscretizerModel(thresholds)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.mllib.feature

import org.apache.spark.Logging
import org.apache.spark.internal.Logging
import FeatureUtils._


Expand Down

0 comments on commit 2db9d58

Please sign in to comment.