POJ 2892: Tunnel Warfare

問題

街がN個ならんでいて、ある街iはi-1の街とi+1の街と接続されている(両端は除く)。このとき、街を破壊するクエリ、最後に壊された街を復旧するクエリ、そして、指定された街からみて、連続な破壊されていない街の数を答えるクエリの3つがM個与えられる。素早くクエリを処理せよ。

制約条件

1 <= N, M <= 50000

解法

情報の見方を少し変えて、左からi番目の街までに壊された街の数を配列a[1],…,a[N]を計算する。

すると、連続な破壊されていない街の数は、指定された街iが持つ数a[i]の配列中でのlower_boundとupper_boundの差から1を引いたものになる。

もう少し噛み砕いて言うと、a[i]以上のものが初めて配列中に現れる番号がlower_boundで、a[i]より大きいものが初めて配列に現れる番号がupper_boundで、その差から1を引いたものになる。

で、情報をどう持つかというと、BIT上にa[i]を持って、更新のクエリの処理が素早く行えるようにしておく。

// POJ 2892: Tunnel Warfare
#include <iostream>
#include <sstream>
#include <cstdio>
#include <cstdlib>
#include <cmath>
#include <ctime>
#include <cstring>
#include <climits>
#include <algorithm>
#include <functional>
#include <numeric>
#include <map>
#include <set>
#include <stack>
#include <vector>
#include <queue>
#include <bitset>
#include <string>
using namespace std;

#define REP(i,n) for(int i=0; i<n; i++)
#define REPP(i,s,e) for(int i=s; i<=e; i++)
#define PB(a) push_back(a)
#define MP(f,s) make_pair(f,s)
#define SZ(a) (int)(a).size()
#define ALL(a) (a).begin(), (a).end()
#define PRT(a) cerr << #a << " = " << (a) << endl
#define PRT2(a,b) cerr << #a << " = " << (a) << ", " << #b << " = " << (b) << endl
#define PRT3(a,b,c) cerr << (a) << ", " << (b) << ", " << (c) << endl

struct BIT {
    int sz;
    vector<int> nodes;
    BIT(int n) : sz(n), nodes(n+1, 0) {}
    void add(int x, int val) {
        for(; x<=sz; x+=(x&-x)) nodes[x] += val;
    }
    int sum(int x) {
        int ret = 0;
        for(; x>0; x-=(x&-x)) ret += nodes[x];
        return ret;
    }
    int sum(int a, int b) {
        return sum(b) - sum(a-1);
    }

    int lb(int x) {
        int lo = 0;
        int hi = sz;
        while(lo < hi) {
            int mid = (lo + hi) /2;
            if(sum(mid) >= x) {
                hi = mid;
            } else {
                lo = mid + 1;
            }
        }
        return lo;
    }

    int ub(int x) {
        int lo = 0;
        int hi = sz;
        while(lo < hi) {
            int mid = (lo + hi) / 2;
            if(sum(mid) > x) {
                hi = mid;
            } else {
                lo = mid + 1;
            }
        }
        return lo;
    }
};

int n, m;
bool flg[50011];

void coding() {
    int x;
    char c[2];
    stack<int> stk;
    memset(flg, 0, sizeof(flg));

    scanf("%d %d", &n, &m);
    BIT bit(n+1);
    REP(i,m) {
        scanf("%s", c);
        if(strcmp(c, "D") == 0) {
            scanf("%d", &x);  
            bit.add(x, 1);
            flg[x] = true;
            stk.push(x);
        } 
        else if(strcmp(c, "Q") == 0) {
            scanf("%d", &x);
            if(flg[x]) printf("0\n");
            else {
                int s = bit.sum(x);
                printf("%d\n", bit.ub(s) - bit.lb(s) - 1);
            }
        } 
        else if(strcmp(c, "R") == 0) {
            bit.add(stk.top(), -1);
            flg[stk.top()] = false;
            stk.pop();
        }
    }
}

// #define _LOCAL_TEST

int main() {
#ifdef _LOCAL_TEST
    clock_t startTime = clock();
    freopen("a.in", "r", stdin);
    // freopen("a.out", "w", stdout);
#endif

    coding();

#ifdef _LOCAL_TEST
    clock_t elapsedTime = clock() - startTime;
    cout << endl;
    cerr << (elapsedTime / 1000.0) << " sec elapsed." << endl;
    cerr << "This is local test" << endl;
    cerr << "Do not forget to comment out _LOCAL_TEST" << endl << endl;
#endif
}

別解

上のBITを使ったコードの計算量はO(M(logN)^2)になるのですが、実はBITを使わなくてもsetを使った方法でO(MlogN)の解法があります。

こっちの方が実装も簡単なんだけど、なぜかBITを使ったコードの方が結果は早かったです(BITの方は360ms, setの方は547ms)。

STLの実装の問題で、オーバーヘッドが結構あるんでしょうか?わかる人がいたら教えてください(笑)。

// POJ 2892: Tunnel Warfare
#include <iostream>
#include <sstream>
#include <cstdio>
#include <cstdlib>
#include <cmath>
#include <ctime>
#include <cstring>
#include <climits>
#include <algorithm>
#include <functional>
#include <numeric>
#include <map>
#include <set>
#include <stack>
#include <vector>
#include <queue>
#include <bitset>
#include <string>
using namespace std;

#define REP(i,n) for(int i=0; i<n; i++)
#define REPP(i,s,e) for(int i=s; i<=e; i++)
#define PB(a) push_back(a)
#define MP(f,s) make_pair(f,s)
#define SZ(a) (int)(a).size()
#define ALL(a) (a).begin(), (a).end()
#define PRT(a) cerr << #a << " = " << (a) << endl
#define PRT2(a,b) cerr << #a << " = " << (a) << ", " << #b << " = " << (b) << endl
#define PRT3(a,b,c) cerr << (a) << ", " << (b) << ", " << (c) << endl

int n, m;
bool flg[50011];

void coding() {
    int x;
    char c[2];
    stack<int> stk;

    memset(flg, 0, sizeof(flg));

    scanf("%d %d", &n, &m);
    set<int> asc;
    set<int,greater<int> > dec;
    asc.insert(n+1);
    dec.insert(0);

    REP(i,m) {
        scanf("%s", c);
        if(strcmp(c, "D") == 0) {
            scanf("%d", &x);  
            asc.insert(x);
            dec.insert(x);
            flg[x] = true;
            stk.push(x);
        }
        else if(strcmp(c, "Q") == 0) {
            scanf("%d", &x);
            if(flg[x]) printf("0\n");
            else {
                set<int, greater<int> >::iterator lb = dec.lower_bound(x);
                set<int>::iterator ub = asc.lower_bound(x);
                printf("%d\n", *ub - *lb - 1);
            }
        } 
        else if(strcmp(c, "R") == 0) {
            asc.erase(stk.top());
            dec.erase(stk.top());
            flg[stk.top()] = false;
            stk.pop();
        }
    }
}

// #define _LOCAL_TEST

int main() {
#ifdef _LOCAL_TEST
    clock_t startTime = clock();
    freopen("a.in", "r", stdin);
    // freopen("a.out", "w", stdout);
#endif

    coding();

#ifdef _LOCAL_TEST
    clock_t elapsedTime = clock() - startTime;
    cout << endl;
    cerr << (elapsedTime / 1000.0) << " sec elapsed." << endl;
    cerr << "This is local test" << endl;
    cerr << "Do not forget to comment out _LOCAL_TEST" << endl << endl;
#endif
}