ABC117 D XXOR 解説
問題
数列Aの全ての要素とのxorを最大化するようなK以下の整数を求めよ。
解説
公式解説の序盤について、Xの候補となりうる整数とは
Xがなんでも良ければ難易度が下がりますが、K以下というところが難点です。 まずK以下の整数とはどのようなものかを確認します。
例えば、2進数表示で0101
という整数Kがあるとします。これはつまり5なのですが、
1<2<3<4<5
→001<010<011<100<101
ということを思い出せば、少なくとも3bit目が0であるような数に対しては、それより下の1,2bitがどんなものであろうと、Kよりも小さい数だと分かります。
このように、整数Kのあるiビット目が1であれば、そのビットをとりあえず0にした整数Xを持ってこれば、iビット目より下のビットはどんなでも良い、というのが公式解説の序盤です。上記はi=3のときの例となります。
001,010,011 < 101
↑1,2ビットをどう変えても3ビット目に応じて大小が決まる例。
このようなことは、Kのあるビットが1であれば、それをiビット目としていつでも言うことができます。言い換えれば、「iビット目」の候補になるのは、Kにおけるiビット目が1であるようなビットのみ、となります。
ビットの決め方
- iビット目未満のビット
さて、上記の例における1,2ビットのように、どう変えても大小の結果に影響しないビットというのは、自分の都合の良いように変えることができるわけです。(もちろんXの3ビット目は0である、ということは厳守した上で。)
ここで、x(x<i)ビット目を都合の良いように変えるとは、N個の整数のxビット目の0の数と1の数を数えて、少ない方の数字を採用するいうことです。
例えば、8個の整数において、xビット目が0のものが5個、1のものが3個あれば、Xのxビット目を0にしたときは、0 xor 1 = 1
より3つの整数について1が立ちます。つまり、f(X)は増えます。対して、xビット目を1にしたときは、1 xor 0 = 1
より、f(X)は増えます。ということは、より数が少ない方にXのxビット目を設定すればよいです。
さらに、実はX自体が何かを知る必要はなくて、上記のようにf(X)にどれだけ足されるかという値がわかっているので、直接f(X)を求めることができます。
よって、都合の良いように変えていいビットについては、となります。ここで、とは、xビット目が立っている整数の数です。
- iビット目
このビットは、唯一Kのビットが1で、Xのビットが0であるようなところです。つまりXのビットは必ず0にしないといけないビットです。この問題ではxorを取ることを目的としているので、N個の整数の中でiビットが1であるようなものだけf(X)に影響します。よって、iビット目は必ず です。
- iビット目より上のビット
ここではK以下の整数という制約を守るため、Kのビットに合わせます。
Kのy(y>i)ビット目が1ならXのyビットも1です。つまり、N個の整数のうちyビット目が0のものだけf(X)に影響します。このときはだけ影響します。
Kのyビットが0ならXのyビットも0です。つまり、N個の整数のうちyビット目が1のものだけf(X)に影響します。このときはだけ影響します。
これで全てのビットについてf(X)への影響度がわかったので、f(X)を求めることができます。
ただし、これはiビット目が決まっているからこそできることなので、i の候補が40程度なことから、全探索することを考えます。全探索する中で、Kのビットが1であるようなものについて上記の計算をして、一番大きいものを取りましょう。
コードは一部省略(using namespaceとか)
#define rep(i,j,k) for(int i=(int)j;i<(int)k;i++) signed main (){ int N,K; cin>>N>>K; vector<int> A(N); rep(i,0,N)cin>>A[i]; int res=0; // X_i < K_i、つまりX_i=0で、K_i=1であるようなiを全探索 rep(i,-1,41){ if(i!=-1 && !(K&(1LL<<i)))continue; //Kのiビット目が0ならcontinue int ans=0; //a[i]に対して、jビット目が立っているものを数える = ct rep(j,0,41){ int ct=0; rep(k,0,N){ if(A[k]&(1LL<<j))ct++; } // iビット目より上のビット if(j>i){ if(K&(1LL<<j))ans+=(1LL<<j)*(N-ct); else ans+=(1LL<<j)*ct; // iビット目 }else if(i==j){ ans+=(1LL<<j)*ct; // iビット目未満のビット }else { ans+=(1LL<<j)*max(ct,N-ct); } } res=max(res,ans); } cout<<res<<endl; }