Here is my code
import java.util.*;
import java.io.*;
class Main{
public static int MAX=10002;
public static int depth[];
public static long dist[];
public static int parent[][];
public static int father[];
public static HashMap<Integer,ArrayList<Edge>> adj;
public static void main(String args[]) throws IOException{
try(BufferedReader br=new BufferedReader(new InputStreamReader(System.in))){
int t=Integer.parseInt(br.readLine());
while(t-->0){
br.readLine();
int n=Integer.parseInt(br.readLine());
depth =new int[MAX+5];
dist=new long[MAX+5];
parent=new int[MAX+5][14];
father=new int[MAX+5];
for(int i=0;i<MAX+5;i++){
Arrays.fill(parent[i], -1);
}
Arrays.fill(father,-1);
adj=new HashMap<>();
for(int i=0;i<n+1;i++){
adj.put(i,new ArrayList<Edge>());
}
for(int i=0;i<n-1;i++){
String s[]=br.readLine().split("\\s");
int a =Integer.parseInt(s[0]);
int b =Integer.parseInt(s[1]);
int k =Integer.parseInt(s[2]);
adj.get(a).add(new Edge(b,k));
adj.get(b).add(new Edge(a,k));
}
//node id, parent
dfs(1,0);
preprocess(n);
String inp=br.readLine();
while(!inp.equals("DONE")){
String in[]=inp.split("\\s");
if (in[0].equals("DIST")) {
int a=Integer.parseInt(in[1]);
int b=Integer.parseInt(in[2]);
System.out.println(dist[a]+dist[b]-2L*dist[lca(a,b)]);
}else{
int a=Integer.parseInt(in[1]);
int b=Integer.parseInt(in[2]);
int k=Integer.parseInt(in[3]);
System.out.println(kth(a,b,k));
}
inp=br.readLine();
}
System.out.println();
}
}catch(Exception e){
e.printStackTrace();
}
}
public static int kth(int p,int q, int k){
int l=lca(p,q);
int d1=depth[p]-depth[l]+1;
int d2=depth[q]-depth[l]+1;
int from;
if (d1<k) {
from=q;
k=(d2+d1-1)-k;
}else if (k==d1) {
return l;
}else{
from=p;
k--;
}
int log;
//now compute log(depth[from])
for(log=1;1<<log<=depth[from];log++);
log--;
while(k!=0){
if(1<<log<=k){
from=parent[from][log];
k-=(1<<log);
}
--log;
if(log<0){
log=0;
}
}
return from;
}
public static int lca(int p, int q){
int tmp,log,i;
// to make p the deeper one
if (depth[p]<depth[q]) {
tmp=p;p=q;q=tmp;
}
//now compute log(depth[p])
for(log=1;1<<log<=depth[p];log++);
log--;
for(i=log;i>=0;i--){
if(depth[p]-(1<<i)>=depth[q]){
p=parent[p][i];
}
}
if (p==q) {
return p;
}
for(i=log;i>=0;i--){
if (parent[p][i]!=-1&&parent[p][i]!=parent[q][i]) {
p=parent[p][i];
q=parent[q][i];
}
}
return father[p];
}
public static void dfs(int currNode, int parentNode){
depth[currNode]=depth[parentNode]+1;
ArrayList<Edge> l=adj.get(currNode);
for(int i=0;i<l.size();i++){
Edge v=l.get(i);
if (v.end!=parentNode) {
father[v.end]=currNode;
dist[v.end]=dist[currNode]+v.weight;
dfs(v.end,currNode);
}
}
}
public static void preprocess(int n){
for(int i=1;i<=n;i++){
parent[i][0]=father[i];
}
for(int j=1;1<<j<n;j++){
for(int i=1;i<=n;i++){
if (parent[i][j-1]!=-1) {
parent[i][j]=parent[parent[i][j-1]][j-1];
}
}
}
}
static class Edge{
int end;
int weight;
Edge(int end, int weight){
this.end=end;
this.weight=weight;
}
}
}