rainygame @ 2023-08-07 19:53:50
过了中间四个点。
思路:
对于每个前缀,用哈希存一遍。可以发现总共最多只会存
储存我是用双哈希,因为直接比较字符串的复杂度开销很大。查询的期望时间复杂度是 set
去重),其中我设
清空我只清空用过的,顶多就只会清空 list
中的顶多
总的时间复杂度应该是
代码:
#include <bits/stdc++.h>
using namespace std;
#define int long long
const int MOD(1e6+7);
const int BASE1(131);
const int BASE2(13331);
int t, n, q, ha1, base1;
unsigned int ha2, base2;
string str, tmp;
list<pair<unsigned int, int>> li[MOD];
vector<int> vec;
void clear(){
for (int i: vec) li[i].clear();
vec.clear();
}
void insert(string str, int ha1, int ha2, int ind){
vec.push_back(ha1);
li[ha1].push_back({ha2, ind});
}
int query(string str){
set<int> st;
ha1 = ha2 = 0;
base1 = base2 = 1;
for (char i: str){
ha1 = (ha1 + i * base1) % MOD;
base1 = (base1 * BASE1) % MOD;
ha2 += i * base2;
base2 *= BASE2;
}
for (auto i: li[ha1]){
if (i.first == ha2) st.insert(i.second);
}
return st.size();
}
signed main(){
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
cin >> t;
while (t--){
cin >> n >> q;
clear();
for (int i(1); i<=n; ++i){
cin >> str;
tmp = "";
ha1 = ha2 = 0;
base1 = base2 = 1;
for (char j: str){
ha1 = (ha1 + j * base1) % MOD;
base1 = (base1 * BASE1) % MOD;
ha2 += j * base2;
base2 *= BASE2;
tmp += j;
insert(tmp, ha1, ha2, i);
}
}
while (q--){
cin >> str;
cout << query(str) << '\n';
}
}
return 0;
}
by Iniaugoty @ 2023-08-07 20:29:21
@rainygame 不好意思记错题了……不过刚现写了一个map
的,能过,R119552748
by rainygame @ 2023-08-07 20:41:55
@esquigybcu 换了,更慢
明明 list
的常数比 vector
好呀
而且哈希表就是用链表来实现的呀
by rainygame @ 2023-08-07 20:45:16
@gty314159 unordered_map
都比我快
呜呜呜呜呜
by rainygame @ 2023-08-07 20:55:27
卡常路漫漫
by litjohn @ 2024-07-01 20:15:12
@rainygame 其实你可以对相同长度的hash排序,查询时二分查找:
#include <bits/stdc++.h>
using namespace std;
constexpr unsigned long long mod = (1ll << 61) - 1, base = 331;
int t, n, q;
vector<vector<unsigned long long>> h, tmp;
int main() {
ios_base::sync_with_stdio(false);
cin.tie(nullptr);
cout.tie(nullptr);
cin >> t;
for (int i = 0; i < t; ++i) {
h.clear();
tmp.clear();
tmp.resize(1000);
cin >> n >> q;
h.resize(n + 1);
basic_string<char> s;
for (int j = 1; j <= n; ++j) {
cin >> s;
unsigned long long res = 0;
for (auto l: s) {
res = (res * base + l) % mod;
h[j].push_back(res);
}
}
for (int j = 1; j <= n; ++j) {
tmp[0].push_back(h[j][0]);
}
for (int j = 1; !tmp[j - 1].empty(); ++j) {
tmp.emplace_back();
for (int k = 1; k <= n; ++k) {
if (j < h[k].size()) {
tmp[j].push_back(h[k][j]);
}
}
}
for (auto &j: tmp) {
sort(j.begin(), j.end());
}
for (int j = 0; j < q; ++j) {
cin >> s;
int ans = 0;
unsigned long long num = 0;
for (auto k: s) {
num = (num * base + k) % mod;
}
if (tmp.size() >= s.length()) {
ans = upper_bound(tmp[s.length() - 1].begin(), tmp[s.length() - 1].end(), num) -
lower_bound(tmp[s.length() - 1].begin(),
tmp[s.length() -
1].end(), num);
}
cout << ans << "\n";
}
}
return 0;
}