From 0b5447671e4d3cdf358adf7f434153e5b6aecf7b Mon Sep 17 00:00:00 2001 From: Hanson Char Date: Fri, 17 Jan 2025 10:43:02 -0800 Subject: [PATCH] Enhance DFS.lua to make it easier to implement topological sort --- learning-lua/algo/DFS-tests.lua | 21 +++++++++------------ learning-lua/algo/DFS.lua | 29 ++++++++++++++++++++++------- 2 files changed, 31 insertions(+), 19 deletions(-) diff --git a/learning-lua/algo/DFS-tests.lua b/learning-lua/algo/DFS-tests.lua index 76c411b..791dde0 100644 --- a/learning-lua/algo/DFS-tests.lua +++ b/learning-lua/algo/DFS-tests.lua @@ -23,25 +23,22 @@ local function dfs_test(input, src, expected_visits, expected_max_level) end --- We define a terminal node as the node with the highest topological order in a DAG. ---- We idenity the terminal node by observing that, for the very first time, ---- if the DFS level doesn't increase, then the previous node (ie to node) visited must be the terminal node. +--- We identify the terminal node by detecting the first node with no (unvisited) outgoing edges. local function topo_sort(input, src) local terminal -- the terminal node - local prev_level, prev_to = 0, src local level_visits = {} local G = load_input(input) local dfs = DFS:new(G, src) - for from, to, _, level in dfs:iterate() do + -- outgoings is the number of unvisited outgoing edges + for from, to, _, level, outgoings in dfs:iterate() do -- print(string.format("%d: %s-%s", level, from, to)) - if not terminal and level <= prev_level then - terminal = prev_to - local visits = level_visits[prev_level] -- remove the terminal node from level_visits - visits[#visits] = nil -- as we handle the source and terminal node differently. + if not terminal and outgoings == 0 then + terminal = to + else + level_visits[level] = level_visits[level] or {} + local visits = level_visits[level] + visits[#visits + 1] = to end - level_visits[level] = level_visits[level] or {} - local visits = level_visits[level] - visits[#visits + 1] = to - prev_level, prev_to = level, to end local a = {src} for _, visits in ipairs(level_visits) do diff --git a/learning-lua/algo/DFS.lua b/learning-lua/algo/DFS.lua index 6609929..e724368 100644 --- a/learning-lua/algo/DFS.lua +++ b/learning-lua/algo/DFS.lua @@ -4,6 +4,24 @@ local Stack = require "algo.Stack" local DFS = GraphSearch:class() local E = {} +--- Push all unvisited outgoing edges to the stack. +---@param self (table) the current DFS instance +---@param from (any) from node +---@param level (number) the number of hops from the source node +---@param stack (table) the stack to push edges to +---@param visited (table) used to check if a node has been visited +---@return (number) the number of unvisited outoing edges pushed to the stack +local function push_edges(self, from, level, stack, visited) + local vertex, count = self.graph:vertex(from), 0 + for to, weight in vertex:outgoings() do + if not visited[to] then + stack:push{to, weight, level, from} + count = count + 1 + end + end + return count +end + local function _iterate(self) local stack, visited = Stack:new(), {} self._visited_count = 0 @@ -12,15 +30,12 @@ local function _iterate(self) while from do if not visited[from] then visited[from], self._visited_count = true, self._visited_count + 1 - if t then -- t is nil only during the first iteration when we don't yet have an edge to yield - self._yield(t) - end local vertex = self.graph:vertex(from) level = level + 1 - for to, weight in vertex:outgoings() do - if not visited[to] then - stack:push{to, weight, level, from} - end + local count = push_edges(self, from, level, stack, visited) + if t then -- t is nil only during the first iteration when we are starting with the source node. + t[#t + 1] = count -- append the number of unvisited outgoing edges. + self._yield(t) end end t = (stack:pop() or E)