Friday, March 23, 2018

LeetCode 4. Median of Two Sorted Arrays

LeetCode 4

Yifeng Zeng

Description

There are two sorted arrays nums1 and nums2 of size m and n respectively.
Find the median of the two sorted arrays. The overall run time complexity should be O(log (m+n)).
Example 1:
nums1 = [1, 3]
nums2 = [2]
The median is 2.0
Example 2:
nums1 = [1, 2]
nums2 = [3, 4]
The median is (2 + 3)/2 = 2.5

Idea Report

Since both input arrays are sorted and the time complexity is required to be O(log(m+n)), we can utilize the method of binary search, basically we want to throw out about half of the impossible candidates at a time. To get the median, we can just find the kth element in the two input arrays, the k may depends on the length of the two input arrays (the len / 2 + 1 small). So this problem is reduced to a problem that we want to find the kth element in two sorted array. To find the kth element, we can compare the k/2-th element in each array (n1 and n2), it is safe to throw away the smaller element and the elements before it because it's garanteed that the result will not be there. Then we can recurcively search from that index, and since we throw away k/2 numbers, we only need to look for another k - k/2 numbers. Until k == 1, we found the result.
Code:
class Solution {
    public double findMedianSortedArrays(int[] nums1, int[] nums2) {
        int len = nums1.length + nums2.length;
        if (len % 2 == 1) {
            return findKth(nums1, nums2, 0, 0, len / 2 + 1);
        }
        return (findKth(nums1, nums2, 0, 0, len / 2)
                + findKth(nums1, nums2, 0, 0, len / 2 + 1)) / 2.0;
    }

    private int findKth(int[] nums1, int[] nums2, int i, int j, int k) {
        if (i >= nums1.length) {
            return nums2[j + k - 1];
        }
        if (j >= nums2.length) {
            return nums1[i + k - 1];
        }

        if (k == 1) {
            return Math.min(nums1[i], nums2[j]);
        }

        int n1 = i + k / 2 - 1 >= nums1.length ?
                 Integer.MAX_VALUE : nums1[i + k / 2 - 1];
        int n2 = j + k / 2 - 1 >= nums2.length ?
                 Integer.MAX_VALUE : nums2[j + k / 2 - 1];
        if (n1 < n2) {
            return findKth(nums1, nums2, i + k / 2, j, k - k / 2);
        }
        return findKth(nums1, nums2, i, j + k / 2, k - k / 2);
    }
}

Summary

  • Throw away half of candidates at a time, binary seach methodology.

LeetCode 215. Kth Largest Element in an Array

LeetCode 215

Yifeng Zeng

Description

Find the kth largest element in an unsorted array. Note that it is the kth largest element in the sorted order, not the kth distinct element.
For example,
Given [3,2,1,5,6,4] and k = 2, return 5.
Note: You may assume k is always valid, 1 ≤ k ≤ array's length.

Idea Report

I always have a primitive idea to solve "top k" problem which is using a heap. We can manually maintain a size k min heap, where this min heap contains k largest number in the input array. When new a number n comes in, we compare n to the smallest of the k numbers in the heap (heap.peek()) and see if n is larger, if it is larger, then the we poll out the smallerst number in the min heap because it cannot be the kth largest number, then put n in the min heap as a candidate. In the end, the smallest number in the min heap is the kth largest number in the array. This takes O(nlogk), because we have n numbers and for each number we need logk time to update the heap.
Code
class Solution {
    // AC
    public int findKthLargest(int[] nums, int k) {
        Queue<Integer> pq = new PriorityQueue<>(); // min heap
        for (int n : nums) {
            if (pq.size() < k) {
                pq.offer(n);
            } else {
                if (pq.peek() < n) {
                    pq.poll();
                    pq.offer(n);
                }
            }
        }
        return pq.peek();
    }
}
Another way is to utilize the method of the binary search, basically throw out half of the impossible candidates at a time and sort the other half possible candidates, similar to quick sort. We first decrease k by one then we can just find the element at index k where the array is partially sorted in decreasing order. Like quick sort, each time we choose a pivot (I like to use the mid between start and end), and we move numbers larger to pivot to the left of the array, and move number smaller to pivot to the right of the array. And then we check index k is at the left part (larger than pivot), or at the right part (smaller than pivot). We choose the correct part and look for the kth index recursively, and ditch throw away the other part which is not longer possible. Let me give an example, we can just use the given example [3,2,4,5,6], k = 2. 's' is start, 'e' is end, 'p' is pivot, 'l' is left, 'r' is right. We decrese k by 1 to find the index k = 1.
  • Step 1, we see nums[l] = 3 <= p = 4, nums[r] = 6 >= p = 4, so we swap them.
  • Step 2, we see nums[l] = 2 <= p = 4, nums[r] = 5 >= p = 4, so we swap them.
  • Step 3, we see nums[l] = 4 <= p = 4, nums[r] = 5 >= p = 4, so we swap them as well, after swapping, index r = 1, index l = 3, the array are splitted into three parts.
    • Part 1, [s,r] = [0,1], which has all elements larger or equal to pivot
    • Part 2, (r, l) = (1,3) = [2], which has all elements equal to pivot
    • Part 3, [l,e] = [3,4], which has all elements smaller or equal to pivot
  • Step 4, comparing indices s,r,l,e to index k, we know which part index k belongs to, then we recursively do the search in that part.
  • Step 5, in our example, k = 1, so k belongs to Part 1, we do search recursively in Part 1. Now s = 0, e = 1, p = 6. Becaues nums[r] = nums[1] = 5 < pivot = 6, we move r to left.
  • Step 6, we see nums[l] = 6 <= p = 6, nums[r] = 6 >= p = 6, so we swap them, and now r = -1, l = 1.
  • Step 7, we check k = 1, belongs to part 3 [l,e] = [1,1], we seach recursively
  • Step 8, we find s == e = 1, so we just return nums[s] = 5. So we find the final result 5. The time complexity is O(n), because we throw away about half of the candidates at a time and only care about the other half which has the correct answer.
1. [3,2,4,5,6], p = 4
    s       e
    l       r  -> swap
2. [6,2,4,5,3]
    s       e
      l   r    -> swap
3. [6,5,4,2,3]
    s        e
        lr     -> swap
4. [6,5,4,2,3]
    s        e
      r   l    -> throw away half
5. [6,5,4,2,3], p = 6
    s e
    l r        -> move r to left because nums[r] < pivot
6. [6,5,4,2,3], p = 6
    s e
    l
    r          -> swap
7. [6,5,4,2,3], p = 6
    s e
      l
  r = - 1      -> l <= k && k <= e, recursive
8. [6,5,4,2,3], p = 6
      s
      e        -> s == e, return nums[s] = 5
Code
class Solution {
    // AC
    public int findKthLargest(int[] nums, int k) {
        // k is the index now.
        return quickSelect(nums, 0, nums.length - 1, k - 1);
    }

    private int quickSelect(int[] nums, int start, int end, int k) {
        if (start == end) {
            return nums[start];
        }

        int left = start;
        int right = end;
        int mid = (end - start) / 2 + start;
        int pivot = nums[mid];

        while (left <= right) {
            while (left <= right && nums[left] > pivot) {
                left++;
            }
            while (left <= right && nums[right] < pivot) {
                right--;
            }
            if (left <= right) {
                swap(nums, left, right);
                left++;
                right--;
            }
        }

        if (start <= k && k <= right) {
            return quickSelect(nums, start, right, k);
        } else if (left <= k && k <= end) {
            return quickSelect(nums, left, end, k);
        }
        return nums[right + 1];
    }

    private void swap(int[] nums, int i, int j) {
        int t = nums[i];
        nums[i] = nums[j];
        nums[j] = t;
    }
}

Summary

  • A "top k" problem may be solved by maintaining a heap.
  • Throw away half of candidates at a time, binary seach methodology.

LeetCode 307. Range Sum Query - Mutable

LeetCode 307

Yifeng Zeng

Description

Given an integer array nums, find the sum of the elements between indices i and j (i ≤ j), inclusive.
The update(i, val) function modifies nums by updating the element at index i to val.
Example:
Given nums = [1, 3, 5]
sumRange(0, 2) -> 9
update(1, 2)
sumRange(0, 2) -> 8

Idea Report

We are asked to get the sum of the elments in a range, so my primitive idea is to use a special data structure called prefix sum array, int[] prefixSum. For input array nums, prefixSum[i] means the sum of first ith elements in nums array, and prefixSum[0] is 0. For example, for input nums = [1,3,5] then prefixSum is [0,1,4,9]. prefixSum[0] is basically 0, prefixSum[1] is the first element which is 1, prefixSum[2] is the sum of first 2 elements, which is 1 + 3 = 4, and prefixSum[3] is the sum of first 3 elements, which is 1 + 3 + 5 = 9. When query the sum ranging from i to j we can just return prefixSum[j + 1] - prefixSum[i] which is sum of index [0, 1, 2,...,i,...,j] - sum of index [0, 1, 2,...,i-1] which is exactly sum of index [i,i+1,...j-1,j]. To update the prefixSum at index i to value val, we can just find out all the prefix sums that contains the index i number, substract the old value and plus the new value. The old value can be obtains by query(i, i). To build the prefix sum array, uses O(n) time, to query(i, j) uses O(1) time, to update val at index i uses O(n - i) time which is essentially O(n) time.
Code
class NumArray {
    // TLE
    int[] prefixSum;
    public NumArray(int[] nums) {
        prefixSum = new int[nums.length + 1];
        for (int i = 1; i < prefixSum.length; i++) {
            prefixSum[i] = prefixSum[i - 1] + nums[i - 1];
        }
    }

    public void update(int i, int val) {
        int oldValue = sumRange(i, i);
        for (int j = i + 1; j < prefixSum.length; j++) {
            prefixSum[j] = prefixSum[j] - oldValue + val;
        }
    }

    public int sumRange(int i, int j) {
        return prefixSum[j + 1] - prefixSum[i];
    }
}
The prefix sum method takes O(n) time to update which is relatively time consuming when the input size is large. And the problems states that "calls to update and sumRange function is distributed evenly." so we need a better approach than O(n), which is O(logN). So we can think of a tree structure, an advanced data structure is called segment tree. Take the same input example [1,3,4], we can draw a tree like follows:
              [0, 2 : 8]
            /           \
      [0, 1 : 4]         [2, 2 : 4]
      /        \
[0, 0 : 0]     [1, 1 : 3]
Each node has three integer and to children. The first integer is start, meaning the start index of the array. The second integer is end, meaning the end index of the array. The third integer is sum, meaning teh sum of elements between start and end inclusive. So for each leave node, the sum is the element value at that index, for example [1,1:3] is the node indicates index start from 1 end at 1 which is nums[1], so the sum is nums[1] = 3. For other nodes, the sum is a range, for example [0,1:4] is the node indexates index start from 0 end at 1 which is sum of nums[0], nums[1], which is 4. To query, we can check if the query index i, j matches node.start, node.end, if it matches we just return node.sum; if it doesn't match, we can split the indexes into two parts and find the result either from left child or right child or both. For example, if we want to query [0,2], it matches root we can just return 8 directly. If we want to query [1,2], we see node[0,2:8] doesn't match, then we devide query into two halves, to query [1,1] from left child, and [2,2] from right child (which we get [2,2:4] directly). From left child[0,1:4], the indexes doesn't match [1,1], so we split again to get from [1,1:3]. And we sum together 3 + 4 together to get query[1,2] = 7. To update the value at index i, we can find the [i,i:value] node, and update all nodes along the path from root to node [i,i:value] by substracting the old value and plus the new value.
Code
class NumArray {

    public NumArray(int[] nums) {
        root = build(nums, 0, nums.length - 1);
    }

    public void update(int i, int val) {
        int oldVal = query(root, i, i);
        modify(root, i, val, oldVal);
    }

    public int sumRange(int i, int j) {
        return query(root, i, j);
    }

    class Node {
        int start;
        int end;
        int sum;
        Node left;
        Node right;
        public Node(int start, int end, int sum) {
            this.start = start;
            this.end = end;
            this.sum = sum;
        }
    }

    Node root;

    private Node build(int[] nums, int start, int end) {
        if (start > end) {
            return null;
        }

        Node node = new Node(start, end, 0);

        if (start == end) {
            node.sum = nums[start];
            return node;
        }

        int mid = (end - start) / 2 + start;
        node.left = build(nums, start, mid);
        node.right = build(nums, mid + 1, end);
        if (node.left != null) {
            node.sum += node.left.sum;
        }
        if (node.right != null) {
            node.sum += node.right.sum;
        }
        return node;
    }

    private int query(Node root, int start, int end) {
        if (start > end) {
            return 0;
        }
        if (start <= root.start && root.end <= end) {
            return root.sum;
        }
        int rootMid = (root.end - root.start) / 2 + root.start;
        if (end <= rootMid) {
            return query(root.left, start, end);
        } else if (rootMid < start) {
            return query(root.right, start, end);
        }
        return query(root.left, start, rootMid) + query(root.right, rootMid + 1, end);
    }

    private void modify(Node root, int i, int val, int oldVal) {
        if (root == null) {
            return;
        }
        root.sum = root.sum - oldVal + val;
        if (root.start == i && root.end == i) {
            return;
        }

        int rootMid = (root.end - root.start) / 2 + root.start;
        if (i <= rootMid) {
            modify(root.left, i, val, oldVal);
        } else {
            modify(root.right, i, val, oldVal);
        }
    }
}

Summary

  • Prefix sum takes O(1) to get the sum of a range in the array, which would be a good choice if the array is immutable.
  • Segment tree takes O(n) time to build and O(logN) time to query and update.

LeetCode 436. Find Right Interval

LeetCode 436

Yifeng Zeng

Description

Given a set of intervals, for each of the interval i, check if there exists an interval j whose start point is bigger than or equal to the end point of the interval i, which can be called that j is on the "right" of i.
For any interval i, you need to store the minimum interval j's index, which means that the interval j has the minimum start point to build the "right" relationship for interval i. If the interval j doesn't exist, store -1 for the interval i. Finally, you need output the stored value of each interval as an array.
Note:
You may assume the interval's end point is always bigger than its start point.
You may assume none of these intervals have the same start point.
Example 3:
Input: [ [1,4], [2,3], [3,4] ]
Output: [-1, 2, -1]
Explanation:
There is no satisfied "right" interval for [1,4] and [3,4]. For [2,3], the interval [3,4] has minimum-"right" start point.

Idea Report

The primitive idea is to select two intervals and compare them, which takes O(n^2) time. We can think of nlogn which is sort. But is we sort by start, or sort by end, we cannot find a good way to solve the problem, so we can try to sort both start and end, and we need store where was this start, and end (using Point.i) and which interval it belongs to (using Point.intervalIndex). And we need to differentiate it is a start or end (using Point.isStart). We sort Point by Point.i, and we put a end Point before a start Point. So for an input {[1,4],[2,3],[3,4]}, we can get the Point array list: [1,1,0],[4,0,0],[2,1,1],[2,0,1],[3,1,2],[4,0,2], after sorting is: [1,1,0],[2,0,1],[2,1,1],[3,1,2],[4,0,0],[4,0,2]. We scan from the end to begining. And initialize our rightInterval as -1. For a Point which is an end Point, we just assign it to the rightInterval (res[cur.intervalIndex] = rightInterval), for a start Point, we update it with it's intervalIndex (rightInterval = cur.intervalIndex).
Code
class Solution {
    class Point {
        int i;
        int isStart;
        int intervalIndex;
        public Point(int i, int isStart, int intervalIndex) {
            this.i = i;
            this.isStart = isStart; // 1 is true, 0 is false
            this.intervalIndex = intervalIndex;
        }
    }

    public int[] findRightInterval(Interval[] intervals) {
        List<Point> list = new ArrayList<>();
        int min = Integer.MAX_VALUE;
        int max = Integer.MIN_VALUE;
        for (int i = 0; i < intervals.length; i++) {
            list.add(new Point(intervals[i].start, 1, i));
            list.add(new Point(intervals[i].end, 0, i));
            min = Math.min(min, intervals[i].start);
            max = Math.max(max, intervals[i].end);
        }

        Collections.sort(list, (a, b) -> (a.i == b.i) ?
                         a.isStart - b.isStart : a.i - b.i);

        int[] res = new int[intervals.length];
        int rightInterval = -1;
        for (int i = list.size() - 1; i >= 0; i--) {
            Point cur = list.get(i);
            if (cur.isStart == 0) {
                res[cur.intervalIndex] = rightInterval;
            } else {
                rightInterval = cur.intervalIndex;
            }
        }
        return res;
    }
}

Summary

  • Sweepline method.
  • Similar to 252. Meeting Rooms, 253. Meeting Rooms II
  • Because of time limitation, I still need to add other approach and revise the writing.

Follow up

Qinyuan introduced double pointer model, O(nlogn) sort, O(n) scan.
class Solution {
    // AC
    class Point {
        int i;
        int intervalIndex;
        public Point(int i, int index) {
            this.i = i;
            this.intervalIndex = index;
        }
    }

    public int[] findRightInterval(Interval[] intervals) {
        List<Point> start = new ArrayList<>();
        List<Point> end = new ArrayList<>();
        for (int i = 0; i < intervals.length; i++) {
            start.add(new Point(intervals[i].start, i));
            end.add(new Point(intervals[i].end, i));
        }

        Collections.sort(start, (a, b) -> a.i - b.i);
        Collections.sort(end, (a, b) -> a.i - b.i);
        int[] res = new int[intervals.length];
        Arrays.fill(res, -1);
        int j = 0;
        for (int i = 0; i < end.size(); i++) {
            Point pointEnd = end.get(i);
            while (j < start.size() && pointEnd.i > start.get(j).i) {
                j++;
            }
            if (j >= start.size()) {
                break;
            }
            res[pointEnd.intervalIndex] = start.get(j).intervalIndex;
        }
        return res;
    }
}
TreeMap.ceilingEntry()
class Solution {
    public int[] findRightInterval(Interval[] intervals) {
        int[] res = new int[intervals.length];
        TreeMap<Integer, Integer> relationships = new TreeMap<>();

        for (int i = 0; i < intervals.length; i++) {
            relationships.put(intervals[i].start, i);
        }

        for (int i = 0; i < intervals.length; i++) {
            Map.Entry<Integer, Integer> entry =
                relationships.ceilingEntry(intervals[i].end);
            res[i] = entry == null ? -1 : entry.getValue();
        }

        return res;
    }
}
Bucket sort
class Solution {
    public int[] findRightInterval(Interval[] intervals) {
        int min = Integer.MAX_VALUE;
        int max = Integer.MIN_VALUE;
        for (int i = 0; i < intervals.length; i++) {
            min = Math.min(min, intervals[i].start);
            max = Math.max(max, intervals[i].end);
        }

        int[] bucket = new int[max - min + 1];
        Arrays.fill(bucket, -1);
        for (int i = 0; i < intervals.length; i++) {
            bucket[intervals[i].start - min] = i;
        }

        for (int i = bucket.length - 2; i >= 0; i--) {
            if (bucket[i] == -1) {
                bucket[i] = bucket[i + 1];
            }
        }

        int[] res = new int[intervals.length];
        for (int i = 0; i < intervals.length; i++) {
            res[i] = bucket[intervals[i].end - min];
        }

        return res;
    }
}