From 6533418bff5e5ace8bb9fb4ffe5bd214b168b87b Mon Sep 17 00:00:00 2001 From: Christopher Cave-Ayland Date: Thu, 27 Jun 2024 17:20:23 +0100 Subject: [PATCH 01/10] First pass at duckdb data interface --- .../example/default_new_input/commodities.csv | 4 +- src/muse/new_input/readers.py | 76 +++++++++++ tests/test_readers.py | 118 ++++++++++++++---- 3 files changed, 173 insertions(+), 25 deletions(-) create mode 100644 src/muse/new_input/readers.py diff --git a/src/muse/data/example/default_new_input/commodities.csv b/src/muse/data/example/default_new_input/commodities.csv index cec5cbf65..09857851f 100644 --- a/src/muse/data/example/default_new_input/commodities.csv +++ b/src/muse/data/example/default_new_input/commodities.csv @@ -1,6 +1,6 @@ -commodity_name,description,type,unit +name,description,type,unit electricity,Electricity,energy,PJ gas,Gas,energy,PJ heat,Heat,energy,PJ wind,Wind,energy,PJ -C02f,Carbon dioxide,energy,kt +CO2f,Carbon dioxide,energy,kt diff --git a/src/muse/new_input/readers.py b/src/muse/new_input/readers.py new file mode 100644 index 000000000..eafa4fb07 --- /dev/null +++ b/src/muse/new_input/readers.py @@ -0,0 +1,76 @@ +import duckdb +import numpy as np +import xarray as xr + + +def read_inputs(data_dir): + data = {} + con = duckdb.connect(":memory:") + + with open(data_dir / "regions.csv") as f: + regions = read_regions_csv(f, con) # noqa: F841 + + with open(data_dir / "commodities.csv") as f: + commodities = read_commodities_csv(f, con) + + with open(data_dir / "demand.csv") as f: + demand = read_demand_csv(f, con) # noqa: F841 + + data["global_commodities"] = calculate_global_commodities(commodities) + return data + + +def read_regions_csv(buffer_, con): + sql = """CREATE TABLE regions ( + name VARCHAR PRIMARY KEY, + ); + """ + con.sql(sql) + rel = con.read_csv(buffer_, header=True, delimiter=",") # noqa: F841 + con.sql("INSERT INTO regions SELECT name FROM rel;") + return con.sql("SELECT name from regions").fetchnumpy() + + +def read_commodities_csv(buffer_, con): + sql = """CREATE TABLE commodities ( + name VARCHAR PRIMARY KEY, + type VARCHAR CHECK (type IN ('energy', 'service', 'material', 'environmental')), + unit VARCHAR, + ); + """ + con.sql(sql) + rel = con.read_csv(buffer_, header=True, delimiter=",") # noqa: F841 + con.sql("INSERT INTO commodities SELECT name, type, unit FROM rel;") + + return con.sql("select name, type, unit from commodities").fetchnumpy() + + +def calculate_global_commodities(commodities): + names = commodities["name"].astype(np.dtype("str")) + types = commodities["type"].astype(np.dtype("str")) + units = commodities["unit"].astype(np.dtype("str")) + + type_array = xr.DataArray( + data=types, dims=["commodity"], coords=dict(commodity=names) + ) + + unit_array = xr.DataArray( + data=units, dims=["commodity"], coords=dict(commodity=names) + ) + + data = xr.Dataset(data_vars=dict(type=type_array, unit=unit_array)) + return data + + +def read_demand_csv(buffer_, con): + sql = """CREATE TABLE demand ( + year BIGINT, + commodity VARCHAR REFERENCES commodities(name), + region VARCHAR REFERENCES regions(name), + demand DOUBLE, + ); + """ + con.sql(sql) + rel = con.read_csv(buffer_, header=True, delimiter=",") # noqa: F841 + con.sql("INSERT INTO demand SELECT year, commodity_name, region, demand FROM rel;") + return con.sql("SELECT * from demand").fetchnumpy() diff --git a/tests/test_readers.py b/tests/test_readers.py index 1543959df..22aad16a0 100644 --- a/tests/test_readers.py +++ b/tests/test_readers.py @@ -1,7 +1,9 @@ +from io import StringIO from itertools import chain, permutations from pathlib import Path from unittest.mock import patch +import duckdb import numpy as np import toml import xarray as xr @@ -861,40 +863,110 @@ def default_new_input(tmp_path): from muse.examples import copy_model copy_model("default_new_input", tmp_path) - return tmp_path + return tmp_path / "model" -@mark.xfail -def test_read_new_global_commodities(default_new_input): - from muse.new_input.readers import read_inputs +@fixture +def con(): + return duckdb.connect(":memory:") - all_data = read_inputs(default_new_input) - data = all_data["global_commodities"] + +@fixture +def populate_regions(default_new_input, con): + from muse.new_input.readers import read_regions_csv + + with open(default_new_input / "regions.csv") as f: + return read_regions_csv(f, con) + + +@fixture +def populate_commodities(default_new_input, con): + from muse.new_input.readers import read_commodities_csv + + with open(default_new_input / "commodities.csv") as f: + return read_commodities_csv(f, con) + + +@fixture +def populate_demand(default_new_input, con, populate_regions, populate_commodities): + from muse.new_input.readers import read_demand_csv + + with open(default_new_input / "demand.csv") as f: + return read_demand_csv(f, con) + + +def test_read_regions(populate_regions): + assert populate_regions["name"] == np.array(["R1"]) + + +def test_read_new_global_commodities(populate_commodities): + data = populate_commodities + assert list(data["name"]) == ["electricity", "gas", "heat", "wind", "CO2f"] + assert list(data["type"]) == ["energy"] * 5 + assert list(data["unit"]) == ["PJ"] * 4 + ["kt"] + + +def test_calculate_global_commodities(populate_commodities): + from muse.new_input.readers import calculate_global_commodities + + data = calculate_global_commodities(populate_commodities) assert isinstance(data, xr.Dataset) assert set(data.dims) == {"commodity"} - assert dict(data.dtypes) == dict( - type=np.dtype("str"), - unit=np.dtype("str"), - ) + for dt in data.dtypes.values(): + assert np.issubdtype(dt, np.dtype("str")) + + assert list(data.coords["commodity"].values) == list(populate_commodities["name"]) + assert list(data.data_vars["type"].values) == list(populate_commodities["type"]) + assert list(data.data_vars["unit"].values) == list(populate_commodities["unit"]) + + +def test_read_new_global_commodities_type_constraint(default_new_input, con): + from muse.new_input.readers import read_commodities_csv + + csv = StringIO("name,type,unit\nfoo,invalid,bar\n") + with raises(duckdb.ConstraintException): + read_commodities_csv(csv, con) - assert list(data.coords["commodity"].values) == [ - "electricity", - "gas", - "heat", - "wind", - "CO2f", - ] - assert list(data.data_vars["type"].values) == ["energy"] * 5 - assert list(data.data_vars["unit"].values) == ["PJ"] * 4 + ["kt"] + +def test_new_read_demand_csv(populate_demand): + data = populate_demand + assert np.all(data["year"] == np.array([2020, 2050])) + assert np.all(data["commodity"] == np.array(["heat", "heat"])) + assert np.all(data["region"] == np.array(["R1", "R1"])) + assert np.all(data["demand"] == np.array([10, 30])) + + +def test_new_read_demand_csv_commodity_constraint( + default_new_input, con, populate_commodities, populate_regions +): + from muse.new_input.readers import read_demand_csv + + csv = StringIO("year,commodity_name,region,demand\n2020,invalid,R1,0\n") + with raises(duckdb.ConstraintException, match=".*foreign key.*"): + read_demand_csv(csv, con) + + +def test_new_read_demand_csv_region_constraint( + default_new_input, con, populate_commodities, populate_regions +): + from muse.new_input.readers import read_demand_csv + + csv = StringIO("year,commodity_name,region,demand\n2020,heat,invalid,0\n") + with raises(duckdb.ConstraintException, match=".*foreign key.*"): + read_demand_csv(csv, con) @mark.xfail -def test_read_demand(default_new_input): - from muse.new_input.readers import read_inputs +def test_demand_dataset(default_new_input): + import duckdb + from muse.new_input.readers import read_commodities, read_demand, read_regions - all_data = read_inputs(default_new_input) - data = all_data["demand"] + con = duckdb.connect(":memory:") + + read_regions(default_new_input, con) + read_commodities(default_new_input, con) + data = read_demand(default_new_input, con) assert isinstance(data, xr.DataArray) assert data.dtype == np.float64 From c5e1b24d95e981f9bf94eda573481b2df44f01ec Mon Sep 17 00:00:00 2001 From: Tom Bland Date: Wed, 3 Jul 2024 16:03:19 +0100 Subject: [PATCH 02/10] New db tables --- .../default_new_input/commodity_trade.csv | 2 +- src/muse/new_input/readers.py | 111 +++++++++++++----- 2 files changed, 84 insertions(+), 29 deletions(-) diff --git a/src/muse/data/example/default_new_input/commodity_trade.csv b/src/muse/data/example/default_new_input/commodity_trade.csv index eb23c4b6c..d32a72acc 100644 --- a/src/muse/data/example/default_new_input/commodity_trade.csv +++ b/src/muse/data/example/default_new_input/commodity_trade.csv @@ -1 +1 @@ -commodity,region,net_import,year +commodity,region,year,import,export diff --git a/src/muse/new_input/readers.py b/src/muse/new_input/readers.py index eafa4fb07..a02f40a84 100644 --- a/src/muse/new_input/readers.py +++ b/src/muse/new_input/readers.py @@ -7,28 +7,26 @@ def read_inputs(data_dir): data = {} con = duckdb.connect(":memory:") - with open(data_dir / "regions.csv") as f: - regions = read_regions_csv(f, con) # noqa: F841 - with open(data_dir / "commodities.csv") as f: commodities = read_commodities_csv(f, con) + with open(data_dir / "commodity_trade.csv") as f: + commodity_trade = read_commodity_trade_csv(f, con) # noqa: F841 + + with open(data_dir / "commodity_costs.csv") as f: + commodity_costs = read_commodity_costs_csv(f, con) # noqa: F841 + with open(data_dir / "demand.csv") as f: demand = read_demand_csv(f, con) # noqa: F841 - data["global_commodities"] = calculate_global_commodities(commodities) - return data + with open(data_dir / "demand_slicing.csv") as f: + demand_slicing = read_demand_slicing_csv(f, con) # noqa: F841 + with open(data_dir / "regions.csv") as f: + regions = read_regions_csv(f, con) # noqa: F841 -def read_regions_csv(buffer_, con): - sql = """CREATE TABLE regions ( - name VARCHAR PRIMARY KEY, - ); - """ - con.sql(sql) - rel = con.read_csv(buffer_, header=True, delimiter=",") # noqa: F841 - con.sql("INSERT INTO regions SELECT name FROM rel;") - return con.sql("SELECT name from regions").fetchnumpy() + data["global_commodities"] = calculate_global_commodities(commodities) + return data def read_commodities_csv(buffer_, con): @@ -41,25 +39,38 @@ def read_commodities_csv(buffer_, con): con.sql(sql) rel = con.read_csv(buffer_, header=True, delimiter=",") # noqa: F841 con.sql("INSERT INTO commodities SELECT name, type, unit FROM rel;") - return con.sql("select name, type, unit from commodities").fetchnumpy() -def calculate_global_commodities(commodities): - names = commodities["name"].astype(np.dtype("str")) - types = commodities["type"].astype(np.dtype("str")) - units = commodities["unit"].astype(np.dtype("str")) - - type_array = xr.DataArray( - data=types, dims=["commodity"], coords=dict(commodity=names) - ) +def read_commodity_trade_csv(buffer_, con): + sql = """CREATE TABLE commodity_trade ( + commodity VARCHAR REFERENCES commodities(name), + region VARCHAR REFERENCES regions(name), + year BIGINT, + import DOUBLE, + export DOUBLE, + ); + """ + con.sql(sql) + rel = con.read_csv(buffer_, header=True, delimiter=",") # noqa: F841 + con.sql("""INSERT INTO commodity_trade SELECT + commodity, region, year, import, export FROM rel;""") + return con.sql("SELECT * from commodity_trade").fetchnumpy() - unit_array = xr.DataArray( - data=units, dims=["commodity"], coords=dict(commodity=names) - ) - data = xr.Dataset(data_vars=dict(type=type_array, unit=unit_array)) - return data +def read_commodity_costs_csv(buffer_, con): + sql = """CREATE TABLE commodity_costs ( + year BIGINT, + region VARCHAR REFERENCES regions(name), + commodity VARCHAR REFERENCES commodities(name), + value DOUBLE, + ); + """ + con.sql(sql) + rel = con.read_csv(buffer_, header=True, delimiter=",") # noqa: F841 + con.sql("""INSERT INTO commodity_costs SELECT + year, region, commodity_name, value FROM rel;""") + return con.sql("SELECT * from commodity_costs").fetchnumpy() def read_demand_csv(buffer_, con): @@ -74,3 +85,47 @@ def read_demand_csv(buffer_, con): rel = con.read_csv(buffer_, header=True, delimiter=",") # noqa: F841 con.sql("INSERT INTO demand SELECT year, commodity_name, region, demand FROM rel;") return con.sql("SELECT * from demand").fetchnumpy() + + +def read_demand_slicing_csv(buffer_, con): + sql = """CREATE TABLE demand_slicing ( + commodity VARCHAR REFERENCES commodities(name), + region VARCHAR REFERENCES regions(name), + timeslice VARCHAR, + fraction DOUBLE CHECK (fraction >= 0 AND fraction <= 1), + year BIGINT, + ); + """ + con.sql(sql) + rel = con.read_csv(buffer_, header=True, delimiter=",") # noqa: F841 + con.sql("""INSERT INTO demand_slicing SELECT + commodity, region, timeslice, fraction, year FROM rel;""") + return con.sql("SELECT * from demand_slicing").fetchnumpy() + + +def read_regions_csv(buffer_, con): + sql = """CREATE TABLE regions ( + name VARCHAR PRIMARY KEY, + ); + """ + con.sql(sql) + rel = con.read_csv(buffer_, header=True, delimiter=",") # noqa: F841 + con.sql("INSERT INTO regions SELECT name FROM rel;") + return con.sql("SELECT name from regions").fetchnumpy() + + +def calculate_global_commodities(commodities): + names = commodities["name"].astype(np.dtype("str")) + types = commodities["type"].astype(np.dtype("str")) + units = commodities["unit"].astype(np.dtype("str")) + + type_array = xr.DataArray( + data=types, dims=["commodity"], coords=dict(commodity=names) + ) + + unit_array = xr.DataArray( + data=units, dims=["commodity"], coords=dict(commodity=names) + ) + + data = xr.Dataset(data_vars=dict(type=type_array, unit=unit_array)) + return data From 4afc763b32a405cdf06c74f9d47b735ec1ec7afd Mon Sep 17 00:00:00 2001 From: Tom Bland Date: Thu, 4 Jul 2024 09:29:16 +0100 Subject: [PATCH 03/10] Update tables for new csv columns --- src/muse/new_input/readers.py | 40 +++++++++++++++++------------------ tests/test_readers.py | 28 ++++++++++++------------ 2 files changed, 34 insertions(+), 34 deletions(-) diff --git a/src/muse/new_input/readers.py b/src/muse/new_input/readers.py index a02f40a84..b9228f5bf 100644 --- a/src/muse/new_input/readers.py +++ b/src/muse/new_input/readers.py @@ -31,21 +31,21 @@ def read_inputs(data_dir): def read_commodities_csv(buffer_, con): sql = """CREATE TABLE commodities ( - name VARCHAR PRIMARY KEY, + id VARCHAR PRIMARY KEY, type VARCHAR CHECK (type IN ('energy', 'service', 'material', 'environmental')), unit VARCHAR, ); """ con.sql(sql) rel = con.read_csv(buffer_, header=True, delimiter=",") # noqa: F841 - con.sql("INSERT INTO commodities SELECT name, type, unit FROM rel;") - return con.sql("select name, type, unit from commodities").fetchnumpy() + con.sql("INSERT INTO commodities SELECT id, type, unit FROM rel;") + return con.sql("select * from commodities").fetchnumpy() def read_commodity_trade_csv(buffer_, con): sql = """CREATE TABLE commodity_trade ( - commodity VARCHAR REFERENCES commodities(name), - region VARCHAR REFERENCES regions(name), + commodity VARCHAR REFERENCES commodities(id), + region VARCHAR REFERENCES regions(id), year BIGINT, import DOUBLE, export DOUBLE, @@ -54,68 +54,68 @@ def read_commodity_trade_csv(buffer_, con): con.sql(sql) rel = con.read_csv(buffer_, header=True, delimiter=",") # noqa: F841 con.sql("""INSERT INTO commodity_trade SELECT - commodity, region, year, import, export FROM rel;""") + commodity_id, region_id, year, import, export FROM rel;""") return con.sql("SELECT * from commodity_trade").fetchnumpy() def read_commodity_costs_csv(buffer_, con): sql = """CREATE TABLE commodity_costs ( + commodity VARCHAR REFERENCES commodities(id), + region VARCHAR REFERENCES regions(id), year BIGINT, - region VARCHAR REFERENCES regions(name), - commodity VARCHAR REFERENCES commodities(name), value DOUBLE, ); """ con.sql(sql) rel = con.read_csv(buffer_, header=True, delimiter=",") # noqa: F841 con.sql("""INSERT INTO commodity_costs SELECT - year, region, commodity_name, value FROM rel;""") + commidity_id, region_id, year, value FROM rel;""") return con.sql("SELECT * from commodity_costs").fetchnumpy() def read_demand_csv(buffer_, con): sql = """CREATE TABLE demand ( + commodity VARCHAR REFERENCES commodities(id), + region VARCHAR REFERENCES regions(id), year BIGINT, - commodity VARCHAR REFERENCES commodities(name), - region VARCHAR REFERENCES regions(name), demand DOUBLE, ); """ con.sql(sql) rel = con.read_csv(buffer_, header=True, delimiter=",") # noqa: F841 - con.sql("INSERT INTO demand SELECT year, commodity_name, region, demand FROM rel;") + con.sql("INSERT INTO demand SELECT commodity_id, region_id, year, demand FROM rel;") return con.sql("SELECT * from demand").fetchnumpy() def read_demand_slicing_csv(buffer_, con): sql = """CREATE TABLE demand_slicing ( - commodity VARCHAR REFERENCES commodities(name), - region VARCHAR REFERENCES regions(name), + commodity VARCHAR REFERENCES commodities(id), + region VARCHAR REFERENCES regions(id), + year BIGINT, timeslice VARCHAR, fraction DOUBLE CHECK (fraction >= 0 AND fraction <= 1), - year BIGINT, ); """ con.sql(sql) rel = con.read_csv(buffer_, header=True, delimiter=",") # noqa: F841 con.sql("""INSERT INTO demand_slicing SELECT - commodity, region, timeslice, fraction, year FROM rel;""") + commodity_id, region_id, year, timeslice, fraction FROM rel;""") return con.sql("SELECT * from demand_slicing").fetchnumpy() def read_regions_csv(buffer_, con): sql = """CREATE TABLE regions ( - name VARCHAR PRIMARY KEY, + id VARCHAR PRIMARY KEY, ); """ con.sql(sql) rel = con.read_csv(buffer_, header=True, delimiter=",") # noqa: F841 - con.sql("INSERT INTO regions SELECT name FROM rel;") - return con.sql("SELECT name from regions").fetchnumpy() + con.sql("INSERT INTO regions SELECT id FROM rel;") + return con.sql("SELECT * from regions").fetchnumpy() def calculate_global_commodities(commodities): - names = commodities["name"].astype(np.dtype("str")) + names = commodities["id"].astype(np.dtype("str")) types = commodities["type"].astype(np.dtype("str")) units = commodities["unit"].astype(np.dtype("str")) diff --git a/tests/test_readers.py b/tests/test_readers.py index 497b9601f..adb1f8a23 100644 --- a/tests/test_readers.py +++ b/tests/test_readers.py @@ -858,14 +858,6 @@ def con(): return duckdb.connect(":memory:") -@fixture -def populate_regions(default_new_input, con): - from muse.new_input.readers import read_regions_csv - - with open(default_new_input / "regions.csv") as f: - return read_regions_csv(f, con) - - @fixture def populate_commodities(default_new_input, con): from muse.new_input.readers import read_commodities_csv @@ -882,13 +874,21 @@ def populate_demand(default_new_input, con, populate_regions, populate_commoditi return read_demand_csv(f, con) +@fixture +def populate_regions(default_new_input, con): + from muse.new_input.readers import read_regions_csv + + with open(default_new_input / "regions.csv") as f: + return read_regions_csv(f, con) + + def test_read_regions(populate_regions): - assert populate_regions["name"] == np.array(["R1"]) + assert populate_regions["id"] == np.array(["R1"]) def test_read_new_global_commodities(populate_commodities): data = populate_commodities - assert list(data["name"]) == ["electricity", "gas", "heat", "wind", "CO2f"] + assert list(data["id"]) == ["electricity", "gas", "heat", "wind", "CO2f"] assert list(data["type"]) == ["energy"] * 5 assert list(data["unit"]) == ["PJ"] * 4 + ["kt"] @@ -903,7 +903,7 @@ def test_calculate_global_commodities(populate_commodities): for dt in data.dtypes.values(): assert np.issubdtype(dt, np.dtype("str")) - assert list(data.coords["commodity"].values) == list(populate_commodities["name"]) + assert list(data.coords["commodity"].values) == list(populate_commodities["id"]) assert list(data.data_vars["type"].values) == list(populate_commodities["type"]) assert list(data.data_vars["unit"].values) == list(populate_commodities["unit"]) @@ -911,7 +911,7 @@ def test_calculate_global_commodities(populate_commodities): def test_read_new_global_commodities_type_constraint(default_new_input, con): from muse.new_input.readers import read_commodities_csv - csv = StringIO("name,type,unit\nfoo,invalid,bar\n") + csv = StringIO("id,type,unit\nfoo,invalid,bar\n") with raises(duckdb.ConstraintException): read_commodities_csv(csv, con) @@ -929,7 +929,7 @@ def test_new_read_demand_csv_commodity_constraint( ): from muse.new_input.readers import read_demand_csv - csv = StringIO("year,commodity_name,region,demand\n2020,invalid,R1,0\n") + csv = StringIO("year,commodity_id,region_id,demand\n2020,invalid,R1,0\n") with raises(duckdb.ConstraintException, match=".*foreign key.*"): read_demand_csv(csv, con) @@ -939,7 +939,7 @@ def test_new_read_demand_csv_region_constraint( ): from muse.new_input.readers import read_demand_csv - csv = StringIO("year,commodity_name,region,demand\n2020,heat,invalid,0\n") + csv = StringIO("year,commodity_id,region_id,demand\n2020,heat,invalid,0\n") with raises(duckdb.ConstraintException, match=".*foreign key.*"): read_demand_csv(csv, con) From e9ed5ede2e976f3c319d2ecf612f323bc362e485 Mon Sep 17 00:00:00 2001 From: Tom Bland Date: Fri, 5 Jul 2024 14:15:29 +0100 Subject: [PATCH 04/10] Split new tests into new file --- tests/test_new_readers.py | 221 ++++++++++++++++++++++++++++++++++++++ tests/test_readers.py | 219 +------------------------------------ 2 files changed, 222 insertions(+), 218 deletions(-) create mode 100644 tests/test_new_readers.py diff --git a/tests/test_new_readers.py b/tests/test_new_readers.py new file mode 100644 index 000000000..9665df14d --- /dev/null +++ b/tests/test_new_readers.py @@ -0,0 +1,221 @@ +from io import StringIO + +import duckdb +import numpy as np +import xarray as xr +from pytest import approx, fixture, mark, raises + + +@fixture +def default_new_input(tmp_path): + from muse.examples import copy_model + + copy_model("default_new_input", tmp_path) + return tmp_path / "model" + + +@fixture +def con(): + return duckdb.connect(":memory:") + + +@fixture +def populate_commodities(default_new_input, con): + from muse.new_input.readers import read_commodities_csv + + with open(default_new_input / "commodities.csv") as f: + return read_commodities_csv(f, con) + + +@fixture +def populate_demand(default_new_input, con, populate_regions, populate_commodities): + from muse.new_input.readers import read_demand_csv + + with open(default_new_input / "demand.csv") as f: + return read_demand_csv(f, con) + + +@fixture +def populate_regions(default_new_input, con): + from muse.new_input.readers import read_regions_csv + + with open(default_new_input / "regions.csv") as f: + return read_regions_csv(f, con) + + +def test_read_regions(populate_regions): + assert populate_regions["id"] == np.array(["R1"]) + + +def test_read_new_global_commodities(populate_commodities): + data = populate_commodities + assert list(data["id"]) == ["electricity", "gas", "heat", "wind", "CO2f"] + assert list(data["type"]) == ["energy"] * 5 + assert list(data["unit"]) == ["PJ"] * 4 + ["kt"] + + +def test_calculate_global_commodities(populate_commodities): + from muse.new_input.readers import calculate_global_commodities + + data = calculate_global_commodities(populate_commodities) + + assert isinstance(data, xr.Dataset) + assert set(data.dims) == {"commodity"} + for dt in data.dtypes.values(): + assert np.issubdtype(dt, np.dtype("str")) + + assert list(data.coords["commodity"].values) == list(populate_commodities["id"]) + assert list(data.data_vars["type"].values) == list(populate_commodities["type"]) + assert list(data.data_vars["unit"].values) == list(populate_commodities["unit"]) + + +def test_read_new_global_commodities_type_constraint(default_new_input, con): + from muse.new_input.readers import read_commodities_csv + + csv = StringIO("id,type,unit\nfoo,invalid,bar\n") + with raises(duckdb.ConstraintException): + read_commodities_csv(csv, con) + + +def test_new_read_demand_csv(populate_demand): + data = populate_demand + assert np.all(data["year"] == np.array([2020, 2050])) + assert np.all(data["commodity"] == np.array(["heat", "heat"])) + assert np.all(data["region"] == np.array(["R1", "R1"])) + assert np.all(data["demand"] == np.array([10, 30])) + + +def test_new_read_demand_csv_commodity_constraint( + default_new_input, con, populate_commodities, populate_regions +): + from muse.new_input.readers import read_demand_csv + + csv = StringIO("year,commodity_id,region_id,demand\n2020,invalid,R1,0\n") + with raises(duckdb.ConstraintException, match=".*foreign key.*"): + read_demand_csv(csv, con) + + +def test_new_read_demand_csv_region_constraint( + default_new_input, con, populate_commodities, populate_regions +): + from muse.new_input.readers import read_demand_csv + + csv = StringIO("year,commodity_id,region_id,demand\n2020,heat,invalid,0\n") + with raises(duckdb.ConstraintException, match=".*foreign key.*"): + read_demand_csv(csv, con) + + +@mark.xfail +def test_demand_dataset(default_new_input): + import duckdb + from muse.new_input.readers import read_commodities, read_demand, read_regions + + con = duckdb.connect(":memory:") + + read_regions(default_new_input, con) + read_commodities(default_new_input, con) + data = read_demand(default_new_input, con) + + assert isinstance(data, xr.DataArray) + assert data.dtype == np.float64 + + assert set(data.dims) == {"year", "commodity", "region", "timeslice"} + assert list(data.coords["region"].values) == ["R1"] + assert list(data.coords["timeslice"].values) == list(range(1, 7)) + assert list(data.coords["year"].values) == [2020, 2050] + assert set(data.coords["commodity"].values) == { + "electricity", + "gas", + "heat", + "wind", + "CO2f", + } + + assert data.sel(year=2020, commodity="electricity", region="R1", timeslice=0) == 1 + + +@mark.xfail +def test_new_read_initial_market(default_new_input): + from muse.new_input.readers import read_inputs + + all_data = read_inputs(default_new_input) + data = all_data["initial_market"] + + assert isinstance(data, xr.Dataset) + assert set(data.dims) == {"region", "year", "commodity", "timeslice"} + assert dict(data.dtypes) == dict( + prices=np.float64, + exports=np.float64, + imports=np.float64, + static_trade=np.float64, + ) + assert list(data.coords["region"].values) == ["R1"] + assert list(data.coords["year"].values) == list(range(2010, 2105, 5)) + assert list(data.coords["commodity"].values) == [ + "electricity", + "gas", + "heat", + "CO2f", + "wind", + ] + month_values = ["all-year"] * 6 + day_values = ["all-week"] * 6 + hour_values = [ + "night", + "morning", + "afternoon", + "early-peak", + "late-peak", + "evening", + ] + + assert list(data.coords["timeslice"].values) == list( + zip(month_values, day_values, hour_values) + ) + assert list(data.coords["month"]) == month_values + assert list(data.coords["day"]) == day_values + assert list(data.coords["hour"]) == hour_values + + assert all(var.coords.equals(data.coords) for var in data.data_vars.values()) + + prices = data.data_vars["prices"] + assert approx( + prices.sel( + year=2010, + region="R1", + commodity="electricity", + timeslice=("all-year", "all-week", "night"), + ) + - 14.81481, + abs=1e-4, + ) + + exports = data.data_vars["exports"] + assert ( + exports.sel( + year=2010, + region="R1", + commodity="electricity", + timeslice=("all-year", "all-week", "night"), + ) + ) == 0 + + imports = data.data_vars["imports"] + assert ( + imports.sel( + year=2010, + region="R1", + commodity="electricity", + timeslice=("all-year", "all-week", "night"), + ) + ) == 0 + + static_trade = data.data_vars["static_trade"] + assert ( + static_trade.sel( + year=2010, + region="R1", + commodity="electricity", + timeslice=("all-year", "all-week", "night"), + ) + ) == 0 diff --git a/tests/test_readers.py b/tests/test_readers.py index adb1f8a23..7545bff4a 100644 --- a/tests/test_readers.py +++ b/tests/test_readers.py @@ -1,13 +1,11 @@ -from io import StringIO from itertools import chain, permutations from pathlib import Path from unittest.mock import patch -import duckdb import numpy as np import toml import xarray as xr -from pytest import approx, fixture, mark, raises +from pytest import fixture, mark, raises @fixture @@ -843,218 +841,3 @@ def test_check_utilization_and_minimum_service_factors_fail_missing_utilization( with raises(ValueError): check_utilization_and_minimum_service_factors(df, "file.csv") - - -@fixture -def default_new_input(tmp_path): - from muse.examples import copy_model - - copy_model("default_new_input", tmp_path) - return tmp_path / "model" - - -@fixture -def con(): - return duckdb.connect(":memory:") - - -@fixture -def populate_commodities(default_new_input, con): - from muse.new_input.readers import read_commodities_csv - - with open(default_new_input / "commodities.csv") as f: - return read_commodities_csv(f, con) - - -@fixture -def populate_demand(default_new_input, con, populate_regions, populate_commodities): - from muse.new_input.readers import read_demand_csv - - with open(default_new_input / "demand.csv") as f: - return read_demand_csv(f, con) - - -@fixture -def populate_regions(default_new_input, con): - from muse.new_input.readers import read_regions_csv - - with open(default_new_input / "regions.csv") as f: - return read_regions_csv(f, con) - - -def test_read_regions(populate_regions): - assert populate_regions["id"] == np.array(["R1"]) - - -def test_read_new_global_commodities(populate_commodities): - data = populate_commodities - assert list(data["id"]) == ["electricity", "gas", "heat", "wind", "CO2f"] - assert list(data["type"]) == ["energy"] * 5 - assert list(data["unit"]) == ["PJ"] * 4 + ["kt"] - - -def test_calculate_global_commodities(populate_commodities): - from muse.new_input.readers import calculate_global_commodities - - data = calculate_global_commodities(populate_commodities) - - assert isinstance(data, xr.Dataset) - assert set(data.dims) == {"commodity"} - for dt in data.dtypes.values(): - assert np.issubdtype(dt, np.dtype("str")) - - assert list(data.coords["commodity"].values) == list(populate_commodities["id"]) - assert list(data.data_vars["type"].values) == list(populate_commodities["type"]) - assert list(data.data_vars["unit"].values) == list(populate_commodities["unit"]) - - -def test_read_new_global_commodities_type_constraint(default_new_input, con): - from muse.new_input.readers import read_commodities_csv - - csv = StringIO("id,type,unit\nfoo,invalid,bar\n") - with raises(duckdb.ConstraintException): - read_commodities_csv(csv, con) - - -def test_new_read_demand_csv(populate_demand): - data = populate_demand - assert np.all(data["year"] == np.array([2020, 2050])) - assert np.all(data["commodity"] == np.array(["heat", "heat"])) - assert np.all(data["region"] == np.array(["R1", "R1"])) - assert np.all(data["demand"] == np.array([10, 30])) - - -def test_new_read_demand_csv_commodity_constraint( - default_new_input, con, populate_commodities, populate_regions -): - from muse.new_input.readers import read_demand_csv - - csv = StringIO("year,commodity_id,region_id,demand\n2020,invalid,R1,0\n") - with raises(duckdb.ConstraintException, match=".*foreign key.*"): - read_demand_csv(csv, con) - - -def test_new_read_demand_csv_region_constraint( - default_new_input, con, populate_commodities, populate_regions -): - from muse.new_input.readers import read_demand_csv - - csv = StringIO("year,commodity_id,region_id,demand\n2020,heat,invalid,0\n") - with raises(duckdb.ConstraintException, match=".*foreign key.*"): - read_demand_csv(csv, con) - - -@mark.xfail -def test_demand_dataset(default_new_input): - import duckdb - from muse.new_input.readers import read_commodities, read_demand, read_regions - - con = duckdb.connect(":memory:") - - read_regions(default_new_input, con) - read_commodities(default_new_input, con) - data = read_demand(default_new_input, con) - - assert isinstance(data, xr.DataArray) - assert data.dtype == np.float64 - - assert set(data.dims) == {"year", "commodity", "region", "timeslice"} - assert list(data.coords["region"].values) == ["R1"] - assert list(data.coords["timeslice"].values) == list(range(1, 7)) - assert list(data.coords["year"].values) == [2020, 2050] - assert set(data.coords["commodity"].values) == { - "electricity", - "gas", - "heat", - "wind", - "CO2f", - } - - assert data.sel(year=2020, commodity="electricity", region="R1", timeslice=0) == 1 - - -@mark.xfail -def test_new_read_initial_market(default_new_input): - from muse.new_input.readers import read_inputs - - all_data = read_inputs(default_new_input) - data = all_data["initial_market"] - - assert isinstance(data, xr.Dataset) - assert set(data.dims) == {"region", "year", "commodity", "timeslice"} - assert dict(data.dtypes) == dict( - prices=np.float64, - exports=np.float64, - imports=np.float64, - static_trade=np.float64, - ) - assert list(data.coords["region"].values) == ["R1"] - assert list(data.coords["year"].values) == list(range(2010, 2105, 5)) - assert list(data.coords["commodity"].values) == [ - "electricity", - "gas", - "heat", - "CO2f", - "wind", - ] - month_values = ["all-year"] * 6 - day_values = ["all-week"] * 6 - hour_values = [ - "night", - "morning", - "afternoon", - "early-peak", - "late-peak", - "evening", - ] - - assert list(data.coords["timeslice"].values) == list( - zip(month_values, day_values, hour_values) - ) - assert list(data.coords["month"]) == month_values - assert list(data.coords["day"]) == day_values - assert list(data.coords["hour"]) == hour_values - - assert all(var.coords.equals(data.coords) for var in data.data_vars.values()) - - prices = data.data_vars["prices"] - assert approx( - prices.sel( - year=2010, - region="R1", - commodity="electricity", - timeslice=("all-year", "all-week", "night"), - ) - - 14.81481, - abs=1e-4, - ) - - exports = data.data_vars["exports"] - assert ( - exports.sel( - year=2010, - region="R1", - commodity="electricity", - timeslice=("all-year", "all-week", "night"), - ) - ) == 0 - - imports = data.data_vars["imports"] - assert ( - imports.sel( - year=2010, - region="R1", - commodity="electricity", - timeslice=("all-year", "all-week", "night"), - ) - ) == 0 - - static_trade = data.data_vars["static_trade"] - assert ( - static_trade.sel( - year=2010, - region="R1", - commodity="electricity", - timeslice=("all-year", "all-week", "night"), - ) - ) == 0 From 81a837f083fc233facb81750d288da2895106883 Mon Sep 17 00:00:00 2001 From: Tom Bland Date: Mon, 8 Jul 2024 10:28:08 +0100 Subject: [PATCH 05/10] Tests for new tables --- src/muse/new_input/readers.py | 2 +- tests/test_new_readers.py | 80 ++++++++++++++++++++++++++++------- 2 files changed, 65 insertions(+), 17 deletions(-) diff --git a/src/muse/new_input/readers.py b/src/muse/new_input/readers.py index b9228f5bf..4e2c09de0 100644 --- a/src/muse/new_input/readers.py +++ b/src/muse/new_input/readers.py @@ -69,7 +69,7 @@ def read_commodity_costs_csv(buffer_, con): con.sql(sql) rel = con.read_csv(buffer_, header=True, delimiter=",") # noqa: F841 con.sql("""INSERT INTO commodity_costs SELECT - commidity_id, region_id, year, value FROM rel;""") + commodity_id, region_id, year, value FROM rel;""") return con.sql("SELECT * from commodity_costs").fetchnumpy() diff --git a/tests/test_new_readers.py b/tests/test_new_readers.py index 9665df14d..010883287 100644 --- a/tests/test_new_readers.py +++ b/tests/test_new_readers.py @@ -27,6 +27,26 @@ def populate_commodities(default_new_input, con): return read_commodities_csv(f, con) +@fixture +def populate_commodity_trade( + default_new_input, con, populate_commodities, populate_regions +): + from muse.new_input.readers import read_commodity_trade_csv + + with open(default_new_input / "commodity_trade.csv") as f: + return read_commodity_trade_csv(f, con) + + +@fixture +def populate_commodity_costs( + default_new_input, con, populate_commodities, populate_regions +): + from muse.new_input.readers import read_commodity_costs_csv + + with open(default_new_input / "commodity_costs.csv") as f: + return read_commodity_costs_csv(f, con) + + @fixture def populate_demand(default_new_input, con, populate_regions, populate_commodities): from muse.new_input.readers import read_demand_csv @@ -35,6 +55,16 @@ def populate_demand(default_new_input, con, populate_regions, populate_commoditi return read_demand_csv(f, con) +@fixture +def populate_demand_slicing( + default_new_input, con, populate_regions, populate_commodities +): + from muse.new_input.readers import read_demand_slicing_csv + + with open(default_new_input / "demand_slicing.csv") as f: + return read_demand_slicing_csv(f, con) + + @fixture def populate_regions(default_new_input, con): from muse.new_input.readers import read_regions_csv @@ -43,17 +73,43 @@ def populate_regions(default_new_input, con): return read_regions_csv(f, con) -def test_read_regions(populate_regions): - assert populate_regions["id"] == np.array(["R1"]) - - -def test_read_new_global_commodities(populate_commodities): +def test_read_commodities_csv(populate_commodities): data = populate_commodities assert list(data["id"]) == ["electricity", "gas", "heat", "wind", "CO2f"] assert list(data["type"]) == ["energy"] * 5 assert list(data["unit"]) == ["PJ"] * 4 + ["kt"] +def test_read_commodity_trade_csv(populate_commodity_trade): + data = populate_commodity_trade + assert data["commodity"].size == 0 + assert data["region"].size == 0 + assert data["year"].size == 0 + assert data["import"].size == 0 + assert data["export"].size == 0 + + +def test_read_commodity_costs_csv(populate_commodity_costs): + data = populate_commodity_costs + # Only checking the first element of each array, as the table is large + assert next(iter(data["commodity"])) == "electricity" + assert next(iter(data["region"])) == "R1" + assert next(iter(data["year"])) == 2010 + assert next(iter(data["value"])) == approx(14.81481) + + +def test_read_demand_csv(populate_demand): + data = populate_demand + assert np.all(data["year"] == np.array([2020, 2050])) + assert np.all(data["commodity"] == np.array(["heat", "heat"])) + assert np.all(data["region"] == np.array(["R1", "R1"])) + assert np.all(data["demand"] == np.array([10, 30])) + + +def test_read_regions_csv(populate_regions): + assert populate_regions["id"] == np.array(["R1"]) + + def test_calculate_global_commodities(populate_commodities): from muse.new_input.readers import calculate_global_commodities @@ -69,7 +125,7 @@ def test_calculate_global_commodities(populate_commodities): assert list(data.data_vars["unit"].values) == list(populate_commodities["unit"]) -def test_read_new_global_commodities_type_constraint(default_new_input, con): +def test_read_global_commodities_type_constraint(default_new_input, con): from muse.new_input.readers import read_commodities_csv csv = StringIO("id,type,unit\nfoo,invalid,bar\n") @@ -77,15 +133,7 @@ def test_read_new_global_commodities_type_constraint(default_new_input, con): read_commodities_csv(csv, con) -def test_new_read_demand_csv(populate_demand): - data = populate_demand - assert np.all(data["year"] == np.array([2020, 2050])) - assert np.all(data["commodity"] == np.array(["heat", "heat"])) - assert np.all(data["region"] == np.array(["R1", "R1"])) - assert np.all(data["demand"] == np.array([10, 30])) - - -def test_new_read_demand_csv_commodity_constraint( +def test_read_demand_csv_commodity_constraint( default_new_input, con, populate_commodities, populate_regions ): from muse.new_input.readers import read_demand_csv @@ -95,7 +143,7 @@ def test_new_read_demand_csv_commodity_constraint( read_demand_csv(csv, con) -def test_new_read_demand_csv_region_constraint( +def test_read_demand_csv_region_constraint( default_new_input, con, populate_commodities, populate_regions ): from muse.new_input.readers import read_demand_csv From c28e8aa42879f9e9b62ba75e24c67c0433f6fa05 Mon Sep 17 00:00:00 2001 From: Tom Bland Date: Thu, 15 Aug 2024 10:39:58 +0100 Subject: [PATCH 06/10] Add functions for demand data and (in progress) initial market --- .../default_new_input/demand_slicing.csv | 20 +- .../example/default_new_input/time_slices.csv | 7 - .../example/default_new_input/timeslices.csv | 7 + src/muse/new_input/readers.py | 264 ++++++++++++++++-- tests/test_new_readers.py | 128 ++++++--- 5 files changed, 350 insertions(+), 76 deletions(-) delete mode 100644 src/muse/data/example/default_new_input/time_slices.csv create mode 100644 src/muse/data/example/default_new_input/timeslices.csv diff --git a/src/muse/data/example/default_new_input/demand_slicing.csv b/src/muse/data/example/default_new_input/demand_slicing.csv index 6877d5663..662b32869 100644 --- a/src/muse/data/example/default_new_input/demand_slicing.csv +++ b/src/muse/data/example/default_new_input/demand_slicing.csv @@ -1,7 +1,13 @@ -commodity_id,region_id,year,timeslice,fraction -heat,R1,,night,0.1 -heat,R1,,morning,0.15 -heat,R1,,afternoon,0.1 -heat,R1,,early-peak,0.15 -heat,R1,,late-peak,0.3 -heat,R1,,evening,0.2 +commodity_id,region_id,year,timeslice_id,fraction +heat,R1,2020,1,0.1 +heat,R1,2020,2,0.15 +heat,R1,2020,3,0.1 +heat,R1,2020,4,0.15 +heat,R1,2020,5,0.3 +heat,R1,2020,6,0.2 +heat,R1,2050,1,0.1 +heat,R1,2050,2,0.15 +heat,R1,2050,3,0.1 +heat,R1,2050,4,0.15 +heat,R1,2050,5,0.3 +heat,R1,2050,6,0.2 diff --git a/src/muse/data/example/default_new_input/time_slices.csv b/src/muse/data/example/default_new_input/time_slices.csv deleted file mode 100644 index 376022d96..000000000 --- a/src/muse/data/example/default_new_input/time_slices.csv +++ /dev/null @@ -1,7 +0,0 @@ -season,day,time_of_day,fraction -all,all,night,0.1667 -all,all,morning,0.1667 -all,all,afternoon,0.1667 -all,all,early-peak,0.1667 -all,all,late-peak,0.1667 -all,all,evening,0.1667 diff --git a/src/muse/data/example/default_new_input/timeslices.csv b/src/muse/data/example/default_new_input/timeslices.csv new file mode 100644 index 000000000..c83832c90 --- /dev/null +++ b/src/muse/data/example/default_new_input/timeslices.csv @@ -0,0 +1,7 @@ +id,season,day,time_of_day,fraction +1,all,all,night,0.1667 +2,all,all,morning,0.1667 +3,all,all,afternoon,0.1667 +4,all,all,early-peak,0.1667 +5,all,all,late-peak,0.1667 +6,all,all,evening,0.1667 diff --git a/src/muse/new_input/readers.py b/src/muse/new_input/readers.py index 4e2c09de0..b216ba0d4 100644 --- a/src/muse/new_input/readers.py +++ b/src/muse/new_input/readers.py @@ -1,5 +1,6 @@ import duckdb import numpy as np +import pandas as pd import xarray as xr @@ -7,28 +8,54 @@ def read_inputs(data_dir): data = {} con = duckdb.connect(":memory:") + with open(data_dir / "timeslices.csv") as f: + timeslices = read_timeslices_csv(f, con) + with open(data_dir / "commodities.csv") as f: commodities = read_commodities_csv(f, con) + with open(data_dir / "regions.csv") as f: + regions = read_regions_csv(f, con) + with open(data_dir / "commodity_trade.csv") as f: - commodity_trade = read_commodity_trade_csv(f, con) # noqa: F841 + commodity_trade = read_commodity_trade_csv(f, con) with open(data_dir / "commodity_costs.csv") as f: - commodity_costs = read_commodity_costs_csv(f, con) # noqa: F841 + commodity_costs = read_commodity_costs_csv(f, con) with open(data_dir / "demand.csv") as f: - demand = read_demand_csv(f, con) # noqa: F841 + demand = read_demand_csv(f, con) with open(data_dir / "demand_slicing.csv") as f: - demand_slicing = read_demand_slicing_csv(f, con) # noqa: F841 - - with open(data_dir / "regions.csv") as f: - regions = read_regions_csv(f, con) # noqa: F841 + demand_slicing = read_demand_slicing_csv(f, con) data["global_commodities"] = calculate_global_commodities(commodities) + data["demand"] = calculate_demand( + commodities, regions, timeslices, demand, demand_slicing + ) + data["initial_market"] = calculate_initial_market( + commodities, regions, timeslices, commodity_trade, commodity_costs + ) return data +def read_timeslices_csv(buffer_, con): + sql = """CREATE TABLE timeslices ( + id VARCHAR PRIMARY KEY, + season VARCHAR, + day VARCHAR, + time_of_day VARCHAR, + fraction DOUBLE CHECK (fraction >= 0 AND fraction <= 1), + ); + """ + con.sql(sql) + rel = con.read_csv(buffer_, header=True, delimiter=",") # noqa: F841 + con.sql( + "INSERT INTO timeslices SELECT id, season, day, time_of_day, fraction FROM rel;" + ) + return con.sql("SELECT * from timeslices").fetchnumpy() + + def read_commodities_csv(buffer_, con): sql = """CREATE TABLE commodities ( id VARCHAR PRIMARY KEY, @@ -42,6 +69,17 @@ def read_commodities_csv(buffer_, con): return con.sql("select * from commodities").fetchnumpy() +def read_regions_csv(buffer_, con): + sql = """CREATE TABLE regions ( + id VARCHAR PRIMARY KEY, + ); + """ + con.sql(sql) + rel = con.read_csv(buffer_, header=True, delimiter=",") # noqa: F841 + con.sql("INSERT INTO regions SELECT id FROM rel;") + return con.sql("SELECT * from regions").fetchnumpy() + + def read_commodity_trade_csv(buffer_, con): sql = """CREATE TABLE commodity_trade ( commodity VARCHAR REFERENCES commodities(id), @@ -49,6 +87,7 @@ def read_commodity_trade_csv(buffer_, con): year BIGINT, import DOUBLE, export DOUBLE, + PRIMARY KEY (commodity, region, year) ); """ con.sql(sql) @@ -64,6 +103,7 @@ def read_commodity_costs_csv(buffer_, con): region VARCHAR REFERENCES regions(id), year BIGINT, value DOUBLE, + PRIMARY KEY (commodity, region, year) ); """ con.sql(sql) @@ -79,6 +119,7 @@ def read_demand_csv(buffer_, con): region VARCHAR REFERENCES regions(id), year BIGINT, demand DOUBLE, + PRIMARY KEY (commodity, region, year) ); """ con.sql(sql) @@ -92,28 +133,19 @@ def read_demand_slicing_csv(buffer_, con): commodity VARCHAR REFERENCES commodities(id), region VARCHAR REFERENCES regions(id), year BIGINT, - timeslice VARCHAR, + timeslice VARCHAR REFERENCES timeslices(id), fraction DOUBLE CHECK (fraction >= 0 AND fraction <= 1), + PRIMARY KEY (commodity, region, year, timeslice), + FOREIGN KEY (commodity, region, year) REFERENCES demand(commodity, region, year) ); """ con.sql(sql) rel = con.read_csv(buffer_, header=True, delimiter=",") # noqa: F841 con.sql("""INSERT INTO demand_slicing SELECT - commodity_id, region_id, year, timeslice, fraction FROM rel;""") + commodity_id, region_id, year, timeslice_id, fraction FROM rel;""") return con.sql("SELECT * from demand_slicing").fetchnumpy() -def read_regions_csv(buffer_, con): - sql = """CREATE TABLE regions ( - id VARCHAR PRIMARY KEY, - ); - """ - con.sql(sql) - rel = con.read_csv(buffer_, header=True, delimiter=",") # noqa: F841 - con.sql("INSERT INTO regions SELECT id FROM rel;") - return con.sql("SELECT * from regions").fetchnumpy() - - def calculate_global_commodities(commodities): names = commodities["id"].astype(np.dtype("str")) types = commodities["type"].astype(np.dtype("str")) @@ -129,3 +161,195 @@ def calculate_global_commodities(commodities): data = xr.Dataset(data_vars=dict(type=type_array, unit=unit_array)) return data + + +def calculate_demand( + commodities, regions, timeslices, demand, demand_slicing +) -> xr.DataArray: + """Calculate demand data for all commodities, regions, years, and timeslices. + + Result: A DataArray with a demand value for every combination of: + - commodity: all commodities specified in the commodities table + - region: all regions specified in the regions table + - year: all years specified in the demand table + - timeslice: all timeslices specified in the timeslices table + + Checks: + - If demand data is specified for one year, it must be specified for all years. + - If demand is nonzero, slicing data must be present. + - If slicing data is specified for a commodity/region/year, the sum of + the fractions must be 1, and all timeslices must be present. + + Fills: + - If demand data is not specified for a commodity/region combination, the demand is + 0 for all years and timeslices. + + Todo: + - Interpolation to allow for missing years in demand data. + - Ability to leave the year field blank in both tables to indicate all years + - Allow slicing data to be missing -> demand is spread equally across timeslices + - Allow more flexibility for timeslices (e.g. can specify "winter" to apply to all + winter timeslices, or "all" to apply to all timeslices) + """ + # Prepare dataframes + df_demand = pd.DataFrame(demand).set_index(["commodity", "region", "year"]) + df_slicing = pd.DataFrame(demand_slicing).set_index( + ["commodity", "region", "year", "timeslice"] + ) + + # DataArray dimensions + all_commodities = commodities["id"].astype(np.dtype("str")) + all_regions = regions["id"].astype(np.dtype("str")) + all_years = df_demand.index.get_level_values("year").unique() + all_timeslices = timeslices["id"].astype(np.dtype("str")) + + # CHECK: all years are specified for each commodity/region combination + check_all_values_specified(df_demand, ["commodity", "region"], "year", all_years) + + # CHECK: if slicing data is present, all timeslices must be specified + check_all_values_specified( + df_slicing, ["commodity", "region", "year"], "timeslice", all_timeslices + ) + + # CHECK: timeslice fractions sum to 1 + check_timeslice_sum = df_slicing.groupby(["commodity", "region", "year"]).apply( + lambda x: np.isclose(x["fraction"].sum(), 1) + ) + if not check_timeslice_sum.all(): + raise DataValidationError + + # CHECK: if demand data >0, fraction data must be specified + check_fraction_data_present = ( + df_demand[df_demand["demand"] > 0] + .index.isin(df_slicing.droplevel("timeslice").index) + .all() + ) + if not check_fraction_data_present.all(): + raise DataValidationError + + # FILL: demand is zero if unspecified + df_demand = df_demand.reindex( + pd.MultiIndex.from_product( + [all_commodities, all_regions, all_years], + names=["commodity", "region", "year"], + ), + fill_value=0, + ) + + # FILL: slice data is zero if unspecified + df_slicing = df_slicing.reindex( + pd.MultiIndex.from_product( + [all_commodities, all_regions, all_years, all_timeslices], + names=["commodity", "region", "year", "timeslice"], + ), + fill_value=0, + ) + + # Create DataArray + da_demand = df_demand.to_xarray()["demand"] + da_slicing = df_slicing.to_xarray()["fraction"] + data = da_demand * da_slicing + return data + + +def calculate_initial_market( + commodities, regions, timeslices, commodity_trade, commodity_costs +) -> xr.Dataset: + """Calculate trade and price data for all commodities, regions and years. + + Result: A Dataset with variables: + - prices + - exports + - imports + - static_trade + For every combination of: + - commodity: all commodities specified in the commodities table + - region: all regions specified in the regions table + - year: all years specified in the commodity_costs table + - timeslice (multiindex): all timeslices specified in the timeslices table + + Checks: + - If trade data is specified for one year, it must be specified for all years. + - If price data is specified for one year, it must be specified for all years. + + Fills: + - If trade data is not specified for a commodity/region combination, imports and + exports are both zero + - If price data is not specified for a commodity/region combination, the price is + zero + + """ + from muse.timeslices import QuantityType, convert_timeslice + + # Prepare dataframes + df_trade = pd.DataFrame(commodity_trade).set_index(["commodity", "region", "year"]) + df_costs = ( + pd.DataFrame(commodity_costs) + .set_index(["commodity", "region", "year"]) + .rename(columns={"value": "prices"}) + ) + df_timeslices = pd.DataFrame(timeslices).set_index(["season", "day", "time_of_day"]) + + # DataArray dimensions + all_commodities = commodities["id"].astype(np.dtype("str")) + all_regions = regions["id"].astype(np.dtype("str")) + all_years = df_costs.index.get_level_values("year").unique() + + # CHECK: all years are specified for each commodity/region combination + check_all_values_specified(df_trade, ["commodity", "region"], "year", all_years) + check_all_values_specified(df_costs, ["commodity", "region"], "year", all_years) + + # FILL: price is zero if unspecified + df_costs = df_costs.reindex( + pd.MultiIndex.from_product( + [all_commodities, all_regions, all_years], + names=["commodity", "region", "year"], + ), + fill_value=0, + ) + + # FILL: trade is zero if unspecified + df_trade = df_trade.reindex( + pd.MultiIndex.from_product( + [all_commodities, all_regions, all_years], + names=["commodity", "region", "year"], + ), + fill_value=0, + ) + + # Calculate static trade + df_trade["static_trade"] = df_trade["export"] - df_trade["import"] + + # Create Data + df_full = df_costs.join(df_trade) + data = df_full.to_xarray() + ts = df_timeslices.to_xarray()["fraction"] + ts = ts.stack(timeslice=("season", "day", "time_of_day")) + convert_timeslice(data, ts, QuantityType.EXTENSIVE) + + return data + + +class DataValidationError(ValueError): + pass + + +def check_all_values_specified( + df: pd.DataFrame, group_by_cols: list[str], column_name: str, values: list +) -> None: + """Check that the required values are specified in a dataframe. + + Checks that a row exists for all specified values of column_name for each + group in the grouped dataframe. + """ + if not ( + df.groupby(group_by_cols) + .apply( + lambda x: ( + set(x.index.get_level_values(column_name).unique()) == set(values) + ) + ) + .all() + ).all(): + msg = "" # TODO + raise DataValidationError(msg) diff --git a/tests/test_new_readers.py b/tests/test_new_readers.py index 010883287..483d42627 100644 --- a/tests/test_new_readers.py +++ b/tests/test_new_readers.py @@ -57,7 +57,7 @@ def populate_demand(default_new_input, con, populate_regions, populate_commoditi @fixture def populate_demand_slicing( - default_new_input, con, populate_regions, populate_commodities + default_new_input, con, populate_regions, populate_commodities, populate_demand ): from muse.new_input.readers import read_demand_slicing_csv @@ -73,6 +73,28 @@ def populate_regions(default_new_input, con): return read_regions_csv(f, con) +@fixture +def populate_timeslices(default_new_input, con): + from muse.new_input.readers import read_timeslices_csv + + with open(default_new_input / "timeslices.csv") as f: + return read_timeslices_csv(f, con) + + +def test_read_timeslices_csv(populate_timeslices): + data = populate_timeslices + assert len(data["id"]) == 6 + assert next(iter(data["id"])) == "1" + assert next(iter(data["season"])) == "all" + assert next(iter(data["day"])) == "all" + assert next(iter(data["time_of_day"])) == "night" + assert next(iter(data["fraction"])) == approx(0.1667) + + +def test_read_regions_csv(populate_regions): + assert populate_regions["id"] == np.array(["R1"]) + + def test_read_commodities_csv(populate_commodities): data = populate_commodities assert list(data["id"]) == ["electricity", "gas", "heat", "wind", "CO2f"] @@ -106,26 +128,18 @@ def test_read_demand_csv(populate_demand): assert np.all(data["demand"] == np.array([10, 30])) -def test_read_regions_csv(populate_regions): - assert populate_regions["id"] == np.array(["R1"]) - - -def test_calculate_global_commodities(populate_commodities): - from muse.new_input.readers import calculate_global_commodities - - data = calculate_global_commodities(populate_commodities) - - assert isinstance(data, xr.Dataset) - assert set(data.dims) == {"commodity"} - for dt in data.dtypes.values(): - assert np.issubdtype(dt, np.dtype("str")) - - assert list(data.coords["commodity"].values) == list(populate_commodities["id"]) - assert list(data.data_vars["type"].values) == list(populate_commodities["type"]) - assert list(data.data_vars["unit"].values) == list(populate_commodities["unit"]) +def test_read_demand_slicing_csv(populate_demand_slicing): + data = populate_demand_slicing + assert np.all(data["commodity"] == "heat") + assert np.all(data["region"] == "R1") + # assert np.all(data["timeslice"] == np.array([0, 1])) + assert np.all( + data["fraction"] + == np.array([0.1, 0.15, 0.1, 0.15, 0.3, 0.2, 0.1, 0.15, 0.1, 0.15, 0.3, 0.2]) + ) -def test_read_global_commodities_type_constraint(default_new_input, con): +def test_read_commodities_csv_type_constraint(con): from muse.new_input.readers import read_commodities_csv csv = StringIO("id,type,unit\nfoo,invalid,bar\n") @@ -134,7 +148,7 @@ def test_read_global_commodities_type_constraint(default_new_input, con): def test_read_demand_csv_commodity_constraint( - default_new_input, con, populate_commodities, populate_regions + con, populate_commodities, populate_regions ): from muse.new_input.readers import read_demand_csv @@ -143,9 +157,7 @@ def test_read_demand_csv_commodity_constraint( read_demand_csv(csv, con) -def test_read_demand_csv_region_constraint( - default_new_input, con, populate_commodities, populate_regions -): +def test_read_demand_csv_region_constraint(con, populate_commodities, populate_regions): from muse.new_input.readers import read_demand_csv csv = StringIO("year,commodity_id,region_id,demand\n2020,heat,invalid,0\n") @@ -153,23 +165,44 @@ def test_read_demand_csv_region_constraint( read_demand_csv(csv, con) -@mark.xfail -def test_demand_dataset(default_new_input): - import duckdb - from muse.new_input.readers import read_commodities, read_demand, read_regions +def test_calculate_global_commodities(populate_commodities): + from muse.new_input.readers import calculate_global_commodities - con = duckdb.connect(":memory:") + data = calculate_global_commodities(populate_commodities) - read_regions(default_new_input, con) - read_commodities(default_new_input, con) - data = read_demand(default_new_input, con) + assert isinstance(data, xr.Dataset) + assert set(data.dims) == {"commodity"} + for dt in data.dtypes.values(): + assert np.issubdtype(dt, np.dtype("str")) + + assert list(data.coords["commodity"].values) == list(populate_commodities["id"]) + assert list(data.data_vars["type"].values) == list(populate_commodities["type"]) + assert list(data.data_vars["unit"].values) == list(populate_commodities["unit"]) + + +def test_calculate_demand( + populate_commodities, + populate_regions, + populate_timeslices, + populate_demand, + populate_demand_slicing, +): + from muse.new_input.readers import calculate_demand + + data = calculate_demand( + populate_commodities, + populate_regions, + populate_timeslices, + populate_demand, + populate_demand_slicing, + ) assert isinstance(data, xr.DataArray) assert data.dtype == np.float64 assert set(data.dims) == {"year", "commodity", "region", "timeslice"} assert list(data.coords["region"].values) == ["R1"] - assert list(data.coords["timeslice"].values) == list(range(1, 7)) + assert list(data.coords["timeslice"].values) == ["1", "2", "3", "4", "5", "6"] assert list(data.coords["year"].values) == [2020, 2050] assert set(data.coords["commodity"].values) == { "electricity", @@ -179,15 +212,26 @@ def test_demand_dataset(default_new_input): "CO2f", } - assert data.sel(year=2020, commodity="electricity", region="R1", timeslice=0) == 1 + assert data.sel(year=2020, commodity="heat", region="R1", timeslice="1") == 1 @mark.xfail -def test_new_read_initial_market(default_new_input): - from muse.new_input.readers import read_inputs - - all_data = read_inputs(default_new_input) - data = all_data["initial_market"] +def test_calculate_initial_market( + populate_commodities, + populate_regions, + populate_timeslices, + populate_commodity_trade, + populate_commodity_costs, +): + from muse.new_input.readers import calculate_initial_market + + data = calculate_initial_market( + populate_commodities, + populate_regions, + populate_timeslices, + populate_commodity_trade, + populate_commodity_costs, + ) assert isinstance(data, xr.Dataset) assert set(data.dims) == {"region", "year", "commodity", "timeslice"} @@ -197,15 +241,15 @@ def test_new_read_initial_market(default_new_input): imports=np.float64, static_trade=np.float64, ) - assert list(data.coords["region"].values) == ["R1"] - assert list(data.coords["year"].values) == list(range(2010, 2105, 5)) - assert list(data.coords["commodity"].values) == [ + assert set(data.coords["region"].values) == {"R1"} + assert set(data.coords["year"].values) == set(range(2010, 2105, 5)) + assert set(data.coords["commodity"].values) == { "electricity", "gas", "heat", "CO2f", "wind", - ] + } month_values = ["all-year"] * 6 day_values = ["all-week"] * 6 hour_values = [ From 2c3f77b0d83d198d070f536d2dc6976ad09393e3 Mon Sep 17 00:00:00 2001 From: Tom Bland Date: Mon, 19 Aug 2024 12:25:32 +0100 Subject: [PATCH 07/10] Convert timeslice id to int, fix failing test --- src/muse/new_input/readers.py | 6 +++--- tests/test_new_readers.py | 13 +++++++++---- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/src/muse/new_input/readers.py b/src/muse/new_input/readers.py index b216ba0d4..67c5dd9aa 100644 --- a/src/muse/new_input/readers.py +++ b/src/muse/new_input/readers.py @@ -41,7 +41,7 @@ def read_inputs(data_dir): def read_timeslices_csv(buffer_, con): sql = """CREATE TABLE timeslices ( - id VARCHAR PRIMARY KEY, + id BIGINT PRIMARY KEY, season VARCHAR, day VARCHAR, time_of_day VARCHAR, @@ -133,7 +133,7 @@ def read_demand_slicing_csv(buffer_, con): commodity VARCHAR REFERENCES commodities(id), region VARCHAR REFERENCES regions(id), year BIGINT, - timeslice VARCHAR REFERENCES timeslices(id), + timeslice BIGINT REFERENCES timeslices(id), fraction DOUBLE CHECK (fraction >= 0 AND fraction <= 1), PRIMARY KEY (commodity, region, year, timeslice), FOREIGN KEY (commodity, region, year) REFERENCES demand(commodity, region, year) @@ -201,7 +201,7 @@ def calculate_demand( all_commodities = commodities["id"].astype(np.dtype("str")) all_regions = regions["id"].astype(np.dtype("str")) all_years = df_demand.index.get_level_values("year").unique() - all_timeslices = timeslices["id"].astype(np.dtype("str")) + all_timeslices = timeslices["id"].astype(np.dtype("int")) # CHECK: all years are specified for each commodity/region combination check_all_values_specified(df_demand, ["commodity", "region"], "year", all_years) diff --git a/tests/test_new_readers.py b/tests/test_new_readers.py index 483d42627..68f715d55 100644 --- a/tests/test_new_readers.py +++ b/tests/test_new_readers.py @@ -57,7 +57,12 @@ def populate_demand(default_new_input, con, populate_regions, populate_commoditi @fixture def populate_demand_slicing( - default_new_input, con, populate_regions, populate_commodities, populate_demand + default_new_input, + con, + populate_regions, + populate_commodities, + populate_demand, + populate_timeslices, ): from muse.new_input.readers import read_demand_slicing_csv @@ -84,7 +89,7 @@ def populate_timeslices(default_new_input, con): def test_read_timeslices_csv(populate_timeslices): data = populate_timeslices assert len(data["id"]) == 6 - assert next(iter(data["id"])) == "1" + assert next(iter(data["id"])) == 1 assert next(iter(data["season"])) == "all" assert next(iter(data["day"])) == "all" assert next(iter(data["time_of_day"])) == "night" @@ -202,7 +207,7 @@ def test_calculate_demand( assert set(data.dims) == {"year", "commodity", "region", "timeslice"} assert list(data.coords["region"].values) == ["R1"] - assert list(data.coords["timeslice"].values) == ["1", "2", "3", "4", "5", "6"] + assert set(data.coords["timeslice"].values) == set(range(1, 7)) assert list(data.coords["year"].values) == [2020, 2050] assert set(data.coords["commodity"].values) == { "electricity", @@ -212,7 +217,7 @@ def test_calculate_demand( "CO2f", } - assert data.sel(year=2020, commodity="heat", region="R1", timeslice="1") == 1 + assert data.sel(year=2020, commodity="heat", region="R1", timeslice=1) == 1 @mark.xfail From 3e57b257fef5a886bc699584c26e4b1166db3443 Mon Sep 17 00:00:00 2001 From: Tom Bland Date: Mon, 19 Aug 2024 14:13:27 +0100 Subject: [PATCH 08/10] Finish initial market reader --- .../example/default_new_input/timeslices.csv | 14 ++--- src/muse/new_input/readers.py | 60 ++++++++++++++----- src/muse/timeslices.py | 2 +- tests/test_new_readers.py | 45 +++++++------- 4 files changed, 75 insertions(+), 46 deletions(-) diff --git a/src/muse/data/example/default_new_input/timeslices.csv b/src/muse/data/example/default_new_input/timeslices.csv index c83832c90..8d7e721cf 100644 --- a/src/muse/data/example/default_new_input/timeslices.csv +++ b/src/muse/data/example/default_new_input/timeslices.csv @@ -1,7 +1,7 @@ -id,season,day,time_of_day,fraction -1,all,all,night,0.1667 -2,all,all,morning,0.1667 -3,all,all,afternoon,0.1667 -4,all,all,early-peak,0.1667 -5,all,all,late-peak,0.1667 -6,all,all,evening,0.1667 +id,month,day,hour,fraction +1,all-year,all-week,night,0.1667 +2,all-year,all-week,morning,0.1667 +3,all-year,all-week,afternoon,0.1667 +4,all-year,all-week,early-peak,0.1667 +5,all-year,all-week,late-peak,0.1667 +6,all-year,all-week,evening,0.1667 diff --git a/src/muse/new_input/readers.py b/src/muse/new_input/readers.py index 67c5dd9aa..54688f3ec 100644 --- a/src/muse/new_input/readers.py +++ b/src/muse/new_input/readers.py @@ -2,6 +2,7 @@ import numpy as np import pandas as pd import xarray as xr +from muse.timeslices import QuantityType def read_inputs(data_dir): @@ -42,17 +43,15 @@ def read_inputs(data_dir): def read_timeslices_csv(buffer_, con): sql = """CREATE TABLE timeslices ( id BIGINT PRIMARY KEY, - season VARCHAR, + month VARCHAR, day VARCHAR, - time_of_day VARCHAR, + hour VARCHAR, fraction DOUBLE CHECK (fraction >= 0 AND fraction <= 1), ); """ con.sql(sql) rel = con.read_csv(buffer_, header=True, delimiter=",") # noqa: F841 - con.sql( - "INSERT INTO timeslices SELECT id, season, day, time_of_day, fraction FROM rel;" - ) + con.sql("INSERT INTO timeslices SELECT id, month, day, hour, fraction FROM rel;") return con.sql("SELECT * from timeslices").fetchnumpy() @@ -278,9 +277,11 @@ def calculate_initial_market( - If price data is not specified for a commodity/region combination, the price is zero - """ - from muse.timeslices import QuantityType, convert_timeslice + Todo: + - Allow data to be specified on a timeslice level (optional) + - Interpolation, missing year field, flexible timeslice specification as above + """ # Prepare dataframes df_trade = pd.DataFrame(commodity_trade).set_index(["commodity", "region", "year"]) df_costs = ( @@ -288,7 +289,7 @@ def calculate_initial_market( .set_index(["commodity", "region", "year"]) .rename(columns={"value": "prices"}) ) - df_timeslices = pd.DataFrame(timeslices).set_index(["season", "day", "time_of_day"]) + df_timeslices = pd.DataFrame(timeslices).set_index(["month", "day", "hour"]) # DataArray dimensions all_commodities = commodities["id"].astype(np.dtype("str")) @@ -320,13 +321,17 @@ def calculate_initial_market( # Calculate static trade df_trade["static_trade"] = df_trade["export"] - df_trade["import"] - # Create Data - df_full = df_costs.join(df_trade) - data = df_full.to_xarray() - ts = df_timeslices.to_xarray()["fraction"] - ts = ts.stack(timeslice=("season", "day", "time_of_day")) - convert_timeslice(data, ts, QuantityType.EXTENSIVE) + # Create xarray datasets + xr_costs = df_costs.to_xarray() + xr_trade = df_trade.to_xarray() + # Project over timeslices + ts = df_timeslices.to_xarray()["fraction"].stack(timeslice=("month", "day", "hour")) + xr_costs = project_timeslice(xr_costs, ts, QuantityType.EXTENSIVE) + xr_trade = project_timeslice(xr_trade, ts, QuantityType.INTENSIVE) + + # Combine data + data = xr.merge([xr_costs, xr_trade]) return data @@ -353,3 +358,30 @@ def check_all_values_specified( ).all(): msg = "" # TODO raise DataValidationError(msg) + + +def project_timeslice( + data: xr.Dataset, timeslices: xr.DataArray, quantity_type: QuantityType +) -> xr.Dataset: + """Project a dataset over a new timeslice dimension. + + The projection can be done in one of two ways, depending on whether the + quantity type is extensive or intensive. See `QuantityType`. + + Args: + data: Dataset to project + timeslices: DataArray of timeslice levels, with values between 0 and 1 + representing the timeslice length (fraction of the year) + quantity_type: Type of projection to perform. QuantityType.EXTENSIVE or + QuantityType.INTENSIVE + + Returns: + Projected dataset + """ + assert "timeslice" in timeslices.dims + assert "timeslice" not in data.dims + + if quantity_type is QuantityType.INTENSIVE: + return data * timeslices + if quantity_type is QuantityType.EXTENSIVE: + return data * xr.ones_like(timeslices) diff --git a/src/muse/timeslices.py b/src/muse/timeslices.py index 001152bfe..5e6618c76 100644 --- a/src/muse/timeslices.py +++ b/src/muse/timeslices.py @@ -405,7 +405,7 @@ def convert_timeslice( ) -> Union[DataArray, Dataset]: '''Adjusts the timeslice of x to match that of ts. - The conversion can be done in on of two ways, depending on whether the + The conversion can be done in one of two ways, depending on whether the quantity is extensive or intensive. See `QuantityType`. Example: diff --git a/tests/test_new_readers.py b/tests/test_new_readers.py index 68f715d55..07a3889e3 100644 --- a/tests/test_new_readers.py +++ b/tests/test_new_readers.py @@ -3,7 +3,7 @@ import duckdb import numpy as np import xarray as xr -from pytest import approx, fixture, mark, raises +from pytest import approx, fixture, raises @fixture @@ -206,9 +206,9 @@ def test_calculate_demand( assert data.dtype == np.float64 assert set(data.dims) == {"year", "commodity", "region", "timeslice"} - assert list(data.coords["region"].values) == ["R1"] + assert set(data.coords["region"].values) == {"R1"} assert set(data.coords["timeslice"].values) == set(range(1, 7)) - assert list(data.coords["year"].values) == [2020, 2050] + assert set(data.coords["year"].values) == {2020, 2050} assert set(data.coords["commodity"].values) == { "electricity", "gas", @@ -220,7 +220,6 @@ def test_calculate_demand( assert data.sel(year=2020, commodity="heat", region="R1", timeslice=1) == 1 -@mark.xfail def test_calculate_initial_market( populate_commodities, populate_regions, @@ -240,12 +239,8 @@ def test_calculate_initial_market( assert isinstance(data, xr.Dataset) assert set(data.dims) == {"region", "year", "commodity", "timeslice"} - assert dict(data.dtypes) == dict( - prices=np.float64, - exports=np.float64, - imports=np.float64, - static_trade=np.float64, - ) + for dt in data.dtypes.values(): + assert dt == np.dtype("float64") assert set(data.coords["region"].values) == {"R1"} assert set(data.coords["year"].values) == set(range(2010, 2105, 5)) assert set(data.coords["commodity"].values) == { @@ -266,28 +261,30 @@ def test_calculate_initial_market( "evening", ] - assert list(data.coords["timeslice"].values) == list( + assert set(data.coords["timeslice"].values) == set( zip(month_values, day_values, hour_values) ) - assert list(data.coords["month"]) == month_values - assert list(data.coords["day"]) == day_values - assert list(data.coords["hour"]) == hour_values + assert set(data.coords["month"].values) == set(month_values) + assert set(data.coords["day"].values) == set(day_values) + assert set(data.coords["hour"].values) == set(hour_values) assert all(var.coords.equals(data.coords) for var in data.data_vars.values()) prices = data.data_vars["prices"] - assert approx( - prices.sel( - year=2010, - region="R1", - commodity="electricity", - timeslice=("all-year", "all-week", "night"), + assert ( + approx( + prices.sel( + year=2010, + region="R1", + commodity="electricity", + timeslice=("all-year", "all-week", "night"), + ), + abs=1e-4, ) - - 14.81481, - abs=1e-4, + == 14.81481 ) - exports = data.data_vars["exports"] + exports = data.data_vars["export"] assert ( exports.sel( year=2010, @@ -297,7 +294,7 @@ def test_calculate_initial_market( ) ) == 0 - imports = data.data_vars["imports"] + imports = data.data_vars["import"] assert ( imports.sel( year=2010, From af2ef3713257d498cd088fa59e70911802152ba8 Mon Sep 17 00:00:00 2001 From: Tom Bland Date: Mon, 19 Aug 2024 14:24:04 +0100 Subject: [PATCH 09/10] Fix test --- tests/test_new_readers.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_new_readers.py b/tests/test_new_readers.py index 07a3889e3..e7d7e31d9 100644 --- a/tests/test_new_readers.py +++ b/tests/test_new_readers.py @@ -90,9 +90,9 @@ def test_read_timeslices_csv(populate_timeslices): data = populate_timeslices assert len(data["id"]) == 6 assert next(iter(data["id"])) == 1 - assert next(iter(data["season"])) == "all" - assert next(iter(data["day"])) == "all" - assert next(iter(data["time_of_day"])) == "night" + assert next(iter(data["month"])) == "all-year" + assert next(iter(data["day"])) == "all-week" + assert next(iter(data["hour"])) == "night" assert next(iter(data["fraction"])) == approx(0.1667) From 190a49a9012874bd81987b596d56737929fbc775 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 22 Oct 2024 12:05:37 +0000 Subject: [PATCH 10/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/muse/new_input/readers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/muse/new_input/readers.py b/src/muse/new_input/readers.py index 54688f3ec..c8833e902 100644 --- a/src/muse/new_input/readers.py +++ b/src/muse/new_input/readers.py @@ -2,6 +2,7 @@ import numpy as np import pandas as pd import xarray as xr + from muse.timeslices import QuantityType