Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add mmoe; support multilabel libsvm, multilabelauc #106

Open
wants to merge 1 commit into
base: branch-0.2.0
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions cpp/src/angel/pytorch/angel_torch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,12 @@ JNIEXPORT jfloatArray JNICALL Java_com_tencent_angel_pytorch_Torch_forward
std::vector<torch::jit::IValue> inputs;
std::vector<std::pair<std::string, void *>> ptrs;

int multi_forward_out = 1;
if (angel::jni_map_contain(env, jparams, "multi_forward_out")) {
multi_forward_out =
angel::jni_map_get_int(env, jparams, "multi_forward_out");
}

int batch_size = angel::jni_map_get_int(env, jparams, "batch_size");
// data inputs
inputs.emplace_back(batch_size);
Expand All @@ -282,7 +288,7 @@ JNIEXPORT jfloatArray JNICALL Java_com_tencent_angel_pytorch_Torch_forward
}
auto output = ptr->serving_forward(inputs);
auto output_ptr = output.data_ptr();
DEFINE_JFLOATARRAY(output_ptr, batch_size);
DEFINE_JFLOATARRAY(output_ptr, batch_size * multi_forward_out);

// release java arrays
release_array(env, ptrs, jparams);
Expand All @@ -291,7 +297,7 @@ JNIEXPORT jfloatArray JNICALL Java_com_tencent_angel_pytorch_Torch_forward
add_inputs(env, &inputs, &ptrs, jparams, ptr->get_type());
auto output = ptr->forward(inputs).toTensor();
auto output_ptr = output.data_ptr();
DEFINE_JFLOATARRAY(output_ptr, batch_size);
DEFINE_JFLOATARRAY(output_ptr, batch_size * multi_forward_out);

// release java arrays
release_array(env, ptrs, jparams);
Expand Down Expand Up @@ -603,4 +609,4 @@ JNIEXPORT void JNICALL Java_com_tencent_angel_pytorch_Torch_gcnSave
ptr->save(path);
env->ReleaseStringUTFChars(jpath, path);
release_array(env, ptrs, jparams);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,17 @@ public static Tuple3<CooLongFloatMatrix, long[], String[]> parsePredict(String[]
private static Tuple2<CooLongFloatMatrix, float[]> parseLIBSVM(String[] lines) {
LongArrayList rows = new LongArrayList();
LongArrayList cols = new LongArrayList();
LongArrayList fields = null;
FloatArrayList vals = new FloatArrayList();
float[] targets = new float[lines.length];
FloatArrayList targets = new FloatArrayList();

int index = 0;
for (int i = 0; i < lines.length; i++) {
String[] parts = lines[i].split(" ");
float label = Float.parseFloat(parts[0]);
targets[i] = label;

String[] labels = parts[0].split("#");
for (int l = 0; l < labels.length; l += 1) {
float label = Float.parseFloat(labels[l]);
targets.add(label);
}
for (int j = 1; j < parts.length; j++) {
String[] kv = parts[j].split(":");
long key = Long.parseLong(kv[0]) - 1;
Expand All @@ -75,8 +76,7 @@ private static Tuple2<CooLongFloatMatrix, float[]> parseLIBSVM(String[] lines) {

CooLongFloatMatrix coo = MFactory.cooLongFloatMatrix(rows.toLongArray(),
cols.toLongArray(), vals.toFloatArray(), null);

return new Tuple2<CooLongFloatMatrix, float[]>(coo, targets);
return new Tuple2<CooLongFloatMatrix, float[]>(coo, targets.toFloatArray());
}

private static Tuple3<CooLongFloatMatrix, long[], float[]> parseLIBFFM(String[] lines) {
Expand Down
22 changes: 22 additions & 0 deletions java/src/main/java/com/tencent/angel/pytorch/torch/TorchModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ public class TorchModel implements Serializable {

// load library of torch and torch_angel
static {
System.loadLibrary("torch");
System.loadLibrary("torch_angel");
}

Expand Down Expand Up @@ -242,6 +243,27 @@ public float[] forward(int batchSize, CooLongFloatMatrix batch, float[] bias, fl
return Torch.forward(ptr, params, false);
}

public float[] forward(int batchSize, CooLongFloatMatrix batch, float[] bias, float[] weights, float[] embeddings, int embeddingDim, float[] mats, int[] matSizes, int multiForwardOut) {
Map<String, Object> params = buildParams(batchSize, batch, bias, weights);
params.put("embedding", embeddings);
params.put("embedding_dim", embeddingDim);
params.put("mats", mats);
params.put("mats_sizes", matSizes);
params.put("multi_forward_out", multiForwardOut);
return Torch.forward(ptr, params, false);
}

public float[] forward(int batchSize, CooLongFloatMatrix batch, float[] bias, float[] weights, float[] embeddings, int embeddingDim, float[] mats, int[] matSizes, long[] fields, int multiForwardOut) {
Map<String, Object> params = buildParams(batchSize, batch, bias, weights);
params.put("embedding", embeddings);
params.put("embedding_dim", embeddingDim);
params.put("mats", mats);
params.put("mats_sizes", matSizes);
params.put("fields", fields);
params.put("multi_forward_out", multiForwardOut);
return Torch.forward(ptr, params, false);
}

public float backward(int batchSize, CooLongFloatMatrix batch, float[] bias, float[] weights, float[] targets) {
Map<String, Object> params = buildParams(batchSize, batch, bias, weights);
params.put("targets", targets);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Tencent is pleased to support the open source community by making Angel available.
*
* Copyright (C) 2017-2018 THL A29 Limited, a Tencent company. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
* compliance with the License. You may obtain a copy of the License at
*
* https://opensource.org/licenses/Apache-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License
* is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
* or implied. See the License for the specific language governing permissions and limitations under
* the License.
*
*/
package com.tencent.angel.pytorch.eval

import org.apache.spark.rdd.RDD

import scala.language.implicitConversions

// evaluation for multi-labels
private[pytorch]
abstract class EvaluationM extends Serializable {

def calculate(pairs: RDD[(Double, Double)]): String
}

private[pytorch]
object EvaluationM {

def eval(metrics: Array[String], pairs: RDD[(Double, Double)], numLabels: Int = 1): Map[String, String] = {
metrics.map(name => (name.toLowerCase(), EvaluationM.apply(name, numLabels).calculate(pairs))).toMap
}

def apply(name: String, numLabels: Int = 1): EvaluationM = {
name.toLowerCase match {
case "multi_auc" => new MultiLabelAUC(numLabels)
case "multi_auc_collect" => new MultiLabelAUCCollect(numLabels)
}
}

implicit def pairNumericRDDToPairDoubleRDD[T](rdd: RDD[(T, T)])(implicit num: Numeric[T])
: RDD[(Double, Double)] = {
rdd.map(x => (num.toDouble(x._1), num.toDouble(x._2)))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Tencent is pleased to support the open source community by making Angel available.
*
* Copyright (C) 2017-2018 THL A29 Limited, a Tencent company. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
* compliance with the License. You may obtain a copy of the License at
*
* https://opensource.org/licenses/Apache-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License
* is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
* or implied. See the License for the specific language governing permissions and limitations under
* the License.
*
*/
package com.tencent.angel.pytorch.eval

import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel

class MultiLabelAUC(numLabels: Int) extends EvaluationM {

def calculate_(pairs: RDD[(Double, Double)]): Double = {
// sort by predict
val sorted = pairs.sortBy(f => f._2)
sorted.cache()

val numTotal = sorted.count()
val numPositive = sorted.filter(f => f._1 > 0).count()
val numNegetive = numTotal - numPositive

// calculate the summation of ranks for positive samples
val sumRanks_ = sorted.zipWithIndex().filter(f => f._1._1.toInt == 1).persist(StorageLevel.MEMORY_ONLY)
val sumRanks = sumRanks_.map(f => f._2 + 1).reduce(_ + _)
val auc = sumRanks * 1.0 / numPositive / numNegetive - (numPositive + 1.0) / 2.0 / numNegetive

sorted.unpersist()
sumRanks_.unpersist()
auc
}

override
def calculate(pairs: RDD[(Double, Double)]): String = {
pairs.persist(StorageLevel.MEMORY_ONLY)
val data = pairs.mapPartitions { part =>
val p = part.toArray
p.sliding(numLabels, numLabels).map(_.toArray)
}.persist(StorageLevel.MEMORY_ONLY)
val re = new Array[Double](numLabels)
var i = 0
while (i < numLabels) {
re(i) = calculate_(data.map(_(i)))
i += 1
}
re.mkString(",")
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Tencent is pleased to support the open source community by making Angel available.
*
* Copyright (C) 2017-2018 THL A29 Limited, a Tencent company. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
* compliance with the License. You may obtain a copy of the License at
*
* https://opensource.org/licenses/Apache-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License
* is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
* or implied. See the License for the specific language governing permissions and limitations under
* the License.
*
*/
package com.tencent.angel.pytorch.eval

import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel

/**
*collect predict results to driver to calculate multi-label auc
*this is suitable when num of train/predict samples is acceptable for collecting, eg. less than 10,000,000
*/
class MultiLabelAUCCollect(numLabels: Int) extends EvaluationM {

def calculate_(pairs: Array[(Double, Double)]): Double = {
// sort by predict
val sorted = pairs.sortBy(f => f._2)

val numTotal = sorted.length
val numPositive = sorted.count(f => f._1 > 0)
val numNegative = numTotal - numPositive

// calculate the summation of ranks for positive samples
val sumRanks_ = sorted.zipWithIndex.filter(f => f._1._1.toInt == 1)
val sumRanks = sumRanks_.map(f => f._2.toLong + 1).sum
val auc = sumRanks * 1.0 / numPositive / numNegative - (numPositive + 1.0) / 2.0 / numNegative
auc
}

override
def calculate(pairs: RDD[(Double, Double)]): String = {
pairs.persist(StorageLevel.MEMORY_ONLY)
val data = pairs.mapPartitions { part =>
val p = part.toArray
p.sliding(numLabels, numLabels).map(_.toArray)
}.collect()
val re = new Array[Double](numLabels)
var i = 0
while (i < numLabels) {
re(i) = calculate_(data.map(_(i)))
i += 1
}
re.mkString(",")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ object RecommendationExample {
val torchOutputModelPath = params.getOrElse("torchOutputModelPath", "")
val rowType = params.getOrElse("rowType", "T_FLOAT_DENSE")
val evals = params.getOrElse("evals", "auc")
val numLabels = params.getOrElse("numLabels", "1").toInt
val level = params.getOrElse("storageLevel", "memory_only").toUpperCase()

val recommendation = new Recommendation(torchModelPath)
Expand All @@ -60,6 +61,7 @@ object RecommendationExample {
recommendation.setDecay(decay)
recommendation.setAsync(async)
recommendation.setEvaluations(evals)
recommendation.setNumLabels(numLabels)
recommendation.setStorageLevel(StorageLevel.fromString(level))

var numPartitions = start(mode)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* Tencent is pleased to support the open source community by making Angel available.
*
* Copyright (C) 2017-2018 THL A29 Limited, a Tencent company. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
* compliance with the License. You may obtain a copy of the License at
*
* https://opensource.org/licenses/Apache-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License
* is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
* or implied. See the License for the specific language governing permissions and limitations under
* the License.
*
*/
package com.tencent.angel.pytorch.params

import org.apache.spark.ml.param.{IntParam, Params}

trait HasNumLabels extends Params {

final val numLabels = new IntParam(this, "numLabels", "numLabels")

final def getNumLabels: Int = $(numLabels)

setDefault(numLabels, 1)

final def setNumLabels(value: Int): this.type = set(numLabels, value)
}
Loading