This documentation is automatically generated by online-judge-tools/verification-helper
#include "convolution/ntt.hpp"
$\mod 998244353$ でのみ動作する $O(N\log N)$ の畳み込み。
NTTをするために必要なデータを格納している。
$n$ を ビット幅 bit_size でビットリバースする。
配列 $a$ をFFTする。inverse = true のときinvFFTをする。 invFFTでは割る $n$ をする。 $a$ の配列の大きさは $2$ 冪でないとダメ。
$a$ と $b$ を畳み込みその配列を返す。 $O(N\log N)$
#pragma once
#include "../template/template.hpp"
#include "../utility/modint.hpp"
namespace lib {
using mint = modint998244353;
struct ntt_info {
static constexpr int rank2 = 23;
const int g = 3;
std::array<std::array<mint, rank2 + 1>, 2> root;
ntt_info() {
root[0][rank2] = mint(g).pow((mint::mod() - 1) >> rank2);
root[1][rank2] = root[0][rank2].inv();
rrep(i, 0, rank2) {
root[0][i] = root[0][i + 1] * root[0][i + 1];
root[1][i] = root[1][i + 1] * root[1][i + 1];
}
}
};
void butterfly(std::vector<mint>& a, bool inverse) {
static ntt_info info;
int n = a.size();
int bit_size = 0;
while ((1 << bit_size) < n) bit_size++;
assert(1 << bit_size == n);
for (int i = 0, j = 1; j < n - 1; j++) {
for (int k = n >> 1; k > (i ^= k); k >>= 1);
if (j < i) {
std::swap(a[i], a[j]);
}
}
rep(bit, 0, bit_size) {
rep(i, 0, n / (1 << (bit + 1))) {
mint zeta1 = 1;
mint zeta2 = info.root[inverse][1];
mint w = info.root[inverse][bit + 1];
rep(j, 0, 1 << bit) {
int idx = i * (1 << (bit + 1)) + j;
int jdx = idx + (1 << bit);
mint p1 = a[idx];
mint p2 = a[jdx];
a[idx] = p1 + zeta1 * p2;
a[jdx] = p1 + zeta2 * p2;
zeta1 *= w;
zeta2 *= w;
}
}
}
if (inverse) {
mint inv_n = mint(n).inv();
rep(i, 0, n) a[i] *= inv_n;
}
}
std::vector<mint> convolution(const std::vector<mint>& f,
const std::vector<mint>& g) {
int n = 1;
while (n < int(f.size() + g.size() - 1)) n <<= 1;
std::vector<mint> a(n), b(n);
std::copy(f.begin(), f.end(), a.begin());
std::copy(g.begin(), g.end(), b.begin());
butterfly(a, false);
butterfly(b, false);
rep(i, 0, n) {
a[i] *= b[i];
}
butterfly(a, true);
a.resize(f.size() + g.size() - 1);
return a;
}
} // namespace lib
#line 2 "convolution/ntt.hpp"
#line 2 "template/template.hpp"
#include <bits/stdc++.h>
#define rep(i, s, n) for (int i = (int)(s); i < (int)(n); i++)
#define rrep(i, s, n) for (int i = (int)(n)-1; i >= (int)(s); i--)
#define all(v) v.begin(), v.end()
using ll = long long;
using ld = long double;
using ull = unsigned long long;
template <typename T> bool chmin(T &a, const T &b) {
if (a <= b) return false;
a = b;
return true;
}
template <typename T> bool chmax(T &a, const T &b) {
if (a >= b) return false;
a = b;
return true;
}
namespace lib {
using namespace std;
} // namespace lib
// using namespace lib;
#line 2 "utility/modint.hpp"
#line 4 "utility/modint.hpp"
namespace lib {
template <ll m> struct modint {
using mint = modint;
ll a;
modint(ll x = 0) : a((x % m + m) % m) {}
static constexpr ll mod() {
return m;
}
ll val() const {
return a;
}
ll& val() {
return a;
}
mint pow(ll n) const {
mint res = 1;
mint x = a;
while (n) {
if (n & 1) res *= x;
x *= x;
n >>= 1;
}
return res;
}
mint inv() const {
return pow(m - 2);
}
mint& operator+=(const mint rhs) {
a += rhs.a;
if (a >= m) a -= m;
return *this;
}
mint& operator-=(const mint rhs) {
if (a < rhs.a) a += m;
a -= rhs.a;
return *this;
}
mint& operator*=(const mint rhs) {
a = a * rhs.a % m;
return *this;
}
mint& operator/=(mint rhs) {
*this *= rhs.inv();
return *this;
}
friend mint operator+(const mint& lhs, const mint& rhs) {
return mint(lhs) += rhs;
}
friend mint operator-(const mint& lhs, const mint& rhs) {
return mint(lhs) -= rhs;
}
friend mint operator*(const mint& lhs, const mint& rhs) {
return mint(lhs) *= rhs;
}
friend mint operator/(const mint& lhs, const mint& rhs) {
return mint(lhs) /= rhs;
}
friend bool operator==(const modint &lhs, const modint &rhs) {
return lhs.a == rhs.a;
}
friend bool operator!=(const modint &lhs, const modint &rhs) {
return !(lhs == rhs);
}
mint operator+() const {
return *this;
}
mint operator-() const {
return mint() - *this;
}
};
using modint998244353 = modint<998244353>;
using modint1000000007 = modint<1'000'000'007>;
} // namespace lib
#line 5 "convolution/ntt.hpp"
namespace lib {
using mint = modint998244353;
struct ntt_info {
static constexpr int rank2 = 23;
const int g = 3;
std::array<std::array<mint, rank2 + 1>, 2> root;
ntt_info() {
root[0][rank2] = mint(g).pow((mint::mod() - 1) >> rank2);
root[1][rank2] = root[0][rank2].inv();
rrep(i, 0, rank2) {
root[0][i] = root[0][i + 1] * root[0][i + 1];
root[1][i] = root[1][i + 1] * root[1][i + 1];
}
}
};
void butterfly(std::vector<mint>& a, bool inverse) {
static ntt_info info;
int n = a.size();
int bit_size = 0;
while ((1 << bit_size) < n) bit_size++;
assert(1 << bit_size == n);
for (int i = 0, j = 1; j < n - 1; j++) {
for (int k = n >> 1; k > (i ^= k); k >>= 1);
if (j < i) {
std::swap(a[i], a[j]);
}
}
rep(bit, 0, bit_size) {
rep(i, 0, n / (1 << (bit + 1))) {
mint zeta1 = 1;
mint zeta2 = info.root[inverse][1];
mint w = info.root[inverse][bit + 1];
rep(j, 0, 1 << bit) {
int idx = i * (1 << (bit + 1)) + j;
int jdx = idx + (1 << bit);
mint p1 = a[idx];
mint p2 = a[jdx];
a[idx] = p1 + zeta1 * p2;
a[jdx] = p1 + zeta2 * p2;
zeta1 *= w;
zeta2 *= w;
}
}
}
if (inverse) {
mint inv_n = mint(n).inv();
rep(i, 0, n) a[i] *= inv_n;
}
}
std::vector<mint> convolution(const std::vector<mint>& f,
const std::vector<mint>& g) {
int n = 1;
while (n < int(f.size() + g.size() - 1)) n <<= 1;
std::vector<mint> a(n), b(n);
std::copy(f.begin(), f.end(), a.begin());
std::copy(g.begin(), g.end(), b.begin());
butterfly(a, false);
butterfly(b, false);
rep(i, 0, n) {
a[i] *= b[i];
}
butterfly(a, true);
a.resize(f.size() + g.size() - 1);
return a;
}
} // namespace lib