Skip to content

Commit

Permalink
[jvm-packages] Support Ranker (#10823)
Browse files Browse the repository at this point in the history
  • Loading branch information
wbo4958 committed Sep 21, 2024
1 parent d7599e0 commit 19b55b3
Show file tree
Hide file tree
Showing 6 changed files with 558 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ class GpuXGBoostPlugin extends XGBoostPlugin {
selectedCols.append(col)
}
val input = dataset.select(selectedCols.toArray: _*)
estimator.repartitionIfNeeded(input)
val repartitioned = estimator.repartitionIfNeeded(input)
estimator.sortPartitionIfNeeded(repartitioned)
}

// visible for testing
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@

package ml.dmlc.xgboost4j.scala.spark

import ai.rapids.cudf.Table
import ai.rapids.cudf.{OrderByArg, Table}
import ml.dmlc.xgboost4j.java.CudfColumnBatch
import ml.dmlc.xgboost4j.scala.{DMatrix, QuantileDMatrix, XGBoost => ScalaXGBoost}
import ml.dmlc.xgboost4j.scala.rapids.spark.GpuTestSuite
import ml.dmlc.xgboost4j.scala.rapids.spark.SparkSessionHolder.withSparkSession
import ml.dmlc.xgboost4j.scala.spark.Utils.withResource
import org.apache.spark.ml.linalg.DenseVector
import org.apache.spark.sql.{Dataset, SparkSession}
import org.apache.spark.sql.{Dataset, Row, SparkSession}
import org.apache.spark.SparkConf

import java.io.File
Expand Down Expand Up @@ -94,7 +94,9 @@ class GpuXGBoostPluginSuite extends GpuTestSuite {
}

// spark.rapids.sql.enabled is not set explicitly, default to true
withSparkSession(new SparkConf(), spark => {checkIsEnabled(spark, true)})
withSparkSession(new SparkConf(), spark => {
checkIsEnabled(spark, true)
})

// set spark.rapids.sql.enabled to false
withCpuSparkSession() { spark =>
Expand Down Expand Up @@ -503,6 +505,109 @@ class GpuXGBoostPluginSuite extends GpuTestSuite {
}
}

test("The group col should be sorted in each partition") {
withGpuSparkSession() { spark =>
import spark.implicits._
val df = Ranking.train.toDF("label", "weight", "group", "c1", "c2", "c3")

val xgboostParams: Map[String, Any] = Map(
"device" -> "cuda",
"objective" -> "rank:ndcg"
)
val features = Array("c1", "c2", "c3")
val label = "label"
val group = "group"

val ranker = new XGBoostRanker(xgboostParams)
.setFeaturesCol(features)
.setLabelCol(label)
.setNumWorkers(1)
.setNumRound(1)
.setGroupCol(group)
.setDevice("cuda")

val processedDf = ranker.getPlugin.get.asInstanceOf[GpuXGBoostPlugin].preprocess(ranker, df)
processedDf.rdd.foreachPartition { iter => {
var prevGroup = Int.MinValue
while (iter.hasNext) {
val curr = iter.next()
val group = curr.asInstanceOf[Row].getAs[Int](1)
assert(prevGroup <= group)
prevGroup = group
}
}
}
}
}

test("Ranker: XGBoost-Spark should match xgboost4j") {
withGpuSparkSession() { spark =>
import spark.implicits._

val trainPath = writeFile(Ranking.train.toDF("label", "weight", "group", "c1", "c2", "c3"))
val testPath = writeFile(Ranking.test.toDF("label", "weight", "group", "c1", "c2", "c3"))

val df = spark.read.parquet(trainPath)
val testdf = spark.read.parquet(testPath)

val features = Array("c1", "c2", "c3")
val featuresIndices = features.map(df.schema.fieldIndex)
val label = "label"
val group = "group"

val numRound = 100
val xgboostParams: Map[String, Any] = Map(
"device" -> "cuda",
"objective" -> "rank:ndcg"
)

val ranker = new XGBoostRanker(xgboostParams)
.setFeaturesCol(features)
.setLabelCol(label)
.setNumRound(numRound)
.setLeafPredictionCol("leaf")
.setContribPredictionCol("contrib")
.setGroupCol(group)
.setDevice("cuda")

val xgb4jModel = withResource(new GpuColumnBatch(
Table.readParquet(new File(trainPath)
).orderBy(OrderByArg.asc(df.schema.fieldIndex(group))))) { batch =>
val cb = new CudfColumnBatch(batch.select(featuresIndices),
batch.select(df.schema.fieldIndex(label)), null, null,
batch.select(df.schema.fieldIndex(group)))
val qdm = new QuantileDMatrix(Seq(cb).iterator, ranker.getMissing,
ranker.getMaxBins, ranker.getNthread)
ScalaXGBoost.train(qdm, xgboostParams, numRound)
}

val (xgb4jLeaf, xgb4jContrib, xgb4jPred) = withResource(new GpuColumnBatch(
Table.readParquet(new File(testPath)))) { batch =>
val cb = new CudfColumnBatch(batch.select(featuresIndices), null, null, null, null
)
val qdm = new DMatrix(cb, ranker.getMissing, ranker.getNthread)
(xgb4jModel.predictLeaf(qdm), xgb4jModel.predictContrib(qdm),
xgb4jModel.predict(qdm))
}

val rows = ranker.fit(df).transform(testdf).collect()

// Check Leaf
val xgbSparkLeaf = rows.map(row => row.getAs[DenseVector]("leaf").toArray.map(_.toFloat))
checkEqual(xgb4jLeaf, xgbSparkLeaf)

// Check contrib
val xgbSparkContrib = rows.map(row =>
row.getAs[DenseVector]("contrib").toArray.map(_.toFloat))
checkEqual(xgb4jContrib, xgbSparkContrib)

// Check prediction
val xgbSparkPred = rows.map(row =>
Array(row.getAs[Double]("prediction").toFloat))
checkEqual(xgb4jPred, xgbSparkPred)
}
}

def writeFile(df: Dataset[_]): String = {
def listFiles(directory: String): Array[String] = {
val dir = new File(directory)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,6 @@ object Regression extends TrainTestData {
}

object Ranking extends TrainTestData {
val train = generateRankDataset(300, 10, 555)
val test = generateRankDataset(150, 10, 556)
val train = generateRankDataset(300, 10, 12, 555)
val test = generateRankDataset(150, 10, 12, 556)
}
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,15 @@ private[spark] trait XGBoostEstimator[
}
}

/**
* Sort partition for Ranker issue.
* @param dataset
* @return
*/
private[spark] def sortPartitionIfNeeded(dataset: Dataset[_]): Dataset[_] = {
dataset
}

/**
* Build the columns indices.
*/
Expand Down Expand Up @@ -198,10 +207,10 @@ private[spark] trait XGBoostEstimator[
case p: HasGroupCol => selectCol(p.groupCol, IntegerType)
case _ =>
}
val input = repartitionIfNeeded(dataset.select(selectedCols.toArray: _*))

val columnIndices = buildColumnIndices(input.schema)
(input, columnIndices)
val repartitioned = repartitionIfNeeded(dataset.select(selectedCols.toArray: _*))
val sorted = sortPartitionIfNeeded(repartitioned)
val columnIndices = buildColumnIndices(sorted.schema)
(sorted, columnIndices)
}

/** visible for testing */
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
/*
Copyright (c) 2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package ml.dmlc.xgboost4j.scala.spark

import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.util.{DefaultParamsReadable, Identifiable, MLReadable, MLReader}
import org.apache.spark.ml.xgboost.SparkUtils
import org.apache.spark.sql.Dataset
import ml.dmlc.xgboost4j.scala.Booster
import ml.dmlc.xgboost4j.scala.spark.XGBoostRanker._uid
import ml.dmlc.xgboost4j.scala.spark.params.HasGroupCol
import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams.RANKER_OBJS
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}

class XGBoostRanker(override val uid: String,
private val xgboostParams: Map[String, Any])
extends Predictor[Vector, XGBoostRanker, XGBoostRankerModel]
with XGBoostEstimator[XGBoostRanker, XGBoostRankerModel] with HasGroupCol {

def this() = this(_uid, Map[String, Any]())

def this(uid: String) = this(uid, Map[String, Any]())

def this(xgboostParams: Map[String, Any]) = this(_uid, xgboostParams)

def setGroupCol(value: String): XGBoostRanker = set(groupCol, value)

xgboost2SparkParams(xgboostParams)

/**
* Validate the parameters before training, throw exception if possible
*/
override protected[spark] def validate(dataset: Dataset[_]): Unit = {
super.validate(dataset)

require(isDefinedNonEmpty(groupCol), "groupCol needs to be set")

// If the objective is set explicitly, it must be in RANKER_OBJS
if (isSet(objective)) {
val tmpObj = getObjective
require(RANKER_OBJS.contains(tmpObj),
s"Wrong objective for XGBoostRanker, supported objs: ${RANKER_OBJS.mkString(",")}")
} else {
setObjective("rank:ndcg")
}
}

/**
* Sort partition for Ranker issue.
*
* @param dataset
* @return
*/
override private[spark] def sortPartitionIfNeeded(dataset: Dataset[_]) = {
dataset.sortWithinPartitions(getGroupCol)
}

override protected def createModel(
booster: Booster,
summary: XGBoostTrainingSummary): XGBoostRankerModel = {
new XGBoostRankerModel(uid, booster, Option(summary))
}

override protected def validateAndTransformSchema(
schema: StructType,
fitting: Boolean,
featuresDataType: DataType): StructType =
SparkUtils.appendColumn(schema, $(predictionCol), DoubleType)
}

object XGBoostRanker extends DefaultParamsReadable[XGBoostRanker] {
private val _uid = Identifiable.randomUID("xgbranker")
}

class XGBoostRankerModel private[ml](val uid: String,
val nativeBooster: Booster,
val summary: Option[XGBoostTrainingSummary] = None)
extends PredictionModel[Vector, XGBoostRankerModel]
with RankerRegressorBaseModel[XGBoostRankerModel] with HasGroupCol {

def this(uid: String) = this(uid, null)

def setGroupCol(value: String): XGBoostRankerModel = set(groupCol, value)

override def copy(extra: ParamMap): XGBoostRankerModel = {
val newModel = copyValues(new XGBoostRankerModel(uid, nativeBooster, summary), extra)
newModel.setParent(parent)
}

override def predict(features: Vector): Double = {
val values = predictSingleInstance(features)
values(0)
}
}

object XGBoostRankerModel extends MLReadable[XGBoostRankerModel] {
override def read: MLReader[XGBoostRankerModel] = new ModelReader

private class ModelReader extends XGBoostModelReader[XGBoostRankerModel] {
override def load(path: String): XGBoostRankerModel = {
val xgbModel = loadBooster(path)
val meta = SparkUtils.loadMetadata(path, sc)
val model = new XGBoostRankerModel(meta.uid, xgbModel, None)
meta.getAndSetParams(model)
model
}
}
}
Loading

0 comments on commit 19b55b3

Please sign in to comment.