hihoCoder1646: Rikka with String II

「爆ぜろリアル!弾けろシナプス!」

头图是贴来“激励” YZ 的。

传送门

题目大意

勇太定义:把一些 01 串加入一棵 $\text{Trie}$ 中,这棵 $\text{Trie}$ 的结点数即为这些字符串的价值。

勇太给了六花 $n$ 个 01? 串,其中 ? 既可以是 0 也可以是 1。显然如果一共有 $k$ 个 ?,那么这些字符串一共有 $2^k$ 种情况。

现在,勇太想让六花算出所有情况字符串的价值和,答案对 $998244353$ 取模。

其中 $n\le 20,|s|\le 50$

题解

首先,我们可以想到的是枚举每个 ? 是什么。。。好吧说实话,一开始我什么都没想到。

不过可以发现,这道题多半不会用到字符串相关的知识,像什么 $\text{sa}$ 或者 $\text{ac}$ 自动机多半不会出现。【经验之谈。

我们先考虑对于一些给定的字符串,价值该怎么求。【不要说建一棵 $\text{Trie}$!

也就是说,什么会让 $\text{Trie}$ 增加节点?

如果一个字符串是另一个字符串的前缀,那么它对答案贡献显然为 $0$。

那么如果一个字符串不是另一个字符串的前缀,它对答案的贡献是多少呢?

是不是 $|s|-|LCP|$?($LCP$:最长公共前缀)

但是 $LCP$ 太难算了,还要写 $sa$,不会,怎么办?

这时候我们再想一想,$\text{Trie}$ 上的一个节点到底是什么?

是一棵树上的节点。。。

是一个/些字符串的前缀,对吧?

那么问题来了:$\text{Trie}$ 上不同节点代表着什么?

是不是不同的前缀

所以需要算的是不同的前缀的数量之和。

这道题就从一道字符串变成了一道数论。

可以发现很难单独计算不同前缀的数量,所以想到用容斥

那么一个字符串的前缀数量是多少呢?

$|s|$?

其实是 $|s|\times 2^k$。

看吧,这样就把进阶的问题一起处理了。

接下来需要做的就是统计两个字符串的公共前缀,三个字符串……

那么这个时候,只需要枚举一下每一个字符串集合再计算就可以了。

每个前缀对答案的贡献为 $2$ 的自由 ? 数次方。注意如果这一位上集合中所有字符串都是 ?,那么只算一个自由 ?,因为这些 ? 都是一样的。

注意:$\text{Trie}$ 有根节点,答案加上 $2$ 的 ? 个数次方。【记得取模。。。

代码

/*"-Aria on the Planets-"*/
#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
typedef double D;
typedef pair<int,int> Pii;
#define mp(a,b) make_pair(a,b)
#define fir first
#define sec second
inline int min(int _a,int _b){return _a<_b?_a:_b;}
inline int max(int _a,int _b){return _a>_b?_a:_b;}
template<class _T>inline void rd(_T &_a){int _f=0,_ch=getchar();_a=0;while(_ch<'0'||_ch>'9'){if(_ch=='-')_f=1;_ch=getchar();}while(_ch>='0'&&_ch<='9')_a=(_a<<1)+(_a<<3)+_ch-'0',_ch=getchar();if(_f)_a=-_a;}
const int inf=0x3f3f3f3f;
const D eps=1e-8;
const int mod=998244353,N=55;
int p[N*N],l[N];char s[N][N],*st[N];
int main(){
    //clock_t start=clock();
    //freopen("test.in","r",stdin);
    //freopen("test.out","w",stdout);
    int n,cnt=0,ans=0;rd(n);
    for(int i=0;i<n;i++){scanf("%s",s[i]);l[i]=strlen(s[i]);for(int j=0;j<l[i];j++)cnt+=s[i][j]=='?',s[i][j]=(s[i][j]!='?')*(s[i][j]-'0'+1);}
    for(int i=p[0]=1;i<=cnt;i++)(p[i]=p[i-1]<<1)%=mod;
    for(int i=1,t=1<<n;i<t;i++){
        int top=0,sum=cnt,tmp=0,len=inf;for(int j=0;j<n;j++)if(i>>j&1)st[top++]=s[j],len=min(len,l[j]);
        for(int j=0;j<len;j++){int fl=0,ct=0;for(int k=0;k<top&&(~fl);k++)if(!st[k][j])ct++;else if(!fl)fl=st[k][j];else if(fl!=st[k][j])fl=-1;if(fl==-1)break;(tmp+=p[(sum-=ct-!fl)])%=mod;}
        (ans+=top&1?tmp:mod-tmp)%=mod;
    }
    printf("%d",(ans+p[cnt])%mod);
    //printf("\n%dms",(int)((D)(clock()-start)/CLOCKS_PER_SEC*1000));
    return 0;
}