LeetCode:Top K问题

Top K问题,即找到第K大(小)、最大(小)的K个元素一类问题。

第K个最大元素

第K大、第K小的转换

题目中的第K大的index是从1开始的,转成index从0开始则为第K-1大,可以继续转换成第N-K小。

直接排序

最简单的办法就是直接排序数组,然后取第K大的元素即可,但是复杂度较高。

快速选择算法

在快速排序算法基础上稍作改动,就得到快速选择算法。思路是每次partition结束,然后根据pivot和K的大小关系,下一步只需要处理分割后的其中一侧即可,因此提高了性能。

对于第K小的情况(第K大同理),因为只需要得到第K小的元素,partition分割后,假设pivot小于K,也就是 start < pivot < K < end,则 start~pivot 的元素肯定都比第K个元素小,因此下一步只需要处理pivot右侧也就是 pivot+1 ~ end 的部分,左侧部分直接不用管了。这个思路有点像二分法(写出来的代码也很像)。

复杂度:

  • 时间复杂度 O(N)
  • 空间复杂度 O(1)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
public int findKthLargest(int[] nums, int k) {
// 转换成第k小
k = nums.length - k;
int start = 0, end = nums.length - 1;
while (start < end) {
int pivot = partition(nums, start, end);
if (pivot == k) {
break;
} else if (pivot < k) {
start = pivot + 1;
} else {
end = pivot - 1;
}
}
return nums[k];
}

public int partition(int[] nums, int start, int end) {
// ...
}

也可以用堆实现。维护一个小顶堆,将数据逐一添加进去,堆的大小超过K时,就把最小的从队列头部移除。遍历完成后,堆里面保存的就是最大的K个数,堆顶就是第K大的数。

  • 时间复杂度 O(N·logK)
  • 空间复杂度 O(K)

因为复杂度和K的大小有关,这种解法还能进一步优化,判断K和N+1-K哪个更小,如果N+1-K更小,可以转换成求解第N+1-K小,耗时更少了。

1
2
3
4
5
6
7
8
9
10
public int findKthLargest(int[] nums, int k) {
PriorityQueue<Integer> pq = new PriorityQueue<>(); // 小顶堆
for (int val : nums) {
pq.add(val);
if (pq.size() > k) { // 维护堆的大小为 K
pq.poll();
}
}
return pq.peek();
}

前 K 个高频元素

和前一道题有两个区别:

  • 一个是找第K个数,一个是找前K个数。
  • 这道题需要先预处理,把数组转成frequency的HashMap,然后根据出现频率排序。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
public List<Integer> topKFrequent(int[] nums, int k) {
// number -> frequency
Map<Integer,Integer> frequencies = new HashMap<>();
for (int num : nums) {
int f = frequencies.getOrDefault(num, 0) + 1;
frequencies.put(num, f);
}
// min heap
PriorityQueue<Integer> heap = new PriorityQueue(new Comparator<Integer>() {
public int compare(Integer a, Integer b) {
// compare a,b by frequency
return frequencies.get(a) - frequencies.get(b);
}
});
// put each number to heap
for (int num : frequencies.keySet()) {
heap.add(num);
if (heap.size() > k) {
heap.poll();
}
}
// heap to list
return new ArrayList<Integer>(heap);
}

快速选择

题目没有要求必须按顺序输出结果,因此也可以用快速选择解决。

注意,如果要求按顺序输出结果,快速选择不能实现。

分析:假设K=60,有100个元素,第一次partition返回pivot为50,之后会在51~99之间继续partition,此时0~49只能符合前K大的要求,但并没有被排序,最后输出来的顺序就是乱的。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
class Solution {
public List<Integer> topKFrequent(int[] nums, int k) {
// number -> frequency
Map<Integer,Integer> frequencies = new HashMap<>();
for (int num : nums) {
int f = frequencies.getOrDefault(num, 0) + 1;
frequencies.put(num, f);
}
// [num, freq], [num, freq], ...
int[][] array = new int[frequencies.size()][2];
int index = 0;
for (Map.Entry<Integer, Integer> entry : frequencies.entrySet()) {
array[index++] = new int[]{entry.getKey(), entry.getValue()};
}
// quick select
quickSelect(array, 0, array.length-1, k);
// for (int[] a : array) {
// System.out.print(a[0] + ":" + a[1] + ", ");
// }
List<Integer> result = new ArrayList<Integer>(k);
for (int i = 0; i < k; ++i) {
result.add(array[i][0]);
}
return result;
}

public void quickSelect(int[][] array, int start, int end, int k) {
while (start < end) {
int pivot = partition(array, start, end);
if (pivot == k) {
break;
} else if (pivot < k) {
start = pivot + 1;
} else {
end = pivot - 1;
}
}
}

public int partition(int[][] array, int start, int end) {
swap(array, start, new Random().nextInt(end - start) + start);
int pivot = array[start][1];
int i = start, j = end;
while (i < j) {
while (i < j && array[j][1] <= pivot) --j;
while (i < j && array[i][1] >= pivot) ++i;
if (i < j) {
swap(array, i, j);
}
}
swap(array, start, i);
return i;
}

public void swap(int[][] array, int a, int b) {
int[] tmp = array[a];
array[a] = array[b];
array[b] = tmp;
}
}

桶排序

这道题也可以使用任意排序算法,例如桶排序。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class Solution {
public List<Integer> topKFrequent(int[] nums, int k) {
// number -> frequency
Map<Integer,Integer> frequencies = new HashMap<>();
for (int num : nums) {
int f = frequencies.getOrDefault(num, 0) + 1;
frequencies.put(num, f);
}
// frequency = 0 ~ N
// create N+1 buckets
// buckets[i] : numbers with freq = i
List<Integer>[] buckets = new List[nums.length + 1];
for (Map.Entry<Integer, Integer> entry : frequencies.entrySet()) {
int f = entry.getValue();
int num = entry.getKey();
if (buckets[f] == null) {
buckets[f] = new LinkedList<>();
}
buckets[f].add(num);
}
// read k most frequent elements from buckets
List<Integer> result = new ArrayList<Integer>(k);
for (int i = buckets.length-1; i >= 0; --i) {
List<Integer> bucket = buckets[i];
if (bucket == null) {
continue;
}
int remain = k - result.size();
if (remain == bucket.size()) {
result.addAll(bucket);
break;
} else if (remain > bucket.size()) {
result.addAll(bucket);
} else {
result.addAll(bucket.subList(0, remain-1));
break;
}
}
return result;
}
}