Skip to content

Commit

Permalink
Check keys to be removed are not removed if --remove_keys_from_pickle…
Browse files Browse the repository at this point in the history
…s=false
  • Loading branch information
DimaMolod committed Jan 17, 2025
1 parent d4fd6c6 commit d89b0f8
Showing 1 changed file with 55 additions and 39 deletions.
94 changes: 55 additions & 39 deletions test/test_post_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,8 @@
class TestPostPrediction(parameterized.TestCase):
def setUp(self) -> None:
super().setUp()
# Get path of the alphapulldown module
parent_dir = join(dirname(dirname(abspath(__file__))))
# Join the path with the script name
self.input_dir = join(parent_dir, "test/test_data/predictions")
# Set logging level to INFO
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

@parameterized.parameters(
Expand All @@ -40,71 +37,90 @@ def setUp(self) -> None:
def test_files(self, prediction_dir, compress_pickles, remove_pickles, remove_keys):
temp_dir = tempfile.TemporaryDirectory()
try:
logging.info(f"Running test for prediction_dir='{prediction_dir}', compress_pickles={compress_pickles}, remove_pickles={remove_pickles}, remove_keys={remove_keys}")
logging.info(f"Running test for prediction_dir='{prediction_dir}', "
f"compress_pickles={compress_pickles}, remove_pickles={remove_pickles}, remove_keys={remove_keys}")
temp_dir_path = temp_dir.name
# Copy the files to the temporary directory
shutil.copytree(join(self.input_dir, prediction_dir), join(temp_dir_path, prediction_dir))

# Remove existing gz files
gz_files = [f for f in os.listdir(join(temp_dir_path, prediction_dir)) if f.endswith('.gz')]
for f in gz_files:
os.remove(join(temp_dir_path, prediction_dir, f))
# Run the postprocessing function
post_prediction_process(join(temp_dir_path, prediction_dir), compress_pickles, remove_pickles, remove_keys)
gz_files_existing = [f for f in os.listdir(join(temp_dir_path, prediction_dir)) if f.endswith('.gz')]
for f_ in gz_files_existing:
os.remove(join(temp_dir_path, prediction_dir, f_))

# Get the best model from ranking_debug.json
# Run the postprocessing
post_prediction_process(join(temp_dir_path, prediction_dir),
compress_pickles,
remove_pickles,
remove_keys)

# Identify the best model
with open(join(temp_dir_path, prediction_dir, 'ranking_debug.json')) as f:
best_model = json.load(f)['order'][0]

# Define the expected best result pickle path
best_result_pickle = join(temp_dir_path, prediction_dir, f"result_{best_model}.pkl")

# Check if files are removed and/or compressed based on the parameters
# Gather .pkl and .gz files
pickle_files = [f for f in os.listdir(join(temp_dir_path, prediction_dir)) if f.endswith('.pkl')]
gz_files = [f for f in os.listdir(join(temp_dir_path, prediction_dir)) if f.endswith('.gz')]

# Check if specified keys exist or were removed
if remove_keys:
# Ensure specified keys are removed from the pickle files
for pickle_file in pickle_files:
with open(join(temp_dir_path, prediction_dir, pickle_file), 'rb') as f:
for pf in pickle_files:
with open(join(temp_dir_path, prediction_dir, pf), 'rb') as f:
data = pickle.load(f)
for key in ['aligned_confidence_probs', 'distogram', 'masked_msa']:
self.assertNotIn(key, data, f"Key {key} was not removed from {pickle_file}")
self.assertNotIn(key, data, f"Key '{key}' was not removed from {pf}")
else:
# If we're not removing keys, verify they still exist in the pickle
for pf in pickle_files:
with open(join(temp_dir_path, prediction_dir, pf), 'rb') as f:
data = pickle.load(f)
for key in ['aligned_confidence_probs', 'distogram', 'masked_msa']:
self.assertIn(key, data, f"Key '{key}' was unexpectedly removed from {pf}")

# Now check file counts / compressions
if not compress_pickles and not remove_pickles:
# All pickle files should be present, no gz files
logging.info("Checking condition: not compress_pickles and not remove_pickles")
self.assertEqual(len(pickle_files), 5, f"Expected 5 pickle files, found {len(pickle_files)}.")
self.assertEqual(len(gz_files), 0, f"Expected 0 gz files, found {len(gz_files)}.")
# Expect all .pkl files (5 in your scenario), no .gz
self.assertEqual(len(pickle_files), 5,
f"Expected 5 pickle files, found {len(pickle_files)}.")
self.assertEqual(len(gz_files), 0,
f"Expected 0 gz files, found {len(gz_files)}.")

if compress_pickles and not remove_pickles:
# No pickle files should be present, each compressed separately
logging.info("Checking condition: compress_pickles and not remove_pickles")
self.assertEqual(len(pickle_files), 0, f"Expected 0 pickle files, found {len(pickle_files)}.")
self.assertEqual(len(gz_files), 5, f"Expected 5 gz files, found {len(gz_files)}.")
# Expect 0 .pkl files, all compressed (5)
self.assertEqual(len(pickle_files), 0,
f"Expected 0 pickle files, found {len(pickle_files)}.")
self.assertEqual(len(gz_files), 5,
f"Expected 5 gz files, found {len(gz_files)}.")
# Validate that gz files are readable
for gz_file in gz_files:
with gzip.open(join(temp_dir_path, prediction_dir, gz_file), 'rb') as f:
f.read(1) # Ensure it's a valid gzip file
f.read(1)

if not compress_pickles and remove_pickles:
# Only the best result pickle should be present
logging.info("Checking condition: not compress_pickles and remove_pickles")
self.assertEqual(len(pickle_files), 1, f"Expected 1 pickle file, found {len(pickle_files)}.")
self.assertEqual(len(gz_files), 0, f"Expected 0 gz files, found {len(gz_files)}.")
self.assertTrue(os.path.exists(best_result_pickle), f"Best result pickle file does not exist: {best_result_pickle}")
# Only the best pickle remains
self.assertEqual(len(pickle_files), 1,
f"Expected 1 pickle file, found {len(pickle_files)}.")
self.assertEqual(len(gz_files), 0,
f"Expected 0 gz files, found {len(gz_files)}.")
self.assertTrue(os.path.exists(best_result_pickle),
f"Best result pickle file does not exist: {best_result_pickle}")

if compress_pickles and remove_pickles:
# Only the best result pickle should be compressed, no pickle files present
logging.info("Checking condition: compress_pickles and remove_pickles")
self.assertEqual(len(pickle_files), 0, f"Expected 0 pickle files, found {len(pickle_files)}.")
self.assertEqual(len(gz_files), 1, f"Expected 1 gz file, found {len(gz_files)}.")
self.assertTrue(os.path.exists(best_result_pickle + ".gz"), f"Best result pickle file not compressed: {best_result_pickle}.gz")
# Only the best pickle is compressed
self.assertEqual(len(pickle_files), 0,
f"Expected 0 pickle files, found {len(pickle_files)}.")
self.assertEqual(len(gz_files), 1,
f"Expected 1 gz file, found {len(gz_files)}.")
self.assertTrue(os.path.exists(best_result_pickle + ".gz"),
f"Best result pickle file not compressed: {best_result_pickle}.gz")
with gzip.open(join(temp_dir_path, prediction_dir, gz_files[0]), 'rb') as f:
f.read(1) # Ensure it's a valid gzip file
f.read(1) # Check it's valid gzip

except AssertionError as e:
logging.error(f"AssertionError: {e}")
all_files = os.listdir(join(temp_dir_path, prediction_dir))
relevant_files = [f for f in all_files if f.endswith('.gz') or f.endswith('.pkl')]
logging.error(f".gz and .pkl files in {join(temp_dir_path, prediction_dir)}: {relevant_files}")
raise # Re-raise the exception to ensure the test is marked as failed
raise
finally:
temp_dir.cleanup()

0 comments on commit d89b0f8

Please sign in to comment.