Ccmmutty logo
Commutty IT
14 min read

青色コーダーが ABC 全問解説する【ABC293】

https://cdn.magicode.io/media/notebox/blob_sRz7Ze0

AtCoder Beginner Contest 293

5 完…。停滞してきてる…(´・ω・`)

A - Swap Odd and Even

問題のとおり、入れ替えて出力するだけ。swap 関数を使えば楽。
↓ こんな感じ。
#include <bits/stdc++.h>

using namespace std;

int main() {
    string s;

    cin >> s;

    for (int i = 0; i < s.length() / 2; ++i) {
        swap(s[i * 2], s[i * 2 + 1]);
    }

    cout << s;

    return 0;
}

B - Call the ID Number

bool 配列とかで、呼ばれたかどうかを管理しておけば OK。
↓ こんな感じ。
#include <bits/stdc++.h>

using namespace std;

int main() {
    int n;

    cin >> n;

    vector<bool> x(n + 1);  // 呼ばれたかどうか
    for (int i = 1, a; i <= n; ++i) {
        cin >> a;
        if (!x[i]) x[a] = true;  // i が呼ばれていないなら A_i を呼ぶ
    }

    int k = 0;
    for (int i = 1; i <= n; ++i) if (!x[i]) ++k;  // 呼ばれていない人数をカウント
    
    cout << k << '\n';
    for (int i = 1; i <= n; ++i) if (!x[i]) cout << i << ' ';

    return 0;
}

C - Make Takahashi Happy

重複順列でも next_permutation 関数使えるのか (公式解説)。
経路の長さが H+W2H+W-2 なので、経路は高々 2H+W22^{H+W-2} 通り (正確には H+W2CH1_{H+W-2} C_{H-1} 通り) しかない。全探索で O(2H+W(H+W)log(H+W))O(2^{H+W}(H+W) \log (H+W)) とかでも間に合う。バックトラックなどで H+WH+W 落とせる。
↓ こんな感じ。
#include <bits/stdc++.h>

using namespace std;

int dfs(int i, int j, vector<vector<int>> &a) {
    static set<int> s;

    if (s.find(a[i][j]) != s.end()) return 0;
    if (i == 0 && j == 0) return 1;

    s.insert(a[i][j]);

    int ret = 0;
    if (i > 0) ret += dfs(i - 1, j, a);
    if (j > 0) ret += dfs(i, j - 1, a);

    s.erase(a[i][j]);

    return ret;
}

int main() {
    int h, w;

    cin >> h >> w;

    vector<vector<int>> a(h, vector<int>(w));
    for (int i = 0; i < h; ++i) for (int j = 0; j < w; ++j) cin >> a[i][j];

    cout << dfs(h - 1, w - 1, a);

    return 0;
}

D - Tying Rope

同じ色の端が複数回結ばれることがないので、環状にする場合以外に、同じ組に属するロープの端同士を結ぶことはない。
よって、同じ組に属するロープの端同士を結んだタイミングで XX11 増やせばいい。
これは、Union-Find などで実装できる。(Union-Find は、AtCoder Library にもあります。)
また、YY は、最初は NN で、ロープの端を結ぶたびに 11 減る (ロープの組が 11 減る or ロープの組が 11 つ環状になる)。つまり、Y=NMY=N-M
↓ こんな感じ。
#include <bits/stdc++.h>

using namespace std;

// Union-Find
class UnionFind {
    vector<int> p;

   public:
    UnionFind(int n) : p(n, -1) {}

    void unite(int x, int y) {
        x = find(x); y = find(y);
        if (x == y) return;

        if (p[x] < p[y]) {
            p[x] += p[y];
            p[y] = x;
        } else {
            p[y] += p[x];
            p[x] = y;
        }
        return;
    }

    int find(int x) {
        if (p[x] < 0) return x;
        while (p[p[x]] >= 0) {
            p[x] = p[p[x]];
            x = p[x];
            if (p[x] < 0) return x;
        }
        return p[x];
    }

    bool isSame(int x, int y) { return find(x) == find(y); }

    int size(int x) { return -p[find(x)]; }
};

int main() {
    int n, m;

    cin >> n >> m;

    int a, c;
    char b, d;
    UnionFind uni(n);
    int x = 0;
    for (int i = 0; i < m; ++i) {
        cin >> a >> b >> c >> d;
        --a; --c;
        if (uni.isSame(a, c)) ++x;
        else uni.unite(a, c);
    }

    cout << x << ' ' << n - m;

    return 0;
}

E - Geometric Progression

MM が素数でない場合、mod\bmod の割り算ができるとは限らないので注意。
f(x)=i=0n1Aif(x)=\displaystyle\sum_{i=0}^{n-1} A^i とおくと、
f(x+y)=Ayf(x)+f(y)f(2x)=(Ax+1)f(x)\begin{align*} f(x+y) &= A^y f(x) + f(y) \\ f(2x) &= (A^x+1)f(x) \end{align*}
であるから、ダブリングできる。
↓ こんな感じ。
#include <bits/stdc++.h>

using namespace std;

using ll = long long;

ll cal(ll a, ll x, ll mod) {
    ll ans = 0;
    ll ax = a, fx = 1;
    while (x) {
        if (x & 1) ans = (ax * ans + fx) % mod;
        fx = (ax + 1) * fx % mod;
        ax = ax * ax % mod;
        x >>= 1;
    }
    return ans;
}

int main() {
    ll a, x, m;

    cin >> a >> x >> m;

    cout << cal(a, x, m);

    return 0;
}
別解として、128-bit 整数が必要だけど、整数 a,b(0b<M)a,b(0 \leq b<M) を用いて、f(x)=aM+bf(x)=aM+b とおくと、
Ax1=(A1)f(x)=α(aM+b)=a(αM)+αb(α=A1)\begin{align*} A^x-1&=(A-1)f(x) \\ &=\alpha(aM+b)=a (\alpha M)+\alpha b&(\alpha=A-1) \end{align*}
であるから、(Ax1)modαM=αb(0b<M)(A^x-1) \bmod {\alpha M} = \alpha b(\because 0 \leq b < M)
よって、
f(X)modM=(AX1)modαMαf(X) \bmod M = \displaystyle \frac{(A^X-1) \bmod \alpha M}{\alpha}
でも求められる。(A=1A=1 のときは 00 除算になるので注意 (そのときの解は、XmodMX \bmod M)。)
↓ こんな感じ。
#include <bits/stdc++.h>

using namespace std;

using ll = long long;

ll pow(__int128_t n, __int128_t m, ll mod) {
    __int128_t ans = 1;
    while (m) {
        if (m & 1) ans = ans * n % mod;
        n = n * n % mod;
        m >>= 1;
    }
    return ans;
}

int main() {
    ll a, x, m;

    cin >> a >> x >> m;

    if (a == 1) cout << x % m;
    else cout << ((pow(a, x, m * (a - 1)) - 1) / (a - 1) + m) % m;

    return 0;
}

F - Zero or One

桁の数字を固定して、それと NNbb 進数表記 が一致するような bb が存在するかを考える。
これは、二分探索で求められる。桁数を dd とおくと、11 回の探索に O(d)O(d) かかるので、O(dlogN)O(d \log N)
桁数は logN\log N 桁程度であるから、全探索すると、d=logNd=\log N とおいて、O(2ddlogN)O(2^d d \log N) となり、間に合わない。dd を縮められるとうれしい。
今度は、bb について全探索することを考える。d=logNd=\log N 程度で、オーダーは、O(Nd)O(Nd) である。NNを縮められるとうれしい。
上の 22 つを複合して考えてみる。2b<M2 \leq b < M では「bb について全探索」、MbM \leq b では「桁を固定して全探索」とすると、オーダーは、d=logMNd=\log_M N 程度で、O(MlogN+2ddlogN)O(M \log N + 2^d d \log N) となる。
M=1024M=1024 程度にすれば、d=logMN=logN/logM6d = \log_M N=\log N / \log M \fallingdotseq 6 となり、MlogN6×104,2ddlogN2×104M \log N \fallingdotseq 6 \times 10^4, 2^d d \log N \fallingdotseq 2 \times 10^4 から、T=103T=10^3 でも間に合う。
↓ こんな感じ。
#include <bits/stdc++.h>

using namespace std;

using ll = long long;

// n の b 進数表記が 0 か 1 のみか
bool binary(ll n, ll b) {
    while (n) {
        if (n % b > 1) return false;
        n /= b;
    }

    return true;
}

// n の b 進数表記と桁を固定した値 s が一致するかどうか
bool binary(ll n, ll b, int s) {
    int t = 0, x = 1;
    while (n) {
        if (n % b > 1)
            return false;
        else if (n % b == 1)
            t |= x;
        n /= b;
        x <<= 1;
    }

    return s == t;
}

// n と 桁を固定した値 s のどちらが大きいか比較
// b 進数、最大 d 桁
bool cond(ll n, ll b, int s, int d) {
    vector<ll> j;
    while (n) {
        j.push_back(n % b);
        n /= b;
    }
    j.resize(d);
    for (int k = d - 1; k >= 0; --k) {
        int sk = (s >> k) & 1;
        if (j[k] > sk) return true;
        if (j[k] < sk) return false;
    }

    return false;
}

int main() {
    const int m = 1024, d = 6;
    int t;
    ll n;

    cin >> t;

    while (t--) {
        ll ans = 0;

        cin >> n;

        for (int i = 2; i < m; ++i) {  // m まで全探索
            ans += binary(n, i);
        }

        if (n >= m) {  // m から桁固定全探索
            for (int i = 1; i < (1 << d); ++i) {
                ll ok = m - 1, ng = n + 1;
                for (ll mid = (ok + ng) / 2; abs(ok - ng) > 1; mid = (ok + ng) / 2) {
                    if (cond(n, mid, i, d)) ok = mid;
                    else ng = mid;
                }
                ans += binary(n, ng, i);
            }
        }

        cout << ans << '\n';
    }

    return 0;
}

G - Triple Index

Mo's algorithm を使えば OK。
↓ こんな感じ。
#include <bits/stdc++.h>

using namespace std;

using ll = long long;

// Mo's algorithm
template <typename T, typename Push, typename Pop>
vector<T> mo(vector<pair<int, int>> &query, Push push, Pop pop) {
    return mo<T>(query, push, pop, push, pop);
}

template <typename T, typename Push, typename Pop>
vector<T> mo(const vector<pair<int, int>> &query, Push pushL, Pop popL,
             Push pushR, Pop popR) {
    T x{};
    int sz = query.size();
    vector<T> ret(sz);
    vector<int> idx(sz);
    iota(idx.begin(), idx.end(), 0);
    int sqrtQ = sqrt(sz);
    vector query4sort(query);
    for (auto &a : query4sort) a.first /= sqrtQ;

    sort(idx.begin(), idx.end(),
         [&query4sort](int a, int b) { return query4sort[a] < query4sort[b]; });

    int l = 0, r = 0;
    for (auto i : idx) {
        auto [L, R] = query[i];
        while (L < l) x = pushL(x, --l);
        while (r < R) x = pushR(x, r++);
        while (l < L) x = popL(x, l++);
        while (R < r) x = popR(x, --r);
        ret[i] = x;
    }

    return ret;
}

int main() {
    int n, q;

    cin >> n >> q;

    vector<int> a(n);
    for (int i = 0; i < n; ++i) cin >> a[i];

    vector<ll> ptn(max(4, n + 1));  // iC3
    ptn[3] = 6;
    for (int i = 4; i <= n; ++i) ptn[i] = ptn[i - 1] / (i - 3) * i;
    for (int i = 3; i <= n; ++i) ptn[i] /= 6;

    vector<int> cnt(200005);
    auto push = [&](ll &x, int i) {
        return x + ptn[++cnt[a[i]]] - ptn[cnt[a[i]] - 1];
    };
    auto pop = [&](ll &x, int i) {
        return x + ptn[--cnt[a[i]]] - ptn[cnt[a[i]] + 1];
    };

    vector<pair<int, int>> query(q);
    for (int i = 0; i < q; ++i) {
        cin >> query[i].first >> query[i].second;
        --query[i].first;
    }

    for (auto a : mo<ll>(query, push, pop)) cout << a << '\n';

    return 0;
}

Ex - Optimal Path Decomposition

公式解説のとおりに実装しました。
↓ こんな感じ。
#include <bits/stdc++.h>

using namespace std;

using ll = long long;

int dfs(int v, int p, vector<vector<int>> &g, int k, vector<int> &dp) {
    static int ans = 1;
    if (p == -1) ans = 1;  // 初期化

    vector<int> x(5), y(3);  // 大きい順に定数個
    for (int a : g[v]) {
        if (a != p) {
            if (dfs(a, v, g, k, dp)) {
                y[2] = dp[a];
                for (int i = 1; i >= 0; --i)
                    if (y[i] < y[i + 1]) swap(y[i], y[i + 1]);
            } else {
                x[4] = dp[a];
                for (int i = 3; i >= 0; --i)
                    if (x[i] < x[i + 1]) swap(x[i], x[i + 1]);
            }
        }
    }

    bool flag = false;  // 選択肢 3 を選ぶかどうか

    int s2 = max({x[0], x[1] + 1, y[0] + 1}), s3 = max({x[0], x[2] + 1, y[0] + 1});

    vector<int> c(5);
    c[0] = x[0] - 1, c[1] = x[1], c[2] = x[2], c[3] = y[0], c[4] = y[1];
    sort(c.rbegin(), c.rend());

    if (s2 > s3 || c[0] + c[1] + 1 > k) {
        flag = true;

        vector<int> d(6);
        d[0] = x[0] - 1, d[1] = x[1] - 1, d[2] = x[2], d[3] = x[3], d[4] = y[0], d[5] = y[1];
        sort(d.rbegin(), d.rend());

        ans = max(ans, d[0] + d[1] + 1);
        dp[v] = s3;
    } else {
        ans = max(ans, c[0] + c[1] + 1);
        dp[v] = s2;
    }

    return (p == -1 ? ans : flag);
}

int main() {
    int n;

    cin >> n;

    vector<vector<int>> g(n);
    for (int i = 0, u, v; i < n - 1; ++i) {
        cin >> u >> v;
        --u; --v;
        g[u].push_back(v); g[v].push_back(u);
    }

    int ok = n + 1, ng = 0;
    vector<int> dp(n);
    for (int mid = (ok + ng) / 2; abs(ok - ng) > 1; mid = (ok + ng) / 2) {
        if (dfs(0, -1, g, mid, dp) <= mid) ok = mid;
        else ng = mid;
    }

    cout << ok;

    return 0;
}

Discussion

コメントにはログインが必要です。