Skip to content

Commit

Permalink
Properly handle arguments for PMTObserver
Browse files Browse the repository at this point in the history
  • Loading branch information
csbnw committed Jun 5, 2024
1 parent e21681a commit db9fc45
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion examples/cuda/vector_add_observers_pmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def tune():
tune_params = dict()
tune_params["block_size_x"] = [128+64*i for i in range(15)]

pmtobserver = PMTObserver(["nvidia", "rapl"])
pmtobserver = PMTObserver([("nvidia", 0), "rapl"])

metrics = OrderedDict()
metrics["GPU W"] = lambda p: p["nvidia_power"]
Expand Down
4 changes: 2 additions & 2 deletions kernel_tuner/observers/pmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ def __init__(self, observable=None):
if type(observable) is dict:
pass
elif type(observable) is list:
# user specifies a list of platforms as observable
observable = dict([(obs, 0) for obs in observable])
# user specifies a list of platforms as observable, optionally with an argument
observable = dict([obs if isinstance(obs, tuple) else (obs, None) for obs in observable])
else:
# User specifices a string (single platform) as observable
observable = {observable: None}
Expand Down

0 comments on commit db9fc45

Please sign in to comment.