こんにちは。そういえばハッシュ関数ってしょっちゅう使われているのに、自分で構成を追ったことがないことに気づきました。そこで今回の記事では、ハッシュ関数として代表的な SHA-256 について、ネットで転がっている記事を読みつつ、書き慣れている Python と Rust で実装していきたいと思います。
注意:この記事に書かれているプログラムは、私が勉強用に書いた質の低いものです。実際に SHA-256 を用いる際は、適切なライブラリの関数を利用してください。その方が安全であり、かつ高いパフォーマンスが得られるはずです。
ハッシュ関数とは
ハッシュ関数はその名の通り関数です。関数なので入力と出力があります。入力は任意の文字列で、出力は固定長の文字列になります。ハッシュ関数の入力のことをメッセージ、出力のことをハッシュ値と呼ぶようです。
今回取り扱う SHA-256 は暗号学的ハッシュ関数と呼ばれるものの一つのようです。暗号学的ハッシュ関数は、以下のような性質を持つことが求められます。
- ハッシュ値から、そのようなハッシュ値となるメッセージを得ることが(事実上)不可能であること(原像計算困難性、弱衝突耐性)。
- 同じハッシュ値となる、異なる2つのメッセージのペアを求めることが(事実上)不可能であること(強衝突耐性)。
- メッセージをほんの少し変えたとき、ハッシュ値は大幅に変わり、元のメッセージのハッシュ値とは相関がないように見えること。
(暗号学的ハッシュ関数 - Wikipedia より引用)
このような性質を持つ関数があるおかげで、例えばパスワードそのものを保存する代わりに、パスワードをハッシュ値にしたものを保存することで、安全性を高めることができます。この記事の目的はハッシュ関数を詳しく解説することではないので、用途を詳しく知りたい方は調べてみてください。
参考にしたページ
概ね二つのサイトを参考にしました。
まずこちらは TypeScript で実際に実装してみた、という記事です。この記事を見て自分も同じような記事を書いてみたくなったのがそもそもの発端です。
zenn.dev
こちらは図や数式を使ってフローを分かりやすく説明しているサイトです。先ほどのサイトは具体的な記述がメインのため大きな流れが分かりづらい部分があったので、こちらのサイトで補完する形になりました。逆に細かい部分は先ほどのサイトに助けてもらいました。
www.ios-net.co.jp
フローを整理
SHA-256 の全体の流れは以下の通りです:
- 入力データの末尾に数ビット付け足して、全体の長さが (512 の倍数) bit になるようにする(パディング)
- パディング済みのメッセージを 512 bit ごとに分割する
- 各 512 bit ごとに 32 bit の "ワード" 64 個 (
) を生成する
- 64個のワードから定まる変換ルール(256 bit → 256 bit)を 512 bit の塊の数だけ作り、それらをある初期値に順に作用させる
- 最終的にできた 256 bit の数値を64桁の16進法の文字列として出力する
そこで以下の機能を実装すればハッシュ関数を完成させることができます。
- 入力データをパディングして 512bit ごとの数値の列にする関数
- 512 bit の入力に対して
を作る関数
に対して変換規則を作る関数
- 変換規則を初期値に順に適用する関数
これらを順番に作っていきます。
ところで、ハッシュ関数は文字列を文字列にする関数でしたが、上記の処理は純粋に数値に対する関数として定義したほうが見通しがよさそうです。そこで、文字列は最初と最後だけ登場するようにして、中間では常に整数を扱うように実装することにします。
各機能の実装
入力データをパディングして 512bit ごとの数値の列にする関数
先ほどの説明では省いていましたが、実はパディング後の文字列は長さが512の倍数になる他にも、以下のルールを満たす必要があります。
- 一番末尾の8バイトは入力データの文字数を格納する
- 入力データが終わった直後は
を付ける
- 隙間は0で埋める
これらを満たすもののうち最も短いものを求める関数を作る必要があります。
なお、全角文字を取り扱う場合は、どの文字コードを使うかで結果が変わります。この記事では話を単純にするため、メッセージは常に半角の英数字+記号で構成されているもののみを扱うことにします。つまり、一文字は常に1バイトに対応します。
実際の流れは以下のようにします。
- 入力の文字列を文字ごとに分解し、それぞれを 8 bit の整数に変換したものを並べる。Python は list, Rust は Vec にする。
- ↑の末尾に
を付け足す。
- 後ろに 8 個以上の 0 を付け足して、長さが 64 (=512 / 8) の倍数になるようにする。
- 末尾の 8 個の部分を 64bit 整数と見て、入力の文字列の bit 数を格納する。
- 64 個ずつの塊を並べたものを返す
# Python def pad(msg: str) -> [[int]]: def _ceil_64(n: int) -> int: """n 以上の最小の64の倍数を返す""" return (n + 63) // 64 * 64 bytes_list: list[int] = list(msg.encode()) # 1 バイトずつの数値のリストを作る total_bits = len(bytes_list) * 8 # 最後に末尾に格納する # 0x80 と 0x00 で埋めて、長さを 64 の倍数にする bytes_list.append(0x80) padded_len = _ceil_64(len(bytes_list) + 8) bytes_list.extend([0] * (padded_len - len(bytes_list))) # 末尾に入力文字列のビット数を格納する for i in range(8): digit = (total_bits >> (8 * i)) & 0xff # 下から i バイト目の値 bytes_list[-1 - i] = digit # 64 バイトずつのブロックに分割する blocks = [ bytes_list[idx:idx + 64] for idx in range(0, len(bytes_list), 64) ] return blocks
// Rust fn pad(msg: &str) -> Vec<[u8; 64]> { let mut bytes_vec = msg.as_bytes().to_vec(); // 1 バイトずつの数値のリストを作る let total_bits = bytes_vec.len() * 8; // 最後に末尾に格納する // 0x80 と 0x00 で埋めて、長さを 64 の倍数にする bytes_vec.push(0x80); let padded_len = (bytes_vec.len() + 8).div_ceil(64) * 64; bytes_vec.resize(padded_len, 0); // 末尾に入力文字列のビット数を格納する for i in 0..8 { let digit = (total_bits >> (8 * i)) & 0xff; // 下から i バイト目の値 bytes_vec[padded_len - i - 1] = digit as u8; } // 64 バイトずつのブロックに分割する let mut blocks = Vec::with_capacity(padded_len / 64); for i in (0..padded_len).step_by(64) { let mut block = [0; 64]; for j in 0..64 { block[j] = bytes_vec[i + j]; } blocks.push(block); } blocks }
512 bit の入力に対して
を作る関数
ここは少々定義が複雑なので、しっかり確認していきましょう。
まずいきなり定義を書いてみます。最初に 512 bit の数を 32 bit ずつ計16個の数に分割し、先頭から順に とします。さらに
に対しては
さて
さて、道具がそろったので実装してみます。まずは
# Python U32MAX: int = 0xffffffff # 二進法で32桁の1が並ぶ数 def rotate_right(x: int, n: int) -> int: return (x >> n) | (x << (32 - n)) & U32MAX # はみ出た部分は捨てる def sigma0(x: int) -> int: return rotate_right(x, 7) ^ rotate_right(x, 18) ^ (x >> 3) def sigma1(x: int) -> int: return rotate_right(x, 17) ^ rotate_right(x, 19) ^ (x >> 10)
// Rust fn sigma0(x: u32) -> u32 { x.rotate_right(7) ^ x.rotate_right(18) ^ (x >> 3) } fn sigma1(x: u32) -> u32 { x.rotate_right(17) ^ x.rotate_right(19) ^ (x >> 10) }
これらを使って を求める関数はこうなります。
# Python U32MAX: int = 0xffffffff # ~~ 中略 ~~ def create_ws(block_64bytes: [int]) -> [int]: def _4bytes_to_32bit(nums: [int]) -> int: n0, n1, n2, n3 = nums return (n0 << 24) | (n1 << 16) | (n2 << 8) | n3 ws = [0] * 64 for t in range(16): sub = block_64bytes[4 * t: 4 * (t + 1)] ws[t] = _4bytes_to_32bit(sub) for t in range(16, 64): ws[t] = (sigma1(ws[t - 2]) + ws[t - 7] + sigma0(ws[t - 15]) + ws[t - 16]) & U32MAX return ws
// Rust // ~~ 前略 ~~ fn create_ws(block: &[u8; 64]) -> [u32; 64] { let mut ws = [0; 64]; for t in 0..16 { ws[t] = u32::from_be_bytes([ block[t * 4], block[t * 4 + 1], block[t * 4 + 2], block[t * 4 + 3], ]); } for t in 16..64 { ws[t] = sigma1(ws[t - 2]) .wrapping_add(ws[t - 7]) .wrapping_add(sigma0(ws[t - 15])) .wrapping_add(ws[t - 16]); } ws }
に対して変換規則を作る関数
ここが一番の難関です。頑張りましょう。
32 bit の整数たち が与えられた状態で、256 bit 整数を 256 bit 整数にする関数を作ります。実際には処理の流れは
- 256 bit の数を 32 bit ずつ 8 つの数に分解する
を使って所定の処理をして、新たな 32 bit 8 つの数を作る
- これを並べた 256 bit の数を作る
となるので、実質 32 bit × 8 → 32bit × 8 の関数を作ることになります。
そこで、引数に 8 つの 32 bit 整数 が与えられたとして、戻り値
を求める手順を説明します。
最初に に対してとある操作を64回行います。操作を行う前のものを
として、
まず一時的に用いる変数 を以下のように定義します。
まず はいずれも 32 bit 整数を引数に取る関数です。次の式で定義されます。
はいずれも 32 bit 整数3つを引数にとる関数です。以下のように定義されます。
ちなみに は choose または choice の略、
は majority の略だそうです。
の方は「
の各 bit について、0 なら
の bit を、1なら
の bit を採用する」という振る舞いから、
の方は「各 bit について、最も多いものを採用する」という振る舞いから来ているらしいです*1。
最後に ですが、これは定数です。どんな定数かというと、素数を小さい方から
としたときの
の小数点以下 32 bitだそうです。なんじゃそりゃ。定数なのでハードコードしてもいいのですが、打ち間違いの確認が面倒くさいので最初に計算させることにします。Rust に関しては const で書いておけばコンパイル時に評価してくれるので、実行時の計算時間は気にしないで済みます。
さて、これで が定義できました。以上を用いて
は次のように定義されます。
以上で から
を得る方法の説明が終わりました。最後に
非常に長い説明でした。お疲れさまでした。以上で計算方法が分かったので、パーツを一つひとつ実装していきましょう。まず簡単な から。
# Python def upper_sigma0(x: int) -> int: return rotate_right(x, 2) ^ rotate_right(x, 13) ^ rotate_right(x, 22) def upper_sigma1(x: int) -> int: return rotate_right(x, 6) ^ rotate_right(x, 11) ^ rotate_right(x, 25)
// Rust fn upper_sigma0(x: u32) -> u32 { x.rotate_right(2) ^ x.rotate_right(13) ^ x.rotate_right(22) } fn upper_sigma1(x: u32) -> u32 { x.rotate_right(6) ^ x.rotate_right(11) ^ x.rotate_right(25) }
同様に と
も定義していきます。
# Python U32MAX: int = 0xffffffff # ~ 中略 ~ def choose(x: int, y: int, z: int) -> int: return (x & y) ^ ((x ^ U32MAX) & z) def majority(x: int, y: int, z: int) -> int: return (x & y) ^ (y & z) ^ (z & x)
// Rust fn choose(x: u32, y: u32, z: u32) -> u32 { (x & y) ^ (!x & z) } fn majority(x: u32, y: u32, z: u32) -> u32 { (x & y) ^ (y & z) ^ (z & x) }
次に定数 を定義します。立方根の小数点以下の数値を求めるために、ニュートン法で
乗根を求める関数を用意しておきます。
# Python def nth_root(x: int, n: int) -> int: """ニュートン法でn乗根を求める関数""" out = x memory = 0 while out != memory: memory = out out = memory - (memory ** n - x) // (n * memory ** (n - 1)) return out - 1 def is_prime(n: int) -> bool: if n <= 1: return False return all(n % d != 0 for d in range(2, n)) U32MAX = 0xffffffff PRIMES = tuple(n for n in range(2, 312) if is_prime(n)) CBRT_PRIMES = tuple( nth_root(p << 32 * 3, 3) & U32MAX for p in PRIMES ) # K_t
// Rust const PRIMES: [u128; 64] = { let mut primes = [0; 64]; let mut idx = 0; let mut p = 0; while idx < 64 { if is_prime(p) { primes[idx] = p; idx += 1; } p += 1; } primes }; const CBRT_PRIMES: [u32; 64] = { let mut cbrt_primes = [0; 64]; let mut idx = 0; while idx < 64 { cbrt_primes[idx] = nth_root(PRIMES[idx] << 32 * 3, 3) as u32; idx += 1; } cbrt_primes }; // K_t const fn is_prime(n: u128) -> bool { if n <= 1 { return false; } let mut d = 2; while d < n { if n % d == 0 { return false; } d += 1; } true } /// ニュートン法でn乗根を求める関数 const fn nth_root(x: u128, n: u32) -> u128 { if x == 0 { return 0; } let mut out = (x >> ((127 - x.leading_zeros()) * (n - 1) / n)) + 1; let mut memory = 0; while memory != out { memory = out; out = memory - (memory.pow(n) - x) / ((n as u128) * memory.pow(n - 1)); } out - 1 }
準備が整ったので、変換規則を実装していきます。 をセットすると変換規則が生えてほしいので、Python では class を、Rust では struct を使います。インスタンス生成時に
を設定し、メソッド update でハッシュを変換できるようにします。
# Python class HashUpdater: def __init__(self, ws: [int]): self._ws = ws def update(self, hash_: [int]) -> [int]: tmp = hash_ for step in range(64): tmp = self._update_step(tmp, step) return [(x + y) & U32MAX for x, y in zip(hash_, tmp)] def _update_step(self, hash_: [int], step: int) -> [int]: t1 = self._calc_t1(hash_, step) t2 = self._calc_t2(hash_) updated = [0] * 8 for n in range(8): match n: case 0: updated[n] = (t1 + t2) & U32MAX case 4: updated[n] = (hash_[n - 1] + t1) & U32MAX case _: updated[n] = hash_[n - 1] return updated def _calc_t1(self, hash_: [int], step: int) -> int: t1 = ( hash_[7] + upper_sigma1(hash_[4]) + choose(hash_[4], hash_[5], hash_[6]) + self._ws[step] + CBRT_PRIMES[step] ) return t1 & U32MAX @staticmethod def _calc_t2(hash_: [int]) -> int: t2 = upper_sigma0(hash_[0]) + majority(hash_[0], hash_[1], hash_[2]) return t2 & U32MAX
// Rust pub struct HashUpdater { ws: [u32; 64], } impl HashUpdater { pub fn new(ws: [u32; 64]) -> Self { Self { ws } } pub fn update(&self, hash: [u32; 8]) -> [u32; 8] { let mut tmp = hash; for step in 0..64 { tmp = self.update_step(tmp, step); } let mut updated = [0; 8]; for n in 0..8 { updated[n] = hash[n].wrapping_add(tmp[n]); } updated } fn update_step(&self, hash: [u32; 8], step: usize) -> [u32; 8] { let t1 = self.calc_t1(&hash, step); let t2 = Self::calc_t2(&hash); let mut updated = [0; 8]; for n in 0..8 { updated[n] = match n { 0 => t1.wrapping_add(t2), 4 => hash[3].wrapping_add(t1), _ => hash[n - 1], }; } updated } fn calc_t1(&self, hash: &[u32; 8], step: usize) -> u32 { hash[7] .wrapping_add(upper_sigma1(hash[4])) .wrapping_add(choose(hash[4], hash[5], hash[6])) .wrapping_add(self.ws[step]) .wrapping_add(CBRT_PRIMES[step]) } fn calc_t2(hash: &[u32; 8]) -> u32 { upper_sigma0(hash[0]).wrapping_add(majority(hash[0], hash[1], hash[2])) } }
変換規則を初期値に順に適用する関数
ここまででパーツはすべてそろいました。あとは初期値を準備して変換していくだけです。256 bit の初期値を32bit ずつに分解したもの は
の小数点以下32bit です。またお前か。
# Python # ~ 前略 ~ SQRT_PRIMES = tuple( nth_root(p << 32 * 2, 2) & U32MAX for p in PRIMES[:8] ) # H_t # ~ 中略 ~ def calc_hash_sha256(msg: str) -> str: padded = pad(msg) hash_ = SQRT_PRIMES for block in padded: ws = create_ws(block) updater = HashUpdater(ws) hash_ = updater.update(hash_) return "".join(f"{x:08x}" for x in hash_)
// Rust // ~ 前略 ~ const SQRT_PRIMES: [u32; 8] = { let mut sqrt_primes = [0; 8]; let mut idx = 0; while idx < 8 { sqrt_primes[idx] = nth_root(PRIMES[idx] << 32 * 2, 2) as u32; idx += 1; } sqrt_primes }; // H_t // ~ 中略 ~ fn calc_hash_sha256(msg: &str) -> String { let padded = pad(msg); let mut hash = SQRT_PRIMES; for block in padded { let ws = create_ws(&block); let updater = HashUpdater::new(ws); hash = updater.update(hash); } hash.into_iter().map(|h| format!("{:0>8x}", h)).collect() }
実装を整理してまとめたもの
以上で実装が完了しました。お疲れさまでした。
最後に、この記事を通して実装したものを以下にまとめました。全体のバランスを見て適宜関数名や定義する場所を変更しています。また Python も Rust も SHA-256 を実装したライブラリがあるので、それと結果を比較する簡易的なテストを実行するようにしてあります。なお Rust は sha2 というクレートを利用しました*2。
# Python def nth_root(x: int, n: int) -> int: """ニュートン法でn乗根を求める関数""" out = x memory = 0 while out != memory: memory = out out = memory - (memory ** n - x) // (n * memory ** (n - 1)) return out - 1 def is_prime(n: int) -> bool: if n <= 1: return False return all(n % d != 0 for d in range(2, n)) U32MAX = 0xffffffff PRIMES = tuple(n for n in range(2, 312) if is_prime(n)) CBRT_PRIMES = tuple( nth_root(p << 32 * 3, 3) & U32MAX for p in PRIMES ) # K_t SQRT_PRIMES = tuple( nth_root(p << 32 * 2, 2) & U32MAX for p in PRIMES[:8] ) # H_t def rotate_right(x: int, n: int) -> int: return (x >> n) | (x << (32 - n)) & U32MAX # はみ出た部分は捨てる def pad(msg: str) -> [[int]]: def _ceil_64(n: int) -> int: """n 以上の最小の64の倍数を返す""" return (n + 63) // 64 * 64 bytes_list: list[int] = list(msg.encode()) # 1 バイトずつの数値のリストを作る total_bits = len(bytes_list) * 8 # 最後に末尾に格納する # 0x80 と 0x00 で埋めて、長さを 64 の倍数にする bytes_list.append(0x80) padded_len = _ceil_64(len(bytes_list) + 8) bytes_list.extend([0] * (padded_len - len(bytes_list))) # 末尾に入力文字列のビット数を格納する for i in range(8): digit = (total_bits >> (8 * i)) & 0xff # 下から i バイト目の値 bytes_list[-1 - i] = digit # 64 バイトずつのブロックに分割する blocks = [ bytes_list[idx:idx + 64] for idx in range(0, len(bytes_list), 64) ] return blocks def create_ws(block_64bytes: [int]) -> [int]: def _sigma0(x: int) -> int: return rotate_right(x, 7) ^ rotate_right(x, 18) ^ (x >> 3) def _sigma1(x: int) -> int: return rotate_right(x, 17) ^ rotate_right(x, 19) ^ (x >> 10) def _4bytes_to_32bit(nums: [int]) -> int: n0, n1, n2, n3 = nums return (n0 << 24) | (n1 << 16) | (n2 << 8) | n3 ws = [0] * 64 for t in range(16): sub = block_64bytes[4 * t: 4 * (t + 1)] ws[t] = _4bytes_to_32bit(sub) for t in range(16, 64): ws[t] = (_sigma1(ws[t - 2]) + ws[t - 7] + _sigma0(ws[t - 15]) + ws[t - 16]) & U32MAX return ws class HashUpdater: def __init__(self, ws: [int]): self._ws = ws def update(self, hash_: [int]) -> [int]: tmp = hash_ for step in range(64): tmp = self._update_step(tmp, step) return [(x + y) & U32MAX for x, y in zip(hash_, tmp)] def _update_step(self, hash_: [int], step: int) -> [int]: t1 = self._calc_t1(hash_, step) t2 = self._calc_t2(hash_) updated = [0] * 8 for n in range(8): match n: case 0: updated[n] = (t1 + t2) & U32MAX case 4: updated[n] = (hash_[n - 1] + t1) & U32MAX case _: updated[n] = hash_[n - 1] return updated def _calc_t1(self, hash_: [int], step: int) -> int: t1 = ( hash_[7] + self._sigma1(hash_[4]) + self._choose(hash_[4], hash_[5], hash_[6]) + self._ws[step] + CBRT_PRIMES[step] ) return t1 & U32MAX @classmethod def _calc_t2(cls, hash_: [int]) -> int: t2 = cls._sigma0(hash_[0]) + cls._majority(hash_[0], hash_[1], hash_[2]) return t2 & U32MAX @staticmethod def _sigma0(x: int) -> int: return rotate_right(x, 2) ^ rotate_right(x, 13) ^ rotate_right(x, 22) @staticmethod def _sigma1(x: int) -> int: return rotate_right(x, 6) ^ rotate_right(x, 11) ^ rotate_right(x, 25) @staticmethod def _choose(x: int, y: int, z: int) -> int: return (x & y) ^ ((x ^ U32MAX) & z) @staticmethod def _majority(x: int, y: int, z: int) -> int: return (x & y) ^ (y & z) ^ (z & x) def calc_hash_sha256(msg: str) -> str: padded = pad(msg) hash_ = SQRT_PRIMES for block in padded: ws = create_ws(block) updater = HashUpdater(ws) hash_ = updater.update(hash_) return "".join(f"{x:08x}" for x in hash_) def main(): import hashlib msg = "Hello, world!" my_hash = calc_hash_sha256(msg) print(f"My hash: {my_hash}") lib_hash = hashlib.sha256(msg.encode()).hexdigest() print(f"Library hash: {lib_hash}") assert my_hash == lib_hash if __name__ == '__main__': main()
// Rust const PRIMES: [u128; 64] = { let mut primes = [0; 64]; let mut idx = 0; let mut p = 0; while idx < 64 { if is_prime(p) { primes[idx] = p; idx += 1; } p += 1; } primes }; const CBRT_PRIMES: [u32; 64] = { let mut cbrt_primes = [0; 64]; let mut idx = 0; while idx < 64 { cbrt_primes[idx] = nth_root(PRIMES[idx] << 32 * 3, 3) as u32; idx += 1; } cbrt_primes }; // K_t const SQRT_PRIMES: [u32; 8] = { let mut sqrt_primes = [0; 8]; let mut idx = 0; while idx < 8 { sqrt_primes[idx] = nth_root(PRIMES[idx] << 32 * 2, 2) as u32; idx += 1; } sqrt_primes }; // H_t const fn is_prime(n: u128) -> bool { if n <= 1 { return false; } let mut d = 2; while d < n { if n % d == 0 { return false; } d += 1; } true } /// ニュートン法でn乗根を求める関数 const fn nth_root(x: u128, n: u32) -> u128 { if x == 0 { return 0; } let mut out = (x >> ((127 - x.leading_zeros()) * (n - 1) / n)) + 1; let mut memory = 0; while memory != out { memory = out; out = memory - (memory.pow(n) - x) / ((n as u128) * memory.pow(n - 1)); } out - 1 } fn pad(msg: &str) -> Vec<[u8; 64]> { let mut bytes_vec = msg.as_bytes().to_vec(); // 1 バイトずつの数値のリストを作る let total_bits = bytes_vec.len() * 8; // 最後に末尾に格納する // 0x80 と 0x00 で埋めて、長さを 64 の倍数にする bytes_vec.push(0x80); let padded_len = (bytes_vec.len() + 8).div_ceil(64) * 64; bytes_vec.resize(padded_len, 0); // 末尾に入力文字列のビット数を格納する for i in 0..8 { let digit = (total_bits >> (8 * i)) & 0xff; // 下から i バイト目の値 bytes_vec[padded_len - i - 1] = digit as u8; } // 64 バイトずつのブロックに分割する let mut blocks = Vec::with_capacity(padded_len / 64); for i in (0..padded_len).step_by(64) { let mut block = [0; 64]; for j in 0..64 { block[j] = bytes_vec[i + j]; } blocks.push(block); } blocks } fn create_ws(block: &[u8; 64]) -> [u32; 64] { fn sigma0(x: u32) -> u32 { x.rotate_right(7) ^ x.rotate_right(18) ^ (x >> 3) } fn sigma1(x: u32) -> u32 { x.rotate_right(17) ^ x.rotate_right(19) ^ (x >> 10) } let mut ws = [0; 64]; for t in 0..16 { ws[t] = u32::from_be_bytes([ block[t * 4], block[t * 4 + 1], block[t * 4 + 2], block[t * 4 + 3], ]); } for t in 16..64 { ws[t] = sigma1(ws[t - 2]) .wrapping_add(ws[t - 7]) .wrapping_add(sigma0(ws[t - 15])) .wrapping_add(ws[t - 16]); } ws } pub struct HashUpdater { ws: [u32; 64], } impl HashUpdater { pub fn new(ws: [u32; 64]) -> Self { Self { ws } } pub fn update(&self, hash: [u32; 8]) -> [u32; 8] { let mut tmp = hash; for step in 0..64 { tmp = self.update_step(tmp, step); } let mut updated = [0; 8]; for n in 0..8 { updated[n] = hash[n].wrapping_add(tmp[n]); } updated } fn update_step(&self, hash: [u32; 8], step: usize) -> [u32; 8] { let t1 = self.calc_t1(&hash, step); let t2 = Self::calc_t2(&hash); let mut updated = [0; 8]; for n in 0..8 { updated[n] = match n { 0 => t1.wrapping_add(t2), 4 => hash[3].wrapping_add(t1), _ => hash[n - 1], }; } updated } fn calc_t1(&self, hash: &[u32; 8], step: usize) -> u32 { hash[7] .wrapping_add(Self::sigma1(hash[4])) .wrapping_add(Self::choose(hash[4], hash[5], hash[6])) .wrapping_add(self.ws[step]) .wrapping_add(CBRT_PRIMES[step]) } fn calc_t2(hash: &[u32; 8]) -> u32 { Self::sigma0(hash[0]).wrapping_add(Self::majority(hash[0], hash[1], hash[2])) } fn sigma0(x: u32) -> u32 { x.rotate_right(2) ^ x.rotate_right(13) ^ x.rotate_right(22) } fn sigma1(x: u32) -> u32 { x.rotate_right(6) ^ x.rotate_right(11) ^ x.rotate_right(25) } fn choose(x: u32, y: u32, z: u32) -> u32 { (x & y) ^ (!x & z) } fn majority(x: u32, y: u32, z: u32) -> u32 { (x & y) ^ (y & z) ^ (z & x) } } fn calc_hash_sha256(msg: &str) -> String { let padded = pad(msg); let mut hash = SQRT_PRIMES; for block in padded { let ws = create_ws(&block); let updater = HashUpdater::new(ws); hash = updater.update(hash); } hash.into_iter().map(|h| format!("{:0>8x}", h)).collect() } fn main() { use sha2::{Digest, Sha256}; // cargo add sha2 で追加 let msg = "Hello, world!"; let my_hash = calc_hash_sha256(msg); println!("My hash: {my_hash}"); let lib_hash: String = Sha256::digest(msg.as_bytes()) .into_iter() .map(|b| format!("{:0>2x}", b)) .collect(); println!("Library hash: {lib_hash}"); assert_eq!(my_hash, lib_hash); }