Skip to content

Commit

Permalink
Major refactoring - migrate source vetex from Graph's constructor to …
Browse files Browse the repository at this point in the history
…the iterate method as a param.
  • Loading branch information
hansonchar committed Jan 26, 2025
1 parent 199de83 commit de9196e
Show file tree
Hide file tree
Showing 10 changed files with 81 additions and 83 deletions.
4 changes: 2 additions & 2 deletions learning-lua/algo/BFS-tests.lua
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ end

local function bfs_test(input, src, expected_visits, expected_max_level)
local G = load_input(input)
local bfs, count, prev_level = BFS:new(G, src), 1, 0
for from, to, weight, level in bfs:iterate() do
local bfs, count, prev_level = BFS:new(G), 1, 0
for from, to, weight, level in bfs:iterate(src) do
count = count + 1
assert(prev_level <= level)
prev_level = level
Expand Down
11 changes: 5 additions & 6 deletions learning-lua/algo/BFS.lua
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@ local E = {}

local TO<const>, LEVEL<const> = 2, 4

local function _iterate(self)
local function _iterate(self, src)
local q, visited = Queue:new(), {
[self.src_vertex] = true
[src] = true
}
self._visited_count = 1
local level = 0
local from = self.src_vertex
local from = src
repeat
local vertex = self.graph:vertex(from)
level = level + 1
Expand All @@ -31,9 +31,8 @@ local function _iterate(self)
end

---@param G (table) graph
---@param src (any) source vertex
function BFS:new(G, src, func_iterate)
return getmetatable(self):new(G, src, _iterate)
function BFS:new(G, func_iterate)
return getmetatable(self):new(G, _iterate)
end

return BFS
12 changes: 6 additions & 6 deletions learning-lua/algo/DFS-tests.lua
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ end

local function dfs_test(input, src, expected_visits, expected_max_level)
local G = load_edges(input)
local dfs, count, max_level = DFS:new(G, src), 1, 0
for from, to, weight, level in dfs:iterate() do
local dfs, count, max_level = DFS:new(G), 1, 0
for from, to, weight, level in dfs:iterate(src) do
count = count + 1
max_level = level > max_level and level or max_level
-- print(string.format("%d: %s-%s=%d", level, from, to, weight))
Expand All @@ -31,9 +31,9 @@ local function topo_sort(input, src)
local terminal -- the terminal node
local level_visits = {}
local G = load_edges(input)
local dfs = DFS:new(G, src)
local dfs = DFS:new(G)
-- outgoings is the number of unvisited outgoing edges
for from, to, _, depth, outgoings in dfs:iterate() do
for from, to, _, depth, outgoings in dfs:iterate(src) do
-- print(string.format("%d: %s-%s", level, from, to))
if not terminal and outgoings == 0 then
terminal = to
Expand Down Expand Up @@ -107,9 +107,9 @@ local function single_vertex_test()
print("Testing single vertex ...")
local G = Graph:new()
G:add('a')
local dfs = DFS:new(G, 'a')
local dfs = DFS:new(G)
local count = 0
for from, to, weight, depth in dfs:iterate() do
for from, to, weight, depth in dfs:iterate('a') do
count = count + 1
assert(from == 'a')
assert(not to)
Expand Down
15 changes: 8 additions & 7 deletions learning-lua/algo/DFS.lua
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ local Stack = require "algo.Stack"
local DFS = GraphSearch:class()
local E = {}

-- Element index of a stack entry with the format: {from, to, weight, depth, count_unvisited, begin_vertex}
local FROM<const>, TO<const>, WEIGHT<const>, DEPTH<const>, UNVISITED<const>, BEGIN_VERTEX<const> = 1, 2, 3, 4, 5, 6
-- Element index of a stack entry with the format: {from, to, weight, depth, count_unvisited, begin_vertex, is_visited}
local FROM<const>, TO<const>, WEIGHT<const>, DEPTH<const>, UNVISITED<const>, BEGIN_VERTEX<const>, IS_VISITED<const> = 1, 2, 3, 4, 5, 6, 7

local function debug(...)
-- print(...)
Expand Down Expand Up @@ -60,15 +60,17 @@ local function vertices_of(G)
return vertices
end

local function _iterate(self)
---@param src (any) optional source vertex; this takes precedence.
---@param is_include_visited (boolean) true if visited nodes are returned in addition to unvisited node.
local function _iterate(self, src, is_include_visited)
local stack, visited = Stack:new(), {}
-- Applicable only if a single source is not specified for this DFS.
-- If a single source is specified, the DFS will only be performed from that source vertex.
-- Otherwise, DFS is performed from potentially many source vertices until all vertices have been explored.
local unvisited_vertices -- Contains vertices that have not been visited; visited ones are erased.
local src_spec_idx = 0
self._visited_count = 0
local node = self.src_vertex -- DFS from a single source vertex
local node = src -- DFS from a single source vertex
if not node then -- DFS from potentially many source vertices
unvisited_vertices = vertices_of(self.graph)
local next_unvisited = next(unvisited_vertices) -- we are done if all vertices have been visited.
Expand Down Expand Up @@ -132,11 +134,10 @@ local function _iterate(self)
end

---@param G (table) graph
---@param src (any) source vertex
---@param nav_spec (table) opional navigation spec in the format of {from_1={to_1, to_2, ...}, ...} e.g. {['3']={'5','11'}, ['5']={'7','9'}}
---@param src_spec (table) optional source vertex spec in the format of {v1, v2, ...} e.g. {'1', '2', '3', ...}; applicable only if 'src' is not specified
function DFS:new(G, src, nav_spec, src_spec)
return getmetatable(self):new(G, src, _iterate, nav_spec, src_spec)
function DFS:new(G, nav_spec, src_spec)
return getmetatable(self):new(G, _iterate, nav_spec, src_spec)
end

return DFS
14 changes: 6 additions & 8 deletions learning-lua/algo/GraphSearch.lua
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@ local function _yield(entry)
end

---@param src (any) optional source vertex; this takes precedence.
function GraphSearch:iterate(src)
---@param is_include_visited (boolean) true if visited nodes are returned in addition to unvisited node. (Currently only DFS supports this parameter.)
function GraphSearch:iterate(src, is_include_visited)
assert(not src or self.graph:vertex(src), "Source vertex not found in graph")
self._visited_count = 0
-- self._nav = build_navigation(self)
return coroutine.wrap(function()
self:_iterate(src)
self:_iterate(src, is_include_visited)
end)
end

Expand All @@ -34,18 +36,14 @@ function GraphSearch:class(o)
end

---@param G (table) graph
---@param src (any) source vertex (optional)
---@param func_iterate (function) function for iteration
---@param nav_spec (table) optional navigation spec in the format of {from_1={to_1, to_2, ...}, ...} e.g. {['3']={'5','11'}, ['5']={'7','9'}}
---@param src_spec (table) optional source vertex spec in the format of {v1, v2, ...} e.g. {'1', '2', '3', ...}; applicable only if 'src' is not specified
function GraphSearch:new(G, src, func_iterate, nav_spec, src_spec)
function GraphSearch:new(G, func_iterate, nav_spec, src_spec)
assert(G, "Missing Graph")
assert(Graph.isGraph(G), "G must be a graph object")
-- assert(src, "Missing source vertex")
assert(not src or G:vertex(src), "Source vertex not found in graph")
local o = GraphSearch:class{
graph = G,
src_vertex = src
graph = G
}
o._iterate = func_iterate
o._nav_spec = nav_spec or E
Expand Down
2 changes: 1 addition & 1 deletion learning-lua/algo/SCCSearch.lua
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ end

---@param G (table) graph
function SCCSearch:new(G)
return getmetatable(self):new(G, nil, _iterate)
return getmetatable(self):new(G, _iterate)
end

return SCCSearch
4 changes: 2 additions & 2 deletions learning-lua/algo/SccDfsSearch.lua
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ local function _iterate(self)
local scc_id, count = 0, 0
local src_spec = reversed(src_spec)
debug(table.concat(src_spec, ","))
local dfs = DFS:new(G, nil, nil, src_spec)
local dfs = DFS:new(G, nil, src_spec)
local scc_src_vertex
for from, to, _, _, _, src_vertex in dfs:iterate() do
debug(string.format("from=%s to=%s, src_vertex=%s", from, to, src_vertex))
Expand All @@ -60,7 +60,7 @@ end

---@param G (table) graph
function SccDfsSearch:new(G)
return getmetatable(SccDfsSearch):new(G, nil, _iterate)
return getmetatable(SccDfsSearch):new(G, _iterate)
end

return SccDfsSearch
28 changes: 14 additions & 14 deletions learning-lua/algo/ShortestPathSearch-tests.lua
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ local function basic_tests()
G:add('v', 't', 6)
G:add('w', 't', 3)

local search = ShortestPathSearch:new(G, 's')
for from, to, weight, level, min_cost in search:iterate() do
local search = ShortestPathSearch:new(G)
for from, to, weight, level, min_cost in search:iterate('s') do
debug(string.format("%d: %s-%s=%d, min:%d", level, from, to, weight, min_cost))
end

Expand Down Expand Up @@ -43,8 +43,8 @@ local function tim_test()
local src<const> = 's'
local G = load_input(input)
local level_counts = {}
local search = ShortestPathSearch:new(G, src)
for from, to, weight, level, min_cost in search:iterate() do
local search = ShortestPathSearch:new(G)
for from, to, weight, level, min_cost in search:iterate(src) do
level_counts[level] = level_counts[level] or 0
level_counts[level] = level_counts[level] + 1
-- print(string.format("%d: %s-%s=%d, min:%d", level, from, to, weight, min_cost))
Expand Down Expand Up @@ -73,8 +73,8 @@ local function geek_test()
local src<const> = '0'
local G = load_input(input)
local level_counts = {}
local search = ShortestPathSearch:new(G, src)
for from, to, weight, level, min_cost in search:iterate() do
local search = ShortestPathSearch:new(G)
for from, to, weight, level, min_cost in search:iterate(src) do
level_counts[level] = level_counts[level] or 0
level_counts[level] = level_counts[level] + 1
-- print(string.format("%d: %s-%s=%d, min:%d", level, from, to, weight, min_cost))
Expand Down Expand Up @@ -122,8 +122,8 @@ local function redblobgames_test()

local G = load_input(input)
local level_counts = {}
local search = ShortestPathSearch:new(G, src)
for from, to, weight, level, min_cost in search:iterate() do
local search = ShortestPathSearch:new(G)
for from, to, weight, level, min_cost in search:iterate(src) do
level_counts[level] = level_counts[level] or 0
level_counts[level] = level_counts[level] + 1
-- print(string.format("%d: %s-%s=%d, min:%d", level, from, to, weight, min_cost))
Expand Down Expand Up @@ -159,8 +159,8 @@ local function algodaily_test()

local G = load_input(input)
local level_counts = {}
local search = ShortestPathSearch:new(G, src)
for from, to, weight, level, min_cost in search:iterate() do
local search = ShortestPathSearch:new(G)
for from, to, weight, level, min_cost in search:iterate(src) do
level_counts[level] = level_counts[level] or 0
level_counts[level] = level_counts[level] + 1
-- print(string.format("%d: %s-%s=%d, min:%d", level, from, to, weight, min_cost))
Expand Down Expand Up @@ -199,8 +199,8 @@ local function scott_moura_test()

local G = load_input(input)
local level_counts = {}
local search = ShortestPathSearch:new(G, src)
for from, to, weight, level, min_cost in search:iterate() do
local search = ShortestPathSearch:new(G)
for from, to, weight, level, min_cost in search:iterate(src) do
level_counts[level] = level_counts[level] or 0
level_counts[level] = level_counts[level] + 1
-- print(string.format("%d: %s-%s=%d, min:%d", level, from, to, weight, min_cost))
Expand Down Expand Up @@ -241,8 +241,8 @@ local function negative_tests()
print("ShortestPathSearch negative_tests...")
local G = Graph:new()
G:add('s', 'v', -1)
local search = ShortestPathSearch:new(G, 's')
local iterator = search:iterate()
local search = ShortestPathSearch:new(G)
local iterator = search:iterate('s')
local ok, errmsg = pcall(iterator, search)
assert(not ok)
assert(string.match(errmsg, "Weight must not be negative$"))
Expand Down
70 changes: 35 additions & 35 deletions learning-lua/algo/ShortestPathSearch.lua
Original file line number Diff line number Diff line change
Expand Up @@ -6,38 +6,6 @@ local E = {}

local FROM<const>, TO<const>, WEIGHT<const>, DEPTH<const>, COST_SO_FAR<const> = 1, 2, 3, 4, 5

-- Uses Dijkstra's algorithm
local function _iterate(self)
local G, s, sssp = self.graph, self.src_vertex, self.sssp
local heap = BinaryHeap:new({}, function(a, b)
local a, b = a[TO], b[TO]
return sssp.vertices[a].min_cost <= sssp.vertices[b].min_cost
end)
local depth = 0
local node = self.src_vertex
repeat
sssp.vertices[node].ref = nil -- nullify from's heap reference as from is no longer on the heap
local vertex = self.graph:vertex(node)
depth = depth + 1
for to, weight in vertex:outgoings() do
assert(weight >= 0, "Weight must not be negative")
local v_info, v_cost_so_far = sssp.vertices[to], sssp.vertices[node].min_cost + weight
if v_cost_so_far < v_info.min_cost then
v_info.min_cost = v_cost_so_far
v_info.from = node
local v_ref = v_info.ref
if v_ref then
local entry = heap:remove(v_ref.pos)
assert(entry[TO] == to, string.format("removed: %s, to: %s", entry[TO], to)) -- remove from heap if necessary before adding
end
v_info.ref = heap:add{node, to, weight, depth, v_cost_so_far}
end
end
local item = self._yield(heap:remove())
node, depth = item[TO], item[DEPTH]
until not node -- note a cyclical path would lead to infinite iteration
end

local function shortest_path_of(sssp, dst)
local path = {assert(dst)}
local from = sssp.vertices[dst].from
Expand Down Expand Up @@ -80,6 +48,39 @@ local function new_sssp(s)
return sssp
end

-- Uses Dijkstra's algorithm
local function _iterate(self, src)
self.sssp = new_sssp(src)
local G, sssp = self.graph, self.sssp
local heap = BinaryHeap:new({}, function(a, b)
local a, b = a[TO], b[TO]
return sssp.vertices[a].min_cost <= sssp.vertices[b].min_cost
end)
local depth = 0
local node = src
repeat
sssp.vertices[node].ref = nil -- nullify from's heap reference as from is no longer on the heap
local vertex = self.graph:vertex(node)
depth = depth + 1
for to, weight in vertex:outgoings() do
assert(weight >= 0, "Weight must not be negative")
local v_info, v_cost_so_far = sssp.vertices[to], sssp.vertices[node].min_cost + weight
if v_cost_so_far < v_info.min_cost then
v_info.min_cost = v_cost_so_far
v_info.from = node
local v_ref = v_info.ref
if v_ref then
local entry = heap:remove(v_ref.pos)
assert(entry[TO] == to, string.format("removed: %s, to: %s", entry[TO], to)) -- remove from heap if necessary before adding
end
v_info.ref = heap:add{node, to, weight, depth, v_cost_so_far}
end
end
local item = self._yield(heap:remove())
node, depth = item[TO], item[DEPTH]
until not node -- note a cyclical path would lead to infinite iteration
end

--- ShortestPathSearch.lua enables graph traveral using an iterator
--- (while following the Dijkstra’s graph search strategy).
--- Notes:
Expand All @@ -93,9 +94,8 @@ end
--- * No full computation on the entire graph necessary a priori.
---@param G (table) graph
---@param src (any) source vertex
function ShortestPathSearch:new(G, src)
local o = getmetatable(self):new(G, src, _iterate)
o.sssp = new_sssp(src)
function ShortestPathSearch:new(G)
local o = getmetatable(self):new(G, _iterate)
o.shortest_paths = function(self)
return self.sssp
end
Expand Down
4 changes: 2 additions & 2 deletions learning-lua/algo/TopologicalSearch.lua
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ local function _iterate(self, src)
if src then
stack:push(src)
end
for from, to in DFS:new(self.graph, src, nav_spec, src_spec):iterate() do
for from, to in DFS:new(self.graph, nav_spec, src_spec):iterate(src) do
if not src then -- This condition can be true at most once at the beginning of a topo search.
debug(string.format("Topo search starting from %s", from))
src = from
Expand Down Expand Up @@ -55,7 +55,7 @@ end
---@param nav_spec (table) optional navigation spec in the format of {from_1={to_1, to_2, ...}, ...} e.g. {['3']={'5','11'}, ['5']={'7','9'}}
---@param src_spec (table) optional source vertex spec in the format of {v1, v2, ...} e.g. {'1', '2', '3', ...}; applicable only if a single source vertex is not specified
function TopologicalSearch:new(G, nav_spec, src_spec)
return getmetatable(self):new(G, nil, _iterate, nav_spec, src_spec)
return getmetatable(self):new(G, _iterate, nav_spec, src_spec)
end

return TopologicalSearch

0 comments on commit de9196e

Please sign in to comment.