from typing import List
Problem Description
Difficulty: Medium
You are given an m × n
integer array matrix
and an integer target
with the following properties:
- Each row in
matrix
is sorted in non-decreasing order. - The first integer of each row is greater than the last integer of the previous row.
Return true
if target
exists in matrix
, or false
otherwise.
The solution must run in \(O(\log n)\) time.
Examples:
Input: matrix = [[1,3,5,8],[10,12,16,20],[23,30,36,50]], target = 3
Output: True
Input: matrix = [[1,2,3,7],[10,11,12,13],[14,20,30,40]], target = 15
Output: False
Constraints:
m == matrix.length
n == matrix[i].length
1 <= m, n <= 100
-10000 <= matrix[i][j], target <= 10000
Initial Solution
class Solution:
def searchMatrix(self, matrix: List[List[int]], target: int) -> bool:
= 0
top = len(matrix) - 1
bottom
while top < bottom:
= top + ((bottom - top) // 2)
midpoint_row
if matrix[midpoint_row][0] == target:
return True
elif matrix[midpoint_row][-1] < target:
= midpoint_row + 1
top else:
= midpoint_row
bottom
= matrix[top]
row
= 0
l = len(row) - 1
r
while l <= r:
= l + ((r - l) // 2)
midpoint
if row[midpoint] == target:
return True
elif row[midpoint] < target:
= midpoint + 1
l else:
= midpoint - 1
r
return False
Initial Results:
The left-to-right search after finding the target row was basically copied from a previous binary search problem, so that was relatively easy. It took me a little time to figure out the row search, but I got there eventually. The code above runs with the following worst-case complexity:
Complexity
- Time complexity: \(O(\log m + \log n)\) overall
- The first binary search narrows down the potential row in \(O(\log m)\) time.
- The second binary search finds the target within that row in \(O(\log n)\) time.
- The first binary search narrows down the potential row in \(O(\log m)\) time.
- Space complexity: \(O(1)\)
- Only integer variables are used (
top
,bottom
,midpoint_row
,l
,r
), so it operates in constant space.
- Only integer variables are used (
Test Function
def test_searchMatrix(solution_class):
""" Test function for searchMatrix. Not extensive. """
# Instantiate the class
= solution_class()
solution
# Define the test cases
= [
test_cases
{"matrix": [[1, 3, 5, 8], [10, 12, 16, 20], [23, 30, 36, 50]],
"target": 3,
"expected": True
},
{"matrix": [[1, 2, 3, 7], [10, 11, 12, 13], [14, 20, 30, 40]],
"target": 15,
"expected": False
},
{"matrix": [[1, 2, 3, 7], [10, 11, 12, 13], [14, 15, 30, 40]],
"target": 15,
"expected": True
},
{"matrix": [[1], [3]],
"target": 3,
"expected": True
},
{"matrix": [[1], [3]],
"target": 0,
"expected": False
},
{"matrix": [[-10,-8,-6,-4,-3],[0,2,3,4,6],[8,9,10,10,12]],
"target": 0,
"expected": True
}
]
# Iterate over test cases
for i, test_case in enumerate(test_cases):
= test_case["matrix"]
matrix = test_case["target"]
target = test_case["expected"]
expected
# Get the result from the searchMatrix method
= solution.searchMatrix(matrix, target)
result
# Check if the result matches the expected output
if result == expected:
print(f"Test case {i+1} passed")
else:
print(f"Test case {i+1} failed")
print(f"Expected: {expected}")
print(f"Got: {result}")
test_searchMatrix(Solution)
Test case 1 passed
Test case 2 passed
Test case 3 passed
Test case 4 passed
Test case 5 passed
Test case 6 passed
NeetCode Solution
NeedCode have a few solutions on their website, including one fairly similar to the binary search implemented above. Theirs was a little better than mine, breaking out of the first while
loop and stopping early if the row doesn’t contain the target:
if not (top <= bot):
return False
They also had a nice one-pass solution, in which the 2D matrix is conceptually flattened into a 1D (sorted) array:
class Solution:
def searchMatrix(self, matrix: List[List[int]], target: int) -> bool:
# Get the number of rows and columns in the matrix
= len(matrix), len(matrix[0])
ROWS, COLS
# Define left and right pointers for binary search on a virtual 1D array
= 0, ROWS * COLS - 1
l, r
# Perform binary search
while l <= r:
# Calculate the midpoint index in the virtual 1D array
= l + (r - l) // 2
m
# Convert the 1D index into 2D row and column indices
= m // COLS, m % COLS
row, col
# Compare target with the current element at (row, col)
if target > matrix[row][col]:
# If target is greater, move to the right half
= m + 1
l elif target < matrix[row][col]:
# If target is smaller, move to the left half
= m - 1
r else:
# Found the target, return True
return True
# Target not found, return False
return False