cancan123456 @ 2021-11-14 19:26:23
RT, #1#5#7#14#17 WA
#include <cstdio>
#include <queue>
#include <vector>
using namespace std;
#define maxn 100005
typedef long long ll;
const ll mod = 998244353;
ll a[maxn];
vector < int > G1[maxn];
int degree1[maxn];
void add_edge1(int u, int v) {
G1[u].push_back(v);
degree1[v]++;
}
vector < int > G2[maxn];
int degree2[maxn];
void add_edge2(int u, int v) {
G2[u].push_back(v);
degree2[v]++;
}
int T[maxn], P[maxn], C[maxn], g[maxn][55];
ll mul[maxn], V[maxn] ; // mul[i] 表示第 i 个函数执行后整个序列被乘了多少次
int n, m;
void topo1() {
queue < int > q;
for (int i = 0; i <= m; i++) {
if (degree2[i] == 0) {
q.push(i);
}
}
while (!q.empty()) {
int u = q.front();
q.pop();
for (int v, i = 0; i < (int)G2[u].size(); i++) {
v = G2[u][i];
degree2[v]--;
mul[v] = mul[v] * mul[u] % mod;
if (degree2[v] == 0) {
q.push(v);
}
}
}
}
ll cnt_of_calls[maxn]; // 砍掉操作 2 之后, 每个函数等效于调用了多少次
void topo2() {
cnt_of_calls[0] = 1;
queue < int > q;
for (int i = 0; i <= m; i++) {
if (degree1[i] == 0) {
q.push(i);
}
}
while (!q.empty()) {
int u = q.front();
q.pop();
long long now_mul = 1;
for (int v, i = (int)G1[u].size() - 1; i >= 0; i--) { // 注意这里要倒着遍历边, 因为乘法标记的累计是反向的
v = G1[u][i];
cnt_of_calls[v] = (cnt_of_calls[v] + cnt_of_calls[u] * now_mul) % mod;
now_mul = now_mul * mul[v] % mod;
degree1[v]--;
if (degree1[v] == 0) {
q.push(v);
}
}
}
}
int main() {
scanf("%d", &n);
for (int i = 1; i <= n; i++) {
scanf("%lld", a + i);
}
scanf("%d", &m);
for (int i = 1; i <= m; i++) {
scanf("%d", T + i);
if (T[i] == 1) {
mul[i] = 1;
scanf("%d %lld", P + i, V + i);
} else if (T[i] == 2) {
scanf("%lld", mul + i);
} else {
mul[i] = 1;
scanf("%d", C + i);
for (int j = 1; j <= C[i]; j++) {
scanf("%d", g[i] + j);
add_edge1(i, g[i][j]);
add_edge2(g[i][j], i);
}
}
}
int Q;
scanf("%d", &Q);
C[0] = Q;
for (int f, i = 1; i <= Q; i++) {
scanf("%d", &f);
g[0][i] = f;
mul[0] = 1;
add_edge1(0, g[0][i]);
add_edge2(g[0][i], 0);
}
topo1();
topo2();
for (int i = 1; i <= n; i++) {
a[i] = a[i] * mul[0] % mod;
}
for (int i = 0; i < m; i++) {
if (T[i] == 1) {
a[P[i]] = (a[P[i]] + V[i] * cnt_of_calls[i]) % mod;
}
}
for (int i = 1; i <= n; i++) {
printf("%lld ", a[i]);
}
return 0;
}
by cancan123456 @ 2021-11-15 20:34:41
找到原因了,
for (int i = 0; i <= m; i++) {
if (T[i] == 1) {
a[P[i]] = (a[P[i]] + V[i] * cnt_of_calls[i]) % mod;
}
}
写成了
for (int i = 0; i < m; i++) {
if (T[i] == 1) {
a[P[i]] = (a[P[i]] + V[i] * cnt_of_calls[i]) % mod;
}
}