We can build a graph from tree, then start to traverse from each leaf node to find the distance between leaf nodes.
fun countPairs(root: TreeNode?, distance: Int): Int {
if (root == null) return 0
val graph = HashMap<TreeNode, HashSet<TreeNode>>()
val leafNodes = HashSet<TreeNode>()
buildGraph(root, null, graph, leafNodes)
var pairs = 0
for (leaf in leafNodes) {
val queue = ArrayDeque<TreeNode>()
val visited = HashSet<TreeNode>()
queue.addLast(leaf)
visited.add(leaf)
var currentDistance = 0
while (queue.isNotEmpty()) {
val size = queue.size
repeat(size) {
val node = queue.removeFirst()
// It's a leaf node, and it's not the same leaf node.
if (node in leafNodes && node != leaf) {
pairs++
}
graph[node]?.forEach { adj ->
if (adj !in visited) {
queue.addLast(adj)
visited.add(adj)
}
}
}
currentDistance++
if (currentDistance > distance) break
}
}
return pairs / 2
}
private fun buildGraph(root: TreeNode?, parent: TreeNode?, graph: HashMap<TreeNode, HashSet<TreeNode>>, leafNodes: HashSet<TreeNode>) {
if (root == null) return
if (root.left == null && root.right == null) {
leafNodes.add(root)
}
if (root !in graph) graph[root] = HashSet<TreeNode>()
if (parent != null) graph[root]!!.add(parent)
if (root.left != null) {
graph[root]!!.add(root.left)
buildGraph(root.left, root, graph, leafNodes)
}
if (root.right != null) {
graph[root]!!.add(root.right)
buildGraph(root.right, root, graph, leafNodes)
}
}
We have to keep track of the number of leaf nodes with a particular distance for each root node.
We can think that a parent node receives the distance list of children, then every parent node will validate the condition of left_distance + right_distance <= distance. The base case: When reach to leaf node, we return [1].
private var pairs = 0
fun countPairs(root: TreeNode?, distance: Int): Int {
dfs(root, distance)
return pairs
}
// Index i represents the number of leaf nodes with distance i
// distance[0] = 2 means there are 2 leaf nodes with distance 0.
// distance[1] = 3 means there are 3 leaf nodes with distance 1.
// so on...
private fun dfs(root: TreeNode?, distance: Int): IntArray {
val leafDistances = IntArray(distance + 1)
if (root == null) return leafDistances
if (root.left == null && root.right == null) {
leafDistances[1] = 1
return leafDistances
}
val leftDistances = dfs(root.left, distance)
val rightDistances = dfs(root.right, distance)
for (left in 1..distance) {
for (right in 1..distance) {
if (left + right <= distance) {
pairs += leftDistances[left] * rightDistances[right]
}
}
}
for (d in 1 until distance) {
leafDistances[d + 1] = leftDistances[d] + rightDistances[d]
}
return leafDistances
}
// Or we can use a map to store the distance of leaf nodes, then we can calculate the distance of leaf nodes from root node by postorder traversal.
// Map: TreeNode -> Distance, representing the distance between leaf node and root node.
private fun dfs(root: TreeNode?, distance: Int): HashMap<TreeNode, Int> {
val map = HashMap<TreeNode, Int>()
if (root == null) return map
if (root.left == null && root.right == null) {
map[root] = 0
return map
}
val leftMap = dfs(root.left, distance)
val rightMap = dfs(root.right, distance)
for ((k, v) in leftMap) {
leftMap[k] = v + 1
}
for ((k, v) in rightMap) {
rightMap[k] = v + 1
}
for ((kl, vl) in leftMap) {
for ((kr, vr) in rightMap) {
if (vl + vr <= distance) count++
}
}
for ((k, v) in rightMap) {
newMap[k] = v
}
return newMap
}
Good explanation for hash map solution.