Skip to content

Commit

Permalink
update unittest and exception handling
Browse files Browse the repository at this point in the history
  • Loading branch information
Marika-K committed Mar 18, 2024
1 parent 3ca5be0 commit 3ba08e7
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 30 deletions.
11 changes: 10 additions & 1 deletion src/gurobi_optimods/line_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,16 @@ def all_shortest_paths(
# define a variable for each such shortest path
f = {} # stuv from s to t using edge uv
for s, t in demands:
paths = list(nx.all_shortest_paths(G, source=s, target=t, weight="time"))
if s not in G:
raise ValueError(f"demand node {s} not found in edges")
if t not in G:
raise ValueError(f"demand node {t} not found in edges")
try:
paths = list(
nx.all_shortest_paths(G, source=s, target=t, weight="time")
)
except nx.NetworkXNoPath:
raise ValueError(f"no path found for connection from {s} to {t}")
demand_expr = 0
for p in paths:
y = model.addVar(
Expand Down
125 changes: 96 additions & 29 deletions tests/test_line_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@

DATA_FILE_DIR = pathlib.Path(__file__).parent / "data"

edge_data2A = """
edges_wrong = """
sourc,target,length,time
0,1,15,40
0,2,8,40
"""

edge_data2 = """
edges = """
source,target,time
0,1,10
0,2,10
Expand All @@ -33,23 +33,35 @@
4,5,30
"""

node_data2 = """
nodes = """
number
0
1
2
3
4
5
"""

lines2 = """
nodes_unknown_node = """
number
0
1
2
3
4
5
6
"""

lines = """
linename,capacity,fixCost,operatingCost
L1,20,9,4
L2,20,9,2
L3,20,9,3
"""

line_path2 = """
linepath = """
linename,edgeSource,edgeTarget
L1,0,1
L1,1,3
Expand All @@ -62,20 +74,33 @@
L3,4,5
"""

demand2 = """
demand = """
source,target,demand
0,3,20
0,5,30
3,4,30
"""

demand2A = """
demand_negative = """
source,target,demand
0,4,-30
0,3,10
2,1,10
"""

demand_unknown_node = """
source,target,demand
0,3,20
0,5,30
3,4,30
2,6,30
"""

demand_no_connection = """
source,target,demand
2,1,20
"""

# python -m unittest tests.test_line_optimization.Testlop


Expand Down Expand Up @@ -103,11 +128,11 @@ def test_sol(self):
self.assertEqual(len(final_lines), 8)

def test_wrong_data_format(self):
edge_data = pd.read_csv(io.StringIO(edge_data2A))
node_data = pd.read_csv(io.StringIO(node_data2))
linepath_data = pd.read_csv(io.StringIO(line_path2))
line_data = pd.read_csv(io.StringIO(lines2))
demand_data = pd.read_csv(io.StringIO(demand2))
edge_data = pd.read_csv(io.StringIO(edges_wrong))
node_data = pd.read_csv(io.StringIO(nodes))
linepath_data = pd.read_csv(io.StringIO(linepath))
line_data = pd.read_csv(io.StringIO(lines))
demand_data = pd.read_csv(io.StringIO(demand))
frequencies = [3, 6]

captured_output = io.StringIO()
Expand All @@ -126,13 +151,13 @@ def test_wrong_data_format(self):
"column source not present in edge_data\n", captured_output.getvalue()
)

def test_wrong_data_format2(self):
edge_data = pd.read_csv(io.StringIO(edge_data2))
node_data = pd.read_csv(io.StringIO(node_data2))
linepath_data = pd.read_csv(io.StringIO(line_path2))
line_data = pd.read_csv(io.StringIO(lines2))
demand_data = pd.read_csv(io.StringIO(demand2A))
frequencies = [3, 6, 9, 18]
def test_wrong_data_neg_demand(self):
edge_data = pd.read_csv(io.StringIO(edges))
node_data = pd.read_csv(io.StringIO(nodes))
linepath_data = pd.read_csv(io.StringIO(linepath))
line_data = pd.read_csv(io.StringIO(lines))
demand_data = pd.read_csv(io.StringIO(demand_negative))
frequencies = [3, 18]
captured_output = io.StringIO()
try:
with contextlib.redirect_stdout(captured_output):
Expand All @@ -150,12 +175,54 @@ def test_wrong_data_format2(self):
captured_output.getvalue(),
)

@unittest.skipIf(nx is None, "networkx is not installed")
def test_wrong_data_node_not_found(self):
edge_data = pd.read_csv(io.StringIO(edges))
node_data = pd.read_csv(io.StringIO(nodes_unknown_node))
linepath_data = pd.read_csv(io.StringIO(linepath))
line_data = pd.read_csv(io.StringIO(lines))
demand_data = pd.read_csv(io.StringIO(demand_unknown_node))
frequencies = [3, 6, 9, 18]
with self.assertRaises(ValueError) as error:
lop.line_optimization(
node_data,
edge_data,
line_data,
linepath_data,
demand_data,
frequencies,
True,
)
self.assertEqual(str(error.exception), "demand node 6 not found in edges")

@unittest.skipIf(nx is None, "networkx is not installed")
def test_wrong_data_no_path_found(self):
edge_data = pd.read_csv(io.StringIO(edges))
node_data = pd.read_csv(io.StringIO(nodes))
linepath_data = pd.read_csv(io.StringIO(linepath))
line_data = pd.read_csv(io.StringIO(lines))
demand_data = pd.read_csv(io.StringIO(demand_no_connection))
frequencies = [18]
with self.assertRaises(ValueError) as error:
lop.line_optimization(
node_data,
edge_data,
line_data,
linepath_data,
demand_data,
frequencies,
True,
)
self.assertEqual(
str(error.exception), "no path found for connection from 2 to 1"
)

def test_smallAllPaths(self):
edge_data = pd.read_csv(io.StringIO(edge_data2))
node_data = pd.read_csv(io.StringIO(node_data2))
linepath_data = pd.read_csv(io.StringIO(line_path2))
line_data = pd.read_csv(io.StringIO(lines2))
demand_data = pd.read_csv(io.StringIO(demand2))
edge_data = pd.read_csv(io.StringIO(edges))
node_data = pd.read_csv(io.StringIO(nodes))
linepath_data = pd.read_csv(io.StringIO(linepath))
line_data = pd.read_csv(io.StringIO(lines))
demand_data = pd.read_csv(io.StringIO(demand))
frequencies = [1, 2, 3]
obj_cost, final_lines = lop.line_optimization(
node_data,
Expand All @@ -171,11 +238,11 @@ def test_smallAllPaths(self):

@unittest.skipIf(nx is None, "networkx is not installed")
def test_shortestPath(self):
edge_data = pd.read_csv(io.StringIO(edge_data2))
node_data = pd.read_csv(io.StringIO(node_data2))
linepath_data = pd.read_csv(io.StringIO(line_path2))
line_data = pd.read_csv(io.StringIO(lines2))
demand_data = pd.read_csv(io.StringIO(demand2))
edge_data = pd.read_csv(io.StringIO(edges))
node_data = pd.read_csv(io.StringIO(nodes))
linepath_data = pd.read_csv(io.StringIO(linepath))
line_data = pd.read_csv(io.StringIO(lines))
demand_data = pd.read_csv(io.StringIO(demand))
frequencies = [1, 2, 3]
obj_cost, final_lines = lop.line_optimization(
node_data,
Expand Down

0 comments on commit 3ba08e7

Please sign in to comment.