POJ 2892: Tunnel Warfare
問題
街がN個ならんでいて、ある街iはi-1の街とi+1の街と接続されている(両端は除く)。このとき、街を破壊するクエリ、最後に壊された街を復旧するクエリ、そして、指定された街からみて、連続な破壊されていない街の数を答えるクエリの3つがM個与えられる。素早くクエリを処理せよ。
制約条件
1 <= N, M <= 50000
解法
情報の見方を少し変えて、左からi番目の街までに壊された街の数を配列a[1],…,a[N]を計算する。
すると、連続な破壊されていない街の数は、指定された街iが持つ数a[i]の配列中でのlower_boundとupper_boundの差から1を引いたものになる。
もう少し噛み砕いて言うと、a[i]以上のものが初めて配列中に現れる番号がlower_boundで、a[i]より大きいものが初めて配列に現れる番号がupper_boundで、その差から1を引いたものになる。
で、情報をどう持つかというと、BIT上にa[i]を持って、更新のクエリの処理が素早く行えるようにしておく。
// POJ 2892: Tunnel Warfare
#include <iostream>
#include <sstream>
#include <cstdio>
#include <cstdlib>
#include <cmath>
#include <ctime>
#include <cstring>
#include <climits>
#include <algorithm>
#include <functional>
#include <numeric>
#include <map>
#include <set>
#include <stack>
#include <vector>
#include <queue>
#include <bitset>
#include <string>
using namespace std;
#define REP(i,n) for(int i=0; i<n; i++)
#define REPP(i,s,e) for(int i=s; i<=e; i++)
#define PB(a) push_back(a)
#define MP(f,s) make_pair(f,s)
#define SZ(a) (int)(a).size()
#define ALL(a) (a).begin(), (a).end()
#define PRT(a) cerr << #a << " = " << (a) << endl
#define PRT2(a,b) cerr << #a << " = " << (a) << ", " << #b << " = " << (b) << endl
#define PRT3(a,b,c) cerr << (a) << ", " << (b) << ", " << (c) << endl
struct BIT {
int sz;
vector<int> nodes;
BIT(int n) : sz(n), nodes(n+1, 0) {}
void add(int x, int val) {
for(; x<=sz; x+=(x&-x)) nodes[x] += val;
}
int sum(int x) {
int ret = 0;
for(; x>0; x-=(x&-x)) ret += nodes[x];
return ret;
}
int sum(int a, int b) {
return sum(b) - sum(a-1);
}
int lb(int x) {
int lo = 0;
int hi = sz;
while(lo < hi) {
int mid = (lo + hi) /2;
if(sum(mid) >= x) {
hi = mid;
} else {
lo = mid + 1;
}
}
return lo;
}
int ub(int x) {
int lo = 0;
int hi = sz;
while(lo < hi) {
int mid = (lo + hi) / 2;
if(sum(mid) > x) {
hi = mid;
} else {
lo = mid + 1;
}
}
return lo;
}
};
int n, m;
bool flg[50011];
void coding() {
int x;
char c[2];
stack<int> stk;
memset(flg, 0, sizeof(flg));
scanf("%d %d", &n, &m);
BIT bit(n+1);
REP(i,m) {
scanf("%s", c);
if(strcmp(c, "D") == 0) {
scanf("%d", &x);
bit.add(x, 1);
flg[x] = true;
stk.push(x);
}
else if(strcmp(c, "Q") == 0) {
scanf("%d", &x);
if(flg[x]) printf("0\n");
else {
int s = bit.sum(x);
printf("%d\n", bit.ub(s) - bit.lb(s) - 1);
}
}
else if(strcmp(c, "R") == 0) {
bit.add(stk.top(), -1);
flg[stk.top()] = false;
stk.pop();
}
}
}
// #define _LOCAL_TEST
int main() {
#ifdef _LOCAL_TEST
clock_t startTime = clock();
freopen("a.in", "r", stdin);
// freopen("a.out", "w", stdout);
#endif
coding();
#ifdef _LOCAL_TEST
clock_t elapsedTime = clock() - startTime;
cout << endl;
cerr << (elapsedTime / 1000.0) << " sec elapsed." << endl;
cerr << "This is local test" << endl;
cerr << "Do not forget to comment out _LOCAL_TEST" << endl << endl;
#endif
}
別解
上のBITを使ったコードの計算量はO(M(logN)^2)になるのですが、実はBITを使わなくてもsetを使った方法でO(MlogN)の解法があります。
こっちの方が実装も簡単なんだけど、なぜかBITを使ったコードの方が結果は早かったです(BITの方は360ms, setの方は547ms)。
STLの実装の問題で、オーバーヘッドが結構あるんでしょうか?わかる人がいたら教えてください(笑)。
// POJ 2892: Tunnel Warfare
#include <iostream>
#include <sstream>
#include <cstdio>
#include <cstdlib>
#include <cmath>
#include <ctime>
#include <cstring>
#include <climits>
#include <algorithm>
#include <functional>
#include <numeric>
#include <map>
#include <set>
#include <stack>
#include <vector>
#include <queue>
#include <bitset>
#include <string>
using namespace std;
#define REP(i,n) for(int i=0; i<n; i++)
#define REPP(i,s,e) for(int i=s; i<=e; i++)
#define PB(a) push_back(a)
#define MP(f,s) make_pair(f,s)
#define SZ(a) (int)(a).size()
#define ALL(a) (a).begin(), (a).end()
#define PRT(a) cerr << #a << " = " << (a) << endl
#define PRT2(a,b) cerr << #a << " = " << (a) << ", " << #b << " = " << (b) << endl
#define PRT3(a,b,c) cerr << (a) << ", " << (b) << ", " << (c) << endl
int n, m;
bool flg[50011];
void coding() {
int x;
char c[2];
stack<int> stk;
memset(flg, 0, sizeof(flg));
scanf("%d %d", &n, &m);
set<int> asc;
set<int,greater<int> > dec;
asc.insert(n+1);
dec.insert(0);
REP(i,m) {
scanf("%s", c);
if(strcmp(c, "D") == 0) {
scanf("%d", &x);
asc.insert(x);
dec.insert(x);
flg[x] = true;
stk.push(x);
}
else if(strcmp(c, "Q") == 0) {
scanf("%d", &x);
if(flg[x]) printf("0\n");
else {
set<int, greater<int> >::iterator lb = dec.lower_bound(x);
set<int>::iterator ub = asc.lower_bound(x);
printf("%d\n", *ub - *lb - 1);
}
}
else if(strcmp(c, "R") == 0) {
asc.erase(stk.top());
dec.erase(stk.top());
flg[stk.top()] = false;
stk.pop();
}
}
}
// #define _LOCAL_TEST
int main() {
#ifdef _LOCAL_TEST
clock_t startTime = clock();
freopen("a.in", "r", stdin);
// freopen("a.out", "w", stdout);
#endif
coding();
#ifdef _LOCAL_TEST
clock_t elapsedTime = clock() - startTime;
cout << endl;
cerr << (elapsedTime / 1000.0) << " sec elapsed." << endl;
cerr << "This is local test" << endl;
cerr << "Do not forget to comment out _LOCAL_TEST" << endl << endl;
#endif
}