torch._dynamo.explainの処理を追っていく。
torch._dynamo.eval_frame.explain
torch/_dynamo/eval_frame.py:1378
def explain(f: Callable[..., Any], *extra_args: Any, **extra_kwargs: Any) -> Any: ... def inner(*args: Any, **kwargs: Any) -> ExplainOutput: ... def dynamo_graph_accumulating_compiler( gm: torch.fx.GraphModule, example_inputs: Any ) -> Callable[..., Any]: from .backends.debugging import _explain_graph_detail nonlocal graphs nonlocal op_count nonlocal ops_per_graph nonlocal break_reasons gm, graphs, op_count, ops_per_graph, break_reasons = _explain_graph_detail( gm, graphs, op_count, ops_per_graph, break_reasons ) return gm.forward def guard_export_print(guards: Iterable[_guards.Guard]) -> None: nonlocal out_guards out_guards.extend(guards) opt_f = optimize( dynamo_graph_accumulating_compiler, nopython=False, guard_export_fn=guard_export_print, )(f) # TODO(voz): We may have instances of `f` that mutate inputs, we should track sideeffects and reject. opt_f(*args, **kwargs) ... return ExplainOutput( graphs, graph_count, graph_break_count, break_reasons, op_count, ops_per_graph, out_guards, compile_time, ) ...
中でoptimize関数を呼んでいて、その戻り値に対して2回関数呼び出しをしている。1回目の呼び出しでは変換対象の関数を与えており、2回目の呼び出しでは変換対象の関数に与える引数を与えている。
optimize関数の引数にはdynamo_graph_accumulating_compiler関数が与えられており、これはtorch.fx.GraphModuleを受け取ってCallableなオブジェクトを返すようである。おそらくコンパイラのバックエンドにあたるものをここで渡すことが想定されているが、今回は単にGraphModuleから必要な情報を取り出してGraphModule.forwardを返している。また、guard_export_fnとしてguard_export_print関数も与えられており、これはGuardの取得のために使われているようである。
torch._dynamo.eval_frame.optimize
torch/_dynamo/eval_frame.py:1251
def optimize(*args: Any, **kwargs: Any) -> Union[OptimizeContext, _NullDecorator]: def rebuild_ctx() -> Union[OptimizeContext, _NullDecorator]: ... return optimize(*args, **kwargs) return _optimize(rebuild_ctx, *args, **kwargs)
optimize関数を呼び出す関数rebuild_ctx関数を引数に与えて、_optimize関数を呼び出している。
torch._dynamo.eval_frame._optimize
torch/_dynamo/eval_frame.py:1266
def _optimize( rebuild_ctx: Callable[[], Union[OptimizeContext, _NullDecorator]], backend: Union[str, Callable[..., Any]] = "inductor", *, nopython: bool = False, error_on_graph_break: Optional[bool] = None, guard_export_fn: Optional[Callable[[_guards.GuardsSet], None]] = None, guard_fail_fn: Optional[Callable[[GuardFail], None]] = None, guard_filter_fn: Optional[Callable[[list[GuardFilterEntry]], list[bool]]] = None, disable: bool = False, dynamic: Optional[bool] = None, package: Optional[CompilePackage] = None, ) -> Union[OptimizeContext, _NullDecorator]: ... check_if_dynamo_supported() check_for_incompatible_configs() ... if ( disable or os.environ.get("TORCHDYNAMO_DISABLE", "") == "1" or (not justknobs_check("pytorch/compiler:enable_dynamo")) ): return _NullDecorator() ... return _optimize_catch_errors( convert_frame.convert_frame( backend, hooks, package=package, ), hooks, backend_ctx_ctor, fullgraph=False, error_on_graph_break=error_on_graph_break and not config.debug_force_graph_break_on_leaf_return, dynamic=dynamic, compiler_config=( backend.get_compiler_config() if hasattr(backend, "get_compiler_config") else None ), rebuild_ctx=rebuild_ctx, package=package, )
まず最初にTorchDynamoが有効な環境であるかを検証し、有効でない場合は_NullDecoratorのインスタンスを返している。その後、convert_frame.convert_frame関数を呼び出し、その戻り値を引数に_optimize_catch_errors関数を呼び出している。
torch._dynamo.eval_frame._optimize_catch_errors
torch/_dynamo/eval_frame.py:1071
def _optimize_catch_errors( compile_fn: convert_frame.ConvertFrameProtocol, hooks: Hooks, backend_ctx_ctor: Callable[ [], contextlib.AbstractContextManager[Any] ] = null_context, fullgraph: bool = False, error_on_graph_break: Optional[bool] = None, export: bool = False, dynamic: Optional[bool] = None, compiler_config: Optional[Any] = None, rebuild_ctx: Optional[Callable[[], Union[OptimizeContext, _NullDecorator]]] = None, package: Optional[CompilePackage] = None, ) -> OptimizeContext: return OptimizeContext( convert_frame.catch_errors_wrapper(compile_fn, hooks), backend_ctx_ctor=backend_ctx_ctor, first_ctx=True, fullgraph=fullgraph, error_on_graph_break=error_on_graph_break, export=export, dynamic=dynamic, compiler_config=compiler_config, rebuild_ctx=rebuild_ctx, package=package, hooks=hooks, )
convert_frame.catch_errors_wrapper関数を呼び出し、その戻り値を引数にOptimizeContextのインスタンスを作成して返している。
torch._dynamo.eval_frame.OptimizeContext.__init__
torch/_dynamo/eval_frame.py:924
class OptimizeContext(_TorchDynamoContext): def __init__( self, callback: DynamoCallback, backend_ctx_ctor: Callable[[], contextlib.AbstractContextManager[Any]], first_ctx: bool = False, *, fullgraph: bool = False, error_on_graph_break: Optional[bool] = None, export: bool = False, dynamic: Optional[bool] = None, compiler_config: Optional[Any] = None, rebuild_ctx: Optional[ Callable[[], Union[OptimizeContext, _NullDecorator]] ] = None, package: Optional[CompilePackage] = None, hooks: Optional[Hooks] = None, ) -> None: ... super().__init__( callback=callback, on_enter=on_enter, backend_ctx_ctor=backend_ctx_ctor, patch_fn=TorchPatcher.patch, first_ctx=first_ctx, fullgraph=fullgraph, error_on_graph_break=error_on_graph_break, export=export, dynamic=dynamic, compiler_config=compiler_config, package=package, hooks=hooks, ) ...
OptimizeContextの親クラス_TorchDynamoContextの__init__関数を呼び出している。ここでcallback引数が指定されていることを覚えておく。
torch._dynamo.eval_frame._TorchDynamoContext.__init__
torch/_dynamo/eval_frame.py:590
class _TorchDynamoContext: def __init__( self, callback: DynamoCallback, on_enter: Callable[[], Any] = nothing, backend_ctx_ctor: Callable[ [], contextlib.AbstractContextManager[Any] ] = null_context, patch_fn: Callable[[], Any] = nothing, first_ctx: bool = False, *, fullgraph: bool = False, error_on_graph_break: Optional[bool] = None, export: bool = False, dynamic: Optional[bool] = None, compiler_config: Optional[Any] = None, package: Optional[CompilePackage] = None, hooks: Optional[Hooks] = None, ) -> None: super().__init__() assert callable(callback) or callback is False or callback is None self.callback: DynamoCallback = callback ...
callbackの引数に与えられた値がself.callbackに保持される。
ここまででtorch._dynamo.eval_frame.optimizeの戻り値がOptimizeContextのインスタンスになることがわかった。次に、OptimizeContext.__call__を調べる。
torch._dynamo.eval_frame.OptimizeContext.__call__
OptimizeContextに__call__関数は実装されていないため、実際にはその親クラスの_TorchDynamoContext.__call__が呼ばれる。
torch/_dynamo/eval_frame.py:680
def __call__(self, fn: Any) -> Any: ... fn = innermost_fn(fn) ... callback: Callable[..., Any] = do_nothing if hasattr(self, "callback"): callback = self.callback # type: ignore[assignment] ... @functools.wraps(fn) def compile_wrapper(*args: Any, **kwargs: Any) -> Any: prior = set_eval_frame(None) try: ... _maybe_set_eval_frame(_callback_from_stance(callback)) try: return fn(*args, **kwargs) except Unsupported as e: if config.verbose: raise # strip internal tracebacks from causes cur_exn: BaseException = e while cur_exn.__cause__ is not None: cur_exn.__cause__.with_traceback(None) cur_exn = cur_exn.__cause__ raise e.with_traceback(None) from e.__cause__ # User compiler error except ShortenTraceback as e: # Failures in the backend likely don't have useful # data in the TorchDynamo frames, so we strip them out. raise e.remove_dynamo_frames() from None # see TORCHDYNAMO_VERBOSE=1 finally: # Restore the dynamic layer stack depth if necessary. set_eval_frame(None) ... finally: _maybe_set_eval_frame(prior) ... if callback not in (None, False): always_optimize_code_objects[fn.__code__] = True return compile_wrapper
まずfnに変換対象の関数を、callbackに先ほどのself.callbackの値が代入される。その後compiler_wrapper関数を構築して戻り値としている。つまり、torch._dynamo.eval_frame.explain内での2回目の関数呼び出しではこのcompiler_wrapper関数が実行される。
今回のまとめ
sequenceDiagram
torch._dynamo.eval_frame.explain ->>+ torch._dynamo.eval_frame.optimize: compile_fn
torch._dynamo.eval_frame.optimize ->>+ torch._dynamo.eval_frame._optimize: compile_fn
torch._dynamo.eval_frame._optimize ->>+ torch._dynamo.eval_frame._optimize_catch_errors: convert_frame.convert_frame(compile_fn)
torch._dynamo.eval_frame._optimize_catch_errors ->>+ torch._dynamo.eval_frame.OptimizeContext.__init__: convert_frame.catch_errors_wrapper(compile_fn)
torch._dynamo.eval_frame.OptimizeContext.__init__ ->>- torch._dynamo.eval_frame._optimize_catch_errors: OptimizeContext
torch._dynamo.eval_frame._optimize_catch_errors ->>- torch._dynamo.eval_frame._optimize: OptimizeContext
torch._dynamo.eval_frame._optimize ->>- torch._dynamo.eval_frame.optimize: OptimizeContext
torch._dynamo.eval_frame.optimize ->>- torch._dynamo.eval_frame.explain: OptimizeContext
torch._dynamo.eval_frame.explain ->>+ torch._dynamo.eval_frame.OptimizeContext.__call__: fn
torch._dynamo.eval_frame.OptimizeContext.__call__ ->>- torch._dynamo.eval_frame.explain: compiler_wrapper
torch._dynamo.eval_frame.explain ->>+ compiler_wrapper: args, kwargs
compiler_wrapper ->>- torch._dynamo.eval_frame.explain: a