Skip to content

Commit

Permalink
feat: Running batch to grow
Browse files Browse the repository at this point in the history
  • Loading branch information
maffettone committed Dec 14, 2023
1 parent 974a022 commit 4cd3369
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions pdf_agents/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,8 @@ def _sample_uncertainty_proxy(self, batch_size=1):
# Assume a 1d scan
# generate 'uncertainty weights' - as a polynomial fit of the golf-score for each point
_x = np.arange(*self.bounds, self.motor_resolution)
if batch_size is None:
batch_size = len(_x)
uwx = polyval(_x, polyfit(sorted_independents, min_landscape, deg=5))
# Chose from the polynomial fit
return pick_from_distribution(_x, uwx, num_picks=batch_size), centers
Expand All @@ -214,12 +216,12 @@ def _sample_uncertainty_proxy(self, batch_size=1):
labels = self.model.predict(sorted_observables)
proby_preds = LogisticRegression().fit(sorted_independents, labels).predict_proba(grid)
shannon = -np.sum(proby_preds * np.log(1 / proby_preds), axis=-1)
top_indicies = np.argsort(shannon)[-batch_size:]
top_indicies = np.argsort(shannon) if batch_size is None else np.argsort(shannon)[-batch_size:]
return grid[top_indicies], centers

def ask(self, batch_size=1):
"""Get's a relative position from the agent. Returns a document and hashes the suggestion for redundancy"""
suggestions, centers = self._sample_uncertainty_proxy(batch_size)
suggestions, centers = self._sample_uncertainty_proxy(None)
kept_suggestions = []
if not isinstance(suggestions, Iterable):
suggestions = [suggestions]
Expand All @@ -234,6 +236,8 @@ def ask(self, batch_size=1):
else:
self.knowledge_cache.add(hashable_suggestion)
kept_suggestions.append(suggestion)
if len(kept_suggestions) >= batch_size:
break

base_doc = dict(
cluster_centers=centers,
Expand Down

0 comments on commit 4cd3369

Please sign in to comment.