以下の内容はhttps://natsugiri.hatenablog.com/より取得しました。


Advent of Code 2025 の感想

Advent of Code 2025

入力データをダウンロードして内容を目視するのも問題の一部である。コードを書く前に確認したほうがいい。
adventofcode.com

毎年12月にプログラミングの問題が出題される。今年は12日間で、毎日Part1 と Part2の2問。合計24問。

Uiua 0.17.2 ですべて解いた。

コード

github.com

傾向

問題文からは入力データのサイズが不明なのでデータをダウンロードしてから推定する必要がある。それはいいのだが、入力データの特徴を見破って効率的なアルゴリズムを書く必要があることに気が付かなければならない。問題文が示す内容を解くための一般的な解法ではなく、データに特化したコードを書く。

特に Day9 は多角形の2頂点を選んで長方形を探す問題であるが、辺が接近しないデータになっているので効率的に解くことができる。気が付かなかったので時間がかかるコードになってしまった。

さらにDay12 は以下の6種類のピースの個数が指定されて長方形領域に入りきるかどうか判定する問題だが、自明な枝刈りとして

  • 3x3のピースが入りきるならば、明らかにピースが入る
  • ピースのマスの個数が長方形領域より大きいならば、明らかに入らない

これらが考えられる。その中間の、どちらでもない場合がとても難しいが、そのようなデータは入力データに含まれていないので高速に解くことができる。

線形RMQ 分割幅34のSparseテーブル

参考

おそらく、以下の Codeforces 解説が最も簡単な線形な静的RMQ実装である。これを自分で使いやすいようにインデックスを逆順にして実装したい。
codeforces.com

分割の幅はO(log n)なら良いが、34の定数にできるのでそれで実装する。

幅33以下の場合

左閉じ右開きの区間(L, R) が R-L ≦ 33 の場合を答えるために小さいRMQを作る。
インデックス i から、連続33要素を見たときに、接頭最小値になっているインデックスに印をつける。

この i より後ろの32要素を32-bitの整数で保存する。このビット表現を全てのiについて求めたいが、スライド最小値のアルゴリズムによって線形時間で求めることができる(ただし、各ビット演算は定数時間でできるとする)。
33よりも小さい範囲でRMQを求めたい場合は、ビットマスクで外側の範囲を0で上書きする。

L, Rが共に34の倍数の場合

34要素ごとに最小値を取り出した数列を作るとこれの長さはn/34である。この数列で単純なSparseテーブルを構築するとL,Rが34の倍数の場合にRMQを答えることができる。
実装は定数34を用いているが、log nで分割されたSparseテーブルの構築の時間計算量・空間計算量はO(n)である。

一般のL, Rの場合

R-L ≦ 33の場合は前述の通り計算できる。
そうでない場合は

// 34x は L を34の倍数に切り上げた値。
int x = (L + 33) / 34;
// 34y は R を34の倍数に切り下げた値。
int y = R / 34;

最小値の候補は3つ

  • 区間(L, L+33) の最小値を小さいRMQで求める
  • 区間(34x, 34y) の最小値を34で分割したSparseテーブルで求める
  • 区間(R-33, R)の最小値を小さいRMQで求める

この3つのうち、最小値が区間(L, R)の最小値である。

コード

judge.yosupo.jp

template<class T, class Comp=less<T> > struct LinearRmq {
    static const int WIDTH = 34;
    static const int BIT_LEN = sizeof(unsigned) * CHAR_BIT;
    static_assert(BIT_LEN == 32);

    vector<T> a;
    vector<unsigned> mask;
    vector<int> table;
    Comp comp;

    LinearRmq(const vector<T> &a_): LinearRmq(a_.begin(), a_.end()) {}

    template<class Iter> LinearRmq(Iter begin, Iter end): a(begin, end), mask(a.size()) {
	int n = a.size();
	unsigned cur = 0;
	for (int i=n, b; i--;) {
	    while (cur && select(i, i + 1 + (b = __builtin_ctz(cur))) == i) { cur ^= 1u << b; }
	    mask[i] = cur;
	    cur = cur << 1 | 1;
	}
	if (int len = n / WIDTH) {
	    int levels = log2plus1(len);
	    table.resize(len * levels, -1);
	    for (int i=0; i<len; i++) { table[i] = select(i*WIDTH, small(i*WIDTH+1)); }
	    for (int s=0; (2<<s)<=len; s++) for (int i=0; i+(2<<s)<=len; i++) {
		table[len*(s+1)+i] = select(table[len*s+i], table[len*s+i+(1<<s)]);
	    }
	}
    }

    int find_index(int l, int r) const {
	assert(0 <= l && l < r && r <= (int)a.size());
	if (r - l < WIDTH) {
	    return l + log2plus1(mask[l] & bit_mask(r - l - 1));
	}
	int ans = small(l);
	int lx = (l + WIDTH - 1) / WIDTH, rx = r / WIDTH;
	if (lx != rx) {
	    int len = (int)a.size() / WIDTH;
	    int s = BIT_LEN - 1 - __builtin_clz(rx - lx);
	    ans = select(ans, select(table[len*s+lx], table[len*s+rx-(1<<s)]));
	}
	ans = select(ans, small(r - (WIDTH - 1)));
	return ans;
    }

    T find_value(int l, int r) const {
	return a[find_index(l, r)];
    }

    int select(int i, int j) const {
	return comp(a[i], a[j])? i: j;
    }

private:

    static inline constexpr unsigned bit_mask(int n) {
	return n? ~0u >> (BIT_LEN - n): 0;
    }

    // x == 0 => 0;
    // o.w.   => floor(log2(x)) + 1;
    static inline constexpr int log2plus1(unsigned x) {
	return x? BIT_LEN - __builtin_clz(x): 0;
    }

    // argmin_{i| l <= i < l+WIDTH} a[i];
    int small(int l) const {
	return l + log2plus1(mask[l]);
    }
};

注意

線形RMQは構成の計算量がスパーステーブルより優れているが、出力命令の中での比較をする回数が異なる。

  • 単純なスパーステーブル: 1回
  • この記事の線形RMQ: 3回

RMQを問題解決のパーツとして使う時に、出力が高速であるスパーステーブルの方が有利になりやすい。

セグツリーをパトリシアツリーで実装する

概要

固定長二進パトリシアツリー

目標

  • 初期化:長さ 2^64 の理想配列 a, 全ての要素は値型の最大値で初期化される
  • 更新 (key, value) : a[key] を value に書き換える
  • 質問 (left, right) :min(a[left], ..., a[right]) を出力する

keyが符号なし64bitであるような、大きなセグツリーを作る
トライで実装するとその頂点数は更新の回数と高さの積になってしまう。パトリシアツリーならば頂点数は更新の回数の2倍で抑えられる。
そのような長さが64固定の二進パトリシアツリーを作り、分枝ノードには min を乗せる。

実装

頂点
struct Node {
    Node *left;
    Node *right;
    u64 key; // minimum key;
    Seg seg;
};

left, right はそれぞれ子ノードへのポインタ、key は部分木の中の最小のkeyを持つ。seg は部分木のvalue の最小値を持つ。

left と right が共にNULLのとき、そのときに限り node は葉である。更新ではkeyが葉のkeyと比べて大きい・小さい・等しいの3通りで場合分けする。

分枝

部分木に含まれる key の最小値 left_key と、右の部分木の right_key について、異なる最大のビットを branch_bit と定義する。これはXORの最大set bitであり、floor(log2(left_key ^ right_key)) で求められるが、GCCでは次のように書ける。

int branch_bit = 63 - __builtin_clzll(left_key ^ right_key);
// C++ のバージョンが新しいなら 63 - std::countl_zero(left_key ^ right_key) でも良い

分枝ノードは、次の条件を満たすような構造にする「左の部分木に含まれる key は branch_bit が 0」「右の部分木に含まれる key は branch_bit が 1」。

この条件を常に満たすように insert 関数を実装する。新しい key と left_key と right_key を branch_bit で切り落とすと4通りの場合分けをすれば良い。

unsigned b = 63 - __builtin_clzll(node->key ^ node->right->key);
new_branch = key >> b;
left_branch = node->key >> b;
right_branch = node->right->key >> b;

if (new_branch < left_branch) {
    node = new Node(nullptr, node);
    insert_rec(node->left, key, seg);
} else if (new_branch == left_branch) {
    insert_rec(node->left, key, seg);
} else if (new_branch == right_branch) {
    insert_rec(node->right, key, seg);
} else {
    node = new Node(node, nullptr);
    insert_rec(node->right, key, seg);
}
node->key = node->left->key;
node->seg = node->left->seg + node->right->seg;

特徴

最大値を見積らなくてもトライと同等以下の高さになる。最悪ケースで簡単に高さ64になるが、これはトライも同じなので許容できるものとする。平衡二分木と比べると回転がないので実装は簡単になる可能性がある。
ビット演算 clz を要求しているのが小さな欠点。

コード

提出結果 judge.yosupo.jp

クラス全体

struct Seg {
    LL value;

    Seg(): value(0x3FFFFFFFFFFFFFFFLL) {}
    Seg(LL x): value(x) {}

    Seg operator+(const Seg &o) const {
	return Seg(min(value, o.value));
    }
    static const Seg IDENTITY;
};
const Seg Seg::IDENTITY = Seg();

template<class Seg> struct BinaryPatriciaTree {
    using u64 = unsigned long long;

    BinaryPatriciaTree() {}
    ~BinaryPatriciaTree() { destructor(root); }

    Seg get(u64 key) const {
	if (!root) { return Seg::IDENTITY; }
	Node *node = root;
	while (!node->is_leaf()) {
	    node = (key < node->right->key? node->left: node->right);
	}
	return node->key == key? node->seg: Seg::IDENTITY;
    }

    void insert(u64 key, const Seg &seg) {
	insert_rec(root, key, seg);
    }

    Seg sum(u64 left_key, u64 right_key_exclusive) const {
	if (right_key_exclusive <= left_key) { return Seg::IDENTITY; }
	return sum_rec(root, ~u64(0), left_key, right_key_exclusive - 1);
    }

    Seg sum_inclusive(u64 left_key, u64 right_key_inclusive) const {
	if (right_key_inclusive < left_key) { return Seg::IDENTITY; }
	return sum_rec(root, ~u64(0), left_key, right_key_inclusive);
    }

private:
    struct Node {
	Node *left = nullptr, *right = nullptr;
	u64 key; // minimum key;
	Seg seg;

	Node() {}
	Node(Node *left_, Node *right_): left(left_), right(right_) {}
	Node(u64 key_, const Seg &seg_): key(key_), seg(seg_) {}

	bool is_leaf() const {
	    // assert(bool(left) == bool(right));
	    return !left;
	}
    };

    Node *root = nullptr;

    static void destructor(Node *node) {
	if (node) {
	    destructor(node->left);
	    destructor(node->right);
	    delete node;
	}
    }

    static void insert_rec(Node *&node, u64 key, const Seg &seg) {
	if (!node) {
	    node = new Node(key, seg);
	} else if (node->is_leaf() && node->key == key) {
	    node->seg = seg;
	} else {
	    u64 new_branch, left_branch, right_branch;
	    if (node->is_leaf()) {
		new_branch = key;
		left_branch = right_branch = node->key;
	    } else {
		unsigned b = 63 - __builtin_clzll(node->key ^ node->right->key);
		new_branch = key >> b;
		left_branch = node->key >> b;
		right_branch = node->right->key >> b;
	    }
	    if (new_branch < left_branch) {
		node = new Node(nullptr, node);
		insert_rec(node->left, key, seg);
	    } else if (new_branch == left_branch) {
		insert_rec(node->left, key, seg);
	    } else if (new_branch == right_branch) {
		insert_rec(node->right, key, seg);
	    } else {
		node = new Node(node, nullptr);
		insert_rec(node->right, key, seg);
	    }
	    node->key = node->left->key;
	    node->seg = node->left->seg + node->right->seg;
	}
    }

    static Seg sum_rec(const Node *node, u64 node_bound, u64 left_key, u64 right_key_inclusive) {
	if (!node || right_key_inclusive < node->key || node_bound < left_key) {
	    return Seg::IDENTITY;
	} else if (left_key <= node->key && node_bound <= right_key_inclusive) {
	    return node->seg;
	} else if (node->is_leaf()) {
	    return left_key <= node->key && node->key <= right_key_inclusive? node->seg: Seg::IDENTITY;
	} else {
	    return sum_rec(node->left, node->right->key - 1, left_key, right_key_inclusive)
		+ sum_rec(node->right, node_bound, left_key, right_key_inclusive);
	}
    }
};

Good Bye 2024 問 I1, I2 Affectionate Arrays 解法

始めに

I1 は最短の長さを答えるEasy版。I2 はそれが何通りできるか数え上げるHard版。
codeforces.com

数列 a が与えられる。a は以下の条件を満たしていることが保証されている。

  • 任意の i について、 |a_i| \le \sum a

次のような数列 b を作りたい。

  • b は部分列として a を含む(連続でなくても良い)
  •  \sum a = \sum b
  • b における任意の連続部分列の和は \sum b 以下である。つまり  \forall{(l, r)} \{ \sum_{i=l}^r b_i \le \sum b\}

I1: b の最短の長さを求めよ。
I2: 最短の長さのbは何通りあるか998244353で割った余りを求めよ。(同じ b でも異なるインデックスの選び方の部分列でaが存在するなら、異なるものとしてカウントする)

解法

連続部分列の和が \sum a となるような b ができる。さらに b のprefix sum は 0 以上  \sum a 以下であることが分かる。
動的計画法を考えて、その特徴からスピードアップできることが分かる。

dp(i, s, p) := 
- a の接頭のi要素を使って bの接頭を作っていて
- bの接頭の和はsで
- pが0ならi要素を見た直後、pが1なら追加で新しい要素を入れるか検討した後
- そのときの最短の長さ

s < 0 または  \sum a < s の場合は最短は存在しないので無限大として良い。
dp(i, s, 0) から dp(i, s', 1) の遷移を考えると、末尾に適切な1要素を追加することで dp(i, s', 1) は dp(i, *, 0)の最小値プラス1以下に出来る。さらにdp(i, *, 1) の最小値の範囲は連続している。

解 I1: このdpテーブルのdp(n, sumA, 1)が求めるべき解であるが、求めるには途中の最小値の値とその範囲だけ計算すれば良い。

解 I2: このdpテーブルで最短を計算していたが(最短, 場合の数) を計算することにする。dp(i, *, 1) を Run-length で持てば良い。区間に加算する必要があるので差分を持つことにして回避する。

======

コード

// I1
#include<stdio.h>
#include<iostream>
#include<vector>
#include<algorithm>
#include<string>
#include<string.h>

#ifdef LOCAL
#define eprintf(...) fprintf(stderr, __VA_ARGS__)
#else
#define NDEBUG
#define eprintf(...) do {} while (0)
#endif
#include<cassert>

using namespace std;

typedef long long LL;
typedef vector<int> VI;

#define REP(i,n) for(int i=0, i##_len=(n); i<i##_len; ++i)
#define EACH(i,c) for(__typeof((c).begin()) i=(c).begin(),i##_end=(c).end();i!=i##_end;++i)

template<class T> inline void amin(T &x, const T &y) { if (y<x) x=y; }
template<class T> inline void amax(T &x, const T &y) { if (x<y) x=y; }
#define rprintf(fmt, begin, end) do { const auto end_rp = (end); auto it_rp = (begin); for (bool sp_rp=0; it_rp!=end_rp; ++it_rp) { if (sp_rp) putchar(' '); else sp_rp = true; printf(fmt, *it_rp); } putchar('\n'); } while(0)

int N;
int A[3000011];

void MAIN() {
    scanf("%d", &N);
    REP (i, N) scanf("%d", A+i);

    LL target = 0;
    REP (i, N) target += A[i];

    LL l = 0, r = 0;
    int cost = 0;
    REP (i, N) {
	l += A[i];
	r += A[i];
	if (l <= target && 0 <= r) {
	    amax(l, 0LL);
	    amin(r, target);
	} else if (r < 0) {
	    l = 0;
	    r = target + A[i];
	    cost++;
	} else {
	    assert(target < l);
	    l = A[i];
	    r = target;
	    cost++;
	}
    }
    cost += N;
    if (r < target) {
	cost++;
    }
    printf("%d\n", cost);
}

int main() {
    int TC = 1;
    scanf("%d", &TC);
    REP (tc, TC) MAIN();
    return 0;
}
// I2
#include<stdio.h>
#include<iostream>
#include<vector>
#include<algorithm>
#include<string>
#include<string.h>
#include<queue>

#ifdef LOCAL
#define eprintf(...) fprintf(stderr, __VA_ARGS__)
#else
#define NDEBUG
#define eprintf(...) do {} while (0)
#endif
#include<cassert>

using namespace std;

typedef long long LL;
typedef vector<int> VI;

#define REP(i,n) for(int i=0, i##_len=(n); i<i##_len; ++i)
#define EACH(i,c) for(__typeof((c).begin()) i=(c).begin(),i##_end=(c).end();i!=i##_end;++i)

template<class T> inline void amin(T &x, const T &y) { if (y<x) x=y; }
template<class T> inline void amax(T &x, const T &y) { if (x<y) x=y; }
#define rprintf(fmt, begin, end) do { const auto end_rp = (end); auto it_rp = (begin); for (bool sp_rp=0; it_rp!=end_rp; ++it_rp) { if (sp_rp) putchar(' '); else sp_rp = true; printf(fmt, *it_rp); } putchar('\n'); } while(0)
template<unsigned MOD_> struct ModInt {
    static constexpr unsigned MOD = MOD_;
    unsigned x;
    void undef() { x = (unsigned)-1; }
    bool isnan() const { return x == (unsigned)-1; }
    inline int geti() const { return (int)x; }
    ModInt() { x = 0; }
    ModInt(int y) { if (y<0 || (int)MOD<=y) y %= (int)MOD; if (y<0) y += MOD; x=y; }
    ModInt(unsigned y) { if (MOD<=y) x = y % MOD; else x = y; }
    ModInt(long long y) { if (y<0 || MOD<=y) y %= MOD; if (y<0) y += MOD; x=y; }
    ModInt(unsigned long long y) { if (MOD<=y) x = y % MOD; else x = y; }
    ModInt &operator+=(const ModInt y) { if ((x += y.x) >= MOD) x -= MOD; return *this; }
    ModInt &operator-=(const ModInt y) { if ((x -= y.x) & (1u<<31)) x += MOD; return *this; }
    ModInt &operator*=(const ModInt y) { x = (unsigned long long)x * y.x % MOD; return *this; }
    ModInt &operator/=(const ModInt y) { x = (unsigned long long)x * y.inv().x % MOD; return *this; }
    ModInt operator-() const { return (x ? MOD-x: 0); }

    ModInt inv() const { return pow(MOD-2); }
    ModInt pow(long long y) const {
	ModInt b = *this, r = 1;
	if (y < 0) { b = b.inv(); y = -y; }
	for (; y; y>>=1) {
	    if (y&1) r *= b;
	    b *= b;
	}
	return r;
    }

    friend ModInt operator+(ModInt x, const ModInt y) { return x += y; }
    friend ModInt operator-(ModInt x, const ModInt y) { return x -= y; }
    friend ModInt operator*(ModInt x, const ModInt y) { return x *= y; }
    friend ModInt operator/(ModInt x, const ModInt y) { return x *= y.inv(); }
    friend bool operator<(const ModInt x, const ModInt y) { return x.x < y.x; }
    friend bool operator==(const ModInt x, const ModInt y) { return x.x == y.x; }
    friend bool operator!=(const ModInt x, const ModInt y) { return x.x != y.x; }
};

constexpr LL MOD = 998244353;
using Mint = ModInt<MOD>;

struct Elem {
    LL len;
    int cost;
    Mint value;
    Elem() {}
    Elem(LL len_, int cost_, Mint value_): len(len_), cost(cost_), value(value_) {}
};
deque<Elem> dq;
Mint C[3000011];
Mint offset[3000011];

void right_delete(LL n) {
    while (n) {
	int cost = dq.back().cost;
	Mint value = dq.back().value;
	LL g = min(n, dq.back().len);
	C[cost] -= g * (value + offset[cost]);

	dq.back().len -= g;
	n -= g;

	if (dq.back().len == 0) {
	    dq.pop_back();
	}
    }
}

void left_delete(LL n) {
    while (n) {
	int cost = dq.front().cost;
	Mint value = dq.front().value;
	LL g = min(n, dq.front().len);
	C[cost] -= g * (value + offset[cost]);

	dq.front().len -= g;
	n -= g;

	if (dq.front().len == 0) {
	    dq.pop_front();
	}
    }
}


int N;
int A[3000011];

void MAIN() {
    dq.clear();
    scanf("%d", &N);
    REP (i, N) scanf("%d", A+i);

    LL target = 0;
    REP (i, N) target += A[i];

    LL l = 0, r = 0;
    int cost = 0;
    dq.emplace_back(1, 0, Mint(1) - offset[0]);
    C[0] = 1;
    if (target) {
	dq.emplace_back(target, 1, Mint(1) - offset[1]);
	C[1] = target;
    }

    REP (i, N) {
	if (A[i] > 0) {
	    LL a = A[i];
	    right_delete(a);
	    if (target < l + a) {
		assert(C[cost] == 0);
		cost++;
		C[cost+1] = 0;
		l = a;
		r = target;
	    } else {
		l += a;
		r = min(r + a, target);
	    }
	    dq.emplace_front(a, cost + 1, -offset[cost + 1]);
	}
	if (A[i] < 0) {
	    LL a = -A[i];
	    left_delete(a);
	    if (r - a < 0) {
		cost++;
		C[cost+1] = 0;
		l = 0;
		r = target - a;
	    } else {
		r -= a;
		l = max(0LL, l - a);
	    }
	    dq.emplace_back(a, cost + 1, -offset[cost + 1]);
	}

	offset[cost + 1] += C[cost];
	C[cost + 1] += C[cost] * ((target + 1) - (r - l + 1));
    }

    Mint ans = dq.back().value + offset[dq.back().cost];
    printf("%d\n", ans.geti());
}

int main() {
    int TC = 1;
    scanf("%d", &TC);
    REP (tc, TC) MAIN();
    return 0;
}

JavaのModIntを考える

序論

プログラミングコンテストでは「答えの数を1000000007で割った余りを求めよ」という問題が大変よく出題される。1000000007は素数で、ほかにも 998244353 が代わりに出題されることもよくある。これらは真の解は大きくなりすぎて計算できない場合でも、加減乗算をするたびにその素数で割った余りを求めてあげれば解が変わらない。

ModInt

ソースコード中では問題で与えられた大きな素数は定数変数の mod と書くことにする。
a * (b + c) を求めたいときに、演算の度に余りを求めると

ans = a * ((b + c) % mod) % mod;

複雑な式になると %mod を書き忘れてオーバーフローで誤答になる危険があるので何とか記法を工夫したい。このモチベーションはとても大きい。

課題点

演算子オーバーロードがない

+ - * などの二項演算子オーバーロードできるプログラミング言語も存在するが Java はできない。そのためどう工夫しても数式的な中置記法をあきらめなければならない。

new のコスト

(a + b) % mod を求めたいだけなのに new で Object を作るのは実行時間に影響があるかもしれないため、new を使う場合はそのオーバーヘッドを知っておきたい。

結果

適当に加算と乗算を1億回するコードを、mod の行を工夫して実行時間を測る。
テストコードはこれ
ModIntTest · GitHub

for (int i = 0; i < a.length; i++) {
    for (int j = 0; j < b.length; j++) {
        ans = (ans * ((a[i] + b[j]) % mod)) % mod;
    }
}

単位はms。

Test AtCoder Codeforces Ideone
1 simple 439 1898 406
2 non final 803 1978 2053
3 function 444 1944 661
4 function2 486 2048 416
5 mint 582 3843 1317
6 mutable 560 2519 700

なぜか Codeforcesが非常に遅い。Codeforcesはそもそも除算・剰余算が驚くほど遅いため、記法の工夫の以前の問題として、演算回数をモンゴメリ乗算などで減らす方がいい。
AtCoder は new が予想よりも早い。正直、どの実装でもいい。
Ideone が最も予想に近い比率の時間になる。関数にすれば遅くなるし、newを使う回数だけもっと遅くなる。

テスト

Test1 simple

for (int i = 0; i < a.length; i++) {
    for (int j = 0; j < b.length; j++) {
        ans = (ans * ((a[i] + b[j]) % mod)) % mod;
    }
}

modを全て手打ちする。typoをしない限りは簡単かつ高速。
加算は mod の2倍を超えないので「剰余算」の代わりに「条件分岐と減算」でもできるが AtCoder では実行時間はほぼ変わらない。遅くなることも有る。

Test2 non final

// final を書き忘れる
long mod = 1000000007; // NG

final にするだけで速くなるので final は付けよう。もしくは

long mod() { return 1000000007; } // GOOD

のような定数関数にしても高速。オーバーライドしても高速なので 998244353 と共存させることもできる。

Test3 functions

ans = mul(ans, add(a[i], b[j]));

ただしstatic関数を定義する。

class ModIntStatic {
    static final long mod = 1000000007;

    static long add(long x, long y) {
        return (x + y) % mod;
    }

    static long mul(long x, long y) {
        return x * y % mod;
    }
}

関数記法になってしまうが実装が軽く実行速度が速く、可読性もよい。long のままなので関数を通し忘れるとオーバーフローの危険があるが、% mod を書き忘れるよりは防ぎやすい気がする。

Test4 function2

interface ModInt {
    public long mod();

    class ModInt998244353 implements ModInt { public long mod() { return 998244353; } }
    class ModInt1000000007 implements ModInt { public long mod() { return 1000000007; } }

    default long add(long x, long y) {
        return (x + y) % mod();
    }

    default long mul(long x, long y) {
        return x * y % mod();
    }
}

static ではなくなったが同様に高速。オーバーライドで複数の素数に対応できる。

Test5 Mint

ans = ans.mul(a[i].add(b[j]));

ただし、ans, a[i], b[j] は Mint オブジェクト。

class Mint {
    static final long mod = 1000000007;

    final long x;
    Mint(long x) { this.x = x; }

    Mint add(Mint y) {
        return new Mint((x + y.x) % mod);
    }

    Mint mul(Mint y) {
        return new Mint(x * y.x % mod);
    }
}

メソッドにしたので語順が中置記法になった。クラス同士でしか演算できないため、最も安全だと思われる。AtCoder は new がなぜか高速なのでこれも十分良いが、ほかのコンテストサイトでは遅いかもしれない。BigInteger と同じ API にできるのも良い。

Test6 Mutable Mint

tmp.assign(a[i]).addAssign(b[j]);
ans.mulAssign(tmp);

ただし ans, tmp, a[i], b[j] は Mint オブジェクト。

class Mint {
    static final long mod = 1000000007;

    long x;

    Mint(long x) { this.x = x; }

    Mint assign(Mint y) {
        x = y.x;
        return this;
    }

    Mint addAssign(Mint y) {
        x = (x + y.x) % mod;
        return this;
    }

    Mint mulAssign(Mint y) {
        x = x * y.x % mod;
        return this;
    }
}

new がなくなったのでどんな環境でも速いはず。ただし、数式が複雑になるだけ tmp 変数を用意しなければならない。また、変更可能なオブジェクトは関数の引数に渡したときに関数内で書き換えられることを心配する必要があるので、慎重に使わなければならない。

プリミティブタイプはJavaのジェネリクスに入らない

int, long, doubleなどのプリミティブな型を使うことは当然だ。これらに対して同じデータ構造やアルゴリズムを適用したいことがあるのも当然だ。
ジェネリクスがあるのだからint, long, doubleで共通な計算は実装できそうなものだが、できない。

例えば配列の和を求めるクラスを作りたいと思う。

// error
class Sum<T> {
    T sum(T[] a) {
        T ret = 0;
        for (int i = 0; i < a.length; i++) {
            ret += a[i];
        }
        return ret;
    }
}

困難① ジェネリクスにはプリミティブ型を入れられない

そのため `new Sum<long>().sum(array)` のように書けない
では long の代わりにラッパーである Long を使えばいいかと言うとそうではない。

困難② ジェネリクス型は `new` が使えない、`+=` が使えない

Integer型やLong型では演算子が使えるが、ジェネリクスになると利用できない。Javaジェネリクスはそういうもの。

解決

そんな不便なことがあるだろうか?じゃあ`Arrays.sort`はどう実装してるんだろう?JDKソースコードを見てみる。

public static void sort(int[] a)

プリミティブ型の`sort`や`binarySearch`などはそれぞれの型に対して全く同じ実装が(型だけ変えて)繰り返し記述されている。これをマネすれば良さそう。

プログラミングコンテストで使う場合はどうせソースコードをコピー&ペーストする。よって long 型だけ用意しておいて double 型が必要ならエディタで置換するのが最も良い方法だと思う。

疑似乱数メルセンヌツイスタは連続 624 項が分かれば完全に再現できる

標準ライブラリのメルセンヌツイスタは実装と定数が公開されている。すると、出力された疑似乱数列から内部状態をコピーすることができる。

メルセンヌツイスタの実装

英語版の wikipedia の実装を見る。
en.wikipedia.org

定数

標準ライブラリは周期が 2^19937 - 1 になるような定数があらかじめ設定されている。

関数

配列 unsigned mt[624] は長い周期で変化するカウンタの役割をする。mt[624] は実質リングバッファで、index でずれていく。
wikipedia の実装では「初めて乱数を生成するまで初期化しない」「twistは 624 回に 1 回だけまとめて行う」ような効率化がされていた。

乱数を1つ生成する関数の中で tempering (調律)をしている

unsigned y = mt[index];
y ^= (y >> U) & D;
y ^= (y << S) & B;
y ^= (y << T) & C;
y ^= (y >> L);

最終的にこの y を出力するので、tempering の逆関数を作ると配列 mt をコピーできるはず。

逆関数

定数に依存すれば結構簡単な逆関数ができる。

static unsigned inverse_tempering(unsigned y) {
    // inverse y ^= y >> L;
    y ^= y >> L;

    // inverse y ^= (y << T) & C; 
    y ^= (y << T) & C;

    // inverse y ^= (y << S) & B;
    y ^= (y << S) & 0x00001680;
    y ^= (y << S) & 0x000C4000;
    y ^= (y << S) & 0x0D200000;
    y ^= (y << S) & 0x90000000;

    // inverse y ^= (y >> U) & D;
    y ^= (y >> U) & 0x001FFC00;
    y ^= (y >> U) & 0x000003FF;

    return y;
}

コード全体

struct Random {
private:
    static const int W = 32; // word size;
    static const int N = 624; // degree of recurrence;
    static const int M = 397; // shift size;
    static const int R = 31; // mask bits;
    static const unsigned A = 0x9908B0DF; // XOR mask;
    
    // tempering;
    static const int U = 11;
    static const unsigned D = 0xFFFFFFFF;
    static const int S = 7;
    static const unsigned B = 0x9D2C5680;
    static const int T = 15;
    static const unsigned C = 0xEFC60000;
    static const int L = 18;

    static const unsigned F = 1812433253; // initialization;

    static const unsigned LOWER_MASK = (1u << R) - 1u;
    static const unsigned UPPER_MASK = ~0u ^ LOWER_MASK;

    static const unsigned DEFAULT_SEED = 5489;

    unsigned mt[N];
    int index;

    void twist(int i) {
	unsigned x = (mt[i] & UPPER_MASK) | (mt[(i + 1) % N] & LOWER_MASK);
	unsigned xA = x >> 1;
	if (x % 2 != 0) {
	    xA ^= A;
	}
	mt[i] = mt[(i + M) % N] ^ xA;
    }

    void initialize(unsigned seed) {
	mt[0] = seed;
	for (int i=1; i<N; i++) {
	    mt[i] = F * (mt[i-1] ^ (mt[i-1] >> (W - 2))) + i;
	}
	for (int i=0; i<N; i++) {
	    twist(i);
	}
	index = 0;
    }

    unsigned extract_number() {
	if (index >= N) {
	    index = 0;
	}

	unsigned y = mt[index];
	y ^= (y >> U) & D;
	y ^= (y << S) & B;
	y ^= (y << T) & C;
	y ^= (y >> L);

	twist(index);

	index++;
	return y;
    }

    static unsigned inverse_tempering(unsigned y) {
	// inverse y ^= y >> L;
	y ^= y >> L;

	// inverse y ^= (y << T) & C; 
	y ^= (y << T) & C;

	// inverse y ^= (y << S) & B;
	y ^= (y << S) & 0x00001680;
	y ^= (y << S) & 0x000C4000;
	y ^= (y << S) & 0x0D200000;
	y ^= (y << S) & 0x90000000;

	// inverse y ^= (y >> U) & D;
	y ^= (y >> U) & 0x001FFC00;
	y ^= (y >> U) & 0x000003FF;

	return y;
    }

public:
    Random(int seed=DEFAULT_SEED) {
	initialize(seed);
    }

    Random(const vector<unsigned> &prv) {
	assert(prv.size() == N);
	for (int i=0; i<N; i++) {
	    mt[i] = inverse_tempering(prv[i]);
	}
	for (int i=0; i<N; i++) {
	    twist(i);
	}
	index = 0;
    }

    unsigned operator()() {
	return extract_number();
    }
};

void MAIN() {
    mt19937 engine(100);

    vector<unsigned> prv;
    REP (i, 624) {
	prv.push_back(engine());
    }
    Random random(prv);

    assert(engine() == random());
}



以上の内容はhttps://natsugiri.hatenablog.com/より取得しました。
このページはhttp://font.textar.tv/のウェブフォントを使用してます

不具合報告/要望等はこちらへお願いします。
モバイルやる夫Viewer Ver0.14