Skip to content

Commit

Permalink
Fix #5
Browse files Browse the repository at this point in the history
  • Loading branch information
cmutel committed May 14, 2024
1 parent c838bf4 commit 5f9d62e
Show file tree
Hide file tree
Showing 9 changed files with 123 additions and 25 deletions.
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ That will return something like:

```python
Speedup(
database_name='ecoinvent-3.10-cutoff',
time_with_aggregation=4.297600030899048,
time_without_aggregation=2.22904896736145,
time_difference_absolute=2.0685510635375977,
time_difference_relative=1.9279971386120622
database_name='USEEIO-2.0',
time_with_aggregation=0.06253910064697266,
time_without_aggregation=0.026948928833007812,
time_difference_absolute=0.035590171813964844,
time_difference_relative=2.3206525585674855
)
```

Expand Down
8 changes: 6 additions & 2 deletions bw_aggregation/calculator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from time import time

import numpy as np
from bw2calc import LCA, PYPARDISO, spsolve
from bw2calc import LCA, spsolve
from bw2data import databases, prepare_lca_inputs
from bw2data.database import DatabaseChooser
from bw_graph_tools import guess_production_exchanges
Expand Down Expand Up @@ -34,7 +34,11 @@ def calculate(self) -> None:
prod_rows, prod_cols = guess_production_exchanges(self.lca.technosphere_mm)
# Not very efficient; could be SQL query but that would break IOTable
matrix_db_process_ids = np.array(
[self.lca.dicts.activity[obj.id] for obj in self.db]
[
self.lca.dicts.activity[obj.id]
for obj in self.db
if obj.get("type", "process") == "process"
]
)

# Get boolean mask for the column indices of the processes in the database
Expand Down
12 changes: 12 additions & 0 deletions bw_aggregation/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
class IncompatibleDatabase(Exception):
"""Database can't be used for inventory calculations.
Usually because it only has biosphere flows."""

pass


class ObsoleteAggregatedDatapackage(Exception):
"""The results from this aggregated datapackage are obsolete"""

pass
42 changes: 35 additions & 7 deletions bw_aggregation/estimator.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import random
from collections import Counter
from dataclasses import dataclass
from time import time

from bw2calc import LCA
from bw2data import Database, prepare_lca_inputs
from bw2data import Database, Node, prepare_lca_inputs

from .errors import IncompatibleDatabase
from .override import AggregationContext
from .utils import check_processes_in_database, get_process_type_counts


@dataclass
Expand All @@ -20,6 +24,34 @@ class CalculationDifferenceEstimator:
def __init__(self, database_name: str):
self.name = database_name
self.db = Database(database_name)
self.random_product = self.pick_random_product()

def pick_random_product(self) -> Node:
"""Check database structure and pick random product"""
if not check_processes_in_database(self.db.name):
dataset_types_formatted = "\n\t".join(
[
f"{a}: {b} objects"
for a, b in test_get_process_type_counts(self.db.name).items()
]
)
ERROR = f"""
This database has the wrong kind of flows for an inventory calculation.
It should have either only "process" flow types, or "process" and "product" flows.
The following flows were found in database {self.db.name}: \n\t{dataset_types_formatted}
"""
raise IncompatibleDatabase(ERROR)

for ds in self.db:
production = list([exc for exc in ds.production() if exc["amount"]])
if len(production) == 1:
return production[0].output

ERROR = f"""
The database {self.db.name} has no processes with a single non-zero production exchange.
We can't find a suitable process to do the example calculations with.
"""
raise IncompatibleDatabase(ERROR)

def difference(self) -> Speedup:
without = self.calculate_without_speedup()
Expand All @@ -36,10 +68,8 @@ def difference(self) -> Speedup:
def calculate_with_speedup(self):
from .main import AggregatedDatabase

process = self.db.random()

with AggregationContext({self.name: False}):
fu, data_objs, _ = prepare_lca_inputs({process: 1})
fu, data_objs, _ = prepare_lca_inputs({self.random_product: 1})
data_objs[-1] = AggregatedDatabase(self.name).process_aggregated(
in_memory=True
)
Expand All @@ -52,10 +82,8 @@ def calculate_with_speedup(self):
return end - start

def calculate_without_speedup(self):
process = self.db.random()

with AggregationContext({self.name: False}):
fu, data_objs, _ = prepare_lca_inputs({process: 1})
fu, data_objs, _ = prepare_lca_inputs({self.random_product: 1})

start = time()
lca = LCA(fu, data_objs=data_objs)
Expand Down
24 changes: 13 additions & 11 deletions bw_aggregation/main.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
import warnings
from datetime import datetime as dt

import tqdm
from bw2calc import LCA, PYPARDISO
from bw2data import databases
from bw2data.backends import SQLiteBackend
from bw_processing import (
Expand All @@ -15,14 +12,10 @@
from fs.zipfs import ZipFS

from .calculator import AggregationCalculator
from .estimator import CalculationDifferenceEstimator
from .errors import IncompatibleDatabase, ObsoleteAggregatedDatapackage
from .estimator import CalculationDifferenceEstimator, Speedup
from .override import AggregationContext, aggregation_override


class ObsoleteAggregatedDatapackage(Exception):
"""The results from this aggregated datapackage are obsolete"""

pass
from .utils import check_processes_in_data, check_processes_in_database


class AggregatedDatabase(SQLiteBackend):
Expand Down Expand Up @@ -78,7 +71,7 @@ def aggregation_datapackage_valid(self) -> bool:
)

@staticmethod
def estimate_speedup(database_name: str) -> float:
def estimate_speedup(database_name: str) -> Speedup:
"""Estimate how much quicker calculations could be when using aggregated emissions.
Prints to `stdout` and return a float, the ratio of calculation speed with aggregation
Expand All @@ -93,6 +86,10 @@ def convert_existing(database_name: str) -> None:
if databases[database_name]["backend"] == "aggregated":
print(f"Database '{database_name}' is already aggregated")
return
if not check_processes_in_database(database_name):
raise IncompatibleDatabase(
"This database only has biosphere flows, and can't be aggregated."
)

db = AggregatedDatabase(database_name)
db.process_aggregated()
Expand Down Expand Up @@ -163,6 +160,11 @@ def refresh_all() -> None:
AggregatedDatabase(db_name).refresh()

def write(self, data, process=True, searchable=True) -> None:
if not check_processes_in_data(data.values()):
raise IncompatibleDatabase(
"This data only has biosphere flows, and can't be aggregated."
)

super().write(data=data, process=process, searchable=searchable)

if process:
Expand Down
26 changes: 26 additions & 0 deletions bw_aggregation/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from collections import Counter
from collections.abc import Iterable
from typing import Optional

from bw2data import Database, databases


def check_processes_in_database(database_name: str) -> bool:
"""Check to make sure database is usable for aggregated calculations."""
if database_name not in databases:
raise KeyError(f"{database_name} not in databases")
return check_processes_in_data(Database(database_name))


def check_processes_in_data(objects: Iterable) -> bool:
"""Check if any object in the input data has type `process`"""
return any(obj.get("type", "process") == "process" for obj in objects)


def get_process_type_counts(database_name: str) -> dict[Optional[str], int]:
"""Get count of each process type in database"""
if database_name not in databases:
raise KeyError(f"{database_name} not in databases")
return dict(
Counter([obj.get("type") for obj in Database(database_name)]).most_common()
)
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def background():
},
],
},
("a", "4"): {"name": "CO2", "type": "emission", "exchanges": []},
}
Database("a").write(a_data)

Expand Down
6 changes: 6 additions & 0 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from bw2data.backends import SQLiteBackend

from bw_aggregation import AggregatedDatabase, ObsoleteAggregatedDatapackage
from bw_aggregation.errors import IncompatibleDatabase


def check_a_database_matrices_unaggregated(lca: LCA):
Expand Down Expand Up @@ -266,3 +267,8 @@ def test_refresh_all(background):

assert Database("a").aggregation_datapackage_valid()
assert Database("b").aggregation_datapackage_valid()


def test_incompatible_database_only_biosphere_flows(background):
with pytest.raises(IncompatibleDatabase):
AggregatedDatabase.convert_existing("bio")
19 changes: 19 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import pytest
from bw2data import Database

from bw_aggregation.utils import (
check_processes_in_data,
check_processes_in_database,
get_process_type_counts,
)


def test_check_processes_in_database(background):
assert check_processes_in_database("a")
assert not check_processes_in_database("bio")
with pytest.raises(KeyError):
check_processes_in_database("missing")


def test_get_process_type_counts(background):
assert get_process_type_counts("a") == {None: 3, "emission": 1}

0 comments on commit 5f9d62e

Please sign in to comment.