1 | #pragma once |
2 | |
3 | #include <torch/csrc/Export.h> |
4 | #include <torch/csrc/jit/ir/ir.h> |
5 | #include <memory> |
6 | |
7 | namespace torch { |
8 | namespace jit { |
9 | |
10 | struct Graph; |
11 | |
12 | // Run TensorExpressions-based fuser. |
13 | // If add_composed_op is true, creates a single operation that |
14 | // performs both the runtime check that types align |
15 | // and then the dispatch to the kernel/unoptimized graph |
16 | TORCH_API void FuseTensorExprs( |
17 | std::shared_ptr<Graph>& graph, |
18 | size_t min_group_size = 2, |
19 | bool add_composed_op = false, |
20 | bool fuse_to_dynamic_shapes = false); |
21 | |
22 | TORCH_API void setTensorExprFuserEnabled(bool val); |
23 | TORCH_API bool tensorExprFuserEnabled(); |
24 | TORCH_API void setTensorExprDynamicShapeFusionEnabled(bool val); |
25 | TORCH_API bool tensorExprDynamicShapeFusionEnabled(); |
26 | TORCH_API bool setTexprReductionsEnabled(bool value); |
27 | TORCH_API bool texprReductionsEnabled(); |
28 | |
29 | TORCH_API void RemoveProfileNodesAndSpecializeTypes( |
30 | std::shared_ptr<Graph>& graph); |
31 | TORCH_API bool hasTensorTypeSpecialization(Value* v); |
32 | TORCH_API void RemoveTensorTypeSpecializations(std::shared_ptr<Graph>& graph); |
33 | TORCH_API void removeTensorTypeSpecializations(Block* block); |
34 | |
35 | using tensor_type_converter_t = |
36 | c10::function_ref<TensorTypePtr(const TensorTypePtr& t)>; |
37 | |
38 | // inserts a TypeCheck pattern |
39 | // |
40 | // around the guarded node that has a Subgraph attribute, this inserts a pattern |
41 | // |
42 | // if TypeCheck(...): |
43 | // guarded_node |
44 | // else: |
45 | // FallbackGraph(...) |
46 | // |
47 | // The TypeCheck includes the types of all Tensor inputs to the guarded_node, |
48 | // as processed by the type_converter, a lambda |
49 | // TensorTypePtr(const TensorTypePtr& t). This allows to erase irrelevant |
50 | // aspects of the type. |
51 | // |
52 | // The Fallback graph will have the same subgraph as the guarded node (with the |
53 | // expectation that the guarded_node's subgraph will then be optimized. |
54 | TORCH_API void insertTypeGuard( |
55 | Node* guarded_node, |
56 | tensor_type_converter_t type_converter, |
57 | c10::Symbol kind); |
58 | |
59 | TORCH_API bool usedOnlyInSize(Value* v); |
60 | TORCH_API Value* broadcastSizes(at::ArrayRef<Value*> sizes, AliasDb* db); |
61 | |
62 | namespace tensorexpr { |
63 | TORCH_API bool isSupported(Node* node); |
64 | |
65 | /// Get the modifiable custom operator set object. |
66 | /// |
67 | /// For static shapes, if a custom operator has been added to the custom |
68 | /// operator set, it will be pulled into the NNC fusion group. But it doesn't |
69 | /// work with dynamic shapes unless explicitly register the shape function via |
70 | /// `torch::jit::RegisterShapeComputeGraphForSchema` for the custom operator. |
71 | /// |
72 | /// @return Reference of the custome operator set |
73 | /// |
74 | TORCH_API OperatorSet& getCustomOperatorSet(); |
75 | } // namespace tensorexpr |
76 | } // namespace jit |
77 | } // namespace torch |
78 | |