1/** \brief Fusing linear patterns as single at::linear for easier pattern
2 * matching in later passes
3 */
4#pragma once
5
6#include <torch/csrc/jit/ir/ir.h>
7
8namespace torch {
9namespace jit {
10
11/** \brief Match the at::linear pattern and fuse it into a single at::linear
12 * This pass fuse the addmm or matmul + add generated by JIT back to linear
13 * This pass can be deleted once the JIT can emit the aten::linear in the future
14 */
15TORCH_API void FuseLinear(std::shared_ptr<Graph>& graph);
16
17/** Swap functional linear CallFunctions to aten::linear
18 */
19TORCH_API void SwapFunctionalLinear(std::shared_ptr<Graph>& graph);
20/** Swap all functional linear CallFunctions in module
21 */
22TORCH_API void SwapFunctionalLinear(Module& module);
23} // namespace jit
24} // namespace torch
25