LeetCode 15: Three Sum

code
two pointers
medium
Author

spa-dev

Published

September 15, 2024

Problem Description

Difficulty: Medium

Given an integer array nums, return all the triplets [nums[i], nums[j], nums[k]] where nums[i] + nums[j] + nums[k] == 0, and the indices i != j, i != k, and j != k.

The solution should not contain duplicate triplets. Output order does not matter.

Example 1:

Input: nums = [-1,0,1,2,-1,-4]
Output: [[-1,-1,2],[-1,0,1]]

Explanation:

nums[0] + nums[1] + nums[2] = (-1) + 0 + 1 = 0
nums[1] + nums[2] + nums[4] = 0 + 1 + (-1) = 0
nums[0] + nums[3] + nums[4] = (-1) + 2 + (-1) = 0

The distinct triplets are [-1,0,1] and [-1,-1,2].

Example 2:

Input: nums = [0,1,1]
Output: []

Explanation: The only possible triplet does not sum up to 0.

Example 3:

Input: nums = [0,0,0]
Output: [[0,0,0]]

Explanation: The only possible triplet sums up to 0.

Constraints:

3 <= nums.length <= 3000 # LeetCode
3 <= nums.length <= 1000 # NeetCode
-10^5 <= nums[i] <= 10^5

Initial Solution

from typing import List
class Solution:
    def threeSum(self, nums: List[int]) -> List[List[int]]:
        triplets = []
        nums.sort() # O(n log n) sorting
        
        for idx, value in enumerate(nums):
            l = idx + 1
            r = len(nums) - 1

            # skip duplicate elements:
            if idx > 0 and value == nums[idx - 1]:
                continue
            
            while l < r:
                triplet = [value, nums[l], nums[r]]
                triplet_sum = sum(triplet)
                
                if triplet_sum > 0:
                    r -= 1
                elif triplet_sum < 0:
                    l += 1
                else:
                    # inefficient check for duplicates:
                    if triplet not in triplets:
                        triplets.append(triplet)
                    l += 1
                                    
        return triplets

Test function

def test_three_sum(solution_class):
    """ Test of threeSum function. Not extensive."""
    solution = solution_class()
    test_cases = [
        {
            "nums": [],
            "expected": []
        },
        {
            "nums": [0, 1, 1],
            "expected": []
        },
        {
            "nums": [0, 0, 0],
            "expected": [[0, 0, 0]]
        },
        {
            "nums": [0, 0, 0, 0],
            "expected": [[0, 0, 0]]
        },
        {
            "nums": [-1, 0, 1, 2, -1, -4],
            "expected": [[-1, -1, 2], [-1, 0, 1]]
        },
        {
            "nums": [-2, 0, 1, 1, 2],
            "expected": [[-2, 0, 2], [-2, 1, 1]]
        },
        {
            "nums": [-1, 0, 1, 2, -1, -4, -2, -3, 3, 0, 4],
            "expected": [[-4, 0, 4], [-4, 1, 3], [-3, -1, 4], [-3, 0, 3], 
                         [-3, 1, 2], [-2, -1, 3], [-2, 0, 2], [-1, -1, 2], [-1, 0, 1]]
        }
    ] 

    for i, test_case in enumerate(test_cases):
        nums = test_case["nums"]
        expected = test_case["expected"]
        results = solution.threeSum(nums)
        
        # Sort both results and expected for comparison
        results_sorted = sorted([sorted(triplet) for triplet in results])
        expected_sorted = sorted([sorted(triplet) for triplet in expected])

        if results_sorted == expected_sorted:
            print(f"Test case {i+1} passed")
        else:
            print(f"Test case {i+1} failed")
            print(f"Expected: {expected_sorted}")
            print(f"Got: {results_sorted}")
test_three_sum(Solution)
Test case 1 passed
Test case 2 passed
Test case 3 passed
Test case 4 passed
Test case 5 passed
Test case 6 passed
Test case 7 passed

Initial Results

My initial attempt passed NeetCode but is too slow to pass LeetCode, failing the last 2 of 313 testcases. I also created a somewhat messy algorithm that used a hashmap to store sums from pairs of values in an attempt to trade off some memory for time, but that failed the speed test too.

Having wasted enough time on this problem, I gave up and looked at the NeetCode solution. However, it later occurred to me that we could store the triplets in a set, then we wouldn’t have to check if they existed in the set, as the set would handle this naturally. The slightly improved algorithm is below.

class Solution:
    def threeSum(self, nums: List[int]) -> List[List[int]]:
        triplets = set()  # Use a set to store unique triplets
        nums.sort()  # O(n log n) sorting
        
        for idx, value in enumerate(nums):  # O(n)
            l = idx + 1
            r = len(nums) - 1

            # Skip duplicate elements to avoid duplicate triplets
            if idx > 0 and value == nums[idx - 1]:  
                continue
            
            while l < r:  # O(n) for each iteration
                triplet = (value, nums[l], nums[r])  # Use tuple for immutability (needed by set)
                triplet_sum = sum(triplet)
                
                if triplet_sum > 0:
                    r -= 1  
                elif triplet_sum < 0:
                    l += 1 
                else:
                    triplets.add(triplet)  # O(1) on average
                    l += 1  
                                 
        return list(map(list, triplets)) # Convert tuple back to list

This one finally passed the LeetCode submission, but is unfortunately very slow. It beats a mere 19% on runtime and 5% on memory. Pretty terrible, but good enough to pass. Worst-case complexity is:

  • Time complexity: \(O(n^2)\) due to the nested loops
  • Space complexity: \(O(n^2)\) due to storing the results list (also the set)

NeetCode Solution

class Solution:
    def threeSum(self, nums: List[int]) -> List[List[int]]:
        result = []
        nums.sort()

        for i, a in enumerate(nums):
            if a > 0:
                break

            if i > 0 and a == nums[i - 1]:
                continue

            l, r = i + 1, len(nums) - 1
            while l < r:
                threeSum = a + nums[l] + nums[r]
                if threeSum > 0:
                    r -= 1
                elif threeSum < 0:
                    l += 1
                else:
                    result.append([a, nums[l], nums[r]])
                    l += 1
                    r -= 1
                    while nums[l] == nums[l - 1] and l < r:
                        l += 1
                        
        return result

NeetCode’s solution has two nice optimizations over my code:

  1. Early exit:
if a > 0:
    break

When the number a in nums becomes positive, the loop terminates early. Since the array is sorted, if a > 0, it’s impossible to find two other positive numbers whose sum, along with a, equals zero.

  1. Efficient handling off duplicate triplets:
while nums[l] == nums[l - 1] and l < r:
    l += 1

After finding a valid triplet, the left pointer is incremented and the additional while loop ensures we skip over any duplicate numbers adjacent to nums[l]. Thus preventing the same triplet from being added to the result multiple times.

On LeetCode, this nice algorithm beats 98% on runtime and 85% on memory. Worst-case complexity is:

  • Time complexity: \(O(n^2)\)
  • Space complexity: \(O(n^2)\)