-
Notifications
You must be signed in to change notification settings - Fork 4
/
trie.h
75 lines (75 loc) · 1.43 KB
/
trie.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
#include <bits/stdc++.h>
using namespace std;
const int K = 26;
inline int getId(char c){
return c - 'a';
}
struct Vertex {
int next[K];
int leaf;
int count;
Vertex() {
fill(begin(next), end(next), -1);
leaf = 0;
count = 0;
}
};
struct Trie{
vector<Vertex> trie;
Trie(){
trie.emplace_back();
}
void add(string const& s) {
int v = 0;
trie[v].count++;
for(char ch: s) {
int c = getId(ch);
if (trie[v].next[c] == -1) {
trie[v].next[c] = trie.size();
trie.emplace_back();
}
v = trie[v].next[c];
trie[v].count++;
}
trie[v].leaf++;
}
int countStr(string const& s) {
int v = 0;
for (char ch : s) {
int c = getId(ch);
if (trie[v].next[c] == -1)
return 0;
v = trie[v].next[c];
}
return trie[v].leaf;
}
int countPre(string const& s) {
int v = 0;
for (char ch : s) {
int c = getId(ch);
if (trie[v].next[c] == -1)
return 0;
v = trie[v].next[c];
}
return trie[v].count;
}
bool remove(string const& s) {
vector<int> rm;
int v = 0;
rm.push_back(v);
for(char ch: s) {
int c = getId(ch);
if (trie[v].next[c] == -1)
return false;
v = trie[v].next[c];
rm.push_back(v);
}
if(trie[v].leaf > 0){
trie[v].leaf--;
for(int x: rm)
trie[x].count--;
return true;
}
return false;
}
};