Skip to content

Commit

Permalink
fix failing test
Browse files Browse the repository at this point in the history
  • Loading branch information
bdamokos committed Jan 1, 2025
1 parent 6bf68f8 commit 2aecc59
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 44 deletions.
90 changes: 46 additions & 44 deletions src/mobility_db_api/external_gtfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import zipfile
import shutil
from .api import MobilityAPI, DatasetMetadata
import json


class ExternalGTFSAPI(MobilityAPI):
Expand Down Expand Up @@ -118,24 +119,27 @@ def extract_gtfs(

# Try to find existing provider by name or generate new ID
if not provider_id:
# First, try to find an exact match by provider name if provided
if provider_name:
# If provider name is provided, look for exact match by name only
for key, meta in self.datasets.items():
if meta.provider_name == provider_name:
provider_id = meta.provider_id
self.logger.info(f"Found existing provider ID {provider_id} for name {provider_name}")
break

# If no match by name, try to find by hash (only if name not provided)
if not provider_id and not provider_name:
# If no match found by name, generate new ID
if not provider_id:
provider_id = self._get_next_provider_id()
self.logger.info(f"Generated new provider ID {provider_id} for name {provider_name}")
else:
# Only do hash matching if no provider name is provided
for key, meta in self.datasets.items():
if meta.file_hash == file_hash:
provider_id = meta.provider_id
provider_name = meta.provider_name
provider_name = meta.provider_name # Use the name from the matching dataset
break

# If still no match, generate new ID
if not provider_id:
provider_id = self._get_next_provider_id()
# If no match found by hash, generate new ID
if not provider_id:
provider_id = self._get_next_provider_id()

# Create a temporary directory for extraction
base_dir = Path(download_dir) if download_dir else self.data_dir
Expand All @@ -162,8 +166,23 @@ def extract_gtfs(
provider_dir = base_dir / f"{provider_id}_{safe_name}"
provider_dir.mkdir(exist_ok=True)

# Generate dataset ID using timestamp
dataset_id = f"direct_{datetime.now().strftime('%Y%m%d%H%M%S')}"
# Find old dataset for this provider/name combination
old_dataset_path = None
old_key = None
for key, meta in list(self.datasets.items()):
if meta.provider_id == provider_id and meta.provider_name == provider_name:
old_dataset_path = meta.download_path
old_key = key
self.logger.info(f"Found old dataset {key} for provider {provider_id} ({provider_name})")
break # Only remove the first match

# Generate dataset ID using timestamp and counter if needed
timestamp = datetime.now().strftime('%Y%m%d%H%M%S')
dataset_id = f"direct_{timestamp}"
counter = 1
while any(meta.dataset_id == dataset_id for meta in self.datasets.values()):
dataset_id = f"direct_{timestamp}_{counter}"
counter += 1
dataset_dir = provider_dir / dataset_id

# Move extracted contents to final location
Expand All @@ -174,15 +193,6 @@ def extract_gtfs(
# Get feed dates
feed_start_date, feed_end_date = self._get_feed_dates(dataset_dir)

# Find old dataset for this provider/name combination
old_dataset_path = None
old_key = None
for key, meta in list(self.datasets.items()):
if meta.provider_id == provider_id and meta.provider_name == provider_name:
old_dataset_path = meta.download_path
old_key = key
break

# Create metadata
metadata = DatasetMetadata(
provider_id=provider_id,
Expand All @@ -205,42 +215,34 @@ def extract_gtfs(
# Add new dataset
dataset_key = f"{provider_id}_{dataset_id}"
self.datasets[dataset_key] = metadata

# Save metadata with new dataset
if download_dir:
self._save_metadata(base_dir)
else:
self._save_metadata()
self.logger.info(f"Added new dataset {dataset_key} for provider {provider_id} ({provider_name})")

# Clean up old dataset if it exists
if old_dataset_path and old_dataset_path.exists():
if old_dataset_path and old_dataset_path.exists() and old_key:
self.logger.info(f"Cleaning up old dataset at {old_dataset_path}")
cleanup_success = False
try:
# Clean up the old dataset
shutil.rmtree(old_dataset_path)
# Only remove old dataset from metadata if cleanup was successful
if old_key and old_key in self.datasets:
# Save the old provider name before deletion
old_provider_name = self.datasets[old_key].provider_name
del self.datasets[old_key]
# Restore any other datasets with the same provider name
for key, meta in list(self.datasets.items()):
if meta.provider_name == old_provider_name and key != old_key:
self.datasets[key] = meta
if download_dir:
self._save_metadata(base_dir)
else:
self._save_metadata()
del self.datasets[old_key]
cleanup_success = True
self.logger.info(f"Successfully cleaned up old dataset {old_key}")
except Exception as e:
self.logger.error(f"Failed to clean up old dataset: {str(e)}")
# If cleanup failed, remove the new dataset from metadata
if dataset_key in self.datasets:
del self.datasets[dataset_key]
if download_dir:
self._save_metadata(base_dir)
else:
self._save_metadata()
return None
return None

# Save metadata once at the end
if download_dir:
self._save_metadata(base_dir)
else:
self._save_metadata()

# Log final state
self.logger.info(f"Final datasets: {list(self.datasets.keys())}")
return dataset_dir

except Exception as e:
Expand Down
2 changes: 2 additions & 0 deletions tests/test_api_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,11 @@ def write_metadata_process(data_dir: str, dataset_id: str, delay: float = 0):
maximum_longitude=None,
)

print(f"Process {dataset_id} starting to write metadata.")
# Add to API's datasets and save
api.datasets[f"test_dataset_{dataset_id}"] = metadata
api._save_metadata() # This will merge with existing metadata
print(f"Process {dataset_id} finished writing metadata.")

def read_metadata_process(data_dir: str):
"""Helper function to read metadata from a separate process"""
Expand Down

0 comments on commit 2aecc59

Please sign in to comment.