以下の内容はhttps://yui-knk.hatenablog.com/entry/2026/02/22/135412より取得しました。


Ruby Parser開発日誌 (24-37) - parse.yが生成するノードを変える ー case when

37日目: case whenの対応

今回はcase whenに取り組みます。 Rubyには3種類のcaseがありますが、今回は1・2つめのcase ... when ...をやることにし、3つめのcase ... in ...についてはパターンマッチングのときに取り組むことにします。

# その1
case a
when :a
when :b
else
end

# その2
case
when a
when b
else
end

# その3
case a
in [0]
in [0, 1]
else
end

書き換え前後のノードの変更点

書き換えの前後で大きく3つ変わります。

  1. NODE_CASENODE_CASE2が統合される
  2. NODE_WHENがリスト構造ではなくなる
  3. elseにあたるノードが直接caseのノードに紐づく

それぞれ詳しく説明します。

まず1つめのNODE_CASENODE_CASE2が統合されるについてです。 書き換え前はcase a ... whenにはNODE_CASEcase when ...にはNODE_CASE2と別のノードが割り当てられていました。 書き換え後はどちらもCaseNodeで表現することになり、それらの差はpredicateというフィールドの値の有無で判別するようになります。

つぎは2つめのNODE_WHENがリスト構造ではなくなるについてです。 ノードの書き換え前後を比べるとわかるのですが、書き換え前はNODE_WHENの要素として次のNODE_WHENが保存されていました。 つまりNODE_WHENそのものがリスト構造になっているというわけです。 一方で書き換え後はCaseNodeconditionsというフィールドが配列になっていて、そこに複数のWhenNodeが保存されるようになります。

# Before
#
# @ NODE_CASE (id: 9, line: 11, location: (11,0)-(14,3))*
# +- nd_body:
# |   @ NODE_WHEN (id: 8, line: 12, location: (12,0)-(13,7))
# |   +- nd_next:
# |       @ NODE_WHEN (id: 7, line: 13, location: (13,0)-(13,7))
# |       +- nd_next:
# |           (null node)

# After
#
# @ CaseNode (location: (11,0)-(14,3))
# +-- conditions: (length: 2)
# |   +-- @ WhenNode (location: (12,0)-(12,7))
# |   +-- @ WhenNode (location: (13,0)-(13,7))

case a
when :a
when :b
end

最後に3つめのelseにあたるノードが直接caseのノードに紐づくです。 書き換え前はNODE_WHENの最後の要素としてelseのbodyにあたるノードが保存されている構造になっていました。 書き換え後はCaseNodeelse_clauseというフィールドで管理することになります。

# Before
#
# @ NODE_CASE (id: 9, line: 11, location: (11,0)-(14,3))*
# +- nd_body:
# |   @ NODE_WHEN (id: 8, line: 12, location: (12,0)-(13,7))
# |   +- nd_next:
# |       @ NODE_WHEN (id: 7, line: 13, location: (13,0)-(13,7))
# |       +- nd_next:
# |           @ NODE_SYM (id: 7, line: 15, location: (15,2)-(15,7))*

# After
#
# @ CaseNode (location: (11,0)-(14,3))
# +-- conditions: (length: 2)
# |   +-- @ WhenNode (location: (12,0)-(12,7))
# |   +-- @ WhenNode (location: (13,0)-(13,7))
# +-- else_clause:
# |   @ ElseNode (location: (14,0)-(16,3))

case a
when :a
when :b
else
  :else
end

parse.yの変更

when ... else ...の部分の生成規則は以下のようになっており、後ろのwhenelseから順番にノードを組み立てるようになっています。

primary     | k_case expr_value terms? // case a
              case_body                // when :a ... else ...
              k_end                    // end

case_body   : k_when case_args then    // when :a
              compstmt(stmts)          //   expr
              cases                    // when :b ... else ...

cases       : opt_else                 // else ...
            | case_body                // when :b ... else ...

今回は生成規則に手を入れないようにしたいので、常に配列の先頭にノードを追加するようにします。

@@ -5911,13 +5918,25 @@ case_body       : k_when case_args then
                   compstmt(stmts)
                   cases
                     {
-                        $$ = NEW_WHEN($2, $4, $5, &@$, &@1, &@3);
+                        $$ = NEW_RB_WHEN($2, $4, &@$, &@1, &@3);
                         fixpos($$, $2);
+                        if ($5) {
+                            $$ = node_array_prepend(p, $5, $$, &@$);
+                        }
+                        else {
+                            $$ = NEW_RB_ARRAY($$, &@$);
+                        }
                     /*% ripper: when!($:2, $:4, $:5) %*/
                     }
                 ;

 cases          : opt_else
+                    {
+                        if ($1) {
+                            $$ = NEW_RB_ARRAY($1, $$);
+                        }
+                    /*% ripper: when!($:2, $:4, $:5) %*/
+                    }
                 | case_body
                 ;

またこのままだとWhenNodeElseNodeも1つの配列に入ったままになってしまうので、CaseNodeを作成するときに配列の最後のノードの種類をチェックして、それがElseNodeの場合には配列から取り出してelse_clauseにセットするようにします。

static rb_case_node_t *
rb_new_node_case_new(struct parser_params *p, rb_node_t *nd_head, rb_array_node_t *nd_conds, const YYLTYPE *loc, const YYLTYPE *case_keyword_loc, const YYLTYPE *end_keyword_loc)
{
    rb_case_node_t *n = RB_NEW_NODE_NEWNODE((enum rb_node_type)RB_CASE_NODE, rb_case_node_t, loc);
    rb_node_list2_t *list = &nd_conds->elements;
    rb_node_t *nd_last = rb_node_list_last(list);
    rb_else_node_t *nd_else = NULL;

    if (nd_last && nd_type_p(nd_last, RB_ELSE_NODE)) {
        nd_else = rb_node_list_pop(list);
    }

    n->predicate = nd_head;
    rb_node_list_init_with_src(&n->conditions, list);
    n->else_clause = nd_else;
    n->case_keyword_loc = *case_keyword_loc;
    n->end_keyword_loc = *end_keyword_loc;

    return n;
}

バイトコードを眺める(case ... when ...の場合)

以下のコードを例にどのようなバイトコードが生成されるか確認しておきましょう。

case v
when a
  :a_body
when b
  :b_body
else
  :else_body
end

このコードをcase ... whenを使わずに書くと以下のようなコードになります。

tmp = v
if (a === v)
  :a_body
elsif (b === v)
  :b_body
else
  :else_body
end

生成されるバイトコードは大きく2つの部分からなります。 最初にマッチするwhen ...を決定するバイトコードが並びます。 ここではwhen ...を上からに試し、マッチした時点で後半のバイトコードにジャンプします。

後半はwhen ... bodybodyに当たるバイトコードが並んでいて、それぞれbodyの部分を実行するとleave(もしくはjump)でcase ... when ...全体から抜けるようになっています。

# `a === v`や`b === v`にあたる部分
#
# 0000 putself                                                          (  11)[Li]
# 0001 opt_send_without_block                 <calldata!mid:v, argc:0, FCALL|VCALL|ARGS_SIMPLE>
# 0003 putself                                                          (  12)
# 0004 opt_send_without_block                 <calldata!mid:a, argc:0, FCALL|VCALL|ARGS_SIMPLE>
# 0006 topn                                   1
# 0008 opt_send_without_block                 <calldata!mid:===, argc:1, FCALL|ARGS_SIMPLE>
# 0010 branchif                               25
# 0012 putself                                                          (  14)
# 0013 opt_send_without_block                 <calldata!mid:b, argc:0, FCALL|VCALL|ARGS_SIMPLE>
# 0015 topn                                   1
# 0017 opt_send_without_block                 <calldata!mid:===, argc:1, FCALL|ARGS_SIMPLE>
# 0019 branchif                               29

# `else`のbody
#
# 0021 pop                                                              (  17)
# 0022 putobject                              :else_body[Li]
# 0024 leave

# `when a`のbody
#
# 0025 pop                                                              (  12)
# 0026 putobject                              :a_body                   (  13)[Li]
# 0028 leave                                                            (  17)

# `when b`のbody
#
# 0029 pop                                                              (  14)
# 0030 putobject                              :b_body                   (  15)[Li]
# 0032 leave                                                            (  17)

バイトコードを眺める(case when ...の場合)

次にcaseのあとに式を取らない場合のバイトコードをみていきます。

case
when a
  :a_body
when b
  :b_body
else
  :else_body
end

case whenを使わずに書くと以下のようなコードになります。

if a
  :a_body
elsif b
  :b_body
else
  :else_body
end

チェックの方法が変わるだけで、基本的には先ほどのケースと同じようなバイトコードが生成されます。

# `if a`や`if b`にあたる部分
#
# 0000 putself                                                          (  45)[Li]
# 0001 opt_send_without_block                 <calldata!mid:a, argc:0, FCALL|VCALL|ARGS_SIMPLE>
# 0003 branchif                               13
# 0005 putself                                                          (  47)
# 0006 opt_send_without_block                 <calldata!mid:b, argc:0, FCALL|VCALL|ARGS_SIMPLE>
# 0008 branchif                               16

# `else`のbody
#
# 0010 putobject                              :else_body                (  50)[Li]
# 0012 leave

# `when a`のbody
#
# 0013 putobject                              :a_body                   (  46)[Li]
# 0015 leave                                                            (  50)

# `when b`のbody
#
# 0016 putobject                              :b_body                   (  48)[Li]
# 0018 leave                                                            (  50)

compile.cを変更する

もともとcompile.cではNODE_CASEcompile_case関数で、NODE_CASE2compile_case2関数で処理していました。 引き続きそれら2つの関数を使用するようにします。

      case RB_CASE_NODE: {
        if (RB_NODE_CASE(node)->predicate) {
            CHECK(compile_case(iseq, ret, node, popped));
        }
        else {
            CHECK(compile_case2(iseq, ret, node, popped));
        }
        break;
      }

compile_case関数

case ... when ...のケースを処理するcompile_case関数からみていきましょう。 生成されるバイトコードは条件をチェックして必要に応じてjumpする部分と、条件にマッチした場合に実行されるbodyの部分に大きく分かれているのでした。 そこでcompile_case関数ではbody_seqcond_seqというアンカーを用意して、WhenNodeをコンパイルするときに条件の部分とbodyの部分を異なるアンカーに追記していき、最後に1つのバイトコード列へmergeします。

static int
compile_case(rb_iseq_t *iseq, LINK_ANCHOR *const ret, const NODE *const orig_node, int popped)
{
    DECL_ANCHOR(head);
    DECL_ANCHOR(body_seq);
    DECL_ANCHOR(cond_seq);

    INIT_ANCHOR(head);
    INIT_ANCHOR(body_seq);
    INIT_ANCHOR(cond_seq);

    ADD_SEQ(ret, cond_seq);
    ADD_SEQ(ret, body_seq);
    return COMPILE_OK;
}

compile_case関数のうち、おもにバイトコードの生成に関するロジックを抜き出すと以下のような構造になっています。

  1. まずpredicate(case v)をコンパイルする
  2. 次にconditions(when a ... when b ...)を1つずつコンパイルする。このときsplat(when *a)かどうかで条件の部分に関して、生成するバイトコードが変化する
  3. さいごにelse_clauseをコンパイルする
static int
compile_case(rb_iseq_t *iseq, LINK_ANCHOR *const ret, const NODE *const orig_node, int popped)
{
    const rb_node_list2_t *vals;
    const rb_node_list2_t *list = &RB_NODE_CASE(node)->conditions;
    const rb_else_node_t *nd_else = RB_NODE_CASE(node)->else_clause;

    CHECK(COMPILE(head, "case base", RB_NODE_CASE(node)->predicate));

    /* when */
    for (size_t i = 0; i < RB_NODE_LIST_LEN(list); i++) {
        const NODE *n = list->nodes[i];

        CHECK(COMPILE_(body_seq, "when body", RB_NODE_WHEN(n)->statements, popped));
        vals = &RB_NODE_WHEN(n)->conditions;
        if (!RB_NODE_LIST_EMPTY_P(vals)) {
            /* when a, b, c */
            for (size_t j = 0; j < RB_NODE_LIST_LEN(vals); j++) {
                const NODE *val = vals->nodes[j];

                if (nd_type_p(val, RB_SPLAT_NODE)) {
                    only_special_literals = 0;
                    CHECK(when_splat_vals(iseq, cond_seq, val, l1, only_special_literals, literals));
                }
                else {
                    only_special_literals = when_vals(iseq, cond_seq, node, val, l1, only_special_literals, literals);
                    if (only_special_literals < 0) return COMPILE_NG;
                }
            }
        }
        else {
            EXPECT_NODE_NONULL("NODE_CASE", n, NODE_LIST, COMPILE_NG);
        }
    }
    /* else */
    if (nd_else) {
        ADD_LABEL(cond_seq, elselabel);
        ADD_INSN(cond_seq, nd_else, pop);
        add_trace_branch_coverage(iseq, cond_seq, nd_code_loc(nd_else), nd_node_id(nd_else), branch_id, "else", branches);
        CHECK(COMPILE_(cond_seq, "else", nd_else, popped));
        ADD_INSNL(cond_seq, nd_else, jump, endlabel);
    }
    else {
        debugs("== else (implicit)\n");
        ADD_LABEL(cond_seq, elselabel);
        ADD_INSN(cond_seq, orig_node, pop);
        add_trace_branch_coverage(iseq, cond_seq, nd_code_loc(orig_node), nd_node_id(orig_node), branch_id, "else", branches);
        if (!popped) {
            ADD_INSN(cond_seq, orig_node, putnil);
        }
        ADD_INSNL(cond_seq, orig_node, jump, endlabel);
    }

    ADD_SEQ(ret, cond_seq);
    ADD_SEQ(ret, body_seq);
    ADD_LABEL(ret, endlabel);
    return COMPILE_OK;
}

想定したバイトコードが出力されることを確認します。

# == disasm: #<ISeq:<main>@../../test.rb:1 (1,0)-(8,3)>
# 0000 putself                                                          (   1)[Li]
# 0001 opt_send_without_block                 <calldata!mid:v, argc:0, FCALL|VCALL|ARGS_SIMPLE>
# 0003 putself                                                          (   2)
# 0004 opt_send_without_block                 <calldata!mid:a, argc:0, FCALL|VCALL|ARGS_SIMPLE>
# 0006 topn                                   1
# 0008 opt_send_without_block                 <calldata!mid:===, argc:1, FCALL|ARGS_SIMPLE>
# 0010 branchif                               25
# 0012 putself                                                          (   4)
# 0013 opt_send_without_block                 <calldata!mid:b, argc:0, FCALL|VCALL|ARGS_SIMPLE>
# 0015 topn                                   1
# 0017 opt_send_without_block                 <calldata!mid:===, argc:1, FCALL|ARGS_SIMPLE>
# 0019 branchif                               29
# 0021 pop                                                              (   7)
# 0022 putobject                              :else_body[Li]
# 0024 leave
# 0025 pop                                                              (   2)
# 0026 putobject                              :a_body                   (   3)[Li]
# 0028 leave                                                            (   7)
# 0029 pop                                                              (   4)
# 0030 putobject                              :b_body                   (   5)[Li]
# 0032 leave                                                            (   7)

case v 
when a
  :a_body
when b
  :b_body
else
  :else_body
end

whenの条件が複数あるとき

case ... when ...にはいくつか細かい注意点があるので、それらをみておきます。 まずはwhen a1, a2のようにwhenのあとに2つ以上の条件がある場合です。

このときは左から順にマッチするかチェックするため、その順番に===を呼び出して条件付きジャンプをするバイトコードが生成されます。 when a1, when a2, when a3の全てでbranchif 43と飛び先の命令が同じになっています。

# `case v`
#
# 0000 putself                                                          (   2)[Li]
# 0001 opt_send_without_block                 <calldata!mid:v, argc:0, FCALL|VCALL|ARGS_SIMPLE>

# `when a1`
#
# 0003 putself                                                          (   3)
# 0004 opt_send_without_block                 <calldata!mid:a1, argc:0, FCALL|VCALL|ARGS_SIMPLE>
# 0006 topn                                   1
# 0008 opt_send_without_block                 <calldata!mid:===, argc:1, FCALL|ARGS_SIMPLE>
# 0010 branchif                               43

# `when a2`
#
# 0012 putself
# 0013 opt_send_without_block                 <calldata!mid:a2, argc:0, FCALL|VCALL|ARGS_SIMPLE>
# 0015 topn                                   1
# 0017 opt_send_without_block                 <calldata!mid:===, argc:1, FCALL|ARGS_SIMPLE>
# 0019 branchif                               43

# `when a3`
#
# 0021 putself
# 0022 opt_send_without_block                 <calldata!mid:a3, argc:0, FCALL|VCALL|ARGS_SIMPLE>
# 0024 topn                                   1
# 0026 opt_send_without_block                 <calldata!mid:===, argc:1, FCALL|ARGS_SIMPLE>
# 0028 branchif                               43

# `when b`
#
# 0030 putself                                                          (   5)
# 0031 opt_send_without_block                 <calldata!mid:b, argc:0, FCALL|VCALL|ARGS_SIMPLE>
# 0033 topn                                   1
# 0035 opt_send_without_block                 <calldata!mid:===, argc:1, FCALL|ARGS_SIMPLE>
# 0037 branchif                               47

# 以下body
#
# 0039 pop                                                              (   8)
# 0040 putobject                              :else_body[Li]
# 0042 leave
# 0043 pop                                                              (   3)
# 0044 putobject                              :a_body                   (   4)[Li]
# 0046 leave                                                            (   8)
# 0047 pop                                                              (   5)
# 0048 putobject                              :b_body                   (   6)[Li]
# 0050 leave                                                            (   8)
case v
when a1, a2, a3
  :a_body
when b
  :b_body
else
  :else_body
end

splatがあるとき

when *bのようにsplatがある場合はマッチのためのバイトコードが変わります。

# `case v`
#
# 0000 putself                                                          (   1)[Li]
# 0001 opt_send_without_block                 <calldata!mid:v, argc:0, FCALL|VCALL|ARGS_SIMPLE>

# `when a`
#
# 0003 putself                                                          (   2)
# 0004 opt_send_without_block                 <calldata!mid:a, argc:0, FCALL|VCALL|ARGS_SIMPLE>
# 0006 topn                                   1
# 0008 opt_send_without_block                 <calldata!mid:===, argc:1, FCALL|ARGS_SIMPLE>
# 0010 branchif                               26

# `when *b`
#
# 0012 dup                                                              (   4)
# 0013 putself
# 0014 opt_send_without_block                 <calldata!mid:b, argc:0, FCALL|VCALL|ARGS_SIMPLE>
# 0016 splatarray                             false
# 0018 checkmatch                             6
# 0020 branchif                               30

# 以下はwhenやelseのbody
#
# 0022 pop                                                              (   7)
# 0023 putobject                              :else_body[Li]
# 0025 leave
# 0026 pop                                                              (   2)
# 0027 putobject                              :a_body                   (   3)[Li]
# 0029 leave                                                            (   7)
# 0030 pop                                                              (   4)
# 0031 putobject                              :b_body                   (   5)[Li]
# 0033 leave                                                            (   7)

case v
when a
  :a_body
when *b
  :b_body
else
  :else_body
end

実行時のスタックを考えてみましょう。

まずはwhen aの場合からみてみます。 このときはa === vを実行するためtopnvがスタックトップにくるようにしてから#===メソッドを呼び出します。

# `0006 topn 1`まで
v
a
v

# `0008 opt_send_without_block <calldata!mid:===, argc:1, FCALL|ARGS_SIMPLE>`まで
a === v # の結果
v

次にwhen *bの場合です。 このときはcheckmatch(VALUE target, VALUE pattern)を実行するため、v, bの順番(右がスタックトップ)にスタックに積まれている必要があります。

# `0012 dup`まで
v
v

# `0016 splatarray false`まで
b # をdupしたもの
v
v

最適化されたバイトコード

最後に最適化についてみておきます。 when ...のすべての要素が実行時評価を必要としない場合には、条件をkey、飛び先をvalueとしたhashを用意して、そこから直接飛び先を引くという最適化が入ります。 以下のバイトコードでいうと0004 opt_case_dispatch <cdhash>, 23が追加されています。

# 0000 putself                                                          (   1)[Li]
# 0001 opt_send_without_block                 <calldata!mid:v, argc:0, FCALL|VCALL|ARGS_SIMPLE>
# 0003 dup
# 0004 opt_case_dispatch                      <cdhash>, 23
# 0007 putobject                              :a                        (   2)
# 0009 topn                                   1
# 0011 opt_send_without_block                 <calldata!mid:===, argc:1, FCALL|ARGS_SIMPLE>
# 0013 branchif                               27
# 0015 putobject                              :b                        (   4)
# 0017 topn                                   1
# 0019 opt_send_without_block                 <calldata!mid:===, argc:1, FCALL|ARGS_SIMPLE>
# 0021 branchif                               31
# 0023 pop                                                              (   7)
# 0024 putobject                              :body_else[Li]
# 0026 leave
# ...

case v
when :a
  :body_a
when :b
  :body_b
else
  :body_else
end

ここまで踏まえて、簡単なコードで動作確認しておきましょう。

def m(v)
  a = :a
  b = :b

  case v
  when a
    :body_a
  when b
    :body_b
  else
    :body_else
  end
end

p m(:a)
#=> :body_a
p m(:else)
#=> :body_else

compile_case2関数

次にcase when ...をコンパイルするcompile_case2関数を新しいノードに対応させます。 といっても大まかな流れは変更前、およびcompile_case関数と変わりません。

case when ...casepredicateが存在しないため、いきなりwhenをコンパイルすることになります。 whenの条件の部分と、bodyの部分を別々のアンカーに追記していくのはcompile_case関数と変わりません。 その後、elseに当たる部分をコンパイルして終了です。

static int
compile_case2(rb_iseq_t *iseq, LINK_ANCHOR *const ret, const NODE *const orig_node, int popped)
{
    /* when */
    for (size_t i = 0; i < RB_NODE_LIST_LEN(list); i++) {
        const NODE *n = list->nodes[i];

        // when ... bodyのうち、bodyの部分のコンパイル
        CHECK(COMPILE_(body_seq, "when", RB_NODE_WHEN(n)->statements, popped));
        ADD_INSNL(body_seq, n, jump, endlabel);

        // when ...のうち、...の部分のコンパイル
        vals = &RB_NODE_WHEN(n)->conditions;
        if (RB_NODE_LIST_EMPTY_P(vals)) {
            EXPECT_NODE_NONULL("NODE_WHEN", n, NODE_LIST, COMPILE_NG);
        }
        /* when a, b, c */
        for (size_t j = 0; j < RB_NODE_LIST_LEN(vals); j++) {
            const NODE *val = vals->nodes[j];

            if (nd_type_p(val, RB_SPLAT_NODE)) {
                ADD_INSN(ret, val, putnil);
                CHECK(COMPILE(ret, "when2/cond splat", RB_NODE_SPLAT(val)->expression));
                ADD_INSN1(ret, val, splatarray, Qtrue);
                ADD_INSN1(ret, val, checkmatch, INT2FIX(VM_CHECKMATCH_TYPE_WHEN | VM_CHECKMATCH_ARRAY));
                ADD_INSNL(ret, val, branchif, l1);
            }
            else {
                LABEL *lnext;
                lnext = NEW_LABEL(nd_line(val));
                debug_compile("== when2\n", (void)0);
                CHECK(compile_branch_condition(iseq, ret, val, l1, lnext));
                ADD_LABEL(ret, lnext);
            }
        }
    }
    /* else */
    CHECK(COMPILE_(ret, "else", nd_else, popped));
    ADD_INSNL(ret, orig_node, jump, endlabel);

    ADD_SEQ(ret, body_seq);
    ADD_LABEL(ret, endlabel);
    return COMPILE_OK;
}

動作確認をして終わりにしましょう。

def m(a: false, b1: false, b2: false, c: [false])
  case
  when a
    :body_a
  when b1, b2
    :body_b
  when *c
    :body_c
  else
    :body_else
  end
end

p m()
#=> :body_else

p m(a: true)
#=> :body_a

p m(b1: true)
#=> :body_b

p m(c: [true])
#=> :body_c

まとめ

今日の成果です。

  • case v when ....に対応した
  • case when ....に対応した



以上の内容はhttps://yui-knk.hatenablog.com/entry/2026/02/22/135412より取得しました。
このページはhttp://font.textar.tv/のウェブフォントを使用してます

不具合報告/要望等はこちらへお願いします。
モバイルやる夫Viewer Ver0.14