Skip to content

Commit

Permalink
fix: multidimensional sort
Browse files Browse the repository at this point in the history
  • Loading branch information
maffettone committed Dec 13, 2023
1 parent a3da652 commit 4f9e1c7
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion pdf_agents/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,16 @@ def _sample_uncertainty_proxy(self, batch_size=1):
"""
# Borrowing from Dan's jupyter fun
# from measurements, perform k-means
sorted_independents, sorted_observables = zip(*sorted(zip(self.independent_cache, self.observable_cache)))
try:
sorted_independents, sorted_observables = zip(
*sorted(zip(self.independent_cache, self.observable_cache))
)
except ValueError:
# Multidimensional case
sorted_independents, sorted_observables = zip(
*sorted(zip(self.independent_cache, self.observable_cache), key=lambda x: (x[0][0], x[0][1]))
)

sorted_independents = np.array(sorted_independents)
sorted_observables = np.array(sorted_observables)
self.model.fit(sorted_observables)
Expand Down

0 comments on commit 4f9e1c7

Please sign in to comment.