Быстрое умножение по модулю
Difference between ru4 and ru5, changed 0 character(s)
Рассмотрим такую задачу: даны три целых числа $x$, $y$ и $m$, $0 \leqslant x,y < m < 2^{32}$, найти $xy\bmod m$. По-хорошему хотелось бы просто перемножить эти два числа, а потом применить операцию остатка:↵
~~~~~↵
uint32_t prod(const uint32_t x, const uint32_t y, const uint32_t m)↵
{↵
return x * y % m;↵
}↵
~~~~~↵
Как вы, возможно, догадываетесь, это решение неверное. Всё дело в том, что в такой процедуре возможно переполнение: операция `x * y` выполняется в типе `uint32_t`, и на самом деле промежуточным результатом выполнения этой операции будет не $xy$, а $xy\bmod2^{32}$. Если после этого взять результат по модулю $m$, он может отличаться от правильного:↵
$$↵
\left(xy\bmod2^{32}\right)\bmod m\ne xy\bmod m.↵
$$↵
Выход прост — перемножать необходимо в большем типе:↵
~~~~~↵
uint64_t prod_uint64(const uint64_t x, const uint64_t y, const uint64_t m)↵
{↵
return x * y % m;↵
}↵
~~~~~↵
Если так делать, то, поскольку $xy<2^{64}$, это произведение точно не переполнится, и после взятия результата по модулю получится правильный ответ.↵

Вопрос: а что делать, если $x$, $y$ и $m$ могут быть больше?↵

[cut]↵

Предлагаю следующее.↵

1. _Бинарное умножение._ Так же, как и бинарное возведение в степень, существует бинарное умножение: чтобы посчитать $xy$, посчитаем $x\left\lfloor\frac y2\right\rfloor$, сложим это число само с собой и, возможно, добавим ещё $x$. Это потратит $\mathcal O(\log y)$ действий, но среди них не будет ничего, кроме сложения и вычитания!↵
~~~~~↵
uint64_t sum(const uint64_t x, const uint64_t y, const uint64_t m)↵
{↵
uint64_t ans = x + y;↵
if (ans < x || ans >= m)↵
ans -= m;↵
return ans;↵
}↵
uint64_t prod_binary(const uint64_t x, const uint64_t y, const uint64_t m)↵
{↵
if (y <= 1)↵
return y ? x : 0;↵
uint64_t ans = prod_binary(x, y >> 1, m);↵
ans = sum(ans, ans, m);↵
if (y & 1)↵
ans = sum(ans, x, m);↵
return ans;↵
}↵
~~~~~↵
2. _Умножение с помощью `int128`._ Чтобы перемножить два 32-битных числа, нужна 64-битная промежуточная переменная. А чтобы перемножить два 64-битных числа, нужна 128-битная переменная! В современных 64-битных компиляторах C++ ([кроме разве что Microsoft® Visual C++®](https://stackoverflow.com/questions/6759592/how-to-enable-int128-on-visual-studio)) есть специальный тип `__int128`, позволяющий осуществлять операции над 128-битными числами.↵
~~~~~↵
int64_t prod_uint128(const uint64_t x, const uint64_t y, const uint64_t m)↵
{↵
return (unsigned __int128)x * y % m;↵
}↵
~~~~~↵
3. _Умножение с помощью вещественного типа._ Что такое $xy\bmod m$? Это на самом деле $xy-cm$, где $c=\left\lfloor\frac{xy}m\right\rfloor$. Давайте тогда попробуем посчитать $c$, а отсюда найдём $xy\bmod m$. При этом заметим, что нам не требуется находить $c$ прямо точно. Что будет, если мы случайно посчитаем, скажем, $c-4$? Тогда при подсчёте остатка мы вычислим $xy-(c-4)m=xy-cm+4m=xy\bmod m+4m$. На первый взгляд, это не то, что нам надо. Но если $m$ не слишком велико и $xy\bmod m+4m$ не вылезло из 64-битного типа, то после этого можно по-честному взять остаток и получить ответ.<br/><br/>↵
   Получается такой код:↵
~~~~~↵
uint64_t prod_double(const uint64_t x, const uint64_t y, const uint64_t m)↵
{↵
uint64_t c = (double)x * y / m;↵
int64_t ans = int64_t(x * y - c * m) % int64_t(m);↵
if (ans < 0)↵
ans += m;↵
return ans;↵
}↵
~~~~~  ↵
~~~~~↵
uint64_t prod_long_double(const uint64_t x, const uint64_t y, const uint64_t m)↵
{↵
uint64_t c = (long double)x * y / m;↵
int64_t ans = int64_t(x * y - c * m) % int64_t(m);↵
if (ans < 0)↵
ans += m;↵
return ans;↵
}↵
~~~~~↵
   `double` достаточно точный для этой задачи, если $x$, $y$ и $m$ меньше $2^{57}$. `long double` же хватает на числа, меньшие $2^{63}$, однако стоит помнить, что `long double` для этого должен быть 80-битным, а это верно не на всех компиляторах: например, в Microsoft® Visual C++® [`long double` — то же самое, что `double`](https://stackoverflow.com/questions/7120710/why-did-microsoft-abandon-long-double-data-type).<br/><br/>↵
   Обратите внимание, что этот способ неприменим, если $m>2^{63}$: в этом случае `ans` нельзя хранить в `int64_t`, так как, возможно, $\mathtt{ans}\geqslant 2^{63}$ и произойдёт переполнение, из-за которого выполнится ветка `(ans < 0)` и мы получим неверный ответ.<br/><br/>↵
   Видно, что Microsoft® Visual C++® отстаёт от остальных компиляторов ~~в развитии~~ в наличии технических средств для умножения больших чисел по модулю, поэтому, если мы хотим, чтобы функция работала на всех компиляторах быстро, ей необходима какая-то свежая идея. К счастью, именно такую идею в 1960-м году изобрёл [Анатолий Карацуба](https://ru.wikipedia.org/wiki/%D0%9A%D0%B0%D1%80%D0%B0%D1%86%D1%83%D0%B1%D0%B0,_%D0%90%D0%BD%D0%B0%D1%82%D0%BE%D0%BB%D0%B8%D0%B9_%D0%90%D0%BB%D0%B5%D0%BA%D1%81%D0%B5%D0%B5%D0%B2%D0%B8%D1%87).↵
4. _Умножение Карацубы._ Идея изначально использовалась для быстрого умножения длинных чисел. А именно, пусть $x$ и $y$ — два неотрицательных целых числа, меньших $N^2$. Поделим их с остатком на $N$: $x=Nx_1+x_0$, $y=Ny_1+y_0$. Тогда искомое $xy$ можно найти как $N^2x_1y_1+Nx_1y_0+Nx_0y_1+x_0y_0=N\cdot\bigl(N\cdot x_1y_1+\left(x_0+x_1\right)\left(y_0+y_1\right)-x_1y_1-x_0y_0\bigr)+x_0y_0$. Как видно, это преобразование свело умножение $x$ и $y$ 1) к $\mathcal O(1)$ сложениям и вычитаниям чисел, не превосходящих $N^4$; 2) к трём умножениям чисел, не превосходящих $2N$ (а именно $x_0y_0$, $x_1y_1$ и $\left(x_0+x_1\right)\left(y_0+y_1\right)$); 3) к двум умножениям на $N$ чисел, не превосходящих $2N^2$.<br/><br/>↵
   Пункт 1) практически всегда очень прост. В случае с длинными числами пункт 3) также прост: можно взять $N$, равное степени двойки, и тогда он осуществляется как обычный двоичный сдвиг (оператор `<<` в C++). Поэтому по существу Карацуба свёл одно умножение чисел, меньших $N^2$, к трём умножениям чисел, меньших $2N$. Если эти умножения также свести методом Карацубы, по [мастер-теореме](https://ru.wikipedia.org/wiki/%D0%9E%D1%81%D0%BD%D0%BE%D0%B2%D0%BD%D0%B0%D1%8F_%D1%82%D0%B5%D0%BE%D1%80%D0%B5%D0%BC%D0%B0_%D0%BE_%D1%80%D0%B5%D0%BA%D1%83%D1%80%D1%80%D0%B5%D0%BD%D1%82%D0%BD%D1%8B%D1%85_%D1%81%D0%BE%D0%BE%D1%82%D0%BD%D0%BE%D1%88%D0%B5%D0%BD%D0%B8%D1%8F%D1%85) асимптотика этого метода составит $\Theta\left(\log^{\log_23}N\right)$ вместо наивного $\Theta\left(\log^2N\right)$.<br/><br/>↵
   Но нам не потребуется использовать рекурсию, ведь, если длину $x$ и $y$ уменьшить вдвое, мы уже сможем воспользоваться `prod_uint64` или `prod_double`. Сложность в нашем случае составляет пункт 3): подобрать такое $N$, чтобы, во-первых, оно было меньше $2^{32}$ либо в крайнем случае чуть-чуть больше, во-вторых, чтобы на него можно было быстро умножать числа порядка $N^2$. Оба требования выполнятся, если взять $N=\mathrm{round}\left(\sqrt m\right)$: действительно, тогда при $m<2^{64}$ верно $N<2^{32}$, а $\left|m_0\right|=\left|m-N^2\right|\leqslant N<2^{32}$; тогда $xN=\left(x_1N+x_0\right)N=x_1N^2+x_0N\equiv x_0N-x_1m_0\pmod m$, и оба умножения здесь осуществляются над числами порядка $N$.<br/><br/> ↵
   Внимательный читатель заметит, что здесь мы обзавелись серьёзной проблемой: извлечение квадратного корня из целого числа. Если вы умеете применять в этой задаче умножение Карацубы, обходя данную проблему (в том числе достаточно быстро находить квадратный корень, написать более быстрый или более короткий код, чем у меня), напишите, пожалуйста, в комментариях!<br/><br/>↵
   Поскольку находить произведение $\left(x_0+x_1\right)\left(y_0+y_1\right)$ оказалось крайне неприятно (напомню, `prod_double` не работает при $m>2^{63}$), я всё же решил просто вычислить $x_0y_1$ и $x_1y_0$ по-честному — таким образом, это не метод Карацубы в истинном смысле, так как я трачу четыре умножения чисел порядка $N$.↵
~~~~~↵
uint64_t dif(const uint64_t x, const uint64_t y, const uint64_t m)↵
{↵
uint64_t ans = x - y;↵
if (ans > x)↵
ans += m;↵
return ans;↵
}↵
bool check_ge_rounded_sqrt(const uint64_t m, const uint64_t r)↵
{↵
return ((r >= 1ull << 32) || r * (r + 1) >= m);↵
}↵
bool check_le_rounded_sqrt(const uint64_t m, const uint64_t r)↵
{↵
return (r == 0 || ((r <= 1ull << 32) && r * (r - 1) < m));↵
}↵
bool check_rounded_sqrt(const uint64_t m, const uint64_t r)↵
{↵
return check_ge_rounded_sqrt(m, r) && check_le_rounded_sqrt(m, r);↵
}↵
uint64_t rounded_sqrt(const uint64_t m)↵
{↵
uint64_t r = floorl(.5 + sqrtl(m));↵
if (!check_ge_rounded_sqrt(m, r))↵
while (!check_ge_rounded_sqrt(m, ++r));↵
else if (!check_le_rounded_sqrt(m, r))↵
while (!check_le_rounded_sqrt(m, --r));↵
return r;↵
}↵
uint64_t prod_karatsuba_aux(const uint64_t x, const uint64_t N, const int64_t m0, const uint64_t m)↵
{↵
uint64_t x1 = x / N;↵
uint64_t x0N = (x - N * x1) * N;↵
if (m0 >= 0)↵
return dif(x0N, x1 * (uint64_t)m0, m);↵
else↵
return sum(x0N, x1 * (uint64_t)-m0, m);↵
}↵
uint64_t prod_karatsuba(const test& t)↵
{↵
uint64_t x = t.x, y = t.y, m = t.modulo;↵
uint64_t N = rounded_sqrt(t.modulo);↵
int64_t m0 = m - N * N;↵
uint64_t x1 = t.x / N;↵
uint64_t x0 = t.x - N * x1;↵
uint64_t y1 = t.y / N;↵
uint64_t y0 = t.y - N * y1;↵
uint64_t x0y0 = sum(x0 * y0, 0, m);↵
uint64_t x0y1 = sum(x0 * y1, 0, m);↵
uint64_t x1y0 = sum(x1 * y0, 0, m);↵
uint64_t x1y1 = sum(x1 * y1, 0, m);↵
return sum(prod_karatsuba_aux(sum(prod_karatsuba_aux(x1y1, N, m0, m), sum(x0y1, x1y0, m), m), N, m0, m), x0y0, m);↵
}↵
~~~~~↵

Видно, что на самом деле единственное, что даёт нам тут метод Карацубы — что если вы найдёте большое число $N$, на которое вы умеете быстро умножать, то вы сможете умножать любые два числа по модулю. Фактически, если бы модуль $m$ был фиксированным, и было бы много запросов умножения по этому фиксированному модулю, то метод Карацубы был бы молниеносным, поскольку самая затратная операция в нём — это квадратный корень. Хотелось бы, таким образом, взять, например, $N=2^{32}$ и сделать всё то же самое, что в прошлом пункте, но без квадратного корня. Увы, но я не придумал, как умножать на $2^{32}$. Можно было бы написать примерно такую функцию:↵
~~~~~↵
uint64_t prod_double_small(const uint64_t x, const uint64_t y, const uint64_t m)↵
{↵
uint64_t c = (double)x * y / m;↵
uint64_t ans = (x * y - c * m) % m;↵
return ans;↵
}↵
~~~~~↵
Она вычисляет произведение по модулю при условии, что `uint64_t c = (double)x * y / m` посчиталось абсолютно точно. Но гарантировать, что оно точно посчитается, не представляется возможным, поскольку $\frac{xy}m$ вполне может оказаться на $10^{-18}$ меньше, чем какое-то целое число, и типа `double` не хватит, чтобы это заметить. Именно эту проблему обходит функция `prod_karatsuba_aux`. Если вы смогли её обойти хитрее, добро пожаловать в комментарии.↵

---↵

Ниже приведены три таблицы по разным компиляторам (да славится имя [user:MikeMirzayanov,2021-11-09], ведь именно благодаря нецелевому использованию его [Polygon](polygon.codeforces.com)'а я это сделал), в каждой таблице строки соответствуют различным функциям, столбцы соответствуют максимальной разрешённой битности $x$, $y$ и $m$. Если указано **<span style="color:red">CE</span>**, значит, с данным компилятором программа не скомпилируется, а если **<span style="color:red">WA</span>** — может выдать неверный ответ. В противном случае указано время работы функции на Intel® Core™ i3-8100 CPU @ 3.60GHz. Погрешность приблизительно равна одной-двум наносекундам, на самых медленных функциях может доходить до десяти наносекунд.↵

1. **Microsoft® Visual C++® 2010**↵
<table class="tg">↵
<thead>↵
  <tr>↵
    <th>Метод</th>↵
    <th>32 бита</th>↵
    <th>57 битов</th>↵
    <th>63 бита</th>↵
    <th>64 бита</th>↵
  </tr>↵
</thead>↵
<tbody>↵
  <tr>↵
    <th><tt>prod_uint64</tt></th>↵
    <th>7 ns</th>↵
    <th><span style="color:red">WA</span></th>↵
    <th><span style="color:red">WA</span></th>↵
    <th><span style="color:red">WA</span></th>↵
  </tr>↵
  <tr>↵
    <th><tt>prod_binary</tt></th>↵
    <th>477 ns</th>↵
    <th>847 ns</th>↵
    <th>889 ns</th>↵
    <th>870 ns</th>↵
  </tr>↵
  <tr>↵
    <th><tt>prod_uint128</tt></th>↵
    <th><span style="color:red">CE</span></th>↵
    <th><span style="color:red">CE</span></th>↵
    <th><span style="color:red">CE</span></th>↵
    <th><span style="color:red">CE</span></th>↵
  </tr>↵
  <tr>↵
    <th><tt>prod_double</tt></th>↵
    <th>66 ns</th>↵
    <th>95 ns</th>↵
    <th><span style="color:red">WA</span></th>↵
    <th><span style="color:red">WA</span></th>↵
  </tr>↵
  <tr>↵
    <th><tt>prod_long_double</tt></th>↵
    <th>66 ns</th>↵
    <th>98 ns</th>↵
    <th><span style="color:red">WA</span></th>↵
    <th><span style="color:red">WA</span></th>↵
  </tr>↵
  <tr>↵
    <th><tt>prod_karatsuba</tt></th>↵
    <th>128 ns</th>↵
    <th>125 ns</th>↵
    <th>138 ns</th>↵
    <th>139 ns</th>↵
  </tr>↵
</tbody>↵
</table>↵

2. **GNU G++17**↵
<table class="tg">↵
<thead>↵
  <tr>↵
    <th>Метод</th>↵
    <th>32 бита</th>↵
    <th>57 битов</th>↵
    <th>63 бита</th>↵
    <th>64 бита</th>↵
  </tr>↵
</thead>↵
<tbody>↵
  <tr>↵
    <th><tt>prod_uint64</tt></th>↵
    <th>4 ns</th>↵
    <th><span style="color:red">WA</span></th>↵
    <th><span style="color:red">WA</span></th>↵
    <th><span style="color:red">WA</span></th>↵
  </tr>↵
  <tr>↵
    <th><tt>prod_binary</tt></th>↵
    <th>455 ns</th>↵
    <th>774 ns</th>↵
    <th>841 ns</th>↵
    <th>845 ns</th>↵
  </tr>↵
  <tr>↵
    <th><tt>prod_uint128</tt></th>↵
    <th><span style="color:red">CE</span></th>↵
    <th><span style="color:red">CE</span></th>↵
    <th><span style="color:red">CE</span></th>↵
    <th><span style="color:red">CE</span></th>↵
  </tr>↵
  <tr>↵
    <th><tt>prod_double</tt></th>↵
    <th>26 ns</th>↵
    <th>36 ns</th>↵
    <th><span style="color:red">WA</span></th>↵
    <th><span style="color:red">WA</span></th>↵
  </tr>↵
  <tr>↵
    <th><tt>prod_long_double</tt></th>↵
    <th>29 ns</th>↵
    <th>20 ns</th>↵
    <th>19 ns</th>↵
    <th><span style="color:red">WA</span></th>↵
  </tr>↵
  <tr>↵
    <th><tt>prod_karatsuba</tt></th>↵
    <th>82 ns</th>↵
    <th>81 ns</th>↵
    <th>91 ns</th>↵
    <th>88 ns</th>↵
  </tr>↵
</tbody>↵
</table>↵

3. **GNU G++17 (64 bit)**↵
<table class="tg">↵
<thead>↵
  <tr>↵
    <th>Метод</th>↵
    <th>32 бита</th>↵
    <th>57 битов</th>↵
    <th>63 бита</th>↵
    <th>64 бита</th>↵
  </tr>↵
</thead>↵
<tbody>↵
  <tr>↵
    <th><tt>prod_uint64</tt></th>↵
    <th>8 ns</th>↵
    <th><span style="color:red">WA</span></th>↵
    <th><span style="color:red">WA</span></th>↵
    <th><span style="color:red">WA</span></th>↵
  </tr>↵
  <tr>↵
    <th><tt>prod_binary</tt></th>↵
    <th>313 ns</th>↵
    <th>550 ns</th>↵
    <th>604 ns</th>↵
    <th>630 ns</th>↵
  </tr>↵
  <tr>↵
    <th><tt>prod_uint128</tt></th>↵
    <th>17 ns</th>↵
    <th>34 ns</th>↵
    <th>30 ns</th>↵
    <th>30 ns</th>↵
  </tr>↵
  <tr>↵
    <th><tt>prod_double</tt></th>↵
    <th>23 ns</th>↵
    <th>22 ns</th>↵
    <th><span style="color:red">WA</span></th>↵
    <th><span style="color:red">WA</span></th>↵
  </tr>↵
  <tr>↵
    <th><tt>prod_long_double</tt></th>↵
    <th>23 ns</th>↵
    <th>24 ns</th>↵
    <th>23 ns</th>↵
    <th><span style="color:red">WA</span></th>↵
  </tr>↵
  <tr>↵
    <th><tt>prod_karatsuba</tt></th>↵
    <th>65 ns</th>↵
    <th>65 ns</th>↵
    <th>69 ns</th>↵
    <th>66 ns</th>↵
  </tr>↵
</tbody>↵
</table>↵

Поэтому базовый рецепт такой: если доступен `unsigned __int128`, то использовать его, если доступен 80-битный `long double`, то его тоже должно всегда хватать, а в противном случае надо, если хватает `double`, использовать `double`, иначе применить метод Карацубы.↵

При желании можете попробовать применить эти идеи на задачах [специального контеста](https://codeforces.me/gym/103399).

History

 
 
 
 
Revisions
 
 
  Rev. Lang. By When Δ Comment
ru8 Russian orz 2022-06-16 20:17:53 4 Мелкая правка: '%D1%87).\n4. _Умно' -> '%D1%87).\n\n4. _Умно'
en7 English orz 2022-02-24 05:26:02 16
ru7 Russian orz 2022-02-24 05:24:33 21 Мелкая правка: 'остатка:\n~~~~~\nu' -> 'остатка:\n\n~~~~~\nu'
en6 English orz 2021-11-10 21:54:33 1 Tiny change: 't|=\left|mN^2\right|' -> 't|=\left|m-N^2\right|'
en5 English orz 2021-11-09 15:13:55 8 Tiny change: ' Karatsuba](https://e' -> ' Karatsuba(https://e'
en4 English orz 2021-11-09 08:42:08 21 Tiny change: 'tal delay<strike> la' -> 'tal delay</strike> la'
ru6 Russian orz 2021-11-09 08:41:10 21
en3 English orz 2021-11-09 08:20:46 0 (published)
ru5 Russian orz 2021-11-09 08:10:07 0 (опубликовано)
ru4 Russian orz 2021-11-09 08:09:54 27
ru3 Russian orz 2021-11-09 08:07:56 23 Мелкая правка: '0y0, m);\n}\n~~~~~\n' -> '0y0, m);\n }\n~~~~~\n'
en2 English orz 2021-11-09 08:07:52 418 Tiny change: 'rm{round}\!\left(\sqr' -> 'rm{round}\left(\sqr'
ru2 Russian orz 2021-11-09 08:00:41 533
en1 English orz 2021-11-09 07:51:02 15661 Initial revision for English translation (saved to drafts)
ru1 Russian orz 2021-11-09 07:33:41 15901 Первая редакция (сохранено в черновиках)