Skip to content

Latest commit

 

History

History
134 lines (117 loc) · 4.36 KB

1530.number-of-good-leaf-nodes-pairs.md

File metadata and controls

134 lines (117 loc) · 4.36 KB

BFS on Graph

We can build a graph from tree, then start to traverse from each leaf node to find the distance between leaf nodes.

fun countPairs(root: TreeNode?, distance: Int): Int {
    if (root == null) return 0
    val graph = HashMap<TreeNode, HashSet<TreeNode>>()
    val leafNodes = HashSet<TreeNode>()
    buildGraph(root, null, graph, leafNodes)

    var pairs = 0
    for (leaf in leafNodes) {
        val queue = ArrayDeque<TreeNode>()
        val visited = HashSet<TreeNode>()
        queue.addLast(leaf)
        visited.add(leaf)
        var currentDistance = 0
        while (queue.isNotEmpty()) {
            val size = queue.size
            repeat(size) {
                val node = queue.removeFirst()
                // It's a leaf node, and it's not the same leaf node.
                if (node in leafNodes && node != leaf) {
                    pairs++
                }
                graph[node]?.forEach { adj ->
                    if (adj !in visited) {
                        queue.addLast(adj)
                        visited.add(adj)
                    }
                }
            }   
            currentDistance++
            if (currentDistance > distance) break
        }
    }
    return pairs / 2
}

private fun buildGraph(root: TreeNode?, parent: TreeNode?, graph: HashMap<TreeNode, HashSet<TreeNode>>, leafNodes: HashSet<TreeNode>) {
    if (root == null) return
    if (root.left == null && root.right == null) {
        leafNodes.add(root)
    }

    if (root !in graph) graph[root] = HashSet<TreeNode>()
    if (parent != null) graph[root]!!.add(parent)
    if (root.left != null) {
        graph[root]!!.add(root.left)
        buildGraph(root.left, root, graph, leafNodes)
    }
    if (root.right != null) {
        graph[root]!!.add(root.right)
        buildGraph(root.right, root, graph, leafNodes)
    } 
}

Postorder

We have to keep track of the number of leaf nodes with a particular distance for each root node.

We can think that a parent node receives the distance list of children, then every parent node will validate the condition of left_distance + right_distance <= distance. The base case: When reach to leaf node, we return [1].

private var pairs = 0

fun countPairs(root: TreeNode?, distance: Int): Int {
    dfs(root, distance)
    return pairs
}

// Index i represents the number of leaf nodes with distance i
// distance[0] = 2 means there are 2 leaf nodes with distance 0.
// distance[1] = 3 means there are 3 leaf nodes with distance 1.
// so on...
private fun dfs(root: TreeNode?, distance: Int): IntArray {
    val leafDistances = IntArray(distance + 1)
    if (root == null) return leafDistances

    if (root.left == null && root.right == null) {
        leafDistances[1] = 1
        return leafDistances
    }
    val leftDistances = dfs(root.left, distance)
    val rightDistances = dfs(root.right, distance)
    for (left in 1..distance) {
        for (right in 1..distance) {
            if (left + right <= distance) {
                pairs += leftDistances[left] * rightDistances[right]
            }
        }
    }
    for (d in 1 until distance) {
        leafDistances[d + 1] = leftDistances[d] + rightDistances[d]
    }
    return leafDistances
}

// Or we can use a map to store the distance of leaf nodes, then we can calculate the distance of leaf nodes from root node by postorder traversal.
// Map: TreeNode -> Distance, representing the distance between leaf node and root node.
private fun dfs(root: TreeNode?, distance: Int): HashMap<TreeNode, Int> {
    val map = HashMap<TreeNode, Int>()
    if (root == null) return map
    if (root.left == null && root.right == null) {
        map[root] = 0
        return map
    }

    val leftMap = dfs(root.left, distance)
    val rightMap = dfs(root.right, distance)

    for ((k, v) in leftMap) {
        leftMap[k] = v + 1
    }
    for ((k, v) in rightMap) {
        rightMap[k] = v + 1
    }

    for ((kl, vl) in leftMap) {
        for ((kr, vr) in rightMap) {
            if (vl + vr <= distance) count++
        }
    }

    for ((k, v) in rightMap) {
        newMap[k] = v
    }
    return newMap
}

Good explanation for hash map solution.