https://atcoder.jp/contests/abc445/tasks/abc445_g
たまには解説を見ようかと思って開いたら「二部グラフ」をいう単語が見えてすぐに分かりました。これは二部グラフになるのですね。例えば、入力例1のようにA=1, B=2なら、一歩動くとi+jの偶奇が変わります。A=2, B=6のようにどちらも2の個数が同じなら、その数をeとして、iだけ見て一歩動くと2e+1で割った余りが2eだけ変わるので、余りが2eより小さいときとそうでないときで二つに分けられます。A=0のとき(このときB=0だと問題の意味がよくわかりませんが、そういうことはないとして)x+yを2Aで割った余りを同様に見ます。本当はグラフを連結成分に分けることができることがありますが(例えば、A=B=1なら、i+jが偶数と奇数のグラフに分けられます)、間に合うので分けなくてもいいことにします。
そうすると、最大マッチングの問題をHopcroft–Karp法で解いて、そこからすぐに最大独立集合が求められます。これがコマの配置です。
Pythonで書きましたが、Rustに書き換える余裕が無いのでそのままでいいことにします。
from __future__ import annotations # coding: utf-8 from itertools import * from collections import defaultdict from math import gcd from typing import Tuple #################### library #################### def read_int() -> int: return int(input()) def read_tuple() -> tuple[int, ...]: return tuple(map(int, input().split())) def read_list() -> list[int]: return list(map(int, input().split())) def YesNo(b: bool) -> str: return 'Yes' if b else 'No' def div_pow(n: int, d: int) -> tuple[int, int]: e = 0 while n % d == 0: e += 1 n //= d return (e, n) #################### BiPartiteGraph #################### from collections import deque class HopcroftKarp: def __init__(self, n_left, n_right): self.n = n_left self.m = n_right self.graph = [[] for _ in range(n_left)] self.pair_u = [-1] * n_left self.pair_v = [-1] * n_right self.dist = [0] * n_left def add_edge(self, u, v): self.graph[u].append(v) def bfs(self): queue = deque() for u in range(self.n): if self.pair_u[u] == -1: self.dist[u] = 0 queue.append(u) else: self.dist[u] = -1 found = False while queue: u = queue.popleft() for v in self.graph[u]: pu = self.pair_v[v] if pu != -1 and self.dist[pu] == -1: self.dist[pu] = self.dist[u] + 1 queue.append(pu) if pu == -1: found = True return found def dfs(self, u): for v in self.graph[u]: pu = self.pair_v[v] if pu == -1 or (self.dist[pu] == self.dist[u] + 1 and self.dfs(pu)): self.pair_u[u] = v self.pair_v[v] = u return True self.dist[u] = -1 return False def max_matching(self) -> list[Edge]: result = 0 while self.bfs(): for u in range(self.n): if self.pair_u[u] == -1: if self.dfs(u): result += 1 return [ (u, v) for u, v in enumerate(self.pair_u) if v != -1 ] #################### BiPartiteGraph #################### Point = Tuple[int, int] Node = int Edge = Tuple[Node, Node] class BiPartiteGraph: def __init__(self, edges1: list[list[Node]], edges2: list[list[Node]], pts1: list[Point], pts2: list[Point]): self.edges1: list[list[Node]] = edges1 self.edges2: list[list[Node]] = edges2 self.pts1: list[Point] = pts1 self.pts2: list[Point] = pts2 def n1(self) -> int: return len(self.edges1) def n2(self) -> int: return len(self.edges2) def solve(self) -> list[Point]: # 最大マッチ hk = HopcroftKarp(len(self.edges1), self.n2()) for pt1, pts2 in enumerate(self.edges1): for pt2 in pts2: hk.add_edge(pt1, pt2) matches = hk.max_matching() set_matches = set(matches) Z1 = set(range(self.n1())) - set(u for u, v in matches) Z2 = set() s1 = set(Z1) while s1: s2 = set() for u in s1: for v in self.edges1[u]: if v not in Z2 and (u, v) not in set_matches: s2.add(v) Z2 |= s2 s1 = set() for v in s2: for u in self.edges2[v]: if u not in Z1 and (u, v) in set_matches: s1.add(u) Z1 |= s1 cZ2 = set(range(self.n2())) - Z2 return [ self.pts1[u] for u in Z1 ] + [ self.pts2[v] for v in cZ2 ] @staticmethod def create_each(pts1: list[Point], pts2: list[Point], A: int, B: int) -> list[list[Node]]: return v @staticmethod def create(points: list[Point], A: int, B: int) -> BiPartiteGraph: if A == 0: pts1 = [ (x, y) for x, y in points if (x+y)%(2*B) < B ] pts2 = [ (x, y) for x, y in points if (x+y)%(2*B) >= B ] else: e1, _ = div_pow(A, 2) e2, _ = div_pow(B, 2) b1 = 1 << min(e1, e2) if e1 == e2: pts1 = [ (x, y) for x, y in points if y//b1%2 == 0 ] pts2 = [ (x, y) for x, y in points if y//b1%2 == 1 ] else: pts1 = [ (x, y) for x, y in points if (x+y)%(2*b1) < b1 ] pts2 = [ (x, y) for x, y in points if (x+y)%(2*b1) >= b1 ] ids2 = { pt: i for i, pt in enumerate(pts2) } edges1: list[list[Node]] = [ [] for _ in pts1 ] edges2: list[list[Node]] = [ [] for _ in pts2 ] for id1, pt1 in enumerate(pts1): x, y = pt1 for pt2 in [(x+A, y+B), (x-A, y+B), (x-A, y-B), (x+A, y-B), (x+B, y+A), (x-B, y+A), (x-B, y-A), (x+B, y-A)]: id2 = ids2.get(pt2, -1) if id2 != -1: edges1[id1].append(id2) edges2[id2].append(id1) return BiPartiteGraph(edges1, edges2, pts1, pts2) #################### Table #################### class Table: def __init__(self, table: list[list[char]]): self.table: list[list[char]] = table def create_bipartite_graph(self, A: int, B: int) -> BiPartiteGraph: N = len(self.table) pts: list[Point] = [] for x, y in product(range(N), repeat=2): if self.table[x][y] == '.': pts.append((x, y)) return BiPartiteGraph.create(pts, A, B) def set_o(self, pt: Point) -> None: x, y = pt self.table[x][y] = 'o' def print(self) -> None: for v in self.table: print(''.join(v)) @staticmethod def read(N: int) -> Table: S = [ input() for _ in range(N) ] table = [ [ c for c in s ] for s in S ] return Table(table) #################### process #################### def read_input() -> tuple[int, int, int]: (N, A, B) = read_tuple() return (N, A, B) def F(N: int, A: int, B: int) -> int: table = Table.read(N) graph = table.create_bipartite_graph(A, B) pts = graph.solve() for pt in pts: table.set_o(pt) table.print() #################### main #################### (N, A, B) = read_input() F(N, A, B)