From d1749a9b75dbe0bf66da5b4483c0c808ab9cccda Mon Sep 17 00:00:00 2001 From: MarcKaminski Date: Thu, 1 Jun 2017 13:53:11 +0200 Subject: [PATCH] #29 Add attributes from inputCol to metadata of discretized column --- .../spark/ml/feature/MDLPDiscretizer.scala | 27 +++++++++++++++---- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/src/main/scala/org/apache/spark/ml/feature/MDLPDiscretizer.scala b/src/main/scala/org/apache/spark/ml/feature/MDLPDiscretizer.scala index 544f94d..4accff6 100644 --- a/src/main/scala/org/apache/spark/ml/feature/MDLPDiscretizer.scala +++ b/src/main/scala/org/apache/spark/ml/feature/MDLPDiscretizer.scala @@ -187,19 +187,36 @@ class DiscretizerModel private[ml] ( */ @Since("2.1.0") override def transform(dataset: Dataset[_]): DataFrame = { - transformSchema(dataset.schema, logging = true) + val newSchema = transformSchema(dataset.schema, logging = true) + val metadata = newSchema.fields.last.metadata val discModel = new feature.DiscretizerModel(splits) val discOp = udf ( (mlVector: Vector) => discModel.transform(OldVectors.fromML(mlVector)).asML ) - dataset.withColumn($(outputCol), discOp(col($(inputCol)))) + dataset.withColumn($(outputCol), discOp(col($(inputCol))).as($(outputCol), metadata)) } override def transformSchema(schema: StructType): StructType = { + val origAttrGroup = AttributeGroup.fromStructField(schema($(inputCol))) + val origFeatureAttributes: Array[Attribute] = if (origAttrGroup.attributes.nonEmpty) { + origAttrGroup.attributes.get.zipWithIndex.map(_._1) + } else { + Array.fill[Attribute](origAttrGroup.size)(NominalAttribute.defaultAttr) + } + val buckets = splits.map(_.sliding(2).map(bucket => bucket.mkString(", ")).toArray) - val featureAttributes: Seq[org.apache.spark.ml.attribute.Attribute] = - for(i <- splits.indices) yield new NominalAttribute( + + val newFeatureAttributes: Seq[org.apache.spark.ml.attribute.Attribute] = if (origAttrGroup.attributes.nonEmpty) { + for (i <- splits.indices) yield new NominalAttribute( + name = origFeatureAttributes(i).name, + index = origFeatureAttributes(i).index, isOrdinal = Some(true), values = Some(buckets(i))) - val newAttributeGroup = new AttributeGroup($(outputCol), featureAttributes.toArray) + } else { + for (i <- splits.indices) yield new NominalAttribute( + isOrdinal = Some(true), + values = Some(buckets(i))) + } + + val newAttributeGroup = new AttributeGroup($(outputCol), newFeatureAttributes.toArray) val outputFields = schema.fields :+ newAttributeGroup.toStructField() StructType(outputFields) }