Hello, Codeforces!
A few days ago MohammadParsaElahimanesh posted a blog titled Can we find each Required node in segment tree in O(1)? Apparently what they meant was to find each node in $$$\mathcal{O}(ans)$$$, according to ecnerwala's explanation. But I was too dumb to realize that and accidentally invented a parallel node resolution method instead, which speeds up segment tree a lot.
A benchmark for you first, with 30 million RMQ on a 32-bit integer array of 17 million elements. It was run in custom test on Codeforces on Apr 6, 2021.
- Classic implementation from cp-algorithms: 7.765 seconds, or 260 ns per query
- Optimized classic implementation: (which I was taught) 4.452 seconds, or 150 ns per query (75% faster than classic)
- Bottom-up implementation: 1.914 seconds, or 64 ns per query (133% faster than optimized)
- Novel parallel implementation: 0.383 seconds, or 13 ns per query (400% faster than bottom-up, or 2000% faster than classic implementation)
FAQ
Q: Is it really that fast? I shamelessly stole someone's solution for 1355C - Count Triangles which uses prefix sums: 112167743. It runs in 46 ms. Then I replaced prefix sums with classic segment tree in 112168469 which runs in 155 ms. The bottom-up implementation runs in 93 ms: 112168530. Finally, my novel implementation runs in only 62 ms: 112168574. Compared to the original prefix sums solution, the bottom-up segment tree uses 47 ms in total, and the parallel implementation uses only 16 ms in total. Thus, even in such a simple problem with only prefix queries the novel implementation is 3x faster than the state of art even in practice!
Q: Why? Maybe you want your $$$\mathcal{O}(n \log^2 n)$$$ solution to pass in a problem with $$$\mathcal{O}(n \log n)$$$ model solution. Maybe you want to troll problemsetters. Maybe you want to obfuscate your code so that no one would understand you used a segment tree so that no one hacks you (just kidding, you'll get FST anyway). Choose an excuse for yourself. I want contribution too.
Q: License? Tough question because we're in CP. So you may use it under MIT license for competitive programming, e.g. on Codeforces, and under GPLv3 otherwise.
Q: Any pitfalls? Yes, sadly. It requires AVX2 instructions which are supported on Codeforces, but may not be supported on other judges.
How it works

In a segment tree, a range query is decomposed into 'red' nodes. Classic segment tree implementations don't find these red nodes directly, but execute recursively on green nodes. Bottom-up segment tree implementation does enumerate red nodes directly, but it also enumerates a few other unused nodes.
The parallel implementation is an optimization of bottom-up tree. Probably you all know how bottom-up implementation looks like, but I'll cite the main idea nevertheless to show the difference between bottom-up and parallel implementations:
In bottom-up segment tree, we find the node corresponding to the leftmost element of the query, i.e. $$$x[l]$$$, and the node corresponding to the rightmost query element, i.e. $$$x[r]$$$. If we numerate nodes in a special way, the leftmost element will correspond to node $$$N+l$$$ and the rightmost will correspond to node $$$N+r$$$. After that, the answer is simply the sum of values of all nodes between $$$N+l$$$ and $$$N+r$$$. Sadly there are $$$\mathcal{O}(n)$$$ of those, but we can do the following optimization:
If $$$N+l$$$ is the left child of its parent and $$$N+r$$$ is the right child of its parent, then instead of summing up all nodes in range $$$[N+l, N+r]$$$, we can sum up all nodes in range $$$[\frac{N+l}{2}, \frac{N+r-1}{2}]$$$. That is, we replace the two nodes with their two parents. Otherwise, if $$$N+l$$$ is the right child of its parent, we do ans += a[N+l]; l++;
, and if $$$N+r$$$ is the left child of its parent, we do ans += a[N+r]; r--;
Then the condition holds and we can do the replacement.
In parallel segment tree, we jump to i-th parent of $$$N+l$$$ and $$$N+r$$$ for all $$$i$$$ simultaneously, and check the is-left/right-child conditions in parallel as well. The checks are rather simple, so a few bit operations do the trick. We can perform all bitwise operations using AVX2 on 8 integers at once, which means that the core of the query should run about 8 times faster.
Want code? We have some!
This is the benchmark, along with the four segment tree implementations I checked and a prefix sum for comparison.
Benchmark code#include <iostream>
#include <random>
#include <ctime>
#include <cassert>
#include <immintrin.h>
const int Q = 30000000;
const int N = 1 << 24;
using T = uint32_t;
T a[2 * N];
T pref[N];
const T identity_element = 0;
T reduce(T a, T b) {
return a + b;
}
__attribute__((target("sse4.1"))) __m128i reduce(__m128i a, __m128i b) {
return _mm_add_epi32(a, b);
}
__attribute__((target("avx2"))) __m256i reduce(__m256i a, __m256i b) {
return _mm256_add_epi32(a, b);
}
static_assert(sizeof(T) == 4, "Segment tree elements must be 32-bit");
T query_recursive_inner(int v, int vl, int vr, int l, int r) {
if(l >= r) {
return identity_element;
}
if(l <= vl && vr <= r) {
return a[v];
}
int vm = (vl + vr) / 2;
return reduce(query_recursive_inner(v * 2, vl, vm, l, std::min(r, vm)), query_recursive_inner(v * 2 + 1, vm, vr, std::max(l, vm), r));
}
T query_recursive_inner(int l, int r) {
return query_recursive_inner(1, 0, N, l, r + 1);
}
T query_recursive_outer(int v, int vl, int vr, int l, int r) {
if(vl == l && vr == r) {
return a[v];
} else {
int vm = (vl + vr) / 2;
if(r <= vm) {
return query_recursive_outer(v * 2, vl, vm, l, r);
} else if(l >= vm) {
return query_recursive_outer(v * 2 + 1, vm, vr, l, r);
} else {
return reduce(query_recursive_outer(v * 2, vl, vm, l, vm), query_recursive_outer(v * 2 + 1, vm, vr, vm, r));
}
}
}
T query_recursive_outer(int l, int r) {
return query_recursive_outer(1, 0, N, l, r + 1);
}
T query_bottom_up(int l, int r) {
l += N;
r += N;
T ans = identity_element;
while(l <= r) {
if(l & 1) {
ans = reduce(ans, a[l]);
l++;
}
if(!(r & 1)) {
ans = reduce(ans, a[r]);
r--;
}
l /= 2;
r /= 2;
}
return ans;
}
int ffs(unsigned int x) {
return sizeof(unsigned int) * 8 - 1 - __builtin_clz(x);
}
__attribute__((target("avx2"))) T query_parallel(int l, int r) {
if(l == r) {
return a[l + N];
}
int mbit = ffs(l ^ r);
int reset = ((1 << mbit) - 1);
int m = r & ~reset;
using vecint = T __attribute__((vector_size(32)));
__m256i identity_vec = _mm256_set1_epi32(identity_element);
vecint vec_ans = (vecint)identity_vec;
__m256i indexes = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0);
if((l & reset) != 0) {
int ll = l - 1 + N;
int rr = m - 1 + N;
int modbit = 0;
int maxmodbit = ffs(ll ^ rr) + 1;
vecint ll_vec = (vecint)_mm256_srav_epi32(_mm256_set1_epi32(ll), indexes);
#define LOOP(content) if(modbit + 8 <= maxmodbit) { \
vec_ans = (vecint)reduce((__m256i)vec_ans, _mm256_i32gather_epi32((int*)a, (__m256i)(((ll_vec & 1) - 1) & (ll_vec | 1)), 4)); \
ll_vec >>= 8; \
modbit += 8; \
content \
}
LOOP(LOOP(LOOP(LOOP())))
#undef LOOP
__m256i tmp = _mm256_i32gather_epi32((int*)a, (__m256i)(((ll_vec & 1) - 1) & (ll_vec | 1)), 4);
__m256i mask = _mm256_cmpgt_epi32(_mm256_set1_epi32(maxmodbit & 7), indexes);
vec_ans = (vecint)reduce((__m256i)vec_ans, _mm256_blendv_epi8(identity_vec, tmp, mask));
} else {
vec_ans[0] = reduce(vec_ans[0], a[(l + N) >> mbit]);
}
if((r & reset) != reset) {
int ll = m + N;
int rr = r + 1 + N;
int modbit = 0;
int maxmodbit = ffs(ll ^ rr) + 1;
vecint rr_vec = (vecint)_mm256_srav_epi32(_mm256_set1_epi32(rr), indexes);
#define LOOP(content) if(modbit + 8 <= maxmodbit) { \
vec_ans = (vecint)reduce((__m256i)vec_ans, _mm256_i32gather_epi32((int*)a, (__m256i)(~((rr_vec & 1) - 1) & (rr_vec - 1)), 4)); \
rr_vec >>= 8; \
modbit += 8; \
content \
}
LOOP(LOOP(LOOP(LOOP())))
#undef LOOP
__m256i tmp = _mm256_i32gather_epi32((int*)a, (__m256i)(~((rr_vec & 1) - 1) & (rr_vec - 1)), 4);
__m256i mask = _mm256_cmpgt_epi32(_mm256_set1_epi32(maxmodbit & 7), indexes);
vec_ans = (vecint)reduce((__m256i)vec_ans, _mm256_blendv_epi8(identity_vec, tmp, mask));
} else {
vec_ans[0] = reduce(vec_ans[0], a[(r + N) >> mbit]);
}
// vec_ans = 7 6 5 4 3 2 1 0
__m128i low128 = _mm256_castsi256_si128((__m256i)vec_ans); // 3 2 1 0
__m128i high128 = _mm256_extractf128_si256((__m256i)vec_ans, 1); // 7 6 5 4
__m128i ans128 = reduce(low128, high128); // 7+3 6+2 5+1 4+0
T ans = identity_element;
for(int i = 0; i < 4; i++) {
ans = reduce(ans, ((T __attribute__((vector_size(16))))ans128)[i]);
}
return ans;
}
T query_prefix(int l, int r) {
return pref[r] - (l == 0 ? 0 : pref[l - 1]);
}
int main() {
std::pair<int, int>* queries = new std::pair<int, int>[Q];
for(int i = 0; i < Q; i++) {
int l = rand() % N;
int r = rand() % N;
if(l > r) {
std::swap(l, r);
}
queries[i] = {l, r};
}
for(int i = 0; i < N; i++) {
a[N + i] = rand();
}
for(int i = N - 1; i >= 1; i--) {
a[i] = reduce(a[i * 2], a[i * 2 + 1]);
}
a[0] = identity_element;
for(int i = 0; i < N; i++) {
pref[i] = (i == 0 ? 0 : pref[i - 1]) + a[N + i];
}
#define CHECK(func) { \
auto clock_start = clock(); \
T checksum = 0; \
for(int i = 0; i < Q; i++) { \
checksum += func(queries[i].first, queries[i].second); \
} \
std::cout << #func << ": " << (double)(clock() - clock_start) / CLOCKS_PER_SEC << " seconds (checksum: " << checksum << ")" << std::endl; \
}
CHECK(query_recursive_inner)
CHECK(query_recursive_outer)
CHECK(query_bottom_up)
CHECK(query_parallel)
CHECK(query_prefix)
return 0;
}
The core is here:
Main code#include <iostream>
#include <random>
#include <ctime>
#include <cassert>
#include <immintrin.h>
const int N = 1 << 24;
using T = uint32_t;
T a[2 * N];
const T identity_element = 0;
T reduce(T a, T b) {
return a + b;
}
__attribute__((target("sse4.1"))) __m128i reduce(__m128i a, __m128i b) {
return _mm_add_epi32(a, b);
}
__attribute__((target("avx2"))) __m256i reduce(__m256i a, __m256i b) {
return _mm256_add_epi32(a, b);
}
static_assert((N & (N - 1)) == 0, "Segment tree size must be a power of two");
static_assert(sizeof(T) == 4, "Segment tree elements must be 32-bit");
int ffs(unsigned int x) {
return sizeof(unsigned int) * 8 - 1 - __builtin_clz(x);
}
// Returns sum/min/max/etc. in range [l; r], inclusive. The operation is determined by reduce()
__attribute__((target("avx2"))) T query_parallel(int l, int r) {
if(l == r) {
return a[l + N];
}
int mbit = ffs(l ^ r);
int reset = ((1 << mbit) - 1);
int m = r & ~reset;
using vecint = T __attribute__((vector_size(32)));
__m256i identity_vec = _mm256_set1_epi32(identity_element);
vecint vec_ans = (vecint)identity_vec;
__m256i indexes = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0);
if((l & reset) != 0) {
int ll = l - 1 + N;
int rr = m - 1 + N;
int modbit = 0;
int maxmodbit = ffs(ll ^ rr) + 1;
vecint ll_vec = (vecint)_mm256_srav_epi32(_mm256_set1_epi32(ll), indexes);
#define LOOP(content) if(modbit + 8 <= maxmodbit) { \
vec_ans = (vecint)reduce((__m256i)vec_ans, _mm256_i32gather_epi32((int*)a, (__m256i)(((ll_vec & 1) - 1) & (ll_vec | 1)), 4)); \
ll_vec >>= 8; \
modbit += 8; \
content \
}
LOOP(LOOP(LOOP(LOOP())))
#undef LOOP
__m256i tmp = _mm256_i32gather_epi32((int*)a, (__m256i)(((ll_vec & 1) - 1) & (ll_vec | 1)), 4);
__m256i mask = _mm256_cmpgt_epi32(_mm256_set1_epi32(maxmodbit & 7), indexes);
vec_ans = (vecint)reduce((__m256i)vec_ans, _mm256_blendv_epi8(identity_vec, tmp, mask));
} else {
vec_ans[0] = reduce(vec_ans[0], a[(l + N) >> mbit]);
}
if((r & reset) != reset) {
int ll = m + N;
int rr = r + 1 + N;
int modbit = 0;
int maxmodbit = ffs(ll ^ rr) + 1;
vecint rr_vec = (vecint)_mm256_srav_epi32(_mm256_set1_epi32(rr), indexes);
#define LOOP(content) if(modbit + 8 <= maxmodbit) { \
vec_ans = (vecint)reduce((__m256i)vec_ans, _mm256_i32gather_epi32((int*)a, (__m256i)(~((rr_vec & 1) - 1) & (rr_vec - 1)), 4)); \
rr_vec >>= 8; \
modbit += 8; \
content \
}
LOOP(LOOP(LOOP(LOOP())))
#undef LOOP
__m256i tmp = _mm256_i32gather_epi32((int*)a, (__m256i)(~((rr_vec & 1) - 1) & (rr_vec - 1)), 4);
__m256i mask = _mm256_cmpgt_epi32(_mm256_set1_epi32(maxmodbit & 7), indexes);
vec_ans = (vecint)reduce((__m256i)vec_ans, _mm256_blendv_epi8(identity_vec, tmp, mask));
} else {
vec_ans[0] = reduce(vec_ans[0], a[(r + N) >> mbit]);
}
// vec_ans = 7 6 5 4 3 2 1 0
__m128i low128 = _mm256_castsi256_si128((__m256i)vec_ans); // 3 2 1 0
__m128i high128 = _mm256_extractf128_si256((__m256i)vec_ans, 1); // 7 6 5 4
__m128i ans128 = reduce(low128, high128); // 7+3 6+2 5+1 4+0
T ans = identity_element;
for(int i = 0; i < 4; i++) {
ans = reduce(ans, ((T __attribute__((vector_size(16))))ans128)[i]);
}
return ans;
}
int main() {
// ...fill array from a[N] to a[2*N-1]...
for(int i = N - 1; i >= 1; i--) {
a[i] = reduce(a[i * 2], a[i * 2 + 1]);
}
a[0] = identity_element;
// ...your code here...
return 0;
}
The following line configures the count of elements in segment tree. It must be a power of two, so instead of using 1e6
use 1 << 20
:
const int N = 1 << 24;
The following line sets the type of the elements. It must be a 32-bit integer, either signed or unsigned, at the moment.
using T = uint32_t;
The following line sets the identity element. It's 0 for sum, -inf for max, inf for min. If you use unsigned integers, I'd recommend you to use 0 for max and (uint32_t)-1
for min.
const T identity_element = 0;
The following function defines the operation itself: sum, min, max, etc.
T reduce(T a, T b) {
return a + b;
}
Then the following two functions are like reduce(T, T)
but vectorized: for 128-bit registers and 256-bit registers. There are builtins for add: _mm[256]_add_epi32
, min (signed): _mm[256]_min_epi32
, max (signed): _mm[256]_max_epi32
, min (unsigned): _mm[256]_min_epu32
, max (unsigned): _mm[256]_max_epu32
. You can check Intel Intrinsics Guide if you're not sure.
__attribute__((target("sse4.1"))) __m128i reduce(__m128i a, __m128i b) {
return _mm_add_epi32(a, b);
}
__attribute__((target("avx2"))) __m256i reduce(__m256i a, __m256i b) {
return _mm256_add_epi32(a, b);
}
Finally, these lines in main()
are something you should not touch. They build the segment tree. Make sure to fill the array from a[N]
to a[N*2-1]
before building it.
for(int i = N - 1; i >= 1; i--) {
a[i] = reduce(a[i * 2], a[i * 2 + 1]);
}
a[0] = identity_element;
Further work
Implement point update queries in a similar way. This should be very fast for segment-tree-on-sum with point += queies, segment tree on minimum with point min= and alike.
Unfortunately BIT/fenwick tree cannot be optimized this way, it turns out 1.5x slower.
Contributions are welcome.