Skip to content

Commit

Permalink
Merge pull request #31 from MarcKaminski/master
Browse files Browse the repository at this point in the history
#29 Add attributes from inputCol to metadata of discretized column
  • Loading branch information
sramirez authored Jun 1, 2017
2 parents 2db9d58 + d1749a9 commit 6d1a8e5
Showing 1 changed file with 22 additions and 5 deletions.
27 changes: 22 additions & 5 deletions src/main/scala/org/apache/spark/ml/feature/MDLPDiscretizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -187,19 +187,36 @@ class DiscretizerModel private[ml] (
*/
@Since("2.1.0")
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
val newSchema = transformSchema(dataset.schema, logging = true)
val metadata = newSchema.fields.last.metadata
val discModel = new feature.DiscretizerModel(splits)
val discOp = udf ( (mlVector: Vector) => discModel.transform(OldVectors.fromML(mlVector)).asML )
dataset.withColumn($(outputCol), discOp(col($(inputCol))))
dataset.withColumn($(outputCol), discOp(col($(inputCol))).as($(outputCol), metadata))
}

override def transformSchema(schema: StructType): StructType = {
val origAttrGroup = AttributeGroup.fromStructField(schema($(inputCol)))
val origFeatureAttributes: Array[Attribute] = if (origAttrGroup.attributes.nonEmpty) {
origAttrGroup.attributes.get.zipWithIndex.map(_._1)
} else {
Array.fill[Attribute](origAttrGroup.size)(NominalAttribute.defaultAttr)
}

val buckets = splits.map(_.sliding(2).map(bucket => bucket.mkString(", ")).toArray)
val featureAttributes: Seq[org.apache.spark.ml.attribute.Attribute] =
for(i <- splits.indices) yield new NominalAttribute(

val newFeatureAttributes: Seq[org.apache.spark.ml.attribute.Attribute] = if (origAttrGroup.attributes.nonEmpty) {
for (i <- splits.indices) yield new NominalAttribute(
name = origFeatureAttributes(i).name,
index = origFeatureAttributes(i).index,
isOrdinal = Some(true),
values = Some(buckets(i)))
val newAttributeGroup = new AttributeGroup($(outputCol), featureAttributes.toArray)
} else {
for (i <- splits.indices) yield new NominalAttribute(
isOrdinal = Some(true),
values = Some(buckets(i)))
}

val newAttributeGroup = new AttributeGroup($(outputCol), newFeatureAttributes.toArray)
val outputFields = schema.fields :+ newAttributeGroup.toStructField()
StructType(outputFields)
}
Expand Down

0 comments on commit 6d1a8e5

Please sign in to comment.