This documentation is automatically generated by online-judge-tools/verification-helper
#include "math/mod_sqrt.hpp"
素数 $p$ に対して、 $x^2 \equiv y \pmod p$ となるような $x$ を返す。 $O(\log p)$
#pragma once
#include <cstdint>
#include <optional>
#include "../modint/dynamic_modint.hpp"
namespace ebi {
std::optional<std::int64_t> mod_sqrt(const std::int64_t &a,
const std::int64_t &p) {
if (a == 0 || a == 1) return a;
using mint = dynamic_modint<100>;
mint::set_mod(p);
if (mint(a).pow((p - 1) >> 1) != 1) return std::nullopt;
mint b = 1;
while (b.pow((p - 1) >> 1) == 1) b += 1;
std::int64_t m = p - 1, e = 0;
while (m % 2 == 0) m >>= 1, e++;
mint x = mint(a).pow((m - 1) >> 1);
mint y = mint(a) * x * x;
x *= a;
mint z = b.pow(m);
while (y != 1) {
std::int64_t j = 0;
mint t = y;
while (t != 1) {
j++;
t *= t;
}
z = z.pow(1ll << (e - j - 1));
x *= z;
z *= z;
y *= z;
e = j;
}
return x.val();
}
} // namespace ebi
#line 2 "math/mod_sqrt.hpp"
#include <cstdint>
#include <optional>
#line 2 "modint/dynamic_modint.hpp"
#include <cassert>
#line 2 "modint/base.hpp"
#include <concepts>
#include <iostream>
#include <utility>
namespace ebi {
template <class T>
concept Modint = requires(T a, T b) {
a + b;
a - b;
a * b;
a / b;
a.inv();
a.val();
a.pow(std::declval<long long>());
T::mod();
};
template <Modint mint> std::istream &operator>>(std::istream &os, mint &a) {
long long x;
os >> x;
a = x;
return os;
}
template <Modint mint>
std::ostream &operator<<(std::ostream &os, const mint &a) {
return os << a.val();
}
} // namespace ebi
#line 6 "modint/dynamic_modint.hpp"
namespace ebi {
template <int id> struct dynamic_modint {
private:
using modint = dynamic_modint;
public:
static void set_mod(int p) {
assert(1 <= p);
m = p;
}
static int mod() {
return m;
}
modint raw(int v) {
modint x;
x._v = v;
return x;
}
dynamic_modint() : _v(0) {}
dynamic_modint(long long v) {
v %= (long long)umod();
if (v < 0) v += (long long)umod();
_v = (unsigned int)v;
}
unsigned int val() const {
return _v;
}
unsigned int value() const {
return val();
}
modint &operator++() {
_v++;
if (_v == umod()) _v = 0;
return *this;
}
modint &operator--() {
if (_v == 0) _v = umod();
_v--;
return *this;
}
modint &operator+=(const modint &rhs) {
_v += rhs._v;
if (_v >= umod()) _v -= umod();
return *this;
}
modint &operator-=(const modint &rhs) {
_v -= rhs._v;
if (_v >= umod()) _v += umod();
return *this;
}
modint &operator*=(const modint &rhs) {
unsigned long long x = _v;
x *= rhs._v;
_v = (unsigned int)(x % (unsigned long long)umod());
return *this;
}
modint &operator/=(const modint &rhs) {
return *this = *this * rhs.inv();
}
modint operator+() const {
return *this;
}
modint operator-() const {
return modint() - *this;
}
modint pow(long long n) const {
assert(0 <= n);
modint x = *this, res = 1;
while (n) {
if (n & 1) res *= x;
x *= x;
n >>= 1;
}
return res;
}
modint inv() const {
assert(_v);
return pow(umod() - 2);
}
friend modint operator+(const modint &lhs, const modint &rhs) {
return modint(lhs) += rhs;
}
friend modint operator-(const modint &lhs, const modint &rhs) {
return modint(lhs) -= rhs;
}
friend modint operator*(const modint &lhs, const modint &rhs) {
return modint(lhs) *= rhs;
}
friend modint operator/(const modint &lhs, const modint &rhs) {
return modint(lhs) /= rhs;
}
friend bool operator==(const modint &lhs, const modint &rhs) {
return lhs.val() == rhs.val();
}
friend bool operator!=(const modint &lhs, const modint &rhs) {
return !(lhs == rhs);
}
private:
unsigned int _v = 0;
static int m;
static unsigned int umod() {
return m;
}
};
template <int id> int dynamic_modint<id>::m = 998244353;
} // namespace ebi
#line 7 "math/mod_sqrt.hpp"
namespace ebi {
std::optional<std::int64_t> mod_sqrt(const std::int64_t &a,
const std::int64_t &p) {
if (a == 0 || a == 1) return a;
using mint = dynamic_modint<100>;
mint::set_mod(p);
if (mint(a).pow((p - 1) >> 1) != 1) return std::nullopt;
mint b = 1;
while (b.pow((p - 1) >> 1) == 1) b += 1;
std::int64_t m = p - 1, e = 0;
while (m % 2 == 0) m >>= 1, e++;
mint x = mint(a).pow((m - 1) >> 1);
mint y = mint(a) * x * x;
x *= a;
mint z = b.pow(m);
while (y != 1) {
std::int64_t j = 0;
mint t = y;
while (t != 1) {
j++;
t *= t;
}
z = z.pow(1ll << (e - j - 1));
x *= z;
z *= z;
y *= z;
e = j;
}
return x.val();
}
} // namespace ebi