Skip to content

Commit

Permalink
Enhance DFS.lua to make it easier to implement topological sort
Browse files Browse the repository at this point in the history
  • Loading branch information
hansonchar committed Jan 17, 2025
1 parent 3295d5e commit 0b54476
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 19 deletions.
21 changes: 9 additions & 12 deletions learning-lua/algo/DFS-tests.lua
Original file line number Diff line number Diff line change
Expand Up @@ -23,25 +23,22 @@ local function dfs_test(input, src, expected_visits, expected_max_level)
end

--- We define a terminal node as the node with the highest topological order in a DAG.
--- We idenity the terminal node by observing that, for the very first time,
--- if the DFS level doesn't increase, then the previous node (ie to node) visited must be the terminal node.
--- We identify the terminal node by detecting the first node with no (unvisited) outgoing edges.
local function topo_sort(input, src)
local terminal -- the terminal node
local prev_level, prev_to = 0, src
local level_visits = {}
local G = load_input(input)
local dfs = DFS:new(G, src)
for from, to, _, level in dfs:iterate() do
-- outgoings is the number of unvisited outgoing edges
for from, to, _, level, outgoings in dfs:iterate() do
-- print(string.format("%d: %s-%s", level, from, to))
if not terminal and level <= prev_level then
terminal = prev_to
local visits = level_visits[prev_level] -- remove the terminal node from level_visits
visits[#visits] = nil -- as we handle the source and terminal node differently.
if not terminal and outgoings == 0 then
terminal = to
else
level_visits[level] = level_visits[level] or {}
local visits = level_visits[level]
visits[#visits + 1] = to
end
level_visits[level] = level_visits[level] or {}
local visits = level_visits[level]
visits[#visits + 1] = to
prev_level, prev_to = level, to
end
local a = {src}
for _, visits in ipairs(level_visits) do
Expand Down
29 changes: 22 additions & 7 deletions learning-lua/algo/DFS.lua
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,24 @@ local Stack = require "algo.Stack"
local DFS = GraphSearch:class()
local E = {}

--- Push all unvisited outgoing edges to the stack.
---@param self (table) the current DFS instance
---@param from (any) from node
---@param level (number) the number of hops from the source node
---@param stack (table) the stack to push edges to
---@param visited (table) used to check if a node has been visited
---@return (number) the number of unvisited outoing edges pushed to the stack
local function push_edges(self, from, level, stack, visited)
local vertex, count = self.graph:vertex(from), 0
for to, weight in vertex:outgoings() do
if not visited[to] then
stack:push{to, weight, level, from}
count = count + 1
end
end
return count
end

local function _iterate(self)
local stack, visited = Stack:new(), {}
self._visited_count = 0
Expand All @@ -12,15 +30,12 @@ local function _iterate(self)
while from do
if not visited[from] then
visited[from], self._visited_count = true, self._visited_count + 1
if t then -- t is nil only during the first iteration when we don't yet have an edge to yield
self._yield(t)
end
local vertex = self.graph:vertex(from)
level = level + 1
for to, weight in vertex:outgoings() do
if not visited[to] then
stack:push{to, weight, level, from}
end
local count = push_edges(self, from, level, stack, visited)
if t then -- t is nil only during the first iteration when we are starting with the source node.
t[#t + 1] = count -- append the number of unvisited outgoing edges.
self._yield(t)
end
end
t = (stack:pop() or E)
Expand Down

0 comments on commit 0b54476

Please sign in to comment.