文章

LeetCode315 Count of Smaller Numbers After Self

题目

传送门

Given an integer array nums, return an integer array counts where counts[i] is the number of smaller elements to the right of nums[i].

这是一个 Hard 级别的题,描述很简单,暴力解法也比较容易想出来:

遍历数组,针对每个元素,遍历它后面的元素得出小于它的个数。时间复杂度为 $O(N^2)$,空间复杂度为 $O(1)$。

这题可以利用归并排序的过程顺便解出。

归并排序

归并排序的思想是把原始数组不断二分,直到分成单个元素,然后两两合并,在合并的过程中排序。

例如对于数组 [6, 5 ,2, 1] 归并排序的流程如下:

不难看出归并排序是一个递归与回溯的过程,代码写法如下(java):

private static void mergeSort(int[] nums) {
  int[] tmp = new int[nums.length];
  msort(nums, 0, nums.length - 1, tmp);
}

private static void msort(int[] nums, int l, int r, int[] tmp) {
  if (l < r) {
    int mid = (l + r) / 2;
    msort(nums, l, mid, tmp);
    msort(nums, mid + 1, r, tmp);
    merge(nums, l, mid, r, tmp);
  }
}

private static void merge(int[] nums, int l, int m, int r, int[] tmp) {
  // [l,m] 是左侧候选元素;[m+1,r] 是右侧候选元素
  int pl = l, pr = m + 1, p = l;
  while (pl <= m && pr <= r) {
    if (nums[pl] <= nums[pr])
      tmp[p++] = nums[pl++];
    else
      tmp[p++] = nums[pr++];
  }
  while (pl <= m)
    tmp[p++] = nums[pl++];
  while (pr <= r)
    tmp[p++] = nums[pr++];
  // 把本轮排序结果复制到原数组供下一轮使用
  while (l <= r) {
    nums[l] = tmp[l];
    l++;
  }
}

计算本题结果

归并排序的合并过程中,每一轮合并,右侧候选元素在原始数组中,也一定在左侧候选元素的右边。那么,合并时每当检测到右侧元素更小时,就意味这它比左侧的当前元素,和其之后的元素更小(因为两侧候选数组各自都是有序的)。

说起来比较抽象,实际演示一下。

划分的结果为 [6] [5] [2] [1],此时还没有开始计算答案,默认为 6(0) 5(0) 2(0) 1(0)

合并 [6] [5] 时,因为右侧的 5 更小,就代表找到了一个在 6 右边,并且比 6 小的元素。此时答案数组更新为 6(1) 5(0) 2(0) 1(0)

接着合并 [2] [1],同理,因为 1 更小并且在右边,所以要更新答案为 6(1) 5(0) 2(1) 1(0)

最后合并 [5,6] [1,2]。首先检测到 1 比 5 小(也一定比 5 之后的元素小)并且在右边,所以更新答案数组:6(2) 5(1) 2(1) 1(0)。接着检测,2 还是比 5 小,所以继续更新答案数组:6(3) 5(2) 2(1) 1(0),这就是最终答案了。

在归并排序过程中,元素的位置发生改变,所以需要额外的数据结构来记录原始下标,才好更新答案数组。 结构如下:

static class Item {
  int val, index;
  Item(int val, int index) {
    this.val = val;
    this.index = index;
  }
}

核心过程如下:

private void merge(Item[] nums, Item[] tmp, int l, int mid, int r, List<Integer> res) {
  int pl = l, pr = mid + 1, p = l;
  while (pl <= mid && pr <= r) {
    if (nums[pl].val <= nums[pr].val) {
      tmp[p] = nums[pl];
      pl++;
    } else {
      for (int i = pl; i <= mid; i++) {
        res.set(nums[i].index, res.get(nums[i].index) + 1);
      }
      tmp[p] = nums[pr];
      pr++;
    }
    p++;
  }
  while (pl <= mid) {
    tmp[p] = nums[pl];
    p++;
    pl++;
  }
  while (pr <= r) {
    tmp[p] = nums[pr];
    p++;
    pr++;
  }

  while (l <= r) {
    nums[l] = tmp[l];
    l++;
  }
}

归并排序是时间复杂度为 $O(N \log N)$。但是我们修改后的实现在合并的时候进行了内部遍历来更新答案数组,使时间复杂度变成了 $O(N^2 \log N)$。

优化

为了避免多次内部遍历,可以先记录找到的符合题目要求的元素个数,在合适的时候一次性更新答案数组。

更新后的完整代码如下:

class Solution {
  static class Item {
    int val, index;

    Item(int val, int index) {
      this.val = val;
      this.index = index;
    }
  }

  public List<Integer> countSmaller(int[] nums) {
    List<Integer> res = new ArrayList<>(nums.length);
    Item[] newNums = new Item[nums.length];
    for (int i = 0; i < nums.length; i++) {
      res.add(0);
      newNums[i] = new Item(nums[i], i);
    }
    Item[] tmp = new Item[nums.length];


    msort(newNums, tmp, 0, nums.length - 1, res);
    return res;
  }

  private void msort(Item[] nums, Item[] tmp, int l, int r, List<Integer> res) {
    if (l < r) {
      int mid = (l + r) / 2;
      msort(nums, tmp, l, mid, res);
      msort(nums, tmp, mid + 1, r, res);
      merge(nums, tmp, l, mid, r, res);
    }
  }

  private void merge(Item[] nums, Item[] tmp, int l, int mid, int r, List<Integer> res) {
    int findEleCount = 0;
    int pl = l, pr = mid + 1, p = l;
    while (pl <= mid && pr <= r) {
      if (nums[pl].val <= nums[pr].val) {
        res.set(nums[pl].index, res.get(nums[pl].index) + findEleCount);
        tmp[p] = nums[pl];
        pl++;
      } else {
        findEleCount++; // 先记录个数,等待统一更新
        tmp[p] = nums[pr];
        pr++;
      }
      p++;
    }
    while (pl <= mid) {
      res.set(nums[pl].index, res.get(nums[pl].index) + findEleCount);
      tmp[p] = nums[pl];
      p++;
      pl++;
    }
    while (pr <= r) {
      tmp[p] = nums[pr];
      p++;
      pr++;
    }

    while (l <= r) {
      nums[l] = tmp[l];
      l++;
    }
  }
}

这样时间复杂度又回到了 $O(N \log N)$。