POJ 3378: Crazy Thairs

問題

長さがNの数列が与えられる。その数列の長さ5の部分列で狭義単調増加するものの個数を答えよ。

制約条件

1 <= N <= 50000

解法

まず、答えがどのくらい大きくなるかを考えます。答えが一番大きくなるのは1, 2, 3, …, 50000の時で、その時の解は$_{50000} C_5 \approx 2 \times 10^{31}$になるので、64bit整数でも収まりません。

なのでおとなしく多倍長演算を書きます。今回は足し算だけあればいいと思います。

つづいて、問題の本題に入っていきますが、解き方は端的に言うとDPです。

DP表は
dp[i][j] = i番目の数が最後に来る狭義単調増加部分列で長さがjになるものの数
というように持ちます。

更新するときは、自分より小さな番号かつ、自分より小さな数が最後に来る部分列の長さが1増えることになるので、
dp[i][j+1] = sum { dp[k][j] where Ak < Ai and k < i }
となります。

このように2種類の順序を扱う時はBITが定番なので、今回もBITを使います。

与えられた数列を数が小さいものから順に処理していきます。処理の過程では数が自分より小さいものしかBITに入っていないので、添え字番号の大小だけを考慮し、BITを用いた和を計算します。

このときポイントは、数が同じものは添え字が大きい順に処理しないといけないということです。

例えば 1, 2, 3, 4, 5, 4, 5 という数列の答えは3ですが、添え字が小さい順に処理してしまうと、一番後ろの5を処理するときに、その前に出現している5の時に計算した値が数えられてしまいます。

多倍長演算があるのでかなり実装は重めでした。Submit何回したかなぁ(笑)

// POJ 3378: Crazy Thairs
#include <iostream>
#include <sstream>
#include <cstdio>
#include <cstdlib>
#include <cmath>
#include <ctime>
#include <cstring>
#include <climits>
#include <algorithm>
#include <functional>
#include <numeric>
#include <map>
#include <set>
#include <stack>
#include <vector>
#include <queue>
#include <bitset>
#include <string>
using namespace std;

#define REP(i,n) for(int i=0; i<n; i++)
#define REPP(i,s,e) for(int i=s; i<=e; i++)
#define PB(a) push_back(a)
#define MP(f,s) make_pair(f,s)
#define SZ(a) (int)(a).size()
#define ALL(a) (a).begin(), (a).end()
#define PRT(a) cerr << #a << " = " << (a) << endl
#define PRT2(a,b) cerr << #a << " = " << (a) << ", " << #b << " = " << (b) << endl
#define PRT3(a,b,c) cerr << (a) << ", " << (b) << ", " << (c) << endl

typedef long long lint;
typedef pair<lint,int> P;

struct LargeInteger {
    static const int N = 3;
    static const int mod = 100000000;
    int v[N];
    LargeInteger() : v() { memset(v, 0, sizeof(int) * N); }
    LargeInteger& operator+=(const LargeInteger& li) {
        int up = 0;
        for(int i=0; i<N; i++) {
            int a = v[i] + li.v[i] + up;
            v[i] = a % mod;
            up = a / mod;
        }
        return *this;
    }

    LargeInteger(const LargeInteger& li) : v() {
        memcpy(v, li.v, sizeof(int) * N);
    }

    LargeInteger& operator=(const LargeInteger& li) {
        memcpy(v, li.v, sizeof(int) * N);
        return *this;
    }

    LargeInteger(long long x) : v() {
        for(int i=0; i<N; i++) {
            v[i] = x % mod;
            x /= mod;
        }
    }

    LargeInteger& operator=(long long x) {
        memset(v, 0, sizeof(int) * N);
        for(int i=0; i<N; i++) {
            v[i] = x % mod;
            x /= mod;
        }
        return *this;
    }

    void print() {
        bool ok = false;
        for(int i=N-1; i>=0; i--) {
            if(v[i] != 0) {
                printf("%d", v[i]);
                ok = true;
            }
            else if(ok) {
                printf("00000000");
            }
        }
        if(!ok) {
            printf("0");
        }
        printf("\n");
    }
};

struct BIT {
    int sz;
    vector<LargeInteger> nodes;
    BIT() : sz(0), nodes() {} 
    void init(int n) { sz = n; nodes = vector<LargeInteger>(n+1, 0); }
    void add(int x, const LargeInteger& val) { for(; x<=sz; x+=(x&-x)) nodes[x] += val; }
    LargeInteger sum(int x) {
        LargeInteger ret = 0;
        for(; x>0; x-=(x&-x)) ret += nodes[x];
        return ret;
    }
};

int n;
lint a[50011];
const int m = 5;
BIT bit[m];

void solve() {
    vector<P> v(n);
    REP(i,n) {
        v[i] = P(a[i], -(i+1));
    }
    sort(ALL(v));

    REP(i,m) bit[i].init(n);
    REP(i,n) {
        int id = -v[i].second;
        for(int k=m-2; k>=0; k--) {
            LargeInteger nxt = bit[k].sum(id);
            bit[k+1].add(id, nxt);
        }
        bit[0].add(id, 1);
    }
    bit[m-1].sum(n).print();
}

void coding() {
    while(scanf("%d", &n) != EOF) {
        REP(i,n) scanf("%I64d", a+i);
        solve();
    }
}

// #define _LOCAL_TEST

int main() {
#ifdef _LOCAL_TEST
    clock_t startTime = clock();
    freopen("a.in", "r", stdin);
    // freopen("a.out", "w", stdout);
#endif

    coding();

#ifdef _LOCAL_TEST
    clock_t elapsedTime = clock() - startTime;
    cout << endl;
    cerr << (elapsedTime / 1000.0) << " sec elapsed." << endl;
    cerr << "This is local test" << endl;
    cerr << "Do not forget to comment out _LOCAL_TEST" << endl << endl;
#endif
}