Быстрое умножение по модулю

Revision ru8, by orz, 2022-06-16 20:17:53

Рассмотрим такую задачу: даны три целых числа $$$x$$$, $$$y$$$ и $$$m$$$, $$$0 \leqslant x,y < m < 2^{32}$$$, найти $$$xy\bmod m$$$. По-хорошему хотелось бы просто перемножить эти два числа, а потом применить операцию остатка:

uint32_t prod(const uint32_t x, const uint32_t y, const uint32_t m)
{
	return x * y % m;
}

Как вы, возможно, догадываетесь, это решение неверное. Всё дело в том, что в такой процедуре возможно переполнение: операция x * y выполняется в типе uint32_t, и на самом деле промежуточным результатом выполнения этой операции будет не $$$xy$$$, а $$$xy\bmod2^{32}$$$. Если после этого взять результат по модулю $$$m$$$, он может отличаться от правильного:

$$$ \left(xy\bmod2^{32}\right)\bmod m\ne xy\bmod m. $$$

Выход прост — перемножать необходимо в большем типе:

uint64_t prod_uint64(const uint64_t x, const uint64_t y, const uint64_t m)
{
	return x * y % m;
}

Если так делать, то, поскольку $$$xy<2^{64}$$$, это произведение точно не переполнится, и после взятия результата по модулю получится правильный ответ.

Вопрос: а что делать, если $$$x$$$, $$$y$$$ и $$$m$$$ могут быть больше? Предлагаю следующее.

  1. Бинарное умножение. Так же, как и бинарное возведение в степень, существует бинарное умножение: чтобы посчитать $$$xy$$$, посчитаем $$$x\left\lfloor\frac y2\right\rfloor$$$, сложим это число само с собой и, возможно, добавим ещё $$$x$$$. Это потратит $$$\mathcal O(\log y)$$$ действий, но среди них не будет ничего, кроме сложения и вычитания!
	uint64_t sum(const uint64_t x, const uint64_t y, const uint64_t m)
	{
		uint64_t ans = x + y;
		if (ans < x || ans >= m)
			ans -= m;
		return ans;
	}
	uint64_t prod_binary(const uint64_t x, const uint64_t y, const uint64_t m)
	{
		if (y <= 1)
			return y ? x : 0;
		uint64_t ans = prod_binary(x, y >> 1, m);
		ans = sum(ans, ans, m);
		if (y & 1)
			ans = sum(ans, x, m);
		return ans;
	}
  1. Умножение с помощью int128. Чтобы перемножить два 32-битных числа, нужна 64-битная промежуточная переменная. А чтобы перемножить два 64-битных числа, нужна 128-битная переменная! В современных 64-битных компиляторах C++ (кроме разве что Microsoft® Visual C++®) есть специальный тип __int128, позволяющий осуществлять операции над 128-битными числами.
	int64_t prod_uint128(const uint64_t x, const uint64_t y, const uint64_t m)
	{
		return (unsigned __int128)x * y % m;
	}
  1. Умножение с помощью вещественного типа. Что такое $$$xy\bmod m$$$? Это на самом деле $$$xy-cm$$$, где $$$c=\left\lfloor\frac{xy}m\right\rfloor$$$. Давайте тогда попробуем посчитать $$$c$$$, а отсюда найдём $$$xy\bmod m$$$. При этом заметим, что нам не требуется находить $$$c$$$ прямо точно. Что будет, если мы случайно посчитаем, скажем, $$$c-4$$$? Тогда при подсчёте остатка мы вычислим $$$xy-(c-4)m=xy-cm+4m=xy\bmod m+4m$$$. На первый взгляд, это не то, что нам надо. Но если $$$m$$$ не слишком велико и $$$xy\bmod m+4m$$$ не вылезло из 64-битного типа, то после этого можно по-честному взять остаток и получить ответ.

    Получается такой код:
	uint64_t prod_double(const uint64_t x, const uint64_t y, const uint64_t m)
	{
		uint64_t c = (double)x * y / m;
		int64_t ans = int64_t(x * y - c * m) % int64_t(m);
		if (ans < 0)
			ans += m;
		return ans;
	}
	uint64_t prod_long_double(const uint64_t x, const uint64_t y, const uint64_t m)
	{
		uint64_t c = (long double)x * y / m;
		int64_t ans = int64_t(x * y - c * m) % int64_t(m);
		if (ans < 0)
			ans += m;
		return ans;
	}

double достаточно точный для этой задачи, если $$$x$$$, $$$y$$$ и $$$m$$$ меньше $$$2^{57}$$$. long double же хватает на числа, меньшие $$$2^{63}$$$, однако стоит помнить, что long double для этого должен быть 80-битным, а это верно не на всех компиляторах: например, в Microsoft® Visual C++® long double — то же самое, что double.

Обратите внимание, что этот способ неприменим, если $$$m>2^{63}$$$: в этом случае ans нельзя хранить в int64_t, так как, возможно, $$$\mathtt{ans}\geqslant 2^{63}$$$ и произойдёт переполнение, из-за которого выполнится ветка (ans < 0) и мы получим неверный ответ.

Видно, что Microsoft® Visual C++® отстаёт от остальных компиляторов в развитии в наличии технических средств для умножения больших чисел по модулю, поэтому, если мы хотим, чтобы функция работала на всех компиляторах быстро, ей необходима какая-то свежая идея. К счастью, именно такую идею в 1960-м году изобрёл Анатолий Карацуба.

  1. Умножение Карацубы. Идея изначально использовалась для быстрого умножения длинных чисел. А именно, пусть $$$x$$$ и $$$y$$$ — два неотрицательных целых числа, меньших $$$N^2$$$. Поделим их с остатком на $$$N$$$: $$$x=Nx_1+x_0$$$, $$$y=Ny_1+y_0$$$. Тогда искомое $$$xy$$$ можно найти как $$$N^2x_1y_1+Nx_1y_0+Nx_0y_1+x_0y_0=N\cdot\bigl(N\cdot x_1y_1+\left(x_0+x_1\right)\left(y_0+y_1\right)-x_1y_1-x_0y_0\bigr)+x_0y_0$$$. Как видно, это преобразование свело умножение $$$x$$$ и $$$y$$$ 1) к $$$\mathcal O(1)$$$ сложениям и вычитаниям чисел, не превосходящих $$$N^4$$$; 2) к трём умножениям чисел, не превосходящих $$$2N$$$ (а именно $$$x_0y_0$$$, $$$x_1y_1$$$ и $$$\left(x_0+x_1\right)\left(y_0+y_1\right)$$$); 3) к двум умножениям на $$$N$$$ чисел, не превосходящих $$$2N^2$$$.

    Пункт 1) практически всегда очень прост. В случае с длинными числами пункт 3) также прост: можно взять $$$N$$$, равное степени двойки, и тогда он осуществляется как обычный двоичный сдвиг (оператор << в C++). Поэтому по существу Карацуба свёл одно умножение чисел, меньших $$$N^2$$$, к трём умножениям чисел, меньших $$$2N$$$. Если эти умножения также свести методом Карацубы, по мастер-теореме асимптотика этого метода составит $$$\Theta\left(\log^{\log_23}N\right)$$$ вместо наивного $$$\Theta\left(\log^2N\right)$$$.

    Но нам не потребуется использовать рекурсию, ведь, если длину $$$x$$$ и $$$y$$$ уменьшить вдвое, мы уже сможем воспользоваться prod_uint64 или prod_double. Сложность в нашем случае составляет пункт 3): подобрать такое $$$N$$$, чтобы, во-первых, оно было меньше $$$2^{32}$$$ либо в крайнем случае чуть-чуть больше, во-вторых, чтобы на него можно было быстро умножать числа порядка $$$N^2$$$. Оба требования выполнятся, если взять $$$N=\mathrm{round}\left(\sqrt m\right)$$$: действительно, тогда при $$$m<2^{64}$$$ верно $$$N<2^{32}$$$, а $$$\left|m_0\right|=\left|m-N^2\right|\leqslant N<2^{32}$$$; тогда $$$xN=\left(x_1N+x_0\right)N=x_1N^2+x_0N\equiv x_0N-x_1m_0\pmod m$$$, и оба умножения здесь осуществляются над числами порядка $$$N$$$.

    Внимательный читатель заметит, что здесь мы обзавелись серьёзной проблемой: извлечение квадратного корня из целого числа. Если вы умеете применять в этой задаче умножение Карацубы, обходя данную проблему (в том числе достаточно быстро находить квадратный корень, написать более быстрый или более короткий код, чем у меня), напишите, пожалуйста, в комментариях!

    Поскольку находить произведение $$$\left(x_0+x_1\right)\left(y_0+y_1\right)$$$ оказалось крайне неприятно (напомню, prod_double не работает при $$$m>2^{63}$$$), я всё же решил просто вычислить $$$x_0y_1$$$ и $$$x_1y_0$$$ по-честному — таким образом, это не метод Карацубы в истинном смысле, так как я трачу четыре умножения чисел порядка $$$N$$$.
	uint64_t dif(const uint64_t x, const uint64_t y, const uint64_t m)
	{
		uint64_t ans = x - y;
		if (ans > x)
			ans += m;
		return ans;
	}
	bool check_ge_rounded_sqrt(const uint64_t m, const uint64_t r)
	{
		return ((r >= 1ull << 32) || r * (r + 1) >= m);
	}
	bool check_le_rounded_sqrt(const uint64_t m, const uint64_t r)
	{
		return (r == 0 || ((r <= 1ull << 32) && r * (r - 1) < m));
	}
	bool check_rounded_sqrt(const uint64_t m, const uint64_t r)
	{
		return check_ge_rounded_sqrt(m, r) && check_le_rounded_sqrt(m, r);
	}
	uint64_t rounded_sqrt(const uint64_t m)
	{
		uint64_t r = floorl(.5 + sqrtl(m));
		if (!check_ge_rounded_sqrt(m, r))
			while (!check_ge_rounded_sqrt(m, ++r));
		else if (!check_le_rounded_sqrt(m, r))
			while (!check_le_rounded_sqrt(m, --r));
		return r;
	}
	uint64_t prod_karatsuba_aux(const uint64_t x, const uint64_t N, const int64_t m0, const uint64_t m)
	{
		uint64_t x1 = x / N;
		uint64_t x0N = (x - N * x1) * N;
		if (m0 >= 0)
			return dif(x0N, x1 * (uint64_t)m0, m);
		else
			return sum(x0N, x1 * (uint64_t)-m0, m);
	}
	uint64_t prod_karatsuba(const test& t)
	{
		uint64_t x = t.x, y = t.y, m = t.modulo;
		uint64_t N = rounded_sqrt(t.modulo);
		int64_t m0 = m - N * N;
		uint64_t x1 = t.x / N;
		uint64_t x0 = t.x - N * x1;
		uint64_t y1 = t.y / N;
		uint64_t y0 = t.y - N * y1;
		uint64_t x0y0 = sum(x0 * y0, 0, m);
		uint64_t x0y1 = sum(x0 * y1, 0, m);
		uint64_t x1y0 = sum(x1 * y0, 0, m);
		uint64_t x1y1 = sum(x1 * y1, 0, m);
		return sum(prod_karatsuba_aux(sum(prod_karatsuba_aux(x1y1, N, m0, m), sum(x0y1, x1y0, m), m), N, m0, m), x0y0, m);
	}

Видно, что на самом деле единственное, что даёт нам тут метод Карацубы — что если вы найдёте большое число $$$N$$$, на которое вы умеете быстро умножать, то вы сможете умножать любые два числа по модулю. Фактически, если бы модуль $$$m$$$ был фиксированным, и было бы много запросов умножения по этому фиксированному модулю, то метод Карацубы был бы молниеносным, поскольку самая затратная операция в нём — это квадратный корень. Хотелось бы, таким образом, взять, например, $$$N=2^{32}$$$ и сделать всё то же самое, что в прошлом пункте, но без квадратного корня. Увы, но я не придумал, как умножать на $$$2^{32}$$$. Можно было бы написать примерно такую функцию:

uint64_t prod_double_small(const uint64_t x, const uint64_t y, const uint64_t m)
{
	uint64_t c = (double)x * y / m;
	uint64_t ans = (x * y - c * m) % m;
	return ans;
}

Она вычисляет произведение по модулю при условии, что uint64_t c = (double)x * y / m посчиталось абсолютно точно. Но гарантировать, что оно точно посчитается, не представляется возможным, поскольку $$$\frac{xy}m$$$ вполне может оказаться на $$$10^{-18}$$$ меньше, чем какое-то целое число, и типа double не хватит, чтобы это заметить. Именно эту проблему обходит функция prod_karatsuba_aux. Если вы смогли её обойти хитрее, добро пожаловать в комментарии.


Ниже приведены три таблицы по разным компиляторам (да славится имя MikeMirzayanov, ведь именно благодаря нецелевому использованию его Polygon'а я это сделал), в каждой таблице строки соответствуют различным функциям, столбцы соответствуют максимальной разрешённой битности $$$x$$$, $$$y$$$ и $$$m$$$. Если указано CE, значит, с данным компилятором программа не скомпилируется, а если WA — может выдать неверный ответ. В противном случае указано время работы функции на Intel® Core™ i3-8100 CPU @ 3.60GHz. Погрешность приблизительно равна одной-двум наносекундам, на самых медленных функциях может доходить до десяти наносекунд.

  1. Microsoft® Visual C++® 2010

    Метод 32 бита 57 битов 63 бита 64 бита
    prod_uint64 7 ns WA WA WA
    prod_binary 477 ns 847 ns 889 ns 870 ns
    prod_uint128 CE CE CE CE
    prod_double 66 ns 95 ns WA WA
    prod_long_double 66 ns 98 ns WA WA
    prod_karatsuba 128 ns 125 ns 138 ns 139 ns

  2. GNU G++17

    Метод 32 бита 57 битов 63 бита 64 бита
    prod_uint64 4 ns WA WA WA
    prod_binary 455 ns 774 ns 841 ns 845 ns
    prod_uint128 CE CE CE CE
    prod_double 26 ns 36 ns WA WA
    prod_long_double 29 ns 20 ns 19 ns WA
    prod_karatsuba 82 ns 81 ns 91 ns 88 ns

  3. GNU G++17 (64 bit)

    Метод 32 бита 57 битов 63 бита 64 бита
    prod_uint64 8 ns WA WA WA
    prod_binary 313 ns 550 ns 604 ns 630 ns
    prod_uint128 17 ns 34 ns 30 ns 30 ns
    prod_double 23 ns 22 ns WA WA
    prod_long_double 23 ns 24 ns 23 ns WA
    prod_karatsuba 65 ns 65 ns 69 ns 66 ns

Поэтому базовый рецепт такой: если доступен unsigned __int128, то использовать его, если доступен 80-битный long double, то его тоже должно всегда хватать, а в противном случае надо, если хватает double, использовать double, иначе применить метод Карацубы.

При желании можете попробовать применить эти идеи на задачах специального контеста.

Tags умножение

History

 
 
 
 
Revisions
 
 
  Rev. Lang. By When Δ Comment
ru8 Russian orz 2022-06-16 20:17:53 4 Мелкая правка: '%D1%87).\n4. _Умно' -> '%D1%87).\n\n4. _Умно'
en7 English orz 2022-02-24 05:26:02 16
ru7 Russian orz 2022-02-24 05:24:33 21 Мелкая правка: 'остатка:\n~~~~~\nu' -> 'остатка:\n\n~~~~~\nu'
en6 English orz 2021-11-10 21:54:33 1 Tiny change: 't|=\left|mN^2\right|' -> 't|=\left|m-N^2\right|'
en5 English orz 2021-11-09 15:13:55 8 Tiny change: ' Karatsuba](https://e' -> ' Karatsuba(https://e'
en4 English orz 2021-11-09 08:42:08 21 Tiny change: 'tal delay<strike> la' -> 'tal delay</strike> la'
ru6 Russian orz 2021-11-09 08:41:10 21
en3 English orz 2021-11-09 08:20:46 0 (published)
ru5 Russian orz 2021-11-09 08:10:07 0 (опубликовано)
ru4 Russian orz 2021-11-09 08:09:54 27
ru3 Russian orz 2021-11-09 08:07:56 23 Мелкая правка: '0y0, m);\n}\n~~~~~\n' -> '0y0, m);\n }\n~~~~~\n'
en2 English orz 2021-11-09 08:07:52 418 Tiny change: 'rm{round}\!\left(\sqr' -> 'rm{round}\left(\sqr'
ru2 Russian orz 2021-11-09 08:00:41 533
en1 English orz 2021-11-09 07:51:02 15661 Initial revision for English translation (saved to drafts)
ru1 Russian orz 2021-11-09 07:33:41 15901 Первая редакция (сохранено в черновиках)