Problem Statement

In a rooted binary tree, define the level of a vertex to be the number of edges in the path from that vertex to the root. So, for instance, the level of the root is the level of the vertices immediately adjacent to the root is and so on.

Assume you are given a binary tree whose nodes are integers. Write the function max_level that when given the root of the tree, returns the level of the tree whose values have maximum sum. When implementing this function, please do not use recursion. However, feel free to implement auxiliary data structures. Additionally, please provide runtime and space complexity analysis for your solution.

Assume the input root is a TreeNode with access methods TreeNode.val, TreeNode.left, TreeNode.right which return the value at the node, its left child, and its right child (which are also of type TreeNode), respectively. If the node doesn’t have a particular child, the value of the corresponding access method is None.


class TreeNode:
    """
    # Implementation Hidden #
    Methods:
    self.val    -   returns the value of the node
    self.left   -   returns left child of the node, which is a TreeNode,
                    or None if the child doesn't exist
    self.right  -   returns right child of the node, which is a TreeNode,
                    or None if the child doesn't exist
    """

def max_level(root):
    # your code here

Followups:

(1) What if the tree was a complete binary search tree? Could you make your solution faster knowing this? What if additionally all entries were positive?

(2) How would you solve this recursively? ***

Examples

Example

>>> root
         11
    15        -3
  1    1        
>>> max_level(root)
1

Explanation: The sum of the levels is so the level with the largest sum is

Hints

Click Here

How can you keep track of the sum while traversing the tree?

Click below for the solution

Click Here

Sample Solution

The idea is to traverse the tree using a breadth first search, using a Queue to organize the vertices so that we visit them in level order. We keep an track of the running sum of a level, and then compare it to the max sum we’ve seen so far, keeping track of level indices. Since nodes are enqueued in level order, we never over count or leave out any nodes. Finally, we return the max index at the end.


class Queue:
    def __init__(self):
        self.size = 0
        self.rest = []

    def enqueue(self, item):
        self.rest.append(item)
        self.size += 1

    def dequeue(self):
        if self.size > 0:
            self.size -= 1
            return self.rest.pop(0)

    def is_empty(self):
        return self.size == 0

def max_level(root)
    q = Queue()
    sums, max_sums, best_level = 0, -float('inf'), 0
    q.enqueue((root, 0))
    curr_level = 0
    while not q.is_empty():
        node, level = q.dequeue()
        if level > curr_level:
            if sums > max_sums:
                best_level = curr_level
                max_sums = sums
            sums = 0
            curr_level = level
        sums += node.val
        if node.left:
            q.enqueue((node.left, level + 1))
        if node.right:
            q.enqueue((node.right, level + 1))
    if sums > max_sums:
        best_level = curr_level
    return best_level

Solutions to Followups

(1) Even if the tree is a complete BST, there isn’t enough structure for us to avoid looking at every element at least once, so we can’t really improve on this solution. However, if we additionally restrict to only having positive elements, then we claim we only need to check the lowest complete level against the lowest level. To see why, note that since the tree is complete, every node above those two levels has two children. Moreover, because this is a BST and because every entry is positive, the sum of the two children is always at least as large as the parent. This means that (full) lower levels will always have larger sums than higher layers. Hence, we only need to check the lower two layers. This is a constant factor improvement of a factor of 2.

(2) Recursion offers a very elegant solution using a helper function. The idea is to

def max_level(root):
    sums = []
    def helper(node, level):
        if node:
            nonlocal sums
            if len(sums) <= level:
                sums.append(0)
            sums[level] += node.val
            helper(node.left, level + 1)
            helper(node.right, level + 1)
    helper(root, 0)
    return sums.index(max(sums)) + 1

Runtime

This algorithm traverses the entire tree exactly once, so it is linear in the number of vertices of the tree.