algorithm/theory

세그먼트 트리.

qkqhxla1 2016. 10. 11. 15:27

https://www.acmicpc.net/blog/view/9 의 내용을 정리해왔습니다.


예로 길이 N의 배열 A가 있고 여기서 L~R구간의 합을 M번 구해서 출력해야한다고 가정합시다.


수행해야 하는 연산의 시간복잡도는 O(M(R-L)) 번입니다. 그런데 앞에서부터 차례대로 합을


구해놓는 방식으로 문제를 풀어서 시간복잡도를 줄일수 있습니다.


S[i] = A[1] + ... + A[i] 라고 했을 때, i~j까지 합은 S[j] - S[i-1]라고 할수 있습니다.


만약 A배열의 중간에 어떤 값이 하나 바뀐다고 하면 S[i]값은 A배열의 처음부터 다시 세서


만들어줘야 합니다. 따라서 M과 중간의 값이 변경되는 양이 너무 크게 되면 시간이


너무 오래걸리게 됩니다. (O(M*변경되는양)의 시간복잡도가 걸리겠죠.)


세그먼트 트리를 이용하면 O(lgN)에 수행할 수 있습니다.


1. 세그먼트 트리의 리프 노드와 리프 토드가 아닌 다른 노드는 다음과 같은 의미를 갖습니다.


리프 노드 : 배열의 그 수 자체.

다른 노드 : 왼쪽 자식+오른쪽 자식.


A배열이 0~9까지 있다고 할때 세그먼트 트리는 다음과 같습니다.

만들기.


만약 N이 2의 제곱꼴인 경우에는 완전 이진 트리가 됩니다. 그때 높이는 lgN이고, 2*N-1개의 


노드가 필요합니다. 제곱꼴이 아닌 경우에는 높이가 [lgN]이고, 2^(H+1)-1개의 크기가 필요합니다.


아래와 같은 과정으로 세그먼트 트리를 만들수 있다고 합니다.

// a: 배열 a
// tree: 세그먼트 트리
// node: 세그먼트 트리 노드 번호
// node가 담당하는 합의 범위가 start ~ end
long long init(vector<long long> &a, vector<long long> &tree, int node, int start, int end) {
    if (start == end) {
        return tree[node] = a[start];
    } else {
        return tree[node] = init(a, tree, node*2, start, (start+end)/2) + init(a, tree, node*2+1, (start+end)/2+1, end);
    }
}

start==end인 경우는 노드가 리프 노드인 경우입니다. 


노드의 왼쪽 자식은 노드*2, 오른쪽 자식은 노드*2+1입니다. 또 노드가 담당하는 구간이 [start,end]라면


왼쪽 자식은 [start,(start+end)/2], 오른쪽 자식은 [(start+end)/2+1,end] 입니다.


구간 합 찾기.


구간 left, right가 주어졌을때 합을 찾으려면 루트부터 트리를 순회하면서 각 노드가 담당하는 구간과


left, right사이의 관계를 살펴봐야 합니다.


예로 0~9까지 합을 구하는 경우는 루트 노드 하나만으로 합을 알수 있습니다.


2~4까지 합을 구하는 경우는 다음과 같습니다.

3~9까지 합을 구하는 경우입니다.

노드가 담당하는 구간이 [start,end]이고 합을 구해야하는 구간이 [left, right]라면 다음과 같이


4가지 경우로 나눌 수 있습니다.


1. [left, right]와 [start, end]가 겹치지 않는 경우. 이 경우에는 if(left>end || right<start)로 나타낼 수 있습니다. 이 경우에는 더이상 탐색을 할 필요가 없습니다.


2. [left, right]가 [start,end] 를 완전히 포함하는 경우. 이 경우는 if(left<=start && end<=right)로 나타낼 수 있습니다. 이 경우도 더이상 탐색을 할 필요가 없으며 tree[node]를 리턴하고 종료합니다.


3. [start,end]가 [left, right]를 완전히 포함하는 경우

4. [left, right]와 [start, end]가 겹쳐져 있는 경우.


3,4의 경우에는 왼쪽 자식과 오른쪽 자식을 루트로 하는 트리에서 다시 탐색을 시작해야 합니다.

// node가 담당하는 구간이 start~end이고, 구해야하는 합의 범위는 left~right
long long sum(vector<long long> &tree, int node, int start, int end, int left, int right) {
    if (left > end || right < start) {
        return 0;
    }
    if (left <= start && end <= right) {
        return tree[node];
    }
    return sum(tree, node*2, start, (start+end)/2, left, right) + sum(tree, node*2+1, (start+end)/2+1, end, left, right);
}

수 변경.


중간에 어떤 수를 변경한다고 하면 그 숫자가 포함된 구간을 담당하는 노드를 모두 변경해야 합니다.


다음은 3번째 수를 변경할때 변경해야 하는 구간을 나타냅니다.

수 변경은 두가지 경우가 있습니다.


노드가 담당하는 구간에 index가 포함되는 경우, 포함안되는 경우.


노드의 구간에 포함되는 경우에는 diff = (바뀌어야하는수-원래수)의 값만큼 더해주고, 아닌 경우는 


탐색을 중단합니다.

void update(vector<long long> &tree, int node, int start, int end, int index, long long diff) {
    if (index < start || index > end) return;
    tree[node] = tree[node] + diff;
    if (start != end) {
        update(tree,node*2, start, (start+end)/2, index, diff);
        update(tree,node*2+1, (start+end)/2+1, end, index, diff);
    }
}

리프 노드가 아닌 경우에는 자식도 변경해줘야 합니다.


https://www.acmicpc.net/problem/2042 를 푸는 소스코드.

#include <cstdio>
#include <cmath>
#include <vector>
using namespace std;
long long init(vector<long long> &a, vector<long long> &tree, int node, int start, int end) {
    if (start == end) {
        return tree[node] = a[start];
    } else {
        return tree[node] = init(a, tree, node*2, start, (start+end)/2) + init(a, tree, node*2+1, (start+end)/2+1, end);
    }
}
void update(vector<long long> &tree, int node, int start, int end, int index, long long diff) {
    if (index < start || index > end) return;
    tree[node] = tree[node] + diff;
    if (start != end) {
        update(tree,node*2, start, (start+end)/2, index, diff);
        update(tree,node*2+1, (start+end)/2+1, end, index, diff);
    }
}
long long sum(vector<long long> &tree, int node, int start, int end, int left, int right) {
    if (left > end || right < start) {
        return 0;
    }
    if (left <= start && end <= right) {
        return tree[node];
    }
    return sum(tree, node*2, start, (start+end)/2, left, right) + sum(tree, node*2+1, (start+end)/2+1, end, left, right);
}
int main() {
    int n, m, k;
    scanf("%d %d %d",&n,&m,&k);
    vector<long long> a(n);
    int h = (int)ceil(log2(n));
    int tree_size = (1 << (h+1));
    vector<long long> tree(tree_size);
    m += k;
    for (int i=0; i<n; i++) {
        scanf("%lld",&a[i]);
    }
    init(a, tree, 1, 0, n-1);
    while (m--) {
        int t1,t2,t3;
        scanf("%d",&t1);
        if (t1 == 1) {
            int t2;
            long long t3;
            scanf("%d %lld",&t2,&t3);
            t2-=1;
            long long diff = t3-a[t2];
            a[t2] = t3;
            update(tree, 1, 0, n-1, t2, diff);
        } else if (t1 == 2) {
            int t2,t3;
            scanf("%d %d",&t2,&t3);
            printf("%lld\n",sum(tree, 1, 0, n-1, t2-1, t3-1));
        }
    }
    return 0;
}

파이썬 버전.

# -*- encoding: cp949 -*-
import math
def init(node,start,end):
    global a,tree
    if start==end:
        tree[node] = a[start]
        return a[start]
    else:
        tree[node] = init(node*2,start,(start+end)/2) + \
                     init(node*2+1,(start+end)/2+1,end)
        return init(node*2,start,(start+end)/2) + init(node*2+1,(start+end)/2+1,end)

def update(tree,node,start,end,index,diff):
    if index < start or index > end:
        return
    tree[node] = tree[node] + diff
    if start!=end:
        update(tree,node*2,start,(start+end)/2,index,diff)
        update(tree,node*2+1,(start+end)/2+1,end,index,diff)

def _sum(tree,node,start,end,left,right):
    if left > end or right < start:
        return 0
    if left <= start and end <= right:
        return tree[node]
    return _sum(tree, node*2,  start,           (start+end)/2,left,right) +\
           _sum(tree, node*2+1,(start+end)/2+1, end,          left,right)

n,m,k=map(int,raw_input().split())
h = int(math.ceil(math.log(n,2)))
tree_size = (1<<(h+1))
tree = [0 for i in xrange(tree_size)]
m += k
a=[input() for i in xrange(n)]

init(1,0,n-1)
for i in xrange(m):
    t1,t2,t3=map(int,raw_input().split())
    if t1==1:
        t2 -= 1
        diff = t3-a[t2]
        a[t2] = t3
        update(tree,1,0,n-1,t2,diff)
    elif t1==2:
        print _sum(tree,1,0,n-1,t2-1,t3-1)

'algorithm > theory' 카테고리의 다른 글

배낭(Knapsack) 알고리즘 (DP)  (0) 2016.10.30
Manacher's Algorithm  (0) 2016.10.23
최소 비용 최대 유량(MCMF)  (2) 2016.10.08
Minimum cut  (0) 2016.10.07
SCC, SAT.(boolean satisfiability problem)  (0) 2016.09.30