以下の内容はhttps://arata-nvm.hatenablog.com/entry/2025/12/06/000000より取得しました。


TorchDynamoの仕組みを調べる (1. 概要)

この文章はPyTorch v2.9.1に基づいています。

導入

PyTorch 2.0 では、PyTorchのコードを高速に実行するための機能としてJITコンパイラが導入されました。この機能は、主に以下のコンポーネントから構成されます。

  • TorchDynamo: PyTorchのコードを仮想的に実行し、その実行トレースからPyTorchの計算グラフであるFXグラフを構築する
  • TorchInductor: FXグラフからOpenMPTritonなど特定のバックエンド向けにコードを生成する

この記事では、特にTorchDynamoの仕組みに焦点を当てて解説していきます。

動作例

実際にTorchDynamoを使用して、どのようなPyTorchのコードからどのようなFXグラフが生成されるのかを確認しましょう。以下のコードを実行すると、関数fに対応するFXグラフが構築され、それをGraphvizで可視化したものが出力されます。

import torch
from torch.fx.passes.graph_drawer import FxGraphDrawer


def f(x: torch.Tensor) -> torch.Tensor:
    return torch.abs(x) + 1


x = torch.randn(3, 4)
output = torch._dynamo.explain(f)(x)

for i, graph in enumerate(output.graphs):
    name = f"graph_{i}"
    FxGraphDrawer(graph, name).get_dot_graph().write_png(f"{name}.png")

実際に出力されたグラフが以下になります。

関数fから構築されたFXグラフ

関数fでは引数xテンソルに対してまず絶対値を取り、1を足すという演算を行なっています。FXグラフもその演算に対応する内容になっており、最初と最後のノードに入力と出力を表すノードがあり、その間には絶対値、加算の演算を表す2ノードがあります。

ちなみにFXグラフの文字列表現も存在していて、先ほどのFXグラフは以下のようにも表現されます。

class GraphModule(torch.nn.Module):
    def forward(self, L_x_: "f32[3, 4]"):
        l_x_ = L_x_

         # File: (snip)/prog1.py:8 in f, code: return torch.abs(x) + 1
        abs_1: "f32[3, 4]" = torch.abs(l_x_);  l_x_ = None
        add: "f32[3, 4]" = abs_1 + 1;  abs_1 = None
        return (add,)

見てわかる通りこれはPythonのコードとなっていて、実際に実行することができます。

このように、PyTorchのコードから、コンパイラの入力により適した形式であるFXグラフを構築することがTorchDynamoの主な仕事となっています。

重要な概念

TorchDynamoの仕組みを理解するにあたっては、いくつか知っていると良い概念があります。ここではGuardとGraph Breakの2つの概念について説明します。

Guard

Guardとは、変換対象のPyTorchのコードが外部のオブジェクトに依存する場合に、そのオブジェクトがコンパイル時から変化していないことを検査するための仕組みです。具体例として、以下のように関数fグローバル変数aに依存するようなコードを考えます。

a = 1
def f(x: torch.Tensor) -> torch.Tensor:
    return x + a

このコードに対応するFXグラフを構築すると以下のようになります。

class GraphModule(torch.nn.Module):
    def forward(self, L_x_: "f32[3, 4]"):
        l_x_ = L_x_

         # File: (snip)/prog1.py:8 in f, code: return x + a
        add: "f32[3, 4]" = l_x_ + 1;  l_x_ = None
        return (add,)

変換前はaを加算する処理になっていましたが、変換後は1を加算する処理に変化しました。これはTorchDynamoがコードを特殊化してFXグラフに変換するために発生する現象です。変換時にグローバル変数aの値が1であったため、TorchDynamoはその値をFXグラフに埋め込むような挙動をとります。

この挙動は、例えば参照していたオブジェクトの値がコンパイル後に変化するような状況で問題になります。具体的には、変換後にグローバル変数aの値を2に変更したとしても、FXグラフにはコンパイル時のaの値である1が埋め込まれているので、その変更が反映されず、計算結果が意図しないものになる可能性があります。

これを防ぐために、TorchDynamoはオブジェクトの値が変化していないことをGuardを用いて検証し、もし値の変化が検出された場合にはFXグラフへの変換を再度行うようになっています。

GuardはFXグラフの構築時に作成されており、以下のコードで作成されたGuard一覧を取得できます。

for i, guard in enumerate(output.out_guards):
    print("Guard", i, ":", guard)

例えば先ほどの関数fの変換時には、グローバル変数aに関して以下のGuardが作成されます。

Guard 5 : Name: "G['a']"
    Source: global
    Create Function: EQUALS_MATCH
    Guard Types: ['EQUALS_MATCH']
    Code List: ["G['a'] == 1"]
    Object Weakref: None
    Guarded Class Weakref: <weakref at 0x1045ca610; to 'type' at 0x106696018 (int)>

Code Listの項目が最も重要で、これはグローバル変数aの値が1から変化していないことを検査する条件式になっています。もしこの条件式が満たされない場合は、FXグラフが依存するオブジェクトの値が変化したとみなされ、再変換の対象となります。

Graph Break

変換対象のPyTorchのコードにFXグラフに変換できない処理が含まれていた場合、複数のFXグラフに分割して構築が行われる場合があります。例えば以下のようなコードを考えます。

@torch._dynamo.disable
def unsupported_op() -> None:
    pass


def f(x: torch.Tensor) -> torch.Tensor:
    v1 = x + 1
    unsupported_op()
    return v1 * 2

unsupported_op関数はFXグラフに変換できない処理を想定した関数です。このコードをFXグラフに変換すると、以下の2つのFXグラフが得られます。

class GraphModule(torch.nn.Module):
    def forward(self, L_x_: "f32[3, 4]"):
        l_x_ = L_x_

         # File: (snip)/prog1.py:16 in f, code: v1 = x + 1
        v1: "f32[3, 4]" = l_x_ + 1;  l_x_ = None
        return (v1,)

class GraphModule(torch.nn.Module):
    def forward(self, L_v1_: "f32[3, 4]"):
        l_v1_ = L_v1_

         # File: /(snip)/prog1.py:18 in torch_dynamo_resume_in_f_at_17, code: return v1 * 2
        mul: "f32[3, 4]" = l_v1_ * 2;  l_v1_ = None
        return (mul,)

1つ目のFXグラフをforward_1、2つ目のFXグラフをforward_2とおくと、変換後のコードはおおよそ以下のような形になります。

def f(x: torch.Tensor) -> torch.Tensor:
    v1 = forward_1()
    unsupported_op()
    return forward_2(v1)

FXグラフに変換できる部分のみが変換され、変換できない部分はフォールバックとしてPythonのコードのままになります。このように複数のFXグラフに変換されることをGraph Breakと呼びます。

Graph Breakが発生した場合、Pythonコードが実行される部分でCPUに処理を戻す必要があり、実行パフォーマンスの低下を招く場合があります。

変換の流れ

TorchDynamoは内部でFrame Evaluation API (PEP 523)*1を使用しており、実際に変換対象のコードを実行し、実行されたPythonバイトコードをトレースすることによってFXグラフへの変換を実現しています。

この変換の流れに関してはdepyfのドキュメントの図がわかりやすいので以下に引用します。

TorchDynamoがPyTorchのコードをFXグラフに変換するフロー*2

次回からは、この変換処理をソースコードレベルで調査していきます。




以上の内容はhttps://arata-nvm.hatenablog.com/entry/2025/12/06/000000より取得しました。
このページはhttp://font.textar.tv/のウェブフォントを使用してます

不具合報告/要望等はこちらへお願いします。
モバイルやる夫Viewer Ver0.14