プログラミング初心者の勉強日記

情報科学専攻です. 機械学習と競技プログラミングについて日々勉強しています.

MENU

AtCoderの Typical DP Contestを解いてみた (F 準急)

今回もTypical DP Contest のアウトプットに関する記事です.

今回解いた問題はF 準急です

前回までの記事は以下の通りです.

  1. A コンテスト
  2. B ゲーム
  3. C トーナメント
  4. D サイコロ
  5. E 数

F 準急

今回は「F 準急」という問題を解きました.

(かなり難しくなってきたように感じ、自分では解くことが困難になってきました.

修行が必要ですね...)

問題の概要

ある路線には駅1 \sim N駅までのN個の駅がある.

A君は、下の制約のもと、この路線に準急を走らせることにした.

制約

  •  1に止まり、{2, \cdots, N - 1}の部分集合に止まり、駅Nに止まる.
    • 連続するK個以上の駅に止まることはない.

(自分が思いついた)解法

まずはじめに、計算量を考慮せずに、自分が思いつける範囲でDPに落としこむことを考えました.

少し考えて思いついたのが、

$$ dp[i][j] : i駅までで現在j連続で止まっているときの場合の数 $$

です. (明らかにこれだと間に合いませんが、一応実装したのでアウトプットします.)

このとき、漸化式は

$$ \begin{cases} dp[i + 1][j + 1] += dp[i][j] & (i + 1駅目で止まる) \\ dp[i + 1][0] += dp[i][k] & (i + 1駅目は通過する) \end{cases} $$

となります.

通過する場合は、連続して停車した駅の数が0にリセットされます.

最終出力は、

$$ ans += dp[N][k] \, (k = 0, \cdots K - 1) $$

です.

このDPだと、状態数が  N \times Kですので、もう少し状態を減らすことが考えます.

(現在いる駅の数  0 \sim K - 1回連続で止まれる、遷移数が2なので、計算量はO(NK)となって、到底間に合いません.

(調べてわかった)解法

DPでは状態数を減らすことを考えるのが定石らしいので、状態数を減らすことを考えます.

先ほどは何駅連続で止まったかと言う状態を保持していましたが、今回はi駅を通過するか停車するかの2つだけを考えることにします.

したがって、 $$ \begin{cases} dp[i][0] : 駅iで止まる \\ dp[i][1] : 駅iを通過する \end{cases} $$

となります.

漸化式は、 $$ dp[i][j] = dp[i - 1][0] + dp[i - 1][1] $$ です.

この式が表しているのは、i駅で通過または停車するときの場合の数は、i-1駅通過または停車したときの場合の数を足し合わせたものです.

しかし、i駅で停車したときはもう少し考える必要があります.

それは、i駅でK-1連続停車した回数は、i - K駅を通過したときに等しいことになります. (その区間だけ各駅停車になるので場合の数は変わらない.)

重複した分を引く必要があるので、 $$ dp[i][0] -= dp[i - K][1] $$ とする必要があります.

これで、状態数はN \times 2なので、計算量はO(N)で、間に合います.

その他にも状態数を減らすテクニックとして累積和を用いる方法があるそうです.

この記事が非常にわかりやすかったので、リンクをあげておきます.

(内部を見る限り上で説明した手法とほぼ同じことをやっていると思います. (多分))

shindannin.hatenadiary.com

実装上の注意

modの扱い方には少し注意が必要です.

dp[i][0]からdp[i - K][1]を引くとき、ふになってしまう可能性があり、負の値をmodで割った余りを更新する値としては不具合が生じます.

したがって、dp[i][0] = (dp[i][0] - do[i - K][0] + mod) % modと、modを足した後で割ってあげます.

ソース

使用言語はC++です.

#include <iostream>
#include <string>
#include <vector>
#include <cstdlib>
#include <cstdio>
#include <cmath>
#include <algorithm>
#include <map>
#include <stack>
#include <queue>
#include <set>
#include <cstring>

using namespace std;
// ascending order
#define vsort(v) sort(v.begin(), v.end())
// descending order
#define vsort_r(v) sort(v.begin(), v.end(), greater<int>())
// ascending order
#define asort(array, N) sort(array, array + N)
// descending order
#define asort_r(array, N) sort(array, array + N, greater<int>())
#define vunique(v) v.erase(unique(v.begin(), v.end()), v.end())
#define mp make_pair
#define ts(x) to_string(x)
#define rep(n, init, end) for(int n = init; n < end; n++)
typedef long long ll;
typedef pair<int, int> P;
const ll INF = 1e18;

int main(){
cin.tie(0);
ios::sync_with_stdio(false);

int N, K;
cin >> N >> K;
int mod = 1e9 + 7;

int dp[N + 1][2];
dp[0][0] = dp[0][1] = 1;

rep(i, 1, N + 1) {
if(i == 1) {
dp[i][0] = 1;
dp[i][1] = 0;
continue;
}

rep(j, 0, 2) dp[i][j] = (dp[i - 1][0] + dp[i - 1][1]) % mod;

if(i - K >= 0) dp[i][0] = (dp[i][0] - dp[i - K][1] + mod) % mod;
}
cout << dp[N][0] << endl;

}