ABC117 D XXOR 解説

問題

atcoder.jp

数列Aの全ての要素とのxorを最大化するようなK以下の整数を求めよ。

解説

公式解説の序盤について、Xの候補となりうる整数とは

Xがなんでも良ければ難易度が下がりますが、K以下というところが難点です。 まずK以下の整数とはどのようなものかを確認します。

例えば、2進数表示で0101という整数Kがあるとします。これはつまり5なのですが、 1<2<3<4<5001<010<011<100<101ということを思い出せば、少なくとも3bit目が0であるような数に対しては、それより下の1,2bitがどんなものであろうと、Kよりも小さい数だと分かります。 このように、整数Kのあるiビット目が1であれば、そのビットをとりあえず0にした整数Xを持ってこれば、iビット目より下のビットはどんなでも良い、というのが公式解説の序盤です。上記はi=3のときの例となります。

001,010,011 < 101
↑1,2ビットをどう変えても3ビット目に応じて大小が決まる例。

このようなことは、Kのあるビットが1であれば、それをiビット目としていつでも言うことができます。言い換えれば、「iビット目」の候補になるのは、Kにおけるiビット目が1であるようなビットのみ、となります。

ビットの決め方

  • iビット目未満のビット

さて、上記の例における1,2ビットのように、どう変えても大小の結果に影響しないビットというのは、自分の都合の良いように変えることができるわけです。(もちろんXの3ビット目は0である、ということは厳守した上で。)

ここで、x(x<i)ビット目を都合の良いように変えるとは、N個の整数のxビット目の0の数と1の数を数えて、少ない方の数字を採用するいうことです。 例えば、8個の整数において、xビット目が0のものが5個、1のものが3個あれば、Xのxビット目を0にしたときは、0 xor 1 = 1より3つの整数について1が立ちます。つまり、f(X)は 3 \times 2^ x増えます。対して、xビット目を1にしたときは、1 xor 0 = 1より、f(X)は 5 \times 2^ x増えます。ということは、より数が少ない方にXのxビット目を設定すればよいです。 さらに、実はX自体が何かを知る必要はなくて、上記のようにf(X)にどれだけ足されるかという値がわかっているので、直接f(X)を求めることができます。

よって、都合の良いように変えていいビットについては、 2^ x\times max(ct, N-ct)となります。ここで、 ctとは、xビット目が立っている整数の数です。

  • iビット目

このビットは、唯一Kのビットが1で、Xのビットが0であるようなところです。つまりXのビットは必ず0にしないといけないビットです。この問題ではxorを取ることを目的としているので、N個の整数の中でiビットが1であるようなものだけf(X)に影響します。よって、iビット目は必ず 2^ i \times ct です。

  • iビット目より上のビット

ここではK以下の整数という制約を守るため、Kのビットに合わせます。

Kのy(y>i)ビット目が1ならXのyビットも1です。つまり、N個の整数のうちyビット目が0のものだけf(X)に影響します。このときは 2^ y(N-ct)だけ影響します。

Kのyビットが0ならXのyビットも0です。つまり、N個の整数のうちyビット目が1のものだけf(X)に影響します。このときは 2^ y\times ctだけ影響します。

これで全てのビットについてf(X)への影響度がわかったので、f(X)を求めることができます。

ただし、これはiビット目が決まっているからこそできることなので、i の候補が40程度なことから、全探索することを考えます。全探索する中で、Kのビットが1であるようなものについて上記の計算をして、一番大きいものを取りましょう。

コードは一部省略(using namespaceとか)

#define rep(i,j,k) for(int i=(int)j;i<(int)k;i++)
signed main (){
    int N,K;
    cin>>N>>K;
    vector<int> A(N);
    rep(i,0,N)cin>>A[i];
    int res=0;
    // X_i < K_i、つまりX_i=0で、K_i=1であるようなiを全探索
    rep(i,-1,41){ 
        if(i!=-1 && !(K&(1LL<<i)))continue; //Kのiビット目が0ならcontinue
        int ans=0;
        //a[i]に対して、jビット目が立っているものを数える = ct
        rep(j,0,41){ 
            int ct=0;
            rep(k,0,N){ 
                if(A[k]&(1LL<<j))ct++;
            }
            // iビット目より上のビット
            if(j>i){
                if(K&(1LL<<j))ans+=(1LL<<j)*(N-ct);
                else ans+=(1LL<<j)*ct;
            // iビット目
            }else if(i==j){
                ans+=(1LL<<j)*ct;
            // iビット目未満のビット
            }else {
                ans+=(1LL<<j)*max(ct,N-ct);
            }
        }
        res=max(res,ans);
    }
    cout<<res<<endl;
}