Hello codeforcers.
I'm learning segment trees, and for now I'm focused on "standard" range query/point update segment trees. I've tried to solve CSES 1144 : Salary Queries, but I keep getting TLE.
I've written a simple segment tree struct, and I'm using coordinate compression. I suppose that I should find a constant factor optimization to pass the time limit. Is there a generic optimization I should include in my implementation to make it faster ?
I guess I could consider using more advanced segment trees variants, but I thought I could stay on a generic ST. Am I wrong ?
Code
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef vector<long long> vll;
#define forn(i,j,k) for(ll i=(j); i<=(k); i++)
#define rofn(i,j,k) for(ll i=(j); i>=(k); i--)
#define forv(b,a) for(auto &b : a)
#define yon(t) cout << (t ? "YES" : "NO") << endl;
#define out(x) cout << x << " ";
#define outln(x) cout << x << endl;
#define in(x) cin >> x;
void fastIO() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
}
template<typename T, T (*combine)(T, T), T (*defaultValue)()>
struct SegmentTree {
public:
SegmentTree(int n) : SegmentTree(vector<T>(n+1, defaultValue())) {}
SegmentTree(const vector<T> &v) : _n(v.size()-1) {
_size = 1;
while(_size < _n) _size <<= 1;
data.resize(2*_size, defaultValue());
build(v, 0, 1, _n);
}
T get(int pos) {
assert(pos >= 1 && pos <= _n);
return query(pos, pos);
}
T query(int start, int end) {
assert(1 <= start && start <= end && end <= _n);
return query(0, start, end, 1, _n);
}
T query(int node, int start, int end, int left, int right) {
if(start <= left && end >= right) return data[node];
int mid = (left+right) / 2;
T ans = defaultValue();
if(start <= mid) ans = combine(ans, query(2*node+1, start, end, left, mid));
if(end >= mid+1) ans = combine(ans, query(2*node+2, start, end, mid+1, right));
return ans;
}
void set(int pos, T val) {
assert(pos >= 1 && pos <= _n);
set(pos, val, 0, 1, _n);
}
void set(int pos, T val, int node, int left, int right) {
if(left == right) {
data[node] = val;
return;
}
int mid = (left+right) / 2;
if(pos <= mid) set(pos, val, 2*node+1, left, mid);
else set(pos, val, 2*node+2, mid+1, right);
data[node] = combine(data[2*node+1], data[2*node+2]);
}
T find(ll k, int node, int left, int right) {
if(data[node] < k) return 0;
if(left == right) return left;
int mid = (left+right) / 2;
if(data[2*node+1] >= k) return find(k, 2*node+1, left, mid);
return find(k-data[2*node+1], 2*node+2, mid+1, right);
}
private:
vector<T> data;
int _n, _size;
void build(const vector<T> &v, int node, int left, int right) {
if(left == right) {
data[node] = v[left];
} else {
int mid = (left+right) / 2;
build(v, 2*node+1, left, mid);
build(v, 2*node+2, mid+1, right);
data[node] = combine(data[2*node+1], data[2*node+2]);
}
}
};
ll combine(ll lv, ll rv) { return lv+rv; }
ll defaultValue() { return 0; }
struct custom_hash {
static uint64_t splitmix64(uint64_t x) {
// http://xorshift.di.unimi.it/splitmix64.c
x += 0x9e3779b97f4a7c15;
x = (x ^ (x >> 30)) * 0xbf58476d1ce4e5b9;
x = (x ^ (x >> 27)) * 0x94d049bb133111eb;
return x ^ (x >> 31);
}
size_t operator()(uint64_t x) const {
static const uint64_t FIXED_RANDOM = chrono::steady_clock::now().time_since_epoch().count();
return splitmix64(x + FIXED_RANDOM);
}
};
int main() {
fastIO();
ll n, q;
in(n) in(q)
vll v(n+1);
set<ll> values;
forn(i,1,n) {
in(v[i]);
values.insert(v[i]);
}
vector<pair<bool, pair<ll,ll>>> queries(q+1);
char c; ll a, b;
forn(i,1,q) {
in(c) in(a) in(b)
if(c=='?') {
queries[i] = {true, {a, b}};
values.insert(a);
values.insert(b);
} else {
queries[i] = {false, {a, b}};
values.insert(b);
}
}
unordered_map<ll, int, custom_hash> valueToIndex;
int index = 1;
for(const ll &value : values) valueToIndex[value] = index++;
SegmentTree<ll, combine, defaultValue> seg(index-1);
forn(i,1,n) seg.set(valueToIndex[v[i]], seg.get(valueToIndex[v[i]])+1);
forn(i,1,q) {
const auto &query = queries[i];
a = query.second.first;
b = query.second.second;
if(query.first) {
outln(seg.query(valueToIndex[a], valueToIndex[b]));
} else {
seg.set(valueToIndex[v[a]], seg.get(valueToIndex[v[a]])-1);
v[a] = b;
seg.set(valueToIndex[v[a]], seg.get(valueToIndex[v[a]])+1);
}
}
return 0;
}
I don't think you are supposed to use segment tree here. Have you tried using Order Statistics Tree? Basically you need a DS that supports finding how many indices below a certain point (like
lower_bound
). Also, I saw that each node in your segtree is holding info of the sum of its children. I don't see how this info helps solving the problem.I'm indeed using ST where each node holds info for the sum of its children. The leaves represent a frequency array of salaries (compressed to only consider salaries actually used in the queries).
I haven't tried Order Statistics Tree, I will look into it.
fbrunodr solution can be implemented using PBDS that supports that type of queries. I solved it here for your reference. I am also adding PBDS basic operations for you to look up below:
Yeah, you can also implement a BST yourself or use a Fenwick tree (although you would have to solve offline compressing the values of the salaries)
You can use policy based data structure that supports operations in $$$log(N)$$$ , as explained below by fellow programmers.
Here is my solution https://cses.fi/paste/409056f70f117a41b33be3/