This documentation is automatically generated by online-judge-tools/verification-helper
#include "library/math/sum_i^d_r^i.hpp"$\displaystyle \dfrac{1}{(1 - x) ^ {k + 1}} = \sum _ {i = 0} ^ \infty \binom{i + k}{k} x ^ i$ を思い出すと、$\displaystyle \sum _ {i = 0} ^ \infty i ^ d (rx) ^ i = \dfrac{f}{(1 - rx) ^ {d + 1}}$ なる高々 $d$ 次の $f$ が存在することが分かる。いま計算したいのは $[x ^ {n - 1}] \dfrac{f}{(1 - x)(1 - rx) ^ {d + 1}}$ である。
以下のような、定数 $c$ と高々 $d$ 次の $g$ による部分分数分解を考える。
\[\dfrac{f}{(1 - x)(1 - rx) ^ {d + 1}} = \dfrac{c}{1 - x} + \dfrac{g}{(1 - rx) ^ {d + 1}}.\]両辺に $(1 - rx) ^ {d + 1}$ を掛けて次を得る。
\[(1 - rx) ^ {d + 1}\cdot \dfrac{f}{(1 - x)(1 - rx) ^ {d + 1}} = c\cdot \dfrac{(1 - rx) ^ {d + 1}}{1 - x} + g.\]$g$ は高々 $d$ 次であるから、両辺の $d + 1$ 次の項を比較することで $c$ を得ることができる。
左辺の $d + 1$ 次の項の係数の計算に関しては、$\dfrac{f}{(1 - x)(1 - rx) ^ {d + 1}}$ の $d + 1$ 次以下の項は実際に $\displaystyle \sum _ {i = 0} ^ j i ^ d r ^ i$ を計算することで得られ、$(1 - rx) ^ {d + 1}$ に関しては二項定理で展開すればよいので、$O(d)$ 時間。
右辺の $d + 1$ 次の項の係数の計算に関しても、$\dfrac{1}{1 - x}$ を掛けることが累積和を取る操作と対応しているため $(1 - rx) ^ {d + 1}$ を同様に二項定理で展開することで $O(d)$ 時間。
以上で定数 $c$ を $O(d)$ 時間で得ることができた。また、これにより $\dfrac{g}{(1 - rx) ^ {d + 1}} = \dfrac{f}{(1 - x)(1 - rx) ^ {d + 1}} - \dfrac{c}{1 - x}$ の $d + 1$ 次以下の係数を求めることができる。
$g$ の $i$ 次の項の係数を $g _ i$ とおくと、$g$ は高々 $d$ 次であるから、次が成り立つ。
\[\begin{aligned} \dfrac{g}{(1 - rx) ^ {d + 1}} &= \sum _ {i = 0} ^ \infty x ^ i \sum _ {j = 0} ^ {d}g _ j \cdot \binom{i - j + d}{d} \cdot r ^ {i - j} \newline &= \sum _ {i = 0} ^ \infty x ^ i r ^ i\sum _ {j = 0} ^ d \dfrac{g _ j}{r ^ j \cdot d!} \prod _ {k = 1} ^ d (i - j + k). \end{aligned}\]即ち、$\displaystyle h(x) = \sum _ {j = 0} ^ d \dfrac{g _ j}{r ^ j \cdot d!} \prod _ {k = 1} ^ d (x - j + k)$ とおくと、$[x ^ i] \dfrac{g}{(1 - rx) ^ {d + 1}} = r ^ i \cdot h(i)$ が成り立つ。いま $h(0),\ldots,h(d)$ の値は既知であるから、ラグランジュの補間公式により $h(n - 1)$ を $O(d)$ 時間で計算することができる。計算すべき値は $[x ^ {n - 1}] \dfrac{f}{(1 - x)(1 - rx) ^ {d + 1}} = c + r ^ {n - 1} h(n - 1)$ であるから、結局全体 $O(d + \log n)$ 時間で $\displaystyle \sum _ {i = 0} ^ {n - 1} i ^ d r ^ i$ を計算することができた。
($r\neq 1$ の仮定の下で) 上で求めた $c$ が答えになるらしいが理解していない。
#ifndef SUISEN_SUM_I_D_R_I
#define SUISEN_SUM_I_D_R_I
#include "library/sequence/powers.hpp"
#include "library/math/factorial.hpp"
#include "library/math/pow_mods.hpp"
#include "library/polynomial/lagrange_interpolation.hpp"
#include "library/polynomial/shift_of_sampling_points.hpp"
namespace suisen {
template <typename mint>
struct sum_i_i_pow_d_r_pow_i {
sum_i_i_pow_d_r_pow_i(int d, mint r) : d(d), r(r), i_pow_d(powers<mint>(d + 1, d)), r_pow_i(r, d + 1), fac(d), c(calc_c()) {}
mint sum() const {
assert(r != 1);
return c;
}
mint sum(long long n) {
if (r == 0) return n > 0 and d == 0 ? 1 : 0;
prepare();
return lagrange_interpolation_arithmetic_progression<mint>(ys, n) * r.pow(n) + c;
}
std::vector<mint> sum(long long t, int m) {
if (r == 0) {
std::vector<mint> res(m);
for (long long n = t; n < t + m; ++n) res[n - t] = sum(n);
return res;
}
prepare();
auto res = shift_of_sampling_points<mint>(ys, t, m);
mint pr = r.pow(r);
for (auto &e : res) e *= pr, e += c, pr *= r;
return res;
}
private:
int d;
mint r;
std::vector<mint> i_pow_d;
pow_mods<mint> r_pow_i;
factorial<mint> fac;
mint c;
std::vector<mint> ys;
bool prepared = false;
mint calc_c() {
if (r == 1) return 0;
mint num = 0, den = 0, sum = 0;
for (int i = 0; i <= d + 1; ++i) {
sum += i_pow_d[i] * r_pow_i[i];
den += (i & 1 ? -1 : +1) * fac.binom(d + 1, i) * r_pow_i[i];
num += ((d + 1 - i) & 1 ? -1 : +1) * fac.binom(d + 1, d + 1 - i) * r_pow_i[d + 1 - i] * sum;
}
return num / den;
}
void prepare() {
if (prepared) return;
prepared = true;
ys.resize(d + 2);
for (int i = 0; i <= d; ++i) ys[i + 1] = ys[i] + r_pow_i[i] * i_pow_d[i];
if (r == 1) return;
for (auto& e : ys) e -= c;
mint inv_r = r.inv();
mint pow_inv_r = inv_r.pow(d + 1);
for (int i = d + 1; i >= 0; --i) {
ys[i] *= pow_inv_r;
pow_inv_r *= r;
}
}
};
} // namespace suisen
#endif // SUISEN_SUM_I_D_R_I#line 1 "library/math/sum_i^d_r^i.hpp"
#line 1 "library/sequence/powers.hpp"
#include <cstdint>
#line 1 "library/number/linear_sieve.hpp"
#include <cassert>
#include <numeric>
#include <vector>
namespace suisen {
// reference: https://37zigen.com/linear-sieve/
class LinearSieve {
public:
LinearSieve(const int n) : _n(n), min_prime_factor(std::vector<int>(n + 1)) {
std::iota(min_prime_factor.begin(), min_prime_factor.end(), 0);
prime_list.reserve(_n / 20);
for (int d = 2; d <= _n; ++d) {
if (min_prime_factor[d] == d) prime_list.push_back(d);
const int prime_max = std::min(min_prime_factor[d], _n / d);
for (int prime : prime_list) {
if (prime > prime_max) break;
min_prime_factor[prime * d] = prime;
}
}
}
int prime_num() const noexcept { return prime_list.size(); }
/**
* Returns a vector of primes in [0, n].
* It is guaranteed that the returned vector is sorted in ascending order.
*/
const std::vector<int>& get_prime_list() const noexcept {
return prime_list;
}
const std::vector<int>& get_min_prime_factor() const noexcept { return min_prime_factor; }
/**
* Returns a vector of `{ prime, index }`.
* It is guaranteed that the returned vector is sorted in ascending order.
*/
std::vector<std::pair<int, int>> factorize(int n) const noexcept {
assert(0 < n and n <= _n);
std::vector<std::pair<int, int>> prime_powers;
while (n > 1) {
int p = min_prime_factor[n], c = 0;
do { n /= p, ++c; } while (n % p == 0);
prime_powers.emplace_back(p, c);
}
return prime_powers;
}
private:
const int _n;
std::vector<int> min_prime_factor;
std::vector<int> prime_list;
};
} // namespace suisen
#line 6 "library/sequence/powers.hpp"
namespace suisen {
// returns { 0^k, 1^k, ..., n^k }
template <typename mint>
std::vector<mint> powers(uint32_t n, uint64_t k) {
const auto mpf = LinearSieve(n).get_min_prime_factor();
std::vector<mint> res(n + 1);
res[0] = k == 0;
for (uint32_t i = 1; i <= n; ++i) res[i] = i == 1 ? 1 : uint32_t(mpf[i]) == i ? mint(i).pow(k) : res[mpf[i]] * res[i / mpf[i]];
return res;
}
} // namespace suisen
#line 1 "library/math/factorial.hpp"
#line 6 "library/math/factorial.hpp"
namespace suisen {
// 引数として与える値に対して、法が十分大きいことを仮定する
template <typename T, typename U = T>
struct factorial {
factorial() = default;
factorial(int n) { ensure(n); }
static void ensure(const int n) {
int sz = _fac.size();
if (n + 1 <= sz) return;
int new_size = std::max(n + 1, sz * 2);
_fac.resize(new_size), _fac_inv.resize(new_size);
for (int i = sz; i < new_size; ++i) _fac[i] = _fac[i - 1] * i;
_fac_inv[new_size - 1] = U(1) / _fac[new_size - 1];
for (int i = new_size - 1; i > sz; --i) _fac_inv[i - 1] = _fac_inv[i] * i;
}
T fac(const int i) {
ensure(i);
return _fac[i];
}
T operator()(int i) {
return fac(i);
}
U fac_inv(const int i) {
ensure(i);
return _fac_inv[i];
}
// i の逆数
// i = 0 の場合は assert 違反となる
U inv(const int i) {
assert(i > 0);
ensure(i);
return _fac_inv[i] * _fac[i - 1];
}
U binom(const int n, const int r) {
if (n < 0 or r < 0 or n < r) return 0;
ensure(n);
return _fac[n] * _fac_inv[r] * _fac_inv[n - r];
}
// binom(n, r) の逆数
// binom(n, r) = 0 の場合は assert 違反となる
U binom_inv(const int n, const int r) {
assert(r >= 0 and n >= r);
ensure(n);
return _fac_inv[n] * _fac[r] * _fac[n - r];
}
// n 種類から重複を許して r 個選ぶ場合の数
// x_1+x_2+...+x_n=r(x_i は非負整数)となる x の個数でもある
// multichoose(n, r) = binom(n + r - 1, r)
U multichoose(const int n, const int r) {
if (n < 0 or r < 0) return 0;
return r > 0 ? binom(n + r - 1, r) : U(1);
}
// n 種類から重複を許して r 個選ぶ場合の数 multichoose(n, r) の逆数
// x_1+x_2+...+x_n=r(x_i は非負整数)となる x の個数の逆数でもある
// multichoose(n, r) = binom(n + r - 1, r)
// multichoose(n, r) = 0 の場合は assert 違反となる
U multichoose_inv(const int n, const int r) {
assert(n >= 0 and r >= 0);
return r > 0 ? binom_inv(n + r - 1, r) : U(1);
}
template <typename ...Ds, std::enable_if_t<std::conjunction_v<std::is_integral<Ds>...>, std::nullptr_t> = nullptr>
U polynom(const int n, const Ds& ...ds) {
if (n < 0) return 0;
ensure(n);
int sumd = 0;
U res = _fac[n];
for (int d : { ds... }) {
if (d < 0 or d > n) return 0;
sumd += d;
res *= _fac_inv[d];
}
if (sumd > n) return 0;
res *= _fac_inv[n - sumd];
return res;
}
U perm(const int n, const int r) {
if (n < 0 or r < 0 or n < r) return 0;
ensure(n);
return _fac[n] * _fac_inv[n - r];
}
// perm(n, r) の逆数
// perm(n, r) = 0 の場合は assert 違反となる
U perm_inv(const int n, const int r) {
assert(r >= 0 and n >= r);
ensure(n);
return _fac_inv[n] * _fac[n - r];
}
private:
static std::vector<T> _fac;
static std::vector<U> _fac_inv;
};
template <typename T, typename U>
std::vector<T> factorial<T, U>::_fac{ 1 };
template <typename T, typename U>
std::vector<U> factorial<T, U>::_fac_inv{ 1 };
} // namespace suisen
#line 1 "library/math/pow_mods.hpp"
#line 5 "library/math/pow_mods.hpp"
namespace suisen {
template <int base_as_int, typename mint>
struct static_pow_mods {
static_pow_mods() = default;
static_pow_mods(int n) { ensure(n); }
const mint& operator[](int i) const {
ensure(i);
return pows[i];
}
static void ensure(int n) {
int sz = pows.size();
if (sz > n) return;
pows.resize(n + 1);
for (int i = sz; i <= n; ++i) pows[i] = base * pows[i - 1];
}
private:
static inline std::vector<mint> pows { 1 };
static inline mint base = base_as_int;
static constexpr int mod = mint::mod();
};
template <typename mint>
struct pow_mods {
pow_mods() = default;
pow_mods(mint base, int n) : base(base) { ensure(n); }
const mint& operator[](int i) const {
ensure(i);
return pows[i];
}
void ensure(int n) const {
int sz = pows.size();
if (sz > n) return;
pows.resize(n + 1);
for (int i = sz; i <= n; ++i) pows[i] = base * pows[i - 1];
}
private:
mutable std::vector<mint> pows { 1 };
mint base;
static constexpr int mod = mint::mod();
};
}
#line 1 "library/polynomial/lagrange_interpolation.hpp"
#line 1 "library/math/product_of_differences.hpp"
#include <deque>
#line 1 "library/polynomial/multi_point_eval.hpp"
#line 5 "library/polynomial/multi_point_eval.hpp"
namespace suisen {
template <typename FPSType, typename T>
std::vector<typename FPSType::value_type> multi_point_eval(const FPSType& f, const std::vector<T>& xs) {
int n = xs.size();
if (n == 0) return {};
std::vector<FPSType> seg(2 * n);
for (int i = 0; i < n; ++i) seg[n + i] = FPSType{ -xs[i], 1 };
for (int i = n - 1; i > 0; --i) seg[i] = seg[i * 2] * seg[i * 2 + 1];
seg[1] = f % seg[1];
for (int i = 2; i < 2 * n; ++i) seg[i] = seg[i / 2] % seg[i];
std::vector<typename FPSType::value_type> ys(n);
for (int i = 0; i < n; ++i) ys[i] = seg[n + i].size() ? seg[n + i][0] : 0;
return ys;
}
} // namespace suisen
#line 6 "library/math/product_of_differences.hpp"
namespace suisen {
/**
* O(N(logN)^2)
* return the vector p of length xs.size() s.t. p[i]=Π[j!=i](x[i]-x[j])
*/
template <typename FPSType, typename T>
std::vector<typename FPSType::value_type> product_of_differences(const std::vector<T>& xs) {
// f(x):=Π_i(x-x[i])
// => f'(x)=Σ_i Π[j!=i](x-x[j])
// => f'(x[i])=Π[j!=i](x[i]-x[j])
const int n = xs.size();
std::deque<FPSType> dq;
for (int i = 0; i < n; ++i) dq.push_back(FPSType{ -xs[i], 1 });
while (dq.size() >= 2) {
auto f = std::move(dq.front());
dq.pop_front();
auto g = std::move(dq.front());
dq.pop_front();
dq.push_back(f * g);
}
auto f = std::move(dq.front());
f.diff_inplace();
return multi_point_eval<FPSType, T>(f, xs);
}
} // namespace suisen
#line 5 "library/polynomial/lagrange_interpolation.hpp"
namespace suisen {
// O(N^2+NlogP)
template <typename T>
T lagrange_interpolation_naive(const std::vector<T>& xs, const std::vector<T>& ys, const T t) {
const int n = xs.size();
assert(int(ys.size()) == n);
T p{ 1 };
for (int i = 0; i < n; ++i) p *= t - xs[i];
T res{ 0 };
for (int i = 0; i < n; ++i) {
T w = 1;
for (int j = 0; j < n; ++j) if (j != i) w *= xs[i] - xs[j];
res += ys[i] * (t == xs[i] ? 1 : p / (w * (t - xs[i])));
}
return res;
}
// O(N(logN)^2+NlogP)
template <typename FPSType, typename T>
typename FPSType::value_type lagrange_interpolation(const std::vector<T>& xs, const std::vector<T>& ys, const T t) {
const int n = xs.size();
assert(int(ys.size()) == n);
std::vector<FPSType> seg(2 * n);
for (int i = 0; i < n; ++i) seg[n + i] = FPSType {-xs[i], 1};
for (int i = n - 1; i > 0; --i) seg[i] = seg[i * 2] * seg[i * 2 + 1];
seg[1] = seg[1].diff() % seg[1];
for (int i = 2; i < 2 * n; ++i) seg[i] = seg[i / 2] % seg[i];
using mint = typename FPSType::value_type;
mint p{ 1 };
for (int i = 0; i < n; ++i) p *= t - xs[i];
mint res{ 0 };
for (int i = 0; i < n; ++i) {
mint w = seg[n + i][0];
res += ys[i] * (t == xs[i] ? 1 : p / (w * (t - xs[i])));
}
return res;
}
// xs[i] = ai + b
// requirement: for all 0≤i<j<n, ai+b ≢ aj+b mod p
template <typename T>
T lagrange_interpolation_arithmetic_progression(T a, T b, const std::vector<T>& ys, const T t) {
const int n = ys.size();
T fac = 1;
for (int i = 1; i < n; ++i) fac *= i;
std::vector<T> fac_inv(n), suf(n);
fac_inv[n - 1] = T(1) / fac;
suf[n - 1] = 1;
for (int i = n - 1; i > 0; --i) {
fac_inv[i - 1] = fac_inv[i] * i;
suf[i - 1] = suf[i] * (t - (a * i + b));
}
T pre = 1, res = 0;
for (int i = 0; i < n; ++i) {
T val = ys[i] * pre * suf[i] * fac_inv[i] * fac_inv[n - i - 1];
if ((n - 1 - i) & 1) res -= val;
else res += val;
pre *= t - (a * i + b);
}
return res / a.pow(n - 1);
}
// x = 0, 1, ...
template <typename T>
T lagrange_interpolation_arithmetic_progression(const std::vector<T>& ys, const T t) {
return lagrange_interpolation_arithmetic_progression(T{1}, T{0}, ys, t);
}
} // namespace suisen
#line 1 "library/polynomial/shift_of_sampling_points.hpp"
#line 5 "library/polynomial/shift_of_sampling_points.hpp"
#include <atcoder/convolution>
#line 8 "library/polynomial/shift_of_sampling_points.hpp"
namespace suisen {
template <typename mint, typename Convolve,
std::enable_if_t<std::is_invocable_r_v<std::vector<mint>, Convolve, std::vector<mint>, std::vector<mint>>, std::nullptr_t> = nullptr>
std::vector<mint> shift_of_sampling_points(const std::vector<mint>& ys, mint t, int m, const Convolve &convolve) {
const int n = ys.size();
factorial<mint> fac(std::max(n, m));
std::vector<mint> b = [&] {
std::vector<mint> f(n), g(n);
for (int i = 0; i < n; ++i) {
f[i] = ys[i] * fac.fac_inv(i);
g[i] = (i & 1 ? -1 : 1) * fac.fac_inv(i);
}
std::vector<mint> b = convolve(f, g);
b.resize(n);
return b;
}();
std::vector<mint> e = [&] {
std::vector<mint> c(n);
mint prd = 1;
std::reverse(b.begin(), b.end());
for (int i = 0; i < n; ++i) {
b[i] *= fac.fac(n - i - 1);
c[i] = prd * fac.fac_inv(i);
prd *= t - i;
}
std::vector<mint> e = convolve(b, c);
e.resize(n);
return e;
}();
std::reverse(e.begin(), e.end());
for (int i = 0; i < n; ++i) {
e[i] *= fac.fac_inv(i);
}
std::vector<mint> f(m);
for (int i = 0; i < m; ++i) f[i] = fac.fac_inv(i);
std::vector<mint> res = convolve(e, f);
res.resize(m);
for (int i = 0; i < m; ++i) res[i] *= fac.fac(i);
return res;
}
template <typename mint>
std::vector<mint> shift_of_sampling_points(const std::vector<mint>& ys, mint t, int m) {
auto convolve = [&](const std::vector<mint> &f, const std::vector<mint> &g) { return atcoder::convolution(f, g); };
return shift_of_sampling_points(ys, t, m, convolve);
}
} // namespace suisen
#line 9 "library/math/sum_i^d_r^i.hpp"
namespace suisen {
template <typename mint>
struct sum_i_i_pow_d_r_pow_i {
sum_i_i_pow_d_r_pow_i(int d, mint r) : d(d), r(r), i_pow_d(powers<mint>(d + 1, d)), r_pow_i(r, d + 1), fac(d), c(calc_c()) {}
mint sum() const {
assert(r != 1);
return c;
}
mint sum(long long n) {
if (r == 0) return n > 0 and d == 0 ? 1 : 0;
prepare();
return lagrange_interpolation_arithmetic_progression<mint>(ys, n) * r.pow(n) + c;
}
std::vector<mint> sum(long long t, int m) {
if (r == 0) {
std::vector<mint> res(m);
for (long long n = t; n < t + m; ++n) res[n - t] = sum(n);
return res;
}
prepare();
auto res = shift_of_sampling_points<mint>(ys, t, m);
mint pr = r.pow(r);
for (auto &e : res) e *= pr, e += c, pr *= r;
return res;
}
private:
int d;
mint r;
std::vector<mint> i_pow_d;
pow_mods<mint> r_pow_i;
factorial<mint> fac;
mint c;
std::vector<mint> ys;
bool prepared = false;
mint calc_c() {
if (r == 1) return 0;
mint num = 0, den = 0, sum = 0;
for (int i = 0; i <= d + 1; ++i) {
sum += i_pow_d[i] * r_pow_i[i];
den += (i & 1 ? -1 : +1) * fac.binom(d + 1, i) * r_pow_i[i];
num += ((d + 1 - i) & 1 ? -1 : +1) * fac.binom(d + 1, d + 1 - i) * r_pow_i[d + 1 - i] * sum;
}
return num / den;
}
void prepare() {
if (prepared) return;
prepared = true;
ys.resize(d + 2);
for (int i = 0; i <= d; ++i) ys[i + 1] = ys[i] + r_pow_i[i] * i_pow_d[i];
if (r == 1) return;
for (auto& e : ys) e -= c;
mint inv_r = r.inv();
mint pow_inv_r = inv_r.pow(d + 1);
for (int i = d + 1; i >= 0; --i) {
ys[i] *= pow_inv_r;
pow_inv_r *= r;
}
}
};
} // namespace suisen