POJ 3321: Apple Tree

問題

N個のノードからなる木構造が与えられており、初期状態では各ノードにリンゴがなっている。この木構造に対して2種類のクエリが合計M個与えられる。1つ目のクエリは、指定されたノードを根とする部分木に何個のリンゴがなっているかを答えるもの、もう1つは指定されたノードのリンゴの状態を反転させるもの(リンゴがあったら無くなり、もしあれば新しく実る)である。このクエリを高速に処理せよ。

制約条件

1 <= N, M <= 100000

解法

与えられる情報から木構造を隣接エッジリストで持ちます。それができたら深さ優先探索で各ノードに行きがけの時に番号を振ります。

これに加えて戻りがけの時の番号も保存しておくと、自分の子ノードの持つ番号の範囲が分かります(子ノードの番号は必ず連続する)。

この性質を利用すると、部分木の持つリンゴの数をBITを使ってO(logN)で計算することができます。

リンゴの状態を反転させるクエリもBITを使えばO(logN)で処理できるので、この方針で十分計算量は足りそうという気がします。

ただ、komiyamさんのブログにも書かれている通り、結構時間の条件が厳しく、木構造を双方向グラフで持つとTLEになります。

大した差ではない気もするのですが、以下に示すコードのように木のエッジを一方向のみ持つ(コメントアウトしてる行がポイント)ようにしたら、ギリギリ1875msくらいでACできました。

// POJ 3321: Apple Tree
#include <iostream>
#include <sstream>
#include <cstdio>
#include <cstdlib>
#include <cmath>
#include <ctime>
#include <cstring>
#include <climits>
#include <algorithm>
#include <functional>
#include <numeric>
#include <map>
#include <set>
#include <stack>
#include <vector>
#include <queue>
#include <bitset>
#include <string>
using namespace std;

#define REP(i,n) for(int i=0; i<n; i++)
#define REPP(i,s,e) for(int i=s; i<=e; i++)
#define PB(a) push_back(a)
#define MP(f,s) make_pair(f,s)
#define SZ(a) (int)(a).size()
#define ALL(a) (a).begin(), (a).end()
#define PRT(a) cerr << #a << " = " << (a) << endl
#define PRT2(a,b) cerr << #a << " = " << (a) << ", " << #b << " = " << (b) << endl
#define PRT3(a,b,c) cerr << (a) << ", " << (b) << ", " << (c) << endl

struct BIT {
    int sz;
    vector<int> nodes;
    BIT(int n) : sz(n), nodes(n+1, 0) {}
    void add(int x, int val) {
        for(; x<=sz; x+=(x&-x)) nodes[x] += val;
    }
    int sum(int x) {
        int ret = 0;
        for(; x>0; x-=(x&-x)) ret += nodes[x];
        return ret;
    }
    int sum(int x, int y) {
        return sum(y) - sum(x-1);
    }
};

int n, m;
vector<int> G[100011];
bool noex[100011];
int st[100011];
int en[100011];

void dfs(int v, int p, int& cnt) {
    st[v] = cnt;
    REP(i,G[v].size()) {
        dfs(G[v][i], v, ++cnt);
    }
    en[v] = cnt;
}

void coding() {
    int s, t;
    scanf("%d", &n);
    REP(i,n-1) {
        scanf("%d %d", &s, &t);
        G[s].push_back(t);
        // G[t].push_back(s);
    }

    int cnt = 1;
    dfs(1, -1, cnt);

    int x;
    char c[2];
    memset(noex, 0, sizeof(noex));
    BIT bit(n);
    REP(i,n) {
        bit.add(i+1, 1);
    }

    scanf("%d", &m);
    REP(i,m) {
        scanf("%s", c);
        if(c[0] == 'Q') {
            scanf("%d", &x);
            printf("%d\n", bit.sum(st[x], en[x]));
        } else {
            scanf("%d", &x);
            if(noex[x]) {
                bit.add(st[x], 1);
            } else {
                bit.add(st[x], -1);
            }
            noex[x] = !noex[x];
        }
    }
}

// #define _LOCAL_TEST

int main() {
#ifdef _LOCAL_TEST
    clock_t startTime = clock();
    freopen("a.in", "r", stdin);
    // freopen("a.out", "w", stdout);
#endif

    coding();

#ifdef _LOCAL_TEST
    clock_t elapsedTime = clock() - startTime;
    cout << endl;
    cerr << (elapsedTime / 1000.0) << " sec elapsed." << endl;
    cerr << "This is local test" << endl;
    cerr << "Do not forget to comment out _LOCAL_TEST" << endl << endl;
#endif
}