I kept getting TLE on test case 27, even though my solution should be O(k^2 * logn) per query with fast input. I’m very grateful for any possible help!
#include <bits/stdc++.h>
using namespace std;
using ui = unsigned int;
const int MAX_K = 11;
vector<vector<ui>> nck(MAX_K);
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
struct Node {
int sza; ui val;
vector<ui> ans;
uint32_t priority;
Node *ln, *rn;
Node(ui _val) : sza(1), val(_val), ans(MAX_K), priority(rng()), ln(NULL), rn(NULL) {}
};
using pNode = Node*;
int sza(pNode t) {
return t ? t->sza : 0;
}
ui ans(pNode t, int k) {
return t ? t->ans[k] : 0;
}
void pull(pNode t) {
if (!t) return;
t->sza = sza(t->ln) + sza(t->rn) + 1;
vector<ui> pw = {1};
for (int i = 1; i < MAX_K; i++) {
pw.push_back(pw.back() * (sza(t->ln) + 1));
}
for (int k = 0; k < MAX_K; k++) {
t->ans[k] = ans(t->ln, k) + pw[k] * t->val;
for (int i = 0; i <= k; i++) {
t->ans[k] += ans(t->rn, i) * pw[k - i] * nck[k][i];
}
}
}
void heapify(pNode t) {
if (!t)
return;
pNode mx = t;
if (t->ln && t->ln->priority > mx->priority)
mx = t->ln;
if (t->rn && t->rn->priority > mx->priority)
mx = t->rn;
if (mx != t) {
swap(t->priority, mx->priority);
heapify(mx);
}
}
pNode build(int n, ui *a) {
if (n == 0)
return NULL;
int mid = n / 2;
pNode t = new Node(a[mid]);
t->ln = build(mid, a);
t->rn = build(n - mid - 1, a + mid + 1);
heapify(t);
pull(t);
return t;
}
void split(pNode t, pNode &ln, pNode &rn, int pos, int add = 0) {
if (!t) {
ln = rn = NULL;
return;
}
int curPos = add + sza(t->ln);
if (pos < curPos)
split(t->ln, ln, t->ln, pos, add), rn = t;
else
split(t->rn, t->rn, rn, pos, curPos + 1), ln = t;
pull(t);
}
void merge(pNode &t, pNode ln, pNode rn) {
if (!ln || !rn)
t = ln ? ln : rn;
else if (ln->priority > rn->priority)
merge(ln->rn, ln->rn, rn), t = ln;
else
merge(rn->ln, ln, rn->ln), t = rn;
pull(t);
}
ui query(pNode t, int l, int r, int k) {
pNode pl, pm, pr;
split(t, pm, pr, r);
split(pm, pl, t, l - 1);
ui res = t->ans[k];
merge(pm, pl, t);
merge(t, pm, pr);
return res;
}
inline char gc(){
char ch=getchar();
while (ch!='I'&&ch!='D'&&ch!='R'&&ch!='Q') ch=getchar();
return ch;
}
inline int readi() {
int x=0,f=1;
char ch=getchar();
while (ch<'0'||ch>'9') {
if (ch=='-') f=-1;
ch=getchar();
}
while (ch>='0'&&ch<='9') {
x=(x*10)+(ch-'0');
ch=getchar();
}
return x*f;
}
inline ui readu(){
ui x=0;
char ch=getchar();
while (ch<'0'||ch>'9'){
ch=getchar();
}
while (ch>='0'&&ch<='9'){
x=(x*(ui)(10))+(ui)(ch-'0');
ch=getchar();
}
return x;
}
void solve(int tc = 0) {
for (int i = 0; i < MAX_K; i++) {
nck[i].push_back(1);
for (int j = 1; j <= i; j++) {
nck[i].push_back(i == 0 ? 1 : nck[i - 1][j - 1] + nck[i - 1][j]);
}
}
int n = readi();
vector<ui> a(n);
for (ui &x : a) x = readu();
auto tr = build(n, &a[0]);
int q = readi();
while (q--) {
char c = gc(); int i = readi();
if (c == 'I') {
ui val = readu();
pNode pr;
split(tr, tr, pr, i - 1);
merge(tr, tr, new Node(val));
merge(tr, tr, pr);
} else if (c == 'D') {
pNode pm, pr;
split(tr, tr, pr, i);
split(tr, tr, pm, i - 1);
merge(tr, tr, pr);
} else if (c == 'R') {
ui val = readu();
pNode pm, pr;
split(tr, tr, pr, i);
split(tr, tr, pm, i - 1);
merge(tr, tr, new Node(val));
merge(tr, tr, pr);
} else {
int j = readi();
int k = readi();
cout << query(tr, i, j, k) << "\n";
}
}
}
signed main() {
ios_base::sync_with_stdio(false); cin.tie(NULL);
int tc = 1;
// cin >> tc;
for (int t = 1; t <= tc; t++) {
// cout << "Case #" << t << ": ";
solve(t);
}
}