Hi Codeforces!↵
↵
I have something exciting to tell you guys about today! I have recently come up with a really neat and simple recursive algorithm for multiplying polynomials in $O(n \log n)$ time. It is so neat and simple that I think it might possibly revolutionize the way that fast polynomial multiplication is taught and coded. You don't need to know anything about FFT to understand and implement this algorithm.↵
↵
I've split this blog up into two parts. The first part is intended for anyone to be able to read and understand. The second part is advanced and goes into a ton of interesting ideas and concepts related to this algorithm.↵
↵
Prerequisite: Polynomial quotient and remainder, see [Wiki article] (https://en.wikipedia.org/wiki/Polynomial_greatest_common_divisor#Euclidean_division) and this [Stackexchange example](https://math.stackexchange.com/questions/2847682/find-the-quotient-and-remainder).↵
↵
### Task: ↵
Given two polynomials $P$ and $Q$, an integer $n$ and a non-zero complex number $c$, where degree $P < n$ and degree $Q < n$. Your task is to calculate the polynomial $P(x) \, Q(x) \% (x^n - c)$ in $O(n \log n)$ time. You may assume that $n$ is a power of two.↵
↵
### Solution:↵
We can create a divide and conquer algorithm for $P(x) \, Q(x) \% (x^n - c)$ based on the difference of squares formula. Assuming $n$ is even, then $(x^n - c) = (x^{n/2} - \sqrt{c}) (x^{n/2} + \sqrt{c})$. The idea behind the algorithm is to calculate $P(x) \, Q(x) \% (x^{n/2} - \sqrt{c})$ and $P(x) \, Q(x) \% (x^{n/2} + \sqrt{c})$ using 2 recursive calls, and then use that result to calculate $P(x) \, Q(x) \% (x^n - c)$.↵
↵
So how do we actually calculate $P(x) \, Q(x) \% (x^n - c)$ using $P(x) \, Q(x) \% (x^{n/2} - \sqrt{c})$ and $P(x) \, Q(x) \% (x^{n/2} + \sqrt{c})$? ↵
↵
Well, we can use the following formula:↵
↵
$$↵
\begin{aligned}↵
A(x) \% (x^n - c) = &\frac{1}{2} (1 + \frac{x^{n/2}}{\sqrt{c}}) (A(x) \% (x^{n/2} - \sqrt{c})) \, + \\↵
&\frac{1}{2} (1 - \frac{x^{n/2}}{\sqrt{c}}) (A(x) \% (x^{n/2} + \sqrt{c})).↵
\end{aligned}↵
$$↵
↵
<spoiler summary="Proof of the formula">↵
Note that↵
\begin{equation}↵
A(x) = \frac{1}{2} (1 + \frac{x^{n/2}}{\sqrt{c}}) A(x) + \frac{1}{2} (1 — \frac{x^{n/2}}{\sqrt{c}}) A(x).↵
\end{equation}↵
↵
Let $Q^-(x)$ denote the quotient of $A(x)$ divided by $(x^n/2 - \sqrt{c})$ and let $Q^+(x)$ denote the quotient of $A(x)$ divided by $(x^n/2 + \sqrt{c})$. Then↵
↵
$$↵
\begin{aligned}↵
(1 + \frac{x^{n/2}}{\sqrt{c}}) A(x) &= (1 + \frac{x^{n/2}}{\sqrt{c}}) ((A(x) \% (x^{n/2} - \sqrt{c})) + Q^-(x) (x^{n/2} - \sqrt{c})) \\↵
&= (1 + \frac{x^{n/2}}{\sqrt{c}}) (A(x) \% (x^{n/2} - \sqrt{c})) + \frac{1}{\sqrt{c}} Q^-(x) (x^n - c))↵
\end{aligned}↵
$$↵
↵
and↵
↵
$$↵
\begin{aligned}↵
(1 - \frac{x^{n/2}}{\sqrt{c}}) A(x) &= (1 - \frac{x^{n/2}}{\sqrt{c}}) ((A(x) \% (x^{n/2} + \sqrt{c})) + Q^+(x) (x^{n/2} + \sqrt{c})) \\↵
&= (1 - \frac{x^{n/2}}{\sqrt{c}}) (A(x) \% (x^{n/2} + \sqrt{c})) - \frac{1}{\sqrt{c}} Q^+(x) (x^n - c)).↵
\end{aligned}↵
$$↵
↵
With this we have shown that↵
$$↵
\begin{aligned}↵
A(x) = &\frac{1}{2} (1 + \frac{x^{n/2}}{\sqrt{c}}) (A(x) \% (x^{n/2} - \sqrt{c})) \, + \\↵
&\frac{1}{2} (1 - \frac{x^{n/2}}{\sqrt{c}}) (A(x) \% (x^{n/2} + \sqrt{c})) \, + \\↵
&\frac{1}{\sqrt{c}} \frac{Q^-(x) - Q^+(x)}{2} (x^n - c).↵
\end{aligned}↵
$$↵
↵
Here $A(x)$ is expressed as remainder + quotient times $(x^n - c)$. So we have proven the formula.↵
</spoiler>↵
↵
This formula is very useful. If we substitute $A(x)$ by $P(x) Q(x)$, then the formula tells us how to calculate $P(x) \, Q(x) \% (x^n - c)$ using $P(x) \, Q(x) \% (x^{n/2} - \sqrt{c})$ and $P(x) \, Q(x) \% (x^{n/2} + \sqrt{c})$ in linear time. With this we have the recipie for implementing a $O(n \log n)$ divide and conquer algorithm:↵
↵
Input:↵
↵
- Integer $n$ (power of 2),↵
- Non-zero complex number $c$,↵
- Two polynomials $P(x) \% (x^n - c)$ and $Q(x) \% (x^n - c)$.↵
↵
Output:↵
↵
- The polynomial $P(x) \, Q(x) \% (x^n - c)$.↵
↵
Algorithm:↵
↵
Step 1. (Base case) If $n = 1$, then return $P(0) \cdot Q(0)$. Otherwise:↵
↵
Step 2. Starting from $P(x) \% (x^n - c)$ and $Q(x) \% (x^n - c)$, in $O(n)$ time calculate ↵
↵
$$↵
\begin{align}↵
& P(x) \% (x^{n/2} - \sqrt{c}), \\↵
& Q(x) \% (x^{n/2} - \sqrt{c}), \\↵
& P(x) \% (x^{n/2} + \sqrt{c}) \text{ and} \\↵
& Q(x) \% (x^{n/2} + \sqrt{c}).↵
\end{align}↵
$$↵
↵
Step 3. Make two recursive calls to calculate $P(x) \, Q(x) \% (x^{n/2} - \sqrt{c})$ and $P(x) \, Q(x) \% (x^{n/2} + \sqrt{c})$.↵
↵
Step 4. Using the formula, calculate $P(x) \, Q(x) \% (x^n - c)$ in $O(n)$ time. Return the result.↵
↵
Here is a Python implementation following this recipie:↵
↵
<spoiler summary="Python solution to the task">↵
↵
```py↵
"""↵
Calculates P(x) * Q(x) % (x^n - c) in O(n log n) time↵
↵
Input:↵
n: Integer, needs to be power of 2↵
c: Non-zero complex floating point number↵
P: A list of length n representing a polynomial P(x) % (x^n - c)↵
Q: A list of length n representing a polynomial Q(x) % (x^n - c)↵
Output:↵
A list of length n representing the polynomial P(x) * Q(x) % (x^n - c)↵
"""↵
def fast_polymult_mod(P, Q, n, c):↵
assert len(P) == n and len(Q) == n↵
↵
# Base case↵
if n == 1:↵
return [P[0] * Q[0]]↵
↵
assert n % 2 == 0↵
import cmath↵
sqrtc = cmath.sqrt(c)↵
↵
# Calulate P_minus := P mod (x^(n/2) - sqrt(c))↵
# Q_minus := Q mod (x^(n/2) - sqrt(c))↵
↵
P_minus = [p1 + sqrtc * p2 for p1,p2 in zip(P[:n//2], P[n//2:])]↵
Q_minus = [q1 + sqrtc * q2 for q1,q2 in zip(Q[:n//2], Q[n//2:])]↵
↵
# Calulate P_plus := P mod (x^(n/2) + sqrt(c))↵
# Q_plus := Q mod (x^(n/2) + sqrt(c))↵
↵
P_plus = [p1 - sqrtc * p2 for p1,p2 in zip(P[:n//2], P[n//2:])]↵
Q_plus = [q1 - sqrtc * q2 for q1,q2 in zip(Q[:n//2], Q[n//2:])]↵
↵
# Recursively calculate PQ_minus := P * Q % (x^n/2 - sqrt(c)) ↵
# PQ_plus := P * Q % (x^n/2 + sqrt(c))↵
↵
PQ_minus = fast_polymult_mod(P_minus, Q_minus, n//2, sqrtc)↵
PQ_plus = fast_polymult_mod(P_plus, Q_plus, n//2, -sqrtc)↵
↵
# Calculate PQ mod (x^n - c) using PQ_minus and PQ_plus↵
PQ = [(m + p)/2 for m,p in zip(PQ_minus, PQ_plus)] +\↵
[(m - p)/(2*sqrtc) for m,p in zip(PQ_minus, PQ_plus)]↵
↵
return PQ↵
```↵
</spoiler>↵
↵
One final thing that I want to mention before going into the advanced section is that this algorithm can also be used to do fast unmodded polynomial multiplication, i.e. given polynomials $P(x)$ and $Q(x)$ calculate $P(x) \, Q(x)$. The trick is simply to pick $n$ large enough such that $P(x) \, Q(x) = P(x) \, Q(x) \% (x^n - c)$, and then use the exact same algorithm as before. $c$ can be arbitrarily picked (any non-zero complex number works).↵
↵
<spoiler summary="Python implementation for general Fast polynomial multiplication">↵
↵
```py↵
"""↵
Calculates P(x) * Q(x)↵
↵
Input:↵
P: A list representing a polynomial P(x)↵
Q: A list representing a polynomial Q(x)↵
Output:↵
A list representing the polynomial P(x) * Q(x)↵
"""↵
def fast_polymult(P, Q):↵
# Calculate length of the list representing P*Q↵
n1 = len(P)↵
n2 = len(Q)↵
res_len = n1 + n2 - 1↵
↵
# Pick n sufficiently big↵
n = 1↵
while n < res_len:↵
n *= 2↵
↵
# Pad with extra 0s to reach length n↵
P = P + [0] * (n - n1)↵
Q = Q + [0] * (n - n2)↵
↵
# Pick non-zero c arbitrarily =)↵
c = 123.24↵
↵
# Calculate P*Q mod x^n - c↵
PQ = fast_polymult_mod(P, Q, n, c)↵
↵
# Remove extra 0 padding and return↵
return PQ[:res_len]↵
```↵
</spoiler>↵
↵
If you want to try out implementing this algorithm yourself, then here is a very simple problem to test out your implementation on: [SPOJ:POLYMUL](https://www.spoj.com/problems/POLYMUL/).↵
↵
### (Advanced) Speeding up the algorithm↵
This section will be about tricks that can be used to speed up the algorithm. The first two tricks will speed up the algorithm by a factor of 2 each. The last trick is advanced, and it has the potential to both speed up the algorithm and also make it more numerically stable.↵
↵
<spoiler summary="$n$ doesn't actually need to be a power of 2">↵
We don't actually need the assumption that $n$ is a power of 2. If $n$ ever becomes odd during the recursion, then we have two choices: Either fall back to a $O(n^2)$ algorithm or fall back to the unmodded $O(n \log{n})$ Polynomial multiplication algorithm. ↵
↵
Let us discuss the run time of falling back to the $O(n^2)$ algorithm when $n$ becomes odd. Assume that $n = a \cdot 2^b$, where $a$ is an odd integer and $b$ is an integer. Think of the recursive algorithm as having layers, one layer for each possible value of $n$.↵
The first $b$ layers will all take $O(n)$ time each. In the $(b+1)$-th layer the value of $n$ is $a$. Using the $O(n^2)$ polynomial multiplication algorithm leads to this layer taking $O(n/a \cdot a^2) = O(n \cdot a)$ time. The final time complexity comes out to be $O((a + b) \, n)$.↵
↵
<spoiler summary="Python implementation that works for both odd and even $n$">↵
↵
```py↵
"""↵
Calculates P(x) * Q(x) % (x^n - c) in O((a + b) * n) time, where n = a*2^b.↵
↵
Input:↵
n: Integer↵
c: Non-zero complex floating point number↵
P: A list of length n representing a polynomial P(x) % (x^n - c)↵
Q: A list of length n representing a polynomial Q(x) % (x^n - c)↵
Output:↵
A list of length n representing the polynomial P(x) * Q(x) % (x^n - c)↵
"""↵
def fast_polymult_mod2(P, Q, n, c):↵
assert len(P) == n and len(Q) == n↵
↵
# Base case (n is odd)↵
if n & 1:↵
# Calculate the answer in O(n^2) time↵
res = [0] * (2*n)↵
for i in range(n):↵
for j in range(n):↵
res[i + j] += P[i] * Q[j]↵
return [r1 + c * r2 for r1,r2 in zip(res[:n], res[n:])]↵
↵
assert n % 2 == 0↵
import cmath↵
sqrtc = cmath.sqrt(c)↵
↵
# Calulate P_minus := P mod (x^(n/2) - sqrt(c))↵
# Q_minus := Q mod (x^(n/2) - sqrt(c))↵
↵
P_minus = [p1 + sqrtc * p2 for p1,p2 in zip(P[:n//2], P[n//2:])]↵
Q_minus = [q1 + sqrtc * q2 for q1,q2 in zip(Q[:n//2], Q[n//2:])]↵
↵
# Calulate P_plus := P mod (x^(n/2) + sqrt(c))↵
# Q_plus := Q mod (x^(n/2) + sqrt(c))↵
↵
P_plus = [p1 - sqrtc * p2 for p1,p2 in zip(P[:n//2], P[n//2:])]↵
Q_plus = [q1 - sqrtc * q2 for q1,q2 in zip(Q[:n//2], Q[n//2:])]↵
↵
# Recursively calculate PQ_minus := P * Q % (x^n/2 - sqrt(c)) ↵
# PQ_plus := P * Q % (x^n/2 + sqrt(c))↵
↵
PQ_minus = fast_polymult_mod2(P_minus, Q_minus, n//2, sqrtc)↵
PQ_plus = fast_polymult_mod2(P_plus, Q_plus, n//2, -sqrtc)↵
↵
# Calculate PQ mod (x^n - c) using PQ_minus and PQ_plus↵
PQ = [(m + p)/2 for m,p in zip(PQ_minus, PQ_plus)] +\↵
[(m - p)/(2*sqrtc) for m,p in zip(PQ_minus, PQ_plus)]↵
↵
return PQ↵
```↵
</spoiler>↵
↵
The reason why this is super useful is that it allows us to speed up the fast unmodded polynomial multiplication algorithm. As long as we are fine with $a$ being less than say $10$, then we might be able to choose a significantly smaller $n$ compared to what would be possible if we were allowed to only choose powers of two. This trick has the potential of making the fast unmodded polynomial multiplication algorithm run twice as fast.↵
↵
<spoiler summary="Python implementation for more efficient fast unmodded polynomial multiplication">↵
↵
```py↵
"""↵
Calculates P(x) * Q(x)↵
↵
Input:↵
P: A list representing a polynomial P(x)↵
Q: A list representing a polynomial Q(x)↵
Output:↵
A list representing the polynomial P(x) * Q(x)↵
"""↵
def fast_polymult2(P, Q):↵
# Calculate length of the list representing P*Q↵
n1 = len(P)↵
n2 = len(Q)↵
res_len = n1 + n2 - 1↵
↵
# Pick n sufficiently big↵
b = 0↵
alim = 10↵
while alim * 2**b < res_len:↵
b += 1↵
a = (res_len - 1) // 2**b + 1↵
n = a * 2**b↵
↵
# Pad with extra 0s to reach length n↵
P = P + [0] * (n - n1)↵
Q = Q + [0] * (n - n2)↵
↵
# Pick non-zero c arbitrarily =)↵
c = 123.24↵
↵
# Calculate P*Q mod x^n - c↵
PQ = fast_polymult_mod2(P, Q, n, c)↵
↵
# Remove extra 0 padding and return↵
return PQ[:res_len]↵
```↵
</spoiler>↵
↵
↵
</spoiler>↵
↵
<spoiler summary="Imaginary-cyclic convolution">↵
Suppose that $P(x)$ and $Q(x)$ are two real polynomial, and that we want to calculate $P(x) \, Q(x)$. As discussed earlier, we can calculate the unmodded polynomial product by picking $n$ large enough such that $(P(x) \, Q(x)) \% (x^n - c) = P(x) \, Q(x)$ (here $c$ is any non-zero complex number), and then running the divide and conquer algorithm. But it turns out there is something smarter that we can do.↵
↵
If we use $c = \text{i}$ (the imaginary unit) as the inital value of $c$, then this will allow us to pick an even smaller value for $n$. The reason for this is that if we get "overflow" from $n$ being too small, then that overflow will be placed into the imaginary part of the result $(P(x) \, Q(x)) \% (x^n - \text{i})$. This means that by using $c = \text{i}$ we are allowed to to pick $n$ as half the size compared to if we weren't using $c=\text{i}$. So this trick speeds the fast unmodded polynomial multiplication algorithm up by exactly a factor of 2.↵
</spoiler>↵
↵
<spoiler summary="Trick to go from $\% (x^n - c)$ to $\% (x^n - 1)$">↵
There is somewhat well known technique called "reweighting" that allows us to switch between working with $\% (x^n - c)$ and working with $\% (x^n - 1)$. I've previously written a blog explaining this technique, see [here](https://codeforces.me/blog/entry/106983).↵
↵
So why would we be interested in switching from $\% (x^n - c)$ to $\% (x^n - 1)$? The reason is that by using $c=1$, we don't need to bother with multiplying or dividing with $c$ or $\sqrt{c}$ anywhere, since $c=\sqrt{c}=1$. Additionally, if $c=-1$ or $c=\text{i}$ or $c=\text{-i}$, then multiplying or dividing by $c$ can be done very efficiently. So whenever $c$ becomes something other than $1,-1,\text{i}$ or $-\text{i}$, then it makes sense to use the reweight trick to switch back to $c=1$. This will significantly reduce the number of floating point operations used by the algorithm. Fewer floating point operations means that the algorithm both has the potential to be faster and more nummerically stable. So reweighting is definitely something to consider if you want to create a heavily optimized polynomial multiplication implementation. ↵
↵
</spoiler>↵
↵
### (Advanced) [user:-is-this-fft-,2023-07-07]?↵
This algorithm is actually FFT in disguise. But it is also different compared to any other FFT algorithm that I've seen in the past (for example the Cooley–Tukey FFT algorithm).↵
↵
<spoiler summary="Using this algorithm to calculate FFT">↵
In the tail of the recursion (i.e. when $n$ reaches 1), you are calculating $P(x) \, Q(x) \% (x - c)$, for some non-zero complex number $c$. This is infact the same thing as evaluating the polynomial $P(x) \, Q(x)$ at $x=c$. Furthermore, if you initially started with $c=1$, then the $c$ in the tail will be some $n$-th root of unity. If you analyze it more carefully, then you will see that each tail corresponds to a different $n$-th root of unity. So what the algorithm is actually doing is evaluating $P(x) \, Q(x)$ in all possible $n$-th roots of unity. ↵
↵
The $n$-th order FFT of a polynomial is defined as the polynomial evaluated in all $n$-th roots of unity. This means that the algorithm is infact an FFT algorithm. However, if you want to use it to calculate FFT, then make sure you order the $n$-th roots of unity according to the standard order used for FFT algorithms. The standard order is $\exp{(\frac{2 \pi \text{i}}{n} 0)}, \exp{(\frac{2 \pi \text{i}}{n} 1)}, ..., \exp{(\frac{2 \pi \text{i}}{n} (n-1))}$. To get the ordering correct, you will probably need to do a "bit reversal" at the end.↵
↵
</spoiler>↵
↵
<spoiler summary="This algorithm is not the same algorithm as Cooley–Tukey">↵
The Cooley-Tukey algorithm is the standard algorithm for calculating FFT. It is for exmple used in this blog [[Tutorial] FFT](https://codeforces.me/blog/entry/111371) by [user:-is-this-fft-,2023-07-08]. The idea behind the algorithm is to split up the polynomial $P(x)$ into an even part $P_{\text{even}}(x^2)$ and an odd part $x \, P_{\text{odd}}(x^2)$. You can calculate the FFT of $P(x)$ using the FFTs of $P_{\text{even}}(x)$ and $P_{\text{odd}}(x)$. So Cooley-Tukey is a $O(n \log{n})$ divide and conquer algorithm that repeatedly splits up the polynomial into odd and even parts.↵
↵
The wiki article for [Cooley-Tukey](https://en.wikipedia.org/wiki/Cooley%E2%80%93Tukey_FFT_algorithm) has a nice description of the algorithm↵
↵
$$↵
\begin{align}↵
X_k &= E_k + e^{- \frac{2 \pi \text{i}}{n} k} O_k, \\↵
X_{k+\frac{n}{2}} &= E_k - e^{- \frac{2 \pi \text{i}}{n} k} O_k.↵
\end{align}↵
$$↵
↵
If you compare this to calculating FFT using the divide and conquer polynomial mod method you instead get ↵
↵
$$↵
\begin{align}↵
X_k &= E_k + c \, O_k, \\↵
X_{k+\frac{n}{2}} &= E_k - c \, O_k,↵
\end{align}↵
$$↵
↵
where $c$ is an $n$-th root of unity that is independent of $k$. This is very different compared to Cooley-Tukey since $c$ doesn't have a dependence on $k$ unlike $e^{- \frac{2 \pi \text{i}}{n} k}$. Infact, $c$ being constant means that the polynomial mod method has the potential to be faster than Cooley-Tukey.↵
↵
</spoiler>↵
↵
<spoiler summary="FFT implementation in Python based on this algorithm">↵
Here is an FFT implementation. A codegolfed version of the same code can be found on [Pyrival](https://github.com/cheran-senthil/PyRival/blob/master/pyrival/algebra/fft.py).↵
↵
```py↵
"""↵
Calculates FFT(P) in O(n log n) time.↵
↵
This implementation is based on the ↵
polynomial modulo multiplication algorithm.↵
↵
Input:↵
P: A list of length n representing a polynomial P(x).↵
n needs to be a power of 2.↵
Output:↵
A list of length n representing the FFT of the polynomial P,↵
i.e. the list [P(exp(2j pi / n * i) for i in range(n)]↵
"""↵
rt = [1] # List used to store roots of unity↵
def FFT(P):↵
n = len(P)↵
# Assert n is a power of 2↵
assert n and (n - 1) & n == 0↵
# Make a copy of P to not modify original P↵
P = P[:] ↵
↵
# Precalculate the roots↵
while 2 * len(rt) < n:↵
# 4*len(rt)-th root of unity↵
import cmath↵
root = cmath.exp(2j * cmath.pi / (4 * len(rt)))↵
rt.extend([r * root for r in rt])↵
↵
# Transform P↵
k = n↵
while k > 1:↵
for i in range(n//k):↵
r = rt[i]↵
for j1 in range(i*k, i*k + k//2):↵
j2 = j1 + k//2↵
z = r * P[j2]↵
P[j2] = P[j1] - z↵
P[j1] += z↵
k //= 2↵
↵
# Bit reverse P before returning↵
rev = [0] * n↵
for i in range(1, n):↵
rev[i] = rev[i // 2] // 2 + (i & 1) * n // 2↵
↵
return [P[r] for r in rev]↵
↵
# Inverse of FFT(P) using a standard trick↵
def inverse_FFT(fft_P):↵
n = len(fft_P)↵
return FFT([fft_P[-i]/n for i in range(n)])↵
↵
"""↵
Calculates P(x) * Q(x)↵
↵
Input:↵
P: A list representing a polynomial P(x)↵
Q: A list representing a polynomial Q(x)↵
Output:↵
A list representing the polynomial P(x) * Q(x)↵
"""↵
def fast_polymult_using_FFT(P, Q):↵
# Calculate length of the list representing P*Q↵
n1 = len(P)↵
n2 = len(Q)↵
res_len = n1 + n2 - 1↵
↵
# Pick n sufficiently big↵
n = 1↵
while n < res_len:↵
n *= 2↵
↵
# Pad with extra 0s to reach length n↵
P = P + [0] * (n - n1)↵
Q = Q + [0] * (n - n2)↵
↵
# Transform P and Q↵
fft_P = FFT(P)↵
fft_Q = FFT(Q)↵
↵
# Calculate FFT of P*Q↵
fft_PQ = [p*q for p,q in zip(fft_P,fft_Q)]↵
↵
# Inverse FFT↵
PQ = inverse_FFT(fft_PQ)↵
↵
# Remove padding and return↵
return PQ[:res_len]↵
"""↵
Calculates P(x) * Q(x)↵
↵
Input:↵
P: A list representing a polynomial P(x)↵
Q: A list representing a polynomial Q(x)↵
Output:↵
A list representing the polynomial P(x) * Q(x)↵
"""↵
def fast_polymult_using_FFT(P, Q):↵
# Calculate length of the list representing P*Q↵
n1 = len(P)↵
n2 = len(Q)↵
res_len = n1 + n2 - 1↵
↵
# Pick n sufficiently big↵
n = 1↵
while n < res_len:↵
n *= 2↵
↵
# Pad with extra 0s to reach length n↵
P = P + [0] * (n - n1)↵
Q = Q + [0] * (n - n2)↵
↵
# Transform P and Q↵
fft_P = FFT(P)↵
fft_Q = FFT(Q)↵
↵
# Calculate FFT of P*Q↵
fft_PQ = [p*q for p,q in zip(fft_P,fft_Q)]↵
↵
# Inverse FFT↵
PQ = inverse_FFT(fft_PQ)↵
↵
# Remove padding and return↵
return PQ[:res_len]↵
```↵
</spoiler>↵
↵
<spoiler summary="FFT implementation in C++ based on this algorithm">↵
Here is an FTT implementation. It is coded in the same style as in [KACTL](https://github.com/kth-competitive-programming/kactl/blob/main/content/numerical/FastFourierTransform.h).↵
↵
```cpp↵
typedef complex<double> C;↵
typedef vector<double> vd;↵
void fft(vector<C>& a) {↵
int n = sz(a);↵
static vector R{1.L + 0il};↵
static vector rt{1. + 0i};↵
for (static int k = 2; k < n; k *= 2) {↵
R.resize(n/2); rt.resize(n/2);↵
rep(i,k/2,k) rt[i] = R[i] = R[i-k/2] * pow(1il, 2./k);;↵
}↵
for (int k = n; k > 1; k /= 2) rep(i,0,n/k) rep(j,i*k,i*k + k/2) {↵
C &u = a[j], &v = a[j+k/2], &r = rt[i];↵
C z(v.real()*r.real() - v.imag()*r.imag(), ↵
v.real()*r.imag() + v.imag()*r.real());↵
v = u - z;↵
u = u + z;↵
}↵
vi rev(n);↵
rep(i,0,n) rev[i] = rev[i / 2] / 2 + (i & 1) * n / 2;↵
rep(i,0,n) if (i < rev[i]) swap(a[i], a[rev[i]]);↵
}↵
↵
vd conv(const vd& a, const vd& b) {↵
if (a.empty() || b.empty()) return {};↵
vd res(sz(a) + sz(b) - 1);↵
int L = 32 - __builtin_clz(sz(res)), n = 1 << L;↵
vector<C> in(n), out(n);↵
copy(all(a), begin(in));↵
rep(i,0,sz(b)) in[i].imag(b[i]);↵
fft(in);↵
for (C& x : in) x *= x;↵
rep(i,0,n) out[i] = in[-i & (n - 1)] - conj(in[i]);↵
fft(out);↵
rep(i,0,sz(res)) res[i] = imag(out[i]) / (4 * n);↵
return res;↵
}↵
```↵
</spoiler>↵
↵
### (Advanced) Connection between this algorithm and NTT↵
Just like how there is FFT and NTT, there are two variants of this algorithm too. One using complex floating point numbers, and the other using modulo a prime (or more generally modulo an odd composite number).↵
↵
<spoiler summary="Using modulo integers instead of complex numbers">↵
This algorithm requires three properties. Firstly it needs to be possible to divide by $2$, and secondly $\sqrt{c}$ needs to exist, and thirdly it needs to be possible to divide by $\sqrt{c}$. This means that we don't technically need complex numbers, we could also use other number systems (like working modulo a prime or modulo an odd composite number).↵
↵
Primes that work nicely for this purpose are called "NTT primes", which means that the prime — 1 is divisible by a large power of $2$. Common examples of NTT primes are: $998244353 = 119 \cdot 2^{23} + 1$, $167772161 = 5 \cdot 2^{25} + 1$ and $469762049 = 7 \cdot 2^{26} + 1$.↵
</spoiler>↵
↵
<spoiler summary="What if $sqrt(c)$ doesn't exist?">↵
One of the things I dislike about NTT is that for NTT to be defined, there needs to exist a $n$-th root of unity. Usually problems involving NTT are designed so that this is never an issue. But if you want to use NTT where it hasn't been designed to magically work, then this is a really big issue. The NTT can become undefined!↵
↵
Note that this algorithm does not exactly share the same drawback of being undefined. The reason for this is that if $\sqrt{c}$ doesn't exist, then the algorithm can simply choose to either switch over to a $O(n^2)$ polynomial multiplication algorithm, or fall back to fast unmodded polynomial multiplication. The implications from this is that this algorithm can do fast modded polynomial multiplication even if it is given a relatively bad NTT prime. I just find this property to be really cool!↵
↵
A good example of when NTT becomes undefined is this yosup problem [convolution_mod_large](https://judge.yosupo.jp/problem/convolution_mod_large). Here the NTT mod is $998244353 = 119 \cdot 2^{23}$. The tricky thing about the problem is that $n=2^{24}$. Since $998244353 = 119 \cdot 2^{23} + 1$ there wont exist any $n$-th root of unity, so the NTT of length $n$ is undefined. However, the divide and conquer approach from this blog can easily solve the problem by falling back to the $O(n^2)$ algorithm.↵
</spoiler>↵
↵
<spoiler summary="NTT implementation in Python based on this algorithm">↵
Here is an NTT implementation. A codegolfed version of the same code can be found on [Pyrival](https://github.com/cheran-senthil/PyRival/blob/master/pyrival/algebra/ntt.py).↵
↵
```py↵
# Mod used for NTT↵
# Requirement: Any odd integer > 2↵
# It is important that MOD - 1 is↵
# divisible by lots of 2s↵
MOD = (119 << 23) + 1↵
assert MOD & 1↵
↵
# Precalc non-quadratic_residue (used by the NTT)↵
non_quad_res = 2↵
while pow(non_quad_res, MOD//2, MOD) != MOD - 1:↵
non_quad_res += 1↵
rt = [1]↵
↵
"""↵
Calculates NTT(P) in O(n log n) time.↵
↵
This implementation is based on the ↵
polynomial modulo multiplication algorithm.↵
↵
Input:↵
P: A list of length n representing a polynomial P(x).↵
n needs to be a power of 2.↵
Output:↵
A list of length n representing the NTT of the polynomial P,↵
i.e. the list [P(root**i) % MOD for i in range(n)]↵
where root is an n-th root of unity mod MOD↵
"""↵
def NTT(P):↵
n = len(P)↵
# Assert n is a power of 2↵
assert n and (n - 1) & n == 0↵
↵
# Check that NTT is defined for this n↵
assert (MOD - 1) % n == 0↵
↵
# Make a copy of P to not modify original P↵
P = P[:] ↵
↵
# Precalculate the roots↵
while 2 * len(rt) < n:↵
# 4*len(rt)-th root of unity↵
root = pow(non_quad_res, MOD//(4 * len(rt)), MOD)↵
rt.extend([r * root % MOD for r in rt])↵
↵
# Transform P↵
k = n↵
while k > 1:↵
for i in range(n//k):↵
r = rt[i]↵
for j1 in range(i*k, i*k + k//2):↵
j2 = j1 + k//2↵
z = r * P[j2]↵
P[j2] = (P[j1] - z) % MOD↵
P[j1] = (P[j1] + z) % MOD↵
k //= 2↵
↵
# Bit reverse P before returning↵
rev = [0] * n↵
for i in range(1, n):↵
rev[i] = rev[i // 2] // 2 + (i & 1) * n // 2↵
↵
return [P[r] for r in rev]↵
↵
# Inverse of NTT(P) using a standard trick↵
def inverse_NTT(ntt_P):↵
n = len(ntt_P)↵
n_inv = pow(n, -1, MOD) # Requires Python 3.8↵
# The following works in any Python version, but requires MOD to be prime↵
# n_inv = pow(n, MOD - 2, MOD)↵
assert n * n_inv % MOD == 1↵
return NTT([ntt_P[-i] * n_inv % MOD for i in range(n)])↵
↵
"""↵
Calculates P(x) * Q(x) (where the coeffiecents are returned % MOD)↵
↵
Input:↵
P: A list representing a polynomial P(x)↵
Q: A list representing a polynomial Q(x)↵
Output:↵
A list representing the polynomial P(x) * Q(x) (with coeffients % MOD)↵
"""↵
def fast_polymult_using_NTT(P, Q):↵
# Calculate length of the list representing P*Q↵
n1 = len(P)↵
n2 = len(Q)↵
res_len = n1 + n2 - 1↵
↵
# Pick n sufficiently big↵
n = 1↵
while n < res_len:↵
n *= 2↵
↵
# Pad with extra 0s to reach length n↵
P = P + [0] * (n - n1)↵
Q = Q + [0] * (n - n2)↵
↵
# Transform P and Q↵
ntt_P = NTT(P)↵
ntt_Q = NTT(Q)↵
↵
# Calculate NTT of P*Q↵
ntt_PQ = [p * q % MOD for p,q in zip(ntt_P,ntt_Q)]↵
↵
# Inverse NTT↵
PQ = inverse_NTT(ntt_PQ)↵
↵
# Remove padding and return↵
return PQ[:res_len]↵
```↵
</spoiler>↵
↵
<spoiler summary="NTT implementation in C++ based on this algorithm">↵
Here is an NTT implementation. It is coded in the same style as in [KACTL](https://github.com/kth-competitive-programming/kactl/blob/main/content/numerical/NumberTheoreticTransform.h).↵
↵
```cpp↵
const ll mod = (119 << 23) + 1;// = 998244353↵
// For p < 2^30 there is also e.g. 5 << 25, 7 << 26, 479 << 21↵
// and 483 << 21 The last two are > 10^9.↵
typedef vector<ll> vl;↵
↵
#include "../number-theory/ModPow.h"↵
↵
void ntt(vl &a) {↵
int n = sz(a);↵
static ll r = 3;↵
while(modpow(r, mod/2) + 1 < mod) ++r;↵
static vl rt{1};↵
for (static int k = 2; k < n; k *= 2) {↵
rt.resize(n/2);↵
rep(i,k/2,k) rt[i] = rt[i-k/2] * modpow(r, mod/2/k) % mod;↵
}↵
for (int k = n; k > 1; k /= 2) rep(i,0,n/k) rep(j,i*k,i*k + k/2) {↵
ll &u = a[j], &v = a[j+k/2], z = rt[i] * v % mod;↵
v = u - z + (u < z ? mod : 0);↵
u = u + z - (u + z >= mod ? mod : 0);↵
}↵
vi rev(n);↵
rep(i,0,n) rev[i] = rev[i / 2] / 2 + (i & 1) * n / 2;↵
rep(i,0,n) if (i < rev[i]) swap(a[i], a[rev[i]]);↵
}↵
vl conv(vl a, vl b) {↵
↵
if (a.empty() || b.empty()) return {};↵
int s = sz(a) + sz(b) - 1, B = 32 - __builtin_clz(s), n = 1 << B;↵
int inv = modpow(n, mod - 2);↵
vl out(n);↵
a.resize(n); b.resize(n);↵
ntt(a), ntt(b);↵
rep(i,0,n) out[-i & (n - 1)] = (ll)a[i] * b[i] % mod * inv % mod;↵
ntt(out);↵
return {out.begin(), out.begin() + s};↵
}↵
```↵
</spoiler>↵
↵
### (Advanced) Shorter implementations ("Codegolfed version")↵
It is possible to make really short but slightly less natural implementations of this algorithm. Originally I was thinking of using this shorter version in the blog, but in the end I didn't do it. So here they are. If you want to implement this algorithm and use it in practice, then I would recommend taking one of these implementations and porting it to C++.↵
↵
<spoiler summary="Short Python implementation without any speedup tricks">↵
↵
```py↵
"""↵
Calculates P(x) * Q(x) % (x^n - c) in O(n log n) time↵
↵
Input:↵
n: Integer, needs to be power of 2↵
c: Non-zero complex floating point number↵
P: A list of length 2*n representing a polynomial P(x) % (x^2n - c^2)↵
Q: A list of length 2*n representing a polynomial Q(x) % (x^2n - c^2)↵
Output:↵
A list of length n representing the polynomial P(x) * Q(x) % (x^n - c)↵
"""↵
def fast_polymult_mod3(P, Q, n, c):↵
assert len(P) == 2*n and len(Q) == 2*n↵
↵
# Mod P and Q by (x^n - c)↵
P = [p1 + c * p2 for p1,p2 in zip(P[:n], P[n:])]↵
Q = [q1 + c * q2 for q1,q2 in zip(Q[:n], Q[n:])]↵
↵
# Base case↵
if n == 1:↵
return [P[0] * Q[0]]↵
↵
assert n % 2 == 0↵
import cmath↵
sqrtc = cmath.sqrt(c)↵
↵
# Recursively calculate PQ_minus := P * Q % (x^n/2 - sqrt(c)) ↵
# PQ_plus := P * Q % (x^n/2 + sqrt(c))↵
↵
PQ_minus = fast_polymult_mod3(P, Q, n//2, sqrtc)↵
PQ_plus = fast_polymult_mod3(P, Q, n//2, -sqrtc)↵
↵
# Calculate PQ mod (x^n - c) using PQ_minus and PQ_plus↵
PQ = [(m + p)/2 for m,p in zip(PQ_minus, PQ_plus)] +\↵
[(m - p)/(2*sqrtc) for m,p in zip(PQ_minus, PQ_plus)]↵
↵
return PQ↵
"""↵
Calculates P(x) * Q(x)↵
↵
Input:↵
P: A list representing a polynomial P(x)↵
Q: A list representing a polynomial Q(x)↵
Output:↵
A list representing the polynomial P(x) * Q(x)↵
"""↵
def fast_polymult3(P, Q):↵
# Calculate length of the list representing P*Q↵
n1 = len(P)↵
n2 = len(Q)↵
res_len = n1 + n2 - 1↵
↵
# Pick n sufficiently big↵
n = 1↵
while n < res_len:↵
n *= 2↵
↵
# Pad with extra 0s to reach length 2*n↵
P = P + [0] * (2*n - n1)↵
Q = Q + [0] * (2*n - n2)↵
↵
# Pick non-zero c arbitrarily =)↵
c = 123.24↵
↵
# Calculate P*Q mod x^n - c↵
PQ = fast_polymult_mod3(P, Q, n, c)↵
↵
# Remove extra 0 padding and return↵
return PQ[:res_len]↵
```↵
</spoiler>↵
↵
<spoiler summary="Short Python implementation supporting odd and even $n$ (making it up to 2 times faster)">↵
↵
```py↵
"""↵
Calculates P(x) * Q(x) % (x^n - c) in O(n log n) time↵
↵
Input:↵
n: Integer, needs to be power of 2↵
c: Non-zero complex floating point number↵
P: A list of length 2*n representing a polynomial P(x) % (x^2n - c^2)↵
Q: A list of length 2*n representing a polynomial Q(x) % (x^2n - c^2)↵
Output:↵
A list of length n representing the polynomial P(x) * Q(x) % (x^n - c)↵
"""↵
def fast_polymult_mod4(P, Q, n, c):↵
assert len(P) == 2*n and len(Q) == 2*n↵
↵
# Mod P and Q by (x^n - c)↵
P = [p1 + c * p2 for p1,p2 in zip(P[:n], P[n:])]↵
Q = [q1 + c * q2 for q1,q2 in zip(Q[:n], Q[n:])]↵
↵
# Base case (n is odd)↵
if n & 1:↵
# Calculate the answer in O(n^2) time↵
res = [0] * (2*n)↵
for i in range(n):↵
for j in range(n):↵
res[i + j] += P[i] * Q[j]↵
return [r1 + c * r2 for r1,r2 in zip(res[:n], res[n:])]↵
↵
assert n % 2 == 0↵
import cmath↵
sqrtc = cmath.sqrt(c)↵
↵
# Recursively calculate PQ_minus := P * Q % (x^n/2 - sqrt(c)) ↵
# PQ_plus := P * Q % (x^n/2 + sqrt(c))↵
↵
PQ_minus = fast_polymult_mod4(P, Q, n//2, sqrtc)↵
PQ_plus = fast_polymult_mod4(P, Q, n//2, -sqrtc)↵
↵
# Calculate PQ mod (x^n - c) using PQ_minus and PQ_plus↵
PQ = [(m + p)/2 for m,p in zip(PQ_minus, PQ_plus)] +\↵
[(m - p)/(2*sqrtc) for m,p in zip(PQ_minus, PQ_plus)]↵
↵
return PQ↵
↵
"""↵
Calculates P(x) * Q(x)↵
↵
Input:↵
P: A list representing a polynomial P(x)↵
Q: A list representing a polynomial Q(x)↵
Output:↵
A list representing the polynomial P(x) * Q(x)↵
"""↵
def fast_polymult4(P, Q):↵
# Calculate length of the list representing P*Q↵
n1 = len(P)↵
n2 = len(Q)↵
res_len = n1 + n2 - 1↵
↵
# Pick n sufficiently big↵
b = 0↵
alim = 10↵
while alim * 2**b < res_len:↵
b += 1↵
a = (res_len - 1) // 2**b + 1↵
n = a * 2**b↵
↵
# Pad with extra 0s to reach length 2*n↵
P = P + [0] * (2*n - n1)↵
Q = Q + [0] * (2*n - n2)↵
↵
# Pick non-zero c arbitrarily =)↵
c = 123.24↵
↵
# Calculate P*Q mod x^n - c↵
PQ = fast_polymult_mod4(P, Q, n, c)↵
↵
# Remove extra 0 padding and return↵
return PQ[:res_len]↵
```↵
</spoiler>↵
↵
<spoiler summary="Short Python implementation supporting odd and even $n$ and imaginary cyclic convolution (making it up to 4 times faster)">↵
↵
```py↵
"""↵
Calculates P(x) * Q(x) % (x^n - c) in O(n log n) time↵
↵
Input:↵
n: Integer, needs to be power of 2↵
c: Non-zero complex floating point number↵
P: A list of length 2*n representing a polynomial P(x) % (x^2n - c^2)↵
Q: A list of length 2*n representing a polynomial Q(x) % (x^2n - c^2)↵
Output:↵
A list of length n representing the polynomial P(x) * Q(x) % (x^n - c)↵
"""↵
def fast_polymult_mod4(P, Q, n, c):↵
assert len(P) == 2*n and len(Q) == 2*n↵
↵
# Mod P and Q by (x^n - c)↵
P = [p1 + c * p2 for p1,p2 in zip(P[:n], P[n:])]↵
Q = [q1 + c * q2 for q1,q2 in zip(Q[:n], Q[n:])]↵
↵
# Base case (n is odd)↵
if n & 1:↵
# Calculate the answer in O(n^2) time↵
res = [0] * (2*n)↵
for i in range(n):↵
for j in range(n):↵
res[i + j] += P[i] * Q[j]↵
return [r1 + c * r2 for r1,r2 in zip(res[:n], res[n:])]↵
↵
assert n % 2 == 0↵
import cmath↵
sqrtc = cmath.sqrt(c)↵
↵
# Recursively calculate PQ_minus := P * Q % (x^n/2 - sqrt(c)) ↵
# PQ_plus := P * Q % (x^n/2 + sqrt(c))↵
↵
PQ_minus = fast_polymult_mod4(P, Q, n//2, sqrtc)↵
PQ_plus = fast_polymult_mod4(P, Q, n//2, -sqrtc)↵
↵
# Calculate PQ mod (x^n - c) using PQ_minus and PQ_plus↵
PQ = [(m + p)/2 for m,p in zip(PQ_minus, PQ_plus)] +\↵
[(m - p)/(2*sqrtc) for m,p in zip(PQ_minus, PQ_plus)]↵
↵
return PQ↵
↵
"""↵
Calculates P(x) * Q(x) of two real polynomials↵
↵
Input:↵
P: A list representing a real polynomial P(x)↵
Q: A list representing a real polynomial Q(x)↵
Output:↵
A list representing the real polynomial P(x) * Q(x)↵
"""↵
def fast_polymult5(P, Q):↵
# Calculate length of the list representing P*Q↵
n1 = len(P)↵
n2 = len(Q)↵
res_len = n1 + n2 - 1↵
↵
# Pick n sufficiently big↵
b = 1↵
alim = 10↵
while alim * 2**b < res_len:↵
b += 1↵
a = (res_len - 1) // 2**b + 1↵
n = a * 2**b↵
↵
# Pick c = i (imaginary unit)↵
c = 1j↵
# and decrease the size of n by a factor of 2↵
n //= 2↵
↵
# Pad with extra 0s to reach length 2*n↵
P = P + [0] * (2*n - n1)↵
Q = Q + [0] * (2*n - n2)↵
↵
# Calculate P*Q mod x^n - i↵
PQ = fast_polymult_mod4(P, Q, n, c)↵
↵
# The imaginary part contains the "overflow"↵
PQ = [pq.real for pq in PQ] + [pq.imag for pq in PQ]↵
↵
# Remove extra 0 padding and return↵
return PQ[:res_len]↵
```↵
</spoiler>↵
↵
↵
↵
↵
↵
↵
I have something exciting to tell you guys about today! I have recently come up with a really neat and simple recursive algorithm for multiplying polynomials in $O(n \log n)$ time. It is so neat and simple that I think it might possibly revolutionize the way that fast polynomial multiplication is taught and coded. You don't need to know anything about FFT to understand and implement this algorithm.↵
↵
I've split this blog up into two parts. The first part is intended for anyone to be able to read and understand. The second part is advanced and goes into a ton of interesting ideas and concepts related to this algorithm.↵
↵
Prerequisite: Polynomial quotient and remainder, see [Wiki article] (https://en.wikipedia.org/wiki/Polynomial_greatest_common_divisor#Euclidean_division) and this [Stackexchange example](https://math.stackexchange.com/questions/2847682/find-the-quotient-and-remainder).↵
↵
### Task: ↵
Given two polynomials $P$ and $Q$, an integer $n$ and a non-zero complex number $c$, where degree $P < n$ and degree $Q < n$. Your task is to calculate the polynomial $P(x) \, Q(x) \% (x^n - c)$ in $O(n \log n)$ time. You may assume that $n$ is a power of two.↵
↵
### Solution:↵
We can create a divide and conquer algorithm for $P(x) \, Q(x) \% (x^n - c)$ based on the difference of squares formula. Assuming $n$ is even, then $(x^n - c) = (x^{n/2} - \sqrt{c}) (x^{n/2} + \sqrt{c})$. The idea behind the algorithm is to calculate $P(x) \, Q(x) \% (x^{n/2} - \sqrt{c})$ and $P(x) \, Q(x) \% (x^{n/2} + \sqrt{c})$ using 2 recursive calls, and then use that result to calculate $P(x) \, Q(x) \% (x^n - c)$.↵
↵
So how do we actually calculate $P(x) \, Q(x) \% (x^n - c)$ using $P(x) \, Q(x) \% (x^{n/2} - \sqrt{c})$ and $P(x) \, Q(x) \% (x^{n/2} + \sqrt{c})$? ↵
↵
Well, we can use the following formula:↵
↵
$$↵
\begin{aligned}↵
A(x) \% (x^n - c) = &\frac{1}{2} (1 + \frac{x^{n/2}}{\sqrt{c}}) (A(x) \% (x^{n/2} - \sqrt{c})) \, + \\↵
&\frac{1}{2} (1 - \frac{x^{n/2}}{\sqrt{c}}) (A(x) \% (x^{n/2} + \sqrt{c})).↵
\end{aligned}↵
$$↵
↵
<spoiler summary="Proof of the formula">↵
Note that↵
\begin{equation}↵
A(x) = \frac{1}{2} (1 + \frac{x^{n/2}}{\sqrt{c}}) A(x) + \frac{1}{2} (1 — \frac{x^{n/2}}{\sqrt{c}}) A(x).↵
\end{equation}↵
↵
Let $Q^-(x)$ denote the quotient of $A(x)$ divided by $(x^n/2 - \sqrt{c})$ and let $Q^+(x)$ denote the quotient of $A(x)$ divided by $(x^n/2 + \sqrt{c})$. Then↵
↵
$$↵
\begin{aligned}↵
(1 + \frac{x^{n/2}}{\sqrt{c}}) A(x) &= (1 + \frac{x^{n/2}}{\sqrt{c}}) ((A(x) \% (x^{n/2} - \sqrt{c})) + Q^-(x) (x^{n/2} - \sqrt{c})) \\↵
&= (1 + \frac{x^{n/2}}{\sqrt{c}}) (A(x) \% (x^{n/2} - \sqrt{c})) + \frac{1}{\sqrt{c}} Q^-(x) (x^n - c))↵
\end{aligned}↵
$$↵
↵
and↵
↵
$$↵
\begin{aligned}↵
(1 - \frac{x^{n/2}}{\sqrt{c}}) A(x) &= (1 - \frac{x^{n/2}}{\sqrt{c}}) ((A(x) \% (x^{n/2} + \sqrt{c})) + Q^+(x) (x^{n/2} + \sqrt{c})) \\↵
&= (1 - \frac{x^{n/2}}{\sqrt{c}}) (A(x) \% (x^{n/2} + \sqrt{c})) - \frac{1}{\sqrt{c}} Q^+(x) (x^n - c)).↵
\end{aligned}↵
$$↵
↵
With this we have shown that↵
$$↵
\begin{aligned}↵
A(x) = &\frac{1}{2} (1 + \frac{x^{n/2}}{\sqrt{c}}) (A(x) \% (x^{n/2} - \sqrt{c})) \, + \\↵
&\frac{1}{2} (1 - \frac{x^{n/2}}{\sqrt{c}}) (A(x) \% (x^{n/2} + \sqrt{c})) \, + \\↵
&\frac{1}{\sqrt{c}} \frac{Q^-(x) - Q^+(x)}{2} (x^n - c).↵
\end{aligned}↵
$$↵
↵
Here $A(x)$ is expressed as remainder + quotient times $(x^n - c)$. So we have proven the formula.↵
</spoiler>↵
↵
This formula is very useful. If we substitute $A(x)$ by $P(x) Q(x)$, then the formula tells us how to calculate $P(x) \, Q(x) \% (x^n - c)$ using $P(x) \, Q(x) \% (x^{n/2} - \sqrt{c})$ and $P(x) \, Q(x) \% (x^{n/2} + \sqrt{c})$ in linear time. With this we have the recipie for implementing a $O(n \log n)$ divide and conquer algorithm:↵
↵
Input:↵
↵
- Integer $n$ (power of 2),↵
- Non-zero complex number $c$,↵
- Two polynomials $P(x) \% (x^n - c)$ and $Q(x) \% (x^n - c)$.↵
↵
Output:↵
↵
- The polynomial $P(x) \, Q(x) \% (x^n - c)$.↵
↵
Algorithm:↵
↵
Step 1. (Base case) If $n = 1$, then return $P(0) \cdot Q(0)$. Otherwise:↵
↵
Step 2. Starting from $P(x) \% (x^n - c)$ and $Q(x) \% (x^n - c)$, in $O(n)$ time calculate ↵
↵
$$↵
\begin{align}↵
& P(x) \% (x^{n/2} - \sqrt{c}), \\↵
& Q(x) \% (x^{n/2} - \sqrt{c}), \\↵
& P(x) \% (x^{n/2} + \sqrt{c}) \text{ and} \\↵
& Q(x) \% (x^{n/2} + \sqrt{c}).↵
\end{align}↵
$$↵
↵
Step 3. Make two recursive calls to calculate $P(x) \, Q(x) \% (x^{n/2} - \sqrt{c})$ and $P(x) \, Q(x) \% (x^{n/2} + \sqrt{c})$.↵
↵
Step 4. Using the formula, calculate $P(x) \, Q(x) \% (x^n - c)$ in $O(n)$ time. Return the result.↵
↵
Here is a Python implementation following this recipie:↵
↵
<spoiler summary="Python solution to the task">↵
↵
```py↵
"""↵
Calculates P(x) * Q(x) % (x^n - c) in O(n log n) time↵
↵
Input:↵
n: Integer, needs to be power of 2↵
c: Non-zero complex floating point number↵
P: A list of length n representing a polynomial P(x) % (x^n - c)↵
Q: A list of length n representing a polynomial Q(x) % (x^n - c)↵
Output:↵
A list of length n representing the polynomial P(x) * Q(x) % (x^n - c)↵
"""↵
def fast_polymult_mod(P, Q, n, c):↵
assert len(P) == n and len(Q) == n↵
↵
# Base case↵
if n == 1:↵
return [P[0] * Q[0]]↵
↵
assert n % 2 == 0↵
import cmath↵
sqrtc = cmath.sqrt(c)↵
↵
# Calulate P_minus := P mod (x^(n/2) - sqrt(c))↵
# Q_minus := Q mod (x^(n/2) - sqrt(c))↵
↵
P_minus = [p1 + sqrtc * p2 for p1,p2 in zip(P[:n//2], P[n//2:])]↵
Q_minus = [q1 + sqrtc * q2 for q1,q2 in zip(Q[:n//2], Q[n//2:])]↵
↵
# Calulate P_plus := P mod (x^(n/2) + sqrt(c))↵
# Q_plus := Q mod (x^(n/2) + sqrt(c))↵
↵
P_plus = [p1 - sqrtc * p2 for p1,p2 in zip(P[:n//2], P[n//2:])]↵
Q_plus = [q1 - sqrtc * q2 for q1,q2 in zip(Q[:n//2], Q[n//2:])]↵
↵
# Recursively calculate PQ_minus := P * Q % (x^n/2 - sqrt(c)) ↵
# PQ_plus := P * Q % (x^n/2 + sqrt(c))↵
↵
PQ_minus = fast_polymult_mod(P_minus, Q_minus, n//2, sqrtc)↵
PQ_plus = fast_polymult_mod(P_plus, Q_plus, n//2, -sqrtc)↵
↵
# Calculate PQ mod (x^n - c) using PQ_minus and PQ_plus↵
PQ = [(m + p)/2 for m,p in zip(PQ_minus, PQ_plus)] +\↵
[(m - p)/(2*sqrtc) for m,p in zip(PQ_minus, PQ_plus)]↵
↵
return PQ↵
```↵
</spoiler>↵
↵
One final thing that I want to mention before going into the advanced section is that this algorithm can also be used to do fast unmodded polynomial multiplication, i.e. given polynomials $P(x)$ and $Q(x)$ calculate $P(x) \, Q(x)$. The trick is simply to pick $n$ large enough such that $P(x) \, Q(x) = P(x) \, Q(x) \% (x^n - c)$, and then use the exact same algorithm as before. $c$ can be arbitrarily picked (any non-zero complex number works).↵
↵
<spoiler summary="Python implementation for general Fast polynomial multiplication">↵
↵
```py↵
"""↵
Calculates P(x) * Q(x)↵
↵
Input:↵
P: A list representing a polynomial P(x)↵
Q: A list representing a polynomial Q(x)↵
Output:↵
A list representing the polynomial P(x) * Q(x)↵
"""↵
def fast_polymult(P, Q):↵
# Calculate length of the list representing P*Q↵
n1 = len(P)↵
n2 = len(Q)↵
res_len = n1 + n2 - 1↵
↵
# Pick n sufficiently big↵
n = 1↵
while n < res_len:↵
n *= 2↵
↵
# Pad with extra 0s to reach length n↵
P = P + [0] * (n - n1)↵
Q = Q + [0] * (n - n2)↵
↵
# Pick non-zero c arbitrarily =)↵
c = 123.24↵
↵
# Calculate P*Q mod x^n - c↵
PQ = fast_polymult_mod(P, Q, n, c)↵
↵
# Remove extra 0 padding and return↵
return PQ[:res_len]↵
```↵
</spoiler>↵
↵
If you want to try out implementing this algorithm yourself, then here is a very simple problem to test out your implementation on: [SPOJ:POLYMUL](https://www.spoj.com/problems/POLYMUL/).↵
↵
### (Advanced) Speeding up the algorithm↵
This section will be about tricks that can be used to speed up the algorithm. The first two tricks will speed up the algorithm by a factor of 2 each. The last trick is advanced, and it has the potential to both speed up the algorithm and also make it more numerically stable.↵
↵
<spoiler summary="$n$ doesn't actually need to be a power of 2">↵
We don't actually need the assumption that $n$ is a power of 2. If $n$ ever becomes odd during the recursion, then we have two choices: Either fall back to a $O(n^2)$ algorithm or fall back to the unmodded $O(n \log{n})$ Polynomial multiplication algorithm. ↵
↵
Let us discuss the run time of falling back to the $O(n^2)$ algorithm when $n$ becomes odd. Assume that $n = a \cdot 2^b$, where $a$ is an odd integer and $b$ is an integer. Think of the recursive algorithm as having layers, one layer for each possible value of $n$.↵
The first $b$ layers will all take $O(n)$ time each. In the $(b+1)$-th layer the value of $n$ is $a$. Using the $O(n^2)$ polynomial multiplication algorithm leads to this layer taking $O(n/a \cdot a^2) = O(n \cdot a)$ time. The final time complexity comes out to be $O((a + b) \, n)$.↵
↵
<spoiler summary="Python implementation that works for both odd and even $n$">↵
↵
```py↵
"""↵
Calculates P(x) * Q(x) % (x^n - c) in O((a + b) * n) time, where n = a*2^b.↵
↵
Input:↵
n: Integer↵
c: Non-zero complex floating point number↵
P: A list of length n representing a polynomial P(x) % (x^n - c)↵
Q: A list of length n representing a polynomial Q(x) % (x^n - c)↵
Output:↵
A list of length n representing the polynomial P(x) * Q(x) % (x^n - c)↵
"""↵
def fast_polymult_mod2(P, Q, n, c):↵
assert len(P) == n and len(Q) == n↵
↵
# Base case (n is odd)↵
if n & 1:↵
# Calculate the answer in O(n^2) time↵
res = [0] * (2*n)↵
for i in range(n):↵
for j in range(n):↵
res[i + j] += P[i] * Q[j]↵
return [r1 + c * r2 for r1,r2 in zip(res[:n], res[n:])]↵
↵
assert n % 2 == 0↵
import cmath↵
sqrtc = cmath.sqrt(c)↵
↵
# Calulate P_minus := P mod (x^(n/2) - sqrt(c))↵
# Q_minus := Q mod (x^(n/2) - sqrt(c))↵
↵
P_minus = [p1 + sqrtc * p2 for p1,p2 in zip(P[:n//2], P[n//2:])]↵
Q_minus = [q1 + sqrtc * q2 for q1,q2 in zip(Q[:n//2], Q[n//2:])]↵
↵
# Calulate P_plus := P mod (x^(n/2) + sqrt(c))↵
# Q_plus := Q mod (x^(n/2) + sqrt(c))↵
↵
P_plus = [p1 - sqrtc * p2 for p1,p2 in zip(P[:n//2], P[n//2:])]↵
Q_plus = [q1 - sqrtc * q2 for q1,q2 in zip(Q[:n//2], Q[n//2:])]↵
↵
# Recursively calculate PQ_minus := P * Q % (x^n/2 - sqrt(c)) ↵
# PQ_plus := P * Q % (x^n/2 + sqrt(c))↵
↵
PQ_minus = fast_polymult_mod2(P_minus, Q_minus, n//2, sqrtc)↵
PQ_plus = fast_polymult_mod2(P_plus, Q_plus, n//2, -sqrtc)↵
↵
# Calculate PQ mod (x^n - c) using PQ_minus and PQ_plus↵
PQ = [(m + p)/2 for m,p in zip(PQ_minus, PQ_plus)] +\↵
[(m - p)/(2*sqrtc) for m,p in zip(PQ_minus, PQ_plus)]↵
↵
return PQ↵
```↵
</spoiler>↵
↵
The reason why this is super useful is that it allows us to speed up the fast unmodded polynomial multiplication algorithm. As long as we are fine with $a$ being less than say $10$, then we might be able to choose a significantly smaller $n$ compared to what would be possible if we were allowed to only choose powers of two. This trick has the potential of making the fast unmodded polynomial multiplication algorithm run twice as fast.↵
↵
<spoiler summary="Python implementation for more efficient fast unmodded polynomial multiplication">↵
↵
```py↵
"""↵
Calculates P(x) * Q(x)↵
↵
Input:↵
P: A list representing a polynomial P(x)↵
Q: A list representing a polynomial Q(x)↵
Output:↵
A list representing the polynomial P(x) * Q(x)↵
"""↵
def fast_polymult2(P, Q):↵
# Calculate length of the list representing P*Q↵
n1 = len(P)↵
n2 = len(Q)↵
res_len = n1 + n2 - 1↵
↵
# Pick n sufficiently big↵
b = 0↵
alim = 10↵
while alim * 2**b < res_len:↵
b += 1↵
a = (res_len - 1) // 2**b + 1↵
n = a * 2**b↵
↵
# Pad with extra 0s to reach length n↵
P = P + [0] * (n - n1)↵
Q = Q + [0] * (n - n2)↵
↵
# Pick non-zero c arbitrarily =)↵
c = 123.24↵
↵
# Calculate P*Q mod x^n - c↵
PQ = fast_polymult_mod2(P, Q, n, c)↵
↵
# Remove extra 0 padding and return↵
return PQ[:res_len]↵
```↵
</spoiler>↵
↵
↵
</spoiler>↵
↵
<spoiler summary="Imaginary-cyclic convolution">↵
Suppose that $P(x)$ and $Q(x)$ are two real polynomial, and that we want to calculate $P(x) \, Q(x)$. As discussed earlier, we can calculate the unmodded polynomial product by picking $n$ large enough such that $(P(x) \, Q(x)) \% (x^n - c) = P(x) \, Q(x)$ (here $c$ is any non-zero complex number), and then running the divide and conquer algorithm. But it turns out there is something smarter that we can do.↵
↵
If we use $c = \text{i}$ (the imaginary unit) as the inital value of $c$, then this will allow us to pick an even smaller value for $n$. The reason for this is that if we get "overflow" from $n$ being too small, then that overflow will be placed into the imaginary part of the result $(P(x) \, Q(x)) \% (x^n - \text{i})$. This means that by using $c = \text{i}$ we are allowed to to pick $n$ as half the size compared to if we weren't using $c=\text{i}$. So this trick speeds the fast unmodded polynomial multiplication algorithm up by exactly a factor of 2.↵
</spoiler>↵
↵
<spoiler summary="Trick to go from $\% (x^n - c)$ to $\% (x^n - 1)$">↵
There is somewhat well known technique called "reweighting" that allows us to switch between working with $\% (x^n - c)$ and working with $\% (x^n - 1)$. I've previously written a blog explaining this technique, see [here](https://codeforces.me/blog/entry/106983).↵
↵
So why would we be interested in switching from $\% (x^n - c)$ to $\% (x^n - 1)$? The reason is that by using $c=1$, we don't need to bother with multiplying or dividing with $c$ or $\sqrt{c}$ anywhere, since $c=\sqrt{c}=1$. Additionally, if $c=-1$ or $c=\text{i}$ or $c=\text{-i}$, then multiplying or dividing by $c$ can be done very efficiently. So whenever $c$ becomes something other than $1,-1,\text{i}$ or $-\text{i}$, then it makes sense to use the reweight trick to switch back to $c=1$. This will significantly reduce the number of floating point operations used by the algorithm. Fewer floating point operations means that the algorithm both has the potential to be faster and more nummerically stable. So reweighting is definitely something to consider if you want to create a heavily optimized polynomial multiplication implementation. ↵
↵
</spoiler>↵
↵
### (Advanced) [user:-is-this-fft-,2023-07-07]?↵
This algorithm is actually FFT in disguise. But it is also different compared to any other FFT algorithm that I've seen in the past (for example the Cooley–Tukey FFT algorithm).↵
↵
<spoiler summary="Using this algorithm to calculate FFT">↵
In the tail of the recursion (i.e. when $n$ reaches 1), you are calculating $P(x) \, Q(x) \% (x - c)$, for some non-zero complex number $c$. This is infact the same thing as evaluating the polynomial $P(x) \, Q(x)$ at $x=c$. Furthermore, if you initially started with $c=1$, then the $c$ in the tail will be some $n$-th root of unity. If you analyze it more carefully, then you will see that each tail corresponds to a different $n$-th root of unity. So what the algorithm is actually doing is evaluating $P(x) \, Q(x)$ in all possible $n$-th roots of unity. ↵
↵
The $n$-th order FFT of a polynomial is defined as the polynomial evaluated in all $n$-th roots of unity. This means that the algorithm is infact an FFT algorithm. However, if you want to use it to calculate FFT, then make sure you order the $n$-th roots of unity according to the standard order used for FFT algorithms. The standard order is $\exp{(\frac{2 \pi \text{i}}{n} 0)}, \exp{(\frac{2 \pi \text{i}}{n} 1)}, ..., \exp{(\frac{2 \pi \text{i}}{n} (n-1))}$. To get the ordering correct, you will probably need to do a "bit reversal" at the end.↵
↵
</spoiler>↵
↵
<spoiler summary="This algorithm is not the same algorithm as Cooley–Tukey">↵
The Cooley-Tukey algorithm is the standard algorithm for calculating FFT. It is for exmple used in this blog [[Tutorial] FFT](https://codeforces.me/blog/entry/111371) by [user:-is-this-fft-,2023-07-08]. The idea behind the algorithm is to split up the polynomial $P(x)$ into an even part $P_{\text{even}}(x^2)$ and an odd part $x \, P_{\text{odd}}(x^2)$. You can calculate the FFT of $P(x)$ using the FFTs of $P_{\text{even}}(x)$ and $P_{\text{odd}}(x)$. So Cooley-Tukey is a $O(n \log{n})$ divide and conquer algorithm that repeatedly splits up the polynomial into odd and even parts.↵
↵
The wiki article for [Cooley-Tukey](https://en.wikipedia.org/wiki/Cooley%E2%80%93Tukey_FFT_algorithm) has a nice description of the algorithm↵
↵
$$↵
\begin{align}↵
X_k &= E_k + e^{- \frac{2 \pi \text{i}}{n} k} O_k, \\↵
X_{k+\frac{n}{2}} &= E_k - e^{- \frac{2 \pi \text{i}}{n} k} O_k.↵
\end{align}↵
$$↵
↵
If you compare this to calculating FFT using the divide and conquer polynomial mod method you instead get ↵
↵
$$↵
\begin{align}↵
X_k &= E_k + c \, O_k, \\↵
X_{k+\frac{n}{2}} &= E_k - c \, O_k,↵
\end{align}↵
$$↵
↵
where $c$ is an $n$-th root of unity that is independent of $k$. This is very different compared to Cooley-Tukey since $c$ doesn't have a dependence on $k$ unlike $e^{- \frac{2 \pi \text{i}}{n} k}$. Infact, $c$ being constant means that the polynomial mod method has the potential to be faster than Cooley-Tukey.↵
↵
</spoiler>↵
↵
<spoiler summary="FFT implementation in Python based on this algorithm">↵
Here is an FFT implementation. A codegolfed version of the same code can be found on [Pyrival](https://github.com/cheran-senthil/PyRival/blob/master/pyrival/algebra/fft.py).↵
↵
```py↵
"""↵
Calculates FFT(P) in O(n log n) time.↵
↵
This implementation is based on the ↵
polynomial modulo multiplication algorithm.↵
↵
Input:↵
P: A list of length n representing a polynomial P(x).↵
n needs to be a power of 2.↵
Output:↵
A list of length n representing the FFT of the polynomial P,↵
i.e. the list [P(exp(2j pi / n * i) for i in range(n)]↵
"""↵
rt = [1] # List used to store roots of unity↵
def FFT(P):↵
n = len(P)↵
# Assert n is a power of 2↵
assert n and (n - 1) & n == 0↵
# Make a copy of P to not modify original P↵
P = P[:] ↵
↵
# Precalculate the roots↵
while 2 * len(rt) < n:↵
# 4*len(rt)-th root of unity↵
import cmath↵
root = cmath.exp(2j * cmath.pi / (4 * len(rt)))↵
rt.extend([r * root for r in rt])↵
↵
# Transform P↵
k = n↵
while k > 1:↵
for i in range(n//k):↵
r = rt[i]↵
for j1 in range(i*k, i*k + k//2):↵
j2 = j1 + k//2↵
z = r * P[j2]↵
P[j2] = P[j1] - z↵
P[j1] += z↵
k //= 2↵
↵
# Bit reverse P before returning↵
rev = [0] * n↵
for i in range(1, n):↵
rev[i] = rev[i // 2] // 2 + (i & 1) * n // 2↵
↵
return [P[r] for r in rev]↵
↵
# Inverse of FFT(P) using a standard trick↵
def inverse_FFT(fft_P):↵
n = len(fft_P)↵
return FFT([fft_P[-i]/n for i in range(n)])↵
↵
"""↵
Calculates P(x) * Q(x)↵
↵
Input:↵
P: A list representing a polynomial P(x)↵
Q: A list representing a polynomial Q(x)↵
Output:↵
A list representing the polynomial P(x) * Q(x)↵
"""↵
def fast_polymult_using_FFT(P, Q):↵
# Calculate length of the list representing P*Q↵
n1 = len(P)↵
n2 = len(Q)↵
res_len = n1 + n2 - 1↵
↵
# Pick n sufficiently big↵
n = 1↵
while n < res_len:↵
n *= 2↵
↵
# Pad with extra 0s to reach length n↵
P = P + [0] * (n - n1)↵
Q = Q + [0] * (n - n2)↵
↵
# Transform P and Q↵
fft_P = FFT(P)↵
fft_Q = FFT(Q)↵
↵
# Calculate FFT of P*Q↵
fft_PQ = [p*q for p,q in zip(fft_P,fft_Q)]↵
↵
# Inverse FFT↵
PQ = inverse_FFT(fft_PQ)↵
↵
# Remove padding and return↵
return PQ[:res_len]↵
"""↵
Calculates P(x) * Q(x)↵
↵
Input:↵
P: A list representing a polynomial P(x)↵
Q: A list representing a polynomial Q(x)↵
Output:↵
A list representing the polynomial P(x) * Q(x)↵
"""↵
def fast_polymult_using_FFT(P, Q):↵
# Calculate length of the list representing P*Q↵
n1 = len(P)↵
n2 = len(Q)↵
res_len = n1 + n2 - 1↵
↵
# Pick n sufficiently big↵
n = 1↵
while n < res_len:↵
n *= 2↵
↵
# Pad with extra 0s to reach length n↵
P = P + [0] * (n - n1)↵
Q = Q + [0] * (n - n2)↵
↵
# Transform P and Q↵
fft_P = FFT(P)↵
fft_Q = FFT(Q)↵
↵
# Calculate FFT of P*Q↵
fft_PQ = [p*q for p,q in zip(fft_P,fft_Q)]↵
↵
# Inverse FFT↵
PQ = inverse_FFT(fft_PQ)↵
↵
# Remove padding and return↵
return PQ[:res_len]↵
```↵
</spoiler>↵
↵
<spoiler summary="FFT implementation in C++ based on this algorithm">↵
Here is an FTT implementation. It is coded in the same style as in [KACTL](https://github.com/kth-competitive-programming/kactl/blob/main/content/numerical/FastFourierTransform.h).↵
↵
```cpp↵
typedef complex<double> C;↵
typedef vector<double> vd;↵
void fft(vector<C>& a) {↵
int n = sz(a);↵
static vector R{1.L + 0il};↵
static vector rt{1. + 0i};↵
for (static int k = 2; k < n; k *= 2) {↵
R.resize(n/2); rt.resize(n/2);↵
rep(i,k/2,k) rt[i] = R[i] = R[i-k/2] * pow(1il, 2./k);;↵
}↵
for (int k = n; k > 1; k /= 2) rep(i,0,n/k) rep(j,i*k,i*k + k/2) {↵
C &u = a[j], &v = a[j+k/2], &r = rt[i];↵
C z(v.real()*r.real() - v.imag()*r.imag(), ↵
v.real()*r.imag() + v.imag()*r.real());↵
v = u - z;↵
u = u + z;↵
}↵
vi rev(n);↵
rep(i,0,n) rev[i] = rev[i / 2] / 2 + (i & 1) * n / 2;↵
rep(i,0,n) if (i < rev[i]) swap(a[i], a[rev[i]]);↵
}↵
↵
vd conv(const vd& a, const vd& b) {↵
if (a.empty() || b.empty()) return {};↵
vd res(sz(a) + sz(b) - 1);↵
int L = 32 - __builtin_clz(sz(res)), n = 1 << L;↵
vector<C> in(n), out(n);↵
copy(all(a), begin(in));↵
rep(i,0,sz(b)) in[i].imag(b[i]);↵
fft(in);↵
for (C& x : in) x *= x;↵
rep(i,0,n) out[i] = in[-i & (n - 1)] - conj(in[i]);↵
fft(out);↵
rep(i,0,sz(res)) res[i] = imag(out[i]) / (4 * n);↵
return res;↵
}↵
```↵
</spoiler>↵
↵
### (Advanced) Connection between this algorithm and NTT↵
Just like how there is FFT and NTT, there are two variants of this algorithm too. One using complex floating point numbers, and the other using modulo a prime (or more generally modulo an odd composite number).↵
↵
<spoiler summary="Using modulo integers instead of complex numbers">↵
This algorithm requires three properties. Firstly it needs to be possible to divide by $2$, and secondly $\sqrt{c}$ needs to exist, and thirdly it needs to be possible to divide by $\sqrt{c}$. This means that we don't technically need complex numbers, we could also use other number systems (like working modulo a prime or modulo an odd composite number).↵
↵
Primes that work nicely for this purpose are called "NTT primes", which means that the prime — 1 is divisible by a large power of $2$. Common examples of NTT primes are: $998244353 = 119 \cdot 2^{23} + 1$, $167772161 = 5 \cdot 2^{25} + 1$ and $469762049 = 7 \cdot 2^{26} + 1$.↵
</spoiler>↵
↵
<spoiler summary="What if $sqrt(c)$ doesn't exist?">↵
One of the things I dislike about NTT is that for NTT to be defined, there needs to exist a $n$-th root of unity. Usually problems involving NTT are designed so that this is never an issue. But if you want to use NTT where it hasn't been designed to magically work, then this is a really big issue. The NTT can become undefined!↵
↵
Note that this algorithm does not exactly share the same drawback of being undefined. The reason for this is that if $\sqrt{c}$ doesn't exist, then the algorithm can simply choose to either switch over to a $O(n^2)$ polynomial multiplication algorithm, or fall back to fast unmodded polynomial multiplication. The implications from this is that this algorithm can do fast modded polynomial multiplication even if it is given a relatively bad NTT prime. I just find this property to be really cool!↵
↵
A good example of when NTT becomes undefined is this yosup problem [convolution_mod_large](https://judge.yosupo.jp/problem/convolution_mod_large). Here the NTT mod is $998244353 = 119 \cdot 2^{23}$. The tricky thing about the problem is that $n=2^{24}$. Since $998244353 = 119 \cdot 2^{23} + 1$ there wont exist any $n$-th root of unity, so the NTT of length $n$ is undefined. However, the divide and conquer approach from this blog can easily solve the problem by falling back to the $O(n^2)$ algorithm.↵
</spoiler>↵
↵
<spoiler summary="NTT implementation in Python based on this algorithm">↵
Here is an NTT implementation. A codegolfed version of the same code can be found on [Pyrival](https://github.com/cheran-senthil/PyRival/blob/master/pyrival/algebra/ntt.py).↵
↵
```py↵
# Mod used for NTT↵
# Requirement: Any odd integer > 2↵
# It is important that MOD - 1 is↵
# divisible by lots of 2s↵
MOD = (119 << 23) + 1↵
assert MOD & 1↵
↵
# Precalc non-quadratic_residue (used by the NTT)↵
non_quad_res = 2↵
while pow(non_quad_res, MOD//2, MOD) != MOD - 1:↵
non_quad_res += 1↵
rt = [1]↵
↵
"""↵
Calculates NTT(P) in O(n log n) time.↵
↵
This implementation is based on the ↵
polynomial modulo multiplication algorithm.↵
↵
Input:↵
P: A list of length n representing a polynomial P(x).↵
n needs to be a power of 2.↵
Output:↵
A list of length n representing the NTT of the polynomial P,↵
i.e. the list [P(root**i) % MOD for i in range(n)]↵
where root is an n-th root of unity mod MOD↵
"""↵
def NTT(P):↵
n = len(P)↵
# Assert n is a power of 2↵
assert n and (n - 1) & n == 0↵
↵
# Check that NTT is defined for this n↵
assert (MOD - 1) % n == 0↵
↵
# Make a copy of P to not modify original P↵
P = P[:] ↵
↵
# Precalculate the roots↵
while 2 * len(rt) < n:↵
# 4*len(rt)-th root of unity↵
root = pow(non_quad_res, MOD//(4 * len(rt)), MOD)↵
rt.extend([r * root % MOD for r in rt])↵
↵
# Transform P↵
k = n↵
while k > 1:↵
for i in range(n//k):↵
r = rt[i]↵
for j1 in range(i*k, i*k + k//2):↵
j2 = j1 + k//2↵
z = r * P[j2]↵
P[j2] = (P[j1] - z) % MOD↵
P[j1] = (P[j1] + z) % MOD↵
k //= 2↵
↵
# Bit reverse P before returning↵
rev = [0] * n↵
for i in range(1, n):↵
rev[i] = rev[i // 2] // 2 + (i & 1) * n // 2↵
↵
return [P[r] for r in rev]↵
↵
# Inverse of NTT(P) using a standard trick↵
def inverse_NTT(ntt_P):↵
n = len(ntt_P)↵
n_inv = pow(n, -1, MOD) # Requires Python 3.8↵
# The following works in any Python version, but requires MOD to be prime↵
# n_inv = pow(n, MOD - 2, MOD)↵
assert n * n_inv % MOD == 1↵
return NTT([ntt_P[-i] * n_inv % MOD for i in range(n)])↵
↵
"""↵
Calculates P(x) * Q(x) (where the coeffiecents are returned % MOD)↵
↵
Input:↵
P: A list representing a polynomial P(x)↵
Q: A list representing a polynomial Q(x)↵
Output:↵
A list representing the polynomial P(x) * Q(x) (with coeffients % MOD)↵
"""↵
def fast_polymult_using_NTT(P, Q):↵
# Calculate length of the list representing P*Q↵
n1 = len(P)↵
n2 = len(Q)↵
res_len = n1 + n2 - 1↵
↵
# Pick n sufficiently big↵
n = 1↵
while n < res_len:↵
n *= 2↵
↵
# Pad with extra 0s to reach length n↵
P = P + [0] * (n - n1)↵
Q = Q + [0] * (n - n2)↵
↵
# Transform P and Q↵
ntt_P = NTT(P)↵
ntt_Q = NTT(Q)↵
↵
# Calculate NTT of P*Q↵
ntt_PQ = [p * q % MOD for p,q in zip(ntt_P,ntt_Q)]↵
↵
# Inverse NTT↵
PQ = inverse_NTT(ntt_PQ)↵
↵
# Remove padding and return↵
return PQ[:res_len]↵
```↵
</spoiler>↵
↵
<spoiler summary="NTT implementation in C++ based on this algorithm">↵
Here is an NTT implementation. It is coded in the same style as in [KACTL](https://github.com/kth-competitive-programming/kactl/blob/main/content/numerical/NumberTheoreticTransform.h).↵
↵
```cpp↵
const ll mod = (119 << 23) + 1;// = 998244353↵
// For p < 2^30 there is also e.g. 5 << 25, 7 << 26, 479 << 21↵
// and 483 << 21 The last two are > 10^9.↵
typedef vector<ll> vl;↵
↵
#include "../number-theory/ModPow.h"↵
↵
void ntt(vl &a) {↵
int n = sz(a);↵
static ll r = 3;↵
while(modpow(r, mod/2) + 1 < mod) ++r;↵
static vl rt{1};↵
for (static int k = 2; k < n; k *= 2) {↵
rt.resize(n/2);↵
rep(i,k/2,k) rt[i] = rt[i-k/2] * modpow(r, mod/2/k) % mod;↵
}↵
for (int k = n; k > 1; k /= 2) rep(i,0,n/k) rep(j,i*k,i*k + k/2) {↵
ll &u = a[j], &v = a[j+k/2], z = rt[i] * v % mod;↵
v = u - z + (u < z ? mod : 0);↵
u = u + z - (u + z >= mod ? mod : 0);↵
}↵
vi rev(n);↵
rep(i,0,n) rev[i] = rev[i / 2] / 2 + (i & 1) * n / 2;↵
rep(i,0,n) if (i < rev[i]) swap(a[i], a[rev[i]]);↵
}↵
vl conv(vl a, vl b) {↵
↵
if (a.empty() || b.empty()) return {};↵
int s = sz(a) + sz(b) - 1, B = 32 - __builtin_clz(s), n = 1 << B;↵
int inv = modpow(n, mod - 2);↵
vl out(n);↵
a.resize(n); b.resize(n);↵
ntt(a), ntt(b);↵
rep(i,0,n) out[-i & (n - 1)] = (ll)a[i] * b[i] % mod * inv % mod;↵
ntt(out);↵
return {out.begin(), out.begin() + s};↵
}↵
```↵
</spoiler>↵
↵
### (Advanced) Shorter implementations ("Codegolfed version")↵
It is possible to make really short but slightly less natural implementations of this algorithm. Originally I was thinking of using this shorter version in the blog, but in the end I didn't do it. So here they are. If you want to implement this algorithm and use it in practice, then I would recommend taking one of these implementations and porting it to C++.↵
↵
<spoiler summary="Short Python implementation without any speedup tricks">↵
↵
```py↵
"""↵
Calculates P(x) * Q(x) % (x^n - c) in O(n log n) time↵
↵
Input:↵
n: Integer, needs to be power of 2↵
c: Non-zero complex floating point number↵
P: A list of length 2*n representing a polynomial P(x) % (x^2n - c^2)↵
Q: A list of length 2*n representing a polynomial Q(x) % (x^2n - c^2)↵
Output:↵
A list of length n representing the polynomial P(x) * Q(x) % (x^n - c)↵
"""↵
def fast_polymult_mod3(P, Q, n, c):↵
assert len(P) == 2*n and len(Q) == 2*n↵
↵
# Mod P and Q by (x^n - c)↵
P = [p1 + c * p2 for p1,p2 in zip(P[:n], P[n:])]↵
Q = [q1 + c * q2 for q1,q2 in zip(Q[:n], Q[n:])]↵
↵
# Base case↵
if n == 1:↵
return [P[0] * Q[0]]↵
↵
assert n % 2 == 0↵
import cmath↵
sqrtc = cmath.sqrt(c)↵
↵
# Recursively calculate PQ_minus := P * Q % (x^n/2 - sqrt(c)) ↵
# PQ_plus := P * Q % (x^n/2 + sqrt(c))↵
↵
PQ_minus = fast_polymult_mod3(P, Q, n//2, sqrtc)↵
PQ_plus = fast_polymult_mod3(P, Q, n//2, -sqrtc)↵
↵
# Calculate PQ mod (x^n - c) using PQ_minus and PQ_plus↵
PQ = [(m + p)/2 for m,p in zip(PQ_minus, PQ_plus)] +\↵
[(m - p)/(2*sqrtc) for m,p in zip(PQ_minus, PQ_plus)]↵
↵
return PQ↵
"""↵
Calculates P(x) * Q(x)↵
↵
Input:↵
P: A list representing a polynomial P(x)↵
Q: A list representing a polynomial Q(x)↵
Output:↵
A list representing the polynomial P(x) * Q(x)↵
"""↵
def fast_polymult3(P, Q):↵
# Calculate length of the list representing P*Q↵
n1 = len(P)↵
n2 = len(Q)↵
res_len = n1 + n2 - 1↵
↵
# Pick n sufficiently big↵
n = 1↵
while n < res_len:↵
n *= 2↵
↵
# Pad with extra 0s to reach length 2*n↵
P = P + [0] * (2*n - n1)↵
Q = Q + [0] * (2*n - n2)↵
↵
# Pick non-zero c arbitrarily =)↵
c = 123.24↵
↵
# Calculate P*Q mod x^n - c↵
PQ = fast_polymult_mod3(P, Q, n, c)↵
↵
# Remove extra 0 padding and return↵
return PQ[:res_len]↵
```↵
</spoiler>↵
↵
<spoiler summary="Short Python implementation supporting odd and even $n$ (making it up to 2 times faster)">↵
↵
```py↵
"""↵
Calculates P(x) * Q(x) % (x^n - c) in O(n log n) time↵
↵
Input:↵
n: Integer, needs to be power of 2↵
c: Non-zero complex floating point number↵
P: A list of length 2*n representing a polynomial P(x) % (x^2n - c^2)↵
Q: A list of length 2*n representing a polynomial Q(x) % (x^2n - c^2)↵
Output:↵
A list of length n representing the polynomial P(x) * Q(x) % (x^n - c)↵
"""↵
def fast_polymult_mod4(P, Q, n, c):↵
assert len(P) == 2*n and len(Q) == 2*n↵
↵
# Mod P and Q by (x^n - c)↵
P = [p1 + c * p2 for p1,p2 in zip(P[:n], P[n:])]↵
Q = [q1 + c * q2 for q1,q2 in zip(Q[:n], Q[n:])]↵
↵
# Base case (n is odd)↵
if n & 1:↵
# Calculate the answer in O(n^2) time↵
res = [0] * (2*n)↵
for i in range(n):↵
for j in range(n):↵
res[i + j] += P[i] * Q[j]↵
return [r1 + c * r2 for r1,r2 in zip(res[:n], res[n:])]↵
↵
assert n % 2 == 0↵
import cmath↵
sqrtc = cmath.sqrt(c)↵
↵
# Recursively calculate PQ_minus := P * Q % (x^n/2 - sqrt(c)) ↵
# PQ_plus := P * Q % (x^n/2 + sqrt(c))↵
↵
PQ_minus = fast_polymult_mod4(P, Q, n//2, sqrtc)↵
PQ_plus = fast_polymult_mod4(P, Q, n//2, -sqrtc)↵
↵
# Calculate PQ mod (x^n - c) using PQ_minus and PQ_plus↵
PQ = [(m + p)/2 for m,p in zip(PQ_minus, PQ_plus)] +\↵
[(m - p)/(2*sqrtc) for m,p in zip(PQ_minus, PQ_plus)]↵
↵
return PQ↵
↵
"""↵
Calculates P(x) * Q(x)↵
↵
Input:↵
P: A list representing a polynomial P(x)↵
Q: A list representing a polynomial Q(x)↵
Output:↵
A list representing the polynomial P(x) * Q(x)↵
"""↵
def fast_polymult4(P, Q):↵
# Calculate length of the list representing P*Q↵
n1 = len(P)↵
n2 = len(Q)↵
res_len = n1 + n2 - 1↵
↵
# Pick n sufficiently big↵
b = 0↵
alim = 10↵
while alim * 2**b < res_len:↵
b += 1↵
a = (res_len - 1) // 2**b + 1↵
n = a * 2**b↵
↵
# Pad with extra 0s to reach length 2*n↵
P = P + [0] * (2*n - n1)↵
Q = Q + [0] * (2*n - n2)↵
↵
# Pick non-zero c arbitrarily =)↵
c = 123.24↵
↵
# Calculate P*Q mod x^n - c↵
PQ = fast_polymult_mod4(P, Q, n, c)↵
↵
# Remove extra 0 padding and return↵
return PQ[:res_len]↵
```↵
</spoiler>↵
↵
<spoiler summary="Short Python implementation supporting odd and even $n$ and imaginary cyclic convolution (making it up to 4 times faster)">↵
↵
```py↵
"""↵
Calculates P(x) * Q(x) % (x^n - c) in O(n log n) time↵
↵
Input:↵
n: Integer, needs to be power of 2↵
c: Non-zero complex floating point number↵
P: A list of length 2*n representing a polynomial P(x) % (x^2n - c^2)↵
Q: A list of length 2*n representing a polynomial Q(x) % (x^2n - c^2)↵
Output:↵
A list of length n representing the polynomial P(x) * Q(x) % (x^n - c)↵
"""↵
def fast_polymult_mod4(P, Q, n, c):↵
assert len(P) == 2*n and len(Q) == 2*n↵
↵
# Mod P and Q by (x^n - c)↵
P = [p1 + c * p2 for p1,p2 in zip(P[:n], P[n:])]↵
Q = [q1 + c * q2 for q1,q2 in zip(Q[:n], Q[n:])]↵
↵
# Base case (n is odd)↵
if n & 1:↵
# Calculate the answer in O(n^2) time↵
res = [0] * (2*n)↵
for i in range(n):↵
for j in range(n):↵
res[i + j] += P[i] * Q[j]↵
return [r1 + c * r2 for r1,r2 in zip(res[:n], res[n:])]↵
↵
assert n % 2 == 0↵
import cmath↵
sqrtc = cmath.sqrt(c)↵
↵
# Recursively calculate PQ_minus := P * Q % (x^n/2 - sqrt(c)) ↵
# PQ_plus := P * Q % (x^n/2 + sqrt(c))↵
↵
PQ_minus = fast_polymult_mod4(P, Q, n//2, sqrtc)↵
PQ_plus = fast_polymult_mod4(P, Q, n//2, -sqrtc)↵
↵
# Calculate PQ mod (x^n - c) using PQ_minus and PQ_plus↵
PQ = [(m + p)/2 for m,p in zip(PQ_minus, PQ_plus)] +\↵
[(m - p)/(2*sqrtc) for m,p in zip(PQ_minus, PQ_plus)]↵
↵
return PQ↵
↵
"""↵
Calculates P(x) * Q(x) of two real polynomials↵
↵
Input:↵
P: A list representing a real polynomial P(x)↵
Q: A list representing a real polynomial Q(x)↵
Output:↵
A list representing the real polynomial P(x) * Q(x)↵
"""↵
def fast_polymult5(P, Q):↵
# Calculate length of the list representing P*Q↵
n1 = len(P)↵
n2 = len(Q)↵
res_len = n1 + n2 - 1↵
↵
# Pick n sufficiently big↵
b = 1↵
alim = 10↵
while alim * 2**b < res_len:↵
b += 1↵
a = (res_len - 1) // 2**b + 1↵
n = a * 2**b↵
↵
# Pick c = i (imaginary unit)↵
c = 1j↵
# and decrease the size of n by a factor of 2↵
n //= 2↵
↵
# Pad with extra 0s to reach length 2*n↵
P = P + [0] * (2*n - n1)↵
Q = Q + [0] * (2*n - n2)↵
↵
# Calculate P*Q mod x^n - i↵
PQ = fast_polymult_mod4(P, Q, n, c)↵
↵
# The imaginary part contains the "overflow"↵
PQ = [pq.real for pq in PQ] + [pq.imag for pq in PQ]↵
↵
# Remove extra 0 padding and return↵
return PQ[:res_len]↵
```↵
</spoiler>↵
↵
↵
↵
↵
↵