Рассмотрим такую задачу: даны три целых числа $$$x$$$, $$$y$$$ и $$$m$$$, $$$0 \leqslant x,y < m < 2^{32}$$$, найти $$$xy\bmod m$$$. По-хорошему хотелось бы просто перемножить эти два числа, а потом применить операцию остатка:
uint32_t prod(const uint32_t x, const uint32_t y, const uint32_t m)
{
return x * y % m;
}
Как вы, возможно, догадываетесь, это решение неверное. Всё дело в том, что в такой процедуре возможно переполнение: операция x * y
выполняется в типе uint32_t
, и на самом деле промежуточным результатом выполнения этой операции будет не $$$xy$$$, а $$$xy\bmod2^{32}$$$. Если после этого взять результат по модулю $$$m$$$, он может отличаться от правильного:
Выход прост — перемножать необходимо в большем типе:
uint64_t prod_uint64(const uint64_t x, const uint64_t y, const uint64_t m)
{
return x * y % m;
}
Если так делать, то, поскольку $$$xy<2^{64}$$$, это произведение точно не переполнится, и после взятия результата по модулю получится правильный ответ.
Вопрос: а что делать, если $$$x$$$, $$$y$$$ и $$$m$$$ могут быть больше? Предлагаю следующее.
- Бинарное умножение. Так же, как и бинарное возведение в степень, существует бинарное умножение: чтобы посчитать $$$xy$$$, посчитаем $$$x\left\lfloor\frac y2\right\rfloor$$$, сложим это число само с собой и, возможно, добавим ещё $$$x$$$. Это потратит $$$\mathcal O(\log y)$$$ действий, но среди них не будет ничего, кроме сложения и вычитания!
uint64_t sum(const uint64_t x, const uint64_t y, const uint64_t m)
{
uint64_t ans = x + y;
if (ans < x || ans >= m)
ans -= m;
return ans;
}
uint64_t prod_binary(const uint64_t x, const uint64_t y, const uint64_t m)
{
if (y <= 1)
return y ? x : 0;
uint64_t ans = prod_binary(x, y >> 1, m);
ans = sum(ans, ans, m);
if (y & 1)
ans = sum(ans, x, m);
return ans;
}
- Умножение с помощью
int128
. Чтобы перемножить два 32-битных числа, нужна 64-битная промежуточная переменная. А чтобы перемножить два 64-битных числа, нужна 128-битная переменная! В современных 64-битных компиляторах C++ (кроме разве что Microsoft® Visual C++®) есть специальный тип__int128
, позволяющий осуществлять операции над 128-битными числами.
int64_t prod_uint128(const uint64_t x, const uint64_t y, const uint64_t m)
{
return (unsigned __int128)x * y % m;
}
- Умножение с помощью вещественного типа. Что такое $$$xy\bmod m$$$? Это на самом деле $$$xy-cm$$$, где $$$c=\left\lfloor\frac{xy}m\right\rfloor$$$. Давайте тогда попробуем посчитать $$$c$$$, а отсюда найдём $$$xy\bmod m$$$. При этом заметим, что нам не требуется находить $$$c$$$ прямо точно. Что будет, если мы случайно посчитаем, скажем, $$$c-4$$$? Тогда при подсчёте остатка мы вычислим $$$xy-(c-4)m=xy-cm+4m=xy\bmod m+4m$$$. На первый взгляд, это не то, что нам надо. Но если $$$m$$$ не слишком велико и $$$xy\bmod m+4m$$$ не вылезло из 64-битного типа, то после этого можно по-честному взять остаток и получить ответ.
Получается такой код:
uint64_t prod_double(const uint64_t x, const uint64_t y, const uint64_t m)
{
uint64_t c = (double)x * y / m;
int64_t ans = int64_t(x * y - c * m) % int64_t(m);
if (ans < 0)
ans += m;
return ans;
}
uint64_t prod_long_double(const uint64_t x, const uint64_t y, const uint64_t m)
{
uint64_t c = (long double)x * y / m;
int64_t ans = int64_t(x * y - c * m) % int64_t(m);
if (ans < 0)
ans += m;
return ans;
}
double
достаточно точный для этой задачи, если $$$x$$$, $$$y$$$ и $$$m$$$ меньше $$$2^{57}$$$. long double
же хватает на числа, меньшие $$$2^{63}$$$, однако стоит помнить, что long double
для этого должен быть 80-битным, а это верно не на всех компиляторах: например, в Microsoft® Visual C++® long double
— то же самое, что double
.
Обратите внимание, что этот способ неприменим, если $$$m>2^{63}$$$: в этом случае ans
нельзя хранить в int64_t
, так как, возможно, $$$\mathtt{ans}\geqslant 2^{63}$$$ и произойдёт переполнение, из-за которого выполнится ветка (ans < 0)
и мы получим неверный ответ.
Видно, что Microsoft® Visual C++® отстаёт от остальных компиляторов в развитии в наличии технических средств для умножения больших чисел по модулю, поэтому, если мы хотим, чтобы функция работала на всех компиляторах быстро, ей необходима какая-то свежая идея. К счастью, именно такую идею в 1960-м году изобрёл Анатолий Карацуба.
- Умножение Карацубы. Идея изначально использовалась для быстрого умножения длинных чисел. А именно, пусть $$$x$$$ и $$$y$$$ — два неотрицательных целых числа, меньших $$$N^2$$$. Поделим их с остатком на $$$N$$$: $$$x=Nx_1+x_0$$$, $$$y=Ny_1+y_0$$$. Тогда искомое $$$xy$$$ можно найти как $$$N^2x_1y_1+Nx_1y_0+Nx_0y_1+x_0y_0=N\cdot\bigl(N\cdot x_1y_1+\left(x_0+x_1\right)\left(y_0+y_1\right)-x_1y_1-x_0y_0\bigr)+x_0y_0$$$. Как видно, это преобразование свело умножение $$$x$$$ и $$$y$$$ 1) к $$$\mathcal O(1)$$$ сложениям и вычитаниям чисел, не превосходящих $$$N^4$$$; 2) к трём умножениям чисел, не превосходящих $$$2N$$$ (а именно $$$x_0y_0$$$, $$$x_1y_1$$$ и $$$\left(x_0+x_1\right)\left(y_0+y_1\right)$$$); 3) к двум умножениям на $$$N$$$ чисел, не превосходящих $$$2N^2$$$.
Пункт 1) практически всегда очень прост. В случае с длинными числами пункт 3) также прост: можно взять $$$N$$$, равное степени двойки, и тогда он осуществляется как обычный двоичный сдвиг (оператор<<
в C++). Поэтому по существу Карацуба свёл одно умножение чисел, меньших $$$N^2$$$, к трём умножениям чисел, меньших $$$2N$$$. Если эти умножения также свести методом Карацубы, по мастер-теореме асимптотика этого метода составит $$$\Theta\left(\log^{\log_23}N\right)$$$ вместо наивного $$$\Theta\left(\log^2N\right)$$$.
Но нам не потребуется использовать рекурсию, ведь, если длину $$$x$$$ и $$$y$$$ уменьшить вдвое, мы уже сможем воспользоватьсяprod_uint64
илиprod_double
. Сложность в нашем случае составляет пункт 3): подобрать такое $$$N$$$, чтобы, во-первых, оно было меньше $$$2^{32}$$$ либо в крайнем случае чуть-чуть больше, во-вторых, чтобы на него можно было быстро умножать числа порядка $$$N^2$$$. Оба требования выполнятся, если взять $$$N=\mathrm{round}\left(\sqrt m\right)$$$: действительно, тогда при $$$m<2^{64}$$$ верно $$$N<2^{32}$$$, а $$$\left|m_0\right|=\left|m-N^2\right|\leqslant N<2^{32}$$$; тогда $$$xN=\left(x_1N+x_0\right)N=x_1N^2+x_0N\equiv x_0N-x_1m_0\pmod m$$$, и оба умножения здесь осуществляются над числами порядка $$$N$$$.
Внимательный читатель заметит, что здесь мы обзавелись серьёзной проблемой: извлечение квадратного корня из целого числа. Если вы умеете применять в этой задаче умножение Карацубы, обходя данную проблему (в том числе достаточно быстро находить квадратный корень, написать более быстрый или более короткий код, чем у меня), напишите, пожалуйста, в комментариях!
Поскольку находить произведение $$$\left(x_0+x_1\right)\left(y_0+y_1\right)$$$ оказалось крайне неприятно (напомню,prod_double
не работает при $$$m>2^{63}$$$), я всё же решил просто вычислить $$$x_0y_1$$$ и $$$x_1y_0$$$ по-честному — таким образом, это не метод Карацубы в истинном смысле, так как я трачу четыре умножения чисел порядка $$$N$$$.
uint64_t dif(const uint64_t x, const uint64_t y, const uint64_t m)
{
uint64_t ans = x - y;
if (ans > x)
ans += m;
return ans;
}
bool check_ge_rounded_sqrt(const uint64_t m, const uint64_t r)
{
return ((r >= 1ull << 32) || r * (r + 1) >= m);
}
bool check_le_rounded_sqrt(const uint64_t m, const uint64_t r)
{
return (r == 0 || ((r <= 1ull << 32) && r * (r - 1) < m));
}
bool check_rounded_sqrt(const uint64_t m, const uint64_t r)
{
return check_ge_rounded_sqrt(m, r) && check_le_rounded_sqrt(m, r);
}
uint64_t rounded_sqrt(const uint64_t m)
{
uint64_t r = floorl(.5 + sqrtl(m));
if (!check_ge_rounded_sqrt(m, r))
while (!check_ge_rounded_sqrt(m, ++r));
else if (!check_le_rounded_sqrt(m, r))
while (!check_le_rounded_sqrt(m, --r));
return r;
}
uint64_t prod_karatsuba_aux(const uint64_t x, const uint64_t N, const int64_t m0, const uint64_t m)
{
uint64_t x1 = x / N;
uint64_t x0N = (x - N * x1) * N;
if (m0 >= 0)
return dif(x0N, x1 * (uint64_t)m0, m);
else
return sum(x0N, x1 * (uint64_t)-m0, m);
}
uint64_t prod_karatsuba(const test& t)
{
uint64_t x = t.x, y = t.y, m = t.modulo;
uint64_t N = rounded_sqrt(t.modulo);
int64_t m0 = m - N * N;
uint64_t x1 = t.x / N;
uint64_t x0 = t.x - N * x1;
uint64_t y1 = t.y / N;
uint64_t y0 = t.y - N * y1;
uint64_t x0y0 = sum(x0 * y0, 0, m);
uint64_t x0y1 = sum(x0 * y1, 0, m);
uint64_t x1y0 = sum(x1 * y0, 0, m);
uint64_t x1y1 = sum(x1 * y1, 0, m);
return sum(prod_karatsuba_aux(sum(prod_karatsuba_aux(x1y1, N, m0, m), sum(x0y1, x1y0, m), m), N, m0, m), x0y0, m);
}
Видно, что на самом деле единственное, что даёт нам тут метод Карацубы — что если вы найдёте большое число $$$N$$$, на которое вы умеете быстро умножать, то вы сможете умножать любые два числа по модулю. Фактически, если бы модуль $$$m$$$ был фиксированным, и было бы много запросов умножения по этому фиксированному модулю, то метод Карацубы был бы молниеносным, поскольку самая затратная операция в нём — это квадратный корень. Хотелось бы, таким образом, взять, например, $$$N=2^{32}$$$ и сделать всё то же самое, что в прошлом пункте, но без квадратного корня. Увы, но я не придумал, как умножать на $$$2^{32}$$$. Можно было бы написать примерно такую функцию:
uint64_t prod_double_small(const uint64_t x, const uint64_t y, const uint64_t m)
{
uint64_t c = (double)x * y / m;
uint64_t ans = (x * y - c * m) % m;
return ans;
}
Она вычисляет произведение по модулю при условии, что uint64_t c = (double)x * y / m
посчиталось абсолютно точно. Но гарантировать, что оно точно посчитается, не представляется возможным, поскольку $$$\frac{xy}m$$$ вполне может оказаться на $$$10^{-18}$$$ меньше, чем какое-то целое число, и типа double
не хватит, чтобы это заметить. Именно эту проблему обходит функция prod_karatsuba_aux
. Если вы смогли её обойти хитрее, добро пожаловать в комментарии.
Ниже приведены три таблицы по разным компиляторам (да славится имя MikeMirzayanov, ведь именно благодаря нецелевому использованию его Polygon'а я это сделал), в каждой таблице строки соответствуют различным функциям, столбцы соответствуют максимальной разрешённой битности $$$x$$$, $$$y$$$ и $$$m$$$. Если указано CE, значит, с данным компилятором программа не скомпилируется, а если WA — может выдать неверный ответ. В противном случае указано время работы функции на Intel® Core™ i3-8100 CPU @ 3.60GHz. Погрешность приблизительно равна одной-двум наносекундам, на самых медленных функциях может доходить до десяти наносекунд.
Microsoft® Visual C++® 2010
Метод 32 бита 57 битов 63 бита 64 бита prod_uint64 7 ns WA WA WA prod_binary 477 ns 847 ns 889 ns 870 ns prod_uint128 CE CE CE CE prod_double 66 ns 95 ns WA WA prod_long_double 66 ns 98 ns WA WA prod_karatsuba 128 ns 125 ns 138 ns 139 ns GNU G++17
Метод 32 бита 57 битов 63 бита 64 бита prod_uint64 4 ns WA WA WA prod_binary 455 ns 774 ns 841 ns 845 ns prod_uint128 CE CE CE CE prod_double 26 ns 36 ns WA WA prod_long_double 29 ns 20 ns 19 ns WA prod_karatsuba 82 ns 81 ns 91 ns 88 ns GNU G++17 (64 bit)
Метод 32 бита 57 битов 63 бита 64 бита prod_uint64 8 ns WA WA WA prod_binary 313 ns 550 ns 604 ns 630 ns prod_uint128 17 ns 34 ns 30 ns 30 ns prod_double 23 ns 22 ns WA WA prod_long_double 23 ns 24 ns 23 ns WA prod_karatsuba 65 ns 65 ns 69 ns 66 ns
Поэтому базовый рецепт такой: если доступен unsigned __int128
, то использовать его, если доступен 80-битный long double
, то его тоже должно всегда хватать, а в противном случае надо, если хватает double
, использовать double
, иначе применить метод Карацубы.
При желании можете попробовать применить эти идеи на задачах специального контеста.