1 / 1
Nov 2021

I kept getting TLE on test case 27, even though my solution should be O(k^2 * logn) per query with fast input. I’m very grateful for any possible help!

#include <bits/stdc++.h>
 
using namespace std;

using ui = unsigned int;
 
const int MAX_K = 11;
 
vector<vector<ui>> nck(MAX_K);
 
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
 
struct Node {
    int sza; ui val;
    vector<ui> ans;
    uint32_t priority;
    Node *ln, *rn;
 
    Node(ui _val) : sza(1), val(_val), ans(MAX_K), priority(rng()), ln(NULL), rn(NULL) {}
};
using pNode = Node*;
 
int sza(pNode t) {
    return t ? t->sza : 0;
}
 
ui ans(pNode t, int k) {
    return t ? t->ans[k] : 0;
}
 
void pull(pNode t) {
    if (!t) return;
    t->sza = sza(t->ln) + sza(t->rn) + 1;
    vector<ui> pw = {1};
    for (int i = 1; i < MAX_K; i++) {
        pw.push_back(pw.back() * (sza(t->ln) + 1));
    }
    for (int k = 0; k < MAX_K; k++) {
        t->ans[k] = ans(t->ln, k) + pw[k] * t->val;
        for (int i = 0; i <= k; i++) {
            t->ans[k] += ans(t->rn, i) * pw[k - i] * nck[k][i];
        }
    }
}
 
void heapify(pNode t) {
    if (!t)
        return;
    pNode mx = t;
    if (t->ln && t->ln->priority > mx->priority)
        mx = t->ln;
    if (t->rn && t->rn->priority > mx->priority)
        mx = t->rn;
    if (mx != t) {
        swap(t->priority, mx->priority);
        heapify(mx);
    }
}
 
pNode build(int n, ui *a) {
    if (n == 0)
        return NULL;
    int mid = n / 2;
    pNode t = new Node(a[mid]);
    t->ln = build(mid, a);
    t->rn = build(n - mid - 1, a + mid + 1);
    heapify(t);
    pull(t);
    return t;
}
 
void split(pNode t, pNode &ln, pNode &rn, int pos, int add = 0) {
    if (!t) {
        ln = rn = NULL;
        return;
    }
    int curPos = add + sza(t->ln);
    if (pos < curPos)
        split(t->ln, ln, t->ln, pos, add), rn = t;
    else
        split(t->rn, t->rn, rn, pos, curPos + 1), ln = t;
    pull(t);
}
 
void merge(pNode &t, pNode ln, pNode rn) {
    if (!ln || !rn)
        t = ln ? ln : rn;
    else if (ln->priority > rn->priority)
        merge(ln->rn, ln->rn, rn), t = ln;
    else
        merge(rn->ln, ln, rn->ln), t = rn;
    pull(t);
}
 
ui query(pNode t, int l, int r, int k) {
    pNode pl, pm, pr;
    split(t, pm, pr, r);
    split(pm, pl, t, l - 1);
    ui res = t->ans[k];
    merge(pm, pl, t);
    merge(t, pm, pr);
    return res;
}
 
inline char gc(){
    char ch=getchar();
    while (ch!='I'&&ch!='D'&&ch!='R'&&ch!='Q') ch=getchar();
    return ch;
}
inline int readi() {
    int x=0,f=1;
    char ch=getchar();
    while (ch<'0'||ch>'9') {
        if (ch=='-') f=-1;
        ch=getchar();
    }
    while (ch>='0'&&ch<='9') {
        x=(x*10)+(ch-'0');
        ch=getchar();
    }
    return x*f;
}
inline ui readu(){
    ui x=0;
    char ch=getchar();
    while (ch<'0'||ch>'9'){
        ch=getchar();
    }
    while (ch>='0'&&ch<='9'){
        x=(x*(ui)(10))+(ui)(ch-'0');
        ch=getchar();
    }
    return x;
}
 
void solve(int tc = 0) {
    for (int i = 0; i < MAX_K; i++) {
        nck[i].push_back(1);
        for (int j = 1; j <= i; j++) {
            nck[i].push_back(i == 0 ? 1 : nck[i - 1][j - 1] + nck[i - 1][j]);
        }
    }
 
    int n = readi();
    vector<ui> a(n);
    for (ui &x : a) x = readu();
    auto tr = build(n, &a[0]);
    int q = readi();
    while (q--) {
        char c = gc(); int i = readi();
        if (c == 'I') {
            ui val = readu();
            pNode pr;
            split(tr, tr, pr, i - 1);
            merge(tr, tr, new Node(val));
            merge(tr, tr, pr);
        } else if (c == 'D') {
            pNode pm, pr;
            split(tr, tr, pr, i);
            split(tr, tr, pm, i - 1);
            merge(tr, tr, pr);
        } else if (c == 'R') {
            ui val = readu();
            pNode pm, pr;
            split(tr, tr, pr, i);
            split(tr, tr, pm, i - 1);
            merge(tr, tr, new Node(val));
            merge(tr, tr, pr);
        } else {
            int j = readi();
            int k = readi();
            cout << query(tr, i, j, k) << "\n";
        }
    }
}
 
signed main() {
    ios_base::sync_with_stdio(false); cin.tie(NULL);
    int tc = 1;
    // cin >> tc;
    for (int t = 1; t <= tc; t++) {
        // cout << "Case #" << t << ": ";
        solve(t);
    }
}