[백준] [C++] 1761번 정점들의 거리 (LCA, segment tree)

2021. 7. 20. 17:22알고리즘/백준

1. 문제

 

1761번: 정점들의 거리 (acmicpc.net)

 

1761번: 정점들의 거리

첫째 줄에 노드의 개수 N이 입력되고 다음 N-1개의 줄에 트리 상에 연결된 두 점과 거리를 입력받는다. 그 다음 줄에 M이 주어지고, 다음 M개의 줄에 거리를 알고 싶은 노드 쌍이 한 줄에 한 쌍씩

www.acmicpc.net

 

 

2. 풀이

이 문제는 주어진 입력을 트리로 표현하고

 

모든 노드를 루트까지 거리를 구해

 

주어진 질의 a , b 에 대하여 다음을 답해야한다.

depth(a) + depth(b) - 2depth(lca)

(이 식은 구간합과 유사한 형태로, 증명또한 비슷하게 할 수 있다.)

 

3. 트리 재구성 : 구간트리

LCA는 최소 공통 조상이다.

 

특정 구간이 주어질때 그에 대해 질의를 답해줄 수 있는 구간 트리, 

 

두 점이 주어질때 그 점 두개의 최소 공통 조상을 구간트리로 표현하려면

 

트리의 형태를 재구성해야한다.

 

이 트리를 재구성하면서 생성해야하는 배열은 다음과 같다. 

1) 순회하면서 생성하는 발자취 배열,

2) 현재 위치에서 루트까지 cost를 저장하는 깊이 배열

3) 새로운번호와 기존의 번호 각각 매칭하는 두개의 배열

   (깊이 배열은 기존번호, lca는 새로운번호이기 때문에 복원해주는 배열이 필요하다)

4) 발자취 배열에서 새로운번호의 시작 위치를 알려주는 배열

 

알고리즘은 다음과 같다.

1) 트리의 루트부터 끝까지 dfs로 탐색하면서 오름차순으로 번호를 붙인다.

2) 탐색중인 노드는 초기에 배열들을 설정해주고, 발자취 배열에 담는다.

3) 자손 노드의 탐색이 끝났을때  따로 발자취 배열에 또 저장한다.

 

이렇게 생성된 발자취 배열은 dfs의 특성때문에

 

두 점 사이에 LCA가 포함되게 되는데, 

 

LCA는 이 구간에서 가장 최소인 번호를 가지게 된다.

- u->v 의 경로가 방문하는 최상위 노드는 항상 LCA.
- u를 포함하는 서브트리에서 v를 포함하는 서브트리로 넘어가려면 LCA 를 항상 지나쳐야한다.
- LCA 그 위의 부모노드로 올라가려면 LCA 가 포함된 모든 서브트리를 탐색해야한다.
- 따라서 LCA 의 조상 노드들은 u -> v 경로 사이에 존재하지 않는다. 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
void traverse(int here, int parent, int d, vector<int>& trip) {
    no2serial[here] = nextSerial;
    serial2no[nextSerial] = here;
    ++nextSerial;
 
    depth[here] = d;
    locInTrip[here] = trip.size();
    trip.push_back(no2serial[here]);
    
    for(int i = 0; i < child[here].size(); ++i) {
        int there = child[here][i].first;
        int dist = child[here][i].second;
        if(parent == there) {
            continue;
        }
        traverse(there, here, d + dist, trip);
        trip.push_back(no2serial[here]);
    }
}
cs

 

4. 구간 트리 : RMQ

발자취 배열 구간에서의 최소값을 리턴하는 구간트리로 형성하여

 

주어진 두 점에 대해 질의를 답할 수 있도록 구현해야한다.

 

 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
#include <bits/stdc++.h>
 
using namespace std;
 
struct RMQ {
    int n;
    vector<int> rangeMin;
    RMQ(const vector<int>& array) {
        n = array.size();
        rangeMin.resize(n*4);
        init(array, 0, n-11);
    }
    int init(const vector<int>& array, int left, int right, int node) {
        if(left == right)
            return rangeMin[node] = array[left];
        int mid = (left + right) / 2;
        int leftMin = init(array, left, mid, node*2);
        int rightMin = init(array, mid + 1, right, node*2 + 1);
        return rangeMin[node] = min(leftMin, rightMin);
    }
    int query(int left, int right, int node, int nodeLeft, int nodeRight) {
        if( right < nodeLeft || nodeRight < left) return INT32_MAX;
        if( left <= nodeLeft && nodeRight <= right) return rangeMin[node];
        int mid = (nodeLeft + nodeRight)/2;
        return min(query(left, right, node*2, nodeLeft, mid),
                   query(left, right, node*2+1, mid+1, nodeRight));
     }
    int query(int left, int right) {
        return query(left, right, 10, n-1);
    }
};
 
const int MAX_N = 40001;
//tree number <-> new number serial
int no2serial[MAX_N], serial2no[MAX_N];
// first visit location, 
int locInTrip[MAX_N], depth[MAX_N];
int nextSerial;
int N;
int M;
vector<vector<pair<intint> > > child;
 
void traverse(int here, int parent, int d, vector<int>& trip) {
    no2serial[here] = nextSerial;
    serial2no[nextSerial] = here;
    ++nextSerial;
 
    depth[here] = d;
    locInTrip[here] = trip.size();
    trip.push_back(no2serial[here]);
    
    for(int i = 0; i < child[here].size(); ++i) {
        int there = child[here][i].first;
        int dist = child[here][i].second;
        if(parent == there) {
            continue;
        }
        traverse(there, here, d + dist, trip);
        trip.push_back(no2serial[here]);
    }
}
RMQ* prepareRMQ(int root) {
    nextSerial = 0;
    vector<int> trip;
    traverse(root, root, 0, trip);
    return new RMQ(trip);
}
 
int distance(RMQ* rmq, int u, int v) {
    int lu = locInTrip[u], lv = locInTrip[v];
    if(lu > lv) swap(lu, lv);
    int lca = serial2no[rmq->query(lu, lv)];
    return depth[u] + depth[v] - 2*depth[lca];
}
int main() {
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);
 
    cin>>N;
    child = vector<vector<pair<intint> > >(N+1);
    int root = 1;
    for(int i = 0; i < N - 1; i++) {
        int a, b, d;
        cin>>a>>b>>d;
        child[a].push_back({b, d});
        child[b].push_back({a, d});
        if(child[root].size() < child[a].size()){
            root = a;
        }
        if(child[root].size() < child[b].size()) {
            root = b;
        }
    }
    RMQ* rmq = prepareRMQ(root);
    cin>>M;
    for(int i = 0; i < M; i++) {
        int a, b;
        cin>>a>>b;
        if(a == b) {
            cout<<0<<"\n";
        }
        else {
            cout<<distance(rmq, a, b)<<"\n";
        }
    }
}
cs

 

유사한 문제 : 

[트리] [시도x] FAMILYTREE 족보 탐험 (tistory.com)