1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file arg_binder.h
22 * \brief Helper utility to match and bind arguments.
23 */
24#ifndef TVM_TIR_TRANSFORMS_ARG_BINDER_H_
25#define TVM_TIR_TRANSFORMS_ARG_BINDER_H_
26
27#include <tvm/arith/analyzer.h>
28#include <tvm/tir/buffer.h>
29#include <tvm/tir/expr.h>
30
31#include <string>
32#include <unordered_map>
33#include <vector>
34
35namespace tvm {
36namespace tir {
37
38/*!
39 * \brief Helper utility to generate match and bind of arguments.
40 *
41 * \note There is many places in TVM IR where we need argument bindings.
42 *
43 * Consider a function f(tA(shape=var(n)), tB(shape=3), tC(shape=(n+2)).
44 * Here n is a undefined variable that is decided by the outside, tB imposes
45 * a constraint such that it can only take tensor with shape 3, tC imposes
46 * another constraint that it's shape must equals n + 2.
47 * So if we call it with f(bufferA, bufferB, bufferC), we need to generate
48 * the following binding sequence:
49 * - define n = bufferA.shape[0]
50 * - assert bufferB.shape[0] == 3
51 * - assert bufferB.shape[1] == n + 3
52 *
53 * In general, this is a constraint solving problem. We have simplified assumption
54 * over the binding declaration, such that we require the variable occurred in
55 * constraint must be declared in argument list. So it is illegal to have signature
56 * f(tA(shape=(n+3))) without any argument variable corresponds to n, even though
57 * it is already enough to derive n from the input argument.
58 */
59class ArgBinder {
60 public:
61 /*!
62 * \brief Constructor
63 * \param def_map A definition map that contains definition of known variables.
64 * ArgBinder will update this def_map when adding new definitions.
65 */
66 explicit ArgBinder(std::unordered_map<const VarNode*, PrimExpr>* def_map) : def_map_(def_map) {}
67 /*!
68 * \brief Try to bind arg to value, generate constraint if necessary.
69 * \param arg The argument to be binded.
70 * \param value The target expression value
71 * \param arg_name argument name.
72 * \param with_let Whether add lets during bind
73 */
74 void Bind(const PrimExpr& arg, const PrimExpr& value, const std::string& arg_name,
75 bool with_let = false);
76 /*!
77 * \brief Bind array to array
78 * \param arg The argument to be binded.
79 * \param value The target expression value
80 * \param arg_name argument name.
81 */
82 void BindArray(const Array<PrimExpr>& arg, const Array<PrimExpr>& value,
83 const std::string& arg_name);
84 /*!
85 * \brief Bind symbolic buffer to another symbolic buffer
86 * \param arg The argument to be binded.
87 * \param value The target expression value
88 * \param arg_name argument name.
89 * \param fuzzy_match If enabled, we allow value's dimension to be smaller than arg, as long as
90 * arg's higher dimensions are of 1.
91 */
92 void BindBuffer(const Buffer& arg, const Buffer& value, const std::string& arg_name,
93 bool fuzzy_match);
94 /*!
95 * \brief Bind symbolic buffer to a DLTensor handle.
96 * \param buffer The argument buffer to be binded.
97 * \param device_type The device id to be binded.
98 * \param device_id The device id to be binded.
99 * \param handle The DLTensor handle.
100 * \param arg_name argument name.
101 */
102 void BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, const PrimExpr& device_id,
103 const Var& handle, const std::string& arg_name);
104
105 /*! \return The defs generated in binding. */
106 const std::vector<Var>& defs() const { return defs_; }
107 /*! \return The asserts generated in binding */
108 const std::vector<Stmt>& asserts() const { return asserts_; }
109 /*!
110 * \brief Initialization nest generated
111 * This is only non-empty when BindDLTensor is called.
112 *
113 * \note The binder may choose to generate a let statement
114 * and simply put def_map to map Variable to itself,
115 * or update def_map to directly map to new value and not generate let statement.
116 *
117 * Let statement is usually generated when bind to DLTensor and memory load is involved.
118 * \return The initialization nest generated during binding.
119 */
120 const std::vector<Stmt>& init_nest() const { return init_nest_; }
121 /*! \return Handle data type of the data */
122 const Map<Var, PrimExpr>& def_handle_dtype() const { return def_handle_dtype_; }
123
124 private:
125 // Internal bind function
126 bool Bind_(const PrimExpr& arg, const PrimExpr& value, const std::string& arg_name,
127 bool with_lets);
128 /*! \brief The definition map, can be uses to substitute */
129 std::unordered_map<const VarNode*, PrimExpr>* def_map_;
130 /*! \brief defs generated in the current binder */
131 std::vector<Var> defs_;
132 /*! \brief Initialize nest */
133 std::vector<Stmt> init_nest_;
134 /*! \brief handle data type in the defintiions */
135 Map<Var, PrimExpr> def_handle_dtype_;
136 /*! \brief asserts generated */
137 std::vector<Stmt> asserts_;
138 /*! \brief internal analyzer. */
139 arith::Analyzer analyzer_;
140};
141} // namespace tir
142} // namespace tvm
143#endif // TVM_TIR_TRANSFORMS_ARG_BINDER_H_
144