Kth Smallest Integer in BST
Problem Statement
Given the root
of a binary search tree, and an integer k
, return the kth
smallest value (1-indexed) in the tree.
A binary search tree satisfies the following constraints:
- The left subtree of every node contains only nodes with keys less than the node’s key.
- The right subtree of every node contains only nodes with keys greater than the node’s key.
- Both the left and right subtrees are also binary search trees.
Example 1:
Input: root = [2,1,3], k = 1
Output: 1
Example 2:
Input: root = [4,3,5,2,null], k = 4
Output: 5
Constraints:
1 <= k <= The number of nodes in the tree <= 1000
.0 <= Node.val <= 1000
You should aim for a solution as good or better than O(n)
time and O(n)
space, where n
is the number of nodes in the given tree.
Recommended Time and Space Complexity
You should aim for a solution as good or better than O(n)
time and O(n)
space, where n
is the number of nodes in the given tree.
Hint 1
A naive solution would be to store the node values in an array, sort it, and then return the k
-th value from the sorted array. This would be an O(n log n)
solution due to sorting. Can you think of a better way? Maybe you should try one of the traversal techniques.
Hint 2
We can use the Depth First Search (DFS) algorithm to traverse the tree. Since the tree is a Binary Search Tree (BST), we can leverage its structure and perform an in-order traversal, where we first visit the left subtree, then the current node, and finally the right subtree. Why? Because we need the k
-th smallest integer, and by visiting the left subtree first, we ensure that we encounter smaller nodes before the current node. How can you implement this?
Hint 3
We keep a counter variable cnt
to track the position of the current node in the ascending order of values. When cnt == k
, we store the current node’s value in a global variable and return. This allows us to identify and return the k
-th smallest element during the in-order traversal.
Solution
Brute Force
# Definition for a binary tree node.# class TreeNode:# def __init__(self, val=0, left=None, right=None):# self.val = val# self.left = left# self.right = rightfrom typing import Optional
class TreeNode: def __init__(self, val=0, left=None, right=None): self.val = val self.left = left self.right = right
def kth_smallest(root: Optional[TreeNode], k: int) -> int: arr = []
def dfs(node): if not node: return
arr.append(node.val) dfs(node.left) dfs(node.right)
dfs(root) arr.sort() return arr[k - 1]
Time complexity: O(n log n)
Space complexity: O(n)
Inorder Traversal
# Definition for a binary tree node.# class TreeNode:# def __init__(self, val=0, left=None, right=None):# self.val = val# self.left = left# self.right = rightfrom typing import Optional
class TreeNode: def __init__(self, val=0, left=None, right=None): self.val = val self.left = left self.right = right
def kth_smallest(root: Optional[TreeNode], k: int) -> int: arr = []
def dfs(node): if not node: return
dfs(node.left) arr.append(node.val) dfs(node.right)
dfs(root) return arr[k - 1]
Time complexity: O(n)
Space complexity: O(n)
Recursive DFS (Optimal)
# Definition for a binary tree node.# class TreeNode:# def __init__(self, val=0, left=None, right=None):# self.val = val# self.left = left# self.right = rightfrom typing import Optional
class TreeNode: def __init__(self, val=0, left=None, right=None): self.val = val self.left = left self.right = right
def kth_smallest(root: Optional[TreeNode], k: int) -> int: cnt = k res = root.val
def dfs(node): nonlocal cnt, res if not node: return
dfs(node.left) cnt -= 1 if cnt == 0: res = node.val return dfs(node.right)
dfs(root) return res
Time complexity: O(n)
Space complexity: O(n)
Iterative DFS (Optimal)
# Definition for a binary tree node.# class TreeNode:# def __init__(self, val=0, left=None, right=None):# self.val = val# self.left = left# self.right = rightfrom typing import Optional
class TreeNode: def __init__(self, val=0, left=None, right=None): self.val = val self.left = left self.right = right
def kth_smallest(root: Optional[TreeNode], k: int) -> int: stack = [] curr = root
while stack or curr: while curr: stack.append(curr) curr = curr.left curr = stack.pop() k -= 1 if k == 0: return curr.val curr = curr.right
Time complexity: O(n)
Space complexity: O(n)
Morris Traversal
# Definition for a binary tree node.# class TreeNode:# def __init__(self, val=0, left=None, right=None):# self.val = val# self.left = left# self.right = rightfrom typing import Optional
class TreeNode: def __init__(self, val=0, left=None, right=None): self.val = val self.left = left self.right = right
def kth_smallest(root: Optional[TreeNode], k: int) -> int: curr = root
while curr: if not curr.left: k -= 1 if k == 0: return curr.val curr = curr.right else: pred = curr.left while pred.right and pred.right != curr: pred = pred.right
if not pred.right: pred.right = curr curr = curr.left else: pred.right = None k -= 1 if k == 0: return curr.val curr = curr.right
return -1
Time complexity: O(n)
Space complexity: O(1)