문제는 여기
특정간선을 포함하는경우, 실제 MST보다 큰 값이 나올 수 있다. (프루스칼로 MST를 구할때는 간선을 정렬한다음에 최소간선부터 시작함을 상기하자)
MST 한 번 하는데 $O(E \times log E)$ 니까, 간선개수만큼 하면 $O(E^2 \times log E)$ 일거고.. E가 10^5보다 크므로 나이브하게 해서는 TLE확정이다.
풀이방법
MST를 구한다음에.. 해당 간선이 이미 포함돼 있으면 MST값을 그냥 찍으면 되고, 만약 포함되어 있지 않으면 포함시키고, 다른 간선중 하나를 빼는식으로 하면 된다.
잘 이해가 가지 않는다면 다음 그래프를 보자. (간선 가중치는 생략돼 있지만 MST를 그린 화면으로 생각하자)
여기서 만약 정점 1과 정점4를 잇는 간선을 포함한 답을 구한다고 생각해보자. (위그림에서는 정점1과 정점4 사이에 간선이 없지만 원래는 있었는데 MST구하면서 제거됐다고 가정)
위 그림에서 알 수 있듯이 MST에 (1,4)간선을 추가한 순간 사이클이 생기기 때문에, (1,6) (6,2) (2,5) (5,3) (3,4) 중 하나를 대신 끊어줘야 한다.
이때 나이브하게 경로를 쭉 따라가면서 가중치 가장 큰 걸 빼주는 식으로 하면, 그래프 구조가 직선적인 경우 O(E)가 걸리게 되어 E개의 쿼리를 처리하는데 O(E^2)이 걸려서 TLE가 나게된다.
코드
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 | #include <bits/stdc++.h> using namespace std; using vi = vector<int>; using vvi = vector<vi>; typedef long long ll; const int INF = (int)1e9; const int MAXN = 200001; const int LOG_N = 19; struct node3 { int u, v, w; bool operator < (const node3& t) const { return w < t.w; } }; struct node2 { int v, w; }; bool operator<(node2 l, node2 r) { return l.w > r.w; } vector<vector<node2>> adj_list(MAXN); ll MST_Kruskal(int max_v, vector<node3> v) { struct union_find { vector<int> parent; union_find(int n) { parent.resize(n + 1); for (int i = 0; i <= n; i++) parent[i] = i; } int find(int v) { return v == parent[v] ? v : parent[v] = find(parent[v]); } bool uni(int u, int v) { u = find(u), v = find(v); if (u == v) return 0; parent[u] = v; return 1; } }; union_find uf(max_v); sort(v.begin(), v.end()); ll ans = 0; for (auto a : v) { if (uf.uni(a.u, a.v)) { adj_list[a.u].push_back({ a.v,a.w }); adj_list[a.v].push_back({ a.u,a.w }); ans += a.w; } } return ans; } //무방향그래프를 인풋으로 받아, 트리로 만들어준다.LCA를 위한 작업도 해준다. struct result { int root; int max_level; }; vi parent(MAXN), level(MAXN), dist2(MAXN); // ancestor[y][x] :: x의 2^y번째 조상을 의미 vvi ancestor(LOG_N + 1, vi(MAXN, -INF)); result treefy(vector<vector<node2>>& adj_list, int root=-1,bool base1=false) { int max_v = (int)adj_list.size(); int max_level = LOG_N; function<void(node2, int, int)> dfs = [&](node2 n, int par, int lv) { parent[n.v] = par, level[n.v] = lv, dist2[n.v] = n.w; ancestor[0][n.v] = par; // 첫번째(2^0) 조상은 부모 for (auto adj : adj_list[n.v]) { if (adj.v == par) continue; // 부모방향 간선제거 dfs(adj, n.v, lv + 1); } }; dfs({ root,0 }, -INF, 1); for (int i = 1; i <= max_level; i++) { for (int j = 1; j < max_v; j++) { int tmp = ancestor[i - 1][j]; if (tmp == -INF) continue; ancestor[i][j] = ancestor[i - 1][tmp]; } } return { root, max_level }; } int lca(int u, int v, vi& level, vvi& ancestor) { if (level[u] < level[v]) swap(u, v); // b가 더 깊도록 조정 int diff = level[u] - level[v]; for (int i = 0; diff; i++) { if (diff & 1) u = ancestor[i][u]; //b를 올려서 level을 맞춘다. diff >>= 1; } if (u == v) return u; for (int i = LOG_N; i >= 0; i--) { // a와 b가 다르다면 현재 깊이가 같으니, 깊이를 a,b동시에 계속 올려준다. if (ancestor[i][u] != ancestor[i][v]) u = ancestor[i][u], v = ancestor[i][v]; } return ancestor[0][u]; } int N, M, u, v, w; int32_t main() { ios::sync_with_stdio(0); cin.tie(0); cin >> N >> M; vector<node3> vv; for(int i=0;i<M;i++) { cin >> u >> v >> w; vv.push_back({ u,v,w }); } auto mst = MST_Kruskal(N, vv); result r = treefy(adj_list, 1, true); vector<vector<int>> max_dp(r.max_level + 1, vector<int>(N + 1, 0)); for (int i = 1; i <= N; i++) max_dp[0][i] = dist2[i]; for (int jump = 1; jump <= r.max_level; jump++) { for (int here = 1; here <= N; here++) { int tmp = ancestor[jump - 1][here]; if (tmp == -INF) continue; max_dp[jump][here] = max(max_dp[jump - 1][here], max_dp[jump - 1][tmp]); } } for(int i=0;i<M;i++) { auto [s, e, w] = vv[i]; int l = lca(s, e, level, ancestor); int mx = -1; int diff = level[s] - level[l]; for (int j = 0; diff; j++) { if (diff & 1) mx = max(mx, max_dp[j][s]), s = ancestor[j][s]; diff >>= 1; } diff = level[e] - level[l]; for (int j = 0; diff; j++) { if (diff & 1) mx = max(mx, max_dp[j][e]), e = ancestor[j][e]; diff >>= 1; } cout << mst + w - mx << '\n'; } return 0; } | cs |
반응형
'Programming > Problem Solving' 카테고리의 다른 글
double과 관련된 핸들링 (0) | 2021.12.26 |
---|---|
백준 4103 ATM (0) | 2020.05.05 |
BST 트리 구현 (0) | 2020.04.09 |
인접행렬, 인접리스트 (0) | 2020.04.09 |
[코뽕] AtCoder Beginner Contest 161 - D Lunlun Number (0) | 2020.04.05 |