Skip to content

Commit

Permalink
gradient pruning
Browse files Browse the repository at this point in the history
  • Loading branch information
Refefer committed Oct 19, 2022
1 parent e29f749 commit 82ef32c
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 13 deletions.
7 changes: 5 additions & 2 deletions benches/bench_algos.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ use rand_distr::{Distribution,Uniform};
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use cloverleaf::graph::{Graph,CSR,CumCSR};
use cloverleaf::algos::lpa::lpa;
use cloverleaf::algos::ep::{EmbeddingPropagation,FeatureStore,Loss};
use cloverleaf::algos::ep::{EmbeddingPropagation,Loss};
use cloverleaf::algos::utils::FeatureStore;

const SEED: u64 = 2022341;

Expand Down Expand Up @@ -70,8 +71,10 @@ fn embedding_propagation(c: &mut Criterion) {
passes: 50,
seed: 202220222,
indicator: false,
max_nodes: Some(10),
max_features: None,
wd: 0f32,
loss: Loss::MarginLoss(10f32)
loss: Loss::MarginLoss(10f32, 1)
};

let label = format!("ep:{}", num_feats);
Expand Down
48 changes: 38 additions & 10 deletions src/algos/ep.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,14 +221,13 @@ fn extract_grads(
for (feat_id, (var, _)) in vars {
if grads.contains_key(&feat_id) { continue }

let grad = graph.get_grad(&var)
.expect("Should have a gradient!");

if grad.iter().all(|gi| !gi.is_nan()) {
// Can get some nans in weird cases, such as the distance between
// a node and it's reconstruction when it shares all features.
// Since that's not all that helpful anyways, we simply ignore it and move on
grads.insert(feat_id, grad.to_vec());
if let Some(grad) = graph.get_grad(&var) {
if grad.iter().all(|gi| !gi.is_nan()) {
// Can get some nans in weird cases, such as the distance between
// a node and it's reconstruction when it shares all features.
// Since that's not all that helpful anyways, we simply ignore it and move on
grads.insert(feat_id, grad.to_vec());
}
}
}
}
Expand Down Expand Up @@ -378,13 +377,42 @@ impl Loss {
fn compute(&self, thv: ANode, hv: ANode, hus: Vec<ANode>) -> ANode {
match self {

Loss::MarginLoss(gamma, _) | Loss::StarSpace(gamma, _) => {
Loss::MarginLoss(gamma, _) => {
let d1 = euclidean_distance(thv.clone(), hv);
let d2 = hus.into_iter()
.map(|hu| euclidean_distance(thv.clone(), hu))
.collect::<Vec<_>>().sum_all();

(gamma + d1 - d2).maximum(0f32)
let result = (gamma + d1 - d2).maximum(0f32);
if result.value()[0] > 0f32 {
result
} else {
Constant::scalar(0f32)
}
},

Loss::StarSpace(gamma, _) => {
let thv_norm = il2norm(thv);
let hv_norm = il2norm(hv);

let d1 = gamma - cosine(thv_norm.clone(), hv_norm);
let negs = hus.len();
let pos_losses = hus.into_iter()
.map(|hu| {
let hu_norm = il2norm(hu);
(d1.clone() + cosine(thv_norm.clone(), hu_norm)).maximum(0f32)
})
// Only collect losses which are not zero
.filter(|l| l.value()[0] > 0f32)
.collect::<Vec<_>>();

// Only return positive ones
if pos_losses.len() > 0 {
let n_losses = pos_losses.len() as f32;
pos_losses.sum_all() / n_losses
} else {
Constant::scalar(0f32)
}
},

Loss::Contrastive(tau, _) => {
Expand Down
1 change: 0 additions & 1 deletion src/algos/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,3 @@ pub mod slpa;
pub mod ep;
pub mod ann;
pub mod utils;
pub mod starspace;

0 comments on commit 82ef32c

Please sign in to comment.