1 / 1
May 2024

I am trying to solve the problem Gao on a tree (https://www.spoj.com/problems/GOT/15) using Euler tour and segment tree. But Im getting TLE. can someone help me understand how I can optimise it.

#include <bits/stdc++.h>

using namespace std;

// shortforms

#define pb push_back
#define pf push_front
#define ppb pop_back
#define ppf pop_front
#define all(x) (x).begin(),(x).end()
#define srt(v) sort(v.begin(),v.end())
#define rev(v) reverse(v.begin(),v.end())
#define lb(v,x) lower_bound(v.begin(),v.end(),x)
#define ub(v,x) upper_bound(v.begin(),v.end(),x)
#define cpy(v2,v1) v2.assign(v1.begin(),v1.end())
#define maxv(a) *max_element(a.begin(), a.end())
#define minv(a) *min_element(a.begin(), a.end())
#define ff first
#define ss second
#define endl "\n"

//type definitions

typedef long long ll;
typedef unsigned long long ull;
typedef long double lld;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
typedef vector<ll> vll;

//constants

const long long int inf = 1e18;
const int mod = 1000000007;
#define pi 3.141592653589793238462

struct custom_hash {
    static uint64_t splitmix64(uint64_t x) {
        // http://xorshift.di.unimi.it/splitmix64.c
        x += 0x9e3779b97f4a7c15;
        x = (x ^ (x >> 30)) * 0xbf58476d1ce4e5b9;
        x = (x ^ (x >> 27)) * 0x94d049bb133111eb;
        return x ^ (x >> 31);
    }

    size_t operator()(uint64_t x) const {
        static const uint64_t FIXED_RANDOM = chrono::steady_clock::now().time_since_epoch().count();
        return splitmix64(x + FIXED_RANDOM);
    }
};

//graphs

//const int N=1e5+2;
//std::vector<int> vis(N,0);
//std::vector<int> adj[N];

// debugger

#ifndef ONLINE_JUDGE
#define debug(x) cerr << #x <<" "; _print(x); cerr << endl;
#else
#define debug(x)
#endif

void _print(ll t) {cerr << t;}
void _print(int t) {cerr << t;}
void _print(string t) {cerr << t;}
void _print(char t) {cerr << t;}
void _print(lld t) {cerr << t;}
void _print(double t) {cerr << t;}
void _print(ull t) {cerr << t;}

template <class T, class V> void _print(pair <T, V> p);
template <class T> void _print(vector <T> v);
template <class T> void _print(set <T> v);
template <class T, class V> void _print(map <T, V> v);
template <class T> void _print(multiset <T> v);
template <class T, class V> void _print(pair <T, V> p) {cerr << "{"; _print(p.ff); cerr << ","; _print(p.ss); cerr << "}";}
template <class T> void _print(vector <T> v) {cerr << "[ "; for (T i : v) {_print(i); cerr << " ";} cerr << "]";}
template <class T> void _print(set <T> v) {cerr << "[ "; for (T i : v) {_print(i); cerr << " ";} cerr << "]";}
template <class T> void _print(multiset <T> v) {cerr << "[ "; for (T i : v) {_print(i); cerr << " ";} cerr << "]";}
template <class T, class V> void _print(map <T, V> v) {cerr << "[ "; for (auto i : v) {_print(i); cerr << " ";} cerr << "]";}
template <class T, class V> void _print(multimap <T, V> v) {cerr << "[ "; for (auto i : v) {_print(i); cerr << " ";} cerr << "]";}

mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
ll gcd(ll a, ll b) {if (b > a) {return gcd(b, a);} if (b == 0) {return a;} return gcd(b, a % b);}
ll binexp(ll a, ll b, ll mod) {ll res = 1; while (b > 0) {if (b & 1)res = (res * a) % mod; a = (a * a) % mod; b = b >> 1;} return res;}
void extendgcd(ll a, ll b, ll*v) {if (b == 0) {v[0] = 1; v[1] = 0; v[2] = a; return ;} extendgcd(b, a % b, v); ll x = v[1]; v[1] = v[0] - v[1] * (a / b); v[0] = x; return;} //pass an arry of size1 3
ll mminv(ll a, ll b) {ll arr[3]; extendgcd(a, b, arr); return arr[0];} //for non prime b
ll mminvprime(ll a, ll b) {return binexp(a, b - 2, b);}
bool revsort(ll a, ll b) {return a > b;}
ll combination(ll n, ll r, ll m, vector<ll>& fact, vector<ll>& ifact) {ll val1 = fact[n]; ll val2 = ifact[n - r]; ll val3 = ifact[r]; return (((val1 * val2) % m) * val3) % m;}
void google(int t) {cout << "Case #" << t << ": ";}
vector<ll> sievefn(int n) {int*arr = new int[n + 1](); vector<ll> vect; for (int i = 2; i <= n; i++)if (arr[i] == 0) {vect.push_back(i); for (int j = 2 * i; j <= n; j += i)arr[j] = 1;} return vect;}
ll mod_add(ll a, ll b, ll m) {a = a % m; b = b % m; return (((a + b) % m) + m) % m;}
ll mod_mul(ll a, ll b, ll m) {a = a % m; b = b % m; return (((a * b) % m) + m) % m;}
ll mod_sub(ll a, ll b, ll m) {a = a % m; b = b % m; return (((a - b) % m) + m) % m;}
ll mod_div(ll a, ll b, ll m) {a = a % m; b = b % m; return (mod_mul(a, mminvprime(b, m), m) + m) % m;}  //only for prime m
ll phin(ll n) {ll number = n; if (n % 2 == 0) {number /= 2; while (n % 2 == 0) n /= 2;} for (ll i = 3; i <= sqrt(n); i += 2) {if (n % i == 0) {while (n % i == 0)n /= i; number = (number / i * (i - 1));}} if (n > 1)number = (number / n * (n - 1)) ; return number;} //O(sqrt(N))
ll getRandomNumber(ll l, ll r) {return uniform_int_distribution<ll>(l, r)(rng);} 

template<typename T>
void read(T &x){
    cin>>x;
}

template<typename T,typename T1>
void read(pair<T,T1> &p){
    cin>>p.ff>>p.ss;
}

template<typename T>
void read(vector<T> &a){
    for(auto &i:a)read(i);
}

template<typename T,typename T1>
void read(vector<pair<T,T1>>&a){
    for(auto &i:a)read(i);
}

template<typename T>
void print(T &x){
    cout<<x;
}

template<typename T,typename T1>
void print(pair<T,T1> &p){
    cout<<p.ff<<" "<<p.ss;
}

template<typename T>
void print(vector<T> &a){
    for(auto &i:a)print(i);
}

template<typename T,typename T1>
void print(vector<pair<T,T1>>&a){
    for(auto &i:a)print(i);
}

bool cmp(vector<ll> & a, vector<ll> & b){
    return a[2]<b[2];
}

class LCA_binarylifting{
private:
    ll n, mlog;
    vector<vector<int> >par;
    vector<int> level;

public:
    LCA_binarylifting(int no, vector<vector<ll> >&adj, int root){
        this->n=no;
        par.resize(n);
        mlog=log2(no)+2;
        level.resize(n);
        for(int i=0;i<n;i++){
            par[i].resize(mlog);
            for(int j=0;j<mlog;j++){
                par[i][j]=-1; // parent not assigned yet
            }
        }
        filltable(root, adj);
    }

    void pardfs(int node, vector<vector<ll> >& adj, vector<bool> & vis, int l){
        vis[node]=true;
        level[node]=l;
        for(auto child: adj[node]){
            if(!vis[child]){
                par[child][0]=node;
                pardfs(child, adj,vis, l+1);
            }
        }
    }

    void filltable(int root, vector<vector<ll> >&adj){
        vector<bool> vis(n, false);
        pardfs(root, adj, vis, 0);
        int mid;
        for(int i=1;i<mlog;i++){
            for(int j=0;j<n;j++){
                mid=par[j][i-1];
                if(mid==-1)continue;
                par[j][i]=par[mid][i-1];
            }
        }
    }

    int kthpar(int node, int k){
        for(int i=0;i<mlog;i++){
            if((k>>i)&1){
                if(node==-1)return -1;
                node=par[node][i];
            }
        }
        return node;
    }

    int lca(int a, int b){
        if(level[a]>level[b]){
            swap(a,b);
        }
        int diff=level[b]-level[a];
        b=kthpar(b,diff);

        if(a==b)return a;
        for(int i=mlog-1;i>=0;i--){
            int par1=par[a][i];
            int par2=par[b][i];
            if(par1!=par2 && par1!=-1 && par2!=-1){
                a=par1;
                b=par2;
            }
        }
        return par[a][0];
    }
};

void dfs_euler_tour(int ind, int par, vector<vector<ll> >& adj, vll & eulertour, vector<vector<ll>>& index){
    index[ind].pb(eulertour.size());
    eulertour.pb(ind);
    for(auto ch: adj[ind]){
        if(ch==par)continue;
        dfs_euler_tour(ch,ind,adj,eulertour,index);
    }
    index[ind].pb(eulertour.size());
    eulertour.pb(ind);
    return;
}

template<typename Node, typename Update>
struct LazySGT {
    vector<Node> tree;
    vector<bool> lazy;
    vector<Update> updates;
    vector<ll> arr; // type may change
    int n;
    int s;
    LazySGT(int a_len, vector<ll> &a) { // change if type updated
        arr = a;
        n = a_len;
        s = 1;
        while(s < 2 * n){
            s = s << 1;
        }
        tree.resize(s); fill(all(tree), Node());
        lazy.resize(s); fill(all(lazy), false);
        updates.resize(s); fill(all(updates), Update());
        build(0, n - 1, 1);
    }
    void build(int start, int end, int index) { // Never change this
        if (start == end)   {
            tree[index] = Node(arr[start]);
            return;
        }
        int mid = (start + end) / 2;
        build(start, mid, 2 * index);
        build(mid + 1, end, 2 * index + 1);
        tree[index].merge(tree[2 * index], tree[2 * index + 1]);
    }
    void pushdown(int index, int start, int end){
        if(lazy[index]){
            int mid = (start + end) / 2;
            apply(2 * index, start, mid, updates[index]);
            apply(2 * index + 1, mid + 1, end, updates[index]);
            updates[index] = Update();
            lazy[index] = 0;
        }
    }
    void apply(int index, int start, int end, Update& u){
        if(start != end){
            lazy[index] = 1;
            updates[index].combine(u, start, end);
        }
        u.apply(tree[index], start, end);
    }
    void update(int start, int end, int index, int left, int right, Update& u) {  // Never Change this
        if(start > right || end < left)
            return;
        if(start >= left && end <= right){
            apply(index, start, end, u);
            return;
        }
        pushdown(index, start, end);
        int mid = (start + end) / 2;
        update(start, mid, 2 * index, left, right, u);
        update(mid + 1, end, 2 * index + 1, left, right, u);
        tree[index].merge(tree[2 * index], tree[2 * index + 1]);
    }
    Node query(int start, int end, int index, int left, int right) { // Never change this
        if (start > right || end < left)
            return Node();
        if (start >= left && end <= right){
            pushdown(index, start, end);
            return tree[index];
        }
        pushdown(index, start, end);
        int mid = (start + end) / 2;
        Node l, r, ans;
        l = query(start, mid, 2 * index, left, right);
        r = query(mid + 1, end, 2 * index + 1, left, right);
        ans.merge(l, r);
        return ans;
    }
    void make_update(int left, int right, ll val) {  // pass in as many parameters as required
        Update new_update = Update(val); // may change
        update(0, n - 1, 1, left, right, new_update);
    }
    Node make_query(int left, int right) {
        return query(0, n - 1, 1, left, right);
    }
};

struct Node1 {
    ll val; // may change
    Node1() { // Identity element
        val = 0;    // may change
    }
    Node1(ll p1) {  // Actual Node
        val = p1; // may change
    }
    void merge(Node1 &l, Node1 &r) { // Merge two child nodes
        val = l.val + r.val;  // may change
    }
};

struct Update1 {
    ll val; // may change
    // Here the identity element ensures that no update is done if the val=identity element
    // so set it to a value that never occurs in the seg tree
    Update1(){ // Identity update
        val = 0;
    }
    Update1(ll val1) { // Actual Update
        val = val1;
    }
    // How this change will affect any node in the seg tree
    void apply(Node1 &a, int start, int end) { // apply update to given node
        a.val = val * (end - start + 1); // may change
    }
    // how 2 updates are to be combined
    void combine(Update1& new_update, int start, int end){
        val = new_update.val;
    }
};

int main() {
#ifndef ONLINE_JUDGE
freopen("inputf.in", "r", stdin);
freopen("outputf.in", "w", stdout);
freopen("Error.txt", "w", stderr);
auto begin = std::chrono::high_resolution_clock::now();
#endif

    //fast io
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    //main code starts here
    ll n,q;
    while(cin>>n){
        cin>>q;

        vll num(n);
        for(int i=0;i<n;i++){
            cin>>num[i];
        }

        vector<vector<ll >> adj(n);

        for(int i=0;i<n-1;i++){
            ll a,b;
            cin>>a>>b;
            a--;b--;
            adj[a].pb(b);
            adj[b].pb(a);
        }

        map<ll, vll> mp;
        for(int i=0;i<n;i++){
            mp[num[i]].pb(i);
        }
        vector<vector<ll>> queries;

        for(int i=0;i<q;i++){
            ll a,b,c;
            cin>>a>>b>>c;
            a--;b--;
            queries.pb({a,b,c,i});
        }

        sort(all(queries), cmp);

        LCA_binarylifting bf(n, adj, 0);
        vll eulertour;
        vector<vector<ll>> index(n);

        dfs_euler_tour(0,-1,adj,eulertour, index);

        vll treeList(2*n, 0);

        LazySGT<Node1, Update1> tree(2*n, treeList);

        int last=-1;

        vector<string> ans;

        for(int i=0;i<queries.size();i++){
            if(queries[i][2]==last){
                int lca=bf.lca(queries[i][0], queries[i][1]);
                if(tree.make_query(index[lca][0], index[queries[i][0]][0]).val + tree.make_query(index[lca][0], index[queries[i][1]][0]).val - tree.make_query(index[lca][0], index[lca][0]).val >0){
                    ans.pb("Find");
                }
                else{
                    ans.pb("NotFind");
                }
            }
            else{
                
                for(auto ele: mp[last]){
                    tree.make_update(index[ele][0], index[ele][0], 0);
                    tree.make_update(index[ele][1], index[ele][1], 0);
                }
                
                last=queries[i][2];

                for(auto ele: mp[last]){
                    tree.make_update(index[ele][0], index[ele][0], 1);
                    tree.make_update(index[ele][1], index[ele][1], -1);
                }
                i--;
            }
        }

        vector<string> fans(q);
        for(int i=0;i<q;i++){
            fans[queries[i][3]]=ans[i];
        }

        for(int i=0;i<fans.size();i++){
            cout<<fans[i]<<endl;
        }
    }

    #ifndef ONLINE_JUDGE
    auto end = std::chrono::high_resolution_clock::now();
    cerr << setprecision(4) << fixed;
    cerr << "Execution time: " << std::chrono::duration_cast<std::chrono::duration<double>>(end - begin).count() << " seconds" << endl;
    #endif
    return 0;
}

How can I optimise this code? Any help will be appreciated.