37日目: case whenの対応
今回はcase whenに取り組みます。
Rubyには3種類のcaseがありますが、今回は1・2つめのcase ... when ...をやることにし、3つめのcase ... in ...についてはパターンマッチングのときに取り組むことにします。
# その1 case a when :a when :b else end # その2 case when a when b else end # その3 case a in [0] in [0, 1] else end
書き換え前後のノードの変更点
書き換えの前後で大きく3つ変わります。
NODE_CASEとNODE_CASE2が統合されるNODE_WHENがリスト構造ではなくなるelseにあたるノードが直接caseのノードに紐づく
それぞれ詳しく説明します。
まず1つめのNODE_CASEとNODE_CASE2が統合されるについてです。
書き換え前はcase a ... whenにはNODE_CASE、case when ...にはNODE_CASE2と別のノードが割り当てられていました。
書き換え後はどちらもCaseNodeで表現することになり、それらの差はpredicateというフィールドの値の有無で判別するようになります。
つぎは2つめのNODE_WHENがリスト構造ではなくなるについてです。
ノードの書き換え前後を比べるとわかるのですが、書き換え前はNODE_WHENの要素として次のNODE_WHENが保存されていました。
つまりNODE_WHENそのものがリスト構造になっているというわけです。
一方で書き換え後はCaseNodeのconditionsというフィールドが配列になっていて、そこに複数のWhenNodeが保存されるようになります。
# Before # # @ NODE_CASE (id: 9, line: 11, location: (11,0)-(14,3))* # +- nd_body: # | @ NODE_WHEN (id: 8, line: 12, location: (12,0)-(13,7)) # | +- nd_next: # | @ NODE_WHEN (id: 7, line: 13, location: (13,0)-(13,7)) # | +- nd_next: # | (null node) # After # # @ CaseNode (location: (11,0)-(14,3)) # +-- conditions: (length: 2) # | +-- @ WhenNode (location: (12,0)-(12,7)) # | +-- @ WhenNode (location: (13,0)-(13,7)) case a when :a when :b end
最後に3つめのelseにあたるノードが直接caseのノードに紐づくです。
書き換え前はNODE_WHENの最後の要素としてelseのbodyにあたるノードが保存されている構造になっていました。
書き換え後はCaseNodeのelse_clauseというフィールドで管理することになります。
# Before # # @ NODE_CASE (id: 9, line: 11, location: (11,0)-(14,3))* # +- nd_body: # | @ NODE_WHEN (id: 8, line: 12, location: (12,0)-(13,7)) # | +- nd_next: # | @ NODE_WHEN (id: 7, line: 13, location: (13,0)-(13,7)) # | +- nd_next: # | @ NODE_SYM (id: 7, line: 15, location: (15,2)-(15,7))* # After # # @ CaseNode (location: (11,0)-(14,3)) # +-- conditions: (length: 2) # | +-- @ WhenNode (location: (12,0)-(12,7)) # | +-- @ WhenNode (location: (13,0)-(13,7)) # +-- else_clause: # | @ ElseNode (location: (14,0)-(16,3)) case a when :a when :b else :else end
parse.yの変更
when ... else ...の部分の生成規則は以下のようになっており、後ろのwhenやelseから順番にノードを組み立てるようになっています。
primary | k_case expr_value terms? // case a case_body // when :a ... else ... k_end // end case_body : k_when case_args then // when :a compstmt(stmts) // expr cases // when :b ... else ... cases : opt_else // else ... | case_body // when :b ... else ...
今回は生成規則に手を入れないようにしたいので、常に配列の先頭にノードを追加するようにします。
@@ -5911,13 +5918,25 @@ case_body : k_when case_args then compstmt(stmts) cases { - $$ = NEW_WHEN($2, $4, $5, &@$, &@1, &@3); + $$ = NEW_RB_WHEN($2, $4, &@$, &@1, &@3); fixpos($$, $2); + if ($5) { + $$ = node_array_prepend(p, $5, $$, &@$); + } + else { + $$ = NEW_RB_ARRAY($$, &@$); + } /*% ripper: when!($:2, $:4, $:5) %*/ } ; cases : opt_else + { + if ($1) { + $$ = NEW_RB_ARRAY($1, $$); + } + /*% ripper: when!($:2, $:4, $:5) %*/ + } | case_body ;
またこのままだとWhenNodeもElseNodeも1つの配列に入ったままになってしまうので、CaseNodeを作成するときに配列の最後のノードの種類をチェックして、それがElseNodeの場合には配列から取り出してelse_clauseにセットするようにします。
static rb_case_node_t * rb_new_node_case_new(struct parser_params *p, rb_node_t *nd_head, rb_array_node_t *nd_conds, const YYLTYPE *loc, const YYLTYPE *case_keyword_loc, const YYLTYPE *end_keyword_loc) { rb_case_node_t *n = RB_NEW_NODE_NEWNODE((enum rb_node_type)RB_CASE_NODE, rb_case_node_t, loc); rb_node_list2_t *list = &nd_conds->elements; rb_node_t *nd_last = rb_node_list_last(list); rb_else_node_t *nd_else = NULL; if (nd_last && nd_type_p(nd_last, RB_ELSE_NODE)) { nd_else = rb_node_list_pop(list); } n->predicate = nd_head; rb_node_list_init_with_src(&n->conditions, list); n->else_clause = nd_else; n->case_keyword_loc = *case_keyword_loc; n->end_keyword_loc = *end_keyword_loc; return n; }
バイトコードを眺める(case ... when ...の場合)
以下のコードを例にどのようなバイトコードが生成されるか確認しておきましょう。
case v when a :a_body when b :b_body else :else_body end
このコードをcase ... whenを使わずに書くと以下のようなコードになります。
tmp = v if (a === v) :a_body elsif (b === v) :b_body else :else_body end
生成されるバイトコードは大きく2つの部分からなります。
最初にマッチするwhen ...を決定するバイトコードが並びます。
ここではwhen ...を上からに試し、マッチした時点で後半のバイトコードにジャンプします。
後半はwhen ... bodyのbodyに当たるバイトコードが並んでいて、それぞれbodyの部分を実行するとleave(もしくはjump)でcase ... when ...全体から抜けるようになっています。
# `a === v`や`b === v`にあたる部分 # # 0000 putself ( 11)[Li] # 0001 opt_send_without_block <calldata!mid:v, argc:0, FCALL|VCALL|ARGS_SIMPLE> # 0003 putself ( 12) # 0004 opt_send_without_block <calldata!mid:a, argc:0, FCALL|VCALL|ARGS_SIMPLE> # 0006 topn 1 # 0008 opt_send_without_block <calldata!mid:===, argc:1, FCALL|ARGS_SIMPLE> # 0010 branchif 25 # 0012 putself ( 14) # 0013 opt_send_without_block <calldata!mid:b, argc:0, FCALL|VCALL|ARGS_SIMPLE> # 0015 topn 1 # 0017 opt_send_without_block <calldata!mid:===, argc:1, FCALL|ARGS_SIMPLE> # 0019 branchif 29 # `else`のbody # # 0021 pop ( 17) # 0022 putobject :else_body[Li] # 0024 leave # `when a`のbody # # 0025 pop ( 12) # 0026 putobject :a_body ( 13)[Li] # 0028 leave ( 17) # `when b`のbody # # 0029 pop ( 14) # 0030 putobject :b_body ( 15)[Li] # 0032 leave ( 17)
バイトコードを眺める(case when ...の場合)
次にcaseのあとに式を取らない場合のバイトコードをみていきます。
case when a :a_body when b :b_body else :else_body end
case whenを使わずに書くと以下のようなコードになります。
if a :a_body elsif b :b_body else :else_body end
チェックの方法が変わるだけで、基本的には先ほどのケースと同じようなバイトコードが生成されます。
# `if a`や`if b`にあたる部分 # # 0000 putself ( 45)[Li] # 0001 opt_send_without_block <calldata!mid:a, argc:0, FCALL|VCALL|ARGS_SIMPLE> # 0003 branchif 13 # 0005 putself ( 47) # 0006 opt_send_without_block <calldata!mid:b, argc:0, FCALL|VCALL|ARGS_SIMPLE> # 0008 branchif 16 # `else`のbody # # 0010 putobject :else_body ( 50)[Li] # 0012 leave # `when a`のbody # # 0013 putobject :a_body ( 46)[Li] # 0015 leave ( 50) # `when b`のbody # # 0016 putobject :b_body ( 48)[Li] # 0018 leave ( 50)
compile.cを変更する
もともとcompile.cではNODE_CASEをcompile_case関数で、NODE_CASE2をcompile_case2関数で処理していました。
引き続きそれら2つの関数を使用するようにします。
case RB_CASE_NODE: { if (RB_NODE_CASE(node)->predicate) { CHECK(compile_case(iseq, ret, node, popped)); } else { CHECK(compile_case2(iseq, ret, node, popped)); } break; }
compile_case関数
case ... when ...のケースを処理するcompile_case関数からみていきましょう。
生成されるバイトコードは条件をチェックして必要に応じてjumpする部分と、条件にマッチした場合に実行されるbodyの部分に大きく分かれているのでした。
そこでcompile_case関数ではbody_seqとcond_seqというアンカーを用意して、WhenNodeをコンパイルするときに条件の部分とbodyの部分を異なるアンカーに追記していき、最後に1つのバイトコード列へmergeします。
static int compile_case(rb_iseq_t *iseq, LINK_ANCHOR *const ret, const NODE *const orig_node, int popped) { DECL_ANCHOR(head); DECL_ANCHOR(body_seq); DECL_ANCHOR(cond_seq); INIT_ANCHOR(head); INIT_ANCHOR(body_seq); INIT_ANCHOR(cond_seq); ADD_SEQ(ret, cond_seq); ADD_SEQ(ret, body_seq); return COMPILE_OK; }
compile_case関数のうち、おもにバイトコードの生成に関するロジックを抜き出すと以下のような構造になっています。
- まず
predicate(case v)をコンパイルする - 次に
conditions(when a ... when b ...)を1つずつコンパイルする。このときsplat(when *a)かどうかで条件の部分に関して、生成するバイトコードが変化する - さいごに
else_clauseをコンパイルする
static int compile_case(rb_iseq_t *iseq, LINK_ANCHOR *const ret, const NODE *const orig_node, int popped) { const rb_node_list2_t *vals; const rb_node_list2_t *list = &RB_NODE_CASE(node)->conditions; const rb_else_node_t *nd_else = RB_NODE_CASE(node)->else_clause; CHECK(COMPILE(head, "case base", RB_NODE_CASE(node)->predicate)); /* when */ for (size_t i = 0; i < RB_NODE_LIST_LEN(list); i++) { const NODE *n = list->nodes[i]; CHECK(COMPILE_(body_seq, "when body", RB_NODE_WHEN(n)->statements, popped)); vals = &RB_NODE_WHEN(n)->conditions; if (!RB_NODE_LIST_EMPTY_P(vals)) { /* when a, b, c */ for (size_t j = 0; j < RB_NODE_LIST_LEN(vals); j++) { const NODE *val = vals->nodes[j]; if (nd_type_p(val, RB_SPLAT_NODE)) { only_special_literals = 0; CHECK(when_splat_vals(iseq, cond_seq, val, l1, only_special_literals, literals)); } else { only_special_literals = when_vals(iseq, cond_seq, node, val, l1, only_special_literals, literals); if (only_special_literals < 0) return COMPILE_NG; } } } else { EXPECT_NODE_NONULL("NODE_CASE", n, NODE_LIST, COMPILE_NG); } } /* else */ if (nd_else) { ADD_LABEL(cond_seq, elselabel); ADD_INSN(cond_seq, nd_else, pop); add_trace_branch_coverage(iseq, cond_seq, nd_code_loc(nd_else), nd_node_id(nd_else), branch_id, "else", branches); CHECK(COMPILE_(cond_seq, "else", nd_else, popped)); ADD_INSNL(cond_seq, nd_else, jump, endlabel); } else { debugs("== else (implicit)\n"); ADD_LABEL(cond_seq, elselabel); ADD_INSN(cond_seq, orig_node, pop); add_trace_branch_coverage(iseq, cond_seq, nd_code_loc(orig_node), nd_node_id(orig_node), branch_id, "else", branches); if (!popped) { ADD_INSN(cond_seq, orig_node, putnil); } ADD_INSNL(cond_seq, orig_node, jump, endlabel); } ADD_SEQ(ret, cond_seq); ADD_SEQ(ret, body_seq); ADD_LABEL(ret, endlabel); return COMPILE_OK; }
想定したバイトコードが出力されることを確認します。
# == disasm: #<ISeq:<main>@../../test.rb:1 (1,0)-(8,3)> # 0000 putself ( 1)[Li] # 0001 opt_send_without_block <calldata!mid:v, argc:0, FCALL|VCALL|ARGS_SIMPLE> # 0003 putself ( 2) # 0004 opt_send_without_block <calldata!mid:a, argc:0, FCALL|VCALL|ARGS_SIMPLE> # 0006 topn 1 # 0008 opt_send_without_block <calldata!mid:===, argc:1, FCALL|ARGS_SIMPLE> # 0010 branchif 25 # 0012 putself ( 4) # 0013 opt_send_without_block <calldata!mid:b, argc:0, FCALL|VCALL|ARGS_SIMPLE> # 0015 topn 1 # 0017 opt_send_without_block <calldata!mid:===, argc:1, FCALL|ARGS_SIMPLE> # 0019 branchif 29 # 0021 pop ( 7) # 0022 putobject :else_body[Li] # 0024 leave # 0025 pop ( 2) # 0026 putobject :a_body ( 3)[Li] # 0028 leave ( 7) # 0029 pop ( 4) # 0030 putobject :b_body ( 5)[Li] # 0032 leave ( 7) case v when a :a_body when b :b_body else :else_body end
whenの条件が複数あるとき
case ... when ...にはいくつか細かい注意点があるので、それらをみておきます。
まずはwhen a1, a2のようにwhenのあとに2つ以上の条件がある場合です。
このときは左から順にマッチするかチェックするため、その順番に===を呼び出して条件付きジャンプをするバイトコードが生成されます。
when a1, when a2, when a3の全てでbranchif 43と飛び先の命令が同じになっています。
# `case v` # # 0000 putself ( 2)[Li] # 0001 opt_send_without_block <calldata!mid:v, argc:0, FCALL|VCALL|ARGS_SIMPLE> # `when a1` # # 0003 putself ( 3) # 0004 opt_send_without_block <calldata!mid:a1, argc:0, FCALL|VCALL|ARGS_SIMPLE> # 0006 topn 1 # 0008 opt_send_without_block <calldata!mid:===, argc:1, FCALL|ARGS_SIMPLE> # 0010 branchif 43 # `when a2` # # 0012 putself # 0013 opt_send_without_block <calldata!mid:a2, argc:0, FCALL|VCALL|ARGS_SIMPLE> # 0015 topn 1 # 0017 opt_send_without_block <calldata!mid:===, argc:1, FCALL|ARGS_SIMPLE> # 0019 branchif 43 # `when a3` # # 0021 putself # 0022 opt_send_without_block <calldata!mid:a3, argc:0, FCALL|VCALL|ARGS_SIMPLE> # 0024 topn 1 # 0026 opt_send_without_block <calldata!mid:===, argc:1, FCALL|ARGS_SIMPLE> # 0028 branchif 43 # `when b` # # 0030 putself ( 5) # 0031 opt_send_without_block <calldata!mid:b, argc:0, FCALL|VCALL|ARGS_SIMPLE> # 0033 topn 1 # 0035 opt_send_without_block <calldata!mid:===, argc:1, FCALL|ARGS_SIMPLE> # 0037 branchif 47 # 以下body # # 0039 pop ( 8) # 0040 putobject :else_body[Li] # 0042 leave # 0043 pop ( 3) # 0044 putobject :a_body ( 4)[Li] # 0046 leave ( 8) # 0047 pop ( 5) # 0048 putobject :b_body ( 6)[Li] # 0050 leave ( 8) case v when a1, a2, a3 :a_body when b :b_body else :else_body end
splatがあるとき
when *bのようにsplatがある場合はマッチのためのバイトコードが変わります。
# `case v` # # 0000 putself ( 1)[Li] # 0001 opt_send_without_block <calldata!mid:v, argc:0, FCALL|VCALL|ARGS_SIMPLE> # `when a` # # 0003 putself ( 2) # 0004 opt_send_without_block <calldata!mid:a, argc:0, FCALL|VCALL|ARGS_SIMPLE> # 0006 topn 1 # 0008 opt_send_without_block <calldata!mid:===, argc:1, FCALL|ARGS_SIMPLE> # 0010 branchif 26 # `when *b` # # 0012 dup ( 4) # 0013 putself # 0014 opt_send_without_block <calldata!mid:b, argc:0, FCALL|VCALL|ARGS_SIMPLE> # 0016 splatarray false # 0018 checkmatch 6 # 0020 branchif 30 # 以下はwhenやelseのbody # # 0022 pop ( 7) # 0023 putobject :else_body[Li] # 0025 leave # 0026 pop ( 2) # 0027 putobject :a_body ( 3)[Li] # 0029 leave ( 7) # 0030 pop ( 4) # 0031 putobject :b_body ( 5)[Li] # 0033 leave ( 7) case v when a :a_body when *b :b_body else :else_body end
実行時のスタックを考えてみましょう。
まずはwhen aの場合からみてみます。
このときはa === vを実行するためtopnでvがスタックトップにくるようにしてから#===メソッドを呼び出します。
# `0006 topn 1`まで v a v # `0008 opt_send_without_block <calldata!mid:===, argc:1, FCALL|ARGS_SIMPLE>`まで a === v # の結果 v
次にwhen *bの場合です。
このときはcheckmatch(VALUE target, VALUE pattern)を実行するため、v, bの順番(右がスタックトップ)にスタックに積まれている必要があります。
# `0012 dup`まで v v # `0016 splatarray false`まで b # をdupしたもの v v
最適化されたバイトコード
最後に最適化についてみておきます。
when ...のすべての要素が実行時評価を必要としない場合には、条件をkey、飛び先をvalueとしたhashを用意して、そこから直接飛び先を引くという最適化が入ります。
以下のバイトコードでいうと0004 opt_case_dispatch <cdhash>, 23が追加されています。
# 0000 putself ( 1)[Li] # 0001 opt_send_without_block <calldata!mid:v, argc:0, FCALL|VCALL|ARGS_SIMPLE> # 0003 dup # 0004 opt_case_dispatch <cdhash>, 23 # 0007 putobject :a ( 2) # 0009 topn 1 # 0011 opt_send_without_block <calldata!mid:===, argc:1, FCALL|ARGS_SIMPLE> # 0013 branchif 27 # 0015 putobject :b ( 4) # 0017 topn 1 # 0019 opt_send_without_block <calldata!mid:===, argc:1, FCALL|ARGS_SIMPLE> # 0021 branchif 31 # 0023 pop ( 7) # 0024 putobject :body_else[Li] # 0026 leave # ... case v when :a :body_a when :b :body_b else :body_else end
ここまで踏まえて、簡単なコードで動作確認しておきましょう。
def m(v) a = :a b = :b case v when a :body_a when b :body_b else :body_else end end p m(:a) #=> :body_a p m(:else) #=> :body_else
compile_case2関数
次にcase when ...をコンパイルするcompile_case2関数を新しいノードに対応させます。
といっても大まかな流れは変更前、およびcompile_case関数と変わりません。
case when ...はcaseのpredicateが存在しないため、いきなりwhenをコンパイルすることになります。
whenの条件の部分と、bodyの部分を別々のアンカーに追記していくのはcompile_case関数と変わりません。
その後、elseに当たる部分をコンパイルして終了です。
static int compile_case2(rb_iseq_t *iseq, LINK_ANCHOR *const ret, const NODE *const orig_node, int popped) { /* when */ for (size_t i = 0; i < RB_NODE_LIST_LEN(list); i++) { const NODE *n = list->nodes[i]; // when ... bodyのうち、bodyの部分のコンパイル CHECK(COMPILE_(body_seq, "when", RB_NODE_WHEN(n)->statements, popped)); ADD_INSNL(body_seq, n, jump, endlabel); // when ...のうち、...の部分のコンパイル vals = &RB_NODE_WHEN(n)->conditions; if (RB_NODE_LIST_EMPTY_P(vals)) { EXPECT_NODE_NONULL("NODE_WHEN", n, NODE_LIST, COMPILE_NG); } /* when a, b, c */ for (size_t j = 0; j < RB_NODE_LIST_LEN(vals); j++) { const NODE *val = vals->nodes[j]; if (nd_type_p(val, RB_SPLAT_NODE)) { ADD_INSN(ret, val, putnil); CHECK(COMPILE(ret, "when2/cond splat", RB_NODE_SPLAT(val)->expression)); ADD_INSN1(ret, val, splatarray, Qtrue); ADD_INSN1(ret, val, checkmatch, INT2FIX(VM_CHECKMATCH_TYPE_WHEN | VM_CHECKMATCH_ARRAY)); ADD_INSNL(ret, val, branchif, l1); } else { LABEL *lnext; lnext = NEW_LABEL(nd_line(val)); debug_compile("== when2\n", (void)0); CHECK(compile_branch_condition(iseq, ret, val, l1, lnext)); ADD_LABEL(ret, lnext); } } } /* else */ CHECK(COMPILE_(ret, "else", nd_else, popped)); ADD_INSNL(ret, orig_node, jump, endlabel); ADD_SEQ(ret, body_seq); ADD_LABEL(ret, endlabel); return COMPILE_OK; }
動作確認をして終わりにしましょう。
def m(a: false, b1: false, b2: false, c: [false]) case when a :body_a when b1, b2 :body_b when *c :body_c else :body_else end end p m() #=> :body_else p m(a: true) #=> :body_a p m(b1: true) #=> :body_b p m(c: [true]) #=> :body_c
まとめ
今日の成果です。
case v when ....に対応したcase when ....に対応した