from typing import List
Problem Description
Difficulty: Medium
Given an integer array nums
, return all the triplets [nums[i], nums[j], nums[k]]
where nums[i] + nums[j] + nums[k] == 0
, and the indices i != j
, i != k
, and j != k
.
The solution should not contain duplicate triplets. Output order does not matter.
Example 1:
Input: nums = [-1,0,1,2,-1,-4]
Output: [[-1,-1,2],[-1,0,1]]
Explanation:
nums[0] + nums[1] + nums[2] = (-1) + 0 + 1 = 0
nums[1] + nums[2] + nums[4] = 0 + 1 + (-1) = 0
nums[0] + nums[3] + nums[4] = (-1) + 2 + (-1) = 0
The distinct triplets are [-1,0,1] and [-1,-1,2].
Example 2:
Input: nums = [0,1,1]
Output: []
Explanation: The only possible triplet does not sum up to 0.
Example 3:
Input: nums = [0,0,0]
Output: [[0,0,0]]
Explanation: The only possible triplet sums up to 0.
Constraints:
3 <= nums.length <= 3000 # LeetCode
3 <= nums.length <= 1000 # NeetCode
-10^5 <= nums[i] <= 10^5
Initial Solution
class Solution:
def threeSum(self, nums: List[int]) -> List[List[int]]:
= []
triplets # O(n log n) sorting
nums.sort()
for idx, value in enumerate(nums):
= idx + 1
l = len(nums) - 1
r
# skip duplicate elements:
if idx > 0 and value == nums[idx - 1]:
continue
while l < r:
= [value, nums[l], nums[r]]
triplet = sum(triplet)
triplet_sum
if triplet_sum > 0:
-= 1
r elif triplet_sum < 0:
+= 1
l else:
# inefficient check for duplicates:
if triplet not in triplets:
triplets.append(triplet)+= 1
l
return triplets
Test function
def test_three_sum(solution_class):
""" Test of threeSum function. Not extensive."""
= solution_class()
solution = [
test_cases
{"nums": [],
"expected": []
},
{"nums": [0, 1, 1],
"expected": []
},
{"nums": [0, 0, 0],
"expected": [[0, 0, 0]]
},
{"nums": [0, 0, 0, 0],
"expected": [[0, 0, 0]]
},
{"nums": [-1, 0, 1, 2, -1, -4],
"expected": [[-1, -1, 2], [-1, 0, 1]]
},
{"nums": [-2, 0, 1, 1, 2],
"expected": [[-2, 0, 2], [-2, 1, 1]]
},
{"nums": [-1, 0, 1, 2, -1, -4, -2, -3, 3, 0, 4],
"expected": [[-4, 0, 4], [-4, 1, 3], [-3, -1, 4], [-3, 0, 3],
-3, 1, 2], [-2, -1, 3], [-2, 0, 2], [-1, -1, 2], [-1, 0, 1]]
[
}
]
for i, test_case in enumerate(test_cases):
= test_case["nums"]
nums = test_case["expected"]
expected = solution.threeSum(nums)
results
# Sort both results and expected for comparison
= sorted([sorted(triplet) for triplet in results])
results_sorted = sorted([sorted(triplet) for triplet in expected])
expected_sorted
if results_sorted == expected_sorted:
print(f"Test case {i+1} passed")
else:
print(f"Test case {i+1} failed")
print(f"Expected: {expected_sorted}")
print(f"Got: {results_sorted}")
test_three_sum(Solution)
Test case 1 passed
Test case 2 passed
Test case 3 passed
Test case 4 passed
Test case 5 passed
Test case 6 passed
Test case 7 passed
Initial Results
My initial attempt passed NeetCode but is too slow to pass LeetCode, failing the last 2 of 313 testcases. I also created a somewhat messy algorithm that used a hashmap to store sums from pairs of values in an attempt to trade off some memory for time, but that failed the speed test too.
Having wasted enough time on this problem, I gave up and looked at the NeetCode solution. However, it later occurred to me that we could store the triplets in a set, then we wouldn’t have to check if they existed in the set, as the set would handle this naturally. The slightly improved algorithm is below.
class Solution:
def threeSum(self, nums: List[int]) -> List[List[int]]:
= set() # Use a set to store unique triplets
triplets # O(n log n) sorting
nums.sort()
for idx, value in enumerate(nums): # O(n)
= idx + 1
l = len(nums) - 1
r
# Skip duplicate elements to avoid duplicate triplets
if idx > 0 and value == nums[idx - 1]:
continue
while l < r: # O(n) for each iteration
= (value, nums[l], nums[r]) # Use tuple for immutability (needed by set)
triplet = sum(triplet)
triplet_sum
if triplet_sum > 0:
-= 1
r elif triplet_sum < 0:
+= 1
l else:
# O(1) on average
triplets.add(triplet) += 1
l
return list(map(list, triplets)) # Convert tuple back to list
This one finally passed the LeetCode submission, but is unfortunately very slow. It beats a mere 19% on runtime and 5% on memory. Pretty terrible, but good enough to pass. Worst-case complexity is:
- Time complexity: \(O(n^2)\) due to the nested loops
- Space complexity: \(O(n^2)\) due to storing the results list (also the set)
NeetCode Solution
class Solution:
def threeSum(self, nums: List[int]) -> List[List[int]]:
= []
result
nums.sort()
for i, a in enumerate(nums):
if a > 0:
break
if i > 0 and a == nums[i - 1]:
continue
= i + 1, len(nums) - 1
l, r while l < r:
= a + nums[l] + nums[r]
threeSum if threeSum > 0:
-= 1
r elif threeSum < 0:
+= 1
l else:
result.append([a, nums[l], nums[r]])+= 1
l -= 1
r while nums[l] == nums[l - 1] and l < r:
+= 1
l
return result
NeetCode’s solution has two nice optimizations over my code:
- Early exit:
if a > 0:
break
When the number a
in nums
becomes positive, the loop terminates early. Since the array is sorted, if a > 0
, it’s impossible to find two other positive numbers whose sum, along with a
, equals zero.
- Efficient handling off duplicate triplets:
while nums[l] == nums[l - 1] and l < r:
l += 1
After finding a valid triplet, the left pointer is incremented and the additional while loop ensures we skip over any duplicate numbers adjacent to nums[l]
. Thus preventing the same triplet from being added to the result multiple times.
On LeetCode, this nice algorithm beats 98% on runtime and 85% on memory. Worst-case complexity is:
- Time complexity: \(O(n^2)\)
- Space complexity: \(O(n^2)\)