This page looks best with JavaScript enabled

Leetcode Problem - All Nodes Distance K in Binary Tree

2st of 15 Leetcode problems

 ·   ·  ☕ 7 min read · 👀... views
Problem link: https://leetcode.com/problems/all-nodes-distance-k-in-binary-tree/
Difficulty: Medium
Category: Trees

Step 1: Grasping the Problem

This is purely a graph problem. We are given a binary tree (via a root node), a target node, and an integer K. We need to return a list of nodes at K distance from the target node. For the below example, the target node (marked yellow) is 5 and the answer will be [7, 4, 1] (marked blue), because K = 2.

problem-tree

Below are the cases to handle for solving problem:

  1. The targets children at distance K.
  2. The targets parent node being at distance K.
  3. Other nodes being l+m distant from target: where l is the distance from target to a parent node par, and m is the distance from par to the node. (See the picture above)

If K is even, then you can say that target node is at distance K from itself. We are just assuming that we will always exclude the target node.

Step 2: Build a graph

It can be seen that, there is no escape from traversing from the target node to it’s parent for solving this problem. But it cannot be done in a tree data structure.

So the first step is to convert this tree to a bidirectional unweighted graph. We can traverse the tree and every time we go to a child node, we will assign the parent to the child. This can be done both recursively or iteratively.

Unlike other statically typed programming languages, python has the ability to create class attributes outside of class. We will use this ability for convenience.

1
2
3
4
5
6
7
8
# Time Complexity: O(n)
def dfs(node, parent=None):
    if node: # Check if the node is not None
        node.parent = parent # Create a parent attribute
        dfs(node.left, node)
        dfs(node.right, node)

dfs(root) # Now all nodes in the graph has parent attribute

Step 3: BFS

Now that we have the graph, we can easily apply a BFS (Breadth First Search) from the target to all nodes in the graph, check for the distance and if it is K, log it in a list. Finally, return the list. Again, we can keep the distance saved within the node itself.

If we find a node at distance K, BFS property tells that we have all the nodes at distance K already in the queue. So we can stop BFS and return the queue

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
# Time Complexity: O(n*n)
target.dist = 0     # distance from itself is zero
queue = [target]    # BFS queue
seen = {target}     # A set to check if visited

while queue:                # run until queue is empty
    if queue[0].dist == K:  # Found one of our answer nodes
        # Return the queue
        ans = []
        for node in queue:
            ans.append(node.val)
        return ans

    node = queue.pop(0)     # Takes O(len(queue))

    # Go through the neighbors of current node
    for neighbor in (node.left, node.right, node.parent):
        if neighbor and neighbor not in seen:
            neighbor.dist = node.dist + 1
            seen.add(neighbor)
            queue.append(neighbor)

Step 4: Handle corner cases

In this problem, there can be 2 corner cases.

  1. There will be no node at that distance (the graph is very small compared to K).
  2. The graph is empty.

In both cases, we will need to return an empty array.

1
return []

Step 4: Optimize

The complete algorithm currently runs in O(n + n*n) = O(n*n) complexity. Notice that, we are using a list to do a queue’s job in BFS. Python has a built in queue and deque module which have constant access time for all operations. We can use deque for this purpose. Also in some places, we can use list comprehensions which will not effect on complexity, but may have some interpreter optimizations.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
target.dist = 0
queue = collections.deque([target])
seen = {target}
while queue:
    if queue[0].dist == K:
        return [node.val for node in queue]
    node = queue.popleft()
    for neighbor in (node.left, node.right, node.parent):
        if neighbor and neighbor not in seen:
            neighbor.dist = node.dist + 1
            seen.add(neighbor)
            queue.append(neighbor)
return []

Finally, the time complexity will be O(n) and space complexity is also O(n) where n is the number of nodes in the tree. This is the best because all the n nodes must be visited once to solve the problem.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import collections

# Definition for a binary tree node.
class TreeNode:
    def __init__(self, x):
        self.val = x
        self.left = None
        self.right = None

class Solution(object):
    def distanceK(self, root: TreeNode, target: TreeNode, K: int) -> [int]:
        def dfs(node, parent=None):
            if node:            # Check if the node is not None
                node.parent = parent  # Create a parent attribute
                dfs(node.left, node)
                dfs(node.right, node)
        dfs(root) # Now all nodes in the graph has parent attribute

        target.dist = 0                 # distance from itself is zero
        queue = collections.deque([target]) # BFS queue
        seen = {target}                 # A set to check if visited
        while queue:                    # run until queue is empty
            if queue[0].dist == K:      # Found one of our answer nodes
                return [node.val for node in queue]
            node = queue.popleft()

            # Go through the neighbors of current node
            for neighbor in (node.left, node.right, node.parent):
                if neighbor and neighbor not in seen:
                    neighbor.dist = node.dist + 1
                    seen.add(neighbor)
                    queue.append(neighbor)
        return []
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
/**
 * Definition for a binary tree node.
 * struct TreeNode {
 *     int val;
 *     TreeNode *left;
 *     TreeNode *right;
 *     TreeNode(int x) : val(x), left(NULL), right(NULL) {}
 * };
 */
class Solution {
public:
    vector<int> distanceK(TreeNode* root, TreeNode* target, int K) {
        map<TreeNode*, TreeNode*> parents;    // Initial parents for all nodes
        dfs(parents, root, nullptr);

        queue<pair<TreeNode*, int> > q;
        set<TreeNode*> seen;

        q.push({target, 0});                // Distance from itself is zero
        seen.insert(target);                // To check if visited
        while(!q.empty()) {                 // run until queue is empty
            if(q.front().second == K) {
                vector<int> result;
                while(!q.empty()) {
                    result.push_back(q.front().first->val);
                    q.pop();
                }
                return result;
            }

            // Go through the neighbors of current node
            auto cur = q.front();
            q.pop();
            auto node = cur.first;
            auto dist = cur.second;

            TreeNode* neighbors[] = {
                node->left,
                node->right,
                parents[node]
            };
            for(auto neighbor: neighbors) {
                if(neighbor != nullptr && seen.find(neighbor) == seen.end()) {
                    auto new_dist = dist + 1;
                    seen.insert(neighbor);
                    q.push({neighbor, new_dist});
                }
            }
        }
        return {};
    }

private:
    void dfs(map<TreeNode*, TreeNode*> &parents, TreeNode* node, TreeNode* parent) {
        if(node != nullptr) {                   // Check if the node is not Node
            parents[node] = parent;             // Create a parent
            dfs(parents, node->left, node);
            dfs(parents, node->right, node);
        }
    }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import java.util.*;
import java.lang.*;

/**
 * Definition for a binary tree node.
 * public class TreeNode {
 *     int val;
 *     TreeNode left;
 *     TreeNode right;
 *     TreeNode(int x) { val = x; }
 * }
 */
class Pair {
    public TreeNode first;
    public Integer second;
    public Pair(TreeNode first, Integer second) {
        this.first = first;
        this.second = second;
    }
}
class Solution {
    public List<Integer> distanceK(TreeNode root, TreeNode target, int K) {
        Map<TreeNode, TreeNode> parents = new HashMap<>();
        dfs(parents, root, null);                   // Initial parents for all nodes

        Queue<Pair> queue = new LinkedList<>();
        Set<TreeNode> seen = new HashSet<>();

        queue.add(new Pair(target, 0));             // Distance from itself is zero
        seen.add(target);                           // To check if visited
        while(!queue.isEmpty()) {                   // run until queue is empty
            if(queue.peek().second == K) {          // Found one of our answer nodes
                List<Integer> result = new ArrayList<>();
                while(!queue.isEmpty())
                    result.add(queue.remove().first.val);
                return result;
            }

            // Go through the neighbors of current node
            Pair cur = queue.remove();
            TreeNode node = cur.first;
            int dist = cur.second;

            TreeNode[] neighbors = {
                node.left,
                node.right,
                parents.get(node)
            };
            for(TreeNode neighbor : neighbors) {
                if(neighbor != null && !seen.contains(neighbor)) {
                    int new_dist = dist + 1;
                    seen.add(neighbor);
                    queue.add(new Pair(neighbor, new_dist));
                }
            }
        }
        return new ArrayList<Integer>();
    }

    private void dfs(Map<TreeNode, TreeNode> parents, TreeNode node, TreeNode parent) {
        if(node != null) {                  // Check if the node is not Node
            parents.put(node, parent);      // Create a parent attribute
            dfs(parents, node.left, node);
            dfs(parents, node.right, node);
        }
    }
}
Share on

Rahat Zaman
WRITTEN BY
Rahat Zaman
Graduate Research Assistant, School of Computing