Yao Lirong's Blog

Tsinghua DSA 作业总结 (3)

2021/02/11

CST数据结构(2020秋)PA3

3.1 Not Found

算法

要找二进制字符串 A 中最短的未出现过的子串 B,我们先考虑一个比较长的子串,其长度为 24。 注意到 A 的长度最长为 16777216 = 2^24。因为还要掐头去尾,所以 A 中长度为 24 的子串的总数必定小于 2^24 个,而长度为 24 的字符串总共有 2^24 种,所以 A 中必定有某个长度为 24 的字符串是不存在的。

我们用 bitmap 边读入,边记录下所有出现过的长为 24 的子串。这个 bitmap 只存长度为 24 的子串,我们叫它 bitmap24。读入完成后,注意到任何一个在 A 中出现的长为 23 的子串必定是某一 24 子串掐头或去尾得到的,于是我们遍历所有在 24 子串,对他们掐头去尾,将得到的两个结果存入 bitmap23 中,如此做直到 bitmap1 存完。

最后我们从长度 24 开始遍历,找到第一个长度 n 使得所有长度为 n 的子串都在 A 中出现了,那么所要找的“最短未出现子串” B 必然有长度 n+1,我们只需要再遍历一遍 bitmap(n+1) 找到第一个不存在的字符串即可

细节

  1. 读入字符串的时候当总长度达到 24 以后,我们就要读一个新的弃一个旧的,因为根据题目分析 B 最长也就是 24
  2. 一个 int 是 4 byte = 32 bit = 2^5 bit,所以 bitmap24 需要 $2^{24}/2^5 = 2^{19}$ 个 int,bitmap1 … bitmap 5 各自仅需 1个 int
  3. 因为我们是将二进制字符串用 int 方式存在 bitmap 中,如果这个字符串有 leading 0s, 它们在输出时会被忽略掉,所以我们需要根据 bitmap-n 这个长度 n 来补全 leading 0s

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
#include<cstdio>
#include<iostream>
using namespace std;

// int_size[i] is the number of ints needed to store all strings of length i
const int int_size[25] = {
1, // there should be no bitmap for string of length 0,
// but we give it 1 to make the whole program consistent
1, 1, 1, 1, 1, // 2^1 2^2 2^3 2^4 2^5 each only needs one int
2, 4, 8, 16, 32, // 2^6, ... 10
64, 128, 256, 512,
1024, 2048, 4096,
8192, 16384, 32768,
65536, 131072,
262144, 524288 };

// ones[i] is 2^i - 1
const int ones[25] = {
0, 1, 3, 7, 15, 31,
63, 127, 255, 511, //10
1023, 2047, 4095,
8191, 16383, 32767,
65535, 131071,
262143, 524287, //20
1048575, 2097151,
4194303, 8388607, 16777215
};

// bitmap[i] is the bitmap for binary strings of length i
unsigned int *bitmap[25];

// make our bitmap contain a binary string x of length n
void setbit(int n, unsigned int x);

// returns true if our bitmap contains a binary string x of length n
bool checkbit(int n, unsigned int x);

// print a binary string x of length n
void print_binary(int n, unsigned int x);

int main()
{
for (int i = 0; i <= 24; i++) {
bitmap[i] = new unsigned int[int_size[i]];
for (int j = 0; j < int_size[i]; j++)
bitmap[i][j] = 0;
}

// n is the total number of characters we read in
// s is the string at our sliding window
// c is the character we just read in
// input is 0 if c is '0', is 1 if c is '1'
unsigned int n = 0, s = 0, input = 0;
char c = getchar(); n = 1;

// read till nothing more to read or the string is 24 char long
for (; c!='\n' && n<24; c = getchar()) {
input = c - '0';
s = (s << 1) | input;
n += 1;
}

// n is the number of characters read in, including the line feed
// n-1 is the actual length of s
setbit(n-1, s);

// we probably halted because n==24, so we read in 24 valid 0 1 characters
// If so, there can be more to be read, so we try to read more but keep the string at 24 characters long
// skip this loop if the string is finished with a space
for (; c!='\n'; c = getchar()) {
input = c - '0';
s = (s << 1) | input;
s = s & 0xFFFFFF; // keeps only the first 24 characters
setbit(24, s);
n += 1;
}

n -= 1; // n is the number of characters read in, including the line feed
// delete 1 to obtain the actual string length


// len is the length of answer string
// ans is the binary string in int representation
// full is true if all the strings of length i is in our bitmap
unsigned int len = 0, ans = 0;
bool full = false;
for (int i = n>24 ? 24 : n; i>0 && !full; i--) {
full = true; // we assume this level is full
for (int j = 0; j < ones[i] + 1; j++) { // iterate all strings 0 ~ 2^i
if (checkbit(i, j)) { // percolate down to its substring
setbit(i - 1, j >> 1);
setbit(i - 1, j & ones[i - 1]);
}
else if (full) { // current substring doesn't exist, and all the previous substrings do exist
// so this is the FIRST substring that doesn't exist
ans = j;
len = i;
full = false;
}
}
}

print_binary(len, ans);

}

void setbit(int n, unsigned int x) {
// x%32 就是存储 x 的 bit,即从左向右 x%32 个位置的那个 bit
// 但由于计算机中存储数是从右向左存的,我们需要让 1 从右端开始移动 ( 31- x%32 ) 个位置才可以
// 这样我们得到一个第 x%32 为1,其他位为 0 的二进制数,通过 or 与原 bitmap 储值合并
// bitmap[x/32] |= (1<<(31 - x%32));
bitmap[n][x>>5] |= (1<<(31 - x&31));
};

bool checkbit(int n, unsigned int x) {
// bitmap[x/32] & (1<<(31 - x%32)) is determined solely by the x%32 bit of this int chunk
// If that bit is 0, the whole expression is 0
// If that bit is 1, the whole expression is greater than 1 and thus evaluate to true

// return bitmap[x/32] & (1<<(31 - x%32));

return bitmap[n][x>>5] & (1<<(31 - x&31));
}

void print_binary(int n, unsigned int x) {
// int 是从右往左存的,且我们只能访问最右边的 least-significant digit
// 我们要从左往右打印,只能将从右向左的每个 bit 顺序存起来再倒序打印

int ans[25]; int m = 0;
while (x != 0) {
ans[m] = x & 1;
x = x >> 1;
m++;
}

// 补全 leading 0s
for (int i = m; i < n; i++) {
ans[i] = 0;
}

for (int i = n - 1; i >= 0; i--) {
printf("%d",ans[i]);
}
printf("\n");
}

复杂度分析

读入长度为 n 的字符串,耗时 O(n)

如果 n >24 则从 bitmap24 开始遍历,如果 n<=24 则从 bitmap(n) 开始遍历,耗时 $O(2^{min(24,n)})$

当 n 达到 2^24 级别时,整体复杂度还是 O(n)

Reference

  1. 用C++实现bitmap

3.3 Kth

算法

题目要求找出 a,b,c 三个数组对应的三元数对中和为第 k 大的那个三元数对,观察到如果 a,b,c 是有序数对,那么必有 a[i]+b[j]+c[k] < a[i+1]+b[j]+c[k], a[i]+b[j]+c[k] < a[i]+b[j+1]+c[k], a[i]+b[j]+c[k] < a[i]+b[j]+c[k+1].

于是,我们维护一个优先队列,每次出队 (i,j,k) 就入队 (i+1,j,k) (i,j+1,k) (i,j,k+1)。如此做 k 次,出队的就是我们要找的三元对。我们现在将“找第 k 大”转变成了一个三维图的遍历问题。

实现中要注意的是不能让同一个点多次入队,我们可以开一个 vis 数组,但是每个数组最多有 500000 个元素,三维 vis 数组空间绝对不够。于是我们想一种遍历顺序,使得每个点只被遍历一次。首先考虑最简单的一维,单个的 x 轴,就是不停地遍历下一个而已 i, i+1, i+2, ... ;扩展到二维其实就是多个一维情况,我们通过 (0,j), (1,j), ... (i-1,j) 到达 (i,j) 那我们如何到达 (0,j) 呢?通过 (0,0) 的一维扩张,也就是说,当 x 轴为 0 时,我们既向 x 方向扩张,也向 y 方向扩张,而当 x 轴不为 0 时,我们只向 x 方向扩张。

对于三维情况,想象 x,y,z 正方向为右,前,下。则在任意时刻,我们都向 x 扩张;仅当 x=0 时,我们向 y 方向扩张;仅当 x=0 且 y=0 时,我们向 z 方向扩张。并且由于我们根据优先级选取每一次的扩张边界,我们一定也是优先级最高的先被找到。

细节

  • Heap 的实现:sink 时首先判断孩子存不存在(孩子坐标与元素总数比较)如果左孩子存在且“右孩子不存在,或左孩子优先级比右孩子高”,则与左孩子互换;如果右孩子存在且右孩子优先级更高,则与右孩子互换
  • 三维的遍历顺序:尝试向 y 方向扩张时,如果 x!=0,跳过此次扩张;尝试向 z 方向扩张时,如果 x!=0 || y!=0,跳过此次扩张
  • 数组的排序:在本题提供接口中,我们无法直接访问数组 a,b,c 中的元素,所以我们自己开另外三个数组 s,u,t 其中 s[i] 表示 a 中第 i 大的元素所对应在 a 中的位置。即 s,u,t 存 1…n, 代表 a,b,c 中的下标。为取得 s,我们使用 sort(s,n) 但比较器用的却是 a 的比较器

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
#include "kth.h"
#define _CRT_SECURE_NO_WARNINGS
#include<iostream>
#include<cstdio>
using namespace std;
const int N = 500007, K = 2000003;
int dir[3][3] = { {1,0,0}, {0,1,0}, {0,0,1} };

// sort x-axis by only comparing sums along x-axis
int sortx_cmp(const void* a, const void* b) {
if (compare(*(int*)a, 1, 1, *(int*)b, 1, 1) == 1) return -1;
else if (compare(*(int*)b, 1, 1, *(int*)a, 1, 1) == 1) return 1;
else return 0;
}

int sorty_cmp(const void* a, const void* b) {
if (compare(1, *(int*)a, 1, 1, *(int*)b, 1) == 1) return -1;
else if (compare(1, *(int*)b, 1, 1, *(int*)a, 1) == 1) return 1;
else return 0;
}

int sortz_cmp(const void* a, const void* b) {
if (compare(1, 1, *(int*)a, 1, 1, *(int*)b) == 1) return -1;
else if (compare(1, 1, *(int*)b, 1, 1, *(int*)a) == 1) return 1;
else return 0;
}

struct triple {
int x, y, z;

triple() {
x = 0; y = 0; z = 0;
};

triple(int a, int b, int c) {
x = a; y = b; z = c;
}

triple(const triple& from) {
this->x = from.x;
this->y = from.y;
this->z = from.z;
}
};
// myPQ is my priority queue
triple myPQ[K*2];
const triple INF = triple(10e7, 10e7, 10e7);

// a, b, c is the array given in problem
int a[N], b[N], c[N];

inline bool operator<(const triple& t1, const triple& t2) {
return compare(a[t1.x], b[t1.y], c[t1.z], a[t2.x], b[t2.y], c[t2.z]);
}

inline bool operator>(const triple& t1, const triple& t2) {
return compare(a[t2.x], b[t2.y], c[t2.z], a[t1.x], b[t1.y], c[t1.z]);
}

// refactorred PQ that only uses strictly greater/lesser to be consistent with compare function
class PriorityQueue {

int n = 0;
triple* a = myPQ;

public:
void add(triple x)
{
a[++n] = x;
swim(n);
}
triple extract()
{
if (n == 0) throw "Nothing to extract";
triple result = a[1];
swap(a[1], a[n]);
a[n--] = INF;
sink(1);
return result;
}

bool isEmpty() { return n == 0; }

void print() {
for (int i = 1; i <= n; i++) {
printf("%d in heap: (%d, %d, %d)\n", i, ::a[a[i].x], b[a[i].y], c[a[i].z]);
}
}

private:
void swim(int i)
{
while (i > 1 && !(a[i / 2] < a[i])) {
swap(a[i / 2], a[i]);
i = i / 2;
}
}

void sink(int i)
{
int l = i * 2, r = i * 2 + 1;
while ((l <= n && !(a[i] < a[l])) || (r <= n && !(a[i] < a[r]))) {
if (l <= n && (r > n || !(a[l] > a[r]))) { // l is in the heap and (r is not in the heap, or l is the better choice compared to r)
swap(a[i], a[l]);
i = l; l = i * 2; r = i * 2 + 1; continue;
}
else if (a[l] > a[r] && r <= n) {
swap(a[i], a[r]);
i = r; l = i * 2; r = i * 2 + 1; continue;
}
else return;
}
}
};

void get_kth(int n, int k, int *x, int *y, int *z) {

for (int i = 0; i <= n; i++) {
a[i] = b[i] = c[i] = i;
}

qsort(a+1, n, sizeof(int), sortx_cmp);
qsort(b+1, n, sizeof(int), sorty_cmp);
qsort(c+1, n, sizeof(int), sortz_cmp);

PriorityQueue q;
q.add(triple(1, 1, 1));
for (int i = 1; i < k; i++) { // extract k-1 triples
triple now = q.extract();
int nowx = now.x, nowy = now.y, nowz = now.z;

for (int j = 0; j < 3; j++) {

int nextx = nowx + dir[j][0], nexty = nowy + dir[j][1], nextz = nowz + dir[j][2];

if (nextx > n || nexty > n || nextz > n) continue;
if ((j == 1 && nowx != 1) || (j == 2 && (nowx != 1 || nowy != 1))) continue;

q.add(triple(nextx, nexty, nextz));
}
}
triple result = q.extract();
*x = a[result.x];
*y = b[result.y];
*z = c[result.z];
}

复杂度分析

共有三个数组,一个数组中有 n 个元素,找大小为 k 对的三元数对。首先对三个数组进行排序 O(nlogn),每有一个数对出优先队列,就有最多三个入优先队列,共操作 k 次,每次操作 O(logk),总共 O(klogk)。总时间 O(nlogn + klogk)

3.4 Component

算法

堆的合并 左偏树

题目的询问永远是某一联通分量中第 k 大的点的权值,k 是一个常数。第 k 大又可以看做前 k 个最大元素中最小的元素,即如果我们维护一个小根堆,使它恒有 k 个元素(n<k 时输出 -1,n>k 时弹出 n-k 次最小的元素)那么这 k 个元素必然是连通块中前 k 大的元素,堆顶元素就是我们的询问。

当加入的新边 (u,v) 联通两个不曾联通的连通块时,对应的两个堆必须合并。支持快速合并操作的优先队列,我们选择左式堆。(u,v) 将块联通,实际上是将其所在的堆合并起来,我们必须能够高效找到 (u,v) 所属哪个堆,即其所属堆的根是谁,使用并查集存储这个信息。

细节

每个节点都有编号,我们不用传统的 class 建优先级队列,而直接用数组存每个点对应的信息,速度更快,访问更方便

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
#define _CRT_SECURE_NO_WARNINGS
#include<iostream>
#include<cstdio>
#include<iomanip>
using namespace std;

const bool DEBUG = false;
const int N = 1000007;
int n, m, k, q;

// We use a min-heap 小根堆 to store the points
// delMax getMax refers to the "max priority" element, which has the smallest value

// value[i] is the value of the point i
// father[i] is the root of the heap i belongs to
// lchild[i], rchild[i] is the left and right child of the point i in heap
// npl[i] is the null-path-length of heap i
// sze[i] is the size of heap i
int value[N], father[N], lchild[N], rchild[N], npl[N], sze[N];

// find(i) returns the root of the heap i belongs to
inline int find(int x) {
return x == father[x] ? x : (father[x] = find(father[x]));
}

// merge heap b into heap a
int merge(int a, int b) {
if (a == 0) return b;
if (b == 0) return a;
if (value[a] > value[b]) swap(a,b);

rchild[a] = merge(rchild[a], b);
father[rchild[a]] = find(a);
if (lchild[a] == 0 || npl[lchild[a]] < npl[rchild[a]]) {
int temp = lchild[a];
lchild[a] = rchild[a];
rchild[a] = temp;
}
npl[a] = rchild[a] == 0 ? 1 : npl[rchild[a]] + 1;
sze[a] = sze[lchild[a]] + sze[rchild[a]] + 1;

return a;
}

// getMax(x) returns the value of root of the heap x represents
// Requires: x is the root of a heap
int getMax(int x) {
// x is the root, root is the max, so we just return the value of x
return value[x];
}

// delMax(x) returns the new root after deleting root in heap x
// Requires: x is the root of a heap
int delMax(int x) {
int ans = value[x];
sze[x] -= 1;
int new_root = merge(lchild[x], rchild[x]);
father[new_root] = new_root; // new root is now a root, so its father is itself
father[x] = new_root; // this deleted node, and all the nodes pointing to the deleted node should now point to the new root
return ans;
}

// deletes the Max element until this heap has no more than k elements
void prune(int x) {
if (sze[x] <= k) return;
delMax(find(x)); // needs to find(x) because delMax requires a root
prune(find(x)); // after being deleted, x becomes a stranded point
// prune must take in the new root of the heap
}

// print all points and their information
void print() {
if (!DEBUG) return;
cout << "# value Parent Lchild Rchild npl size " << endl;
for (int i = 1; i <= n; i++) {
cout << setw(2) << i;
cout << setw(6) << value[i];
cout << setw(6) << father[i];
cout << setw(8) << lchild[i];
cout << setw(8) << rchild[i];
cout << setw(6) << npl[i];
cout << setw(6) << sze[i];
cout << endl;
}
}

int main() {
scanf("%d%d%d%d", &n, &m, &k, &q);
for (int i = 1; i <= n; i++) {
scanf("%d", value + i);
father[i] = i;
lchild[i] = rchild[i] = npl[i] = 0; // points to null
sze[i] = 1;
} sze[0] = 0; npl[0] = 0; father[0] = lchild[0] = rchild[0] = 10e9;

for (int i = 1; i <= m; i++) {
int a, b;
scanf("%d%d", &a, &b);
if (find(a) == find(b)) continue; // already connected, another edge doesn't make a difference
int merged = merge(find(a), find(b));
prune(merged);
}

for (int i = 1; i <= q; i++) {
int op, a, b;
scanf("%d", &op);
if (op == 1) {
scanf("%d%d", &a, &b);
if (find(a) == find(b)) continue;
int merged = merge(find(a), find(b));
prune(merged);
}
else if (op == 2) {
scanf("%d", &a);
if (sze[find(a)] < k) printf("-1\n");
else printf("%d\n", getMax(find(a)));
}
}
}

复杂度分析

初始化后,每个点最多被入堆一次(所在连通块与他人联通),出堆一次(因为不属于前 k 大而被弹出堆)每次出入堆操作是两个左式堆的 merge,复杂度 O(logn)。共 n 个点,所以总体复杂度是 O(n logn)

Reference

  1. 题解 P3377 【模板】左偏树(可并堆)
  2. 课程代码
CATALOG
  1. 1. 3.1 Not Found
    1. 1.1. 算法
    2. 1.2. 细节
    3. 1.3. 代码
    4. 1.4. 复杂度分析
    5. 1.5. Reference
  2. 2. 3.3 Kth
    1. 2.1. 算法
    2. 2.2. 细节
    3. 2.3. 代码
    4. 2.4. 复杂度分析
  3. 3. 3.4 Component
    1. 3.1. 算法
    2. 3.2. 细节
    3. 3.3. 代码
    4. 3.4. 复杂度分析
    5. 3.5. Reference