分治算法

1. 分治算法概述

分治算法(Divide and Conquer)是一种重要的算法设计策略,其基本思想是将一个规模较大的问题分解为若干个规模较小、相互独立且与原问题形式相同的子问题,然后递归地解决这些子问题,最后将子问题的解合并起来得到原问题的解。

分治算法通常包含三个步骤:

  1. 分解(Divide):将原问题分解为若干个规模较小、相互独立、与原问题形式相同的子问题。
  2. 解决(Conquer):若子问题规模较小而容易被解决则直接求解,否则递归地解各个子问题。
  3. 合并(Combine):将各个子问题的解合并为原问题的解。

2. 分支算法经典案例

2.1 归并排序

归并排序是分治算法的一个经典应用,它的基本思想是将一个数组分成两个子数组,分别对这两个子数组进行排序,然后将排好序的子数组合并成一个有序的数组。

归并排序算法

2.2 快速排序

快速排序采用分治策略,通过递归地将数组分割成子数组来实现排序。其平均时间复杂度为 **O(n log n)**,在大多数情况下表现良好,因此在实际应用中广泛使用。

快速排序算法

2.3 合并K个有序单链表

给定一个链表数组,每个链表都已经按升序排列,要求使用分治算法将这 k 个有序链表合并成一个有序链表并返回。

输入

1
2
3
4
5
[
1->4->5,
1->3->4,
2->6
]

输出

1
1->1->2->3->4->4->5->6

分治算法的核心思想是将一个复杂的大问题分解为多个相似的小问题,递归地解决这些小问题,最后将小问题的解合并得到原问题的解。对于合并 k 个有序链表的问题,我们可以采用如下步骤:

  1. 分解:将 k 个链表两两分组,不断地将问题规模缩小,直到每组只剩下一个或两个链表。
  2. 解决:对于每组中的一个链表,直接返回该链表;对于每组中的两个链表,将它们合并成一个有序链表。
  3. 合并:将每次分组合并后的结果继续进行分组合并,直到最终得到一个有序链表。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
#include <iostream>
#include <vector>
using namespace std;

// 定义链表节点结构体
struct ListNode
{
int val;
ListNode* next;
ListNode(int x) : val(x), next(nullptr) {}
};

ListNode* initList(const initializer_list<int>& values)
{
ListNode* head = nullptr;
ListNode* current = nullptr;
for (int val : values)
{
if (head == nullptr)
{
head = new ListNode(val);
current = head;
}
else
{
current->next = new ListNode(val);
current = current->next;
}
}
return head;
}

ListNode* mergeTwoLists(ListNode* ls1, ListNode* ls2)
{
ListNode dummy(0);
ListNode* tail = &dummy;
while (ls1 && ls2)
{
if (ls1->val < ls2->val)
{
tail->next = ls1;
ls1 = ls1->next;
}
else
{
tail->next = ls2;
ls2 = ls2->next;
}
tail = tail->next;
}
tail->next = ls1 ? ls1 : ls2;
return dummy.next;
}

ListNode* mergeKLists(vector<ListNode*>& lists, int left, int right)
{
if (left > right) return nullptr;
if (left == right) return lists[left];
int mid = left + (right - left) / 2;
ListNode* ls1 = mergeKLists(lists, left, mid);
ListNode* ls2 = mergeKLists(lists, mid + 1, right);
return mergeTwoLists(ls1, ls2);
}

ListNode* mergeKLists(vector<ListNode*>& lists)
{
return mergeKLists(lists, 0, lists.size() - 1);
}

void printList(ListNode* head)
{
while (head)
{
cout << head->val;
if (head->next) cout << "->";
head = head->next;
}
cout << endl;
}

void freeList(ListNode* head)
{
while (head)
{
ListNode* temp = head;
head = head->next;
delete temp;
}
}

int main()
{
// 初始化链表
ListNode* ls1 = initList({ 1, 5, 12, 17, 18, 22 });
ListNode* ls2 = initList({ 2, 6, 9, 16, 19, 23 });
ListNode* ls3 = initList({ 3, 7, 10, 15, 20, 24 });
ListNode* ls4 = initList({ 4, 8, 11, 13, 21, 25 });

vector<ListNode*> lists = { ls1, ls2, ls3, ls4 };
ListNode* mergedList = mergeKLists(lists);
printList(mergedList);
freeList(mergedList);

return 0;
}

程序输出的结果为:

1
1->2->3->4->5->6->7->8->9->10->11->12->13->15->16->17->18->19->20->21->22->23->24->25
  1. ListNode 结构体:定义了链表节点的结构。
  2. mergeTwoLists 函数:用于合并两个有序链表,通过比较两个链表当前节点的值,将较小值的节点添加到结果链表中。
  3. mergeKLists 递归函数:将链表数组进行二分,递归地合并左右两部分的链表,最终调用 mergeTwoLists 合并两个子问题的结果。
  4. printList 函数:辅助函数,用于打印链表节点的值。
  5. main 函数:创建示例链表,调用 mergeKLists 进行合并,打印结果并释放内存。

2.4 对数时间求中位数

给定两个已排序的数组 nums1nums2,要求在对数时间复杂度 O(log⁡(m+n)) 内找出这两个数组合并后的中位数,其中 mn 分别是 nums1nums2 的长度。

中位数的定义:

  • 如果合并后的数组长度是奇数,中位数就是合并后数组中间位置的元素。
  • 如果合并后的数组长度是偶数,中位数是合并后数组中间两个元素的平均值。

我们可以使用二分查找的方法来解决这个问题。基本思想是通过在较短的数组上进行二分查找,找到合适的分割点,使得两个数组分割后的左半部分元素个数和右半部分元素个数满足一定的条件,同时左半部分的所有元素都小于等于右半部分的所有元素。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) 
{
if (nums1.size() > nums2.size())
{
return findMedianSortedArrays(nums2, nums1);
}

int m = nums1.size();
int n = nums2.size();
int low = 0, high = m;

while (low <= high)
{
int partition1 = (low + high) / 2;
int partition2 = (m + n + 1) / 2 - partition1;
int maxLeft1 = (partition1 == 0) ? INT_MIN : nums1[partition1 - 1];
int minRight1 = (partition1 == m) ? INT_MAX : nums1[partition1];

int maxLeft2 = (partition2 == 0) ? INT_MIN : nums2[partition2 - 1];
int minRight2 = (partition2 == n) ? INT_MAX : nums2[partition2];
if (maxLeft1 <= minRight2 && maxLeft2 <= minRight1)
{
if ((m + n) % 2 == 1)
{
return max(maxLeft1, maxLeft2);
}
else
{
return (max(maxLeft1, maxLeft2) + min(minRight1, minRight2)) / 2.0;
}
}
else if (maxLeft1 > minRight2)
{
high = partition1 - 1;
}
else
{
low = partition1 + 1;
}
}
}

int main()
{
vector<int> nums1 = { 1, 3, 5, 7, 9, 11, 13 };
vector<int> nums2 = { 2, 4, 6, 8, 10 };
double median = findMedianSortedArrays(nums1, nums2);
cout << "中位数为: " << median << endl;
return 0;
}

二分查找与分割点的关系

我们把两个数组想象成两根长长的绳子,二分查找就像是在绳子上找一个合适的切割位置。

nums1 的长度为 mnums2 的长度为 n。我们要在这两个数组上分别找到分割点 partition1partition2,使得 partition1 + partition2 = (m + n + 1) / 2。这个等式的意思是,两个数组左半部分的元素总数要符合整体分割的要求。

  • 假设 nums1 是长数组,长度 m = 100nums2 是短数组,长度 n = 10。如果我们在长数组 nums1 上进行二分查找,初始的搜索范围是从 0100。二分查找每次会把搜索范围缩小一半,但即使缩小很多次,范围还是可能比较大。
  • 如果在短数组 nums2 上进行二分查找,初始搜索范围是从 010。这个范围小很多,二分查找能更快地缩小范围找到合适的分割点。而且,当确定了 nums2 的分割点 partition2 后,再根据公式计算 nums1 的分割点 partition1,由于 nums1 比较长,partition1 更容易落在合法的范围内。
  • + 1 的作用是让公式在合并后数组长度为奇数和偶数的情况下都能统一使用。在奇数长度时,保证左半部分元素数量比右半部分多 1 个,中位数就在左半部分;在偶数长度时,保证左右两部分元素数量相等,方便计算中间两个数的平均值。所以,partition1 + partition2 = (m + n + 1) / 2 这个公式可以简化代码逻辑,避免对两种情况分别进行处理。

范围调整

下面我们通过具体的数据例子来详细解释代码中根据 maxLeft1minRight2 的大小关系调整二分查找范围的逻辑。

假设我们有两个有序数组:

1
2
nums1 = [1, 3, 5]
nums2 = [2, 4, 6, 8]

初始化变量

1
2
3
int m = nums1.size();    // m = 3
int n = nums2.size(); // n = 4
int low = 0, high = m; // low = 0, high = 3

第一次二分查找

1
2
int partition1 = (low + high) / 2;             // partition1 = (0 + 3) / 2 = 1
int partition2 = (m + n + 1) / 2 - partition1; // partition2 = (3 + 4 + 1) / 2 - 1 = 3

此时,数组分割情况如下:

  • nums1 左半部分:[1],右半部分:[3, 5]
  • nums2 左半部分:[2, 4, 6],右半部分:[8]

计算边界值:

1
2
3
4
int maxLeft1 = (partition1 == 0) ? INT_MIN : nums1[partition1 - 1];  // maxLeft1 = 1
int minRight1 = (partition1 == m) ? INT_MAX : nums1[partition1]; // minRight1 = 3
int maxLeft2 = (partition2 == 0) ? INT_MIN : nums2[partition2 - 1]; // maxLeft2 = 6
int minRight2 = (partition2 == n) ? INT_MAX : nums2[partition2]; // minRight2 = 8

判断条件:

1
2
3
4
5
6
if (maxLeft1 <= minRight2 && maxLeft2 <= minRight1) // 1 <= 8 成立,但 6 <= 3 不成立
else if (maxLeft1 > minRight2) // 1 > 8 不成立
else // 说明 partition1 太小,需要将搜索范围的左边界 low 调整为 partition1 + 1
{
low = partition1 + 1; // low = 1 + 1 = 2
}

第二次二分查找

1
2
int partition1 = (low + high) / 2;             // partition1 = (2 + 3) / 2 = 2
int partition2 = (m + n + 1) / 2 - partition1; // partition2 = (3 + 4 + 1) / 2 - 2 = 2

此时,数组分割情况如下:

  • nums1 左半部分:[1, 3],右半部分:[5]
  • nums2 左半部分:[2, 4],右半部分:[6, 8]

计算边界值:

1
2
3
4
int maxLeft1 = (partition1 == 0) ? INT_MIN : nums1[partition1 - 1]; // maxLeft1 = 3
int minRight1 = (partition1 == m) ? INT_MAX : nums1[partition1]; // minRight1 = 5
int maxLeft2 = (partition2 == 0) ? INT_MIN : nums2[partition2 - 1]; // maxLeft2 = 4
int minRight2 = (partition2 == n) ? INT_MAX : nums2[partition2]; // minRight2 = 6

判断条件:

1
2
3
4
5
6
7
if (maxLeft1 <= minRight2 && maxLeft2 <= minRight1) // 3 <= 6 成立,4 <= 5 成立
{
if ((m + n) % 2 == 1) // (3 + 4) % 2 = 1 成立
{
return max(maxLeft1, maxLeft2); // 返回 max(3, 4) = 4
}
}
  • 在第一次二分查找中,由于 maxLeft1 <= minRight2maxLeft2 > minRight1,说明 partition1 太小,我们将 low 调整为 partition1 + 1
  • 在第二次二分查找中,分割点满足 maxLeft1 <= minRight2maxLeft2 <= minRight1,此时可以根据合并后数组的长度是奇数还是偶数计算中位数。在这个例子中,合并后数组长度为奇数,中位数是左半部分的最大值 4。

通过不断调整二分查找的范围,我们可以找到合适的分割点,从而计算出两个有序数组合并后的中位数。