以下の内容はhttps://yui-knk.hatenablog.com/より取得しました。


Ruby Parser開発日誌 (24-40) - parse.yが生成するノードを変える ー パターンマッチング その3 (Variable patternと後置if, unless)

40日目: Variable patternと後置if, unlessに対応する

前回はValue patternの対応をしました。 今回はVariable patternと後置if, unlessに取り組みます。

Variable pattern

Variable patternもしくはvariable captureとは、パターンマッチングによるローカル変数の束縛のことです。 以下のコードではa = v相当の処理が行われます。

def m(v)
  case v

  # Before
  #
  # @ NODE_IN (id: 8, line: 3, location: (3,2)-(6,9))
  # +- nd_head:
  # |   @ NODE_LASGN (id: 5, line: 3, location: (3,5)-(3,6))
  # |   +- nd_vid: :a
  # |   +- nd_value:
  # |       (null node)

  # After
  #
  # @ InNode (location: (3,2)-(4,5))
  # +-- pattern:
  # |   @ LocalVariableTargetNode (location: (3,5)-(3,6))
  # |   +-- name: :a
  # |   +-- depth: 0
  in a
    a
  else
    :else
  end
end

対応するノードは書き換え前がNODE_LASGNで、書き換え後がLocalVariableTargetNodeです。 parse.yではアクションで呼び出す関数を変更してTargetNodeが生成されるようにします。

@@ -6372,7 +6372,7 @@ p_variable        : tIDENTIFIER
                     {
                         error_duplicate_pattern_variable(p, $1, &@1);
                     /*% ripper: var_field!($:1) %*/
-                        $$ = assignable(p, $1, 0, &@$);
+                        $$ = assignable_target(p, $1, &@$);
                     }
                 ;

Variable patternの場合のバイトコードはsetlocaljumpからなります。

# `case v`
#
# 0000 putnil                                                           (   3)[LiCa]
# 0001 getlocal_WC_0                          v@0                       (   2)
# 0003 dup                                                              (   3)

# `in a`
#
# 0004 setlocal_WC_0                          a@1
# 0006 jump                                   13

# `else`
#
# 0008 pop                                                              (   6)
# 0009 pop
# 0010 putobject                              :else[Li]
# 0012 leave                                                            (   8)[Re]

# `in a`のbody
#
# 0013 adjuststack                            2                         (   3)
# 0015 getlocal_WC_0                          a@1                       (   4)[Li]
# 0017 leave                                                            (   8)[Re]
def m(v)
  case v
  in a
    a
  else
    :else
  end
end

書き換え前のコンパイラをみると以下の3つを行なっています。

  1. Alternative patternのときに変数の束縛を禁止するためのチェックをする
  2. setlocal命令を生成するためにindexとlevelを解決する
  3. setlocal命令とjump命令を生成する
      case NODE_LASGN: {
        struct rb_iseq_constant_body *const body = ISEQ_BODY(iseq);
        ID id = RNODE_LASGN(node)->nd_vid;
        int idx = ISEQ_BODY(body->local_iseq)->local_table_size - get_local_var_idx(iseq, id);

        if (in_alt_pattern) {
            const char *name = rb_id2name(id);
            if (name && strlen(name) > 0 && name[0] != '_') {
                COMPILE_ERROR(ERROR_ARGS "illegal variable in alternative pattern (%"PRIsVALUE")",
                              rb_id2str(id));
                return COMPILE_NG;
            }
        }

        ADD_SETLOCAL(ret, line_node, idx, get_lvar_level(iseq));
        ADD_INSNL(ret, line_node, jump, matched);
        break;
      }
      case NODE_DASGN: {
        int idx, lv, ls;
        ID id = RNODE_DASGN(node)->nd_vid;

        idx = get_dyna_var_idx(iseq, id, &lv, &ls);

        if (in_alt_pattern) {
            const char *name = rb_id2name(id);
            if (name && strlen(name) > 0 && name[0] != '_') {
                COMPILE_ERROR(ERROR_ARGS "illegal variable in alternative pattern (%"PRIsVALUE")",
                              rb_id2str(id));
                return COMPILE_NG;
            }
        }

        if (idx < 0) {
            COMPILE_ERROR(ERROR_ARGS "NODE_DASGN: unknown id (%"PRIsVALUE")",
                          rb_id2str(id));
            return COMPILE_NG;
        }
        ADD_SETLOCAL(ret, line_node, ls - idx, lv);
        ADD_INSNL(ret, line_node, jump, matched);
        break;
      }

ノードの修正にあわせてcompile_lasgn_lhs関数を使うように書き換えます。

      case RB_LOCAL_VARIABLE_TARGET_NODE: {
        ID id = RB_NODE_LOCAL_VARIABLE_TARGET(node)->name;

        if (in_alt_pattern) {
            const char *name = rb_id2name(id);
            if (name && strlen(name) > 0 && name[0] != '_') {
                COMPILE_ERROR(ERROR_ARGS "illegal variable in alternative pattern (%"PRIsVALUE")",
                              rb_id2str(id));
                return COMPILE_NG;
            }
        }

        CHECK(compile_lasgn_lhs(iseq, ret, node, id));
        ADD_INSNL(ret, line_node, jump, matched);
        break;
      }

buildして動作を確認します。

def m(v)
  case v
  in a
    a
  else
    :else
  end
end

p m(:t)
#=> :t
p m(:f)
#=> :f
p m(:e)
#=> :e

後置ifと後置unless

パターンマッチングではpatternのあとにifやunlessを書くことができます。

def m(v, cond)
  case v

  in a if cond
    a
  else
    :else
  end
end

p m(:t, true)
#=> :t
p m(:t, false)
#=> :else

書き換え前のノードはNODE_IF、書き換え後のノードはIfNodeです。 書き換え前後で構造が変化しないのでparse.yの修正は不要です。

  # Before
  #
  # @ NODE_IN (id: 11, line: 3, location: (3,2)-(6,9))
  # +- nd_head:
  # |   @ NODE_IF (id: 8, line: 3, location: (3,5)-(3,14))*
  # |   +- nd_cond:
  # |   |   @ NODE_LVAR (id: 7, line: 3, location: (3,10)-(3,14))
  # |   |   +- nd_vid: :cond
  # |   +- nd_body:
  # |   |   @ NODE_LASGN (id: 6, line: 3, location: (3,5)-(3,6))
  # |   |   +- nd_vid: :a
  # |   |   +- nd_value:
  # |   |       (null node)

  # After
  #
  # @ InNode (location: (3,2)-(4,5))
  # +-- pattern:
  # |   @ IfNode (location: (3,5)-(3,14))
  # |   +-- predicate:
  # |   |   @ LocalVariableReadNode (location: (3,10)-(3,14))
  # |   |   +-- name: :cond
  # |   +-- statements:
  # |   |   @ StatementsNode (location: (3,5)-(3,6))
  # |   |   +-- body: (length: 1)
  # |   |       +-- @ LocalVariableTargetNode (location: (3,5)-(3,6))
  # |   |           +-- name: :a
  in a if cond

生成されるバイトコードはifのbodyに当たるsetlocal aを実行してから条件部分のgetlocal condbranchifを行うようになっています。

# `case v`
#
# 0000 putnil                                                           (   3)[LiCa]
# 0001 getlocal_WC_0                          v@0                       (   2)
# 0003 dup                                                              (   3)

# `in a if cond`
#
# 0004 setlocal_WC_0                          a@2
# 0006 getlocal_WC_0                          cond@1
# 0008 branchif                               15

# `else ...`
#
# 0010 pop                                                              (   6)
# 0011 pop
# 0012 putobject                              :else[Li]
# 0014 leave                                                            (   8)[Re]

# `in a if cond`のbody
#
# 0015 adjuststack                            2                         (   3)
# 0017 getlocal_WC_0                          a@2                       (   4)[Li]
# 0019 leave                                                            (   8)[Re]
def m(v, cond)
  case v
  in a if cond
    a
  else
    :else
  end
end

compile.cの変更は構造体に合わせて修正をするだけです。

      case RB_IF_NODE:
      case RB_UNLESS_NODE: {
        LABEL *match_failed;
        match_failed = unmatched;
        CHECK(iseq_compile_pattern_match(iseq, ret, (NODE *)RB_NODE_IF(node)->statements, unmatched, in_single_pattern, in_alt_pattern, base_index, use_deconstructed_cache));
        CHECK(COMPILE(ret, "case in if", RB_NODE_IF(node)->predicate));
        if (in_single_pattern) {
            LABEL *match_succeeded;
            match_succeeded = NEW_LABEL(line);

            ADD_INSN(ret, line_node, dup);
            if (nd_type_p(node, NODE_IF)) {
                ADD_INSNL(ret, line_node, branchif, match_succeeded);
            }
            else {
                ADD_INSNL(ret, line_node, branchunless, match_succeeded);
            }

            ADD_INSN1(ret, line_node, putobject, rb_fstring_lit("guard clause does not return true")); // (1)
            ADD_INSN1(ret, line_node, setn, INT2FIX(base_index + CASE3_BI_OFFSET_ERROR_STRING + 1 /* (1) */)); // (2)
            ADD_INSN1(ret, line_node, putobject, Qfalse);
            ADD_INSN1(ret, line_node, setn, INT2FIX(base_index + CASE3_BI_OFFSET_KEY_ERROR_P + 2 /* (1), (2) */));

            ADD_INSN(ret, line_node, pop);
            ADD_INSN(ret, line_node, pop);

            ADD_LABEL(ret, match_succeeded);
        }
        if (nd_type_p(node, RB_IF_NODE)) {
            ADD_INSNL(ret, line_node, branchunless, match_failed);
        }
        else {
            ADD_INSNL(ret, line_node, branchif, match_failed);
        }
        ADD_INSNL(ret, line_node, jump, matched);
        break;
      }

実際にコードを実行して挙動を確認します。

def m(v, cond)
  case v
  in a if cond
    a
  else
    :else
  end
end

p m(:t, true)
#=> :t
p m(:t, false)
#=> :else

まとめ

  • Variable patternに対応した
  • 後置ifと後置unlessに対応した

パターンマッチング全体の進捗は以下の通りです。

  • Value pattern
    • p_primitive ("str", 1, :symなど)
    • range_expr (1...3など)
    • p_var_ref (^varなど)
    • p_expr_ref (^(cmd 1, 2)など)
    • p_const (A, ::A, A::Bなど)
  • Variable pattern
  • Array pattern
  • Hash pattern
  • Find pattern
  • Alternative pattern
  • As pattern
  • 後置ifと後置unless

Ruby Parser開発日誌 (24-39) - parse.yが生成するノードを変える ー パターンマッチング その2 (Value patternとVariable pinning)

39日目: Value patternに対応する

前回はcase ... in String ... endの対応をしました。 今回はString以外のValue patternに取り組みます。

文法の面からみるとValue patternは以下のように分類できます。 このうちp_const (A, ::A, A::Bなど)は前回対応したので、残りを実装していきましょう。

  • Value pattern
    • p_primitive ("str", 1, :symなど)
    • range_expr (1...3など)
    • p_var_ref (^varなど)
    • p_expr_ref (^(cmd 1, 2)など)
    • p_const (A, ::A, A::Bなど)

p_primitiveとrange_expr

p_primitiveには"str", 1, :symが、range_exprには1...3などが含まれます。 parse.yでは以下の生成規則が対応しています。 ここではとくにアクションが設定されていないので、inline_primaryという生成規則が返す値がそのままp_primitiveの値になります。

p_primitive     : inline_primary
                ...
                ;

書き換え後のInNodeにおけるpatternではStringNodeIntegerNodeがそのまま設定されています。 そのためparse.yのアクションを変更する必要はありません。

case v
# @ InNode (location: (2,0)-(3,7))
# +-- pattern:
# |   @ StringNode (location: (2,3)-(2,8))
# |   +-- unescaped: "str"
in "str"
  :expr

# @ InNode (location: (4,0)-(5,7))
# +-- pattern:
# |   @ IntegerNode (location: (4,3)-(4,4))
# |   +-- value: 1
in 1
  :expr

# @ InNode (location: (6,0)-(7,7))
# +-- pattern:
# |   @ SymbolNode (location: (6,3)-(6,7))
# |   +-- unescaped: "sym"
in :sym
  :expr
end

次にcompile.cについてです。 これらのvalue patternのときに生成される命令をみてみるとputobjectにコンパイルされていることがわかります。

# 0000 putnil                                                           (   2)[Li]
# 0001 putself                                                          (   1)
# 0002 opt_send_without_block                 <calldata!mid:v, argc:0, FCALL|VCALL|ARGS_SIMPLE>

# in "str"
#
# 0004 dup                                                              (   2)
# 0005 putchilledstring                       "str"
# 0007 checkmatch                             2
# 0009 branchif                               36

# in 1
#
# 0011 dup                                                              (   4)
# 0012 putobject_INT2FIX_1_
# 0013 checkmatch                             2
# 0015 branchif                               41

# in :sym
#
# 0017 dup                                                              (   6)
# 0018 putobject                              :sym
# 0020 checkmatch                             2
# 0022 branchif                               46
# ...

これは通常のコンパイルに任せればいいので、iseq_compile_pattern_each関数ではcase ...を追加するだけですみます。

static int
iseq_compile_pattern_each(rb_iseq_t *iseq, LINK_ANCHOR *const ret, const NODE *const node, LABEL *matched, LABEL *unmatched, bool in_single_pattern, bool in_alt_pattern, int base_index, bool use_deconstructed_cache)
{
    const int line = nd_line(node);
    const NODE *line_node = node;

    switch (nd_type(node)) {
      case NODE_ARYPTN: {
        ...
        break;
      }
      case NODE_FNDPTN: {
        ...
        break;
      }
      case NODE_HSHPTN: {
        ...
        break;
      }
      case RB_SYMBOL_NODE:
      case RB_REGULAR_EXPRESSION_NODE:
      case RB_SOURCE_LINE_NODE:
      case RB_INTEGER_NODE:
      case RB_FLOAT_NODE:
      case RB_RATIONAL_NODE:
      case RB_IMAGINARY_NODE:
      case RB_SOURCE_FILE_NODE:
      case RB_SOURCE_ENCODING_NODE:
      case RB_STRING_NODE:
      case RB_X_STRING_NODE:
      case RB_INTERPOLATED_STRING_NODE:
      case RB_INTERPOLATED_SYMBOL_NODE:
      case RB_INTERPOLATED_REGULAR_EXPRESSION_NODE:
      case RB_ARRAY_NODE:
      case RB_LAMBDA_NODE:
      case RB_RANGE_NODE:
      case RB_CONSTANT_READ_NODE:
      case RB_LOCAL_VARIABLE_READ_NODE:
      case RB_INSTANCE_VARIABLE_READ_NODE:
      case RB_CLASS_VARIABLE_READ_NODE:
      case RB_GLOBAL_VARIABLE_READ_NODE:
      case RB_TRUE_NODE:
      case RB_FALSE_NODE:
      case RB_SELF_NODE:
      case RB_NIL_NODE:
      case RB_CONSTANT_PATH_NODE:
        CHECK(COMPILE(ret, "case in literal", node)); // (1)
        if (in_single_pattern) {
            ADD_INSN1(ret, line_node, dupn, INT2FIX(2));
        }
        ADD_INSN1(ret, line_node, checkmatch, INT2FIX(VM_CHECKMATCH_TYPE_CASE)); // (2)
        if (in_single_pattern) {
            CHECK(iseq_compile_pattern_set_eqq_errmsg(iseq, ret, node, base_index + 2 /* (1), (2) */));
        }
        ADD_INSNL(ret, line_node, branchif, matched);
        ADD_INSNL(ret, line_node, jump, unmatched);
        break;
      default:
        UNKNOWN_NODE("NODE_IN", node, COMPILE_NG);
    }
    return COMPILE_OK;
}

動作を確認しておきます。

def m(v)
  case v
  in "str"
    :expr_str
  in 1
    :expr_int
  in :sym
    :expr_sym
  in 1...10
    :expr_range
  end
end

p m("str")
#=> :expr_str
p m(1)
#=> :expr_int
p m(:sym)
#=> :expr_sym
p m(5)
#=> :expr_range

p_var_refとvariable pinning

p_var_refは2つの生成規則からなります。

p_var_ref  : '^' tIDENTIFIER
                    {
                        NODE *n = gettable(p, $2, &@$);
                        if (!n) {
                            n = NEW_ERROR(&@$);
                        }
                        else if (!(nd_type_p(n, NODE_LVAR) || nd_type_p(n, NODE_DVAR))) {
                            compile_error(p, "%"PRIsVALUE": no such local variable", rb_id2str($2));
                        }
                        $$ = n;
                    /*% ripper: var_ref!($:2) %*/
                    }
                | '^' nonlocal_var
                    {
                        if (!($$ = gettable(p, $2, &@$))) $$ = NEW_ERROR(&@$);
                    /*% ripper: var_ref!($:2) %*/
                    }
                ;

パターンマッチングのパターンに変数の値を使うときは^をつける必要があります。 in a ...と書くと、それはaに対する代入になるためです。

def m(v)
  a = 0

  case v
  in a
    a
  else
    :else
  end
end

def m2(v)
  a = 0

  case v
  in ^a
    a
  else
    :else
  end
end

p m(0)
#=> 0
p m(1)
#=> 1

p m2(0)
#=> 0
p m2(1)
#=> :else

ノードとしてはPinnedVariableNodeでラップすることになります。

def m(v)
  a = 0

  case v
  # @ InNode (location: (5,2)-(6,5))
  # +-- pattern:
  # |   @ PinnedVariableNode (location: (5,5)-(5,7))
  # |   +-- variable:
  # |   |   @ LocalVariableReadNode (location: (5,6)-(5,7))
  # |   |   +-- name: :a
  # |   |   +-- depth: 0
  # |   +-- operator_loc: (5,5)-(5,6) = "^"
  in ^a
    a

  # @ InNode (location: (15,2)-(16,6))
  # +-- pattern:
  # |   @ PinnedVariableNode (location: (15,5)-(15,8))
  # |   +-- variable:
  # |   |   @ InstanceVariableReadNode (location: (15,6)-(15,8))
  # |   |   +-- name: :@b
  # |   +-- operator_loc: (15,5)-(15,6) = "^"
  in ^@b
    @b

  else
    :else
  end
end

parse.yでは素直にPinnedVariableNodeを生成するだけです。

@@ -6380,15 +6380,24 @@ p_var_ref       : '^' tIDENTIFIER
                         if (!n) {
                             n = NEW_ERROR(&@$);
                         }
-                        else if (!(nd_type_p(n, NODE_LVAR) || nd_type_p(n, NODE_DVAR))) {
+                        else if (!(nd_type_p(n, RB_LOCAL_VARIABLE_READ_NODE))) {
                             compile_error(p, "%"PRIsVALUE": no such local variable", rb_id2str($2));
                         }
+                        else {
+                            n = NEW_RB_PINNED_VARIABLE(n, &@1, &@$);
+                        }
                         $$ = n;
                     /*% ripper: var_ref!($:2) %*/
                     }
                 | '^' nonlocal_var
                     {
-                        if (!($$ = gettable(p, $2, &@$))) $$ = NEW_ERROR(&@$);
+                        NODE *n = gettable(p, $2, &@$);
+                        if (!n) {
+                            $$ = NEW_ERROR(&@$);
+                        }
+                        else {
+                            $$ = NEW_RB_PINNED_VARIABLE(n, &@1, &@$);
+                        }
                     /*% ripper: var_ref!($:2) %*/
                     }
                 ;

compile.cではPinnedVariableNodevariableを取り出してコンパイルします。

@@ -8339,6 +8339,12 @@ iseq_compile_pattern_each(rb_iseq_t *iseq, LINK_ANCHOR *const ret, const NODE *c
       //   ADD_INSNL(ret, line_node, jump, unmatched);
       //   break;
       // }
+      case RB_PINNED_VARIABLE_NODE:
+        node = RB_NODE_PINNED_VARIABLE(node)->variable;
+        goto compile_value;
       case RB_SYMBOL_NODE:
       case RB_REGULAR_EXPRESSION_NODE:
       case RB_SOURCE_LINE_NODE:
       ...
      compile_value:
        CHECK(COMPILE(ret, "case in literal", node)); // (1)
        if (in_single_pattern) {
            ADD_INSN1(ret, line_node, dupn, INT2FIX(2));
        }
        ADD_INSN1(ret, line_node, checkmatch, INT2FIX(VM_CHECKMATCH_TYPE_CASE)); // (2)
        if (in_single_pattern) {
            CHECK(iseq_compile_pattern_set_eqq_errmsg(iseq, ret, node, base_index + 2 /* (1), (2) */));
        }
        ADD_INSNL(ret, line_node, branchif, matched);
        ADD_INSNL(ret, line_node, jump, unmatched);
        break;

^varを含むコードを実行してみます。

def m(v)
  a = 0
  @b = 1

  case v
  in ^a
    :a
  in ^@b
    :@b
  else
    :else
  end
end

p m(0)
# #=> :a
p m(1)
# #=> :@b
p m(2)
# #=> :else

p_expr_ref

p_expr_refという生成規則は^(cmd 1, 2)のように^()の中にexprを書くことができるというルールです。 書き換え前はNODE_BLOCKで、書き換え後はPinnedExpressionNodeで表現しています。

def m(v)
  c = true

  case v
  # Before
  #
  # @ NODE_IN (id: 14, line: 5, location: (5,2)-(8,9))
  # +- nd_head:
  # |   @ NODE_BLOCK (id: 11, line: 5, location: (5,5)-(5,19))
  # |   +- nd_head (1):
  # |       @ NODE_IF (id: 10, line: 5, location: (5,7)-(5,18))*

  # After
  #
  # @ InNode (location: (5,2)-(6,9))
  # +-- pattern:
  # |   @ PinnedExpressionNode (location: (5,5)-(5,19))
  # |   +-- expression:
  # |   |   @ IfNode (location: (5,7)-(5,18))
  in ^(c ? :t : :f)
    :expr
  else
    :else
  end
end

parse.yでは生成するノードを変更します。

 p_expr_ref     : '^' tLPAREN expr_value rparen
                     {
-                        $$ = NEW_BLOCK($3, &@$);
+                        $$ = NEW_RB_PINNED_EXPRESSION($3, &@$, &@1, &@2, &@4);
                     /*% ripper: begin!($:3) %*/
                     }
                 ;

compile.cではPinnedVariableNodeと同様にexpressionを取り出してコンパイルするようにします。

@@ -8339,6 +8339,12 @@ iseq_compile_pattern_each(rb_iseq_t *iseq, LINK_ANCHOR *const ret, const NODE *c
       //   ADD_INSNL(ret, line_node, jump, unmatched);
       //   break;
       // }
       case RB_PINNED_VARIABLE_NODE:
         node = RB_NODE_PINNED_VARIABLE(node)->variable;
         goto compile_value;
+      case RB_PINNED_EXPRESSION_NODE:
+        node = RB_NODE_PINNED_EXPRESSION(node)->expression;
+        goto compile_value;
       case RB_SYMBOL_NODE:
       case RB_REGULAR_EXPRESSION_NODE:
       case RB_SOURCE_LINE_NODE:

動作確認します。

def m(v)
  c = true

  case v
  in ^(c ? :t : :f)
    :expr
  else
    :else
  end
end

p m(:t)
# # #=> :expr
p m(:f)
# # #=> :else
p m(:e)
# # #=> :else

まとめ

今日の成果です。

  • Value patternのうちConst以外の残りに対応した

パターンマッチング全体の進捗は以下の通りです。

  • Value pattern
    • p_primitive ("str", 1, :symなど)
    • range_expr (1...3など)
    • p_var_ref (^varなど)
    • p_expr_ref (^(cmd 1, 2)など)
    • p_const (A, ::A, A::Bなど)
  • Variable pattern
  • Array pattern
  • Hash pattern
  • Find pattern
  • Alternative pattern
  • As pattern
  • 後置ifと後置unless

Ruby Parser開発日誌 (24-38) - parse.yが生成するノードを変える ー パターンマッチング その1 (定数とのマッチ)

38日目: パターンマッチングはじめました

前回はcase whenをやったので、今回からパターンマッチングに取り組みたいと思います。 おそらく1回で終わらないと思うので、数回に分けて対応していければいいなと考えています。

パターンマッチングの文法を洗い出す

パターンマッチングは豊富な文法要素をもっています。 まずはそれを種類ごとに整理しておきましょう。

パターンマッチング全体の構造として3つの異なる文法が用意されています。

  1. case expr in pattern ...
  2. expr in pattern
  3. expr => pattern

それぞれ具体的なコードは以下の通りです。

# その1
case ary
in [0, 1]
  expr1
in [1, 2]
  expr2
else
  expr_else
end

# その2
ary in [1, 2]

# その3
ary => [1, 2]
# その2

パターンについても複数の異なるパターンが用意されています。 多くない...?

  1. Value pattern
  2. Variable pattern
  3. Array pattern
  4. Hash pattern
  5. Find pattern
  6. Alternative pattern
  7. As pattern

それぞれ具体的には以下のようなコードになります。

a = 1

case obj
in String # Value pattern
  expr
in ^a # Value pattern (variable pinning)
  expr
in v # Variable pattern
  expr
in [1, 2] # Array pattern
  expr
in {k: :v} # Hash pattern
  expr
in [*a, b, c, *d] # Find pattern
  expr
in [1, 2] | {k: :v} # Alternative pattern
  expr
in Integer => a, Integer # As pattern
  expr
else
  expr_else
end

一度に全部をやるのは難しそうなので、今回はつぎの2つの点に絞って実装していきます。

  • case expr in pattern ...の形式
  • Value pattern (の一部)

まずは以下のコードをコンパイルするところまで進めてみます。

case "str"
in String
  expr_str
in Integer
  expr_int
else
  expr_else
end

parserの変更

書き換え前後のノードを見比べてみます。

# Before
#
# @ NODE_CASE3
# +- nd_head:
# |   @ NODE_STR ("str")
# +- nd_body:
# |   @ NODE_IN (String)
# |   |   @ NODE_IN (Integer)
# |   |   |   @ NODE_VCALL (expr_else)

# After
#
# @ CaseMatchNode (location: (1,0)-(8,3))
# +-- predicate:
# |   @ StringNode ("str")
# +-- conditions: (length: 2)
# |   +-- @ InNode (String)
# |   +-- @ InNode (Integer)
# +-- else_clause:
# |   @ ElseNode (location: (6,0)-(8,3))
# |   +-- statements:
# |   |   @ StatementsNode (expr_else)

case "str"
in String
  expr_str
in Integer
  expr_int
else
  expr_else
end

ここで3つの差異があることがわかります。

  1. NODE_CASE3の代わりにCaseMatchNodeを使う
  2. NODE_INがリスト構造ではなくなる
  3. elseにあたるノードが直接CaseMatchNodeに紐づく

これは前回取り組んだcase whenのケースと同じような構造の変化です。

parse.yのアクションに関していうと、p_case_bodyInNodeを常に配列の先頭に入れるようにして、CaseMatchNodeをつくるときにElseNodeがあれば取り出してelse_clauseに設定します。 このあたりは前回と同様の変更です。

p_case_body     : keyword_in
                  p_in_kwarg[ctxt] p_pvtbl p_pktbl
                  p_top_expr[expr] then
                    {
                        pop_pktbl(p, $p_pktbl);
                        pop_pvtbl(p, $p_pvtbl);
                        p->ctxt.in_kwarg = $ctxt.in_kwarg;
                        p->ctxt.in_alt_pattern = $ctxt.in_alt_pattern;
                        p->ctxt.capture_in_pattern = $ctxt.capture_in_pattern;
                    }
                  compstmt(stmts)
                  p_cases[cases]
                    {
                        $$ = NEW_RB_IN($expr, $compstmt, &@$, &@keyword_in, &@then);
                        if ($cases) {
                            $$ = node_array_prepend(p, $cases, $$, &@$);
                        }
                        else {
                            $$ = NEW_RB_ARRAY($$, &@$);
                        }
                    /*% ripper: in!($:expr, $:compstmt, $:cases) %*/
                    }
                ;

p_cases         : opt_else
                    {
                        if ($1) {
                            $$ = NEW_RB_ARRAY($1, $$);
                        }
                    }
                | p_case_body
                ;

大枠の構造ができたので、次にパターンの部分についてみていきます。 今回のケースでいうと、in Stringin Integerの部分はConstantReadNodeとして表現されています。

# conditions: (length: 2)
# +-- @ InNode (location: (2,0)-(3,10))
# |   +-- pattern:
# |   |   @ ConstantReadNode (location: (2,3)-(2,9))
# |   |   +-- name: :String
# +-- @ InNode (location: (4,0)-(5,10))
#     +-- pattern:
#     |   @ ConstantReadNode (location: (4,3)-(4,10))
#     |   +-- name: :Integer

このinのあとのStringは生成規則を辿っていくとp_constというルールで定義されています。

p_case_body     : keyword_in
                  p_in_kwarg[ctxt] p_pvtbl p_pktbl
                  p_top_expr[expr] then
                  compstmt(stmts)
                  p_cases[cases]

p_top_expr      : p_top_expr_body
                | p_top_expr_body modifier_if expr_value
                | p_top_expr_body modifier_unless expr_value
                ;

p_top_expr_body : p_expr
                | p_expr ','
                | p_expr ',' p_args
                | p_find
                | p_args_tail
                | p_kwargs
                ;

p_expr          : p_as
                ;

p_as            : p_expr tASSOC p_variable
                | p_alt
                ;


p_alt           : p_alt[left] '|'[alt] p_expr_basic[right]
                | p_expr_basic
                ;


p_expr_basic    : p_value
                | p_variable
                | p_const p_lparen[p_pktbl] p_args rparen
                | p_const p_lparen[p_pktbl] p_find rparen
                | p_const p_lparen[p_pktbl] p_kwargs rparen
                ...
                ;

p_value         : p_primitive
                | range_expr(p_primitive)
                | p_var_ref
                | p_expr_ref
                | p_const
                ;

p_const         : tCOLON3 cname
                    {
                        $$ = NEW_COLON3($2, &@$, &@1, &@2);
                    }
                | p_const tCOLON2 cname
                    {
                        $$ = NEW_COLON2($1, $3, &@$, &@2, &@3);
                    }
                | tCONSTANT
                   {
                        $$ = gettable(p, $1, &@$);
                   }
                ;

定数に対するgettable関数はConstantReadNodeを返すので期待するノードが生成されていることがわかります。 in Stringの他にもin ::Integerin A::Bを書くこともできます。 それぞれ期待されるノードはConstantPathNodeConstantPathNode(ConstantReadNode)(ネストしている)ですが、これらは今のアクションが生成するノードと一致しているため、value patternにおける定数のケースではとくにアクションを修正する必要がないことがわかります。

生成されるバイトコードを確認する

サンプルコードから生成されるバイトコードを先にみておきましょう。 基本的な構造は

  • case ...の部分のバイトコード
  • in ...のチェックとジャンプをするバイトコード
  • else bodyのbodyの部分のバイトコード
  • in ... bodyのbodyの部分のバイトコード

となっており、これも前回やったcase whenとほぼ同じ構造になっています。 一番最初にputnilしているのは実際にこれを使うときに説明しようとおもいます。

# `case "str"`の部分
#
# 0000 putnil                                                           (   2)[Li]
# 0001 putchilledstring                       "str"                     (   1)

# `String === "str"`によるチェック
#
# 0003 dup                                                              (   2)
# 0004 opt_getconstant_path                   <ic:0 String>
# 0006 checkmatch                             2
# 0008 branchif                               23

# `Integer === "str"`によるチェック
#
# 0010 dup                                                              (   4)
# 0011 opt_getconstant_path                   <ic:1 Integer>
# 0013 checkmatch                             2
# 0015 branchif                               29

# `else`のケース
#
# 0017 pop                                                              (   7)
# 0018 pop
# 0019 putself                                [Li]
# 0020 opt_send_without_block                 <calldata!mid:expr_else, argc:0, FCALL|VCALL|ARGS_SIMPLE>
# 0022 leave

# `String === "str"`のケース
#
# 0023 adjuststack                            2                         (   2)
# 0025 putself                                                          (   3)[Li]
# 0026 opt_send_without_block                 <calldata!mid:expr_str, argc:0, FCALL|VCALL|ARGS_SIMPLE>
# 0028 leave                                                            (   7)

# ` === "str"`のケース
#
# 0029 adjuststack                            2                         (   4)
# 0031 putself                                                          (   5)[Li]
# 0032 opt_send_without_block                 <calldata!mid:expr_int, argc:0, FCALL|VCALL|ARGS_SIMPLE>
# 0034 leave

case "str"
in String
  expr_str
in Integer
  expr_int
else
  expr_else
end

compile.cを変更する

compile.cではcompile_case3という関数があるので、それを修正していくことにします。

      case RB_CASE_MATCH_NODE: {
        CHECK(compile_case3(iseq, ret, node, popped));
        break;
      }

compile_case3関数ではhead, body_seq, cond_seqという3つのアンカーを用意してバイトコードを生成していきます。 この辺はcase whenのときと同じです。

static int
compile_case3(rb_iseq_t *iseq, LINK_ANCHOR *const ret, const NODE *const orig_node, int popped)
{
    DECL_ANCHOR(head);
    DECL_ANCHOR(body_seq);
    DECL_ANCHOR(cond_seq);

    INIT_ANCHOR(head);
    INIT_ANCHOR(body_seq);
    INIT_ANCHOR(cond_seq);

    ADD_SEQ(ret, head);  /* case VAL */

    ADD_SEQ(ret, cond_seq);
    ADD_SEQ(ret, body_seq);
    ADD_LABEL(ret, endlabel);
    return COMPILE_OK;
}

まずcase "str"の部分のコンパイルですが、これは最初にいくつかスタックを確保しておきます。 確保する数はパターンが一個かどうかで変わりますが、おそらく一個のときだけは決め打ちで確保できるのでしょう。 その後nd_headをコンパイルしています。

static int
compile_case3(rb_iseq_t *iseq, LINK_ANCHOR *const ret, const NODE *const orig_node, int popped)
{
    const NODE *node = orig_node;
    bool single_pattern;

    node = RNODE_CASE3(node)->nd_body;
    EXPECT_NODE("NODE_CASE3", node, NODE_IN, COMPILE_NG);
    single_pattern = !RNODE_IN(node)->nd_next;

    if (single_pattern) {
        /* allocate stack for ... */
        ADD_INSN(head, line_node, putnil); /* key_error_key */
        ADD_INSN(head, line_node, putnil); /* key_error_matchee */
        ADD_INSN1(head, line_node, putobject, Qfalse); /* key_error_p */
        ADD_INSN(head, line_node, putnil); /* error_string */
    }
    ADD_INSN(head, line_node, putnil); /* allocate stack for cached #deconstruct value */

    CHECK(COMPILE(head, "case base", RNODE_CASE3(orig_node)->nd_head));

in ...の部分がconditionsフィールドに移動したことを踏まえて書き換えると以下のようになります。

static int
compile_case3(rb_iseq_t *iseq, LINK_ANCHOR *const ret, const NODE *const orig_node, int popped)
{
    const NODE *node = orig_node;
    const rb_node_list2_t *conditions = &RB_NODE_CASE_MATCH(node)->conditions;
    bool single_pattern;

    single_pattern = RB_NODE_LIST_LEN(conditions) == 1;

    if (single_pattern) {
        /* allocate stack for ... */
        ADD_INSN(head, line_node, putnil); /* key_error_key */
        ADD_INSN(head, line_node, putnil); /* key_error_matchee */
        ADD_INSN1(head, line_node, putobject, Qfalse); /* key_error_p */
        ADD_INSN(head, line_node, putnil); /* error_string */
    }
    ADD_INSN(head, line_node, putnil); /* allocate stack for cached #deconstruct value */

    CHECK(COMPILE(head, "case base", RB_NODE_CASE_MATCH(orig_node)->predicate));

続いてin pattern bodyを1つずつ順番にコンパイルしていきます。 bodybody_seqに、patterncond_seqにそれぞれ追記していきます。

パターンをコンパイルするのはiseq_compile_pattern_each関数なので、あとでiseq_compile_pattern_each関数も修正します。

    while (type == NODE_IN) {
        LABEL *l1;

        // `in pattern body`のうち`body`をコンパイルする
        if (branch_id) {
            ADD_INSN(body_seq, line_node, putnil);
        }
        l1 = NEW_LABEL(line);
        ADD_LABEL(body_seq, l1);
        ADD_INSN1(body_seq, line_node, adjuststack, INT2FIX(single_pattern ? 6 : 2));

        const NODE *const coverage_node = RNODE_IN(node)->nd_body ? RNODE_IN(node)->nd_body : node;
        add_trace_branch_coverage(
            iseq,
            body_seq,
            nd_code_loc(coverage_node),
            nd_node_id(coverage_node),
            branch_id++,
            "in",
            branches);

        CHECK(COMPILE_(body_seq, "in body", RNODE_IN(node)->nd_body, popped));
        ADD_INSNL(body_seq, line_node, jump, endlabel);

        // `in pattern body`のうち`pattern`をコンパイルする
        pattern = RNODE_IN(node)->nd_head;
        if (pattern) {
            int pat_line = nd_line(pattern);
            LABEL *next_pat = NEW_LABEL(pat_line);
            ADD_INSN (cond_seq, pattern, dup); /* dup case VAL */
            // NOTE: set base_index (it's "under" the matchee value, so it's position is 2)
            CHECK(iseq_compile_pattern_each(iseq, cond_seq, pattern, l1, next_pat, single_pattern, false, 2, true));
            ADD_LABEL(cond_seq, next_pat);
            LABEL_UNREMOVABLE(next_pat);
        }
        else {
            COMPILE_ERROR(ERROR_ARGS "unexpected node");
            return COMPILE_NG;
        }

        node = RNODE_IN(node)->nd_next;
        if (!node) {
            break;
        }
        type = nd_type(node);
        line = nd_line(node);
        line_node = node;
    }

NODE_INがリンク構造から配列の要素に変わったことを踏まえて、forによるループに変更します。

    for (size_t i = 0; i < RB_NODE_LIST_LEN(conditions); i++) {
        LABEL *l1;
        const NODE *nd_cond = conditions->nodes[i];
        EXPECT_NODE("NODE_CASE3", nd_cond, RB_IN_NODE, COMPILE_NG);
        type = nd_type(nd_cond);
        line = nd_line(nd_cond);
        line_node = nd_cond;

        if (branch_id) {
            ADD_INSN(body_seq, line_node, putnil);
        }
        l1 = NEW_LABEL(line);
        ADD_LABEL(body_seq, l1);
        ADD_INSN1(body_seq, line_node, adjuststack, INT2FIX(single_pattern ? 6 : 2));

        const NODE *const coverage_node = RB_NODE_IN(nd_cond)->statements ? (const NODE *const)RB_NODE_IN(nd_cond)->statements : nd_cond;
        add_trace_branch_coverage(
            iseq,
            body_seq,
            nd_code_loc(coverage_node),
            nd_node_id(coverage_node),
            branch_id++,
            "in",
            branches);

        CHECK(COMPILE_(body_seq, "in body", RB_NODE_IN(nd_cond)->statements, popped));
        ADD_INSNL(body_seq, line_node, jump, endlabel);

        pattern = RB_NODE_IN(nd_cond)->pattern;
        if (pattern) {
            int pat_line = nd_line(pattern);
            LABEL *next_pat = NEW_LABEL(pat_line);
            ADD_INSN (cond_seq, pattern, dup); /* dup case VAL */
            // NOTE: set base_index (it's "under" the matchee value, so it's position is 2)
            CHECK(iseq_compile_pattern_each(iseq, cond_seq, pattern, l1, next_pat, single_pattern, false, 2, true));
            ADD_LABEL(cond_seq, next_pat);
            LABEL_UNREMOVABLE(next_pat);
        }
        else {
            COMPILE_ERROR(ERROR_ARGS "unexpected node");
            return COMPILE_NG;
        }
    }

最後にelseがあればその部分をコンパイルして終了です。 elseがないときについてはあとで見ることにします。

    /* else */
    if (node) {
        ADD_LABEL(cond_seq, elselabel);
        ADD_INSN(cond_seq, line_node, pop);
        ADD_INSN(cond_seq, line_node, pop); /* discard cached #deconstruct value */
        add_trace_branch_coverage(iseq, cond_seq, nd_code_loc(node), nd_node_id(node), branch_id, "else", branches);
        CHECK(COMPILE_(cond_seq, "else", node, popped));
        ADD_INSNL(cond_seq, line_node, jump, endlabel);
        ADD_INSN(cond_seq, line_node, putnil);
        if (popped) {
            ADD_INSN(cond_seq, line_node, putnil);
        }
    }
    else {
        ...
    }

NODE_INの最後の要素として存在していたelseに相当するノードがelse_clauseに分離されたことを踏まえて書き直します。

    const rb_else_node_t *const nd_else = RB_NODE_CASE_MATCH(node)->else_clause;

    /* else */
    if (nd_else) {
        node = (NODE *)nd_else;
        line_node = node;

        ADD_LABEL(cond_seq, elselabel);
        ADD_INSN(cond_seq, line_node, pop);
        ADD_INSN(cond_seq, line_node, pop); /* discard cached #deconstruct value */
        add_trace_branch_coverage(iseq, cond_seq, nd_code_loc(node), nd_node_id(node), branch_id, "else", branches);
        CHECK(COMPILE_(cond_seq, "else", node, popped));
        ADD_INSNL(cond_seq, line_node, jump, endlabel);
        ADD_INSN(cond_seq, line_node, putnil);
        if (popped) {
            ADD_INSN(cond_seq, line_node, putnil);
        }
    }
    else {
        ...
    }

さて大枠を書き換えたのでiseq_compile_pattern_each関数にいきましょう。 この関数はin pattern ...patternの部分をコンパイルするための関数で、内部はpatternのノードで分岐しています。

static int
iseq_compile_pattern_each(rb_iseq_t *iseq, LINK_ANCHOR *const ret, const NODE *const node, LABEL *matched, LABEL *unmatched, bool in_single_pattern, bool in_alt_pattern, int base_index, bool use_deconstructed_cache)
{
    const int line = nd_line(node);
    const NODE *line_node = node;

    switch (nd_type(node)) {
      case NODE_ARYPTN: {
        ...
        break;
      }
      case NODE_FNDPTN: {
        ...
        break;
      }
      case NODE_HSHPTN: {
        ...
        break;
      }
      case NODE_SYM:
      case NODE_REGX:
      case NODE_LINE:
      case NODE_INTEGER:
      ...
      default:
        UNKNOWN_NODE("NODE_IN", node, COMPILE_NG);
    }
    return COMPILE_OK;
}

今回は定数に関するノードをコンパイルできるようにしましょう。 といっても、基本的には既存のコンパイルの枠組み(COMPILE)に載せるだけなので、作業としてはcase ...のところを書き換えるだけです。

static int
iseq_compile_pattern_each(rb_iseq_t *iseq, LINK_ANCHOR *const ret, const NODE *const node, LABEL *matched, LABEL *unmatched, bool in_single_pattern, bool in_alt_pattern, int base_index, bool use_deconstructed_cache)
{
    const int line = nd_line(node);
    const NODE *line_node = node;

    switch (nd_type(node)) {
      ...
      // case NODE_FALSE:
      // case NODE_SELF:
      // case NODE_NIL:
      case RB_CONSTANT_READ_NODE:
      case RB_CONSTANT_PATH_NODE:  
      // case NODE_BEGIN:
      // case NODE_BLOCK:
      // case NODE_ONCE:
        CHECK(COMPILE(ret, "case in literal", node)); // (1)
        if (in_single_pattern) {
            ADD_INSN1(ret, line_node, dupn, INT2FIX(2));
        }
        ADD_INSN1(ret, line_node, checkmatch, INT2FIX(VM_CHECKMATCH_TYPE_CASE)); // (2)
        if (in_single_pattern) {
            CHECK(iseq_compile_pattern_set_eqq_errmsg(iseq, ret, node, base_index + 2 /* (1), (2) */));
        }
        ADD_INSNL(ret, line_node, branchif, matched);
        ADD_INSNL(ret, line_node, jump, unmatched);
        break;
      ...
      default:
        UNKNOWN_NODE("NODE_IN", node, COMPILE_NG);
    }
    return COMPILE_OK;
}

ここまでで一度minirubyをbuildして動作確認します。

def m(val)
  case val
  in String
    :expr_str
  in Integer
    :expr_int
  else
    :expr_else
  end
end

p m("str")
#=> :expr_str
p m(1)
#=> :expr_int
p m([])
#=> :expr_else

よさそうですね。

elseがないケース

ここで一旦case .... inの全体の構造に戻って、elseがないときのことを考えてみましょう。 elseがなくて、いずれのパターンにもマッチしないときはNoMatchingPatternError例外が投げられます。

case :sym
in String
  :expr_str
in Integer
  :expr_int
end
#=> test.rb:1:in '<main>': sym (NoMatchingPatternError)

このコードに対応するバイトコードは以下のようになっています。 これはRubyVMFrozenCore#raise(NoMatchingPatternError, :sym)を実行して例外を投げていると言えます。

# 0000 putnil                                                           (   2)[Li]
# 0001 putobject                              :sym                      (   1)
# 0003 dup                                                              (   2)
# 0004 opt_getconstant_path                   <ic:0 String>
# 0006 checkmatch                             2
# 0008 branchif                               29
# 0010 dup                                                              (   4)
# 0011 opt_getconstant_path                   <ic:1 Integer>
# 0013 checkmatch                             2
# 0015 branchif                               34

# どのパターンにもマッチしないとき
#
# 0017 putspecialobject                       1                         (   1)
# 0019 putobject                              NoMatchingPatternError
# 0021 topn                                   2
# 0023 opt_send_without_block                 <calldata!mid:core#raise, argc:2, ARGS_SIMPLE>
# 0025 adjuststack                            3
# 0027 putnil
# 0028 leave                                                            (   5)

# `String`にマッチしたとき
#
# 0029 adjuststack                            2                         (   2)
# 0031 putobject                              :expr_str                 (   3)[Li]
# 0033 leave                                                            (   5)

# `Integer`にマッチしたとき
#
# 0034 adjuststack                            2                         (   4)
# 0036 putobject                              :expr_int                 (   5)[Li]
# 0038 leave

compile.cの該当する部分は以下のようになっています。 ここはノードの種類に依存していないので書き換えなしで動くはずです。

static int
compile_case3(rb_iseq_t *iseq, LINK_ANCHOR *const ret, const NODE *const orig_node, int popped)
{
    /* else */
    if (nd_else) {
        ...
    }
    else {
        debugs("== else (implicit)\n");
        ADD_LABEL(cond_seq, elselabel);
        add_trace_branch_coverage(iseq, cond_seq, nd_code_loc(orig_node), nd_node_id(orig_node), branch_id, "else", branches);
        ADD_INSN1(cond_seq, orig_node, putspecialobject, INT2FIX(VM_SPECIAL_OBJECT_VMCORE));

        if (single_pattern) {
            ...
        }
        else {
            ADD_INSN1(cond_seq, orig_node, putobject, rb_eNoMatchingPatternError);
            ADD_INSN1(cond_seq, orig_node, topn, INT2FIX(2));
            ADD_SEND(cond_seq, orig_node, id_core_raise, INT2FIX(2));
        }
        ADD_INSN1(cond_seq, orig_node, adjuststack, INT2FIX(single_pattern ? 7 : 3));
        if (!popped) {
            ADD_INSN(cond_seq, orig_node, putnil);
        }
        ADD_INSNL(cond_seq, orig_node, jump, endlabel);
        ADD_INSN1(cond_seq, orig_node, dupn, INT2FIX(single_pattern ? 5 : 1));
        if (popped) {
            ADD_INSN(cond_seq, line_node, putnil);
        }
    }

    ADD_SEQ(ret, cond_seq);
    ADD_SEQ(ret, body_seq);
    ADD_LABEL(ret, endlabel);
    return COMPILE_OK;
}

実行してみると期待したとおりの例外が発生することがわかります。

case :sym
in String
  :expr_str
in Integer
  :expr_int
end
#=> ../../test.rb:1:in '<main>': sym (NoMatchingPatternError)

single_patternのとき

compile_case3関数にはsingle_pattern(パターンが1つ)のときの分岐が存在します。 まずはバイトコードを眺めてみましょう。

case :sym
in String
  :expr_str
end

このコードからバイトコードを生成すると以下のようなバイトコードが生成されます。

# 0000 putnil                                                           (   2)[Li]
# 0001 putnil
# 0002 putobject                              false
# 0004 putnil
# 0005 putnil
# 0006 putobject                              :sym                      (   1)
# 0008 dup                                                              (   2)
# 0009 opt_getconstant_path                   <ic:0 String>
# 0011 dupn                                   2
# 0013 checkmatch                             2
# 0015 dup
# 0016 branchif                               36
# 0018 putspecialobject                       1
# 0020 putobject                              "%p === %p does not return true"
# 0022 topn                                   3
# 0024 topn                                   5
# 0026 opt_send_without_block                 <calldata!mid:core#sprintf, argc:3, ARGS_SIMPLE>
# 0028 setn                                   6
# 0030 putobject                              false
# 0032 setn                                   8
# 0034 pop
# 0035 pop
# 0036 setn                                   2
# 0038 pop
# 0039 pop
# 0040 branchif                               88
# 0042 putspecialobject                       1                         (   1)
# 0044 topn                                   4
# 0046 branchif                               64
# 0048 putobject                              NoMatchingPatternError
# 0050 putspecialobject                       1
# 0052 putobject                              "%p: %s"
# 0054 topn                                   4
# 0056 topn                                   7
# 0058 opt_send_without_block                 <calldata!mid:core#sprintf, argc:3, ARGS_SIMPLE>
# 0060 opt_send_without_block                 <calldata!mid:core#raise, argc:2, ARGS_SIMPLE>
# 0062 jump                                   84
# 0064 putobject                              NoMatchingPatternKeyError
# 0066 putspecialobject                       1
# 0068 putobject                              "%p: %s"
# 0070 topn                                   4
# 0072 topn                                   7
# 0074 opt_send_without_block                 <calldata!mid:core#sprintf, argc:3, ARGS_SIMPLE>
# 0076 topn                                   7
# 0078 topn                                   9
# 0080 opt_send_without_block                 <calldata!mid:new, argc:3, kw:[#<Symbol:0x000000000023310c>,#<Symbol:0x000000000021f10c>], KWARG>
# 0082 opt_send_without_block                 <calldata!mid:core#raise, argc:1, ARGS_SIMPLE>
# 0084 adjuststack                            7
# 0086 putnil
# 0087 leave                                                            (   3)
# 0088 adjuststack                            6                         (   2)
# 0090 putobject                              :expr_str                 (   3)[Li]
# 0092 leave

なんか異様に長いんだが?? パターンが2つのときのほうがバイトコードが短くて面食らいますね。 順を追ってみていきましょう。

まず最初にスタックに5つの領域を確保します。 これは後々使うタイミングで説明します。

# stackに領域を確保する
#
# 0000 putnil # key_error_key
# 0001 putnil # key_error_matchee
# 0002 putobject false # key_error_p
# 0004 putnil # error_string
# 0005 putnil # cached #deconstruct value

次にString === :symを実行して結果に応じてjumpします。

# `String === :sym`をチェックする
#
# 0006 putobject                              :sym                      (   1)
# 0008 dup                                                              (   2)
# 0009 opt_getconstant_path                   <ic:0 String>
# 0011 dupn                                   2
# 0013 checkmatch                             2
# 0015 dup
# 0016 branchif                               36

このバイトコードによってスタックがどうなるかを確認しておきましょう(最初に確保した部分は変わらないので一旦無視します)。

# `0011 dupn 2`まで
String
:sym
String
:sym
:sym

# `0015 dup`まで
false
false
String
:sym
:sym

# `0016 branchif 36`まで
false
String
:sym
:sym

今回のケースではStirng === :symfalseなので、ここではjumpしません。

# マッチしなかったとき
#
# 0018 putspecialobject                       1
# 0020 putobject                              "%p === %p does not return true"
# 0022 topn                                   3
# 0024 topn                                   5
# 0026 opt_send_without_block                 <calldata!mid:core#sprintf, argc:3, ARGS_SIMPLE>
# 0028 setn                                   6
# 0030 putobject                              false
# 0032 setn                                   8
# 0034 pop
# 0035 pop

続く0018から0035では主にエラーメッセージの構築を行います。

# `0024 topn 5`まで
:sym
String
"%p === %p does not return true"
RubyVMFrozenCore

false
String
:sym
:sym

# `0028 setn 6`まで
"String === :sym does not return true"

false
String
:sym
:sym

nil # cached #deconstruct value
"String === :sym does not return true" # error_string
false # key_error_p
nil # key_error_matchee
nil # key_error_key

#sprintfをつかってエラーメッセージを構築したら、最初に確保したスタックのうち下から4つ目の領域にメッセージをコピーします。 この領域はerror_stringとあるように、例外に渡すメッセージのための領域のようです。

その後key_error_pfalseにするために一時的にスタックにfalseを積んで、popでエラーメッセージの構築に使った領域を捨てます。

# `0032 setn 8`まで
false
"String === :sym does not return true"

false
String
:sym
:sym

nil # cached #deconstruct value
"String === :sym does not return true" # error_string
false # key_error_p <- ここを`0032 setn 8`で更新する
nil # key_error_matchee
nil # key_error_key

# `0035 pop`まで
false
String
:sym
:sym

0036以降の命令はStirng === :symがマッチしたときも、マッチしていないときも実行されます。 そのため0035 popまで終わった時点でもともとジャンプしてきた0016 branchif 36の時点とスタックが同じになるように調整されています(最初に確保した5つの領域の値は変わっていることもありますが)。

0036から0046をみていきましょう。

#
# マッチしたときはここに飛んでくる
#
# 0036 setn                                   2
# 0038 pop
# 0039 pop
# 0040 branchif                               88
# 0042 putspecialobject                       1                         (   1)
# 0044 topn                                   4
# 0046 branchif                               64
...
# `in String :expr_str`のbodyの部分
#
# 0088 adjuststack                            6                         (   2)
# 0090 putobject                              :expr_str                 (   3)[Li]
# 0092 leave

setnpopをつかってString:symを捨てます。

# `0039 pop`まで
false # String === :sym の結果
:sym

nil # cached #deconstruct value
"String === :sym does not return true" # error_string
false # key_error_p
nil # key_error_matchee
nil # key_error_key

0040 branchifString === :symの結果をみて分岐することを意味しています。 0088以降のバイトコードはinのbodyに当たる命令列なので、0088にジャンプしたあとはbodyを評価してcase全体を抜けることになります。

マッチが成功しない場合をみていきましょう。 RubyVMFrozenCoreをスタックに積んだのち、最初に確保した領域からkey_error_pをコピーして0046 branchif 64を行います。

# `0044 topn 4`まで
false # key_error_p
RubyVMFrozenCore
:sym

nil # cached #deconstruct value
"String === :sym does not return true" # error_string
false # key_error_p
nil # key_error_matchee
nil # key_error_key

key_error_pの値によってNoMatchingPatternErrorを投げるかNoMatchingPatternKeyErrorを投げるかが変わります。

# key_error_p == false
#
# 0048 putobject                              NoMatchingPatternError
# 0050 putspecialobject                       1
# 0052 putobject                              "%p: %s"
# 0054 topn                                   4
# 0056 topn                                   7
# 0058 opt_send_without_block                 <calldata!mid:core#sprintf, argc:3, ARGS_SIMPLE>
# 0060 opt_send_without_block                 <calldata!mid:core#raise, argc:2, ARGS_SIMPLE>
# 0062 jump                                   84

# key_error_p == true
#
# 0064 putobject                              NoMatchingPatternKeyError
# 0066 putspecialobject                       1
# 0068 putobject                              "%p: %s"
# 0070 topn                                   4
# 0072 topn                                   7
# 0074 opt_send_without_block                 <calldata!mid:core#sprintf, argc:3, ARGS_SIMPLE>
# 0076 topn                                   7
# 0078 topn                                   9
# 0080 opt_send_without_block                 <calldata!mid:new, argc:3, kw:[#<Symbol:0x000000000023310c>,#<Symbol:0x000000000021f10c>], KWARG>
# 0082 opt_send_without_block                 <calldata!mid:core#raise, argc:1, ARGS_SIMPLE>
#
# key_error_p == falseのときは最後にここに飛んでくる
#
# 0084 adjuststack                            7
# 0086 putnil
# 0087 leave                                                            (   3)

single_patternのときもin pattern bodypatternbodyに依存せず例外を発生させる部分のバイトコードを生成することができます。 そのためcompile_case3関数の修正は特に必要ないでしょう。

patternが1つのケースを実行してみます。

case :sym
in String
  :expr_str
end
#=> ../../test.rb:1:in '<main>': :sym: String === :sym does not return true (NoMatchingPatternError)

良さそうですね

まとめ

今日の成果です。

  • パターンマッチングの外観を眺めて整理した
  • Value patternのうちConstの対応をした(in String ...)

しばらくパターンマッチングが続くと思うので、ブレイクダウンした結果と現在の進捗をまとめておきます。

  • case expr in pattern ...
  • expr in pattern
  • expr => pattern

  • Value pattern

    • p_primitive ("str", 1, :symなど)
    • range_expr (1...3など)
    • p_var_ref (^varなど)
    • p_expr_ref (^(cmd 1, 2)など)
    • p_const (A, ::A, A::Bなど)
  • Variable pattern
  • Array pattern
  • Hash pattern
  • Find pattern
  • Alternative pattern
  • As pattern
  • 後置ifと後置unless

今回はここまで!

Ruby Parser開発日誌 (24-37) - parse.yが生成するノードを変える ー case when

37日目: case whenの対応

今回はcase whenに取り組みます。 Rubyには3種類のcaseがありますが、今回は1・2つめのcase ... when ...をやることにし、3つめのcase ... in ...についてはパターンマッチングのときに取り組むことにします。

# その1
case a
when :a
when :b
else
end

# その2
case
when a
when b
else
end

# その3
case a
in [0]
in [0, 1]
else
end

書き換え前後のノードの変更点

書き換えの前後で大きく3つ変わります。

  1. NODE_CASENODE_CASE2が統合される
  2. NODE_WHENがリスト構造ではなくなる
  3. elseにあたるノードが直接caseのノードに紐づく

それぞれ詳しく説明します。

まず1つめのNODE_CASENODE_CASE2が統合されるについてです。 書き換え前はcase a ... whenにはNODE_CASEcase when ...にはNODE_CASE2と別のノードが割り当てられていました。 書き換え後はどちらもCaseNodeで表現することになり、それらの差はpredicateというフィールドの値の有無で判別するようになります。

つぎは2つめのNODE_WHENがリスト構造ではなくなるについてです。 ノードの書き換え前後を比べるとわかるのですが、書き換え前はNODE_WHENの要素として次のNODE_WHENが保存されていました。 つまりNODE_WHENそのものがリスト構造になっているというわけです。 一方で書き換え後はCaseNodeconditionsというフィールドが配列になっていて、そこに複数のWhenNodeが保存されるようになります。

# Before
#
# @ NODE_CASE (id: 9, line: 11, location: (11,0)-(14,3))*
# +- nd_body:
# |   @ NODE_WHEN (id: 8, line: 12, location: (12,0)-(13,7))
# |   +- nd_next:
# |       @ NODE_WHEN (id: 7, line: 13, location: (13,0)-(13,7))
# |       +- nd_next:
# |           (null node)

# After
#
# @ CaseNode (location: (11,0)-(14,3))
# +-- conditions: (length: 2)
# |   +-- @ WhenNode (location: (12,0)-(12,7))
# |   +-- @ WhenNode (location: (13,0)-(13,7))

case a
when :a
when :b
end

最後に3つめのelseにあたるノードが直接caseのノードに紐づくです。 書き換え前はNODE_WHENの最後の要素としてelseのbodyにあたるノードが保存されている構造になっていました。 書き換え後はCaseNodeelse_clauseというフィールドで管理することになります。

# Before
#
# @ NODE_CASE (id: 9, line: 11, location: (11,0)-(14,3))*
# +- nd_body:
# |   @ NODE_WHEN (id: 8, line: 12, location: (12,0)-(13,7))
# |   +- nd_next:
# |       @ NODE_WHEN (id: 7, line: 13, location: (13,0)-(13,7))
# |       +- nd_next:
# |           @ NODE_SYM (id: 7, line: 15, location: (15,2)-(15,7))*

# After
#
# @ CaseNode (location: (11,0)-(14,3))
# +-- conditions: (length: 2)
# |   +-- @ WhenNode (location: (12,0)-(12,7))
# |   +-- @ WhenNode (location: (13,0)-(13,7))
# +-- else_clause:
# |   @ ElseNode (location: (14,0)-(16,3))

case a
when :a
when :b
else
  :else
end

parse.yの変更

when ... else ...の部分の生成規則は以下のようになっており、後ろのwhenelseから順番にノードを組み立てるようになっています。

primary     | k_case expr_value terms? // case a
              case_body                // when :a ... else ...
              k_end                    // end

case_body   : k_when case_args then    // when :a
              compstmt(stmts)          //   expr
              cases                    // when :b ... else ...

cases       : opt_else                 // else ...
            | case_body                // when :b ... else ...

今回は生成規則に手を入れないようにしたいので、常に配列の先頭にノードを追加するようにします。

@@ -5911,13 +5918,25 @@ case_body       : k_when case_args then
                   compstmt(stmts)
                   cases
                     {
-                        $$ = NEW_WHEN($2, $4, $5, &@$, &@1, &@3);
+                        $$ = NEW_RB_WHEN($2, $4, &@$, &@1, &@3);
                         fixpos($$, $2);
+                        if ($5) {
+                            $$ = node_array_prepend(p, $5, $$, &@$);
+                        }
+                        else {
+                            $$ = NEW_RB_ARRAY($$, &@$);
+                        }
                     /*% ripper: when!($:2, $:4, $:5) %*/
                     }
                 ;

 cases          : opt_else
+                    {
+                        if ($1) {
+                            $$ = NEW_RB_ARRAY($1, $$);
+                        }
+                    /*% ripper: when!($:2, $:4, $:5) %*/
+                    }
                 | case_body
                 ;

またこのままだとWhenNodeElseNodeも1つの配列に入ったままになってしまうので、CaseNodeを作成するときに配列の最後のノードの種類をチェックして、それがElseNodeの場合には配列から取り出してelse_clauseにセットするようにします。

static rb_case_node_t *
rb_new_node_case_new(struct parser_params *p, rb_node_t *nd_head, rb_array_node_t *nd_conds, const YYLTYPE *loc, const YYLTYPE *case_keyword_loc, const YYLTYPE *end_keyword_loc)
{
    rb_case_node_t *n = RB_NEW_NODE_NEWNODE((enum rb_node_type)RB_CASE_NODE, rb_case_node_t, loc);
    rb_node_list2_t *list = &nd_conds->elements;
    rb_node_t *nd_last = rb_node_list_last(list);
    rb_else_node_t *nd_else = NULL;

    if (nd_last && nd_type_p(nd_last, RB_ELSE_NODE)) {
        nd_else = rb_node_list_pop(list);
    }

    n->predicate = nd_head;
    rb_node_list_init_with_src(&n->conditions, list);
    n->else_clause = nd_else;
    n->case_keyword_loc = *case_keyword_loc;
    n->end_keyword_loc = *end_keyword_loc;

    return n;
}

バイトコードを眺める(case ... when ...の場合)

以下のコードを例にどのようなバイトコードが生成されるか確認しておきましょう。

case v
when a
  :a_body
when b
  :b_body
else
  :else_body
end

このコードをcase ... whenを使わずに書くと以下のようなコードになります。

tmp = v
if (a === v)
  :a_body
elsif (b === v)
  :b_body
else
  :else_body
end

生成されるバイトコードは大きく2つの部分からなります。 最初にマッチするwhen ...を決定するバイトコードが並びます。 ここではwhen ...を上からに試し、マッチした時点で後半のバイトコードにジャンプします。

後半はwhen ... bodybodyに当たるバイトコードが並んでいて、それぞれbodyの部分を実行するとleave(もしくはjump)でcase ... when ...全体から抜けるようになっています。

# `a === v`や`b === v`にあたる部分
#
# 0000 putself                                                          (  11)[Li]
# 0001 opt_send_without_block                 <calldata!mid:v, argc:0, FCALL|VCALL|ARGS_SIMPLE>
# 0003 putself                                                          (  12)
# 0004 opt_send_without_block                 <calldata!mid:a, argc:0, FCALL|VCALL|ARGS_SIMPLE>
# 0006 topn                                   1
# 0008 opt_send_without_block                 <calldata!mid:===, argc:1, FCALL|ARGS_SIMPLE>
# 0010 branchif                               25
# 0012 putself                                                          (  14)
# 0013 opt_send_without_block                 <calldata!mid:b, argc:0, FCALL|VCALL|ARGS_SIMPLE>
# 0015 topn                                   1
# 0017 opt_send_without_block                 <calldata!mid:===, argc:1, FCALL|ARGS_SIMPLE>
# 0019 branchif                               29

# `else`のbody
#
# 0021 pop                                                              (  17)
# 0022 putobject                              :else_body[Li]
# 0024 leave

# `when a`のbody
#
# 0025 pop                                                              (  12)
# 0026 putobject                              :a_body                   (  13)[Li]
# 0028 leave                                                            (  17)

# `when b`のbody
#
# 0029 pop                                                              (  14)
# 0030 putobject                              :b_body                   (  15)[Li]
# 0032 leave                                                            (  17)

バイトコードを眺める(case when ...の場合)

次にcaseのあとに式を取らない場合のバイトコードをみていきます。

case
when a
  :a_body
when b
  :b_body
else
  :else_body
end

case whenを使わずに書くと以下のようなコードになります。

if a
  :a_body
elsif b
  :b_body
else
  :else_body
end

チェックの方法が変わるだけで、基本的には先ほどのケースと同じようなバイトコードが生成されます。

# `if a`や`if b`にあたる部分
#
# 0000 putself                                                          (  45)[Li]
# 0001 opt_send_without_block                 <calldata!mid:a, argc:0, FCALL|VCALL|ARGS_SIMPLE>
# 0003 branchif                               13
# 0005 putself                                                          (  47)
# 0006 opt_send_without_block                 <calldata!mid:b, argc:0, FCALL|VCALL|ARGS_SIMPLE>
# 0008 branchif                               16

# `else`のbody
#
# 0010 putobject                              :else_body                (  50)[Li]
# 0012 leave

# `when a`のbody
#
# 0013 putobject                              :a_body                   (  46)[Li]
# 0015 leave                                                            (  50)

# `when b`のbody
#
# 0016 putobject                              :b_body                   (  48)[Li]
# 0018 leave                                                            (  50)

compile.cを変更する

もともとcompile.cではNODE_CASEcompile_case関数で、NODE_CASE2compile_case2関数で処理していました。 引き続きそれら2つの関数を使用するようにします。

      case RB_CASE_NODE: {
        if (RB_NODE_CASE(node)->predicate) {
            CHECK(compile_case(iseq, ret, node, popped));
        }
        else {
            CHECK(compile_case2(iseq, ret, node, popped));
        }
        break;
      }

compile_case関数

case ... when ...のケースを処理するcompile_case関数からみていきましょう。 生成されるバイトコードは条件をチェックして必要に応じてjumpする部分と、条件にマッチした場合に実行されるbodyの部分に大きく分かれているのでした。 そこでcompile_case関数ではbody_seqcond_seqというアンカーを用意して、WhenNodeをコンパイルするときに条件の部分とbodyの部分を異なるアンカーに追記していき、最後に1つのバイトコード列へmergeします。

static int
compile_case(rb_iseq_t *iseq, LINK_ANCHOR *const ret, const NODE *const orig_node, int popped)
{
    DECL_ANCHOR(head);
    DECL_ANCHOR(body_seq);
    DECL_ANCHOR(cond_seq);

    INIT_ANCHOR(head);
    INIT_ANCHOR(body_seq);
    INIT_ANCHOR(cond_seq);

    ADD_SEQ(ret, cond_seq);
    ADD_SEQ(ret, body_seq);
    return COMPILE_OK;
}

compile_case関数のうち、おもにバイトコードの生成に関するロジックを抜き出すと以下のような構造になっています。

  1. まずpredicate(case v)をコンパイルする
  2. 次にconditions(when a ... when b ...)を1つずつコンパイルする。このときsplat(when *a)かどうかで条件の部分に関して、生成するバイトコードが変化する
  3. さいごにelse_clauseをコンパイルする
static int
compile_case(rb_iseq_t *iseq, LINK_ANCHOR *const ret, const NODE *const orig_node, int popped)
{
    const rb_node_list2_t *vals;
    const rb_node_list2_t *list = &RB_NODE_CASE(node)->conditions;
    const rb_else_node_t *nd_else = RB_NODE_CASE(node)->else_clause;

    CHECK(COMPILE(head, "case base", RB_NODE_CASE(node)->predicate));

    /* when */
    for (size_t i = 0; i < RB_NODE_LIST_LEN(list); i++) {
        const NODE *n = list->nodes[i];

        CHECK(COMPILE_(body_seq, "when body", RB_NODE_WHEN(n)->statements, popped));
        vals = &RB_NODE_WHEN(n)->conditions;
        if (!RB_NODE_LIST_EMPTY_P(vals)) {
            /* when a, b, c */
            for (size_t j = 0; j < RB_NODE_LIST_LEN(vals); j++) {
                const NODE *val = vals->nodes[j];

                if (nd_type_p(val, RB_SPLAT_NODE)) {
                    only_special_literals = 0;
                    CHECK(when_splat_vals(iseq, cond_seq, val, l1, only_special_literals, literals));
                }
                else {
                    only_special_literals = when_vals(iseq, cond_seq, node, val, l1, only_special_literals, literals);
                    if (only_special_literals < 0) return COMPILE_NG;
                }
            }
        }
        else {
            EXPECT_NODE_NONULL("NODE_CASE", n, NODE_LIST, COMPILE_NG);
        }
    }
    /* else */
    if (nd_else) {
        ADD_LABEL(cond_seq, elselabel);
        ADD_INSN(cond_seq, nd_else, pop);
        add_trace_branch_coverage(iseq, cond_seq, nd_code_loc(nd_else), nd_node_id(nd_else), branch_id, "else", branches);
        CHECK(COMPILE_(cond_seq, "else", nd_else, popped));
        ADD_INSNL(cond_seq, nd_else, jump, endlabel);
    }
    else {
        debugs("== else (implicit)\n");
        ADD_LABEL(cond_seq, elselabel);
        ADD_INSN(cond_seq, orig_node, pop);
        add_trace_branch_coverage(iseq, cond_seq, nd_code_loc(orig_node), nd_node_id(orig_node), branch_id, "else", branches);
        if (!popped) {
            ADD_INSN(cond_seq, orig_node, putnil);
        }
        ADD_INSNL(cond_seq, orig_node, jump, endlabel);
    }

    ADD_SEQ(ret, cond_seq);
    ADD_SEQ(ret, body_seq);
    ADD_LABEL(ret, endlabel);
    return COMPILE_OK;
}

想定したバイトコードが出力されることを確認します。

# == disasm: #<ISeq:<main>@../../test.rb:1 (1,0)-(8,3)>
# 0000 putself                                                          (   1)[Li]
# 0001 opt_send_without_block                 <calldata!mid:v, argc:0, FCALL|VCALL|ARGS_SIMPLE>
# 0003 putself                                                          (   2)
# 0004 opt_send_without_block                 <calldata!mid:a, argc:0, FCALL|VCALL|ARGS_SIMPLE>
# 0006 topn                                   1
# 0008 opt_send_without_block                 <calldata!mid:===, argc:1, FCALL|ARGS_SIMPLE>
# 0010 branchif                               25
# 0012 putself                                                          (   4)
# 0013 opt_send_without_block                 <calldata!mid:b, argc:0, FCALL|VCALL|ARGS_SIMPLE>
# 0015 topn                                   1
# 0017 opt_send_without_block                 <calldata!mid:===, argc:1, FCALL|ARGS_SIMPLE>
# 0019 branchif                               29
# 0021 pop                                                              (   7)
# 0022 putobject                              :else_body[Li]
# 0024 leave
# 0025 pop                                                              (   2)
# 0026 putobject                              :a_body                   (   3)[Li]
# 0028 leave                                                            (   7)
# 0029 pop                                                              (   4)
# 0030 putobject                              :b_body                   (   5)[Li]
# 0032 leave                                                            (   7)

case v 
when a
  :a_body
when b
  :b_body
else
  :else_body
end

whenの条件が複数あるとき

case ... when ...にはいくつか細かい注意点があるので、それらをみておきます。 まずはwhen a1, a2のようにwhenのあとに2つ以上の条件がある場合です。

このときは左から順にマッチするかチェックするため、その順番に===を呼び出して条件付きジャンプをするバイトコードが生成されます。 when a1, when a2, when a3の全てでbranchif 43と飛び先の命令が同じになっています。

# `case v`
#
# 0000 putself                                                          (   2)[Li]
# 0001 opt_send_without_block                 <calldata!mid:v, argc:0, FCALL|VCALL|ARGS_SIMPLE>

# `when a1`
#
# 0003 putself                                                          (   3)
# 0004 opt_send_without_block                 <calldata!mid:a1, argc:0, FCALL|VCALL|ARGS_SIMPLE>
# 0006 topn                                   1
# 0008 opt_send_without_block                 <calldata!mid:===, argc:1, FCALL|ARGS_SIMPLE>
# 0010 branchif                               43

# `when a2`
#
# 0012 putself
# 0013 opt_send_without_block                 <calldata!mid:a2, argc:0, FCALL|VCALL|ARGS_SIMPLE>
# 0015 topn                                   1
# 0017 opt_send_without_block                 <calldata!mid:===, argc:1, FCALL|ARGS_SIMPLE>
# 0019 branchif                               43

# `when a3`
#
# 0021 putself
# 0022 opt_send_without_block                 <calldata!mid:a3, argc:0, FCALL|VCALL|ARGS_SIMPLE>
# 0024 topn                                   1
# 0026 opt_send_without_block                 <calldata!mid:===, argc:1, FCALL|ARGS_SIMPLE>
# 0028 branchif                               43

# `when b`
#
# 0030 putself                                                          (   5)
# 0031 opt_send_without_block                 <calldata!mid:b, argc:0, FCALL|VCALL|ARGS_SIMPLE>
# 0033 topn                                   1
# 0035 opt_send_without_block                 <calldata!mid:===, argc:1, FCALL|ARGS_SIMPLE>
# 0037 branchif                               47

# 以下body
#
# 0039 pop                                                              (   8)
# 0040 putobject                              :else_body[Li]
# 0042 leave
# 0043 pop                                                              (   3)
# 0044 putobject                              :a_body                   (   4)[Li]
# 0046 leave                                                            (   8)
# 0047 pop                                                              (   5)
# 0048 putobject                              :b_body                   (   6)[Li]
# 0050 leave                                                            (   8)
case v
when a1, a2, a3
  :a_body
when b
  :b_body
else
  :else_body
end

splatがあるとき

when *bのようにsplatがある場合はマッチのためのバイトコードが変わります。

# `case v`
#
# 0000 putself                                                          (   1)[Li]
# 0001 opt_send_without_block                 <calldata!mid:v, argc:0, FCALL|VCALL|ARGS_SIMPLE>

# `when a`
#
# 0003 putself                                                          (   2)
# 0004 opt_send_without_block                 <calldata!mid:a, argc:0, FCALL|VCALL|ARGS_SIMPLE>
# 0006 topn                                   1
# 0008 opt_send_without_block                 <calldata!mid:===, argc:1, FCALL|ARGS_SIMPLE>
# 0010 branchif                               26

# `when *b`
#
# 0012 dup                                                              (   4)
# 0013 putself
# 0014 opt_send_without_block                 <calldata!mid:b, argc:0, FCALL|VCALL|ARGS_SIMPLE>
# 0016 splatarray                             false
# 0018 checkmatch                             6
# 0020 branchif                               30

# 以下はwhenやelseのbody
#
# 0022 pop                                                              (   7)
# 0023 putobject                              :else_body[Li]
# 0025 leave
# 0026 pop                                                              (   2)
# 0027 putobject                              :a_body                   (   3)[Li]
# 0029 leave                                                            (   7)
# 0030 pop                                                              (   4)
# 0031 putobject                              :b_body                   (   5)[Li]
# 0033 leave                                                            (   7)

case v
when a
  :a_body
when *b
  :b_body
else
  :else_body
end

実行時のスタックを考えてみましょう。

まずはwhen aの場合からみてみます。 このときはa === vを実行するためtopnvがスタックトップにくるようにしてから#===メソッドを呼び出します。

# `0006 topn 1`まで
v
a
v

# `0008 opt_send_without_block <calldata!mid:===, argc:1, FCALL|ARGS_SIMPLE>`まで
a === v # の結果
v

次にwhen *bの場合です。 このときはcheckmatch(VALUE target, VALUE pattern)を実行するため、v, bの順番(右がスタックトップ)にスタックに積まれている必要があります。

# `0012 dup`まで
v
v

# `0016 splatarray false`まで
b # をdupしたもの
v
v

最適化されたバイトコード

最後に最適化についてみておきます。 when ...のすべての要素が実行時評価を必要としない場合には、条件をkey、飛び先をvalueとしたhashを用意して、そこから直接飛び先を引くという最適化が入ります。 以下のバイトコードでいうと0004 opt_case_dispatch <cdhash>, 23が追加されています。

# 0000 putself                                                          (   1)[Li]
# 0001 opt_send_without_block                 <calldata!mid:v, argc:0, FCALL|VCALL|ARGS_SIMPLE>
# 0003 dup
# 0004 opt_case_dispatch                      <cdhash>, 23
# 0007 putobject                              :a                        (   2)
# 0009 topn                                   1
# 0011 opt_send_without_block                 <calldata!mid:===, argc:1, FCALL|ARGS_SIMPLE>
# 0013 branchif                               27
# 0015 putobject                              :b                        (   4)
# 0017 topn                                   1
# 0019 opt_send_without_block                 <calldata!mid:===, argc:1, FCALL|ARGS_SIMPLE>
# 0021 branchif                               31
# 0023 pop                                                              (   7)
# 0024 putobject                              :body_else[Li]
# 0026 leave
# ...

case v
when :a
  :body_a
when :b
  :body_b
else
  :body_else
end

ここまで踏まえて、簡単なコードで動作確認しておきましょう。

def m(v)
  a = :a
  b = :b

  case v
  when a
    :body_a
  when b
    :body_b
  else
    :body_else
  end
end

p m(:a)
#=> :body_a
p m(:else)
#=> :body_else

compile_case2関数

次にcase when ...をコンパイルするcompile_case2関数を新しいノードに対応させます。 といっても大まかな流れは変更前、およびcompile_case関数と変わりません。

case when ...casepredicateが存在しないため、いきなりwhenをコンパイルすることになります。 whenの条件の部分と、bodyの部分を別々のアンカーに追記していくのはcompile_case関数と変わりません。 その後、elseに当たる部分をコンパイルして終了です。

static int
compile_case2(rb_iseq_t *iseq, LINK_ANCHOR *const ret, const NODE *const orig_node, int popped)
{
    /* when */
    for (size_t i = 0; i < RB_NODE_LIST_LEN(list); i++) {
        const NODE *n = list->nodes[i];

        // when ... bodyのうち、bodyの部分のコンパイル
        CHECK(COMPILE_(body_seq, "when", RB_NODE_WHEN(n)->statements, popped));
        ADD_INSNL(body_seq, n, jump, endlabel);

        // when ...のうち、...の部分のコンパイル
        vals = &RB_NODE_WHEN(n)->conditions;
        if (RB_NODE_LIST_EMPTY_P(vals)) {
            EXPECT_NODE_NONULL("NODE_WHEN", n, NODE_LIST, COMPILE_NG);
        }
        /* when a, b, c */
        for (size_t j = 0; j < RB_NODE_LIST_LEN(vals); j++) {
            const NODE *val = vals->nodes[j];

            if (nd_type_p(val, RB_SPLAT_NODE)) {
                ADD_INSN(ret, val, putnil);
                CHECK(COMPILE(ret, "when2/cond splat", RB_NODE_SPLAT(val)->expression));
                ADD_INSN1(ret, val, splatarray, Qtrue);
                ADD_INSN1(ret, val, checkmatch, INT2FIX(VM_CHECKMATCH_TYPE_WHEN | VM_CHECKMATCH_ARRAY));
                ADD_INSNL(ret, val, branchif, l1);
            }
            else {
                LABEL *lnext;
                lnext = NEW_LABEL(nd_line(val));
                debug_compile("== when2\n", (void)0);
                CHECK(compile_branch_condition(iseq, ret, val, l1, lnext));
                ADD_LABEL(ret, lnext);
            }
        }
    }
    /* else */
    CHECK(COMPILE_(ret, "else", nd_else, popped));
    ADD_INSNL(ret, orig_node, jump, endlabel);

    ADD_SEQ(ret, body_seq);
    ADD_LABEL(ret, endlabel);
    return COMPILE_OK;
}

動作確認をして終わりにしましょう。

def m(a: false, b1: false, b2: false, c: [false])
  case
  when a
    :body_a
  when b1, b2
    :body_b
  when *c
    :body_c
  else
    :body_else
  end
end

p m()
#=> :body_else

p m(a: true)
#=> :body_a

p m(b1: true)
#=> :body_b

p m(c: [true])
#=> :body_c

まとめ

今日の成果です。

  • case v when ....に対応した
  • case when ....に対応した

Ruby Parser開発日誌 (24-36) - parse.yが生成するノードを変える ー constant declaration with operator

36日目: A::B ||= 1

前回はa += 1の形式の代入をやりました。 今回は左辺が定数のときの代入をやっていきます。

これまでと同様に演算子の種類によって生成するノードが変わります。

  • A += 1: ConstantOperatorWriteNode
  • A ||= 1: ConstantOrWriteNode
  • A &&= 1: ConstantAndWriteNode

また定数に特有の話としてConstant(A += 1)とConstantPath(::A += 1A::B += 1)を使いわける必要もあります。

parse.yを変更する

まず変更するのはnew_const_op_assign関数の実装です。 この関数はA::B += 1::A += 1を解析するときに呼ばれます。

%rule op_asgn(rhs) <node>
                | primary_value tCOLON2 tCONSTANT tOP_ASGN lex_ctxt rhs
                    {
                        YYLTYPE loc = code_loc_gen(&@primary_value, &@tCONSTANT);
                        $$ = new_const_op_assign(p, NEW_COLON2($primary_value, $tCONSTANT, &loc, &@tCOLON2, &@tCONSTANT), $tOP_ASGN, $rhs, $lex_ctxt, &@$);
                    /*% ripper: opassign!(const_path_field!($:1, $:3), $:4, $:6) %*/
                    }
                | tCOLON3 tCONSTANT tOP_ASGN lex_ctxt rhs
                    {
                        YYLTYPE loc = code_loc_gen(&@tCOLON3, &@tCONSTANT);
                        $$ = new_const_op_assign(p, NEW_COLON3($tCONSTANT, &loc, &@tCOLON3, &@tCONSTANT), $tOP_ASGN, $rhs, $lex_ctxt, &@$);
                    /*% ripper: opassign!(top_const_field!($:2), $:3, $:5) %*/
                    }

前回、前々回と同様に演算子の種類とノードの種類に応じて生成するノードを変えます。 またshareable_constant_valueマジックコメントがあるときのために最後にnew_shareable_constant関数を呼び出して、必要に応じてShareableConstantNodeでラップするようにしておきます。

static rb_node_t *
new_const_op_assign(struct parser_params *p, rb_node_t *lhs, ID op, rb_node_t *rhs, struct lex_context ctxt, const YYLTYPE *loc)
{
    rb_node_t *asgn;

    if (lhs) {
        if (op == tOROP) {
            switch (nd_type(lhs)) {
              case RB_CONSTANT_READ_NODE: {
                ...
              }
              case RB_CONSTANT_PATH_NODE: {
                ...
              }
              ...
            }
        }
        else if (op == tANDOP) {
            ...
        }
        else {
            ...
        }

        asgn = new_shareable_constant(p, asgn);
    }
    else {
        asgn = NEW_ERROR(loc);
    }
    return asgn;
}

そのほかnew_op_assign関数も修正しておきます。 これはA += 1などのときに以下の生成規則のアクションが適用されるからです。

%rule op_asgn(rhs) <node>
                : var_lhs tOP_ASGN lex_ctxt rhs
                    {
                        $$ = new_op_assign(p, $var_lhs, $tOP_ASGN, $rhs, $lex_ctxt, &@$);
                    /*% ripper: opassign!($:1, $:2, $:4) %*/
                    }

shareable_constant_valueマジックコメントがある場合、new_op_assign関数にはShareableConstantNodeが渡ってくるので、ShareableConstantNodeについてもケアする必要があります。

@@ -17274,6 +17352,16 @@ new_op_assign(struct parser_params *p, rb_node_t *lhs, ID op, rb_node_t *rhs, st
                 asgn = NEW_RB_GLOBAL_VARIABLE_OR_WRITE(cast->name, rhs, loc);
                 break;
               }
+              case RB_CONSTANT_WRITE_NODE: {
+                rb_constant_write_node_t *cast = (rb_constant_write_node_t *)lhs;
+                asgn = NEW_RB_CONSTANT_OR_WRITE(cast->name, rhs, loc);
+                break;
+              }
+              case RB_SHAREABLE_CONSTANT_NODE: {
+                asgn = lhs;
+                RB_NODE_SHAREABLE_CONSTANT(lhs)->write = new_op_assign(p, RB_NODE_SHAREABLE_CONSTANT(lhs)->write, op, rhs, ctxt, loc);
+                break;
+              }
               default:

compile.cを変更する

ここで一度、書き換え前後のノードと書き換え前のcompile.cで使っている関数を整理しておきます。

Before After compile (Before)
A += 1 NODE_CDECL ConstantOperatorWriteNode iseq_compile_each0
A ||= 1 NODE_OP_ASGN_OR ConstantOrWriteNode compile_op_log
A &&= 1 NODE_OP_ASGN_AND ConstantAndWriteNode compile_op_log
::A += 1 NODE_OP_CDECL ConstantPathOperatorWriteNode compile_op_cdecl
::A ||= 1 NODE_OP_CDECL ConstantPathOrWriteNode compile_op_cdecl
::A &&= 1 NODE_OP_CDECL ConstantPathAndWriteNode compile_op_cdecl
A::B += 1 NODE_OP_CDECL ConstantPathOperatorWriteNode compile_op_cdecl
A::B ||= 1 NODE_OP_CDECL ConstantPathOrWriteNode compile_op_cdecl
A::B &&= 1 NODE_OP_CDECL ConstantPathAndWriteNode compile_op_cdecl

使用している関数に注目して、以下の3つのグループに分けて書き換えを考えることにします。

  • ConstantOperatorWriteNode
  • ConstantOrWriteNodeConstantAndWriteNode
  • その他

A += 1をコンパイルする

ノードの書き換え前は以下のようにiseq_compile_each0関数のなかに直接コンパイルするためのロジックが書かれていました。

case NODE_CDECL:{
  if (RNODE_CDECL(node)->nd_vid) {
      CHECK(compile_shareable_constant_value(iseq, ret, RNODE_CDECL(node)->shareability, node, RNODE_CDECL(node)->nd_value));

      if (!popped) {
          ADD_INSN(ret, node, dup);
      }

      ADD_INSN1(ret, node, putspecialobject,
                INT2FIX(VM_SPECIAL_OBJECT_CONST_BASE));
      ADD_INSN1(ret, node, setconstant, ID2SYM(RNODE_CDECL(node)->nd_vid));
  }
  else {
      compile_cpath(ret, iseq, RNODE_CDECL(node)->nd_else);
      CHECK(compile_shareable_constant_value(iseq, ret, RNODE_CDECL(node)->shareability, node, RNODE_CDECL(node)->nd_value));
      ADD_INSN(ret, node, swap);

      if (!popped) {
          ADD_INSN1(ret, node, topn, INT2FIX(1));
          ADD_INSN(ret, node, swap);
      }

      ADD_INSN1(ret, node, setconstant, ID2SYM(get_node_colon_nd_mid(RNODE_CDECL(node)->nd_else)));
  }
  break;
}

今回対象としているケースはA += 1であり、この場合常にnd_vidが設定されているため、実際はifの方の分岐だけを考えればよいでしょう。 shareable_constant_valueマジックコメントがあるときのことも考えて、今回はcompile_constant_operator_write関数として処理を切り出します。

まずshareable_constant_valueマジックコメントがない、もしくはnoneのときのことを考えましょう。

== disasm: #<ISeq:<main>@test.rb:1 (1,0)-(2,6)>
# A + 1の部分
#
# 0000 opt_getconstant_path                   <ic:0 A>                  (   2)[Li]
# 0002 putobject_INT2FIX_1_
# 0003 opt_plus                               <calldata!mid:+, argc:1, ARGS_SIMPLE>[CcCr]
# 0005 dup

# 代入の部分
#
# 0006 putspecialobject                       3
# 0008 setconstant                            :A
# 0010 leave

# shareable_constant_value: none
A += 1

他の+=と同様に右辺を計算する部分と代入をする部分からなるバイトコードが生成されます。 さてこれまではノードのレベルで右辺がA + 1の形になっていたので、そのまま右辺をコンパイルすれば期待するバイトコードが生成されていました。

# @ NODE_CDECL (id: 0, line: 2, location: (2,0)-(2,6))*
# +- nd_vid: :A
# +- nd_else: not used
# +- shareability: none
# +- nd_value:
#     @ NODE_CALL (id: 4, line: 2, location: (2,0)-(2,6))
#     +- nd_mid: :+
#     +- nd_recv:
#     |   @ NODE_CONST (id: 2, line: 2, location: (2,0)-(2,1))
#     |   +- nd_vid: :A
#     +- nd_args:
#         @ NODE_LIST (id: 3, line: 2, location: (2,5)-(2,6))
#         +- as.nd_alen: 1
#         +- nd_head:
#         |   @ NODE_INTEGER (id: 1, line: 2, location: (2,5)-(2,6))
#         |   +- val: 1
#         +- nd_next:
#             (null node)

しかしノードの書き換え後はvalue: 1になっているため、自前でA + 1を組み立てる必要があります。

# @ ConstantOperatorWriteNode (location: (2,0)-(2,6))
# +-- name: :A
# +-- name_loc: (2,0)-(2,1) = "A"
# +-- binary_operator_loc: (2,2)-(2,4) = "+="
# +-- value:
# |   @ IntegerNode (location: (2,5)-(2,6))
# |   +-- IntegerBaseFlags: decimal
# |   +-- value: 1
# +-- binary_operator: :+

これらを踏まえてcompile_constant_operator_write関数を実装すると以下のようになります。

static int
compile_constant_operator_write(rb_iseq_t *iseq, LINK_ANCHOR *const ret, enum rb_parser_shareability shareable, const NODE *const node, int popped)
{
    CHECK(compile_constant_read(iseq, ret, node, RB_NODE_CONSTANT_OPERATOR_WRITE(node)->name));
    CHECK(COMPILE(ret, "const op asgn value", RB_NODE_CONSTANT_OPERATOR_WRITE(node)->value));
    ADD_SEND_R(ret, node, RB_NODE_CONSTANT_OPERATOR_WRITE(node)->binary_operator, INT2FIX(1), NULL, INT2FIX(0), NULL);

    if (!popped) {
        ADD_INSN(ret, node, dup);
    }

    ADD_INSN1(ret, node, putspecialobject,
              INT2FIX(VM_SPECIAL_OBJECT_CONST_BASE));
    ADD_INSN1(ret, node, setconstant, ID2SYM(RB_NODE_CONSTANT_OPERATOR_WRITE(node)->name));
    return COMPILE_OK;
}

それではshareable_constant_valueが指定されているケースを考えてみましょう。 shareable_constant_valueにはliteral, experimental_copy, experimental_everythingがあるので、それぞれ生成されるバイトコードを確認します。

まずliteralの場合です。 このときA = RubyVMFrozenCore#ensure_shareable(A + 1, "A")に相当する命令が生成されます。 そのため0000 putspecialobject 1RubyVMFrozenCoreをスタックに積み、0007 putobject "A"0009 opt_send_without_block <calldata!mid:ensure_shareable, argc:2, ARGS_SIMPLE>#ensure_shareableを呼び出す命令が追加されています。

# == disasm: #<ISeq:<main>@test.rb:1 (1,0)-(2,6)>
# 0000 putspecialobject                       1                         (   2)[Li]
# 0002 opt_getconstant_path                   <ic:0 A>
# 0004 putobject_INT2FIX_1_
# 0005 opt_plus                               <calldata!mid:+, argc:1, ARGS_SIMPLE>[CcCr]
# 0007 putobject                              "A"
# 0009 opt_send_without_block                 <calldata!mid:ensure_shareable, argc:2, ARGS_SIMPLE>
# 0011 dup
# 0012 putspecialobject                       3
# 0014 setconstant                            :A
# 0016 leave

# shareable_constant_value: literal
A += 1

experimental_copyの場合はA = RubyVMFrozenCore#make_shareable_copy(A + 1)に相当する命令が生成されます。

# == disasm: #<ISeq:<main>@test.rb:1 (1,0)-(2,6)>
# 0000 putspecialobject                       1                         (   2)[Li]
# 0002 opt_getconstant_path                   <ic:0 A>
# 0004 putobject_INT2FIX_1_
# 0005 opt_plus                               <calldata!mid:+, argc:1, ARGS_SIMPLE>[CcCr]
# 0007 opt_send_without_block                 <calldata!mid:make_shareable_copy, argc:1, ARGS_SIMPLE>
# 0009 dup
# 0010 putspecialobject                       3
# 0012 setconstant                            :A
# 0014 leave

# shareable_constant_value: experimental_copy
A += 1

同様にexperimental_everythingの場合はA = RubyVMFrozenCore#make_shareable(A + 1)に相当する命令が生成されます。

# == disasm: #<ISeq:<main>@test.rb:1 (1,0)-(2,6)>
# 0000 putspecialobject                       1                         (   2)[Li]
# 0002 opt_getconstant_path                   <ic:0 A>
# 0004 putobject_INT2FIX_1_
# 0005 opt_plus                               <calldata!mid:+, argc:1, ARGS_SIMPLE>[CcCr]
# 0007 opt_send_without_block                 <calldata!mid:make_shareable, argc:1, ARGS_SIMPLE>
# 0009 dup
# 0010 putspecialobject                       3
# 0012 setconstant                            :A
# 0014 leave

# shareable_constant_value: experimental_everything
A += 1

これまではノードのnd_value(右辺)がA + 1の形をしていたのでcompile_shareable_constant_value関数を呼び出すだけで、マジックコメントの値に応じたバイトコードを生成することができました。 しかし今回のノードの書き換えでvalue(右辺)が1になっているため、compile_shareable_constant_value関数にメソッド呼び出しのノードを渡したときと同じ挙動を実装する必要があります。

これらを踏まえてcompile_constant_operator_write関数を修正すると以下のようになります1

 static int
 compile_constant_operator_write(rb_iseq_t *iseq, LINK_ANCHOR *const ret, enum rb_parser_shareability shareable, const NODE *const node, int popped)
 {
+    if (shareable != rb_parser_shareable_none) {
+      ADD_INSN1(ret, node, putspecialobject, INT2FIX(VM_SPECIAL_OBJECT_VMCORE));
+    }
     CHECK(compile_constant_read(iseq, ret, node, RB_NODE_CONSTANT_OPERATOR_WRITE(node)->name));
     CHECK(COMPILE(ret, "const op asgn value", RB_NODE_CONSTANT_OPERATOR_WRITE(node)->value));
     ADD_SEND_R(ret, node, RB_NODE_CONSTANT_OPERATOR_WRITE(node)->binary_operator, INT2FIX(1), NULL, INT2FIX(0), NULL);

+    switch (shareable) {
+      case rb_parser_shareable_none:
+        break;
+      case rb_parser_shareable_literal: {
+        VALUE path = const_decl_path(node);
+        ADD_INSN1(ret, node, putobject, path);
+        RB_OBJ_WRITTEN(iseq, Qundef, path);
+        ADD_SEND_WITH_FLAG(ret, node, rb_intern("ensure_shareable"), INT2FIX(2), INT2FIX(VM_CALL_ARGS_SIMPLE));
+        break;
+      }
+      case rb_parser_shareable_copy:
+        ADD_SEND_WITH_FLAG(ret, node, rb_intern("make_shareable_copy"), INT2FIX(1), INT2FIX(VM_CALL_ARGS_SIMPLE));
+        break;
+      case rb_parser_shareable_everything:
+        ADD_SEND_WITH_FLAG(ret, node, rb_intern("make_shareable"), INT2FIX(1), INT2FIX(VM_CALL_ARGS_SIMPLE));
+        break;
+    }
+
     if (!popped) {
         ADD_INSN(ret, node, dup);
     }

A ||= 1A &&= 1をコンパイルする

次にA ||= 1A &&= 1のコンパイルですが、ここでは引き続きcompile_op_log関数に処理を任せる方針で実装してみます。 定数のアクセスの仕方やマジックコメントが有効な際に右辺の値に対してRubyVMFrozenCoreのメソッドを呼ぶといった違いはあれど、生成されるバイトコードのフロー自体はa ||= 1などと同じ構造になるためです。

compile_op_log関数の変更点を1つずつみていきましょう。

まずは関数の引数です。 ShareableConstantNodeのときを考慮してenum rb_parser_shareability shareableを1つ増やします。

 static int
-compile_op_log(rb_iseq_t *iseq, LINK_ANCHOR *const ret, const NODE *const node, ID name, const NODE *const nd_value, int popped, bool op_and)
+compile_op_log(rb_iseq_t *iseq, LINK_ANCHOR *const ret, enum rb_parser_shareability shareable, const NODE *const node, ID name, const NODE *const n
d_value, int popped, bool op_and)

A ||= 1A = A || 1であることを踏まえて、右辺のAに対応するバイトコードを生成するようにします。

       case RB_GLOBAL_VARIABLE_OR_WRITE_NODE:
         ADD_INSN1(ret, node, getglobal, ID2SYM(name));
         break;
+      case RB_CONSTANT_AND_WRITE_NODE:
+      case RB_CONSTANT_OR_WRITE_NODE:
+        CHECK(compile_constant_read(iseq, ret, node, name));
+        break;
       default:
         UNKNOWN_NODE("compile_op_log", node, COMPILE_NG);
     }

右辺の演算子の後の値については必要に応じてRubyVMFrozenCoreのメソッドを呼ぶ必要があるため、ノードの種類に応じて分岐します。

-    CHECK(COMPILE(ret, "NODE_OP_ASGN_AND/OR#nd_value", nd_value));
+
+    if (nd_type_p(node, RB_CONSTANT_AND_WRITE_NODE) || nd_type_p(node, RB_CONSTANT_OR_WRITE_NODE)) {
+        CHECK(compile_shareable_constant_value(iseq, ret, shareable, node, nd_value));
+    }
+    else {
+        CHECK(COMPILE(ret, "NODE_OP_ASGN_AND/OR#nd_value", nd_value));
+    }
+

最後に定数へ代入するためのバイトコードを生成します。

       case RB_GLOBAL_VARIABLE_OR_WRITE_NODE:
         ADD_INSN1(ret, node, setglobal, ID2SYM(name));
         break;
+      case RB_CONSTANT_AND_WRITE_NODE:
+      case RB_CONSTANT_OR_WRITE_NODE:
+        ADD_INSN1(ret, node, putspecialobject,
+                  INT2FIX(VM_SPECIAL_OBJECT_CONST_BASE));
+        ADD_INSN1(ret, node, setconstant, ID2SYM(name));
+        break;
       default:
         UNKNOWN_NODE("compile_op_log", node, COMPILE_NG);

その他をコンパイルする

ConstantPathOperatorWriteNode, ConstantPathOrWriteNode, ConstantPathAndWriteNodeについてはcompile_op_cdecl関数を修正することで対応します。

複数種類のノードを扱うためにcompile_op_cdecl関数の呼び出し元で以下の3つの要素を取り出してcompile_op_cdecl関数に渡すようにします。

  • 左辺を表すnd_head
  • 演算子を表すnd_aid
  • 右辺を表すnd_value
 static int
-compile_op_cdecl(rb_iseq_t *iseq, LINK_ANCHOR *const ret, enum rb_parser_shareability shareability, const NODE *const node, ID nd_aid, int popped)
+compile_op_cdecl(rb_iseq_t *iseq, LINK_ANCHOR *const ret, enum rb_parser_shareability shareability, const NODE *const node, const rb_constant_path_node_t *const nd_head, ID nd_aid, const NODE *const nd_value, int popped)

関数の内部の修正については特筆すべきこともないので省略します。

最後にいくつか簡単な例を実行して動作を確認しておきます。

A = 1
p A
#=> 1

A += 1
p A
#=> 2

A ||= 10
p A
#=> 2

A &&= 12
p A
#=> 12

::A = 1
p ::A
#=> 1

::A += 1
p ::A
#=> 2

::A ||= 10
p ::A
#=> 2

::A &&= 12
p ::A
#=> 12

module A
end

A::B = 1
p A::B
#=> 1

A::B += 1
p A::B
#=> 2

A::B ||= 10
p A::B
#=> 2

A::B &&= 12
p A::B
#=> 12

良さそうです。

まとめ

今日の成果です。

  • constant declaration with operator (A::B ||= 1)に対応した

これでメソッド呼び出しに関する構文は全て終わったので、次回は制御構文のうち未対応のものをやっていきたいと思います。


  1. このあたりcompile_ensure_shareable_node関数とcompile_make_shareable_node関数をうまく整理すると共通化できるような気がしなくもないが...

Ruby Parser開発日誌 (24-35) - parse.yが生成するノードを変える ー assignment with operator

35日目: a += 1

前回はary[1] += fooおよびs.f += 1といった形式の代入に対応しました。 今回はa += 1の形式の代入に取り組みたいとおもいます。

前回と同様に以下の3種類の形式を分けて考える必要があります。

  1. a += foo
  2. a ||= foo
  3. a &&= foo

ここで1つ面白いのは、ノードの書き換え前はa += fooがローカル変数への代入を表すNODE_LASGNを用いて表現されていることです。

# @ NODE_LASGN (id: 0, line: 1, location: (1,0)-(1,6))*
# +- nd_vid: :a
# +- nd_value:
#     @ NODE_CALL (id: 4, line: 1, location: (1,0)-(1,6))
#     +- nd_mid: :+
#     +- nd_recv:
#     |   @ NODE_LVAR (id: 2, line: 1, location: (1,0)-(1,1))
#     |   +- nd_vid: :a
#     +- nd_args:
#         @ NODE_LIST (id: 3, line: 1, location: (1,5)-(1,6))
#         +- as.nd_alen: 1
#         +- nd_head:
#         |   @ NODE_INTEGER (id: 1, line: 1, location: (1,5)-(1,6))
#         |   +- val: 1
#         +- nd_next:
#             (null node)
a += 1

# @ NODE_LASGN (id: 5, line: 2, location: (2,0)-(2,5))*
# +- nd_vid: :b
# +- nd_value:
#     @ NODE_INTEGER (id: 6, line: 2, location: (2,4)-(2,5))
#     +- val: 2
b = 2

a += 1からa = a + 1を表すノードを生成しているので、通常のローカル変数への代入とおなじようにコンパイルすれば、期待したバイトコードが生成されるというわけです。

書き換え前後におけるノードをまとめると以下のようになります。

a += foo NODE_LASGN LocalVariableOperatorWriteNode
a ||= foo NODE_OP_ASGN_OR LocalVariableOrWriteNode
a &&= foo NODE_OP_ASGN_AND LocalVariableAndWriteNode

もう1つ注意することとして、LocalVariableOperatorWriteNodeという名前から想像できるように変数の種類に応じて、対応するOperatorWriteNode, OrWriteNode, AndWriteNodeが存在します。

  • a += foo: LocalVariableOperatorWriteNode
  • @a += foo: InstanceVariableOperatorWriteNode
  • @@a += foo: ClassVariableOperatorWriteNode
  • $a += foo: GlobalVariableOperatorWriteNode

parserを修正する

以上を意識しながらparse.yを変更していきます。 parse.yでは左辺の変数の種類によらず以下の生成規則が対応しています。 そのためnew_op_assign関数のなかで分岐して、適切なノードを生成するようにしましょう。

%rule op_asgn(rhs) <node>
                : var_lhs tOP_ASGN lex_ctxt rhs
                    {
                        $$ = new_op_assign(p, $var_lhs, $tOP_ASGN, $rhs, $lex_ctxt, &@$);
                    /*% ripper: opassign!($:1, $:2, $:4) %*/
                    }

new_op_assign関数ではまず&&=, ||=, それ以外で分岐をして、その中でさらに左辺の変数の種類に応じて分岐をすればよいでしょう。

static rb_node_t *
new_op_assign(struct parser_params *p, rb_node_t *lhs, ID op, rb_node_t *rhs, struct lex_context ctxt, const YYLTYPE *loc)
{
    rb_node_t *asgn;

    if (lhs) {
        if (op == tOROP) {
            switch (nd_type(lhs)) {
              case RB_LOCAL_VARIABLE_WRITE_NODE: {
                rb_local_variable_write_node_t *cast = (rb_local_variable_write_node_t *)lhs;
                asgn = NEW_RB_LOCAL_VARIABLE_OR_WRITE(cast->name, rhs, cast->depth, loc);
                break;
              }
              case RB_INSTANCE_VARIABLE_WRITE_NODE: {
                rb_instance_variable_write_node_t *cast = (rb_instance_variable_write_node_t *)lhs;
                asgn = NEW_RB_INSTANCE_VARIABLE_OR_WRITE(cast->name, rhs, loc);
                break;
              }
              ...
        }
        else if (op == tANDOP) {
            switch (nd_type(lhs)) {
              ...
            }
        }
        else {
            switch (nd_type(lhs)) {
              ...
            }
        }
    }
    else {
        asgn = NEW_ERROR(loc);
    }
    return asgn;
}

compilerを修正する

operatorに応じて3種類のノードがあるので、それぞれ見ていきましょう。

まずはOperatorWriteNode、つまりa += fooのときです。 このときa = a + foo相当のバイトコードが生成されます。

# `a + foo`に相当する部分
#
# 0000 getlocal_WC_0                          a@0                       (   1)[Li]
# 0002 putself
# 0003 opt_send_without_block                 <calldata!mid:foo, argc:0, FCALL|VCALL|ARGS_SIMPLE>
# 0005 opt_plus                               <calldata!mid:+, argc:1, ARGS_SIMPLE>[CcCr]

# `a = ...`に相当する部分
#
# 0007 dup
# 0008 setlocal_WC_0                          a@0
# 0010 leave
a += foo

コンパイラでいうと変数の読み書きをする部分は変数の種類によって命令が変わり、+ fooの部分は変数の種類によらず共通です。 ここでは特に関数に切り出さず愚直にそれぞれのノードをコンパイルするようにします。

      case RB_LOCAL_VARIABLE_OPERATOR_WRITE_NODE: {
        rb_local_variable_operator_write_node_t *cast = (rb_local_variable_operator_write_node_t *)node;
        CHECK(compile_lvar(iseq, ret, node, cast->name));
        CHECK(COMPILE(ret, "op asgn value", cast->value));
        ADD_SEND_R(ret, node, cast->binary_operator, INT2FIX(1), NULL, INT2FIX(0), NULL);
        if (!popped) {
            ADD_INSN(ret, node, dup);
        }
        CHECK(compile_lasgn_lhs(iseq, ret, node, cast->name));
        break;
      }
      case RB_INSTANCE_VARIABLE_OPERATOR_WRITE_NODE: {
        rb_instance_variable_operator_write_node_t *cast = (rb_instance_variable_operator_write_node_t *)node;
        ADD_INSN2(ret, node, getinstancevariable, ID2SYM(cast->name), get_ivar_ic_value(iseq, cast->name));
        CHECK(COMPILE(ret, "op asgn value", cast->value));
        ADD_SEND_R(ret, node, cast->binary_operator, INT2FIX(1), NULL, INT2FIX(0), NULL);
        if (!popped) {
            ADD_INSN(ret, node, dup);
        }
        ADD_INSN2(ret, node, setinstancevariable, ID2SYM(cast->name), get_ivar_ic_value(iseq, cast->name));
        break;
      }

つぎにAndWriteNode、つまりa &&= fooの場合です。 このときはa = a && foo相当の命令が生成されます。

# 0000 getlocal_WC_0                          a@0                       (   1)[Li]
# 0002 dup
#
# `&&`なので`a`がfalsyのときは最後の命令へjumpする
#
# 0003 branchunless                           12
# 0005 pop
# 0006 putself
# 0007 opt_send_without_block                 <calldata!mid:foo, argc:0, FCALL|VCALL|ARGS_SIMPLE>
# 0009 dup
# 0010 setlocal_WC_0                          a@0
# 0012 leave
a &&= foo

いままではcompile_op_log関数を呼んでいたので、引数を調整しつつ引き続きcompile_op_log関数を呼ぶようにします。

      case RB_LOCAL_VARIABLE_AND_WRITE_NODE: {
        CHECK(compile_op_log(iseq, ret, node, RB_NODE_LOCAL_VARIABLE_AND_WRITE(node)->name, RB_NODE_LOCAL_VARIABLE_AND_WRITE(node)->value, popped, TRUE));
        break;
      }
      case RB_INSTANCE_VARIABLE_AND_WRITE_NODE: {
        CHECK(compile_op_log(iseq, ret, node, RB_NODE_INSTANCE_VARIABLE_AND_WRITE(node)->name, RB_NODE_INSTANCE_VARIABLE_AND_WRITE(node)->value, popped, TRUE));
        break;
      }

compile_op_log関数はおおまかに以下の4つのステップからなります。

  1. a(変数へのアクセス)をコンパイルする
  2. &&なのでbranchunlessを追加する
  3. fooコンパイルする
  4. a =(変数への代入)をコンパイルする
static int
compile_op_log(rb_iseq_t *iseq, LINK_ANCHOR *const ret, const NODE *const node, ID name, const NODE *const nd_value, int popped, bool op_and)
{
    switch (nd_type(node)) {
      case RB_LOCAL_VARIABLE_AND_WRITE_NODE:
        CHECK(compile_lvar(iseq, ret, node, name));
        break;
      case RB_INSTANCE_VARIABLE_AND_WRITE_NODE:
        ADD_INSN2(ret, node, getinstancevariable, ID2SYM(name), get_ivar_ic_value(iseq, name));
        break;
      ...
    }

    if (!popped) {
        ADD_INSN(ret, node, dup);
    }

    if (op_and) {
        ADD_INSNL(ret, node, branchunless, lfin);
    }
    else {
        ADD_INSNL(ret, node, branchif, lfin);
    }

    if (!popped) {
        ADD_INSN(ret, node, pop);
    }

    ADD_LABEL(ret, lassign);
    CHECK(COMPILE(ret, "NODE_OP_ASGN_AND/OR#nd_value", nd_value));
    if (!popped) {
        ADD_INSN(ret, node, dup);
    }

    switch (nd_type(node)) {
      case RB_LOCAL_VARIABLE_AND_WRITE_NODE:
        CHECK(compile_lasgn_lhs(iseq, ret, node, name));
        break;
      case RB_INSTANCE_VARIABLE_AND_WRITE_NODE:
        ADD_INSN2(ret, node, setinstancevariable, ID2SYM(name), get_ivar_ic_value(iseq, name));
        break;
      ...
    }

    ADD_LABEL(ret, lfin);
    return COMPILE_OK;
}

最後にOrWriteNode、つまりa ||= fooの場合です。 このときはa = a || foo相当の命令が生成されます。

# 0000 getlocal_WC_0                          a@0                       (   1)[Li]
# 0002 dup
#
# `||`なので`a`がtruthyのときは最後の命令へjumpする
#
# 0003 branchif                               12
# 0005 pop
# 0006 putself
# 0007 opt_send_without_block                 <calldata!mid:foo, argc:0, FCALL|VCALL|ARGS_SIMPLE>
# 0009 dup
# 0010 setlocal_WC_0                          a@0
# 0012 leave
a ||= foo

この場合もcompile_op_log関数に処理を任せるようにします。 &&の場合も||の場合も変数を読み書きする部分は変わらないので、大枠での処理は変わりません。

static int
compile_op_log(rb_iseq_t *iseq, LINK_ANCHOR *const ret, const NODE *const node, ID name, const NODE *const nd_value, int popped, bool op_and)
{
    switch (nd_type(node)) {
      case RB_LOCAL_VARIABLE_AND_WRITE_NODE:
      case RB_LOCAL_VARIABLE_OR_WRITE_NODE:
        CHECK(compile_lvar(iseq, ret, node, name));
        break;
      case RB_INSTANCE_VARIABLE_AND_WRITE_NODE:
      case RB_INSTANCE_VARIABLE_OR_WRITE_NODE:
        ADD_INSN2(ret, node, getinstancevariable, ID2SYM(name), get_ivar_ic_value(iseq, name));
        break;
      ...
    }

defined? ?

これでおしまいと言いたいところですが、実は||=のケースはすこしだけ特殊です。 左辺の変数の種類を変えながらバイトコードを見てみると、クラス変数とグローバル変数のケースではdefinedという命令をつかって変数の有無を確認しています。

# 0000 getlocal_WC_0                          a@0                       (   1)[Li]
# 0002 branchif                               9
# 0004 putself
# 0005 opt_send_without_block                 <calldata!mid:foo, argc:0, FCALL|VCALL|ARGS_SIMPLE>
# 0007 setlocal_WC_0                          a@0
a ||= foo

# 0009 getinstancevariable                    :@a, <is:0>               (   2)[Li]
# 0012 branchif                               20
# 0014 putself
# 0015 opt_send_without_block                 <calldata!mid:foo, argc:0, FCALL|VCALL|ARGS_SIMPLE>
# 0017 setinstancevariable                    :@a, <is:1>
@a ||= foo

# 0020 putnil                                                           (   3)[Li]
# 0021 defined                                class variable, :@@a, true
# 0025 branchunless                           32
# 0027 getclassvariable                       :@@a, <is:2>
# 0030 branchif                               38
# 0032 putself
# 0033 opt_send_without_block                 <calldata!mid:foo, argc:0, FCALL|VCALL|ARGS_SIMPLE>
# 0035 setclassvariable                       :@@a, <is:2>
@@a ||= foo

# 0038 putnil                                                           (   4)[Li]
# 0039 defined                                global-variable, :$a, true
# 0043 branchunless                           51
# 0045 getglobal                              :$a
# 0047 dup
# 0048 branchif                               57
# 0050 pop
# 0051 putself
# 0052 opt_send_without_block                 <calldata!mid:foo, argc:0, FCALL|VCALL|ARGS_SIMPLE>
# 0054 dup
# 0055 setglobal                              :$a
# 0057 leave
$a ||= foo

もともとのcompile_op_log関数の実装でもORかつインスタンス変数以外のときにdefined_expr関数を呼び出しています。

static int
compile_op_log(rb_iseq_t *iseq, LINK_ANCHOR *const ret, const NODE *const node, int popped, const enum node_type type)
{
    const int line = nd_line(node);
    LABEL *lfin = NEW_LABEL(line);
    LABEL *lassign;

    if (type == NODE_OP_ASGN_OR && !nd_type_p(RNODE_OP_ASGN_OR(node)->nd_head, NODE_IVAR)) {
        LABEL *lfinish[2];
        lfinish[0] = lfin;
        lfinish[1] = 0;
        defined_expr(iseq, ret, RNODE_OP_ASGN_OR(node)->nd_head, lfinish, Qfalse, false);
        lassign = lfinish[1];
        if (!lassign) {
            lassign = NEW_LABEL(line);
        }
        ADD_INSNL(ret, node, branchunless, lassign);
    }
    else {
        lassign = NEW_LABEL(line);
    }

ちなみにローカル変数のケースではdefined命令が生成されていないように見えますが、これはバイトコードの最適化により消えているだけです。 debugモードをonにしてバイトコードを生成してみるとわかります。

# -- raw disasm--------
#   trace: 1
#   0000 putobject            true                                        (   8)
#   0002 branchunless         <L001>                                      (   8)
#   0004 getlocal             3, 0                                        (   8)
#   0007 dup                                                              (   8)
#   0008 branchif             <L000>                                      (   8)
#   0010 pop                                                              (   8)
# <L001> [sp: -1, unremovable: 0, refcnt: 1]
#   0011 putself                                                          (   8)
#   0012 send                 <calldata:foo, 0>, nil                      (   8)
#   0015 dup                                                              (   8)
#   0016 setlocal             3, 0                                        (   8)
# <L000> [sp: -1, unremovable: 0, refcnt: 1]
#   0019 leave                                                            (   8)
# ---------------------
# [compile step 3.1 (iseq_optimize)]
# -- raw disasm--------
#   trace: 1
#   0000 getlocal_WC_0        3                                           (   8)
#   0002 dup                                                              (   8)
#   0003 branchif             <L000>                                      (   8)
#   0005 pop                                                              (   8)
#   0006 putself                                                          (   8)
#   0007 opt_send_without_block <calldata:foo, 0>                         (   8)
#   0009 dup                                                              (   8)
#   0010 setlocal_WC_0        3                                           (   8)
# <L000> [sp: -1, unremovable: 0, refcnt: 1]
#   0012 leave                                                            (   8)
# ---------------------
a ||= foo

defined_exprによってputobject truebranchunlessが生成されますが、この場合は絶対にjumpしないので、2つまとめて削除されています。

defined命令の部分についていくつかの実装案が思い浮かびます。

  1. defined_expr関数でうまくClassVariableOrWriteNodeGlobalVariableOrWriteNodeをハンドリングする
  2. 呼び出し元のcompile_op_log関数でClassVariableWriteNodeGlobalVariableWriteNodeといった代入を表すノードを生成してdefined_expr関数に渡すようにする
  3. defined_expr関数を使わずにcompile_op_log関数内部で直接バイトコードを生成する

1番目の方法はdefined? @@a ||= fooのための実装とぶつかるので現実的ではありません。 2番目の方法はノードの作成にはノードのメモリ空間を管理している構造体を引き摺り回してくるなどの手間があります。

生成されるバイトコードがそこまで多くないことと、ノードの種類以外の要素で分岐しなくていいことを踏まえて3番目の実装案にします。

static int
compile_op_log(rb_iseq_t *iseq, LINK_ANCHOR *const ret, const NODE *const node, ID name, const NODE *const nd_value, int popped, bool op_and)
{
    const int line = nd_line(node);
    LABEL *lfin = NEW_LABEL(line);
    LABEL *lassign = NEW_LABEL(line);

    switch (nd_type(node)) {
      case RB_CLASS_VARIABLE_OR_WRITE_NODE:
        ADD_INSN(ret, node, putnil);
        ADD_INSN3(ret, node, defined, INT2FIX(DEFINED_CVAR), ID2SYM(name), Qtrue);
        ADD_INSNL(ret, node, branchunless, lassign);
        break;
      case RB_GLOBAL_VARIABLE_OR_WRITE_NODE:
        ADD_INSN(ret, node, putnil);
        ADD_INSN3(ret, node, defined, INT2FIX(DEFINED_GVAR), ID2SYM(name), Qtrue);
        ADD_INSNL(ret, node, branchunless, lassign);
        break;
    }
    ...

ここまでを踏まえて、簡単なコード例で動作確認してみます。

a = 0

a += 1
p a
#=> 1

a ||= 2
p a
#=> 1

a &&= 3
p a
#=> 3

$a = 0

$a += 1
p $a
#=> 1

$a ||= 2
p $a
#=> 1

$a &&= 3
p $a
#=> 3

良さそうですね。

まとめ

今日の成果です。

  • assignment with operator (a += foo)に対応した

Ruby Parser開発日誌 (24-34) - parse.yが生成するノードを変える ー array assignment with operatorとattr assignment with operator

34日目: ary[1] += foostruct.field += foo

前回は匿名引数とforwardingというメソッド呼び出しに関係する部分をやりました。 今回はary[1] += foostruct.field += fooという形式のメソッド呼び出しについて取り組んでいきたいと思います。

ary[1] += fooの場合

この形式のメソッド呼び出しでは最終的に以下の3つを異なるものとして扱う必要があると思います。

  1. ary[1] += foo
  2. ary[1] ||= foo
  3. ary[1] &&= foo

というのもこれらは以下のような意味になるわけですが、とくに1と2 & 3では生成されるバイトコードが異なるはずだからです。 1はメソッド呼び出しと代入ですが、2と3ではメソッド呼び出しの結果をみて分岐する必要があります。

  1. ary[1] = ary[1] + foo
  2. ary[1] = ary[1] || foo
  3. ary[1] = ary[1] && foo

そのあたりに注意しながら進めていきましょう。

parse.yの変更

ノードの書き換え前はいずれのケースもNODE_OP_ASGN1で表現していて、それぞれの違いはnd_midの部分で区別をしています。

# @ NODE_OP_ASGN1 (id: 4, line: 1, location: (1,0)-(1,13))*
# +- nd_recv:
# |   @ NODE_VCALL (id: 0, line: 1, location: (1,0)-(1,3))
# |   +- nd_mid: :ary
# +- nd_mid: :+
# +- nd_index:
# |   @ NODE_LIST (id: 2, line: 1, location: (1,4)-(1,5))
# |   +- as.nd_alen: 1
# |   +- nd_head:
# |   |   @ NODE_INTEGER (id: 1, line: 1, location: (1,4)-(1,5))
# |   |   +- val: 1
# +- nd_rvalue:
# |   @ NODE_VCALL (id: 3, line: 1, location: (1,10)-(1,13))
# |   +- nd_mid: :foo
ary[1] += foo

# @ NODE_OP_ASGN1 (id: 9, line: 2, location: (2,0)-(2,13))*
# +- nd_recv:
# |   @ NODE_VCALL (id: 5, line: 2, location: (2,0)-(2,3))
# |   +- nd_mid: :ary
# +- nd_mid: :-
# +- nd_index:
# |   @ NODE_LIST (id: 7, line: 2, location: (2,4)-(2,5))
# |   +- as.nd_alen: 1
# |   +- nd_head:
# |   |   @ NODE_INTEGER (id: 6, line: 2, location: (2,4)-(2,5))
# |   |   +- val: 1
# +- nd_rvalue:
# |   @ NODE_VCALL (id: 8, line: 2, location: (2,10)-(2,13))
# |   +- nd_mid: :foo
ary[1] -= foo

# @ NODE_OP_ASGN1 (id: 16, line: 3, location: (3,0)-(3,14))*
# +- nd_recv:
# |   @ NODE_VCALL (id: 12, line: 3, location: (3,0)-(3,3))
# |   +- nd_mid: :ary
# +- nd_mid: :||
# +- nd_index:
# |   @ NODE_LIST (id: 14, line: 3, location: (3,4)-(3,5))
# |   +- as.nd_alen: 1
# |   +- nd_head:
# |   |   @ NODE_INTEGER (id: 13, line: 3, location: (3,4)-(3,5))
# |   |   +- val: 1
# +- nd_rvalue:
# |   @ NODE_VCALL (id: 15, line: 3, location: (3,11)-(3,14))
# |   +- nd_mid: :foo
ary[1] ||= foo

# @ NODE_OP_ASGN1 (id: 22, line: 4, location: (4,0)-(4,14))*
# +- nd_recv:
# |   @ NODE_VCALL (id: 18, line: 4, location: (4,0)-(4,3))
# |   +- nd_mid: :ary
# +- nd_mid: :&&
# +- nd_index:
# |   @ NODE_LIST (id: 20, line: 4, location: (4,4)-(4,5))
# |   +- as.nd_alen: 1
# |   +- nd_head:
# |   |   @ NODE_INTEGER (id: 19, line: 4, location: (4,4)-(4,5))
# |   |   +- val: 1
# +- nd_rvalue:
# |   @ NODE_VCALL (id: 21, line: 4, location: (4,11)-(4,14))
# |   +- nd_mid: :foo
ary[1] &&= foo

ノードの書き換え後は+=-=IndexOperatorWriteNode||=IndexOrWriteNode&&=IndexAndWriteNodeで表すようになります。

parse.yではnew_ary_op_assign関数でNODE_OP_ASGN1を生成しています。

%rule op_asgn(rhs) <node> | primary_value '['[lbracket] opt_call_args rbracket tOP_ASGN lex_ctxt rhs
                              {
                                  $$ = new_ary_op_assign(p, $primary_value, $opt_call_args, $tOP_ASGN, $rhs, &@opt_call_args, &@$, &NULL_LOC, &@lbracket, &@rbracket, &@tOP_ASGN);
                              /*% ripper: opassign!(aref_field!($:1, $:3), $:5, $:7) %*/
                              }

この関数を変更して、$tOP_ASGNに応じて生成するノードを変えればよいでしょう。

static rb_node_t *
new_ary_op_assign(struct parser_params *p, rb_node_t *ary,
                  rb_arguments_node_t *args, ID op, rb_node_t *rhs, const YYLTYPE *args_loc, const YYLTYPE *loc,
                  const YYLTYPE *call_operator_loc, const YYLTYPE *opening_loc, const YYLTYPE *closing_loc, const YYLTYPE *binary_operator_loc)
{
    rb_node_t *asgn;

    aryset_check(p, args);

    switch (op) {
      case idOROP:
        asgn = NEW_RB_INDEX_OR_WRITE(ary, args, rhs, loc, call_operator_loc, opening_loc, closing_loc, binary_operator_loc);
        break;
      case idANDOP:
        asgn = NEW_RB_INDEX_AND_WRITE(ary, args, rhs, loc, call_operator_loc, opening_loc, closing_loc, binary_operator_loc);
        break;
      default :
        asgn = NEW_RB_INDEX_OPERATOR_WRITE(ary, op, args, rhs, loc, call_operator_loc, opening_loc, closing_loc, binary_operator_loc);
        break;
    }
    fixpos(asgn, ary);
    return asgn;
}

バイトコードを眺める

コンパイラに変更を加える前にバイトコードを見ておきましょう。

# == disasm: #<ISeq:<main>@test.rb:1 (1,0)-(4,13)>
# 0000 putnil                                                           (   4)[Li]
# 0001 putself
# 0002 opt_send_without_block                 <calldata!mid:ary, argc:0, FCALL|VCALL|ARGS_SIMPLE>
# 0004 putobject_INT2FIX_1_
# 0005 dupn                                   2
# 0007 opt_aref                               <calldata!mid:[], argc:1, ARGS_SIMPLE>[CcCr]
# 0009 putself
# 0010 opt_send_without_block                 <calldata!mid:foo, argc:0, FCALL|VCALL|ARGS_SIMPLE>
# 0012 opt_plus                               <calldata!mid:+, argc:1, ARGS_SIMPLE>[CcCr]
# 0014 setn                                   3
# 0016 opt_aset                               <calldata!mid:[]=, argc:2, ARGS_SIMPLE>[CcCr]
# 0018 pop
# 0019 leave
ary[1] += foo

それぞれのステップにおけるスタックの状態を考えていきます。 まず初めにnilをスタックに積み、その上に左辺に当たるary1を積みます。 このnilはあとで式全体の値を保持するのに使います。

# `0004 putobject_INT2FIX_1_`まで
1
ary
nil

ary[1] += fooというのはary[1] = ary[1] + fooのことであり、この右辺にあるary[1]を評価します。 後にary[1] =のメソッド呼び出しをすることになるので、ここではdupnをつかってスタックをコピーしておきます。

# `0005 dupn 2`まで
1
ary
1
ary
nil

# `0007 opt_aref <calldata!mid:[], argc:1, ARGS_SIMPLE>[CcCr]`まで
ary[1]
1
ary
nil

ary[1] = ary[1] + foofooary[1] + fooの評価をします。 この時点で式全体の戻り値が確定するので、最初にスタックにおいたnilsetnを用いて置き換えます。

# `0010 opt_send_without_block <calldata!mid:foo, argc:0, FCALL|VCALL|ARGS_SIMPLE>`まで
foo
ary[1]
1
ary
nil

# `0012 opt_plus <calldata!mid:+, argc:1, ARGS_SIMPLE>[CcCr]`まで
ary[1] + foo
1
ary
nil

# `0014 setn 3`まで
ary[1] + foo
1
ary
ary[1] + foo

最後にary[1] =の部分を評価し、popを用いてスタックの状態を調整して完了です。

# `0016 opt_aset <calldata!mid:[]=, argc:2, ARGS_SIMPLE>[CcCr]`まで
# ary.[]=(1, ary[1] + foo) と同等
ary.[]=(1, ary[1] + foo)
ary[1] + foo

# `0018 pop`まで
# 右辺が最終的な式の評価値になる
ary[1] + foo

ではary[1] ||= fooの場合はどうなるでしょうか。

# == disasm: #<ISeq:<main>@test.rb:1 (1,0)-(1,14)>
# 0000 putnil                                                           (   1)[Li]
# 0001 putself
# 0002 opt_send_without_block                 <calldata!mid:ary, argc:0, FCALL|VCALL|ARGS_SIMPLE>
# 0004 putobject_INT2FIX_1_
# 0005 dupn                                   2
# 0007 opt_aref                               <calldata!mid:[], argc:1, ARGS_SIMPLE>[CcCr]
# 0009 dup
# 0010 branchif                               22
# 0012 pop
# 0013 putself
# 0014 opt_send_without_block                 <calldata!mid:foo, argc:0, FCALL|VCALL|ARGS_SIMPLE>
# 0016 setn                                   3
# 0018 opt_aset                               <calldata!mid:[]=, argc:2, ARGS_SIMPLE>[CcCr]
# 0020 pop
# 0021 leave
# 0022 setn                                   3
# 0024 adjuststack                            3
# 0026 leave
ary[1] ||= foo

||なのでary[1]の値がnilfalseのときだけ代入が行われます。 そのためary[1]を評価したあと0010 branchif 22でその値をチェックして分岐しています。 もしary[1]がfalsyならばそのまま0012 popに進み、先ほど同じような命令を実行して0021 leaveで抜けます。

ary[1]がtruthyの場合は0022 setn 3以降の命令が実行されます。 0010 branchif 22の時点でのスタックは以下のようになっています。

ary[1]
1
ary
nil

ary[1] = ary[1] || fooary[1]がtruthyの場合、式全体の値はary[1]になります。 そこでsetnary[1]nilの場所にコピーし、adjuststackでスタックの高さを調整します。

# `0022 setn 3`まで
ary[1]
1
ary
ary[1]

# `0024 adjuststack 3`まで
ary[1]

compile.cの変更

バイトコードはそれなりに複雑ですがcompile.cの変更は局所的です。 今までは3つのパターン全てがNODE_OP_ASGN1というノードで表現されてきました。 なのでそのノードをコンパイルするcompile_op_asgn1関数の中に&&=||=などidに基づいた分岐があります。

ノードの書き換え後は3つのパターンで異なるノードを使いわけるようになるので、compile_op_asgn1関数の外でidの解決を行うようにします。

 static int
-compile_op_asgn1(rb_iseq_t *iseq, LINK_ANCHOR *const ret, const NODE *const node, int popped)
+compile_op_asgn1(rb_iseq_t *iseq, LINK_ANCHOR *const ret, const NODE *const node, const NODE *const nd_recv, ID id, const rb_arguments_node_t *cons
t nd_index, const NODE *const nd_rvalue, int popped)
 {
     const int line = nd_line(node);
     VALUE argc;
     unsigned int flag = 0;
     int asgnflag = 0;
-    ID id = RNODE_OP_ASGN1(node)->nd_mid;

     /*
      * a[x] (op)= y

+      case RB_INDEX_OPERATOR_WRITE_NODE: {
+        rb_index_operator_write_node_t *cast = (rb_index_operator_write_node_t *)node;
+        CHECK(compile_op_asgn1(iseq, ret, node, cast->receiver, cast->binary_operator, cast->arguments, cast->value, popped));
+        break;
+      }
+      case RB_INDEX_OR_WRITE_NODE: {
+        rb_index_or_write_node_t *cast = (rb_index_or_write_node_t *)node;
+        CHECK(compile_op_asgn1(iseq, ret, node, cast->receiver, idOROP, cast->arguments, cast->value, popped));
+        break;
+      }
+      case RB_INDEX_AND_WRITE_NODE: {
+        rb_index_and_write_node_t *cast = (rb_index_and_write_node_t *)node;
+        CHECK(compile_op_asgn1(iseq, ret, node, cast->receiver, idANDOP, cast->arguments, cast->value, popped));
+        break;
+      }

minirubyをビルドして実行してみます。

a = [0, 1]
a[0] += 2
p a
#=> [2, 1]

b = [true, false]
b[0] ||= 0
b[1] ||= 1
p b
#=> [true, 1]

c = [true, false]
c[0] &&= 0
c[1] &&= 1
p c
#=> [0, false]

良さそうです。

struct.field += fooの場合

続いてstruct.field += fooのケースをやっていきましょう。 具体的には以下の3つのパターンがあります。

s.f += foo
s.f &&= foo
s.f ||= foo

このケースもary[1] += fooと同様に書き換え前は3つのパターン全てをNODE_OP_ASGN2という1つのノードで表現していました。 また書き換え後はary[1] += fooと同様に、CallOperatorWriteNode, CallAndWriteNode, CallOrWriteNodeの3種類のノードを使い分けることになります。

parseの変更

NODE_OP_ASGN2new_attr_op_assign関数で生成しているので、この関数を修正してidの種類に応じて生成するノードを切り替えます。

static rb_node_t *
new_attr_op_assign(struct parser_params *p, rb_node_t *lhs,
                   ID atype, ID attr, ID op, rb_node_t *rhs, const YYLTYPE *loc,
                   const YYLTYPE *call_operator_loc, const YYLTYPE *message_loc, const YYLTYPE *binary_operator_loc)
{
    rb_node_t *asgn;

    switch (op) {
      case idOROP:
        asgn = NEW_RB_CALL_OR_WRITE(lhs, attr, rhs, loc, call_operator_loc, message_loc, binary_operator_loc);
        break;
      case idANDOP:
        asgn = NEW_RB_CALL_AND_WRITE(lhs, attr, rhs, loc, call_operator_loc, message_loc, binary_operator_loc);
        break;
      default: {
        int flags = CALL_Q_P(atype) ? RB_CALL_NODE_FLAGS_SAFE_NAVIGATION : 0;
        asgn = NEW_RB_CALL_OPERATOR_WRITE(lhs, flags, attr, op, rhs, loc, call_operator_loc, message_loc, binary_operator_loc);
        break;
      }
    }

    fixpos(asgn, lhs);
    return asgn;
}

ここで注意点が2つあります。

1つはwrite_nameというフィールドについてです。 例えばs.f += fooというコードがあるとき、そのバイトコードではs.fの呼び出しとs.f=の呼び出しを行うことになります。 ノードの書き換え前はnd_vid: :fとしてfのシンボルだけをノードに持たせて、:f=コンパイル時に計算していました。 ノードの書き換え後はread_name: :fwrite_name: :f=として両方のシンボルをノードに持たせることになります。

もうひとつはs.f += foos&.f += fooをどのようにして区別するかです。 ノードの書き換え前はbool nd_aidというフィールドがノードにあり、それをみて.なのか&.なのかを判断していました。 ノードの書き換え後はCallOperatorWriteNodeのフラグにsafe_navigationが立っているかどうかで判断するようになります。

コンパイラを修正する

コンパイラの修正はarray assignment with operatorのときと同様にcompile_op_asgn2関数の呼び出し元で必要な情報を渡すように書き換えます。

@@ -11897,6 +11901,21 @@ iseq_compile_each0(rb_iseq_t *iseq, LINK_ANCHOR *const ret, const NODE *const no
         CHECK(compile_op_asgn1(iseq, ret, node, cast->receiver, idANDOP, cast->arguments, cast->value, popped));
         break;
       }
+      case RB_CALL_OPERATOR_WRITE_NODE: {
+        rb_call_operator_write_node_t *cast = (rb_call_operator_write_node_t *)node;
+        CHECK(compile_op_asgn2(iseq, ret, node, cast->receiver, cast->read_name, cast->write_name, cast->binary_operator, cast->value, popped));
+        break;
+      }
+      case RB_CALL_OR_WRITE_NODE: {
+        rb_call_or_write_node_t *cast = (rb_call_or_write_node_t *)node;
+        CHECK(compile_op_asgn2(iseq, ret, node, cast->receiver, cast->read_name, cast->write_name, idOROP, cast->value, popped));
+        break;
+      }
+      case RB_CALL_AND_WRITE_NODE: {
+        rb_call_and_write_node_t *cast = (rb_call_and_write_node_t *)node;
+        CHECK(compile_op_asgn2(iseq, ret, node, cast->receiver, cast->read_name, cast->write_name, idANDOP, cast->value, popped));
+        break;
+      }

いくつかコードを実行して確認してみましょう。

class S
  attr_accessor :f
end

s = nil
s&.f += 1
p s
#=> nil

s = S.new
p s.f
#=> nil

s.f = 0

s.f += 1
p s.f
#=> 1

s&.f += 2
p s.f
#=> 3

&&=||=も確認します。

class S
  attr_accessor :f
end

s = S.new
p s.f
#=> nil

s.f &&= 1
p s.f
#=> nil

s.f ||= 1
p s.f
#=> 1

良さそうです。

まとめ

今日の成果です。

  • array assignment with operator (ary[1] += foo)に対応した
  • attr assignment with operator (s.f += foo)に対応した



以上の内容はhttps://yui-knk.hatenablog.com/より取得しました。
このページはhttp://font.textar.tv/のウェブフォントを使用してます

不具合報告/要望等はこちらへお願いします。
モバイルやる夫Viewer Ver0.14