POJ 2452: Sticks Problem

問題

長さがNの同じ要素を含まない数列Sがある。この数列の中で2つの番号i, j (i<j)を選んだ時、その間に含まれる数がSiより大きくSjより小さくなるような連続した部分列のうち最大の長さの物を求めよ。もしそのような部分列が存在しないときには-1を返しなさい。

制約条件

1 <= N <= 50000
1 <= Si <= 1000000

解法

まず単純に考えると、適当に番号iとjを選んで、その間の数がSiとSjの間の数になっているかどうかをbrute forceに調べる方法が考えられます。

ですが、当然ながら計算量はO(N^3)であるため、TLEです。

そこでiを固定して、jをもっと早く求める方法を考えます。もし、[i,N]の中で、最初にiより小さい値の現れる番号xを求めることができるならば、[i,x)の中で最大の大きさを持つ物がjになります。

このような番号xは二分探索で簡単に求めることができます。さらに、[i,x)の間にある数は、あらかじめ数列をデータ構造を使って整理しておけば、速く計算することができます。

ここで「速く」といったのは、どのデータ構造を使うかによって、計算時間が変わるからです。僕は最初、セグメント木で実装をしました。

セグメント木は構築の計算量がO(N)であり、探索の計算量がO(logN)なので、上の方法を使ったとすると、全体の計算量はO(N(logN)^2)です。

これで何とか入るでしょ、と思ってSubmitしたところ、TLEになりました。

この問題は少し探索が行われる回数が多いので、探索がO(1)になるようにSparse Tableを使った実装に変更します。

一応、復習しておくと、セグメント木が空間計算量O(N)、構築計算量O(N)、探索計算量O(logN)であるのに対して、Sparse Tableは空間計算量がO(NlogN)、構築計算量がO(NlogN)、探索計算量がO(1)です。

なのでSparse Tableに変更すると全体の計算量がO(NlogN)になります。これでめでたくACできます。


コード

// POJ 2452: Sticks Problem
#include <iostream>
#include <sstream>
#include <cstdio>
#include <cstdlib>
#include <cmath>
#include <ctime>
#include <cstring>
#include <climits>
#include <algorithm>
#include <functional>
#include <numeric>
#include <map>
#include <set>
#include <stack>
#include <vector>
#include <queue>
#include <bitset>
#include <string>
using namespace std;

#define REP(i,n) for(int i=0; i<n; i++)
#define REPP(i,s,e) for(int i=s; i<=e; i++)
#define PB(a) push_back(a)
#define MP(f,s) make_pair(f,s)
#define SZ(a) (int)(a).size()
#define ALL(a) (a).begin(), (a).end()
#define PRT(a) cerr << #a << " = " << (a) << endl
#define PRT2(a,b) cerr << #a << " = " << (a) << ", " << #b << " = " << (b) << endl
#define PRT3(a,b,c) cerr << (a) << ", " << (b) << ", " << (c) << endl

typedef pair<int,int> P;

int n;
int s[50011];

const int INF = 1 << 30;
const int SIZE = 1 << 16;
int id[1000011];
int ma[SIZE][16];
int mi[SIZE][16];
int sz;

void init() {
    REP(i,n) {
        ma[i][0] = s[i];
        mi[i][0] = s[i];
    }

    int w = 1;
    for(int k=1; k<16; k++) {
        for(int i=0; i+w<n; i++) {
            ma[i][k] = max(ma[i][k-1], ma[i+w][k-1]);
            mi[i][k] = min(mi[i][k-1], mi[i+w][k-1]);
        }
        w <<= 1;
    }
}

int max_query(int l, int r) {
    int w = 1;
    int k = 0;
    while((w << 1) < r - l) {
        w <<= 1;
        k++;
    }
    return max(ma[l][k], ma[r-w][k]);
}

int min_query(int l, int r) {
    int w = 1;
    int k = 0;
    while((w << 1) < r - l) {
        w <<= 1;
        k++;
    }
    return min(mi[l][k], mi[r-w][k]);
}

int calc(int x) {
    int lo = x;
    int hi = n;
    while(lo+1 < hi) {
        int mid = (lo + hi) / 2;
        if(id[min_query(x, mid)] == x) {
            lo = mid;
        } else {
            hi = mid;
        }
    }
    return lo;
}

void solve() {
    REP(i,n) {
        id[s[i]] = i;
    }

    int ans = -1;
    init();
    REP(i,n) {
        if(n - i - 1 < ans) break;
        int f = calc(i);
        if(i != f) {
            int j = id[max_query(i+1, f+1)];
            if(s[i] < s[j]) {
                ans = max(ans, j - i);
            }
        }
    }
    printf("%d\n", ans);
}

void coding() {
    while(scanf("%d", &n) != -1) {
        REP(i,n) scanf("%d", s+i);
        solve();
    }
}

// #define _LOCAL_TEST

int main() {
#ifdef _LOCAL_TEST
    clock_t startTime = clock();
    freopen("a.in", "r", stdin);
    // freopen("a.out", "w", stdout);
#endif

    coding();

#ifdef _LOCAL_TEST
    clock_t elapsedTime = clock() - startTime;
    cout << endl;
    cerr << (elapsedTime / 1000.0) << " sec elapsed." << endl;
    cerr << "This is local test" << endl;
    cerr << "Do not forget to comment out _LOCAL_TEST" << endl << endl;
#endif
}