1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file arg_binder.h |
22 | * \brief Helper utility to match and bind arguments. |
23 | */ |
24 | #ifndef TVM_TIR_TRANSFORMS_ARG_BINDER_H_ |
25 | #define TVM_TIR_TRANSFORMS_ARG_BINDER_H_ |
26 | |
27 | #include <tvm/arith/analyzer.h> |
28 | #include <tvm/tir/buffer.h> |
29 | #include <tvm/tir/expr.h> |
30 | |
31 | #include <string> |
32 | #include <unordered_map> |
33 | #include <vector> |
34 | |
35 | namespace tvm { |
36 | namespace tir { |
37 | |
38 | /*! |
39 | * \brief Helper utility to generate match and bind of arguments. |
40 | * |
41 | * \note There is many places in TVM IR where we need argument bindings. |
42 | * |
43 | * Consider a function f(tA(shape=var(n)), tB(shape=3), tC(shape=(n+2)). |
44 | * Here n is a undefined variable that is decided by the outside, tB imposes |
45 | * a constraint such that it can only take tensor with shape 3, tC imposes |
46 | * another constraint that it's shape must equals n + 2. |
47 | * So if we call it with f(bufferA, bufferB, bufferC), we need to generate |
48 | * the following binding sequence: |
49 | * - define n = bufferA.shape[0] |
50 | * - assert bufferB.shape[0] == 3 |
51 | * - assert bufferB.shape[1] == n + 3 |
52 | * |
53 | * In general, this is a constraint solving problem. We have simplified assumption |
54 | * over the binding declaration, such that we require the variable occurred in |
55 | * constraint must be declared in argument list. So it is illegal to have signature |
56 | * f(tA(shape=(n+3))) without any argument variable corresponds to n, even though |
57 | * it is already enough to derive n from the input argument. |
58 | */ |
59 | class ArgBinder { |
60 | public: |
61 | /*! |
62 | * \brief Constructor |
63 | * \param def_map A definition map that contains definition of known variables. |
64 | * ArgBinder will update this def_map when adding new definitions. |
65 | */ |
66 | explicit ArgBinder(std::unordered_map<const VarNode*, PrimExpr>* def_map) : def_map_(def_map) {} |
67 | /*! |
68 | * \brief Try to bind arg to value, generate constraint if necessary. |
69 | * \param arg The argument to be binded. |
70 | * \param value The target expression value |
71 | * \param arg_name argument name. |
72 | * \param with_let Whether add lets during bind |
73 | */ |
74 | void Bind(const PrimExpr& arg, const PrimExpr& value, const std::string& arg_name, |
75 | bool with_let = false); |
76 | /*! |
77 | * \brief Bind array to array |
78 | * \param arg The argument to be binded. |
79 | * \param value The target expression value |
80 | * \param arg_name argument name. |
81 | */ |
82 | void BindArray(const Array<PrimExpr>& arg, const Array<PrimExpr>& value, |
83 | const std::string& arg_name); |
84 | /*! |
85 | * \brief Bind symbolic buffer to another symbolic buffer |
86 | * \param arg The argument to be binded. |
87 | * \param value The target expression value |
88 | * \param arg_name argument name. |
89 | * \param fuzzy_match If enabled, we allow value's dimension to be smaller than arg, as long as |
90 | * arg's higher dimensions are of 1. |
91 | */ |
92 | void BindBuffer(const Buffer& arg, const Buffer& value, const std::string& arg_name, |
93 | bool fuzzy_match); |
94 | /*! |
95 | * \brief Bind symbolic buffer to a DLTensor handle. |
96 | * \param buffer The argument buffer to be binded. |
97 | * \param device_type The device id to be binded. |
98 | * \param device_id The device id to be binded. |
99 | * \param handle The DLTensor handle. |
100 | * \param arg_name argument name. |
101 | */ |
102 | void BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, const PrimExpr& device_id, |
103 | const Var& handle, const std::string& arg_name); |
104 | |
105 | /*! \return The defs generated in binding. */ |
106 | const std::vector<Var>& defs() const { return defs_; } |
107 | /*! \return The asserts generated in binding */ |
108 | const std::vector<Stmt>& asserts() const { return asserts_; } |
109 | /*! |
110 | * \brief Initialization nest generated |
111 | * This is only non-empty when BindDLTensor is called. |
112 | * |
113 | * \note The binder may choose to generate a let statement |
114 | * and simply put def_map to map Variable to itself, |
115 | * or update def_map to directly map to new value and not generate let statement. |
116 | * |
117 | * Let statement is usually generated when bind to DLTensor and memory load is involved. |
118 | * \return The initialization nest generated during binding. |
119 | */ |
120 | const std::vector<Stmt>& init_nest() const { return init_nest_; } |
121 | /*! \return Handle data type of the data */ |
122 | const Map<Var, PrimExpr>& def_handle_dtype() const { return def_handle_dtype_; } |
123 | |
124 | private: |
125 | // Internal bind function |
126 | bool Bind_(const PrimExpr& arg, const PrimExpr& value, const std::string& arg_name, |
127 | bool with_lets); |
128 | /*! \brief The definition map, can be uses to substitute */ |
129 | std::unordered_map<const VarNode*, PrimExpr>* def_map_; |
130 | /*! \brief defs generated in the current binder */ |
131 | std::vector<Var> defs_; |
132 | /*! \brief Initialize nest */ |
133 | std::vector<Stmt> init_nest_; |
134 | /*! \brief handle data type in the defintiions */ |
135 | Map<Var, PrimExpr> def_handle_dtype_; |
136 | /*! \brief asserts generated */ |
137 | std::vector<Stmt> asserts_; |
138 | /*! \brief internal analyzer. */ |
139 | arith::Analyzer analyzer_; |
140 | }; |
141 | } // namespace tir |
142 | } // namespace tvm |
143 | #endif // TVM_TIR_TRANSFORMS_ARG_BINDER_H_ |
144 | |