快速选择:不完全排序的艺术

当你只需要找到第K大的元素时,为什么要把整个数组都排序呢?


什么是快速选择

快速选择(QuickSelect)是一种从无序数组中找到第K小(或第K大)元素的选择算法。它基于快速排序的分区思想,但不需要完全排序。

核心思想:每次分区后,pivot元素的位置就是它在有序数组中的最终位置

形象理解

想象你要在班级里找出第3名:

1
2
普通方法:给全班排序,然后取第3名 → O(n log n)
快速选择:每次排除一半不可能的人 → O(n)

就像打擂台赛,每轮淘汰一半选手,不需要完整排序就能找到冠军!


算法原理

核心思想

快速选择利用快速排序的 partition(分区) 操作:

graph TD
    A[选择pivot] --> B[分区操作]
    B --> C{pivot位置 vs 目标位置}
    C -->|相等| D[找到目标,返回]
    C -->|pivot靠左| E[在右半部分继续]
    C -->|pivot靠右| F[在左半部分继续]
    E --> A
    F --> A
    
    style D fill:#90EE90
    style A fill:#FFE4B5

分区操作详解

Partition将数组分为三部分:

1
2
[小于pivot的元素] | [pivot] | [大于pivot的元素]
区域1 位置p 区域2

关键性质:分区后,pivot元素在它最终排序位置上


算法步骤图解

示例:找第2大元素

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
数组:[3, 2, 1, 5, 6, 4]
目标:第2大 = 第5小(索引4)

步骤1: 选择pivot=4,分区
[3, 2, 1] [4] [5, 6]
0,1,2 3 4,5
↑ pivot在位置3

步骤2: 目标在位置4,pivot在位置3
目标 > pivot,在右半部分继续
[5, 6]

步骤3: 选择pivot=5,分区
[5] [6]
4 5
↑ pivot在位置4

步骤4: pivot位置 = 目标位置 = 4
找到答案:5 ✓

完整过程可视化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
原数组:[3, 2, 1, 5, 6, 4]
找第2大

第1轮 Partition (pivot=4):
Before: [3, 2, 1, 5, 6, 4]
After: [3, 2, 1, 4, 5, 6]

位置3,目标在右边

第2轮 Partition (pivot=6):
Range: [5, 6]
After: [5, 6]

位置5,目标在左边

第3轮 Partition (pivot=5):
单元素:5

位置4,找到!

代码实现

基础版本

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
int quickSelect(vector<int>& nums, int left, int right, int k) {
// 基准情况:只有一个元素
if (left == right) {
return nums[left];
}

// 分区操作
int pivotIndex = partition(nums, left, right);

// 判断pivot位置
if (pivotIndex == k) {
return nums[k]; // 找到目标
} else if (pivotIndex < k) {
return quickSelect(nums, pivotIndex + 1, right, k); // 在右边
} else {
return quickSelect(nums, left, pivotIndex - 1, k); // 在左边
}
}

int partition(vector<int>& nums, int left, int right) {
// 随机选择pivot,避免最坏情况
int randomIndex = left + rand() % (right - left + 1);
swap(nums[randomIndex], nums[right]);

int pivot = nums[right];
int i = left - 1; // i指向小于pivot区域的最后一个位置

for (int j = left; j < right; ++j) {
if (nums[j] <= pivot) {
++i;
swap(nums[i], nums[j]);
}
}

// 将pivot放到正确位置
swap(nums[i + 1], nums[right]);
return i + 1;
}

找第K大元素

1
2
3
4
5
int findKthLargest(vector<int>& nums, int k) {
// 第k大 = 第(n-k)小(从0开始索引)
int targetIndex = nums.size() - k;
return quickSelect(nums, 0, nums.size() - 1, targetIndex);
}

Partition过程详解

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
数组: [3, 2, 1, 5, 6, 4]
pivot = 4 (选择最后一个元素)

初始: i = -1
[3, 2, 1, 5, 6, 4]
j pivot

j=0: nums[0]=3 <= 4 ✓
i++, swap(nums[0], nums[0])
[3, 2, 1, 5, 6, 4]
i,j

j=1: nums[1]=2 <= 4 ✓
i++, swap(nums[1], nums[1])
[3, 2, 1, 5, 6, 4]
i,j

j=2: nums[2]=1 <= 4 ✓
i++, swap(nums[2], nums[2])
[3, 2, 1, 5, 6, 4]
i,j

j=3: nums[3]=5 > 4 ✗
不交换
[3, 2, 1, 5, 6, 4]
i j

j=4: nums[4]=6 > 4 ✗
不交换
[3, 2, 1, 5, 6, 4]
i j

最后: swap(nums[i+1], nums[right])
[3, 2, 1, 4, 6, 5]

pivot位置=3

复杂度分析

时间复杂度

情况 复杂度 说明
平均 O(n) 每次减半搜索空间
最好 O(n) 每次分区都在中间
最坏 O(n²) 每次分区都在边界

为什么平均是 O(n)?

1
2
3
4
5
6
第1轮:处理 n 个元素
第2轮:处理 n/2 个元素
第3轮:处理 n/4 个元素
...

总时间 = n + n/2 + n/4 + ... ≈ 2n = O(n)

这是一个几何级数!

graph LR
    A[n个元素] --> B[n/2个元素]
    B --> C[n/4个元素]
    C --> D[n/8个元素]
    D --> E[...]
    
    style A fill:#FF6B6B
    style B fill:#FFB6C1
    style C fill:#FFE4B5

与快速排序对比

算法 时间复杂度(平均) 说明
快速排序 O(n log n) 两边都要递归
快速选择 O(n) 只递归一边

关键区别:快速选择只需要递归一侧,快速排序需要递归两侧!


优化技巧

1. 随机化Pivot

问题:固定选择pivot(如最后一个元素)在有序数组中会退化为O(n²)

解决:随机选择pivot

1
2
int randomIndex = left + rand() % (right - left + 1);
swap(nums[randomIndex], nums[right]);

2. 三数取中法

从首、中、尾三个位置选择中位数作为pivot

1
2
3
4
5
6
7
8
9
10
11
int getMidPivot(vector<int>& nums, int left, int right) {
int mid = left + (right - left) / 2;

// 将三个数的中位数移到right位置
if (nums[left] > nums[mid]) swap(nums[left], nums[mid]);
if (nums[left] > nums[right]) swap(nums[left], nums[right]);
if (nums[mid] > nums[right]) swap(nums[mid], nums[right]);

swap(nums[mid], nums[right]);
return right;
}

3. 迭代版本(避免递归栈溢出)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
int quickSelectIterative(vector<int>& nums, int k) {
int left = 0, right = nums.size() - 1;

while (left < right) {
int pivotIndex = partition(nums, left, right);

if (pivotIndex == k) {
return nums[k];
} else if (pivotIndex < k) {
left = pivotIndex + 1; // 在右边
} else {
right = pivotIndex - 1; // 在左边
}
}

return nums[k];
}

经典应用

1. 第K大元素(LeetCode 215)

1
2
3
4
int findKthLargest(vector<int>& nums, int k) {
int targetIndex = nums.size() - k;
return quickSelect(nums, 0, nums.size() - 1, targetIndex);
}

示例

1
2
3
4
输入:[3,2,1,5,6,4], k=2
输出:5

解释:第2大的元素是5

2. 数组中的第K个最大元素(扩展)

找出数组中前K大的元素

1
2
3
4
5
6
7
8
vector<int> topKLargest(vector<int>& nums, int k) {
// 找到第k大元素的位置
int targetIndex = nums.size() - k;
quickSelect(nums, 0, nums.size() - 1, targetIndex);

// 返回从targetIndex开始的所有元素
return vector<int>(nums.begin() + targetIndex, nums.end());
}

3. 寻找中位数

1
2
3
4
5
6
7
8
9
10
11
12
13
double findMedian(vector<int>& nums) {
int n = nums.size();

if (n % 2 == 1) {
// 奇数个元素
return quickSelect(nums, 0, n - 1, n / 2);
} else {
// 偶数个元素
int mid1 = quickSelect(nums, 0, n - 1, n / 2 - 1);
int mid2 = quickSelect(nums, 0, n - 1, n / 2);
return (mid1 + mid2) / 2.0;
}
}

快速选择 vs 其他方法

对比表

方法 时间复杂度 空间复杂度 适用场景
快速选择 O(n) 平均 O(1) 一次性查询
排序 O(n log n) O(1) 多次查询
O(n log k) O(k) K很小时
计数排序 O(n + m) O(m) 数值范围小

选择建议

graph TD
    A{需要什么?} --> B[只查询一次]
    A --> C[多次查询]
    A --> D[K很小]
    
    B --> E[快速选择 O n]
    C --> F[排序 O n log n]
    D --> G[堆 O n log k]
    
    style E fill:#90EE90
    style B fill:#FFE4B5

实战案例

案例1:找数组中的第K个最大元素

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
// LeetCode 215
class Solution {
public:
int findKthLargest(vector<int>& nums, int k) {
int targetIndex = nums.size() - k;
return quickSelect(nums, 0, nums.size() - 1, targetIndex);
}

private:
int quickSelect(vector<int>& nums, int left, int right, int k) {
if (left == right) return nums[left];

int pivotIndex = partition(nums, left, right);

if (pivotIndex == k) {
return nums[k];
} else if (pivotIndex < k) {
return quickSelect(nums, pivotIndex + 1, right, k);
} else {
return quickSelect(nums, left, pivotIndex - 1, k);
}
}

int partition(vector<int>& nums, int left, int right) {
int randomIndex = left + rand() % (right - left + 1);
swap(nums[randomIndex], nums[right]);

int pivot = nums[right];
int i = left - 1;

for (int j = left; j < right; ++j) {
if (nums[j] <= pivot) {
swap(nums[++i], nums[j]);
}
}

swap(nums[i + 1], nums[right]);
return i + 1;
}
};

案例2:找出数组中位数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
class Solution {
public:
double findMedian(vector<int>& nums) {
int n = nums.size();

if (n % 2 == 1) {
return quickSelect(nums, 0, n - 1, n / 2);
} else {
// 需要找两个中间元素
int left = quickSelect(nums, 0, n - 1, n / 2 - 1);
int right = quickSelect(nums, 0, n - 1, n / 2);
return (left + right) / 2.0;
}
}

private:
int quickSelect(vector<int>& nums, int left, int right, int k) {
if (left == right) return nums[left];

int pivotIndex = partition(nums, left, right);

if (pivotIndex == k) {
return nums[k];
} else if (pivotIndex < k) {
return quickSelect(nums, pivotIndex + 1, right, k);
} else {
return quickSelect(nums, left, pivotIndex - 1, k);
}
}

int partition(vector<int>& nums, int left, int right) {
int randomIndex = left + rand() % (right - left + 1);
swap(nums[randomIndex], nums[right]);

int pivot = nums[right];
int i = left - 1;

for (int j = left; j < right; ++j) {
if (nums[j] <= pivot) {
swap(nums[++i], nums[j]);
}
}

swap(nums[i + 1], nums[right]);
return i + 1;
}
};

常见陷阱与技巧

陷阱1:索引转换错误

1
2
3
4
5
6
// ❌ 错误:直接使用k
return quickSelect(nums, 0, nums.size() - 1, k);

// ✅ 正确:第k大 = 第(n-k)小
int targetIndex = nums.size() - k;
return quickSelect(nums, 0, nums.size() - 1, targetIndex);

陷阱2:忘记随机化

1
2
3
4
5
6
7
// ❌ 危险:固定选择最后一个元素
int pivot = nums[right];

// ✅ 安全:随机选择pivot
int randomIndex = left + rand() % (right - left + 1);
swap(nums[randomIndex], nums[right]);
int pivot = nums[right];

技巧1:记住partition模板

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
int partition(vector<int>& nums, int left, int right) {
// 随机pivot
int random = left + rand() % (right - left + 1);
swap(nums[random], nums[right]);

int pivot = nums[right];
int i = left - 1; // 关键:从left-1开始

for (int j = left; j < right; ++j) {
if (nums[j] <= pivot) {
swap(nums[++i], nums[j]);
}
}

swap(nums[i + 1], nums[right]);
return i + 1;
}

技巧2:口诀记忆

“找大减K,找小用K”

  • 找第K大:targetIndex = n - k
  • 找第K小:targetIndex = k - 1(从0开始索引)

思维导图

graph TD
    A[快速选择] --> B[核心思想]
    A --> C[关键操作]
    A --> D[应用场景]
    A --> E[优化技巧]
    
    B --> B1[基于快排分区]
    B --> B2[只递归一侧]
    B --> B3[O n 平均复杂度]
    
    C --> C1[Partition分区]
    C --> C2[比较pivot位置]
    C --> C3[递归/迭代]
    
    D --> D1[第K大元素]
    D --> D2[中位数]
    D --> D3[TopK问题]
    
    E --> E1[随机化pivot]
    E --> E2[三数取中]
    E --> E3[迭代版本]
    
    style A fill:#FFE4B5
    style B fill:#87CEEB
    style D fill:#90EE90

总结

核心要点

  1. 快速选择 = 快排分区 + 单侧递归
  2. 平均O(n)复杂度:比完全排序快
  3. 随机化pivot:避免最坏情况
  4. 索引转换:第K大 = 第(n-k)小

代码模板

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
// 快速选择模板
int quickSelect(vector<int>& nums, int left, int right, int k) {
if (left == right) return nums[left];

int pivotIndex = partition(nums, left, right);

if (pivotIndex == k) {
return nums[k];
} else if (pivotIndex < k) {
return quickSelect(nums, pivotIndex + 1, right, k);
} else {
return quickSelect(nums, left, pivotIndex - 1, k);
}
}

// Partition模板
int partition(vector<int>& nums, int left, int right) {
int random = left + rand() % (right - left + 1);
swap(nums[random], nums[right]);

int pivot = nums[right];
int i = left - 1;

for (int j = left; j < right; ++j) {
if (nums[j] <= pivot) {
swap(nums[++i], nums[j]);
}
}

swap(nums[i + 1], nums[right]);
return i + 1;
}

推荐练习

推荐阅读