以下の内容はhttps://arata-nvm.hatenablog.com/entry/2025/12/07/000000より取得しました。


TorchDynamoの仕組みを調べる (2. torch._dynamo.eval_frame.explain)

torch._dynamo.explainの処理を追っていく。

torch._dynamo.eval_frame.explain

torch/_dynamo/eval_frame.py:1378

def explain(f: Callable[..., Any], *extra_args: Any, **extra_kwargs: Any) -> Any:
...
    def inner(*args: Any, **kwargs: Any) -> ExplainOutput:
...
        def dynamo_graph_accumulating_compiler(
            gm: torch.fx.GraphModule, example_inputs: Any
        ) -> Callable[..., Any]:
            from .backends.debugging import _explain_graph_detail

            nonlocal graphs
            nonlocal op_count
            nonlocal ops_per_graph
            nonlocal break_reasons

            gm, graphs, op_count, ops_per_graph, break_reasons = _explain_graph_detail(
                gm, graphs, op_count, ops_per_graph, break_reasons
            )

            return gm.forward

        def guard_export_print(guards: Iterable[_guards.Guard]) -> None:
            nonlocal out_guards
            out_guards.extend(guards)

        opt_f = optimize(
            dynamo_graph_accumulating_compiler,
            nopython=False,
            guard_export_fn=guard_export_print,
        )(f)
        # TODO(voz): We may have instances of `f` that mutate inputs, we should track sideeffects and reject.
        opt_f(*args, **kwargs)
...
        return ExplainOutput(
            graphs,
            graph_count,
            graph_break_count,
            break_reasons,
            op_count,
            ops_per_graph,
            out_guards,
            compile_time,
        )
...

中でoptimize関数を呼んでいて、その戻り値に対して2回関数呼び出しをしている。1回目の呼び出しでは変換対象の関数を与えており、2回目の呼び出しでは変換対象の関数に与える引数を与えている。

optimize関数の引数にはdynamo_graph_accumulating_compiler関数が与えられており、これはtorch.fx.GraphModuleを受け取ってCallableなオブジェクトを返すようである。おそらくコンパイラのバックエンドにあたるものをここで渡すことが想定されているが、今回は単にGraphModuleから必要な情報を取り出してGraphModule.forwardを返している。また、guard_export_fnとしてguard_export_print関数も与えられており、これはGuardの取得のために使われているようである。

torch._dynamo.eval_frame.optimize

torch/_dynamo/eval_frame.py:1251

def optimize(*args: Any, **kwargs: Any) -> Union[OptimizeContext, _NullDecorator]:
    def rebuild_ctx() -> Union[OptimizeContext, _NullDecorator]:
...
        return optimize(*args, **kwargs)

    return _optimize(rebuild_ctx, *args, **kwargs)

optimize関数を呼び出す関数rebuild_ctx関数を引数に与えて、_optimize関数を呼び出している。

torch._dynamo.eval_frame._optimize

torch/_dynamo/eval_frame.py:1266

def _optimize(
    rebuild_ctx: Callable[[], Union[OptimizeContext, _NullDecorator]],
    backend: Union[str, Callable[..., Any]] = "inductor",
    *,
    nopython: bool = False,
    error_on_graph_break: Optional[bool] = None,
    guard_export_fn: Optional[Callable[[_guards.GuardsSet], None]] = None,
    guard_fail_fn: Optional[Callable[[GuardFail], None]] = None,
    guard_filter_fn: Optional[Callable[[list[GuardFilterEntry]], list[bool]]] = None,
    disable: bool = False,
    dynamic: Optional[bool] = None,
    package: Optional[CompilePackage] = None,
) -> Union[OptimizeContext, _NullDecorator]:
...
    check_if_dynamo_supported()
    check_for_incompatible_configs()
...
    if (
        disable
        or os.environ.get("TORCHDYNAMO_DISABLE", "") == "1"
        or (not justknobs_check("pytorch/compiler:enable_dynamo"))
    ):
        return _NullDecorator()
...
    return _optimize_catch_errors(
        convert_frame.convert_frame(
            backend,
            hooks,
            package=package,
        ),
        hooks,
        backend_ctx_ctor,
        fullgraph=False,
        error_on_graph_break=error_on_graph_break
        and not config.debug_force_graph_break_on_leaf_return,
        dynamic=dynamic,
        compiler_config=(
            backend.get_compiler_config()
            if hasattr(backend, "get_compiler_config")
            else None
        ),
        rebuild_ctx=rebuild_ctx,
        package=package,
    )

まず最初にTorchDynamoが有効な環境であるかを検証し、有効でない場合は_NullDecoratorインスタンスを返している。その後、convert_frame.convert_frame関数を呼び出し、その戻り値を引数に_optimize_catch_errors関数を呼び出している。

torch._dynamo.eval_frame._optimize_catch_errors

torch/_dynamo/eval_frame.py:1071

def _optimize_catch_errors(
    compile_fn: convert_frame.ConvertFrameProtocol,
    hooks: Hooks,
    backend_ctx_ctor: Callable[
        [], contextlib.AbstractContextManager[Any]
    ] = null_context,
    fullgraph: bool = False,
    error_on_graph_break: Optional[bool] = None,
    export: bool = False,
    dynamic: Optional[bool] = None,
    compiler_config: Optional[Any] = None,
    rebuild_ctx: Optional[Callable[[], Union[OptimizeContext, _NullDecorator]]] = None,
    package: Optional[CompilePackage] = None,
) -> OptimizeContext:
    return OptimizeContext(
        convert_frame.catch_errors_wrapper(compile_fn, hooks),
        backend_ctx_ctor=backend_ctx_ctor,
        first_ctx=True,
        fullgraph=fullgraph,
        error_on_graph_break=error_on_graph_break,
        export=export,
        dynamic=dynamic,
        compiler_config=compiler_config,
        rebuild_ctx=rebuild_ctx,
        package=package,
        hooks=hooks,
    )

convert_frame.catch_errors_wrapper関数を呼び出し、その戻り値を引数にOptimizeContextインスタンスを作成して返している。

torch._dynamo.eval_frame.OptimizeContext.__init__

torch/_dynamo/eval_frame.py:924

class OptimizeContext(_TorchDynamoContext):
    def __init__(
        self,
        callback: DynamoCallback,
        backend_ctx_ctor: Callable[[], contextlib.AbstractContextManager[Any]],
        first_ctx: bool = False,
        *,
        fullgraph: bool = False,
        error_on_graph_break: Optional[bool] = None,
        export: bool = False,
        dynamic: Optional[bool] = None,
        compiler_config: Optional[Any] = None,
        rebuild_ctx: Optional[
            Callable[[], Union[OptimizeContext, _NullDecorator]]
        ] = None,
        package: Optional[CompilePackage] = None,
        hooks: Optional[Hooks] = None,
    ) -> None:
...
        super().__init__(
            callback=callback,
            on_enter=on_enter,
            backend_ctx_ctor=backend_ctx_ctor,
            patch_fn=TorchPatcher.patch,
            first_ctx=first_ctx,
            fullgraph=fullgraph,
            error_on_graph_break=error_on_graph_break,
            export=export,
            dynamic=dynamic,
            compiler_config=compiler_config,
            package=package,
            hooks=hooks,
        )
...

OptimizeContextの親クラス_TorchDynamoContext__init__関数を呼び出している。ここでcallback引数が指定されていることを覚えておく。

torch._dynamo.eval_frame._TorchDynamoContext.__init__

torch/_dynamo/eval_frame.py:590

class _TorchDynamoContext:
    def __init__(
        self,
        callback: DynamoCallback,
        on_enter: Callable[[], Any] = nothing,
        backend_ctx_ctor: Callable[
            [], contextlib.AbstractContextManager[Any]
        ] = null_context,
        patch_fn: Callable[[], Any] = nothing,
        first_ctx: bool = False,
        *,
        fullgraph: bool = False,
        error_on_graph_break: Optional[bool] = None,
        export: bool = False,
        dynamic: Optional[bool] = None,
        compiler_config: Optional[Any] = None,
        package: Optional[CompilePackage] = None,
        hooks: Optional[Hooks] = None,
    ) -> None:
        super().__init__()
        assert callable(callback) or callback is False or callback is None
        self.callback: DynamoCallback = callback
...

callbackの引数に与えられた値がself.callbackに保持される。

ここまででtorch._dynamo.eval_frame.optimizeの戻り値がOptimizeContextインスタンスになることがわかった。次に、OptimizeContext.__call__を調べる。

torch._dynamo.eval_frame.OptimizeContext.__call__

OptimizeContext__call__関数は実装されていないため、実際にはその親クラスの_TorchDynamoContext.__call__が呼ばれる。

torch/_dynamo/eval_frame.py:680

    def __call__(self, fn: Any) -> Any:
...
        fn = innermost_fn(fn)
...
        callback: Callable[..., Any] = do_nothing
        if hasattr(self, "callback"):
            callback = self.callback  # type: ignore[assignment]
...
        @functools.wraps(fn)
        def compile_wrapper(*args: Any, **kwargs: Any) -> Any:
            prior = set_eval_frame(None)
            try:
...
                _maybe_set_eval_frame(_callback_from_stance(callback))

                try:
                    return fn(*args, **kwargs)
                except Unsupported as e:
                    if config.verbose:
                        raise
                    # strip internal tracebacks from causes
                    cur_exn: BaseException = e
                    while cur_exn.__cause__ is not None:
                        cur_exn.__cause__.with_traceback(None)
                        cur_exn = cur_exn.__cause__
                    raise e.with_traceback(None) from e.__cause__  # User compiler error
                except ShortenTraceback as e:
                    # Failures in the backend likely don't have useful
                    # data in the TorchDynamo frames, so we strip them out.
                    raise e.remove_dynamo_frames() from None  # see TORCHDYNAMO_VERBOSE=1
                finally:
                    # Restore the dynamic layer stack depth if necessary.
                    set_eval_frame(None)
...
            finally:
                _maybe_set_eval_frame(prior)
...
        if callback not in (None, False):
            always_optimize_code_objects[fn.__code__] = True

        return compile_wrapper

まずfnに変換対象の関数を、callbackに先ほどのself.callbackの値が代入される。その後compiler_wrapper関数を構築して戻り値としている。つまり、torch._dynamo.eval_frame.explain内での2回目の関数呼び出しではこのcompiler_wrapper関数が実行される。

今回のまとめ

sequenceDiagram
    torch._dynamo.eval_frame.explain ->>+ torch._dynamo.eval_frame.optimize: compile_fn
    torch._dynamo.eval_frame.optimize ->>+ torch._dynamo.eval_frame._optimize: compile_fn
    torch._dynamo.eval_frame._optimize ->>+ torch._dynamo.eval_frame._optimize_catch_errors: convert_frame.convert_frame(compile_fn)
    torch._dynamo.eval_frame._optimize_catch_errors ->>+ torch._dynamo.eval_frame.OptimizeContext.__init__: convert_frame.catch_errors_wrapper(compile_fn)
    torch._dynamo.eval_frame.OptimizeContext.__init__ ->>- torch._dynamo.eval_frame._optimize_catch_errors: OptimizeContext
    torch._dynamo.eval_frame._optimize_catch_errors ->>- torch._dynamo.eval_frame._optimize: OptimizeContext
    torch._dynamo.eval_frame._optimize ->>- torch._dynamo.eval_frame.optimize: OptimizeContext
    torch._dynamo.eval_frame.optimize ->>- torch._dynamo.eval_frame.explain: OptimizeContext
    torch._dynamo.eval_frame.explain ->>+ torch._dynamo.eval_frame.OptimizeContext.__call__: fn
    torch._dynamo.eval_frame.OptimizeContext.__call__ ->>- torch._dynamo.eval_frame.explain: compiler_wrapper
    torch._dynamo.eval_frame.explain ->>+ compiler_wrapper: args, kwargs
    compiler_wrapper ->>- torch._dynamo.eval_frame.explain: a



以上の内容はhttps://arata-nvm.hatenablog.com/entry/2025/12/07/000000より取得しました。
このページはhttp://font.textar.tv/のウェブフォントを使用してます

不具合報告/要望等はこちらへお願いします。
モバイルやる夫Viewer Ver0.14