問題概要
2 つの文字列の最長共通接頭辞 (Longest Common Prefix; LCP) を求める関数を $\mathrm{ LCP }$ とする.英小文字からなるすべての文字列からなる集合を $\mathcal S$ として,関数 $f \mathpunct{:} \mathcal S \times \mathcal S \to \mathbb Z_{ \geq 0 }$ を
\begin{equation*}
f( x, y ) = | \mathrm{ LCP }( x, y ) |
\end{equation*}
で定義する.
英小文字からなる文字列 $n$ 個からなる列 $S = \langle S_1, S_2, \dots, S_n \rangle$ が与えられる.式
\begin{equation*}
\sum_{ i = 1 }^{ n - 1 }\sum_{ j = i + 1 }^{ n } f( S_i, S_j )
\end{equation*}
を求めよ.
制約
- $2 \leq n \leq 3 \times 10^5$
- $1 \leq | S_i |$
- $\sum_{ i = 1 }^{ n } | S_i | \leq 3 \times 10^5$
解法
制約から $\Omega( n^2 )$ 時間はかけられないので,工夫する必要があります.
式で与えられた範囲で許容される範囲の $i, j$ が固定されたとき,求めたい値への寄与はそのまま $| \mathrm{ LCP }( S_i, S_j ) |$ なわけですが,これを逐一計算する時間はありません.天下り的かもしれませんが,これを $1$ を $| \mathrm{ LCP }( S_i, S_j ) |$ 回足すと考えます.(最長とは限らない)共通接頭辞の個数とも言えます.
知識の問題になりますが Trie というデータ構造があって,Prefix Tree とも言います.これはその名の通り文字列集合のすべての接頭辞に対してデータを紐付けることができるデータ構造です.各接頭辞に対してそれを達成できる添字の個数を紐付ける Trie を実装することで,問題を解くことができます.具体的には,ある接頭辞をもつ文字列の個数を $k$ とすれば,その接頭辞による答えへの寄与は二項係数
\begin{equation*}
\binom k 2 = \frac{ k ( k - 1 ) }{ 2 }
\end{equation*}
と求まります.この値を Trie のすべてのノードに渡って足し上げることで,答えが得られます.
Trie への文字列の挿入はトータルで $\Theta\left( \sum\limits_{ i = 1 }^{ n } | S_i | \right)$ 時間かかり,必要な空間は $O\left( \sum\limits_{ i = 1 }^{ n } | S_i | \right)$ スペースです.答えの計算には Trie の全ノードをトラバースする必要がありますが,これは $O\left( \sum\limits_{ i = 1 }^{ n } | S_i | \right)$ 時間で行えます.よって,問題に正答することができます.
コード
#include <memory> struct Trie { struct Node { int num = 0; map< char, shared_ptr< Node > > children; Node() = default; void insert( const string &S, const int i = 0 ) { if ( SZ( S ) <= i ) { return; } auto &child = children[ S[i] ]; if ( !child ) { child = make_shared< Node >(); } ++child->num; return child->insert( S, i + 1 ); } LL sum() { LL res = LL( num ) * ( num - 1 ) / 2; for_each( ALL( children ), [&]( const auto &p ){ res += p.snd->sum(); } ); return res; } }; shared_ptr< Node > root = make_shared< Node >(); Trie() = default; void insert( const string &S ) { return root->insert( S ); } LL sum() { return root->sum(); } }; int main() { IN( int, N ); VS SS( N ); cin >> SS; Trie trie; for_each( ALL( SS ), [&]( const string &S ){ trie.insert( S ); } ); cout << trie.sum() << endl; return 0; }