diff --git a/en-wordlist.txt b/en-wordlist.txt index ea2ed6f77d..9f318cb99f 100644 --- a/en-wordlist.txt +++ b/en-wordlist.txt @@ -691,3 +691,7 @@ XPU XPUs impl overrideable +TorchServe +Inductor’s +onwards +recompilations diff --git a/recipes_source/recipes_index.rst b/recipes_source/recipes_index.rst index 632efebb5c..f136c4b9c6 100644 --- a/recipes_source/recipes_index.rst +++ b/recipes_source/recipes_index.rst @@ -348,11 +348,20 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu .. customcarditem:: :header: Compile Time Caching in ``torch.compile`` - :card_description: Learn how to configure compile time caching in ``torch.compile`` + :card_description: Learn how to use compile time caching in ``torch.compile`` :image: ../_static/img/thumbnails/cropped/generic-pytorch-logo.png :link: ../recipes/torch_compile_caching_tutorial.html :tags: Model-Optimization +.. Compile Time Caching Configurations + +.. customcarditem:: + :header: Compile Time Caching Configurations + :card_description: Learn how to configure compile time caching in ``torch.compile`` + :image: ../_static/img/thumbnails/cropped/generic-pytorch-logo.png + :link: ../recipes/torch_compile_caching_configuration_tutorial.html + :tags: Model-Optimization + .. Reducing Cold Start Compilation Time with Regional Compilation .. customcarditem:: diff --git a/recipes_source/torch_compile_caching_configuration_tutorial.rst b/recipes_source/torch_compile_caching_configuration_tutorial.rst new file mode 100644 index 0000000000..21565d0562 --- /dev/null +++ b/recipes_source/torch_compile_caching_configuration_tutorial.rst @@ -0,0 +1,78 @@ +Compile Time Caching Configuration +========================================================= +**Authors:** `Oguz Ulgen `_ and `Sam Larsen `_ + +Introduction +------------------ + +PyTorch Compiler implements several caches to reduce compilation latency. +This recipe demonstrates how you can configure various parts of the caching in ``torch.compile``. + +Prerequisites +------------------- + +Before starting this recipe, make sure that you have the following: + +* Basic understanding of ``torch.compile``. See: + + * `torch.compiler API documentation `__ + * `Introduction to torch.compile `__ + * `Compile Time Caching in torch.compile `__ + +* PyTorch 2.4 or later + +Inductor Cache Settings +---------------------------- + +Most of these caches are in-memory, only used within the same process, and are transparent to the user. An exception is caches that store compiled FX graphs (``FXGraphCache``, ``AOTAutogradCache``). These caches allow Inductor to avoid recompilation across process boundaries when it encounters the same graph with the same Tensor input shapes (and the same configuration). The default implementation stores compiled artifacts in the system temp directory. An optional feature also supports sharing those artifacts within a cluster by storing them in a Redis database. + +There are a few settings relevant to caching and to FX graph caching in particular. +The settings are accessible via environment variables listed below or can be hard-coded in the Inductor’s config file. + +TORCHINDUCTOR_FX_GRAPH_CACHE +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +This setting enables the local FX graph cache feature, which stores artifacts in the host’s temp directory. Setting it to ``1`` enables the feature while any other value disables it. By default, the disk location is per username, but users can enable sharing across usernames by specifying ``TORCHINDUCTOR_CACHE_DIR`` (below). + +TORCHINDUCTOR_AUTOGRAD_CACHE +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +This setting extends ``FXGraphCache`` to store cached results at the ``AOTAutograd`` level, rather than at the Inductor level. Setting it to ``1`` enables this feature, while any other value disables it. +By default, the disk location is per username, but users can enable sharing across usernames by specifying ``TORCHINDUCTOR_CACHE_DIR`` (below). +``TORCHINDUCTOR_AUTOGRAD_CACHE`` requires ``TORCHINDUCTOR_FX_GRAPH_CACHE`` to work. The same cache dir stores cache entries for ``AOTAutogradCache`` (under ``{TORCHINDUCTOR_CACHE_DIR}/aotautograd``) and ``FXGraphCache`` (under ``{TORCHINDUCTOR_CACHE_DIR}/fxgraph``). + +TORCHINDUCTOR_CACHE_DIR +~~~~~~~~~~~~~~~~~~~~~~~~ +This setting specifies the location of all on-disk caches. By default, the location is in the system temp directory under ``torchinductor_``, for example, ``/tmp/torchinductor_myusername``. + +Note that if ``TRITON_CACHE_DIR`` is not set in the environment, Inductor sets the ``Triton`` cache directory to this same temp location, under the Triton sub-directory. + +TORCHINDUCTOR_FX_GRAPH_REMOTE_CACHE +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +This setting enables the remote FX graph cache feature. The current implementation uses ``Redis``. ``1`` enables caching, and any other value disables it. The following environment variables configure the host and port of the Redis server: + +``TORCHINDUCTOR_REDIS_HOST`` (defaults to ``localhost``) +``TORCHINDUCTOR_REDIS_PORT`` (defaults to ``6379``) + +.. note:: + + Note that if Inductor locates a remote cache entry, it stores the compiled artifact in the local on-disk cache; that local artifact would be served on subsequent runs on the same machine. + +TORCHINDUCTOR_AUTOGRAD_REMOTE_CACHE +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Similar to ``TORCHINDUCTOR_FX_GRAPH_REMOTE_CACHE``, this setting enables the remote ``AOTAutogradCache`` feature. The current implementation uses Redis. Setting it to ``1`` enables caching, while any other value disables it. The following environment variables are used to configure the host and port of the ``Redis`` server: +* ``TORCHINDUCTOR_REDIS_HOST`` (defaults to ``localhost``) +* ``TORCHINDUCTOR_REDIS_PORT`` (defaults to ``6379``) + +`TORCHINDUCTOR_AUTOGRAD_REMOTE_CACHE`` requires ``TORCHINDUCTOR_FX_GRAPH_REMOTE_CACHE`` to be enabled in order to function. The same Redis server can be used to store both AOTAutograd and FXGraph cache results. + +TORCHINDUCTOR_AUTOTUNE_REMOTE_CACHE +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +This setting enables a remote cache for ``TorchInductor``’s autotuner. Similar to remote FX graph cache, the current implementation uses Redis. Setting it to ``1`` enables caching, while any other value disables it. The same host / port environment variables mentioned above apply to this cache. + +TORCHINDUCTOR_FORCE_DISABLE_CACHES +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Set this value to ``1`` to disable all Inductor caching. This setting is useful for tasks like experimenting with cold-start compile times or forcing recompilation for debugging purposes. + +Conclusion +------------- +In this recipe, we have learned how to configure PyTorch Compiler's caching mechanisms. Additionally, we explored the various settings and environment variables that allow users to configure and optimize these caching features according to their specific needs. + diff --git a/recipes_source/torch_compile_caching_tutorial.rst b/recipes_source/torch_compile_caching_tutorial.rst index 3c024828f9..ebc831cdb9 100644 --- a/recipes_source/torch_compile_caching_tutorial.rst +++ b/recipes_source/torch_compile_caching_tutorial.rst @@ -1,12 +1,16 @@ Compile Time Caching in ``torch.compile`` ========================================================= -**Authors:** `Oguz Ulgen `_ and `Sam Larsen `_ +**Author:** `Oguz Ulgen `_ Introduction ------------------ -PyTorch Inductor implements several caches to reduce compilation latency. -This recipe demonstrates how you can configure various parts of the caching in ``torch.compile``. +PyTorch Compiler provides several caching offerings to reduce compilation latency. +This recipe will explain these offerings in detail to help users pick the best option for their use case. + +Check out `Compile Time Caching Configurations `__ for how to configure these caches. + +Also check out our caching benchmark at `PT CacheBench Benchmarks `__. Prerequisites ------------------- @@ -17,60 +21,83 @@ Before starting this recipe, make sure that you have the following: * `torch.compiler API documentation `__ * `Introduction to torch.compile `__ + * `Triton language documentation `__ * PyTorch 2.4 or later -Inductor Cache Settings ----------------------------- +Caching Offerings +--------------------- + +``torch.compile`` provides the following caching offerings: + +* End to end caching (also known as ``Mega-Cache``) +* Modular caching of ``TorchDynamo``, ``TorchInductor``, and ``Triton`` + +It is important to note that caching validates that the cache artifacts are used with the same PyTorch and Triton version, as well as, same GPU when device is set to be cuda. + +``torch.compile`` end-to-end caching (``Mega-Cache``) +------------------------------------------------------------ + +End to end caching, from here onwards referred to ``Mega-Cache``, is the ideal solution for users looking for a portable caching solution that can be stored in a database and can later be fetched possibly on a separate machine. + +``Mega-Cache`` provides two compiler APIs: + +* ``torch.compiler.save_cache_artifacts()`` +* ``torch.compiler.load_cache_artifacts()`` + +The intended use case is after compiling and executing a model, the user calls ``torch.compiler.save_cache_artifacts()`` which will return the compiler artifacts in a portable form. Later, potentially on a different machine, the user may call ``torch.compiler.load_cache_artifacts()`` with these artifacts to pre-populate the ``torch.compile`` caches in order to jump-start their cache. + +Consider the following example. First, compile and save the cache artifacts. + +.. code-block:: python + + @torch.compile + def fn(x, y): + return x.sin() @ y + + a = torch.rand(100, 100, dtype=dtype, device=device) + b = torch.rand(100, 100, dtype=dtype, device=device) + + result = fn(a, b) + + artifacts = torch.compiler.save_cache_artifacts() + + # Now, potentially store these artifacts in a database + +Later, you can jump-start the cache by the following: -Most of these caches are in-memory, only used within the same process, and are transparent to the user. An exception is caches that store compiled FX graphs (FXGraphCache, AOTAutogradCache). These caches allow Inductor to avoid recompilation across process boundaries when it encounters the same graph with the same Tensor input shapes (and the same configuration). The default implementation stores compiled artifacts in the system temp directory. An optional feature also supports sharing those artifacts within a cluster by storing them in a Redis database. +.. code-block:: python -There are a few settings relevant to caching and to FX graph caching in particular. -The settings are accessible via environment variables listed below or can be hard-coded in Inductor’s config file. + # Potentially download/fetch the artifacts from the database + assert artifacts is not None + artifact_bytes, cache_info = artifacts -TORCHINDUCTOR_FX_GRAPH_CACHE -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -This setting enables the local FX graph cache feature, i.e., by storing artifacts in the host’s temp directory. ``1`` enables, and any other value disables it. By default, the disk location is per username, but users can enable sharing across usernames by specifying ``TORCHINDUCTOR_CACHE_DIR`` (below). + torch.compiler.load_cache_artifacts(artifact_bytes) -TORCHINDUCTOR_AUTOGRAD_CACHE -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -This setting extends FXGraphCache to store cached results at the AOTAutograd level, instead of at the Inductor level. ``1`` enables, and any other value disables it. -By default, the disk location is per username, but users can enable sharing across usernames by specifying ``TORCHINDUCTOR_CACHE_DIR`` (below). -`TORCHINDUCTOR_AUTOGRAD_CACHE` requires `TORCHINDUCTOR_FX_GRAPH_CACHE` to work. The same cache dir stores cache entries for AOTAutogradCache (under `{TORCHINDUCTOR_CACHE_DIR}/aotautograd`) and FXGraphCache (under `{TORCHINDUCTOR_CACHE_DIR}/fxgraph`). +This operation populates all the modular caches that will be discussed in the next section, including ``PGO``, ``AOTAutograd``, ``Inductor``, ``Triton``, and ``Autotuning``. -TORCHINDUCTOR_CACHE_DIR -~~~~~~~~~~~~~~~~~~~~~~~~ -This setting specifies the location of all on-disk caches. By default, the location is in the system temp directory under ``torchinductor_``, for example, ``/tmp/torchinductor_myusername``. -Note that if ``TRITON_CACHE_DIR`` is not set in the environment, Inductor sets the Triton cache directory to this same temp location, under the Triton subdirectory. +Modular caching of ``TorchDynamo``, ``TorchInductor``, and ``Triton`` +----------------------------------------------------------- -TORCHINDUCTOR_FX_GRAPH_REMOTE_CACHE -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -This setting enables the remote FX graph cache feature. The current implementation uses Redis. ``1`` enables caching, and any other value disables it. The following environment variables configure the host and port of the Redis server: +The aforementioned ``Mega-Cache`` is composed of individual components that can be used without any user intervention. By default, PyTorch Compiler comes with local on-disk caches for ``TorchDynamo``, ``TorchInductor``, and ``Triton``. These caches include: -``TORCHINDUCTOR_REDIS_HOST`` (defaults to ``localhost``) -``TORCHINDUCTOR_REDIS_PORT`` (defaults to ``6379``) +* ``FXGraphCache``: A cache of graph-based IR components used in compilation. +* ``TritonCache``: A cache of Triton-compilation results, including ``cubin`` files generated by ``Triton`` and other caching artifacts. +* ``InductorCache``: A bundle of ``FXGraphCache`` and ``Triton`` cache. +* ``AOTAutogradCache``: A cache of joint graph artifacts. +* ``PGO-cache``: A cache of dynamic shape decisions to reduce number of recompilations. -Note that if Inductor locates a remote cache entry, it stores the compiled artifact in the local on-disk cache; that local artifact would be served on subsequent runs on the same machine. +All these cache artifacts are written to ``TORCHINDUCTOR_CACHE_DIR`` which by default will look like ``/tmp/torchinductor_myusername``. -TORCHINDUCTOR_AUTOGRAD_REMOTE_CACHE -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Like TORCHINDUCTOR_FX_GRAPH_REMOTE_CACHE, this setting enables the remote AOT AutogradCache feature. The current implementation uses Redis. ``1`` enables caching, and any other value disables it. The following environment variables configure the host and port of the Redis server: -``TORCHINDUCTOR_REDIS_HOST`` (defaults to ``localhost``) -``TORCHINDUCTOR_REDIS_PORT`` (defaults to ``6379``) -`TORCHINDUCTOR_AUTOGRAD_REMOTE_CACHE`` depends on `TORCHINDUCTOR_FX_GRAPH_REMOTE_CACHE` to be enabled to work. The same Redis server can store both AOTAutograd and FXGraph cache results. +Remote Caching +---------------- -TORCHINDUCTOR_AUTOTUNE_REMOTE_CACHE -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -This setting enables a remote cache for Inductor’s autotuner. As with the remote FX graph cache, the current implementation uses Redis. ``1`` enables caching, and any other value disables it. The same host / port environment variables listed above apply to this cache. +We also provide a remote caching option for users who would like to take advantage of a Redis based cache. Check out `Compile Time Caching Configurations `__ to learn more about how to enable the Redis-based caching. -TORCHINDUCTOR_FORCE_DISABLE_CACHES -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Set this value to ``1`` to disable all Inductor caching. This setting is useful for tasks like experimenting with cold-start compile times or forcing recompilation for debugging purposes. Conclusion ------------- In this recipe, we have learned that PyTorch Inductor's caching mechanisms significantly reduce compilation latency by utilizing both local and remote caches, which operate seamlessly in the background without requiring user intervention. -Additionally, we explored the various settings and environment variables that allow users to configure and optimize these caching features according to their specific needs.