Skip to content

Commit

Permalink
Add a metric for TA-DP-FTRL,
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 590444460
  • Loading branch information
nightldj authored and tensorflower-gardener committed Dec 13, 2023
1 parent fbe5879 commit 7cbb281
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 2 deletions.
7 changes: 5 additions & 2 deletions tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,10 @@
eps = rdp_accountant.get_privacy_spent(orders, rdp, target_delta)[0]
"""

import collections
from typing import Any, NamedTuple

import dp_accounting
import tensorflow as tf

from tensorflow_privacy.privacy.dp_query import dp_query
from tensorflow_privacy.privacy.dp_query import tree_aggregation

Expand Down Expand Up @@ -476,6 +475,10 @@ def reset_l2_clip_gaussian_noise(self, global_state, clip_norm, stddev):
previous_tree_noise=global_state.previous_tree_noise,
)

def derive_metrics(self, global_state):
"""Returns the clip norm as a metric."""
return collections.OrderedDict(tree_agg_dpftrl_clip=global_state.clip_value)

@classmethod
def build_l2_gaussian_query(cls,
clip_norm,
Expand Down
12 changes: 12 additions & 0 deletions tensorflow_privacy/privacy/dp_query/tree_aggregation_query_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,18 @@ def test_sum_tree_aggregator_instance(self, use_efficient, tree_class):
)
self.assertIsInstance(query._tree_aggregator, tree_class)

def test_derive_metrics(self):
specs = tf.TensorSpec([])
l2_clip = 2
query = tree_aggregation_query.TreeResidualSumQuery(
clip_fn=_get_l2_clip_fn(),
clip_value=l2_clip,
noise_generator=_get_noise_fn(specs, 1.0),
record_specs=specs,
)
metrics = query.derive_metrics(query.initial_global_state())
self.assertEqual(metrics['tree_agg_dpftrl_clip'], l2_clip)

@parameterized.named_parameters(
('s0t1f1', 0., 1., 1),
('s0t1f2', 0., 1., 2),
Expand Down

0 comments on commit 7cbb281

Please sign in to comment.