I was recently up-solving AtCoder Beginner Contest 339 and came across a Persistent Segment Tree Problem G — Smaller Sum.
While implementing my own Persistent Segment Tree. I tried to make it as generic as possible.
Code:
#include <bits/stdc++.h>
using namespace std;
/**
* Source: https://github.com/kth-competitive-programming/kactl/blob/main/content/various/BumpAllocator.h
* Description: When you need to dynamically allocate many objects and don't
* care about freeing them. "new X" otherwise has an overhead of something like
* 0.05us + 16 bytes per allocation.
*/
const size_t SZ = (450 << 20); // 450 mb
static char buf[SZ];
class Alloc {
private:
size_t ptr;
public:
Alloc() : ptr(sizeof(buf)) {}
void *alloc(size_t s) {
assert(s < ptr);
return (void *)&buf[ptr -= s];
}
void reset() { ptr = sizeof(buf); }
};
template <typename Info> class Node {
public:
Info info;
Node *left, *right;
};
template <typename Info> class PSegTree {
public:
typedef Node<Info> node_t;
node_t *root;
vector<node_t *> time;
int n;
Alloc ar;
PSegTree(size_t sz) : PSegTree(vector<Info>(sz, Info())) {}
PSegTree(const vector<Info> &info) {
root = nullptr;
ar = Alloc();
n = (int)info.size();
function<node_t *(int, int, int)> build = [&](int p, int l,
int r) -> node_t * {
if (l == r) {
node_t *node = (node_t *)ar.alloc(sizeof(node_t));
node->info = info[l];
node->left = node->right = nullptr;
return node;
}
int m = l + (r - l) / 2;
return pull(build(2 * p + 1, l, m), build(2 * p + 2, m + 1, r));
};
root = build(0, 0, n - 1);
time.push_back(root);
}
node_t *pull(node_t *left, node_t *right) {
node_t *node = (node_t *)ar.alloc(sizeof(node_t));
node->info = (left->info + right->info);
node->left = left, node->right = right;
return node;
}
void modify(int p, const Info &v) {
function<node_t *(node_t *, int, int)> _modify = [&](node_t *c, int l,
int r) -> node_t * {
if (l == r) {
node_t *node = (node_t *)ar.alloc(sizeof(node_t));
node->info = v;
node->left = node->right = nullptr;
return node;
}
int m = l + (r - l) / 2;
node_t *left = c->left, *right = c->right;
if (p <= m) {
left = _modify(left, l, m);
} else {
right = _modify(right, m + 1, r);
}
return pull(left, right);
};
root = _modify(root, 0, n - 1);
time.push_back(root);
}
Info rangeQuery(int t, int x, int y) {
function<Info(node_t *, int, int)> query = [&](node_t *c, int l,
int r) -> Info {
if (y < l or r < x or c == nullptr) {
return Info();
}
if (x <= l and r <= y) {
return c->info;
}
int m = l + (r - l) / 2;
return query(c->left, l, m) + query(c->right, m + 1, r);
};
return query(time[t], 0, n - 1);
}
};
class Sum {
public:
int64_t x = 0;
Sum() : x(0) {}
Sum(int64_t _x) : x(_x) {}
};
Sum operator+(const Sum &lf, const Sum &rt) {
return Sum(lf.x + rt.x);
}
void solve() {
int N;
cin >> N;
vector<int> A(N);
for (int i = 0; i < N; i++)
cin >> A[i];
map<int, int, greater<int>> id;
vector<int> a = A;
sort(a.begin(), a.end());
for (auto &e : a) {
if (id.find(e) == id.end()) {
int sz = (int)id.size();
id[e] = sz;
}
}
PSegTree<Sum> seg(id.size());
for (int i = 0; i < N; i++) {
int idx = id[A[i]];
int64_t val = seg.rangeQuery(i, idx, idx).x;
seg.modify(idx, Sum(val + A[i]));
}
int Q;
cin >> Q;
int64_t b = 0;
for (int _ = 0; _ < Q; _++) {
int64_t l, r, x;
cin >> l >> r >> x;
l = (l ^ b), r = (r ^ b), x = (x ^ b);
b = 0;
if (x != 0) {
auto up = id.lower_bound(x);
if (up != id.end()) {
int idx = up->second;
int64_t rt = seg.rangeQuery(r, 0, idx).x;
int64_t lf = seg.rangeQuery(l - 1, 0, idx).x;
b = rt - lf;
}
}
cout << b << '\n';
}
}
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
solve();
return 0;
}
Here are few things that I want your help and opinion on:
- Is their a more generic and efficient way to implement Persistent Segment Tree? and How to improve the above code?
- How can you identify if a problem uses Persistent Segment Tree?
- Is their a better way to implement custom allocator for Competitive Programming(or for C++/C projects)?
- How to implement Lazy Persistent Segment Tree?
- How to add documentation to Personal Library codes?
Upd: I re-wrote the Persistent Segment Tree with documentation and without custom allocator. Thanks to lrvideckis for help.
Code:
#include <bits/stdc++.h>
using namespace std;
template <typename Info> class PSegTree {
public:
int root;
vector<Info> info;
vector<int> time, left, right;
int n, index, size;
/**
* Create a new Persistent Segment Tree
* @brief constructor
* @param sz defines the size of range [0, sz - 1]
* @time O(n * log(n))
* @space O(8 * n)
*/
PSegTree(size_t sz) : PSegTree(vector<Info>(sz, Info())) {}
/**
* Create a new Persistent Segment Tree
* @brief constructor
* @param a vector defines the size of range [0, len(a) - 1]
* @time O(n * log(n))
* @space O(8 * n)
*/
PSegTree(const vector<Info> &a) {
root = -1;
index = 0;
n = (int)a.size();
size = 8 * n;
info.assign(size, Info());
left.assign(size, -1);
right.assign(size, -1);
function<int(int, int)> build = [&](int l, int r) -> int {
if (l == r) {
return add_leaf(a[l]);
}
int m = l + (r - l) / 2;
return pull(build(l, m), build(m + 1, r));
};
root = build(0, n - 1);
time.push_back(root);
}
/**
* @brief adds a leaf
* @param v to added as leaf
* @return index of created leaf
* @time O(1)
* @space O(1)
*/
int add_leaf(const Info &v) {
if (index < size) {
info[index] = v;
left[index] = -1;
right[index] = -1;
index++;
return index - 1;
}
assert(info.size() == left.size() and left.size() == right.size());
assert(index == size);
int idx = (int)info.size();
info.push_back(v);
left.push_back(-1);
right.push_back(-1);
size++;
index++;
return idx;
}
/**
* @brief adds parent to children
* @param left_idx and right_idx define the children to parent
* @return index of created parent
* @time O(1)
* @space O(1)
*/
int pull(int left_idx, int right_idx) {
if (index < size) {
info[index] = info[left_idx] + info[right_idx];
left[index] = left_idx;
right[index] = right_idx;
index++;
return index - 1;
}
assert(info.size() == left.size() and left.size() == right.size());
assert(index == size);
int idx = (int)info.size();
info.push_back(info[left_idx] + info[right_idx]);
left.push_back(left_idx);
right.push_back(right_idx);
size++;
index++;
return idx;
}
/**
* @brief modify the value a[index] = v on latest version
* @param p index of value to modify
* @param v new value
* @time O(log(n))
* @space O(log(n))
*/
void modify(int p, const Info &v) {
function<int(int, int, int)> _modify = [&](int c, int l, int r) -> int {
if (l == r) {
return add_leaf(v);
}
int m = l + (r - l) / 2;
int left_ptr = left[c], right_ptr = right[c];
if (p <= m) {
left_ptr = _modify(left_ptr, l, m);
} else {
right_ptr = _modify(right_ptr, m + 1, r);
}
return pull(left_ptr, right_ptr);
};
root = _modify(root, 0, n - 1);
time.push_back(root);
}
/**
* @brief modify the value a[index] = v on version t
* @param t defines the version
* @param p index of value to modify
* @param v new value
* @time O(log(n))
* @space O(log(n))
*/
void modifyTime(int t, int p, const Info &v) {
assert(t < (int)time.size());
function<int(int, int, int)> _modify = [&](int c, int l, int r) -> int {
if (l == r) {
return add_leaf(v);
}
int m = l + (r - l) / 2;
int left_ptr = left[c], right_ptr = right[c];
if (p <= m) {
left_ptr = _modify(left_ptr, l, m);
} else {
right_ptr = _modify(right_ptr, m + 1, r);
}
return pull(left_ptr, right_ptr);
};
root = _modify(time[t], 0, n - 1);
time.push_back(root);
}
/**
* @brief find the range query for [x, y] on version t
* @param t defines the version
* @param x, y defines the range [x, y]
* @return a[x] + a[x + 1] + ... + a[y - 1] + a[y] on version t
* @time O(log(n))
* @space O(log(n))
*/
Info rangeQuery(int t, int x, int y) {
function<Info(int, int, int)> query = [&](int c, int l, int r) -> Info {
if (y < l or r < x or c == -1) {
return Info();
}
if (x <= l and r <= y) {
return info[c];
}
int m = l + (r - l) / 2;
return query(left[c], l, m) + query(right[c], m + 1, r);
};
return query(time[t], 0, n - 1);
}
};
class Sum {
public:
int64_t x = 0;
Sum() : x(0) {}
Sum(int64_t _x) : x(_x) {}
};
Sum operator+(const Sum &lf, const Sum &rt) { return Sum(lf.x + rt.x); }
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int N;
cin >> N;
vector<int> A(N);
for (int i = 0; i < N; i++)
cin >> A[i];
map<int, int, greater<int>> id;
vector<int> a = A;
sort(a.begin(), a.end());
for (auto &e : a) {
if (id.find(e) == id.end()) {
int sz = (int)id.size();
id[e] = sz;
}
}
PSegTree<Sum> seg(id.size());
for (int i = 0; i < N; i++) {
int idx = id[A[i]];
int64_t val = seg.rangeQuery(i, idx, idx).x;
seg.modifyTime(i, idx, Sum(val + A[i]));
}
int Q;
cin >> Q;
int64_t b = 0;
for (int _ = 0; _ < Q; _++) {
int64_t l, r, x;
cin >> l >> r >> x;
l = (l ^ b), r = (r ^ b), x = (x ^ b);
b = 0;
if (x != 0) {
auto up = id.lower_bound(x);
if (up != id.end()) {
int idx = up->second;
int64_t rt = seg.rangeQuery(r, 0, idx).x;
int64_t lf = seg.rangeQuery(l - 1, 0, idx).x;
b = rt - lf;
}
}
cout << b << '\n';
}
return 0;
}