From 93d78d3d2854c361462843555742bf5ca7f28674 Mon Sep 17 00:00:00 2001 From: Leon Gao Date: Wed, 8 Aug 2018 10:52:47 -0700 Subject: [PATCH] add the confidence interval computation --- .../data/RandomEffectDataSetIntegTest.scala | 4 +- .../photon/ml/data/LocalDataSet.scala | 161 +++++++++++++++++- .../photon/ml/data/RandomEffectDataSet.scala | 30 +++- .../photon/ml/estimators/GameEstimator.scala | 19 ++- .../photon/ml/data/LocalDataSetTest.scala | 113 +++++++++++- .../game/training/GameTrainingDriver.scala | 7 + .../scala/com/linkedin/photon/ml/Types.scala | 4 + 7 files changed, 323 insertions(+), 15 deletions(-) diff --git a/photon-api/src/integTest/scala/com/linkedin/photon/ml/data/RandomEffectDataSetIntegTest.scala b/photon-api/src/integTest/scala/com/linkedin/photon/ml/data/RandomEffectDataSetIntegTest.scala index 0ed7305c..92385874 100644 --- a/photon-api/src/integTest/scala/com/linkedin/photon/ml/data/RandomEffectDataSetIntegTest.scala +++ b/photon-api/src/integTest/scala/com/linkedin/photon/ml/data/RandomEffectDataSetIntegTest.scala @@ -123,7 +123,7 @@ class RandomEffectDataSetIntegTest extends SparkTestUtils { Some(activeDataLowerBound)) val partitioner = new RandomEffectDataSetPartitioner(sc.broadcast(partitionMap)) - val randomEffectDataSet = RandomEffectDataSet(rdd, randomEffectDataConfig, partitioner, None) + val randomEffectDataSet = RandomEffectDataSet(rdd, randomEffectDataConfig, partitioner, None, None) val numUniqueRandomEffects = randomEffectDataSet.activeData.keys.count() assertEquals(numUniqueRandomEffects, expectedUniqueRandomEffects) @@ -155,7 +155,7 @@ class RandomEffectDataSetIntegTest extends SparkTestUtils { Some(activeDataLowerBound)) val partitioner = new RandomEffectDataSetPartitioner(sc.broadcast(partitionMap)) - val randomEffectDataSet = RandomEffectDataSet(rdd, randomEffectDataConfig, partitioner, Some(existingIdsRDD)) + val randomEffectDataSet = RandomEffectDataSet(rdd, randomEffectDataConfig, partitioner, Some(existingIdsRDD), None) val numUniqueRandomEffects = randomEffectDataSet.activeData.keys.count() assertEquals(numUniqueRandomEffects, expectedUniqueRandomEffects) diff --git a/photon-api/src/main/scala/com/linkedin/photon/ml/data/LocalDataSet.scala b/photon-api/src/main/scala/com/linkedin/photon/ml/data/LocalDataSet.scala index 64f15def..83201af0 100644 --- a/photon-api/src/main/scala/com/linkedin/photon/ml/data/LocalDataSet.scala +++ b/photon-api/src/main/scala/com/linkedin/photon/ml/data/LocalDataSet.scala @@ -18,11 +18,11 @@ import java.util.Random import scala.collection.{Map, Set, mutable} import scala.reflect.ClassTag - import breeze.linalg.{SparseVector, Vector} import com.linkedin.photon.ml.constants.MathConst import com.linkedin.photon.ml.projector.Projector +import com.linkedin.photon.ml.stat.BasicStatisticalSummary /** * Local dataset implementation. @@ -126,6 +126,36 @@ protected[ml] case class LocalDataSet(dataPoints: Array[(Long, LabeledPoint)]) { } } + /** + * + * Filter features by Ratio Confidence Interval lower bound truncation + * + * @param intervalBound The lower bound of the confidence interval we are gonna use to + * @param percentage The percentage for lookup the fractile of standard normal distribution + * @param globalFeatureShardStats The feature stats for the global population for a certain feature shard + * @return The filtered dataset + */ + + def filterFeaturesByRatioCIBound( + intervalBound: Double, + percentage: Double, + globalFeatureShardStats: BasicStatisticalSummary): LocalDataSet = { + val labelAndFeatures = dataPoints.map { case (_, labeledPoint) => (labeledPoint.label, labeledPoint.features) } + val lowerBounds = LocalDataSet.computeRatioCILowerBound(labelAndFeatures, percentage, globalFeatureShardStats) + val filteredFeaturesIndexSet = lowerBounds + .toArray + .filter(_._2 > intervalBound) + .map(_._1) + .toSet + + val filteredActivities = dataPoints.map { case (id, LabeledPoint(label, features, offset, weight)) => + val filteredFeatures = LocalDataSet.filterFeaturesWithFeatureIndexSet(features, filteredFeaturesIndexSet) + (id, LabeledPoint(label, filteredFeatures, offset, weight)) + } + LocalDataSet(filteredActivities) + } + + /** * Filter features by Pearson correlation score. * @@ -212,6 +242,135 @@ object LocalDataSet { new SparseVector(filteredIndexBuilder.result(), filteredDataBuilder.result(), features.length) } + /** + * Compute Ratio Confidence Interval lower bounds. + * + * @param randomLabelAndFeatures An array of (label, feature) tuples + * @param quartile The quartile score of standard normal distribution + * @param globalFeatureShardStats The global population feature statistics to help compute ratio + */ + + protected[ml] def computeRatioCILowerBound( + randomLabelAndFeatures: Array[(Double, Vector[Double])], + quartile: Double, + globalFeatureShardStats: BasicStatisticalSummary): Map[Int, Double] = { + + val dummyBoundForNonBinaryOrIntercept = 2.0 + val lowerBounds = mutable.Map[Int, Double]() + var globalPopulationNumSamples = globalFeatureShardStats.count + val globalPopulationFirstOrderMeans = globalFeatureShardStats.mean.toArray + val globallFeatureNonZero = globalFeatureShardStats.numNonzeros.toArray + val globalMean = globalFeatureShardStats.mean.toArray + + val randomEffectNumSamples: Long = randomLabelAndFeatures.length.toLong + val randomFeatureFirstOrderSums = randomLabelAndFeatures.map(_._2).toSeq.reduce(_ + _) + + val m: Long = randomEffectNumSamples + val n: Long = globalPopulationNumSamples + val lastColumn = randomFeatureFirstOrderSums.keySet.max + + randomFeatureFirstOrderSums.keySet.foreach { key => + // Do computation on only binary and non-intercept features + if (isBinary(globallFeatureNonZero, globalPopulationNumSamples, globalMean, key) && key != lastColumn) { + var x: Double = randomFeatureFirstOrderSums(key) + var y: Double = globalPopulationFirstOrderMeans(key) * globalPopulationNumSamples + var lowerBound = None: Option[Double] + var upperBound = None: Option[Double] + var py = None: Option[Double] + // deal with extreme cases of x and y + (x, y) match { + case d if (d._1 == 0 && d._2 == 0) => { + lowerBound = Some(0.0) + upperBound = Some(Double.PositiveInfinity) + } + case d if (d._1 == 0 && d._2 != 0) => { + lowerBound = Some(0.0) + x = 0.5 + } + case d if (d._1 != 0 && d._2 == 0) => { + y = 0.5 + py = Some(y / n) + upperBound = Some(Double.PositiveInfinity) + } + case d if (d._1 == m && d._2 == n) => { + x = m - 0.5 + y = n - 0.5 + py = Some(y / n) + } + case _ => None + } + // we have the mean already from the statistics, save computation here + py match { + case Some(py) => None + case _ => py = Some(globalPopulationFirstOrderMeans(key)) + } + val (t,variance) = computeMeanAndVariance(x, y, m, n, py.get) + + upperBound match { + case Some(upperBound) => None + case _ => { + upperBound = Some(computeUpperBound(t, variance, quartile)) + } + } + + lowerBound match { + case Some(lowerBound) => None + case _ => { + lowerBound = Some(computeLowerBound(t, variance, quartile)) + if (t < 1.0) { + lowerBound = Some(1.0 / upperBound.get) + } + } + } + + lowerBounds.update(key, lowerBound.get) + } else{ + // not good for further ranking purpose + // currently set a dummy bound for them to get accepted + + lowerBounds.update(key, dummyBoundForNonBinaryOrIntercept) + } + } + + lowerBounds + } + + protected[ml] def computeMeanAndVariance( + x : Double, + y : Double, + m : Long, + n : Long, + py : Double + ) : (Double,Double) = { + val t = ( x / m ) / py + val variance = 1.0 / x - 1.0 / m + 1.0 / y - 1.0 / n + (t,variance) + } + + protected[ml] def computeLowerBound( + t: Double, + variance: Double, + quartile: Double + ): Double = { + t * math.exp(-math.sqrt(variance) * quartile) + } + + protected[ml] def computeUpperBound( + t: Double, + variance: Double, + quartile: Double + ): Double = { + t * math.exp(math.sqrt(variance) * quartile) + } + + protected[ml] def isBinary( + globallFeatureNonZero: Array[Double], + globalPopulationNumSamples: Long, + globalMean: Array[Double], + key: Int + ): Boolean = { + (globallFeatureNonZero(key) * 1.0) / globalPopulationNumSamples == globalMean(key) + } /** * Compute Pearson correlation scores. * diff --git a/photon-api/src/main/scala/com/linkedin/photon/ml/data/RandomEffectDataSet.scala b/photon-api/src/main/scala/com/linkedin/photon/ml/data/RandomEffectDataSet.scala index 1f1f792d..c7e6f424 100644 --- a/photon-api/src/main/scala/com/linkedin/photon/ml/data/RandomEffectDataSet.scala +++ b/photon-api/src/main/scala/com/linkedin/photon/ml/data/RandomEffectDataSet.scala @@ -16,13 +16,12 @@ package com.linkedin.photon.ml.data import scala.collection.Set import scala.util.hashing.byteswap64 - import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.storage.{StorageLevel => SparkStorageLevel} import org.apache.spark.{Partitioner, SparkContext} -import com.linkedin.photon.ml.Types.{FeatureShardId, REId, REType, UniqueSampleId} +import com.linkedin.photon.ml.Types._ import com.linkedin.photon.ml.constants.StorageLevel import com.linkedin.photon.ml.data.scoring.CoordinateDataScores import com.linkedin.photon.ml.spark.{BroadcastLike, RDDLike} @@ -240,7 +239,8 @@ object RandomEffectDataSet { gameDataSet: RDD[(UniqueSampleId, GameDatum)], randomEffectDataConfiguration: RandomEffectDataConfiguration, randomEffectPartitioner: Partitioner, - existingModelKeysRddOpt: Option[RDD[REId]]): RandomEffectDataSet = { + existingModelKeysRddOpt: Option[RDD[REId]], + globalFeatureStats: FeatureShardStatisticsMapOpt): RandomEffectDataSet = { val randomEffectType = randomEffectDataConfiguration.randomEffectType val featureShardId = randomEffectDataConfiguration.featureShardId @@ -252,7 +252,7 @@ object RandomEffectDataSet { randomEffectDataConfiguration, randomEffectPartitioner, existingModelKeysRddOpt) - val activeData = featureSelectionOnActiveData(rawActiveData, randomEffectDataConfiguration) + val activeData = featureSelectionOnActiveData(rawActiveData, randomEffectDataConfiguration, globalFeatureStats) .setName("Active data") .persist(StorageLevel.INFREQUENT_REUSE_RDD_STORAGE_LEVEL) @@ -488,7 +488,8 @@ object RandomEffectDataSet { */ private def featureSelectionOnActiveData( activeData: RDD[(REId, LocalDataSet)], - randomEffectDataConfiguration: RandomEffectDataConfiguration): RDD[(REId, LocalDataSet)] = { + randomEffectDataConfiguration: RandomEffectDataConfiguration, + globalFeatureStats: FeatureShardStatisticsMapOpt): RDD[(REId, LocalDataSet)] = { randomEffectDataConfiguration .numFeaturesToSamplesRatioUpperBound @@ -498,9 +499,24 @@ object RandomEffectDataSet { // In case the above product overflows if (numFeaturesToKeep < 0) numFeaturesToKeep = Int.MaxValue - val filteredLocalDataSet = localDataSet.filterFeaturesByPearsonCorrelationScore(numFeaturesToKeep) + globalFeatureStats match { + case Some(globalFeatureStats) => { + val featureShardId = randomEffectDataConfiguration.featureShardId + val globalFeaturesShardStats = globalFeatureStats.get(featureShardId) + val filteredLocalDataSet = localDataSet.filterFeaturesByRatioCIBound( + 1.0, + 2.575, + globalFeatureStats.get(featureShardId).get) + + filteredLocalDataSet + } + case None => { + val filteredLocalDataSet = localDataSet.filterFeaturesByPearsonCorrelationScore(numFeaturesToKeep) + + filteredLocalDataSet + } + } - filteredLocalDataSet } } .getOrElse(activeData) diff --git a/photon-api/src/main/scala/com/linkedin/photon/ml/estimators/GameEstimator.scala b/photon-api/src/main/scala/com/linkedin/photon/ml/estimators/GameEstimator.scala index f97bffb0..3e93990b 100644 --- a/photon-api/src/main/scala/com/linkedin/photon/ml/estimators/GameEstimator.scala +++ b/photon-api/src/main/scala/com/linkedin/photon/ml/estimators/GameEstimator.scala @@ -15,17 +15,15 @@ package com.linkedin.photon.ml.estimators import scala.language.existentials - import org.apache.spark.SparkContext import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators, Params} import org.apache.spark.ml.util.Identifiable import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame import org.slf4j.Logger - import com.linkedin.photon.ml.TaskType import com.linkedin.photon.ml.TaskType.TaskType -import com.linkedin.photon.ml.Types.{CoordinateId, FeatureShardId, UniqueSampleId} +import com.linkedin.photon.ml.Types.{CoordinateId, FeatureShardId, UniqueSampleId, FeatureShardStatisticsMapOpt} import com.linkedin.photon.ml.algorithm._ import com.linkedin.photon.ml.constants.{MathConst, StorageLevel} import com.linkedin.photon.ml.data._ @@ -43,6 +41,7 @@ import com.linkedin.photon.ml.supervised.classification.{LogisticRegressionModel import com.linkedin.photon.ml.supervised.regression.{LinearRegressionModel, PoissonRegressionModel} import com.linkedin.photon.ml.util._ + /** * Estimator implementation for GAME models. * @@ -57,6 +56,7 @@ class GameEstimator(val sc: SparkContext, implicit val logger: Logger) extends P type SingleNodeLossFunctionConstructor = (PointwiseLossFunction) => SingleNodeGLMLossFunction type DistributedLossFunctionConstructor = (PointwiseLossFunction) => DistributedGLMLossFunction + private implicit val parent: Identifiable = this private val defaultNormalizationContext: NormalizationContextWrapper = NormalizationContextBroadcast(sc.broadcast(NoNormalization())) @@ -129,6 +129,10 @@ class GameEstimator(val sc: SparkContext, implicit val logger: Logger) extends P "Flag to ignore the random effect samples lower bound when encountering a random effect ID without an existing " + "model during warm-start training.") + val featureShardStats: Param[FeatureShardStatisticsMapOpt] = ParamUtils.createParam[FeatureShardStatisticsMapOpt]( + "the global population feature statistics required to compute ratio confidence interval", + "We use this statistics to simply derive the ratio of random effect features and global" + + " ones to do feature selection on per-entity model") // // Initialize object // @@ -166,6 +170,7 @@ class GameEstimator(val sc: SparkContext, implicit val logger: Logger) extends P def setIgnoreThresholdForNewModels(value: Boolean): this.type = set(ignoreThresholdForNewModels, value) + def setFeatureStats(value: FeatureShardStatisticsMapOpt) : this.type = set(featureShardStats, value) // // Params trait extensions // @@ -192,6 +197,7 @@ class GameEstimator(val sc: SparkContext, implicit val logger: Logger) extends P setDefault(computeVariance, false) setDefault(treeAggregateDepth, DEFAULT_TREE_AGGREGATE_DEPTH) setDefault(ignoreThresholdForNewModels, false) + setDefault(featureShardStats, None:FeatureShardStatisticsMapOpt) } /** @@ -535,7 +541,12 @@ class GameEstimator(val sc: SparkContext, implicit val logger: Logger) extends P None } - val rawRandomEffectDataSet = RandomEffectDataSet(gameDataSet, reConfig, partitioner, existingModelKeysRddOpt) + val rawRandomEffectDataSet = RandomEffectDataSet( + gameDataSet, + reConfig, + partitioner, + existingModelKeysRddOpt, + getOrDefault(featureShardStats)) .setName(s"Random Effect Data Set: $coordinateId") .persistRDD(StorageLevel.INFREQUENT_REUSE_RDD_STORAGE_LEVEL) .materialize() diff --git a/photon-api/src/test/scala/com/linkedin/photon/ml/data/LocalDataSetTest.scala b/photon-api/src/test/scala/com/linkedin/photon/ml/data/LocalDataSetTest.scala index 25fdd8da..875439f7 100644 --- a/photon-api/src/test/scala/com/linkedin/photon/ml/data/LocalDataSetTest.scala +++ b/photon-api/src/test/scala/com/linkedin/photon/ml/data/LocalDataSetTest.scala @@ -19,9 +19,12 @@ import java.util.Random import breeze.linalg.{SparseVector, Vector} import org.testng.Assert._ import org.testng.annotations.Test - import com.linkedin.photon.ml.constants.MathConst +import com.linkedin.photon.ml.stat.BasicStatisticalSummary import com.linkedin.photon.ml.test.CommonTestUtils +import com.linkedin.photon.ml.util.VectorUtils +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.ml.linalg.{Vector => SparkMLVector} /** * @@ -64,6 +67,59 @@ class LocalDataSetTest { } } + + @Test(dependsOnGroups = Array[String]("testComputeRatioCILowerBound", "testCore")) + def testFilterFeaturesByRatioCIBound(): Unit ={ + val numSamples = 4 + val numFeature = 7 + val labels = Array(1.0, 4.0, 6.0, 9.0) + // 7 columns of features + // col 1 for x = 0, y = 0 + // col 2 for x = 0, y != 0 + // col 3 for nonBinary + // col 4 for x != 0, y = 0 + // col 5 for x = m, y = n + // col 6 for t > 1 + // col 7 for t < 1 + // col 8 intercept col + val features = Array( + Vector(0.0, 0.0, 3.0, 1.0, 1.0, 1.0, 1.0, 1.0), + Vector(0.0, 0.0, 3.0, 0.0, 1.0, 1.0, 1.0, 1.0), + Vector(0.0, 0.0, 3.0, 0.0, 1.0, 0.0, 0.0, 1.0), + Vector(0.0, 0.0, 3.0, 1.0, 1.0, 0.0, 0.0, 1.0) + ) + val expected = Map( + 0 -> 0.0, + 1 -> 0.0, + 2 -> 2.0, + 3 -> 0.22921243, + 4 -> 0.64467032, + 5 -> 0.24526171, + 6 -> 0.28175655, + 7 -> 2.0) + val labelAndFeatures = labels.zip(features) + + val globalStats = BasicStatisticalSummary( + mean = Vector(0.0, 0.3, 2.0, 0.0, 1.0, 0.4, 0.6, 1.0), + variance = Vector(0.0, 0.0, 2.0, 0.0, 0.0, 1.0, 0.0, 0.0), + count = 10, + numNonzeros = Vector(0.0, 3.0, 2.0, 0.0, 10.0, 4.0, 6.0, 10.0), + max = Vector(0.0, 0.0, 3.0, 1.0, 0.0, 1.0, 1.0, 1.0), + min = Vector(0.0, 0.0, 3.0, 1.0, 0.0, 1.0, 1.0, 1.0), + normL1 = Vector(0.0, 0.0, 2.0, 1.0, 0.0, 1.0, 0.0, 1.0), + normL2 = Vector(0.0, 0.0, 2.0, 1.0, 0.0, 1.0, 0.0, 1.0), + meanAbs = Vector(0.0, 0.0, 2.0, 1.0, 0.0, 1.0, 0.0, 1.0) + ) + val localDataSet = + LocalDataSet( + Array.tabulate(numSamples)(i => (i.toLong, LabeledPoint(labels(i), features(i), offset = 0.0, weight = 1.0)))) + + val filteredDataPoints0 = localDataSet.filterFeaturesByRatioCIBound(1.0, 2.575, globalStats).dataPoints + assertEquals(filteredDataPoints0.length, numSamples) + assertTrue(filteredDataPoints0.forall(_._2.features.activeSize == 2)) + } + + @Test(dependsOnGroups = Array[String]("testPearsonCorrelationScore", "testCore")) def testFilterFeaturesByPearsonCorrelationScore(): Unit = { @@ -192,6 +248,61 @@ class LocalDataSetTest { } } + /** + * Test the Ratio Lower Bound computation + */ + @Test(groups = Array[String]("testComputeRatioCILowerBound", "testCore")) + def testComputeRatioCILowerBound(): Unit = { + val labels = Array(1.0, 4.0, 6.0, 9.0) + // 7 columns of features + // col 1 for x = 0, y = 0 + // col 2 for x = 0, y != 0 + // col 3 for nonBinary + // col 4 for x != 0, y = 0 + // col 5 for x = m, y = n + // col 6 for t > 1 + // col 7 for t < 1 + // col 8 intercept col + val features = Array( + Vector(0.0, 0.0, 3.0, 1.0, 1.0, 1.0, 1.0, 1.0), + Vector(0.0, 0.0, 3.0, 0.0, 1.0, 1.0, 1.0, 1.0), + Vector(0.0, 0.0, 3.0, 0.0, 1.0, 0.0, 0.0, 1.0), + Vector(0.0, 0.0, 3.0, 1.0, 1.0, 0.0, 0.0, 1.0) + ) + val expected = Map( + 0 -> 0.0, + 1 -> 0.0, + 2 -> 2.0, + 3 -> 0.22921243, + 4 -> 0.64467032, + 5 -> 0.24526171, + 6 -> 0.28175655, + 7 -> 2.0) + val labelAndFeatures = labels.zip(features) + + val globalStats = BasicStatisticalSummary( + mean = Vector(0.0, 0.3, 2.0, 0.0, 1.0, 0.4, 0.6, 1.0), + variance = Vector(0.0, 0.0, 2.0, 0.0, 0.0, 1.0, 0.0, 0.0), + count = 10, + numNonzeros = Vector(0.0, 3.0, 2.0, 0.0, 10.0, 4.0, 6.0, 10.0), + max = Vector(0.0, 0.0, 3.0, 1.0, 0.0, 1.0, 1.0, 1.0), + min = Vector(0.0, 0.0, 3.0, 1.0, 0.0, 1.0, 1.0, 1.0), + normL1 = Vector(0.0, 0.0, 2.0, 1.0, 0.0, 1.0, 0.0, 1.0), + normL2 = Vector(0.0, 0.0, 2.0, 1.0, 0.0, 1.0, 0.0, 1.0), + meanAbs = Vector(0.0, 0.0, 2.0, 1.0, 0.0, 1.0, 0.0, 1.0) + ) + + val computed = LocalDataSet.computeRatioCILowerBound(labelAndFeatures, 2.575, globalStats) + println(computed) + computed.foreach { case (key, value) => + assertEquals( + expected(key).asInstanceOf[Double], + value, + CommonTestUtils.LOW_PRECISION_TOLERANCE, + s"Computed Ratio Confidence Interval LowerBound score is $value, while the expected value is ${expected(key)} for key $key.") + } + } + /** * Test the Pearson correlation score */ diff --git a/photon-client/src/main/scala/com/linkedin/photon/ml/cli/game/training/GameTrainingDriver.scala b/photon-client/src/main/scala/com/linkedin/photon/ml/cli/game/training/GameTrainingDriver.scala index 4a1887f4..5ea0bdcb 100644 --- a/photon-client/src/main/scala/com/linkedin/photon/ml/cli/game/training/GameTrainingDriver.scala +++ b/photon-client/src/main/scala/com/linkedin/photon/ml/cli/game/training/GameTrainingDriver.scala @@ -420,6 +420,12 @@ object GameTrainingDriver extends GameDriver { calculateAndSaveFeatureShardStats(trainingData, featureIndexMapLoaders) } + var featureShardStatsInMap = None : FeatureShardStatisticsMapOpt + featureShardStats match { + case Some(featureShardStats) => featureShardStatsInMap = Some(featureShardStats.toMap[FeatureShardId,BasicStatisticalSummary]) + case None => None + } + val normalizationContexts = Timed("Prepare normalization contexts") { prepareNormalizationContexts(trainingData, featureIndexMapLoaders, featureShardStats) } @@ -441,6 +447,7 @@ object GameTrainingDriver extends GameDriver { .setCoordinateDescentIterations(getRequiredParam(coordinateDescentIterations)) .setComputeVariance(getOrDefault(computeVariance)) .setIgnoreThresholdForNewModels(getOrDefault(ignoreThresholdForNewModels)) + .setFeatureStats(featureShardStatsInMap) get(inputColumnNames).foreach(estimator.setInputColumnNames) modelOpt.foreach(estimator.setInitialModel) diff --git a/photon-lib/src/main/scala/com/linkedin/photon/ml/Types.scala b/photon-lib/src/main/scala/com/linkedin/photon/ml/Types.scala index 282e47f2..9f06fe8b 100644 --- a/photon-lib/src/main/scala/com/linkedin/photon/ml/Types.scala +++ b/photon-lib/src/main/scala/com/linkedin/photon/ml/Types.scala @@ -14,6 +14,8 @@ */ package com.linkedin.photon.ml +import com.linkedin.photon.ml.stat.BasicStatisticalSummary + /** * Some types that make the code easier to read and more documented. This class should be visible from everywhere in * photon-ml. @@ -41,4 +43,6 @@ object Types { // A "feature shard" is an arbitrary set of "feature bags" // A random effect model corresponds to a single feature shard type FeatureShardId = String + type FeatureShardStatisticsMap = Map[FeatureShardId, BasicStatisticalSummary] + type FeatureShardStatisticsMapOpt = Option[FeatureShardStatisticsMap] }