1#pragma once
2
3#include <torch/csrc/Export.h>
4#include <torch/csrc/jit/ir/ir.h>
5#include <memory>
6
7namespace torch {
8namespace jit {
9
10struct Graph;
11
12// Run TensorExpressions-based fuser.
13// If add_composed_op is true, creates a single operation that
14// performs both the runtime check that types align
15// and then the dispatch to the kernel/unoptimized graph
16TORCH_API void FuseTensorExprs(
17 std::shared_ptr<Graph>& graph,
18 size_t min_group_size = 2,
19 bool add_composed_op = false,
20 bool fuse_to_dynamic_shapes = false);
21
22TORCH_API void setTensorExprFuserEnabled(bool val);
23TORCH_API bool tensorExprFuserEnabled();
24TORCH_API void setTensorExprDynamicShapeFusionEnabled(bool val);
25TORCH_API bool tensorExprDynamicShapeFusionEnabled();
26TORCH_API bool setTexprReductionsEnabled(bool value);
27TORCH_API bool texprReductionsEnabled();
28
29TORCH_API void RemoveProfileNodesAndSpecializeTypes(
30 std::shared_ptr<Graph>& graph);
31TORCH_API bool hasTensorTypeSpecialization(Value* v);
32TORCH_API void RemoveTensorTypeSpecializations(std::shared_ptr<Graph>& graph);
33TORCH_API void removeTensorTypeSpecializations(Block* block);
34
35using tensor_type_converter_t =
36 c10::function_ref<TensorTypePtr(const TensorTypePtr& t)>;
37
38// inserts a TypeCheck pattern
39//
40// around the guarded node that has a Subgraph attribute, this inserts a pattern
41//
42// if TypeCheck(...):
43// guarded_node
44// else:
45// FallbackGraph(...)
46//
47// The TypeCheck includes the types of all Tensor inputs to the guarded_node,
48// as processed by the type_converter, a lambda
49// TensorTypePtr(const TensorTypePtr& t). This allows to erase irrelevant
50// aspects of the type.
51//
52// The Fallback graph will have the same subgraph as the guarded node (with the
53// expectation that the guarded_node's subgraph will then be optimized.
54TORCH_API void insertTypeGuard(
55 Node* guarded_node,
56 tensor_type_converter_t type_converter,
57 c10::Symbol kind);
58
59TORCH_API bool usedOnlyInSize(Value* v);
60TORCH_API Value* broadcastSizes(at::ArrayRef<Value*> sizes, AliasDb* db);
61
62namespace tensorexpr {
63TORCH_API bool isSupported(Node* node);
64
65/// Get the modifiable custom operator set object.
66///
67/// For static shapes, if a custom operator has been added to the custom
68/// operator set, it will be pulled into the NNC fusion group. But it doesn't
69/// work with dynamic shapes unless explicitly register the shape function via
70/// `torch::jit::RegisterShapeComputeGraphForSchema` for the custom operator.
71///
72/// @return Reference of the custome operator set
73///
74TORCH_API OperatorSet& getCustomOperatorSet();
75} // namespace tensorexpr
76} // namespace jit
77} // namespace torch
78