文章

LeetCode4 两个有序数组的中位数

题目

传送门:4. Median of Two Sorted Arrays

这是一道 Hard 级别的题目,主要难点是要求实现 $O(\log(m+n))$ 的时间复杂度。

暴力算法

最容易想到的就是双指针,分别指向两个数组的首个元素。比较两个指针元素的大小,向后移动较小的一个并计数。此算法时间复杂度为 $O(m+n)$,其实不算差,但不满足题目要求。

二分排除

  • 时间复杂度:$ O(\log(k)) $ = $ O(\log(m+n)) $
  • 空间复杂度:$ O(1) $

首先「求中位数」可以转换为「求第 k 小的数 (k>=1)」,根据总元素个数 totalNum 的不同,就可以轻松得出中位数:

  • totalNum 为奇数时,中位数是第 totalNum / 2 小的数。
  • totalNum 为偶数时,中位数是第 totalNum / 2 小的数与第 totalNum / 2 + 1 小的数的平均数。

O(N) 时间复杂度降低到 O(log N),最常见的就是二分法。暴力算法中每次排除一个元素,现在可以尝试每次排除 k/2 个元素。当然这里 k 是动态的。

给定数组 nums1nums2 以及各自的查找范围 [start1, end1], [start2, end2]。我们比较它们各自的第 k/2 个元素,也就是下标为 start+k/2-1 (k>=2) 的元素,令:

  • halfKIndex1 = star1 + k/2 - 1
  • halfKIndex2 = star2 + k/2 - 1

(注意此处暂未考虑数组长度不够的情况)

那么有三个可能性:

  • nums1[halfKIndex1] > nums2[halfKIndex2]:这意味着 nums2 的前 k/2 个元素不可能是要找的。因为此时在 nums2[halfKIndex2] 之前(包括自己)最多只可能有 k-1 个元素,它们是:nums2[start2 .. halfKIndex2](k/2 个) 以及 nums1[start1 .. halfKIndex1-1](k/2-1 个),显然第 k 个在它们后面。只需在剩下范围内继续找第 k - (halfKIndex1 - start1 + 1) 小的元素,其中 (halfKIndex1 - start1 + 1) 是本次排除的元素个数。不直接用 k/2 是因为数组可能没那么长。
  • nums1[halfKIndex1] < nums2[halfKIndex2]:同理可以排除 nums1 的前 k/2 个元素。
  • nums1[halfKIndex1] = nums2[halfKIndex2]:同理,可以扔掉任意一个(因为相等,留哪个都一样)。

注意,nums1[halfKIndex1] > nums2[halfKIndex2] 时不可以排除 nums1[halfKIndex1] 之前的元素,因为我们无从得知这些元素与 nums2[halfKIndex2] 之后元素的大小关系。已知:A < B, D < E < F, B > E,无法得出 AF 的关系。

例如:nums1 = [4,5], nums2 = [1,2,3,6],令 k=4, halfKIndex1=1,但第 4 小的数却是 nums1[halfKIndex1-1]

另外注意一下边界情况:

  • 若有数组的查找范围是空,则另一个数组中的第 k 个元素就是答案。
  • k==1 则只需比较两个数组的第一个元素,取较小的一个。
class Solution {
    public double findMedianSortedArrays(int[] nums1, int[] nums2) {
        int totalNums = nums1.length + nums2.length;
        if (totalNums % 2 == 0) {
            return (findKthNum(nums1, 0, nums2, 0, totalNums / 2) +
                    findKthNum(nums1, 0, nums2, 0, totalNums / 2 + 1)) / 2.0;
        } else {
            return findKthNum(nums1, 0, nums2, 0, totalNums / 2 + 1);
        }
    }

    private int findKthNum(int[] nums1, int start1, int[] nums2, int start2, int k) {
        if (start1 >= nums1.length)
            return nums2[start2 + k - 1];
        if (start2 >= nums2.length)
            return nums1[start1 + k - 1];
        if (k == 1)
            return Math.min(nums1[start1], nums2[start2]);

        int halfKIndex1 = start1 + Math.min(nums1.length - start1, k / 2) - 1;
        int halfKIndex2 = start2 + Math.min(nums2.length - start2, k / 2) - 1;

        if (nums1[halfKIndex1] > nums2[halfKIndex2]) {
            return findKthNum(nums1, start1, nums2, halfKIndex2 + 1, k - (halfKIndex2 - start2 + 1));
        } else {
            return findKthNum(nums1, halfKIndex1 + 1, nums2, start2, k - (halfKIndex1 - start1 + 1));
        }
    }
}

二分分割

  • 时间复杂度:$ O(\log(\min(m,n))) $
  • 空间复杂度:$ O(1) $

所谓中位数,就是把一个有序序列分成两个长度相同序列的元素(根据元素个数奇偶的不同略有区别)。如果有两个序列,可以推广为找两个分割点,分别把它们分成两个部分:A 分成 A1, A2,B 分成 B1, B2,且满足两个条件:

  1. len(A1) + len(B1) = len(A2) + len(B2),相当于单个数组中左右两部分元素个数相同。
  2. max(A1, B1) <= min(A2, B2),相当于单个数组中左边所有元素均小于或等于右边。

由于第一个条件的存在,实际上只要找一个分割点,另一个可直接计算出来。对于长度为 n 的数组,可能的分割点有 n+1 个。显然我们应该从较短的那个数组来尝试。若枚举分割点则时间复杂度为 $O(\min(m,n))$,但若使用二分查找寻找分割点,则可以优化为 $O(\log(\min(m,n)))$。

若 A 的分割点太靠左,可以想象,此时 A 左侧的最大值会偏小,而 B 右侧的最小值会偏大,由此可以推测出分割点往右移动的条件,反之则应向左移动。

class Solution {
    public double findMedianSortedArrays(int[] nums1, int[] nums2) {
        // make sure nums1 is shorter than nums2
        if (nums1.length > nums2.length) {
            return findMedianSortedArrays(nums2, nums1);
        }

        int totalNums = nums1.length + nums2.length;

        // the split point range of nums1
        int left = 0, right = nums1.length;

        while (left <= right) {
            // [split, len-1] belongs to the right part
            int split1 = left + (right - left) / 2;
            int split2 = (nums1.length - 2 * split1 + nums2.length) / 2;

          	// corner cases
            int leftMax1 = split1 >= 1 ? nums1[split1 - 1] : Integer.MIN_VALUE;
            int leftMax2 = split2 >= 1 ? nums2[split2 - 1] : Integer.MIN_VALUE;
            int rightMin1 = split1 < nums1.length ? nums1[split1] : Integer.MAX_VALUE;
            int rightMin2 = split2 < nums2.length ? nums2[split2] : Integer.MAX_VALUE;

            int leftMax = Math.max(leftMax1, leftMax2);
            int rightMin = Math.min(rightMin1, rightMin2);
            if (leftMax <= rightMin) {
                if (totalNums % 2 == 0) {
                    return (leftMax + rightMin) / 2.0;
                } else {
                    return rightMin;
                }
            } else if (leftMax1 < rightMin2) {
                // split point should move towards right
                left = split1 + 1;
            } else if (leftMax1 > rightMin2) {
                // split point should move towards left
                right = split1 - 1;
            }
        }
        throw new IllegalStateException();
    }
}