본문 바로가기
ps연습장

백준 25638번 : 트리와 경로 개수 쿼리 (C++)

by hwsyl 2023. 5. 5.
반응형

<문제>

백준 25638번

 

<풀이>

q개의 쿼리에 대해 정점 u를 지나는 경로의 개수를 구하는 문제이다.

정점 u에 대해서 다음과 같이 2개의 트리로 쪼개 진다고 가정해보자.

트리 x에대해 빨간점의 개수를 red[x], 파란점의 개수를 blue[x]라고 정의할때,

경로의 개수는 red[1]*blue[2] + blue[1]*red[2]개이다.

따라서 각각의 트리에 대해 red[x], blue[x]를 구하는 것이 핵심이라고 할 수 있다.

 

그러나 모든 쿼리마다 각각의 트리의 red[x], blue[x]를 계산하면 시간 초과가 날 것이 뻔하다.

보다 효율적인 방법을 생각해보자.

 

우선 풀이의 편의성을 위해 루트를 1로 설정하겠다. 그리고 임이의 정점u를 아래 그림으로 나타내었다.

트리

r[x], b[x]를 아래와 같이 정의해주자

  • r[x] : 정점 x를 루트로 하는 서브트리에 속한 빨간 정점의 수
  • b[x] : 정점 x를 루트로 하는 서브트리에 속한 파란 정점의 수

dfs를 이용하면 모든 정점 u에 대해 r[u], b[u]를 구할 수 있다. (백준 15681번 트리와 쿼리 참고)

 

u를 지나는 경로는 크게 2가지 경우가 있다. 2가지 경우를 계산해서 ans에 더해주면 된다.

  1. (u의 서브트리의 점) ->  u  -> (u의 서브트리에 속하지 않는 점)
  2. (u의 서브트리의 점) ->  u  -> (u의 서브트리의 점)

1번 경우 (r[u] - a[u])*(b[1] - b[u]) + (b[u] - !a[u])*(r[1] - r[u])로 구해진다.

2번 경우는 아래와 같다.

경로 개수를 구하는 공식

식이 복잡해 보이는데 그냥 빨간점x파란점 해서 다 더한거다.

이때 2번 경우는 for문을 돌게 되는데 시간 초과가 날 수도 있음으로 dp배열을 선언해 계산을 반복하지 않도록 하자.

 

<코드>

 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
#include<bits/stdc++.h>
#define fio()                     \
    ios_base::sync_with_stdio(0); \
    cin.tie(0)
using namespace std;
 
typedef long long ll;
typedef pair<intint> pii;
typedef pair<ll, ll> pll;
typedef tuple<intintint> tpi;
typedef tuple<ll, ll, ll> tpl;
typedef pair<double, ll> pdl;
 
const int INF = 0x3f3f3f3f;
const ll LINF = 0x3f3f3f3f3f3f3f3f;
const int dx[] = { 010-1 };
const int dy[] = { 10-10 };
 
int a[101010], r[101010], b[101010], pa[101010];
vector<int> g[101010];
ll dp[101010];
 
int rdfs(int x){
    if(r[x] != -1return r[x];
 
    r[x] = a[x];
    for(auto nx : g[x]){
        if(r[nx] < 0){
            r[x] += rdfs(nx);
            pa[nx] = x;
        }
    }
    return r[x];
}
 
int bdfs(int x){
    if(b[x] != -1return b[x];
 
    b[x] = !a[x];
 
    for(auto nx : g[x]){
        if(b[nx] < 0) b[x] += bdfs(nx);
    }
 
     return b[x];
}
 
int main(){
    memset(r, -1sizeof(r));
    memset(b, -1sizeof(b));
    memset(dp, -1L, sizeof(dp));
    int n; scanf("%d"&n);
    for(int i = 1; i <=n ; i++scanf("%d"&a[i]);
    for(int i = 0; i < n-1; i++){
        int u, v; scanf("%d %d"&u, &v);
        g[u].push_back(v);
        g[v].push_back(u);
    }
    rdfs(1);
    bdfs(1);
    int q; scanf("%d"&q);
    for(int i = 0; i < q; i++){
        int u; scanf("%d"&u);
        if(dp[u] < 0){ //u가 반복되었을때 다시 계산 방지
            ll ans = (ll) (r[u] - a[u])*(b[1- b[u])+(ll)(b[u] - !a[u])*(r[1- r[u]); // 1번 경우
            for(auto nx : g[u]){
                if(nx == pa[u]) continue;
                ans += (ll) b[nx]*(r[u]-r[nx]-a[u]); // 2번 경우
            }
            dp[u] = ans;
        }
        printf("%lld\n", dp[u]);
    }
}
 
cs

 

<회고>

제법 어려웠지만 dfs와 dp같은 기본 개념을 잘 활용하면 충분히 풀 수 있는 문제였다.