Montgomery/Barret reduction and NTT [Performance optimization]
Разница между en1 и en2, 2 символ(ов) изменены
$\DeclareMathOperator{\ord}{ord}$↵
$\DeclareMathOperator{\mod}{mod}$↵

Montgomery/Barret reduction and NTT↵
==================↵
  ↵
**In this article** will be shown how to apply Montgomery/Barret reduction of modular arithmetic to Number Theoretic Transform and combinatorial tasks. We will implement NTT with minimum of modular operations that beats classic FFT both in precision and speed. We will compare performance of various optimizations of NTT and FFT in complex field. Also will be shown how to add support of negative coefficients to NTT and replace FFT.  ↵
  ↵

1. Why use NTT↵
------------------↵
DFT (Discrete Fourier Transform) is widely used for polynomial multiplication. DFT can be calculated in every field such that there is nth root (element with order n). For instance, we can use complex numbers or numbers modulo some $p$ (finite field). Widely known realization &mdash; Fast Fourier Transform &mdash; calculates DFT in $O(n\log{n})$ both in complex numbers and integers (only arithmacy varies). FFT in $\mathbb{Z}/p\mathbb{Z}$ called NTT (number theoretic transform). The basic algorithm is reviewed in many articles, e.g. <a href="https://cp-algorithms.com/algebra/fft.html">CP-algorithms</a>. NTT is barely used because standard implementation with big modules is very slow without hardware acceleration. There will be shown how to implement fast NTT with implicit modular operations.  ↵
  ↵
Of course, we have Fast Fourier transform on complex doubles that widely used in polynomial multiplication (and not only that), including solving tasks on Codeforces. But FFT has very bad precision and quite slow (even with non-recursive implementation). Therefore, if we are given polynomials with non-negative integer coefficients to multiply, we are able to use NTT: calculate coefficients some modulo $p$ and recover actual answer up to $10^{18}$ (either Chinese Remainder theorem can be used for recovering answer by two modules or single module $~10^{18}$ with 128-bit cast).  ↵
  ↵
**The main advantage** over FFT is precision &mdash; we can get actual product with coefficients in $[0; 10^{18}]$ where with FFT we will get error 10-100 with doubles and 1-10 with long doubles (for instance, [submission:259759436] gets WA43 in polynomial multiplication task with FFT with doubles, but [submission:259759880] with long doubles gets OK).  ↵
  ↵
NTT is not faster out of the box. Majority implementations of NTT uses a lot of modular arithmetic so at many hardware (especiallly without SIMD) it runs a slower that FFT. But if we will be able to get rid of modular operations then NTT will run much faster than FFT with doubles. And there we can use Montgomery or Barret reduction to multiply integers in special form without % operation.  ↵
  ↵

2. Montgomery reduction↵
------------------↵
Montgomery reduction is approach to operate with unsigned integers on some module without % division. The main idea that we can reduce multiplication modulo $p$ to multiplication modulo $2^k$ that can be implemented using bit manipulations. Recommended to see <a href="https://en.algorithmica.org/hpc/number-theory/montgomery/">the guide</a>. Montgomery reduction requires numbers to operate be in special form (Montgomery space). Conversion into this form and vice versa requires using actual modular division but NTT has special case: all base numbers that we will multiply modulo are source coefficients, root and some inverses (n, root). We can do it in $O(n)$ before using NTT and after it (to convert into normal form). Hence, we got rid of all modular division inside NTT. Single nuance there is that if we calculating NTT over big modulo ($~10^{18}$) we have to use __int128. Therefore, this algorithm has big potential to be used with AVX. There's even a <a href="https://networkbuilders.intel.com/docs/networkbuilders/intel-avx-512-fast-modular-multiplication-technique-technology-guide-1710916893.pdf">Intel guide</a> for optimization for Montgomery multiplication via AVX.  ↵
  ↵
P.S. But if we want to reduce 128 bit integers without loss of size of coefficients, we can calculate NTT over two modules using only 64-bit integers (can be calculated in single call in parallel) and then solve system of congruences over two prime modules via Chinese Remainder Theorem. Then __int128 will be used only in CRT.  ↵
  ↵

3. Barret reduction↵
------------------↵
Another way to reduce modular arithmetic is Barret reduction. Barrett is based on approximating the real reciprocal with bounded precision. Certainly, let's transform expression for remainder:  ↵
$$ a \mod p = a - \lfloor {a \over p} \rfloor \cdot p $$  ↵
Choose $m$, such that:  ↵
$$ {1 \over p} = {m \over 2^k} \longleftrightarrow m = {2^k \over 
np} $$↵
If we take big enough $2^k$ (like $p^2$) answer for numbers in our range (<$10^{18}$) will be correct. See full proof <a href="https://www.nayuki.io/page/barrett-reduction-algorithm">here</a>. Barret reduction gives us the same functionality as Montgomery but doesn't require converting to special form. Main nuance that we have to use bit shift to more than $p^2$ so if we use modulo $~10^{18}$ 256-bit integers is required so operations is more exprensive.  ↵
  ↵

4. Good modules↵
------------------↵
Firstly, we have to choose module $p$ such that $n | (p-1): n=2^k$ in order for existence nth root (order of element divides order of cyclic group from Lagrange theorem). Since we have $g$ &mdash; primitive root, consider $g^{p-1 \over n}$ as nth root. For concrete module $p$ such roots can be easily found via binary exponentiation and primitive root searching algorithm in $O(Ans \cdot \log^2(p))$, $Ans$ is always small by theory of roots in finite field. See <a href="https://cp-algorithms.com/algebra/primitive-root.html">proof of correctness</a>. For example, for work with coefficients up to $2 \cdot 10^{18}$ and size of polynomial up to $2^{24}$ we can use module $p=2524775926340780033$. Primitive root is $3$.  ↵

<spoiler summary="Primitive root searching algorithm">↵
```c++↵
// Primitive root modulo n↵
// (generator of cyclic group with n-1 elements)↵
int generator(int n) {↵
    vector<int> fact;↵
    int phi = euler_totient(n); // for prime equals n-1↵
    int m = phi;↵
    for (int d = 2; d*d <= m; ++d)↵
        if (m%d == 0) {↵
            fact.push_back(d);↵
            while (m%d == 0)↵
                m /= d;↵
        }↵
    if (m > 1)↵
        fact.push_back(m);↵
    for (int root = 2; root <= n; ++root) {↵
        bool found = true;↵
        for (auto d : fact)↵
            if (bin_pow(root, phi / d, n) == 1) {↵
                found = false;↵
                break;↵
            }↵
        if (found)↵
            return root;↵
    }↵
    return n == 1 ? 1 : -1;↵
}↵

```↵
</spoiler>↵

  ↵

5. Implementations↵
------------------↵
We will use non-recursive realization of FFT (NTT) with external structure for Montgomery/Barret arithmetic. Convenient to use unsigned integers while dealing with reductions because of trick with overflow: it doesn't lead to undefined behavior but returns answer modulo $2^{64}$.  ↵

<spoiler summary="NTT with Montgomery multiplication">↵
```c++↵
uint64_t bin_pow(uint64_t n, uint64_t p, uint64_t mod) { /**    n*m = 1 (mod p)  =>  m = n**(p-2) (mod p)    **/↵
    uint64_t res = 1;↵
    while (p) {↵
        if (p & 1)↵
            res = ((__uint128_t) res * n) % mod;↵
        n = ((__uint128_t) n * n) % mod;↵
        p >>= 1;↵
    }↵
    return res;↵
}↵


struct montgomery {↵
    uint64_t n, nr;↵

    constexpr montgomery(uint64_t n) : n(n), nr(1) {↵
        // log(2^64) = 6↵
        for (int i = 0; i < 6; i++)↵
            nr *= 2 - n * nr;↵
    }↵

    [[nodiscard]]↵
    uint64_t reduce(__uint128_t x) const {↵
        uint64_t q = __uint128_t(x) * nr;↵
        uint64_t m = ((__uint128_t) q * n) >> 64;↵
        uint64_t res = (x >> 64) + n - m;↵
        if (res >= n)↵
            res -= n;↵
        return res;↵
    }↵

    [[nodiscard]]↵
    uint64_t multiply(uint64_t x, uint64_t y) const {↵
        return reduce((__uint128_t) x * y);↵
    }↵

    [[nodiscard]]↵
    uint64_t transform(uint64_t x) const {↵
        return (__uint128_t(x) << 64) % n;↵
    }↵
};↵


vector<int> bit_sort(int n) {↵
    int h = -1;↵
    vector<int> rev(n, 0);↵
    int skip = __lg(n) - 1;↵
    for (int i = 1; i < n; ++i) {↵
        if (!(i & (i - 1)))↵
            ++h;↵
        rev[i] = rev[i ^ (1 << h)] | (1 << (skip - h));↵
    }↵
    return rev;↵
}↵


const uint64_t mod = 2524775926340780033, gen = 3;↵
//const uint64_t mod = 998244353, gen = 3;↵
void ntt(vector<uint64_t>& a, vector<int>& rev, montgomery& red, ↵
                  uint64_t inv_n, uint64_t root, uint64_t inv_root, bool invert) {↵
    int n = (int)a.size();↵

    for (int i = 0; i < n; ++i)↵
        if (i < rev[i])↵
            swap(a[i], a[rev[i]]);↵

    uint64_t w = invert ? inv_root : root;↵
    vector<uint64_t> W(n >> 1, red.transform(1));↵
    for (int i = 1; i < (n >> 1); ++i)↵
        W[i] = red.multiply(W[i-1], w);↵

    int lim = __lg(n);↵
    for (int i = 0; i < lim; ++i)↵
        for (int j = 0; j < n; ++j)↵
            if (!(j & (1 << i))) {↵
                uint64_t t = red.multiply(a[j ^ (1 << i)], W[(j & ((1 << i) - 1)) * (n >> (i + 1))]);↵
                a[j ^ (1 << i)] = a[j] >= t ? a[j] - t : a[j] + mod - t;↵
                a[j] = a[j] + t < mod ? a[j] + t : a[j] + t - mod;↵
            }↵

    if (invert)↵
        for (int i = 0; i < n; i++)↵
            a[i] = red.multiply(a[i], inv_n);↵
}↵


void mul(vector<uint64_t>& a, vector<uint64_t>& b) {↵
    montgomery red(mod);↵
    for (auto& x : a)↵
        x = red.transform(x);↵
    for (auto& x : b)↵
        x = red.transform(x);↵
    ↵
    int n = 1;↵
    while (n < a.size() || n < b.size())↵
        n <<= 1;↵
    n <<= 1;↵
    a.resize(n);↵
    b.resize(n);↵
    ↵
    uint64_t inv_n = red.transform(bin_pow(n, mod-2, mod));↵
    uint64_t root = red.transform(bin_pow(gen, (mod-1)/n, mod));↵
    uint64_t inv_root = red.transform(bin_pow(red.reduce(root), mod-2, mod));↵
    auto rev = bit_sort(n);↵
    ↵
    ntt(a, rev, red, inv_n, root, inv_root, false);↵
    ntt(b, rev, red, inv_n, root, inv_root, false);↵
    ↵
    for (int i = 0; i < n; i++)↵
        a[i] = red.multiply(a[i], b[i]);↵
    ntt(a, rev, red, inv_n, root, inv_root, true);↵

    for (auto& x : a)↵
        x = red.reduce(x);↵
}↵

```↵
</spoiler>↵

  ↵
Implementation with Barret reduction can only be used with modules $<2^{32}$ without 256-bit integers. Because Barret reduction a priori has more expensive operations, there isn't implementation with uint256. There is ability to use uint32 instead of uint64 in code but for equality of testing conditions there is same formats both for Montgomery and Barret.  ↵

<spoiler summary="NTT with Barret multiplication">↵
```c++↵
uint64_t bin_pow(uint64_t n, uint64_t p, uint64_t mod) { /**    n*m = 1 (mod p)  =>  m = n**(p-2) (mod p)    **/↵
    uint64_t res = 1;↵
    while (p) {↵
        if (p & 1)↵
            res = ((__uint128_t) res * n) % mod;↵
        n = ((__uint128_t) n * n) % mod;↵
        p >>= 1;↵
    }↵
    return res;↵
}↵


struct barret {↵
    uint64_t n, s;↵
    __uint128_t f;↵
 ↵
    constexpr barret(uint64_t _n) {↵
        n = _n;↵
        s = 64;↵
        f = (__uint128_t(1) << s) / n;↵
    }↵
 ↵
    [[nodiscard]]↵
    uint64_t reduce(__uint128_t x) const {↵
        auto t = (uint64_t)(x - ((x * f) >> s) * n);↵
        if (t < n)↵
            return t;↵
        return t - n;↵
    }↵
 ↵
    [[nodiscard]]↵
    uint64_t multiply(uint64_t x, uint64_t y) const {↵
        return reduce((__uint128_t) x * y);↵
    }↵
};↵


vector<int> bit_sort(int n) {↵
    int h = -1;↵
    vector<int> rev(n, 0);↵
    int skip = __lg(n) - 1;↵
    for (int i = 1; i < n; ++i) {↵
        if (!(i & (i - 1)))↵
            ++h;↵
        rev[i] = rev[i ^ (1 << h)] | (1 << (skip - h));↵
    }↵
    return rev;↵
}↵


//const uint64_t mod = 2524775926340780033, gen = 3;↵
const uint64_t mod = 998244353, gen = 3;↵
void ntt(vector<uint64_t>& a, vector<int>& rev, barret& red, ↵
                  uint64_t inv_n, uint64_t root, uint64_t inv_root, bool invert) {↵
    int n = (int)a.size();↵

    for (int i = 0; i < n; ++i)↵
        if (i < rev[i])↵
            swap(a[i], a[rev[i]]);↵

    uint64_t w = invert ? inv_root : root;↵
    vector<uint64_t> W(n >> 1, 1);↵
    for (int i = 1; i < (n >> 1); ++i)↵
        W[i] = red.multiply(W[i-1], w);↵

    int lim = __lg(n);↵
    for (int i = 0; i < lim; ++i)↵
        for (int j = 0; j < n; ++j)↵
            if (!(j & (1 << i))) {↵
                uint64_t t = red.multiply(a[j ^ (1 << i)], W[(j & ((1 << i) - 1)) * (n >> (i + 1))]);↵
                a[j ^ (1 << i)] = a[j] >= t ? a[j] - t : a[j] + mod - t;↵
                a[j] = a[j] + t < mod ? a[j] + t : a[j] + t - mod;↵
            }↵

    if (invert)↵
        for (int i = 0; i < n; i++)↵
            a[i] = red.multiply(a[i], inv_n);↵
}↵


void mul(vector<uint64_t>& a, vector<uint64_t>& b) {↵
    int n = 1;↵
    while (n < a.size() || n < b.size())↵
        n <<= 1;↵
    n <<= 1;↵
    a.resize(n);↵
    b.resize(n);↵

    barret red(mod);↵
    uint64_t inv_n = bin_pow(n, mod-2, mod);↵
    uint64_t root = bin_pow(gen, (mod-1)/n, mod);↵
    uint64_t inv_root = bin_pow(root, mod-2, mod);↵
    auto rev = bit_sort(n);↵

    ntt(a, rev, red, inv_n, root, inv_root, false);↵
    ntt(b, rev, red, inv_n, root, inv_root, false);↵

    for (int i = 0; i < n; i++)↵
        a[i] = red.multiply(a[i], b[i]);↵
    ntt(a, rev, red, inv_n, root, inv_root, true);↵
}↵
```↵
</spoiler>↵

  ↵
  ↵
Also there is provided FFT implementation with custom struct for complex numbers. It can be used with doubles / long doubles.  ↵

<spoiler summary="FFT with custom cmpls">↵
```c++↵
struct _cmpl {↵
    double a, b;↵
    _cmpl(double a = 0, double b = 0) : a(a), b(b) {}↵

    const _cmpl operator + (const _cmpl &c) const↵
    { return _cmpl(a + c.a, b + c.b); }↵

    const _cmpl operator - (const _cmpl &c) const↵
    { return _cmpl(a - c.a, b - c.b); }↵

    const _cmpl operator * (const _cmpl &c) const↵
    { return _cmpl(a * c.a - b * c.b, a * c.b + b * c.a); }↵
};↵


vector<int> bit_sort(int n) {↵
    int h = -1;↵
    vector<int> rev(n, 0);↵
    int skip = __lg(n) - 1;↵
    for (int i = 1; i < n; ++i) {↵
        if (!(i & (i - 1)))↵
            ++h;↵
        rev[i] = rev[i ^ (1 << h)] | (1 << (skip - h));↵
    }↵
    return rev;↵
}↵


void fft(vector<_cmpl>& a, vector<int>& rev, bool invert) {↵
    int n = a.size(), h = -1;↵
    for (int i = 0; i < n; ++i)↵
        if (i < rev[i])↵
            swap(a[i], a[rev[i]]);↵

    double alpha = 2 * atan2(0.00, -1.00) / n * (invert ? -1 : 1);↵
    _cmpl w1(cos(alpha), sin(alpha));↵
    vector<_cmpl> W(n >> 1, 1);↵
    for (int i = 1; i < (n >> 1); ++i)↵
        W[i] = W[i - 1] * w1;↵

    int lim = __lg(n);↵
    for (int i = 0; i < lim; ++i)↵
        for (int j = 0; j < n; ++j)↵
            if (!(j & (1 << i))) {↵
                _cmpl t = a[j ^ (1 << i)] * W[(j & ((1 << i) - 1)) * (n >> (i + 1))];↵
                a[j ^ (1 << i)] = a[j] - t;↵
                a[j] = a[j] + t;↵
            }↵

    if (invert)↵
        for (int i = 0; i < n; i++)↵
            a[i] = _cmpl(a[i].a / n, a[i].b / n);↵
}↵


void mul(vector<_cmpl>& a, vector<_cmpl>& b, vector<uint64_t>& res) {↵
    int n = 1;↵
    while (n < a.size() || n < b.size())↵
        n <<= 1;↵
    n <<= 1;↵
    a.resize(n);↵
    b.resize(n);↵
    auto rev = bit_sort(n);↵
    fft(a, rev, false);↵
    fft(b, rev, false);↵
    for (int i = 0; i < n; i++)↵
        a[i] = a[i] * b[i];↵
    fft(a, rev, true);↵
    res.resize(n);↵
    for (int i = 0; i < n; i++)↵
        res[i] = (uint64_t)(a[i].a + 0.1);↵
}↵
```↵
</spoiler>↵

  ↵

NTT without optimizations:  ↵

<spoiler summary="NTT without opt">↵
```c++↵
uint64_t bin_pow(uint64_t n, uint64_t p, uint64_t mod) { /**    n*m = 1 (mod p)  =>  m = n**(p-2) (mod p)    **/↵
    uint64_t res = 1;↵
    while (p) {↵
        if (p & 1)↵
            res = ((__uint128_t) res * n) % mod;↵
        n = ((__uint128_t) n * n) % mod;↵
        p >>= 1;↵
    }↵
    return res;↵
}↵

vector<int> bit_sort(int n) {↵
    int h = -1;↵
    vector<int> rev(n, 0);↵
    int skip = __lg(n) - 1;↵
    for (int i = 1; i < n; ++i) {↵
        if (!(i & (i - 1)))↵
            ++h;↵
        rev[i] = rev[i ^ (1 << h)] | (1 << (skip - h));↵
    }↵
    return rev;↵
}↵

//const uint64_t mod = 998244353, gen = 3;↵
const uint64_t mod = 2524775926340780033, gen = 3;↵
void ntt(vector<uint64_t>& a, vector<int>& rev,↵
             uint64_t inv_n, uint64_t root, uint64_t inv_root, bool invert) {↵
    int n = (int)a.size();↵

    for (int i = 0; i < n; ++i)↵
        if (i < rev[i])↵
            swap(a[i], a[rev[i]]);↵

    uint64_t w = invert ? inv_root : root;↵
    vector<uint64_t> W(n >> 1, 1);↵
    for (int i = 1; i < (n >> 1); ++i)↵
        W[i] = (__uint128_t) W[i-1] * w % mod;↵

    int lim = __lg(n);↵
    for (int i = 0; i < lim; ++i)↵
        for (int j = 0; j < n; ++j)↵
            if (!(j & (1 << i))) {↵
                uint64_t t = (__uint128_t) a[j ^ (1 << i)] * W[(j & ((1 << i) - 1)) * (n >> (i + 1))] % mod;↵
                a[j ^ (1 << i)] = a[j] >= t ? a[j] - t : a[j] + mod - t;↵
                a[j] = a[j] + t < mod ? a[j] + t : a[j] + t - mod;↵
            }↵

    if (invert)↵
        for (int i = 0; i < n; i++)↵
            a[i] = (__uint128_t) a[i] * inv_n % mod;↵
}↵


void mul(vector<uint64_t>& a, vector<uint64_t>& b) {↵
    int n = 1;↵
    while (n < a.size() || n < b.size())↵
        n <<= 1;↵
    n <<= 1;↵
    a.resize(n);↵
    b.resize(n);↵

    uint64_t inv_n = bin_pow(n, mod-2, mod);↵
    uint64_t root = bin_pow(gen, (mod-1)/n, mod);↵
    uint64_t inv_root = bin_pow(root, mod-2, mod);↵
    auto rev = bit_sort(n);↵

    ntt(a, rev, inv_n, root, inv_root, false);↵
    ntt(b, rev, inv_n, root, inv_root, false);↵

    for (int i = 0; i < n; i++)↵
        a[i] = (__uint128_t) a[i] * b[i] % mod;↵
    ntt(a, rev, inv_n, root, inv_root, true);↵
}↵
```↵
</spoiler>↵

  ↵

6. Performance↵
------------------↵
We will test 5 algorithms: FFT with doubles and long doubles, NTT with Montgomery and Barret reductions and NTT without optimizations.  ↵
  ↵
Firstly, let's test it on [problem:993E]. It is important to notice that NTT with Barret reduction and FFT with doubles got WA43 (but passed another tests with big input) because FFT with doubles has very bad precision as it was noticed before. NTT with Barret has not passed because module $998244353$ is quite small to store answer in this problem (NTT with Barret and without uint256 can be used only for small coefficients).  ↵

<table>↵
<tr>↵
<th>FFT, doubles</th>↵
<th>FFT, long doubles</th>↵
<th>NTT, Montgomery</th>↵
<th>NTT, Barret</th>↵
<th>NTT, no opt</th>↵
</tr>↵
<tr>↵
<th>[submission:261258380], 311ms, BAD PRECISION</th>↵
<th>[submission:261256931], 702ms, OK</th>↵
<th>[submission:261256487], 281ms, OK</th>↵
<th>[submission:261257532], 312ms, BAD LIMITS</th>↵
<th>[submission:261257532], 859ms, OK</th>↵
</tr>↵
</table>↵
  ↵
Now let's see results on multiplication random polynomials with $10^6$ size (tested on Codeforces custom test):  ↵
<table>↵
<tr>↵
<th>FFT, doubles</th>↵
<th>FFT, long doubles</th>↵
<th>NTT, Montgomery</th>↵
<th>NTT, Barret</th>↵
<th>NTT, no opt</th>↵
</tr>↵
<tr>↵
<th>1407ms</th>↵
<th>2434ms</th>↵
<th>1374ms</th>↵
<th>1437ms</th>↵
<th>4122ms</th>↵
</tr>↵
<tr>↵
<th>1286ms</th>↵
<th>2426ms</th>↵
<th>1256ms</th>↵
<th>1675ms</th>↵
<th>3683ms</th>↵
</tr>↵
<tr>↵
<th>1052ms</th>↵
<th>2286ms</th>↵
<th>1153s</th>↵
<th>1320ms</th>↵
<th>3518ms</th>↵
</tr>↵
<tr>↵
<th>1079ms</th>↵
<th>2359ms</th>↵
<th>1309ms</th>↵
<th>1380ms</th>↵
<th>3961ms</th>↵
</tr>↵
<tr>↵
<th>1217ms</th>↵
<th>2791ms</th>↵
<th>1476ms</th>↵
<th>1780ms</th>↵
<th>3400ms</th>↵
</tr>↵
<tr>↵
<th>1241ms</th>↵
<th>2191ms</th>↵
<th>1125ms</th>↵
<th>1337ms</th>↵
<th>3512ms</th>↵
</tr>↵
</table>↵
  ↵
Next, local tests on Ryzen 5 5650u (AVX), 8Gb RAM 4266MHz, Debian 12, GCC 12.2.0. Using `g++ -Wall -Wextra -Wconversion -static -DONLINE_JUDODGE -O2 -std=c++20 fftest.cc -o fftest` compilation parameters (like on Codeforces).  ↵
<table>↵
<tr>↵
<th>FFT, doubles</th>↵
<th>FFT, long doubles</th>↵
<th>NTT, Montgomery</th>↵
<th>NTT, Barret</th>↵
<th>NTT, no opt</th>↵
</tr>↵
<tr>↵
<th>691ms</th>↵
<th>2339ms</th>↵
<th>352ms</th>↵
<th>401ms</th>↵
<th>398ms</th>↵
</tr>↵
<tr>↵
<th>700ms</th>↵
<th>2328ms</th>↵
<th>411ms</th>↵
<th>500ms</th>↵
<th>487ms</th>↵
</tr>↵
<tr>↵
<th>729ms</th>↵
<th>2036ms</th>↵
<th>359ms</th>↵
<th>446ms</th>↵
<th>396ms</th>↵
</tr>↵
<tr>↵
<th>698ms</th>↵
<th>2164ms</th>↵
<th>380ms</th>↵
<th>456ms</th>↵
<th>472ms</th>↵
</tr>↵
<tr>↵
<th>734ms</th>↵
<th>2284ms</th>↵
<th>407ms</th>↵
<th>462ms</th>↵
<th>440ms</th>↵
</tr>↵
<tr>↵
<th>737ms</th>↵
<th>2148ms</th>↵
<th>396ms</th>↵
<th>459ms</th>↵
<th>434ms</th>↵
</tr>↵
</tr>↵
</table>↵
  ↵
On Codeforces server Montgomery + NTT and FFT + doubles has similar performance, then NTT + Barret, then FFT + long doubles and then vanilla NTT. On local  with modern CPU with SIMD leader is Mongomery + NTT. Actually, on local system NTT runs much faster **even wihout optimizations**. I can't certainly determine what technology are responsible for fast modular arithmetic, but most likely it is some SIMD. In general, NTT has good potential to hardware acceleration. For example <a href="https://link.springer.com/chapter/10.1007/978-3-030-78713-4_6">here</a> shown how to accelerate NTT via FPGA.  ↵
  ↵
As wee see, performance of NTT with big module rests on platform because of many operations with high-bit integers, exactly, 128-bit modular division. This operation depends on bare metal implementation so on Codeforces NTT without optimization is the slowest but on CPU with SIMD acceleration NTT wins FFT. However, diffirence can be neutralized via **reduction** which uses only multiplication of 128-bit integers. <a href="https://danlark.org/2020/06/14/128-bit-division/">This article</a> shows problems of 128-bit division.  ↵
  ↵
So, NTT + Montgomery reduction is fairly universal choice on all platforms. Single nuance is that NTT can't operate with negative coefficients because answer is modulo $p$.  ↵
  ↵

7. Negative coefficients↵
------------------↵
Consider $A(x)$ and $B(x)$ have any integer coefficients and we want to get coefficients of the product via NTT. Let $C(x)$ is polynomial such that every coefficient of $A(x)+C(x)$ $B(x)+C(x)$ and $A(x)+B(x)+C(x)$ is non-negative. Notice that↵
$$ A \cdot B = (A+C) \cdot (B+C) - C \cdot (A+B+C)  $$↵
Because every term in sum is polynomial with non-negative coefficients, $A(x) \cdot B(x)$ can be found via 2 calls of NTT multiplication. $C(x)$ can be found in $O(n)$ greedily.  ↵
  ↵

8. Class for modular arithmetic↵
------------------↵
It is convenient to make a class for modular arithmetic in $\mathbb{Z}/p\mathbb{Z}$ using some reduction. Below showed class for modular arithmetic with Barret reduction (it has methods for gcd, pow, inverses, roots). Class can be extended to arithmetic in every finite field. P.S. If we want get maximum performance while working with specific set of integers, we can replace Barret reduction with Montgomery, because Montgomery is faster.  ↵

<spoiler summary="Modular arithmetic class">↵
```c++↵
class zpz {↵
public:↵
    static void init(uint32_t m) {↵
        mod = m;↵
        shift = 2*(32 - __builtin_clz(m));↵
        factor = (uint64_t(1) << shift) / mod;↵
        gen = 0;↵
    }↵

    static uint32_t ext_gcd(uint32_t a, uint32_t b, uint64_t& x, uint64_t& y) {↵
        if (a < b)↵
            return ext_gcd(b, a, y, x);↵
        if (b == 0) {↵
            x = 1;↵
            y = 0;↵
            return a;↵
        }↵
        uint64_t x1, y1;↵
        uint32_t g = ext_gcd(b, a%b, x1, y1);↵
        x = y1;↵
        y = x1 - (a/b)*y1;↵
        return g;↵
    }↵

    static zpz pow(zpz a, uint32_t n) {↵
        zpz res = 1;↵
        while (n) {↵
            if (n & 1)↵
                res *= a;↵
            a *= a;↵
            n >>= 1;↵
        }↵
        return res;↵
    }↵

    static zpz inv(zpz a) {↵
        if (inverses.find(a()) == inverses.end()) {↵
            uint64_t x, y;↵
            ext_gcd(a(), mod, x, y);↵
            inverses[a()] = reduce(x + mod);↵
        }↵
        return inverses[a()];↵
    }↵

    static zpz root(zpz a, int n) {↵
        if (gen == 0) {↵
            vector<uint32_t> fact;↵
            // int phi = euler_totient(n);↵
            uint32_t phi = mod-1;↵
            uint32_t m = phi;↵
            for (uint32_t d = 2; d*d <= m; ++d)↵
                if (m%d == 0) {↵
                    fact.push_back(d);↵
                    while (m%d == 0)↵
                        m /= d;↵
                }↵
            if (m > 1)↵
                fact.push_back(m);↵
            for (uint32_t rt = 2; rt < mod; ++rt) {↵
                bool found = true;↵
                for (auto d : fact)↵
                    if (pow(zpz(rt), phi / d) == 1) {↵
                        found = false;↵
                        break;↵
                    }↵
                if (found) {↵
                    gen = rt;↵
                    break;↵
                }↵
            }↵
            gen = mod == 1 ? 1 : throw exception();↵
        }↵
        ↵
        return pow(zpz(gen), (mod-1)/n);↵
    }↵

    static void get_all_modular_inverses() {↵
        inverses[1] = 1;↵
        for (int k = 2; k < mod; ++k)↵
            inverses[k] = -1LL * (mod / k) * inverses[mod % k] % mod + mod;↵
    }↵


    zpz(uint32_t x) : val(reduce(x)) {}↵
    ↵
    uint32_t operator () () const { return val; }↵

    zpz& operator =(uint64_t x) { val = reduce(x); return *this; }↵

    zpz& operator =(const zpz& x) { val = x(); return *this; }↵

    zpz& operator +=(const zpz& x) { val = reduce(val + x()); return *this; }↵
    zpz& operator -=(const zpz& x) { val = reduce(val - x() + mod); return *this; }↵
    zpz& operator *=(const zpz& x) { val = reduce((uint64_t) val * x()); return *this; }↵
    ↵
    zpz& operator +=(uint64_t x) { return *this += zpz(x); }↵
    zpz& operator -=(uint64_t x) { return *this -= zpz(x); }↵
    zpz& operator *=(uint64_t x) { return *this *= zpz(x); }↵
    ↵
    zpz& operator /=(const zpz& x) { return *this *= inv(x); }↵
    zpz& operator /=(uint64_t x) { return *this /= zpz(x); }↵
    zpz operator /(uint64_t x) { zpz cur = *this; return cur /= x; }↵

    zpz& operator ++() { return *this += 1; }↵
    zpz& operator --() { return *this -= 1; }↵

    zpz operator ++(int unused) { zpz z(*this); ++(*this); return z; }↵
    zpz operator --(int unused) { zpz z(*this); --(*this); return z; }↵
    ↵
    friend zpz operator +(zpz x, const zpz& y) { return x += y; }↵
    friend zpz operator *(zpz x, const zpz& y) { return x *= y; }↵
    friend zpz operator -(zpz x, const zpz& y) { return x -= y; }↵
    friend zpz operator /(zpz x, const zpz& y) { return x /= y; }↵

    friend zpz operator +(zpz x, uint32_t y) { return x += y; }↵
    friend zpz operator *(zpz x, uint32_t y) { return x *= y; }↵
    friend zpz operator -(zpz x, uint32_t y) { return x -= y; }↵
    friend zpz operator /(zpz x, uint32_t y) { return x /= y; }↵

    friend zpz operator +(uint32_t x, zpz y) { return y += x; }↵
    friend zpz operator *(uint32_t x, zpz y) { return y *= x; }↵

    friend zpz operator -(uint32_t x, const zpz& y) { zpz z(x); return z -= y; }↵
    friend zpz operator /(uint32_t x, const zpz& y) { zpz z(x); return z /= y; }↵

    bool operator  <(const zpz& x) const { return val < x(); }↵
    bool operator ==(const zpz& x) const { return val == x(); }↵
    bool operator  >(const zpz& x) const { return val > x(); }↵
    bool operator !=(const zpz& x) const { return val != x(); }↵
    bool operator <=(const zpz& x) const { return val <= x(); }↵
    bool operator >=(const zpz& x) const { return val >= x(); }↵

    bool operator  <(uint32_t x) const { return val < x; }↵
    bool operator ==(uint32_t x) const { return val == x; }↵
    bool operator  >(uint32_t x) const { return val > x; }↵
    bool operator !=(uint32_t x) const { return val != x; }↵
    bool operator <=(uint32_t x) const { return val <= x; }↵
    bool operator >=(uint32_t x) const { return val >= x; }↵

    friend istream& operator >> (istream& input, zpz& x)↵
    {↵
        uint32_t z;↵
        input >> z,↵
        x = zpz(z);↵
        return input;↵
    }↵

    friend ostream& operator << (ostream& output, const zpz& x)↵
    {↵
        return output << x();↵
    }↵


private:↵
    static uint32_t mod, shift;↵
    static uint64_t factor;↵
    ↵
    static gp_hash_table<uint32_t, uint32_t> inverses;↵
    static uint32_t gen;↵

    [[nodiscard]]↵
    static uint32_t reduce(uint64_t x) {↵
        auto t = (uint32_t)(x - (((__uint128_t) x * factor) >> shift) * mod);↵
        if (t < mod)↵
            return t;↵
        return t - mod;↵
    }↵

    uint32_t val;↵
};↵

uint32_t zpz::mod, zpz::shift;↵
uint64_t zpz::factor;↵
gp_hash_table<uint32_t, uint32_t> zpz::inverses;↵
uint32_t zpz::gen;↵
```↵
</spoiler>↵

  ↵
9. Conclusion↵
------------------↵
As a result we compared various reductions for modular arithmetic and found out that Montgomery reduction has the best performance if we know set of integers in advance. Hence, NTT with this reduction can be used instead of FFT in tasks with polynomial multiplication (even for polynomials with negative coefficients). So there shown reusable realization of NTT with Montgomery reduction which shows better performance even than FFT, especially with hardware optimizations (SIMD?). Open question here what hardware optimizations makes NTT run faster on quite modern CPUs "out of the box".

История

 
 
 
 
Правки
 
 
  Rev. Язык Кто Когда Δ Комментарий
en3 Английский alexvim 2024-05-19 03:08:30 19
en2 Английский alexvim 2024-05-19 03:00:48 2 Tiny change: '2^k \over n} $$\nIf w' -> '2^k \over p} $$\nIf w'
ru2 Русский alexvim 2024-05-19 03:00:41 61
en1 Английский alexvim 2024-05-19 02:28:27 29844 Initial revision for English translation
ru1 Русский alexvim 2024-05-19 02:27:22 31455 Первая редакция (опубликовано)