チームlazy_oracleで参加しました。(一人チーム)
ブロック暗号問があったら解きたいなと思っていたところ、cerberusが該当問だったので解いてみました。
---追記(2021-12-13): 作問者のkurenaifさんからコメントを頂きました。
---追記終わりwriteupありがとうございます!実はc1,c2,c3を3つ繰り返した後、c1を更にもう一度付け足すと少し楽にできます!
— kurenaif@VTuber (@fwarashi) 2021年12月12日
まず、チャレンジャーには以下の情報が与えれます。
- ソースコード
- 問題サーバーのurlとポート番号
ソースコードを読むと、以下のことがわかりました。
詳細を見ると、
c = EncAES-128-PCBC(flag, key, iv)
のように鍵長128bitのAESのPCBCモードでflagが暗号化されているようです。
また、クライアントから受け付けたivと暗号文を使用してサーバーは復号処理を行い、最終ブロックのパディングチェックの結果を返します。
サーバーから与えられるパラメータは
- 暗号文c
- iv
また、サーバーのソースコードを見てみると入力に以下の条件があることがわかります。
- クライアントから受けとった暗号文がもとの暗号文と前方一致しているかチェック
- 復号した平文に対してパッディングチェックをし、成否を出力
つまり、今回自由に操作できるのは以下の2つであることがわかります。
- iv
- もとの暗号文c以降に任意のバイナリを連結すること
条件1,2の該当部分のソースコード
while True:
c = base64.b64decode(input("spell:"))
iv = c[:16]
c = c[16:]
if not c.startswith(ref_c):
print("Grrrrrrr!!!!")
continue
m = decrypt(iv, c)
try:
unpad(m, block_size)
except:
print("little different :(")
continue
print("Great :)")
そして、条件2からPadding Oracle攻撃が狙えそうです。
ただし、PCBCモードは下の図のように2ブロック目以降は暗号文ブロックと平文ブロックが後ろのブロックの復号に影響を与えます。
なので、Padding Oracleをするために任意の入力を作るのは工夫が必要です。

ここで、1ブロック目に着目します。 1ブロック目であればIVを制御するだけでPadding Oracle攻撃ができそうです。 ただし、条件1の制限があるため、オリジナルの暗号文cのうち1ブロック目だけを素朴に切り出してサーバーに入力することはできません。 なので、2ブロック目以降を打ち消す方法を考えます。
ここで、排他的論理和の性質を思い出します。 排他的論理和の真理値表は次のとおりです。 排他的論理和はA==Bのときに0を出力します。 つまり、同じ値を持つもの同士を打ち消す性質があります。
| A | B | A xor B |
|---|---|---|
| 0 | 0 | 0 |
| 0 | 1 | 1 |
| 1 | 0 | 1 |
| 1 | 1 | 0 |
つまり、以下の図のように暗号文cに2ブロック以降の暗号文C3, C2, C1を連結すれば2ブロック目以降を打ち消して1ブロック目の出力P0の値のみを最終ブロックにもっていけそうです。

しかし、このままだとC0とC1の影響が残っているので、ivを操作して打ち消します。

ここまでくれば、ivを操作してPadding Oracle攻撃で1ブロック目のDec(C0)(図赤丸部分)を求めて
P0 = Dec(C0) xor iv
とすることで、平文P0を復元できます。

また、2ブロック目は復元したDec(C0)を使い同じ方針で3ブロック目以降を打ち消した後にPadding Oracle攻撃でDec(C1)を復元します。
その後、P1 = Dec(C1) xor P0 xor C0
とすることで平文P1を復元できます。
3ブロック目以降も同様です。

Solverは以下のとおりです。
SECCON{v._.^v-_-v^._.^_S0und_oF_0rpHeUs_Aha~~}
import socket
import base64
from Crypto.Cipher import AES
from Crypto.Random import get_random_bytes
from Crypto.Util.Padding import pad, unpad
from Crypto.Util.strxor import strxor
block_size = 16
#----------
# Netcat class: ref https://scrapbox.io/progfay-pub/netcat.py
class Netcat:
""" Python 'netcat like' module """
def __init__(self, ip, port):
self.buff = ""
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.socket.connect((ip, port))
def read(self, length = 1024):
""" Read 1024 bytes off the socket """
return self.socket.recv(length)
def write(self, data):
self.socket.sendall(((str(data)+'\n')).encode('utf-8'))
def close(self):
self.socket.close()
#----------
def unpack_ct(b64string):
a = base64.b64decode(b64string.encode('utf-8'))
iv = a[:16]
c = a[16:]
return iv, c
def pack_ct(iv, c):
return base64.b64encode(iv+c).decode('utf-8')
def recover_plaintext_block(iv, c, mask0=b"\x00"*block_size, mask1=b"\x00"*block_size):
attack_query = c
attack_query_iv = iv
iv0 = b'\x00'*block_size
iv0 = [iv0[i : i + 1] for i in range(0, len(iv0))]
state = [None]*block_size
for b in range(block_size):
for j in range(b):
# ターゲットより下位バイトをPadding
iv0[block_size-1-j] = strxor(state[15-j], bytes([b+1]))
for i in range(256):
iv0[block_size-1-b] = bytes([i])
iv0_t = strxor(b"".join(iv0), attack_query_iv)
nc.write(pack_ct(iv0_t, attack_query))
lines = nc.read().decode('utf-8').split('\n')
if -1!=lines[0].find("Great"):
print('found!')
state[15-b] = strxor(bytes([i]), bytes([b+1]))
break
if 255==i:
print('cannot find')
exit(-1)
plaintext=strxor(b"".join(state), mask0)
plaintext=strxor(plaintext, mask1)
return plaintext, b"".join(state)
nc = Netcat('cerberus.quals.seccon.jp', 8080)
lines = nc.read().decode('utf-8').split('\n')
for line in lines:
print(line)
iv, c = unpack_ct(lines[-2])
num_of_block = len(c)//block_size
print("num of block: %d" %num_of_block)
print("recover p0...")
attack_query_iv = strxor(c[:16], c[16:32])
attack_query_ciphertext = c+c[48:64]+c[32:48]+c[16:32]
p0, state_c0 = recover_plaintext_block(attack_query_iv, attack_query_ciphertext, mask0=iv)
print("p0 is:")
print(p0)
print("---")
print("recover p1...")
attack_query_ciphertext = c+c[48:64]+c[32:48]
attack_query_iv = strxor(c[:16], c[16:32])
attack_query_iv = strxor(attack_query_iv, c[32:48])
attack_query_iv = strxor(attack_query_iv, state_c0)
p1, state_c1 = recover_plaintext_block(attack_query_iv, attack_query_ciphertext, mask0=c[0:16], mask1=p0)
print("p1 is:")
print(p1)
print("---")
print("recover p2...")
attack_query_ciphertext = c+c[48:64]
attack_query_iv = strxor(c[:16], c[16:32])
attack_query_iv = strxor(attack_query_iv, c[32:48])
attack_query_iv = strxor(attack_query_iv, c[48:64])
attack_query_iv = strxor(attack_query_iv, state_c0)
attack_query_iv = strxor(attack_query_iv, state_c1)
p2, state_c2 = recover_plaintext_block(attack_query_iv, attack_query_ciphertext, mask0=c[16:32], mask1=p1)
print("p2 is:")
print(p2)
print("---")
print("recover p3...")
attack_query_ciphertext = c
attack_query_iv = strxor(c[:16], c[16:32])
attack_query_iv = strxor(attack_query_iv, c[32:48])
attack_query_iv = strxor(attack_query_iv, state_c0)
attack_query_iv = strxor(attack_query_iv, state_c1)
attack_query_iv = strxor(attack_query_iv, state_c2)
p3, _ = recover_plaintext_block(attack_query_iv, attack_query_ciphertext, mask0=c[32:48], mask1=p2)
print("p3 is:")
print(p3)
print("---")
p_all = p0+p1+p2+p3
flag = unpad(p_all, block_size)[16:].decode('utf-8')
print(flag)