AtCoderの Typical DP Contestを解いてみた (E 数)
今回もTypical DP Contestのアウトプットに関する記事です.
今回解いた問題はE 数です
前回までの記事は以下の通りです.
E 数
今回は「E 数」という問題を解きました.
問題の概要
以下の正整数であって、10進数表記したときの各桁の数の和がの倍数であるものの個数を で求めよ.
制約
解法
愚直にループを回すと、がのケースが考えられるので到底間に合いません.
したがってDPで解くことを考えます.
試行錯誤して漸化式をたてようと試みたのですが、うまくたてられませんでした.
なので、(他の人が解いたこの問題の解法は避けるように)ググってみた結果、桁DPというものが使えそうなので、そちらの勉強から始めました.
(桁DPはよく使う手法らしいのですが恥ずかしながら初めて聞きました... 勉強不足ですね...)
寄り道 (桁DPの勉強)
私は、こちらの記事を参考にさせていただき、一通り実装することで、桁DPの考え方を学びました.
非常にわかりやすくまとまっているので、桁DPってなんぞやって方は、本記事の続きを見る前に一読することをお勧めします.
(桁DPについての解説などは本記事では省きます.)
解法 (本題に戻る)
(ほとんど上記の参考記事に記載されているプログラムの書き方で、アルゴリズムを含め自分で考えたのはほんの少しの部分です.)
桁DPの紹介記事にあった通り、 $$ dp[i][j] : \begin{cases} i : 上からi桁目まで参照している \\ j : N未満であることが確定しているかどうか\\ (j = 1 で確定、j = 0で確定していない) \end{cases} $$
というところから考え始めます.
本問題では各桁の和がである数の総数を求めなければいけないので、追加要素として、で割った余りという状態を持たせます.
つまり、 $$ dp[i][j][k] : \begin{cases} i : 上からi桁目まで参照している \\ j : N未満であることが確定しているかどうか\\ (j = 1 で確定、j = 0で確定していない)\\ k : 各桁の和をDで割った余り \end{cases} $$
とします.
漸化式は、まず桁DPの基本のところから考えると、に対して、 $$ dp[i + 1][j \,|| \, d < lim] += dp[i][j] $$
となります.
ただし、という数字はそれまでにを超えていたら次の桁の数に制限が加わるので、その次の桁に使える数の上限を表しています.
例えば、に対して、現在とみていたら、?には0~9の全ての数を使えません.
を超えていはいけないので、?には0 ~ 3のいずれかしか入れることができないからです.
したがって、この例で言うとです.
次に本問題のために拡張したDPの漸化式をたてます. $$ dp[i + 1][j \,||\, d \,< \, lim][(k + d) \, \% \, D] += dp[i][j][k] $$
3つめの要素では1桁前の総和に使用可能な数を足し合わせ、それをで割った余りで分類しています.
最終出力は、です.
(なぜなら、今回は各桁の総和をで割った余りが0となる正整数の個数を求められているからです.)
ちなみに最後にマイナス1をしているのはもの倍数であるとして数え上げているためです.
ソース
使用言語はC++です.
#include <iostream> #include <string> #include <vector> #include <cstdlib> #include <cstdio> #include <cmath> #include <algorithm> #include <map> #include <stack> #include <queue> #include <set> #include <cstring> using namespace std; // ascending order #define vsort(v) sort(v.begin(), v.end()) // descending order #define vsort_r(v) sort(v.begin(), v.end(), greater<int>()) #define vunique(v) unique(v.begin(), v.end()) #define mp make_pair #define ts(x) to_string(x) #define rep(i, a, b) for(int i = (int)a; i < (int)b; i++) #define repm(i, a, b) for(int i = (int)a; i > (int)b; i--) #define bit(a) bitset<8>(a) typedef long long ll; typedef pair<int, int> P; const ll INF = 1e18; int main(){ cin.tie(0); ios::sync_with_stdio(false); int D; string N; cin >> D >> N; ll n = N.length(); ll mod = 1e9 + 7; int dp[n + 1][2][D]; memset(dp, 0, sizeof(dp)); dp[0][0][0] = 1; rep(i, 0, n) rep(j, 0, 2) rep(k, 0, D) { int lim = j ? 9 : N[i] - '0'; rep(d, 0, lim + 1) (dp[i + 1][j || d < lim][(k + d) % D] += dp[i][j][k]) %= mod; } int ans = 0; rep(j, 0, 2) (ans += dp[n][j][0]) %= mod; cout << (ans - 1) << endl; }