https://www.acmicpc.net/blog/view/9 의 내용을 정리해왔습니다.
예로 길이 N의 배열 A가 있고 여기서 L~R구간의 합을 M번 구해서 출력해야한다고 가정합시다.
수행해야 하는 연산의 시간복잡도는 O(M(R-L)) 번입니다. 그런데 앞에서부터 차례대로 합을
구해놓는 방식으로 문제를 풀어서 시간복잡도를 줄일수 있습니다.
S[i] = A[1] + ... + A[i] 라고 했을 때, i~j까지 합은 S[j] - S[i-1]라고 할수 있습니다.
만약 A배열의 중간에 어떤 값이 하나 바뀐다고 하면 S[i]값은 A배열의 처음부터 다시 세서
만들어줘야 합니다. 따라서 M과 중간의 값이 변경되는 양이 너무 크게 되면 시간이
너무 오래걸리게 됩니다. (O(M*변경되는양)의 시간복잡도가 걸리겠죠.)
세그먼트 트리를 이용하면 O(lgN)에 수행할 수 있습니다.
1. 세그먼트 트리의 리프 노드와 리프 토드가 아닌 다른 노드는 다음과 같은 의미를 갖습니다.
리프 노드 : 배열의 그 수 자체.
다른 노드 : 왼쪽 자식+오른쪽 자식.
A배열이 0~9까지 있다고 할때 세그먼트 트리는 다음과 같습니다.
만들기.
만약 N이 2의 제곱꼴인 경우에는 완전 이진 트리가 됩니다. 그때 높이는 lgN이고, 2*N-1개의
노드가 필요합니다. 제곱꼴이 아닌 경우에는 높이가 [lgN]이고, 2^(H+1)-1개의 크기가 필요합니다.
아래와 같은 과정으로 세그먼트 트리를 만들수 있다고 합니다.
// a: 배열 a // tree: 세그먼트 트리 // node: 세그먼트 트리 노드 번호 // node가 담당하는 합의 범위가 start ~ end long long init(vector<long long> &a, vector<long long> &tree, int node, int start, int end) { if (start == end) { return tree[node] = a[start]; } else { return tree[node] = init(a, tree, node*2, start, (start+end)/2) + init(a, tree, node*2+1, (start+end)/2+1, end); } }
start==end인 경우는 노드가 리프 노드인 경우입니다.
노드의 왼쪽 자식은 노드*2, 오른쪽 자식은 노드*2+1입니다. 또 노드가 담당하는 구간이 [start,end]라면
왼쪽 자식은 [start,(start+end)/2], 오른쪽 자식은 [(start+end)/2+1,end] 입니다.
구간 합 찾기.
구간 left, right가 주어졌을때 합을 찾으려면 루트부터 트리를 순회하면서 각 노드가 담당하는 구간과
left, right사이의 관계를 살펴봐야 합니다.
예로 0~9까지 합을 구하는 경우는 루트 노드 하나만으로 합을 알수 있습니다.
2~4까지 합을 구하는 경우는 다음과 같습니다.
3~9까지 합을 구하는 경우입니다.
노드가 담당하는 구간이 [start,end]이고 합을 구해야하는 구간이 [left, right]라면 다음과 같이
4가지 경우로 나눌 수 있습니다.
1. [left, right]와 [start, end]가 겹치지 않는 경우. 이 경우에는 if(left>end || right<start)로 나타낼 수 있습니다. 이 경우에는 더이상 탐색을 할 필요가 없습니다.
2. [left, right]가 [start,end] 를 완전히 포함하는 경우. 이 경우는 if(left<=start && end<=right)로 나타낼 수 있습니다. 이 경우도 더이상 탐색을 할 필요가 없으며 tree[node]를 리턴하고 종료합니다.
3. [start,end]가 [left, right]를 완전히 포함하는 경우
4. [left, right]와 [start, end]가 겹쳐져 있는 경우.
3,4의 경우에는 왼쪽 자식과 오른쪽 자식을 루트로 하는 트리에서 다시 탐색을 시작해야 합니다.
// node가 담당하는 구간이 start~end이고, 구해야하는 합의 범위는 left~right long long sum(vector<long long> &tree, int node, int start, int end, int left, int right) { if (left > end || right < start) { return 0; } if (left <= start && end <= right) { return tree[node]; } return sum(tree, node*2, start, (start+end)/2, left, right) + sum(tree, node*2+1, (start+end)/2+1, end, left, right); }
수 변경.
중간에 어떤 수를 변경한다고 하면 그 숫자가 포함된 구간을 담당하는 노드를 모두 변경해야 합니다.
다음은 3번째 수를 변경할때 변경해야 하는 구간을 나타냅니다.
수 변경은 두가지 경우가 있습니다.
노드가 담당하는 구간에 index가 포함되는 경우, 포함안되는 경우.
노드의 구간에 포함되는 경우에는 diff = (바뀌어야하는수-원래수)의 값만큼 더해주고, 아닌 경우는
탐색을 중단합니다.
void update(vector<long long> &tree, int node, int start, int end, int index, long long diff) { if (index < start || index > end) return; tree[node] = tree[node] + diff; if (start != end) { update(tree,node*2, start, (start+end)/2, index, diff); update(tree,node*2+1, (start+end)/2+1, end, index, diff); } }
리프 노드가 아닌 경우에는 자식도 변경해줘야 합니다.
https://www.acmicpc.net/problem/2042 를 푸는 소스코드.
#include <cstdio> #include <cmath> #include <vector> using namespace std; long long init(vector<long long> &a, vector<long long> &tree, int node, int start, int end) { if (start == end) { return tree[node] = a[start]; } else { return tree[node] = init(a, tree, node*2, start, (start+end)/2) + init(a, tree, node*2+1, (start+end)/2+1, end); } } void update(vector<long long> &tree, int node, int start, int end, int index, long long diff) { if (index < start || index > end) return; tree[node] = tree[node] + diff; if (start != end) { update(tree,node*2, start, (start+end)/2, index, diff); update(tree,node*2+1, (start+end)/2+1, end, index, diff); } } long long sum(vector<long long> &tree, int node, int start, int end, int left, int right) { if (left > end || right < start) { return 0; } if (left <= start && end <= right) { return tree[node]; } return sum(tree, node*2, start, (start+end)/2, left, right) + sum(tree, node*2+1, (start+end)/2+1, end, left, right); } int main() { int n, m, k; scanf("%d %d %d",&n,&m,&k); vector<long long> a(n); int h = (int)ceil(log2(n)); int tree_size = (1 << (h+1)); vector<long long> tree(tree_size); m += k; for (int i=0; i<n; i++) { scanf("%lld",&a[i]); } init(a, tree, 1, 0, n-1); while (m--) { int t1,t2,t3; scanf("%d",&t1); if (t1 == 1) { int t2; long long t3; scanf("%d %lld",&t2,&t3); t2-=1; long long diff = t3-a[t2]; a[t2] = t3; update(tree, 1, 0, n-1, t2, diff); } else if (t1 == 2) { int t2,t3; scanf("%d %d",&t2,&t3); printf("%lld\n",sum(tree, 1, 0, n-1, t2-1, t3-1)); } } return 0; }
파이썬 버전.
# -*- encoding: cp949 -*- import math def init(node,start,end): global a,tree if start==end: tree[node] = a[start] return a[start] else: tree[node] = init(node*2,start,(start+end)/2) + \ init(node*2+1,(start+end)/2+1,end) return init(node*2,start,(start+end)/2) + init(node*2+1,(start+end)/2+1,end) def update(tree,node,start,end,index,diff): if index < start or index > end: return tree[node] = tree[node] + diff if start!=end: update(tree,node*2,start,(start+end)/2,index,diff) update(tree,node*2+1,(start+end)/2+1,end,index,diff) def _sum(tree,node,start,end,left,right): if left > end or right < start: return 0 if left <= start and end <= right: return tree[node] return _sum(tree, node*2, start, (start+end)/2,left,right) +\ _sum(tree, node*2+1,(start+end)/2+1, end, left,right) n,m,k=map(int,raw_input().split()) h = int(math.ceil(math.log(n,2))) tree_size = (1<<(h+1)) tree = [0 for i in xrange(tree_size)] m += k a=[input() for i in xrange(n)] init(1,0,n-1) for i in xrange(m): t1,t2,t3=map(int,raw_input().split()) if t1==1: t2 -= 1 diff = t3-a[t2] a[t2] = t3 update(tree,1,0,n-1,t2,diff) elif t1==2: print _sum(tree,1,0,n-1,t2-1,t3-1)
'algorithm > theory' 카테고리의 다른 글
배낭(Knapsack) 알고리즘 (DP) (0) | 2016.10.30 |
---|---|
Manacher's Algorithm (0) | 2016.10.23 |
최소 비용 최대 유량(MCMF) (2) | 2016.10.08 |
Minimum cut (0) | 2016.10.07 |
SCC, SAT.(boolean satisfiability problem) (0) | 2016.09.30 |