Trie字典树基础

一些Trie基础(主要是异或Trie)

查找字符串Trie

直接放模板:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
struct trie
{
ll nex[maxn][26];
ll cnt;
bool exist[maxn];
void init()
{
memset(nex[0], 0, sizeof(nex[0]));
exist[cnt] = 0;
cnt = 0;
}
void insert(string s)
{
ll now = 0;
for(auto&ch:s)
{
if(nex[now][ch]==0)
{
cnt++;
memset(nex[cnt], 0, sizeof(nex[cnt]));
exist[cnt] = 0;
nex[now][ch] = cnt;
}
now = nex[now][ch];
}
if(s.size())
{
exist[now] = 1;
}
}
void find(string t)
{
ll now = 0;
for(auto&ch:t)
{
if(nex[now][ch]==0)
{
return 0;
}
now = nex[now][ch];
}
return exist[now];
}
}

其中,maxn一般是所有字符串的总长度(即最多能拓展出多少节点)

此处设置根节点是0,exist表示:存在至少一个字符串,其等于当前节点所表示的字符串前缀

插入和寻找的复杂度都是字符串长度

AC自动机

即在trie树上进行kmp

直接上模板:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
namespace AC
{
ll tr[maxn][26], num[maxn], fail[maxn];
ll tot;
void clear(ll x)
{
num[x] = 0;
for (ll i = 0; i < 26; i++)
{
tr[x][i] = 0;
}
}
void init()
{
tot = 0;
clear(tot);
}
void insert(string s)
{
ll u = 0;
for (auto &ch : s)
{
if (!tr[u][ch - 'a'])
{
tr[u][ch - 'a'] = ++tot;
clear(tot);
}
u = tr[u][ch - 'a'];
}
num[u]++;
}
queue<ll> q;
void build()
{
for (ll i = 0; i < 26; i++)
{
if (tr[0][i])
{
q.push(tr[0][i]);
}
}
ll u;
while (q.size())
{
u = q.front(), q.pop();
for (ll i = 0; i < 26; i++)
{
if (tr[u][i])
{
fail[tr[u][i]] = tr[fail[u]][i];
q.push(tr[u][i]);
}
else
{
tr[u][i] = tr[fail[u]][i];
}
}
}
}
ll query(string t)
{
ll u = 0, ret = 0;
for (auto &ch : t)
{
u = tr[u][ch - 'a'];
for (ll j = u; j && num[j] != -1;j=fail[j])
{
ret += num[j];
num[j] = -1;
}
}
return ret;
}
}

注意,该code中,num表示等于当前节点表示前缀的字符串数量,在ask中,为避免重复计算,计入后就记为-1,故只能ask一次(或许可采用vis来优化?记录答案后便记vis为1,并将答案压入vector,查询完毕后遍历vector来清空vis)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
//vis优化版的query
/*
bitset<maxn>vis;
ll query(string t)
{
vector<ll> tmp;
ll u = 0, ret = 0;
for (auto &ch : t)
{
u = tr[u][ch - 'a'];
for (ll j = u; j && !vis[j];j=fail[j])
{
ret += num[j];
tmp.push_back(j);
vis[j] = 1;
}
}
for(auto&e:tmp)
{
vis[e] = 0;
}
return ret;
}
*/

0-1trie

支持查询异或和、异或极值、全局加一的操作

直接上模板

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
namespace trie
{
const ll MAXD = 21;//01树最大深度,由数值范围决定
ll nex[maxn * (MAXD+1)][2];//存下一节点
ll w[maxn * (MAXD + 1)];
ll xorv[maxn * (MAXD + 1)];
ll tot = 0;
ll mknode()
{
++tot;
nex[tot][1] = nex[tot][0] = w[tot] = xorv[tot] = 0;
return tot;
}
void maintain(ll o)
{
w[0] = xorv[o] = 0;
if(nex[o][0])
{
w[o] += w[nex[o][0]];
xorv[o] ^= (xorv[nex[o][0]]) << 1;
}
if(nex[o][1])
{
w[o] += w[nex[o][1]];
xorv[o] ^= (xorv[ch[o][1]] << 1) | (w[ch[o][1]] & 1);
}
w[o] = w[o] & 1;
}
void insert(ll &o,ll x,ll dep)
{
if(!o)
{
o = mknode();
}
if(dep>MAXD)
{
w[o]++;
return;
}
insert(nex[o][x & 1], x >> 1, dep + 1);
maintain(o);
}
voi erase(ll o,ll x,ll dep)
{
if(dep>MAXD)
{
w[o]--;
return;
}
erase(nex[o][x & 1], x >> 1, dep + 1);
maintain(o);
}
void addall(ll o)//all+1
{
swap(nex[o][0], nex[o][1]);
if(nex[o][0])
{
addall(ch[o][0]);
}
maintain(o);
}
ll merge(ll a,ll b)
{
if(!a)
{
return b;
}
if(!b)
{
return a;
}
w[a] = w[a] + w[b];
xorv[a] ^= xorv[b];
nex[a][0] = merge(nex[a][0], nex[b][0]);
nex[a][1] = merge(nex[a][1], nex[b][1]);
return a;
}
}

xorv[i]表示以i节点作为根节点的异或和,w[i]表示该节点表示前缀数量,都用maintain维护

全局加1,是利用了+1操作相当于将011111变成100000,即0变1,1变0,1变0时会有进位,对下一位继续addall

merge是将以b为根节点的trie合并到以a为根节点的trie上,注意,由于是合并操作,所以没用maintain