icpc_library

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub ebi-fly13/icpc_library

:heavy_check_mark: test/convolution/Convolution_mod.test.cpp

Depends on

Code

#define PROBLEM "https://judge.yosupo.jp/problem/convolution_mod"

#include "../../convolution/ntt.hpp"
#include "../../template/template.hpp"
#include "../../utility/modint.hpp"

using mint = lib::modint998244353;

int main() {
    int n, m;
    std::cin >> n >> m;
    std::vector<mint> a(n), b(m);
    rep(i, 0, n) {
        int x;
        std::cin >> x;
        a[i] = x;
    }
    rep(i, 0, m) {
        int x;
        std::cin >> x;
        b[i] = x;
    }
    auto c = lib::convolution(a, b);
    rep(i, 0, c.size()) {
        std::cout << c[i].val() << " \n"[i == int(c.size()) - 1];
    }
}
#line 1 "test/convolution/Convolution_mod.test.cpp"
#define PROBLEM "https://judge.yosupo.jp/problem/convolution_mod"

#line 2 "convolution/ntt.hpp"

#line 2 "template/template.hpp"

#include <bits/stdc++.h>

#define rep(i, s, n) for (int i = (int)(s); i < (int)(n); i++)
#define rrep(i, s, n) for (int i = (int)(n)-1; i >= (int)(s); i--)
#define all(v) v.begin(), v.end()

using ll = long long;
using ld = long double;
using ull = unsigned long long;

template <typename T> bool chmin(T &a, const T &b) {
    if (a <= b) return false;
    a = b;
    return true;
}
template <typename T> bool chmax(T &a, const T &b) {
    if (a >= b) return false;
    a = b;
    return true;
}

namespace lib {

using namespace std;

}  // namespace lib

// using namespace lib;
#line 2 "utility/modint.hpp"

#line 4 "utility/modint.hpp"

namespace lib {

template <ll m> struct modint {
    using mint = modint;
    ll a;

    modint(ll x = 0) : a((x % m + m) % m) {}
    static constexpr ll mod() {
        return m;
    }
    ll val() const {
        return a;
    }
    ll& val() {
        return a;
    }
    mint pow(ll n) const {
        mint res = 1;
        mint x = a;
        while (n) {
            if (n & 1) res *= x;
            x *= x;
            n >>= 1;
        }
        return res;
    }
    mint inv() const {
        return pow(m - 2);
    }
    mint& operator+=(const mint rhs) {
        a += rhs.a;
        if (a >= m) a -= m;
        return *this;
    }
    mint& operator-=(const mint rhs) {
        if (a < rhs.a) a += m;
        a -= rhs.a;
        return *this;
    }
    mint& operator*=(const mint rhs) {
        a = a * rhs.a % m;
        return *this;
    }
    mint& operator/=(mint rhs) {
        *this *= rhs.inv();
        return *this;
    }
    friend mint operator+(const mint& lhs, const mint& rhs) {
        return mint(lhs) += rhs;
    }
    friend mint operator-(const mint& lhs, const mint& rhs) {
        return mint(lhs) -= rhs;
    }
    friend mint operator*(const mint& lhs, const mint& rhs) {
        return mint(lhs) *= rhs;
    }
    friend mint operator/(const mint& lhs, const mint& rhs) {
        return mint(lhs) /= rhs;
    }
    friend bool operator==(const modint &lhs, const modint &rhs) {
        return lhs.a == rhs.a;
    }
    friend bool operator!=(const modint &lhs, const modint &rhs) {
        return !(lhs == rhs);
    }
    mint operator+() const {
        return *this;
    }
    mint operator-() const {
        return mint() - *this;
    }
};

using modint998244353 = modint<998244353>;
using modint1000000007 = modint<1'000'000'007>;

}  // namespace lib
#line 5 "convolution/ntt.hpp"

namespace lib {

using mint = modint998244353;

struct ntt_info {
    static constexpr int rank2 = 23;
    const int g = 3;
    std::array<std::array<mint, rank2 + 1>, 2> root;

    ntt_info() {
        root[0][rank2] = mint(g).pow((mint::mod() - 1) >> rank2);
        root[1][rank2] = root[0][rank2].inv();
        rrep(i, 0, rank2) {
            root[0][i] = root[0][i + 1] * root[0][i + 1];
            root[1][i] = root[1][i + 1] * root[1][i + 1];
        }
    }
};

void butterfly(std::vector<mint>& a, bool inverse) {
    static ntt_info info;
    int n = a.size();
    int bit_size = 0;
    while ((1 << bit_size) < n) bit_size++;
    assert(1 << bit_size == n);
    for (int i = 0, j = 1; j < n - 1; j++) {
        for (int k = n >> 1; k > (i ^= k); k >>= 1);
        if (j < i) {
            std::swap(a[i], a[j]);
        }
    }
    rep(bit, 0, bit_size) {
        rep(i, 0, n / (1 << (bit + 1))) {
            mint zeta1 = 1;
            mint zeta2 = info.root[inverse][1];
            mint w = info.root[inverse][bit + 1];
            rep(j, 0, 1 << bit) {
                int idx = i * (1 << (bit + 1)) + j;
                int jdx = idx + (1 << bit);
                mint p1 = a[idx];
                mint p2 = a[jdx];
                a[idx] = p1 + zeta1 * p2;
                a[jdx] = p1 + zeta2 * p2;
                zeta1 *= w;
                zeta2 *= w;
            }
        }
    }
    if (inverse) {
        mint inv_n = mint(n).inv();
        rep(i, 0, n) a[i] *= inv_n;
    }
}

std::vector<mint> convolution(const std::vector<mint>& f,
                              const std::vector<mint>& g) {
    int n = 1;
    while (n < int(f.size() + g.size() - 1)) n <<= 1;
    std::vector<mint> a(n), b(n);
    std::copy(f.begin(), f.end(), a.begin());
    std::copy(g.begin(), g.end(), b.begin());
    butterfly(a, false);
    butterfly(b, false);
    rep(i, 0, n) {
        a[i] *= b[i];
    }
    butterfly(a, true);
    a.resize(f.size() + g.size() - 1);
    return a;
}

}  // namespace lib
#line 6 "test/convolution/Convolution_mod.test.cpp"

using mint = lib::modint998244353;

int main() {
    int n, m;
    std::cin >> n >> m;
    std::vector<mint> a(n), b(m);
    rep(i, 0, n) {
        int x;
        std::cin >> x;
        a[i] = x;
    }
    rep(i, 0, m) {
        int x;
        std::cin >> x;
        b[i] = x;
    }
    auto c = lib::convolution(a, b);
    rep(i, 0, c.size()) {
        std::cout << c[i].val() << " \n"[i == int(c.size()) - 1];
    }
}
Back to top page