Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize space usage of ExplorationReport before saving #279

Merged
merged 5 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion dpgen2/exploration/report/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,10 @@ def converged(
"""
pass

@abstractmethod
def no_candidate(self) -> bool:
r"""If no candidate configuration is found"""
return all([len(ii) == 0 for ii in self.get_candidate_ids()])
pass

@abstractmethod
def get_candidate_ids(
Expand Down
21 changes: 18 additions & 3 deletions dpgen2/exploration/report/report_adaptive_lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ def __init__(
self.fmt_str = " ".join([f"%{ii}s" for ii in spaces])
self.fmt_flt = "%.4f"
self.header_str = "#" + self.fmt_str % print_tuple
self._no_candidate = False
self._failed_ratio = None
self._accurate_ratio = None
self._candidate_ratio = None

@staticmethod
def doc() -> str:
Expand Down Expand Up @@ -274,6 +278,10 @@ def record(
# accurate set is substracted by the candidate set
self.accur = self.accur - self.candi
self.model_devi = model_devi
self._no_candidate = len(self.candi) == 0
self._failed_ratio = float(len(self.failed)) / float(self.nframes)
self._accurate_ratio = float(len(self.accur)) / float(self.nframes)
self._candidate_ratio = float(len(self.candi)) / float(self.nframes)

def _record_one_traj(
self,
Expand Down Expand Up @@ -346,29 +354,36 @@ def failed_ratio(
self,
tag=None,
):
return float(len(self.failed)) / float(self.nframes)
return self._failed_ratio

def accurate_ratio(
self,
tag=None,
):
return float(len(self.accur)) / float(self.nframes)
return self._accurate_ratio

def candidate_ratio(
self,
tag=None,
):
return float(len(self.candi)) / float(self.nframes)
return self._candidate_ratio

def no_candidate(self) -> bool:
return self._no_candidate

def get_candidate_ids(
self,
max_nframes: Optional[int] = None,
clear: bool = True,
) -> List[List[int]]:
ntraj = self.ntraj
id_cand = self._get_candidates(max_nframes)
id_cand_list = [[] for ii in range(ntraj)]
for ii in id_cand:
id_cand_list[ii[0]].append(ii[1])
# free the memory, this method should only be called once
if clear:
self.clear()
return id_cand_list

def _get_candidates(
Expand Down
26 changes: 20 additions & 6 deletions dpgen2/exploration/report/report_trust_levels_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ def __init__(
self.fmt_str = " ".join([f"%{ii}s" for ii in spaces])
self.fmt_flt = "%.4f"
self.header_str = "#" + self.fmt_str % print_tuple
self._no_candidate = False
self._failed_ratio = None
self._accurate_ratio = None
self._candidate_ratio = None

@staticmethod
def args() -> List[Argument]:
Expand Down Expand Up @@ -133,6 +137,16 @@ def record(
assert len(self.traj_accu) == ntraj
assert len(self.traj_fail) == ntraj
self.model_devi = model_devi
self._no_candidate = sum([len(ii) for ii in self.traj_cand]) == 0
self._failed_ratio = float(sum([len(ii) for ii in self.traj_fail])) / float(
sum(self.traj_nframes)
)
self._accurate_ratio = float(sum([len(ii) for ii in self.traj_accu])) / float(
sum(self.traj_nframes)
)
self._candidate_ratio = float(sum([len(ii) for ii in self.traj_cand])) / float(
sum(self.traj_nframes)
)

def _get_indexes(
self,
Expand Down Expand Up @@ -205,22 +219,22 @@ def failed_ratio(
self,
tag=None,
):
traj_nf = [len(ii) for ii in self.traj_fail]
return float(sum(traj_nf)) / float(sum(self.traj_nframes))
return self._failed_ratio

def accurate_ratio(
self,
tag=None,
):
traj_nf = [len(ii) for ii in self.traj_accu]
return float(sum(traj_nf)) / float(sum(self.traj_nframes))
return self._accurate_ratio

def candidate_ratio(
self,
tag=None,
):
traj_nf = [len(ii) for ii in self.traj_cand]
return float(sum(traj_nf)) / float(sum(self.traj_nframes))
return self._candidate_ratio

def no_candidate(self) -> bool:
return self._no_candidate

@abstractmethod
def get_candidate_ids(
Expand Down
4 changes: 4 additions & 0 deletions dpgen2/exploration/report/report_trust_levels_max.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,21 @@
converged bool
If the exploration is converged.
"""
return self.accurate_ratio() >= self.conv_accuracy

Check failure on line 44 in dpgen2/exploration/report/report_trust_levels_max.py

View workflow job for this annotation

GitHub Actions / pyright

Operator ">=" not supported for "None" (reportOptionalOperand)

def get_candidate_ids(
self,
max_nframes: Optional[int] = None,
clear: bool = True,
) -> List[List[int]]:
ntraj = len(self.traj_nframes)
id_cand = self._get_candidates(max_nframes)
id_cand_list = [[] for ii in range(ntraj)]
for ii in id_cand:
id_cand_list[ii[0]].append(ii[1])
# free the memory, this method should only be called once
if clear:
self.clear()
return id_cand_list

def _get_candidates(
Expand Down
8 changes: 7 additions & 1 deletion dpgen2/exploration/report/report_trust_levels_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,23 @@ def converged(
converged bool
If the exploration is converged.
"""
return self.accurate_ratio() >= self.conv_accuracy
accurate_ratio = self.accurate_ratio()
assert isinstance(accurate_ratio, float)
return accurate_ratio >= self.conv_accuracy

def get_candidate_ids(
self,
max_nframes: Optional[int] = None,
clear: bool = True,
) -> List[List[int]]:
ntraj = len(self.traj_nframes)
id_cand = self._get_candidates(max_nframes)
id_cand_list = [[] for ii in range(ntraj)]
for ii in id_cand:
id_cand_list[ii[0]].append(ii[1])
# free the memory, this method should only be called once
if clear:
self.clear()
return id_cand_list

def _get_candidates(
Expand Down
8 changes: 4 additions & 4 deletions tests/exploration/test_report_adaptive_lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class MockedReport:
self.assertFalse(ter.converged([mr, mr1, mr]))
self.assertTrue(ter.converged([mr1, mr, mr]))

picked = ter.get_candidate_ids(2)
picked = ter.get_candidate_ids(2, clear=False)
npicked = 0
self.assertEqual(len(picked), 2)
for ii in range(2):
Expand Down Expand Up @@ -218,12 +218,12 @@ def faked_choices(
return ret

ter.record(model_devi)
with mock.patch("random.choices", faked_choices):
picked = ter.get_candidate_ids(11)
self.assertFalse(ter.converged([]))
self.assertEqual(ter.candi, expected_cand)
self.assertEqual(ter.accur, expected_accu)
self.assertEqual(set(ter.failed), expected_fail)
with mock.patch("random.choices", faked_choices):
picked = ter.get_candidate_ids(11)
self.assertFalse(ter.converged([]))
self.assertEqual(len(picked), 2)
self.assertEqual(sorted(picked[0]), [1, 3])
self.assertEqual(sorted(picked[1]), [1, 5, 7])
Expand Down
Loading