generated from JacksonBurns/blank-python-project
-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Loading status checks…
Adding TargetValue and MolecularWeight Sampler (#172)
This PR supersedes #165 and #170 and resolves #168. #165 implemented a sampler which would divide molecules based on their molecular weight (either ascending or descending), and #170 added basically the same thing but for arbitrary target values. This PR just refactors #170 as if #165 where already implemented, reducing code duplication quite a bit. There is also some small internal cleanup, namely 00eeca6, 98b3efd, and 3b8bca1 (as well as some long-overdue CI updates).
Showing
24 changed files
with
721 additions
and
346 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
name: Continuous Integration | ||
on: | ||
schedule: | ||
- cron: "0 8 * * 1-5" | ||
push: | ||
branches: [main] | ||
pull_request: | ||
branches: [main] | ||
workflow_dispatch: | ||
|
||
concurrency: | ||
group: actions-id-${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} | ||
cancel-in-progress: true | ||
|
||
jobs: | ||
check-formatting: | ||
name: Check Build and Formatting Errors | ||
runs-on: ubuntu-latest | ||
steps: | ||
- uses: actions/checkout@v3 | ||
- name: Install Dependencies | ||
run: | | ||
python -m pip install pycodestyle isort | ||
- name: Check Build | ||
run: | | ||
python -m pip install . | ||
- name: Run pycodestyle | ||
run: | | ||
pycodestyle --statistics --count --max-line-length=150 --show-source --ignore=E203 . | ||
- name: Check Import Ordering Errors | ||
run: | | ||
isort --check-only --verbose . | ||
build-and-test: | ||
needs: check-formatting | ||
continue-on-error: true | ||
strategy: | ||
fail-fast: false | ||
matrix: | ||
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] | ||
os: [ubuntu-latest, windows-latest, macos-latest] | ||
|
||
runs-on: ${{ matrix.os }} | ||
defaults: | ||
run: | ||
shell: bash -el {0} | ||
name: ${{ matrix.os }} Python ${{ matrix.python-version }} Subtest | ||
steps: | ||
- uses: actions/checkout@v3 | ||
- uses: mamba-org/setup-micromamba@main | ||
with: | ||
environment-name: temp | ||
condarc: | | ||
channels: | ||
- defaults | ||
- conda-forge | ||
channel_priority: flexible | ||
create-args: | | ||
python=${{ matrix.python-version }} | ||
- name: Install Dependencies | ||
run: | | ||
python -m pip install -e .[molecules] | ||
python -m pip install coverage pytest | ||
- name: Run Tests | ||
run: | | ||
coverage run --source=. --omit=astartes/__init__.py,setup.py,test/* -m pytest -v | ||
- name: Show Coverage | ||
run: | | ||
coverage report -m | ||
ipynb-ci: | ||
needs: check-formatting | ||
strategy: | ||
fail-fast: false | ||
matrix: | ||
nb-file: | ||
["barrier_prediction_with_RDB7/RDB7_barrier_prediction_example", "train_val_test_split_sklearn_example/train_val_test_split_example", "split_comparisons/split_comparisons", "mlpds_2023_astartes_demonstration/mlpds_2023_demo"] | ||
runs-on: ubuntu-latest | ||
defaults: | ||
run: | ||
shell: bash -el {0} | ||
name: Check ${{ matrix.nb-file }} Notebook Execution | ||
steps: | ||
- uses: actions/checkout@v3 | ||
- uses: mamba-org/setup-micromamba@main | ||
with: | ||
environment-name: temp | ||
condarc: | | ||
channels: | ||
- defaults | ||
- conda-forge | ||
channel_priority: flexible | ||
create-args: | | ||
python=3.11 | ||
- name: Install dependencies | ||
run: | | ||
python -m pip install -e .[molecules,demos] | ||
python -m pip install notebook | ||
- name: Test Execution | ||
run: | | ||
cd examples/$(dirname ${{ matrix.nb-file }}) | ||
jupyter nbconvert --to script $(basename ${{ matrix.nb-file }}).ipynb | ||
ipython $(basename ${{ matrix.nb-file }}).py | ||
coverage-check: | ||
if: contains(github.event.pull_request.labels.*.name, 'PR Ready for Review') | ||
needs: [build-and-test, ipynb-ci] | ||
runs-on: ubuntu-latest | ||
defaults: | ||
run: | ||
shell: bash -el {0} | ||
steps: | ||
- uses: actions/checkout@v3 | ||
- uses: conda-incubator/setup-miniconda@v2 | ||
with: | ||
auto-update-conda: true | ||
python-version: "3.10" | ||
- name: Install Dependencies | ||
run: | | ||
python -m pip install -e .[molecules] | ||
python -m pip install coverage | ||
- name: Run Tests | ||
run: | | ||
coverage run --source=. --omit=astartes/__init__.py,setup.py,test/*,astartes/samplers/sampler.py -m unittest discover -v | ||
- name: Show Coverage | ||
run: | | ||
coverage report -m > temp.txt | ||
cat temp.txt | ||
python .github/workflows/coverage_helper.py | ||
echo "COVERAGE_PERCENT=$(cat temp2.txt)" >> $GITHUB_ENV | ||
- name: Request Changes via Review | ||
if: ${{ env.COVERAGE_PERCENT < 90 }} | ||
uses: andrewmusgrave/[email protected] | ||
with: | ||
repo-token: ${{ secrets.GITHUB_TOKEN }} | ||
event: REQUEST_CHANGES | ||
body: "Increase test coverage from ${{ env.COVERAGE_PERCENT }}% to at least 90% before merging." | ||
|
||
- name: Approve PR if Coverage Sufficient | ||
if: ${{ env.COVERAGE_PERCENT > 89 }} | ||
uses: andrewmusgrave/[email protected] | ||
with: | ||
repo-token: ${{ secrets.GITHUB_TOKEN }} | ||
event: APPROVE | ||
body: "Test coverage meets or exceeds 90% threshold (currently ${{ env.COVERAGE_PERCENT }}%)." | ||
|
||
ci-report-status: | ||
name: report CI status | ||
needs: [build-and-test, ipynb-ci] | ||
runs-on: ubuntu-latest | ||
steps: | ||
- run: | | ||
result_1="${{ needs.build-and-test.result }}" | ||
result_2="${{ needs.ipynb-ci.result }}" | ||
if test $result_1 == "success" && test $result_2 == "success"; then | ||
exit 0 | ||
else | ||
exit 1 | ||
fi | ||
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,8 @@ | ||
from .dbscan import DBSCAN | ||
from .kmeans import KMeans | ||
from .molecular_weight import MolecularWeight | ||
from .optisim import OptiSim | ||
from .scaffold import Scaffold | ||
from .sphere_exclusion import SphereExclusion | ||
from .target_property import TargetProperty | ||
from .time_based import TimeBased |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
""" | ||
This sampler partitions the data based on molecular weight. It first sorts the | ||
molecules by molecular weight and then places the smallest molecules in the training set, | ||
the next smallest in the validation set if applicable, and finally the largest molecules | ||
in the testing set. | ||
""" | ||
|
||
import numpy as np | ||
|
||
try: | ||
from astartes.utils.aimsim_featurizer import featurize_molecules | ||
except ImportError: | ||
# this is in place so that the import of this from parent directory will work | ||
# if it fails, it is caught in molecules instead and the error is more helpful | ||
NO_MOLECULES = True | ||
|
||
from .scaffold import Scaffold | ||
from .target_property import TargetProperty | ||
|
||
|
||
# inherit sample method from TargetProperty | ||
class MolecularWeight(TargetProperty): | ||
def _before_sample(self): | ||
# check for invalid data types using the method in the Scaffold sampler | ||
Scaffold._validate_input(self.X) | ||
# calculate the average molecular weight of the molecule | ||
self.y_backup = self.y | ||
self.y = featurize_molecules((Scaffold.str_to_mol(i) for i in self.X), "mordred:MW", fprints_hopts={}) | ||
|
||
def _after_sample(self): | ||
# restore the original y values | ||
self.y = self.y_backup |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
""" | ||
This sampler partitions the data based on the regression target y. It first sorts the | ||
data by y value and then constructs the training set to have either the smallest (largest) | ||
y values, the validation set to have the next smallest (largest) set of y values, and the | ||
testing set to have the largest (smallest) y values. | ||
""" | ||
|
||
import numpy as np | ||
|
||
from astartes.samplers import AbstractSampler | ||
|
||
|
||
class TargetProperty(AbstractSampler): | ||
def _sample(self): | ||
""" | ||
Implements the target property sampler to create an extrapolation split. | ||
""" | ||
data = [(y, idx) for y, idx in zip(self.y, np.arange(len(self.y)))] | ||
|
||
# by default, the smallest property values are placed in the training set | ||
sorted_list = sorted(data, reverse=self.get_config("descending", False)) | ||
|
||
self._samples_idxs = np.array([idx for time, idx in sorted_list], dtype=int) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import warnings | ||
|
||
import numpy as np | ||
|
||
from astartes.utils.exceptions import MoleculesNotInstalledError | ||
|
||
try: | ||
""" | ||
aimsim depends on sklearn_extra, which uses a version checking technique that is due to | ||
be deprecated in a version of Python after 3.11, so it is throwing a deprecation warning | ||
We ignore this warning since we can't do anything about it (sklearn_extra seems to be | ||
abandonware) and in the future it will become an error that we can deal with. | ||
""" | ||
with warnings.catch_warnings(): | ||
warnings.simplefilter("ignore", category=DeprecationWarning) | ||
from aimsim.chemical_datastructures import Molecule | ||
from aimsim.exceptions import LoadingError | ||
except ImportError: # pragma: no cover | ||
raise MoleculesNotInstalledError("""To use molecule featurizer, install astartes with pip install astartes[molecules].""") | ||
|
||
|
||
def featurize_molecules(molecules, fingerprint, fprints_hopts): | ||
"""Call AIMSim's Molecule to featurize the molecules according to the arguments. | ||
Args: | ||
molecules (np.array): SMILES strings or RDKit molecule objects. | ||
fingerprint (str): The molecular fingerprint to be used. | ||
fprints_hopts (dict): Hyperparameters for AIMSim. | ||
Returns: | ||
np.array: X array (featurized molecules) | ||
""" | ||
X = [] | ||
for molecule in molecules: | ||
try: | ||
if type(molecule) in (np.str_, str): | ||
mol = Molecule(mol_smiles=molecule) | ||
else: | ||
mol = Molecule(mol_graph=molecule) | ||
except LoadingError as le: | ||
raise RuntimeError( | ||
"Unable to featurize molecules using '{:s}' with this configuration: fprint_hopts={:s}" | ||
"\nCheck terminal output for messages from the RDkit logger. ".format(fingerprint, repr(fprints_hopts)) | ||
) from le | ||
mol.descriptor.make_fingerprint( | ||
mol.mol_graph, | ||
fingerprint, | ||
fingerprint_params=fprints_hopts, | ||
) | ||
X.append(mol.descriptor.to_numpy()) | ||
return np.array(X) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
165 changes: 165 additions & 0 deletions
165
test/unit/samplers/extrapolative/test_molecular_weight.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
import unittest | ||
|
||
import numpy as np | ||
|
||
from astartes import train_test_split | ||
from astartes.samplers import MolecularWeight | ||
|
||
|
||
class Test_MolecularWeight(unittest.TestCase): | ||
""" | ||
Test the various functionalities of MolecularWeight. | ||
""" | ||
|
||
@classmethod | ||
def setUpClass(self): | ||
"""Convenience attributes for later tests.""" | ||
self.X = np.array( | ||
[ | ||
"C", | ||
"CC", | ||
"CCC", | ||
"CCCC", | ||
"CCCCC", | ||
"CCCCCC", | ||
"CCCCCCC", | ||
"CCCCCCCC", | ||
"CCCCCCCCC", | ||
"CCCCCCCCCC", | ||
] | ||
) | ||
self.X_inchi = np.array( | ||
[ | ||
"InChI=1S/CH4/h1H4", | ||
"InChI=1S/C2H6/c1-2/h1-2H3", | ||
"InChI=1S/C3H8/c1-3-2/h3H2,1-2H3", | ||
"InChI=1S/C4H10/c1-3-4-2/h3-4H2,1-2H3", | ||
"InChI=1S/C5H12/c1-3-5-4-2/h3-5H2,1-2H3", | ||
"InChI=1S/C6H14/c1-3-5-6-4-2/h3-6H2,1-2H3", | ||
"InChI=1S/C7H16/c1-3-5-7-6-4-2/h3-7H2,1-2H3", | ||
"InChI=1S/C8H18/c1-3-5-7-8-6-4-2/h3-8H2,1-2H3", | ||
"InChI=1S/C9H20/c1-3-5-7-9-8-6-4-2/h3-9H2,1-2H3", | ||
"InChI=1S/C10H22/c1-3-5-7-9-10-8-6-4-2/h3-10H2,1-2H3", | ||
] | ||
) | ||
self.y = np.arange(len(self.X)) | ||
self.labels = np.array( | ||
[ | ||
"methane", | ||
"ethane", | ||
"propane", | ||
"butane", | ||
"pentane", | ||
"hexane", | ||
"heptane", | ||
"octane", | ||
"nonane", | ||
"decane", | ||
] | ||
) | ||
|
||
def test_molecular_weight_sampling(self): | ||
"""Use MolecularWeight in the train_test_split and verify results.""" | ||
( | ||
X_train, | ||
X_test, | ||
y_train, | ||
y_test, | ||
labels_train, | ||
labels_test, | ||
) = train_test_split( | ||
self.X, | ||
self.y, | ||
labels=self.labels, | ||
test_size=0.2, | ||
train_size=0.8, | ||
sampler="molecular_weight", | ||
hopts={}, | ||
) | ||
|
||
# test that the known arrays equal the result from above | ||
self.assertIsNone( | ||
np.testing.assert_array_equal( | ||
X_train, | ||
self.X[:8], # X was already sorted by ascending molecular weight | ||
), | ||
"Train X incorrect.", | ||
) | ||
self.assertIsNone( | ||
np.testing.assert_array_equal( | ||
X_test, | ||
self.X[8:], # X was already sorted by ascending molecular weight | ||
), | ||
"Test X incorrect.", | ||
) | ||
self.assertIsNone( | ||
np.testing.assert_array_equal( | ||
y_train, | ||
self.y[:8], # y was already sorted by ascending molecular weight | ||
), | ||
"Train y incorrect.", | ||
) | ||
self.assertIsNone( | ||
np.testing.assert_array_equal( | ||
y_test, | ||
self.y[8:], # y was already sorted by ascending molecular weight | ||
), | ||
"Test y incorrect.", | ||
) | ||
self.assertIsNone( | ||
np.testing.assert_array_equal( | ||
labels_train, | ||
self.labels[:8], # labels was already sorted by ascending molecular weight | ||
), | ||
"Train labels incorrect.", | ||
) | ||
self.assertIsNone( | ||
np.testing.assert_array_equal( | ||
labels_test, | ||
self.labels[8:], # labels was already sorted by ascending molecular weight | ||
), | ||
"Test labels incorrect.", | ||
) | ||
|
||
def test_molecular_weight(self): | ||
"""Directly instantiate and test MolecularWeight.""" | ||
molecular_weight_instance = MolecularWeight( | ||
self.X, | ||
self.y, | ||
self.labels, | ||
{}, | ||
) | ||
self.assertIsInstance( | ||
molecular_weight_instance, | ||
MolecularWeight, | ||
"Failed instantiation.", | ||
) | ||
self.assertFalse( | ||
len(molecular_weight_instance.get_clusters()), | ||
"Clusters was set when it should not have been.", | ||
) | ||
self.assertTrue( | ||
len(molecular_weight_instance._samples_idxs), | ||
"Sample indices not set.", | ||
) | ||
|
||
def test_incorrect_input(self): | ||
"""Calling with something other than SMILES, InChI, or RDKit Molecule should raise TypeError""" | ||
with self.assertRaises(TypeError): | ||
train_test_split( | ||
np.array([[1], [2]]), | ||
sampler="molecular_weight", | ||
) | ||
|
||
def test_mol_from_inchi(self): | ||
"""Ability to load data from InChi inputs""" | ||
MolecularWeight( | ||
self.X_inchi, | ||
None, | ||
None, | ||
{}, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
194 changes: 194 additions & 0 deletions
194
test/unit/samplers/extrapolative/test_target_property.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,194 @@ | ||
import unittest | ||
|
||
import numpy as np | ||
|
||
from astartes import train_test_split | ||
from astartes.samplers import TargetProperty | ||
|
||
|
||
class Test_TargetProperty(unittest.TestCase): | ||
""" | ||
Test the various functionalities of TargetProperty. | ||
""" | ||
|
||
@classmethod | ||
def setUpClass(self): | ||
"""Convenience attributes for later tests.""" | ||
self.X = np.array( | ||
[ | ||
"C", | ||
"CC", | ||
"CCC", | ||
"CCCC", | ||
"CCCCC", | ||
"CCCCCC", | ||
"CCCCCCC", | ||
"CCCCCCCC", | ||
"CCCCCCCCC", | ||
"CCCCCCCCCC", | ||
] | ||
) | ||
|
||
self.y = np.arange(len(self.X)) | ||
self.labels = np.array( | ||
[ | ||
"methane", | ||
"ethane", | ||
"propane", | ||
"butane", | ||
"pentane", | ||
"hexane", | ||
"heptane", | ||
"octane", | ||
"nonane", | ||
"decane", | ||
] | ||
) | ||
|
||
def test_target_property_sampling_ascending(self): | ||
"""Use TargetProperty in the train_test_split and verify results.""" | ||
( | ||
X_train, | ||
X_test, | ||
y_train, | ||
y_test, | ||
labels_train, | ||
labels_test, | ||
) = train_test_split( | ||
self.X, | ||
self.y, | ||
labels=self.labels, | ||
test_size=0.2, | ||
train_size=0.8, | ||
sampler="target_property", | ||
hopts={"descending": False}, | ||
) | ||
|
||
# test that the known arrays equal the result from above | ||
self.assertIsNone( | ||
np.testing.assert_array_equal( | ||
X_train, | ||
self.X[:8], # X was already sorted by ascending target value | ||
), | ||
"Train X incorrect.", | ||
) | ||
self.assertIsNone( | ||
np.testing.assert_array_equal( | ||
X_test, | ||
self.X[8:], # X was already sorted by ascending target value | ||
), | ||
"Test X incorrect.", | ||
) | ||
self.assertIsNone( | ||
np.testing.assert_array_equal( | ||
y_train, | ||
self.y[:8], # y was already sorted by ascending target value | ||
), | ||
"Train y incorrect.", | ||
) | ||
self.assertIsNone( | ||
np.testing.assert_array_equal( | ||
y_test, | ||
self.y[8:], # y was already sorted by ascending target value | ||
), | ||
"Test y incorrect.", | ||
) | ||
self.assertIsNone( | ||
np.testing.assert_array_equal( | ||
labels_train, | ||
self.labels[:8], # labels was already sorted by ascending target value | ||
), | ||
"Train labels incorrect.", | ||
) | ||
self.assertIsNone( | ||
np.testing.assert_array_equal( | ||
labels_test, | ||
self.labels[8:], # labels was already sorted by ascending target value | ||
), | ||
"Test labels incorrect.", | ||
) | ||
|
||
def test_target_property_sampling_descending(self): | ||
"""Use TargetProperty in the train_test_split and verify results.""" | ||
( | ||
X_train, | ||
X_test, | ||
y_train, | ||
y_test, | ||
labels_train, | ||
labels_test, | ||
) = train_test_split( | ||
self.X, | ||
self.y, | ||
labels=self.labels, | ||
test_size=0.2, | ||
train_size=0.8, | ||
sampler="target_property", | ||
hopts={"descending": True}, | ||
) | ||
|
||
# test that the known arrays equal the result from above | ||
self.assertIsNone( | ||
np.testing.assert_array_equal( | ||
X_train, | ||
np.flip(self.X)[:8], | ||
), | ||
"Train X incorrect.", | ||
) | ||
self.assertIsNone( | ||
np.testing.assert_array_equal( | ||
X_test, | ||
np.flip(self.X)[8:], | ||
), | ||
"Test X incorrect.", | ||
) | ||
self.assertIsNone( | ||
np.testing.assert_array_equal( | ||
y_train, | ||
np.flip(self.y)[:8], | ||
), | ||
"Train y incorrect.", | ||
) | ||
self.assertIsNone( | ||
np.testing.assert_array_equal( | ||
y_test, | ||
np.flip(self.y)[8:], | ||
), | ||
"Test y incorrect.", | ||
) | ||
self.assertIsNone( | ||
np.testing.assert_array_equal( | ||
labels_train, | ||
np.flip(self.labels)[:8], | ||
), | ||
"Train labels incorrect.", | ||
) | ||
self.assertIsNone( | ||
np.testing.assert_array_equal( | ||
labels_test, | ||
np.flip(self.labels)[8:], | ||
), | ||
"Test labels incorrect.", | ||
) | ||
|
||
def test_target_property(self): | ||
"""Directly instantiate and test TargetProperty.""" | ||
target_property_instance = TargetProperty( | ||
self.X, | ||
self.y, | ||
self.labels, | ||
{}, | ||
) | ||
self.assertIsInstance( | ||
target_property_instance, | ||
TargetProperty, | ||
"Failed instantiation.", | ||
) | ||
self.assertFalse( | ||
len(target_property_instance.get_clusters()), | ||
"Clusters was set when it should not have been.", | ||
) | ||
self.assertTrue( | ||
len(target_property_instance._samples_idxs), | ||
"Sample indices not set.", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters