Блог пользователя platelet

Автор platelet, история, 2 года назад, По-английски

speed test code of this.

#include <bits/stdc++.h>

using namespace std;

const int N = 5e4, P = 998244353;

int a[N];

void ThroughputTest() {
    int checkSum1 = 0, checkSum2 = 0, checkSum3 = 0;

    auto start = chrono::steady_clock::now();
    for(int i = 0; i < N; i += 2)
        for(int j = 0; j < N; j++) {
            checkSum1 ^= (int64_t)a[i] * a[j] % P;
            checkSum1 ^= (int64_t)a[i + 1] * a[j] % P;
        }
    auto end = std::chrono::steady_clock::now();
    cout << "Compiler's signed modulo:   " << (end - start).count() * 1e-6 << " ms" << endl;

    start = chrono::steady_clock::now();
    for(int i = 0; i < N; i += 2)
        for(int j = 0; j < N; j++) {
            checkSum2 ^= (uint64_t)(uint32_t)a[i] * (uint32_t)a[j] % P;
            checkSum2 ^= (uint64_t)(uint32_t)a[i + 1] * (uint32_t)a[j] % P;
        }
    end = std::chrono::steady_clock::now();
    cout << "Compiler's unsigned modulo: " << (end - start).count() * 1e-6 << " ms" << endl;

    start = chrono::steady_clock::now();
    for(int i = 0; i < N; i += 2) {
        uint64_t x = (((__uint128_t)a[i] << 64) + P - 1) / P;
        uint64_t y = (((__uint128_t)a[i + 1] << 64) + P - 1) / P;
        for(int j = 0; j < N; j++) {
            checkSum3 ^= (uint32_t)a[j] * x * (__uint128_t)P >> 64;
            checkSum3 ^= (uint32_t)a[j] * y * (__uint128_t)P >> 64;
        }
    }
    end = std::chrono::steady_clock::now();
    cout << "My modulo:                  " << (end - start).count() * 1e-6 << " ms" << endl;

    assert(checkSum1 == checkSum2 && checkSum2 == checkSum3);
}
void LatencyTest() {
    int checkSum1 = 0, checkSum2 = 0, checkSum3 = 0;

    auto start = chrono::steady_clock::now();
    for(int i = 0; i < N; i += 2)
        for(int j = 0; j < N / 2; j++) {
            checkSum1 = (int64_t)a[i] * (a[j] ^ checkSum1) % P;
            checkSum1 = (int64_t)a[i + 1] * (a[j] ^ checkSum1) % P;
        }
    auto end = std::chrono::steady_clock::now();
    cout << "Compiler's signed modulo:   " << (end - start).count() * 1e-6 << " ms" << endl;

    start = chrono::steady_clock::now();
    for(int i = 0; i < N; i += 2)
        for(int j = 0; j < N / 2; j++) {
            checkSum2 = (uint64_t)(uint32_t)a[i] * (uint32_t)(a[j] ^ checkSum2) % P;
            checkSum2 = (uint64_t)(uint32_t)a[i + 1] * (uint32_t)(a[j] ^ checkSum2) % P;
        }
    end = std::chrono::steady_clock::now();
    cout << "Compiler's unsigned modulo: " << (end - start).count() * 1e-6 << " ms" << endl;

    start = chrono::steady_clock::now();
    for(int i = 0; i < N; i += 2) {
        uint64_t x = (((__uint128_t)a[i] << 64) + P - 1) / P;
        uint64_t y = (((__uint128_t)a[i + 1] << 64) + P - 1) / P;
        for(int j = 0; j < N / 2; j++) {
            checkSum3 = (uint32_t)(a[j] ^ checkSum3) * x * (__uint128_t)P >> 64;
            checkSum3 = (uint32_t)(a[j] ^ checkSum3) * y * (__uint128_t)P >> 64;
        }
    }
    end = std::chrono::steady_clock::now();
    cout << "My modulo:                  " << (end - start).count() * 1e-6 << " ms" << endl;

    assert(checkSum1 == checkSum2 && checkSum2 == checkSum3);
}
int main() {
    mt19937 gen;
    for(int i = 0; i < N; i++) a[i] = gen() % P;
    cout << "Throughput test (50000 * 50000):" << endl;
    ThroughputTest();
    cout << endl;
    cout << "Latency test (50000 * 25000):" << endl;
    LatencyTest();
}

Possible output:

Throughput test(50000 * 50000):
Compiler's signed modulo:   1954.83 ms
Compiler's unsigned modulo: 1746.73 ms
My modulo:                  1160.47 ms

Latency test(50000 * 25000):
Compiler's signed modulo:   4329.33 ms
Compiler's unsigned modulo: 3945.29 ms
My modulo:                  2397.97 ms
  • Проголосовать: нравится
  • +67
  • Проголосовать: не нравится

»
2 года назад, # |
Rev. 2   Проголосовать: нравится 0 Проголосовать: не нравится

Wow, awesome. Could you explain a bit about how this works (or provide a link if this is some common technique)? Seems like the modulo operation was somehow done by an unsigned 64-bit integer overflow.

upd. Sorry. Somehow I didn't see this yesterday.