@@ -93,9 +93,18 @@ class Attack(abc.ABC, metaclass=InputFilter):
93
93
attack_params : List [str ] = list ()
94
94
_estimator_requirements : Optional [Union [Tuple [Any , ...], Tuple [()]]] = None
95
95
96
- def __init__ (self , estimator ):
96
+ def __init__ (
97
+ self ,
98
+ estimator ,
99
+ tensor_board : Union [str , bool ] = False ,
100
+ ):
97
101
"""
98
102
:param estimator: An estimator.
103
+ :param tensor_board: Activate summary writer for TensorBoard: Default is `False` and deactivated summary writer.
104
+ If `True` save runs/CURRENT_DATETIME_HOSTNAME in current directory. Provide `path` in type
105
+ `str` to save in path/CURRENT_DATETIME_HOSTNAME.
106
+ Use hierarchical folder structure to compare between runs easily. e.g. pass in ‘runs/exp1’,
107
+ ‘runs/exp2’, etc. for each new experiment to compare across them.
99
108
"""
100
109
super ().__init__ ()
101
110
@@ -106,6 +115,19 @@ def __init__(self, estimator):
106
115
raise EstimatorError (self .__class__ , self .estimator_requirements , estimator )
107
116
108
117
self ._estimator = estimator
118
+ self .tensor_board = tensor_board
119
+
120
+ if tensor_board :
121
+ from tensorboardX import SummaryWriter
122
+
123
+ if isinstance (tensor_board , str ):
124
+ self .summary_writer = SummaryWriter (tensor_board )
125
+ else :
126
+ self .summary_writer = SummaryWriter ()
127
+ else :
128
+ self .summary_writer = None
129
+
130
+ Attack ._check_params (self )
109
131
110
132
@property
111
133
def estimator (self ):
@@ -129,7 +151,9 @@ def set_params(self, **kwargs) -> None:
129
151
self ._check_params ()
130
152
131
153
def _check_params (self ) -> None :
132
- pass
154
+
155
+ if not isinstance (self .tensor_board , (bool , str )):
156
+ raise ValueError ("The argument `tensor_board` has to be either of type bool or str." )
133
157
134
158
135
159
class EvasionAttack (Attack ):
@@ -305,12 +329,12 @@ def __init__(self, estimator):
305
329
@abc .abstractmethod
306
330
def infer (self , x : np .ndarray , y : Optional [np .ndarray ] = None , ** kwargs ) -> np .ndarray :
307
331
"""
308
- Infer sensitive properties ( attributes, membership training records) from the targeted estimator. This method
332
+ Infer sensitive attributes from the targeted estimator. This method
309
333
should be overridden by all concrete inference attack implementations.
310
334
311
335
:param x: An array with reference inputs to be used in the attack.
312
336
:param y: Labels for `x`. This parameter is only used by some of the attacks.
313
- :return: An array holding the inferred properties .
337
+ :return: An array holding the inferred attribute values .
314
338
"""
315
339
raise NotImplementedError
316
340
@@ -334,12 +358,41 @@ def __init__(self, estimator, attack_feature: Union[int, slice] = 0):
334
358
@abc .abstractmethod
335
359
def infer (self , x : np .ndarray , y : Optional [np .ndarray ] = None , ** kwargs ) -> np .ndarray :
336
360
"""
337
- Infer sensitive properties (attributes, membership training records) from the targeted estimator. This method
361
+ Infer sensitive attributes from the targeted estimator. This method
362
+ should be overridden by all concrete inference attack implementations.
363
+
364
+ :param x: An array with reference inputs to be used in the attack.
365
+ :param y: Labels for `x`. This parameter is only used by some of the attacks.
366
+ :return: An array holding the inferred attribute values.
367
+ """
368
+ raise NotImplementedError
369
+
370
+
371
+ class MembershipInferenceAttack (InferenceAttack ):
372
+ """
373
+ Abstract base class for membership inference attack classes.
374
+ """
375
+
376
+ def __init__ (self , estimator : Union ["CLASSIFIER_TYPE" ]):
377
+ """
378
+ :param estimator: A trained estimator targeted for inference attack.
379
+ :type estimator: :class:`.art.estimators.estimator.BaseEstimator`
380
+ :param attack_feature: The index of the feature to be attacked.
381
+ """
382
+ super ().__init__ (estimator )
383
+
384
+ @abc .abstractmethod
385
+ def infer (self , x : np .ndarray , y : Optional [np .ndarray ] = None , ** kwargs ) -> np .ndarray :
386
+ """
387
+ Infer membership status of samples from the target estimator. This method
338
388
should be overridden by all concrete inference attack implementations.
339
389
340
390
:param x: An array with reference inputs to be used in the attack.
341
391
:param y: Labels for `x`. This parameter is only used by some of the attacks.
342
- :return: An array holding the inferred properties.
392
+ :param probabilities: a boolean indicating whether to return the predicted probabilities per class, or just
393
+ the predicted class.
394
+ :return: An array holding the inferred membership status (1 indicates member of training set,
395
+ 0 indicates non-member) or class probabilities.
343
396
"""
344
397
raise NotImplementedError
345
398
0 commit comments