Skip to content

Commit

Permalink
weights added to the simple version of the costing algorithm
Browse files Browse the repository at this point in the history
  • Loading branch information
EddyCMWF committed Nov 12, 2024
1 parent 7054d86 commit fc43586
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 4 deletions.
4 changes: 3 additions & 1 deletion cads_adaptors/adaptors/cds.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,9 @@ def estimate_costs(self, request: Request, **kwargs: Any) -> dict[str, int]:
**costing_kwargs,
)
# size is a fast and rough estimate of the number of fields
costs["size"] = costing.estimate_number_of_fields(self.form, mapped_request)
costs["size"] = costing.estimate_number_of_fields(
self.form, mapped_request, **costing_kwargs
)
# Safety net for integration tests:
costs["number_of_fields"] = costs["size"]
return costs
Expand Down
12 changes: 11 additions & 1 deletion cads_adaptors/costing.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,17 +193,27 @@ def get_excluded_keys(
def estimate_number_of_fields(
form: list[dict[str, Any]] | dict[str, Any] | None,
request: dict[str, Any],
**kwargs,
) -> int:
weighted_values = kwargs.get("weighted_values", {})
weighted_keys = kwargs.get("weighted_keys", {})
excluded_variables = get_excluded_keys(form)
number_of_values = []
for variable_id, variable_value in request.items():
weights_v = weighted_values.get(variable_id, {})
weight_k = weighted_keys.get(variable_id, 1)
if isinstance(variable_value, set):
variable_value = list(variable_value)
if not isinstance(variable_value, (list, tuple)):
variable_value = [
variable_value,
]
if variable_id not in excluded_variables:
number_of_values.append(len(variable_value))
# Extend values according to weights
for val, weight in weights_v.items():
if val in variable_value:
variable_value.extend([val] * (weight - 1))
# Append number of values, multiplied by weight
number_of_values.append(len(variable_value) * weight_k)
number_of_fields = math.prod(number_of_values)
return number_of_fields
4 changes: 2 additions & 2 deletions tests/test_10_costing.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,7 +746,7 @@ def test_estimate_costs_2() -> None:
)
costs = weighted_adaptor.estimate_costs(request)
assert costs["precise_size"] == 8
assert costs["size"] == 2
assert costs["size"] == 8

request = {
"variable": "maximum_temperature",
Expand All @@ -762,4 +762,4 @@ def test_estimate_costs_2() -> None:
}
costs = weighted_adaptor.estimate_costs(request)
assert costs["precise_size"] == 10
assert costs["size"] == 2
assert costs["size"] == 10

0 comments on commit fc43586

Please sign in to comment.