삼성 소프트웨어 (SW) 역량테스트 B형 준비! ① (BST)

2019. 9. 14. 17:53알고리즘/삼성SW역량테스트 B형 준비

문제

 

https://www.acmicpc.net/problem/2042

 

풀이

 

N => 수의 개수, 10^6

M => 업데이트 횟수, 10^4

K  => 구간합을 구하는 횟수, 10^4

 

일반적인 방법으로, 구간합을 구할 떄 마다 배열의 합을 계산할 경우,

한번의 구간합마다 O(N) 이 걸리게 될 것입니다.

이 경우, 최대 K*N = 10^10 이 걸리게 되므로, 문제를 제한시간 내에 풀 수 없습니다.

 

따라서, 구간합을 구할 때 O(logN) 이하의 알고리즘이 필요합니다.

이를 구현하기 위하여 Segment tree를 아래와 같이 구현합니다.

 

https://www.hackerearth.com/practice/data-structures/advanced-data-structures/segment-trees/tutorial/ 참조

트리의 Node가 begin~end 까지의 구간합을 갖고 있다고 가정할 때, 왼쪽 Subtree는 begin~1/2*end 오른쪽 Subtree는 1/2*end+1 ~ end 까지의 구간합을 갖는 트리를 갖도록 구현합니다.

이 때 Root Node 는 1~N까지의 구간합을 갖고 있게 구현합니다.

 

트리의 구현부 코드입니다.

* Input

 

 

* Node

/*
    begin: Node 가 표현하는 구간합의 시작 index,
    end: Node가 표현하는 구간합의 마지막 index,
    value: begin~end 까지의 구간합
    left: 왼쪽자식노드
    right: 오른쪽 자식노드
    parent: 부모노드
*/
class Node {
    public: 
        int begin, end;
        long long value;
        Node* parent;
        Node* left;
        Node* right;
        Node();
        Node(int left, int right);
};

 

- Tree 의 생성

// arr 는 초기에 input 으로 받은 N개의 수의 배열입니다.
Node* BST::create(int begin, int end) {

    Node* node = new Node(begin, end);
    if ( begin < end ) {
        int center = (begin + end) / 2;
        node->left = this->create(begin, center);
        node->left->parent = node;
        node->right = this->create(center+1, end);
        node->right->parent = node;
        node->value = node->left->value + node->right->value;
    } else {
        if ( begin != end ) {
            cout << "invalid case, please check logic!" << endl;
        } else {
        	// if begin==end, we can specify value.
            int idx = begin;
            node->value = this->arr[idx];
        }
    }  
    return node;
}

 

이제, 트리의 생성이 완료되었습니다.

트리의 생성 이후, 우리는 두가지 액션을 지원해야합니다.

 

1. 수의 변경

주어진 수의 변경의 경우, 우리는 트리의 노드를 업데이트 해주어야합니다.

이 때 바뀌는 수가 속해있는 모든 노드들의 업데이트가 필요합니다.

Root node 부터 실제 수와 업데이트 될 숫자의 차이(delta) 만큼 더해주는 작업을 바뀌는 수가 속해있는 노드들에 진행해줍니다.  (begin~end 사이에 바뀔 index가 속해있다면 업데이트 진행)

void BST::update(Node* node, int idx, int delta) {
    node->value += delta;
    if ( node->begin == node->end ) return ;
    else {
        if(node->left->end >= idx) {
            update(node->left, idx, delta);
        } else {
            update(node->right, idx, delta);
        }
    }
}

 

 

 

2. 구간합 구하기

A~B까지의 구간합을 구하고 싶다고 가정해봅시다.

 

만들어진 Tree의 Node들을 방문하며 구간합을 구한다고 할 때 아래와 같은 case들 중 하나를 만나게 됩니다.

 

1) 해당 Node의 begin~end 가 A~B 와 정확히 일치하는경우

=> 바로 이 값을 구간합으로 사용하면 됩니다.

 

2) 해당 Node의 begin~end 가 A~B 사이에 있는 경우.

=> 해당 Node의 값을 결과로 사용할 값에 더해줍니다.

이 때, 해당 Node의 하위노드는 탐색할 필요가 없습니다.

 

3) 해당 Node의 begin~end 가 A~B에 걸쳐있는 경우

=>걸쳐있는 경우, 실제로 구간합에 사용될 구간과 사용되지 않을 구간을 구분하는 것이 필요합니다.

하위노드들을 검색하여 더 자세하게 구간을 나눌 수 있도록 합니다.

 

5) 해당 Node의 begin~end 가 A~B 어디에도 걸쳐있지 않은 경우

=> 해당 Node는 구간합에 사용되지 않을 Node입니다. 무시 해 줍시다.

 

long long BST::get(Node* node, int begin, int end){
    if (begin <= node->begin && node->end <= end) {
        return node->value;
    } else if( end < node->begin || node->end < begin) {
        return 0;
    } else {
        return get(node->left, begin, end) + get(node->right, begin, end);
    }
}

 

3. 전체코드

 

/*
Problem : https://www.acmicpc.net/problem/2042
Approach :

N => 수의 개수, 10^6
M => 업데이트 횟수, 10^4
K => 구간합을 구하는 횟수, 10^4

1초당 10^8 개의 연산을 수행할 수 있다고 가정했을 때,
업데이트 치는 함수인 doUpdate() 는 O(logN) 안에 수행되어야 한다.
구간합을 구하는 함수인 getSummation() 는 O(logN) 안에 수행되어야 한다

Summation tree를 만들어서,
update는 leaf node 부터 치자
summation은 재귀함수로 구현하자.
*/

#include <iostream>
#include <fstream>

using namespace std;

ifstream fin;
ofstream fout; 

void init();

/*
    begin: Node 가 표현하는 구간합의 시작 index,
    end: Node가 표현하는 구간합의 마지막 index,
    value: begin~end 까지의 구간합
    left: 왼쪽자식노드
    right: 오른쪽 자식노드
    parent: 부모노드
*/
class Node {
    public: 
        int begin, end;
        long long value;
        Node* parent;
        Node* left;
        Node* right;
        Node();
        Node(int left, int right);
};

Node::Node(int begin, int end){
    this->begin = begin;
    this->end = end;
}

class BST { 
    public:
        BST(long long* arr);
        long long *arr;
        Node* root;
        Node* create(int begin, int end);
        void update(Node* node, int idx, int delta);
        long long get(Node* node, int begin, int end);
};

BST::BST(long long* arr) {
    this->arr = arr;
}

Node* BST::create(int begin, int end) {

    Node* node = new Node(begin, end);
    if ( begin < end ) {
        int center = (begin + end) / 2;
        node->left = this->create(begin, center);
        node->left->parent = node;
        node->right = this->create(center+1, end);
        node->right->parent = node;
        node->value = node->left->value + node->right->value;
    } else {
        if ( begin != end ) {
            cout << "invalid case, please check logic!" << endl;
        } else {
            int idx = begin;
            node->value = this->arr[idx];
        }
    }  
    return node;
}

void BST::update(Node* node, int idx, int delta) {
    node->value += delta;
    if ( node->begin == node->end ) return ;
    else {
        if(node->left->end >= idx) {
            update(node->left, idx, delta);
        } else {
            update(node->right, idx, delta);
        }
    }
}

long long BST::get(Node* node, int begin, int end){
    if (begin <= node->begin && node->end <= end) {
        return node->value;
    } else if( end < node->begin || node->end < begin) {
        return 0;
    } else {
        return get(node->left, begin, end) + get(node->right, begin, end);
    }
}

int main () {
    int N, M, K;
    init();
    // input part
    cin >> N >> M >> K;
    long long arr[N];

    for(int i = 0 ; i < N; i++){
        cin >> arr[i];
    }

    BST* bst = new BST(arr);
    bst->root = bst->create(0, N-1);

    for(int i = 0 ; i < M + K; i++){
        int cmd;
        int a;
        int b;
        cin >> cmd >> a >> b;
        switch(cmd){
            case 1:
                bst->update(bst->root, a-1, b - arr[a-1]);
                arr[a-1] = b;
                break;
            case 2:
                long long result = bst->get(bst->root, a-1, b-1);
                cout << result << endl;
                break;
        }
    }

    return 0;
}

void init() {
    fin.open("resource/input.txt");
    fout.open("resource/output.txt");
}

 

 

자, 이제 모든 구현이 완료되었습니다.

축하합니다 !