본문 바로가기
Computer Science/Algorithm

[이코테] 14강. 자료구조: 바이너리 인덱스 트리

by 9루트 2022. 6. 1.

 

먼저 아래 문제부터 보자.

1. 데이터 업데이트가 가능한 상황에서 구간 합 (Interval Sum) 문제

 

BOJ '구간 합 구하기' 문제: https://www.acmicpc.net/problem/2042 

 

2042번: 구간 합 구하기

첫째 줄에 수의 개수 N(1 ≤ N ≤ 1,000,000)과 M(1 ≤ M ≤ 10,000), K(1 ≤ K ≤ 10,000) 가 주어진다. M은 수의 변경이 일어나는 횟수이고, K는 구간의 합을 구하는 횟수이다. 그리고 둘째 줄부터 N+1번째 줄

www.acmicpc.net

어떤 N 개의 수가 주어져 있다. 그런데 중간에 수의 변경이 번번히 일어나고 그 중간에 어떤 부분의 합을 구하려 한다.

만약에 1, 2, 3, 4, 5라는 수가 있고, 3번째 수를 6으로 바꾸고 2번째부터 5번째까지 합을 구하라고 한다면 17을 출력하면 되는 것이다. 그리고 그 상태에서 다섯 번째까지 합을 구하라고 한다면 12가 될 것이다.

  • 데이터 개수: N(1 ≤ N ≤ 1,000,000)
  • 데이터 변경 횟수: M(1 ≤ M ≤ 10,000)
  • 구간 합 계산 횟수: K(1 ≤ K ≤ 10,000)

이 문제를 어떻게 해결할 수 있을까?

 

 

즉, 데이터가 계속해서 업데이트 되고, 구간 합을 중간에 계속 구해야 하는 상황일 때 어떻게 해야 할까?

 

이 때 사용하는 자료구조가 바이너리 인덱스 트리입니다.

 


2. 바이너리 인덱스 트리 원리

바이너리 인덱스 트리(Binary Indexed Tree, BIT, 펜윅 트리) 

2진법 인덱스 구조를 활용해 구간 합 문제를 효과적으로 해결해 줄 수 있는 트리 자료구조

 

  • 정수에 따른 2진수 표기

음의 정수를 표현할 때 뒤집고 1을 더하면 된다.

  • 특정 숫자 K의 0이 아닌 마지막 비트를 찾는 방법
    • K & -K를 계산한다.
    • K & -K 계산 결과 예시 

  • 파이썬으로 구현
n = 8
for i in range(n + 1):
	print(i, "의 마지막 비트:", (i & -i))

 


3. 바이너리 인덱스 트리로 부분 합 구하기

1. 트리 구조 만들기 

0이 아닌 마지막 비트 = 내가 저장하고 있는 값들의 개수


2. 업데이트(Update) 

특정 값을 변경할 때

0이 아닌 마지막 비트만큼 더하면서 구간들의 값을 변경 (예시 = 3rd)

만약 3번째 인덱스 값이 바뀌었다면,

3번째의 0이 아닌 마지막 비트가 1이므로 3 + 1 = 4번째 인덱스 값을 변경한다.

또 4번째 인덱스 값이 바뀌었으므로, 4 + 4 = 8번째 인덱스 값을 변경한다.

또 8번째 인덱스 값이 바뀌었으므로, 8 + 8 = 16번째 인덱스 값을 변경한다.

 

따라서 3번째 인덱스 값이 바뀌었다면,

각 인덱스의 마지막 비트 만큼 인덱스가 더해지면서

3 → 4 → 8 → 16 번째 인덱스 값이 바뀌게 된다.

 

위 과정은 O( logN )을 보장한다.


3. 누적 합(Prefix Sum) 

1부터 N까지의 합(누적합) 구하기

0이 아닌 마지막 비트 만큼 빼면서 구간들의 값의 합 계산 (예시  = 11rd)

만약 1부터 11번째 인덱스까지의 누적합을 구하려고 한다면

11번째 인덱스 값과

11번째의 0이 아닌 마지막 비트 : 1

11 - 1 = 10번째 인덱스 값과

10번째의 0이 아닌 마지막 비트 : 2

10 - 2 = 8번째 인덱스 값을 

누적해서 더해주면 된다.

 

11번째 인덱스에는 11 하나에 대한 값을 담고 있고,

10번째 인덱스에는 9번째에서 10번째까지의 값의 합을 담고 있고,

8번째 인덱스에는 1번째에서 8번째까지의 값의 합을 담고 있다.

 

11 → 10  → 8번째 인덱스의 값을누적해서 더해주면 1부터 11까지의 총 11개의 원소에 대한 합을 구할 수 있다.

 

 

위 과정은 O( logN )을 보장한다.

 

 


4. 파이썬 코드로 구현

import sys
input = sys.stdin.readline

# 데이터의 개수(n), 변경 횟수(m), 구간 합 계산 횟수(k)
n, m, k = map(int, input().split())

# 전체 데이터의 개수는 최대 1,000,000개
arr = [0] * (n + 1)
tree = [0] * (n + 1)

# i번째 수까지의 누적 합을 계산하는 함수
def prefix_sum(i):
    result = 0
    while i > 0:
        result += tree[i]
        # 0이 아닌 마지막 비트만큼 빼가면서 이동
        i -= (i & -i)
    return result

# i번째 수를 dif만큼 더하는 함수
def update(i, dif):
    while i <= n:
        tree[i] += dif
        i += (i & -i)

# start부터 end까지의 구간 합을 계산하는 함수
def interval_sum(start, end):
    return prefix_sum(end) - prefix_sum(start - 1)

for i in range(1, n + 1):
    x = int(input())
    arr[i] = x
    update(i, x)

for i in range(m + k):
    a, b, c = map(int, input().split())
    # 업데이트(update) 연산인 경우
    if a == 1:
        update(b, c - arr[b]) # 바뀐 크기(dif)만큼 적용
        arr[b] = c
    # 구간 합(interval sum) 연산인 경우
    else:
        print(interval_sum(b, c))

 

 

요약

바이너리 인덱스 트리는 

i 가 1 ~ 7까지 a[i]의 합계를 구하라고 한다면

a[7] + sum(a[5 ~ 6]) + sum(a[1 ~ 4]) 식으로 

이진수의 개념을 그대로 차용하여 구간합을 빠르게 계산하도록 만든 자료구조 방법론이다.