Problem Solving

이분 탐색(Binary Search) 구현

limdef 2023. 6. 17. 17:50

 

https://www.acmicpc.net/blog/view/109

 

이분 탐색(Binary Search) 헷갈리지 않게 구현하기

개요 이분 탐색은 off-by-one error가 발생하기 쉬워서 늘 헷갈립니다. 이분 탐색 문제를 풀다보면 탈출 조건으로 lo <= hi, lo < hi, lo + 1 < hi 중 어느 걸 선택해야 할 지, 정답이 lo인지 hi인지 (lo + hi) / 2

www.acmicpc.net

 

이분 탐색 구현할 때마다 항상 헷갈려서 시간이 걸리게 되는데, 위의 글 읽고 정리해본다.

 

구현 방법 

먼저 결정 문제 설정과 분포 파악해야한다.

 

문제마다 결정 문제가 다른데 이를 먼저 설정하고, 결정 문제에 따른 True, False 결과의 분포를 생각해본다.

 

예를 들면 1~50번의 카드 중 28번 카드를 찾는다고 하자.

1~50번의 카드를 v[i] 찾는 카드를 val 이라 할 때, 결정 문제는 val <= v[i]가 되고 분포는 F, F, ... F, T(i == 28), T, ... T 가 된다.

 

lo+1 < hi 동안 범위를 줄여나가는데 그 동안 check(lo) (=F)와 check(hi) (=T)의 값은 바뀌지 않게 한다.

check(lo) == check(mid) 라면 lo = mid를, check(hi) == check(mid)이면 hi = mid를 해주면 된다.

또한 lo+1 < hi 이기 때문에 lo와 hi 사이에는 무조건 한 개 이상의 칸이 존재, lo < mid < hi를 만족한다.

lo+1 == hi가 되면 반복문을 탈출한다.

 

반복문을 탈출했다면 lo+1 >= hi 인데, 항상 lo < mid < hi 이기 때문에 lo < hi 이면서 lo+1 >= hi를 만족하는 경우는 lo+1 == hi 밖에 없음.

 

이분 탐색이 끝나고 lo, hi 는 F, T의 경계에 위치하게 되고 

가장 큰 F를 찾는 문제라면 lo, 가장 작은 T를 찾는 문제라면 hi를 반환한다.

 

1. 결정 문제를 check() 함수라 할 때, check(lo) != check(hi)가 되도록 lo, hi 설정

 

2. check(lo) == check(mid) 면 lo = mid , 아니면 hi = mid

 

3. lo + 1== hi 가 되면서 탈출. lo, hi 는 경계에 위치

 

4. 문제에 따라 lo , hi 반환 

 

 

bool check(int mid, int target) {
	if(v[mid] >= target) return true;
	return false;
}


int binary_search(int lo, int hi) {

	// check(lo) == false != check(hi) == true
    // F, ... , F, T, ... , T

	while (lo + 1 < hi) {
    	int mid = (lo + hi) / 2; // lo < mid < hi 를 항상 만족
        if(!check(mid)) lo = mid;
        else hi = mid;
    }
    
    // lo+1 == hi 가 되며 탈출하고 lo, hi는 경계에 위치
    // 문제에 따라 lo, hi 반환
    return hi;
}

 

 

lower_bound, upper_bound

lower_bound는 v[i-1] < k <= v[i]인 i를 찾아주는 함수로, k <= v[i]인 i의 최솟값을 반환.

만약 v의 모든 원소가 k보다 작다면 v의 마지막 다음칸의 위치를 반환.

 

upper_bound는 v[i-1] <= k < v[i]인 i를 찾아주는 함수로, k < v[i] 인 최솟값을 반환.

만약 모든 원소가 k보다 작거나 같다면 v의 마지막 다음칸의 위치를 반환.

#include <bits/stdc++.h>
#define fastio cin.tie(0)->sync_with_stdio(0)
using namespace std;

int LowerBound(const vector<int>& v, int x) {
    const int n = v.size();
    int lo = -1, hi = n;
    while (lo + 1 < hi) {
        int mid = (lo + hi) / 2;
        if (!(v[mid] >= x)) lo = mid;
        else hi = mid;
    }
    return hi;
}

int UpperBound(const vector<int>& v, int x) {
    const int n = v.size();
    int lo = -1, hi = n;
    while (lo + 1 < hi) {
        int mid = (lo + hi) / 2;
        if (!(v[mid] > x)) lo = mid;
        else hi = mid;
    }
    return hi;
}

int main() {
    fastio;
    vector<int> v = { 1, 2, 3, 3, 4 };
    cout << LowerBound(v, 3) << '\n'; // 2
    cout << UpperBound(v, 3) << '\n'; // 4
    cout << UpperBound(v, 3) - LowerBound(v, 3) << '\n'; // 2
}

(hi는 v의 모든 원소가 k보다 작은(작거나 같은) 경우 n을 반환해야 하기 때문에 처음에 n이상으로 설정해야 합니다. 또한 hi는 최소 0까지 감소할 수 있어야 하기 때문에 lo = -1로 설정해야 합니다)