1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file op_utils.h |
22 | * \brief Common utility used in operator construction. |
23 | */ |
24 | #ifndef TVM_TE_OPERATION_OP_UTILS_H_ |
25 | #define TVM_TE_OPERATION_OP_UTILS_H_ |
26 | |
27 | #include <tvm/te/schedule.h> |
28 | #include <tvm/tir/expr.h> |
29 | |
30 | #include <unordered_map> |
31 | #include <unordered_set> |
32 | #include <vector> |
33 | |
34 | #include "../../tir/transforms/arg_binder.h" |
35 | #include "../../tir/transforms/ir_utils.h" |
36 | #include "../schedule/message_passing.h" |
37 | |
38 | namespace tvm { |
39 | namespace te { |
40 | |
41 | using tir::MergeNest; |
42 | |
43 | /*! |
44 | * \brief Build loop nest for stage. |
45 | * |
46 | * \param stage The stage to create a loop nest. |
47 | * \param dom_map The range of each iter var. |
48 | * \param begin_iter_pos The beginning position of leaf_iter_vars to generate loop. |
49 | * \param new_loop_var Whether create new loop variable. |
50 | * \param skip_iter Whether skip certain iteration. |
51 | * \param p_value_map The result value of each IterVar. |
52 | * \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1 |
53 | */ |
54 | std::vector<std::vector<Stmt>> MakeLoopNest(const Stage& stage, |
55 | const std::unordered_map<IterVar, Range>& dom_map, |
56 | size_t begin_iter_pos, bool new_loop_var, |
57 | const std::unordered_set<IterVar>& skip_iter, |
58 | std::unordered_map<IterVar, PrimExpr>* p_value_map, |
59 | bool debug_keep_trivial_loop); |
60 | |
61 | /*! |
62 | * \brief Create a nest of if checking the predicates. |
63 | * |
64 | * \param predicates The predicates to be checked. |
65 | * \return List of If nest that checks the predicates. |
66 | */ |
67 | std::vector<Stmt> MakeIfNest(const std::vector<PrimExpr>& predicates); |
68 | |
69 | /*! |
70 | * \brief Replace the tensor reference (especially in Call's) in stmt by the replace map. |
71 | * \param stmt The statement to be processed. |
72 | * \param replace The replacement rule. |
73 | */ |
74 | Stmt ReplaceTensor(Stmt stmt, const std::unordered_map<Tensor, Tensor>& replace); |
75 | /*! |
76 | * \brief Replace the tensor reference (especially in Call's) in primExpr by the replace map. |
77 | * \param expr The expression to be processed. |
78 | * \param replace The replacement rule. |
79 | */ |
80 | PrimExpr ReplaceTensor(PrimExpr expr, const std::unordered_map<Tensor, Tensor>& replace); |
81 | |
82 | /*! |
83 | * \brief Substitute the variables of stmt by value map. |
84 | * \param stmt the statment |
85 | * \param value_map The value map. |
86 | * \return Substituted result. |
87 | */ |
88 | Stmt Substitute(Stmt stmt, const std::unordered_map<IterVar, PrimExpr>& value_map); |
89 | |
90 | /*! |
91 | * \brief Substitute the variables of primExpr by value map. |
92 | * \param expr the expression to be processed. |
93 | * \param value_map The value map. |
94 | * \return Substituted result. |
95 | */ |
96 | PrimExpr Substitute(PrimExpr expr, const std::unordered_map<IterVar, PrimExpr>& value_map); |
97 | |
98 | /*! |
99 | * \brief Converts Halide ForKind to its corresponding IterVarType |
100 | * \param kind The ForKind to be converted |
101 | */ |
102 | IterVarType ForKindToIterVarType(tir::ForKind kind); |
103 | |
104 | /*! |
105 | * \brief Converts IterVarType to its corresponding Halide ForKind |
106 | * \param iter_type The IterVarType to be converted |
107 | */ |
108 | tir::ForKind IterVarTypeToForKind(IterVarType iter_type); |
109 | |
110 | } // namespace te |
111 | } // namespace tvm |
112 | #endif // TVM_TE_OPERATION_OP_UTILS_H_ |
113 | |