1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file op_utils.h
22 * \brief Common utility used in operator construction.
23 */
24#ifndef TVM_TE_OPERATION_OP_UTILS_H_
25#define TVM_TE_OPERATION_OP_UTILS_H_
26
27#include <tvm/te/schedule.h>
28#include <tvm/tir/expr.h>
29
30#include <unordered_map>
31#include <unordered_set>
32#include <vector>
33
34#include "../../tir/transforms/arg_binder.h"
35#include "../../tir/transforms/ir_utils.h"
36#include "../schedule/message_passing.h"
37
38namespace tvm {
39namespace te {
40
41using tir::MergeNest;
42
43/*!
44 * \brief Build loop nest for stage.
45 *
46 * \param stage The stage to create a loop nest.
47 * \param dom_map The range of each iter var.
48 * \param begin_iter_pos The beginning position of leaf_iter_vars to generate loop.
49 * \param new_loop_var Whether create new loop variable.
50 * \param skip_iter Whether skip certain iteration.
51 * \param p_value_map The result value of each IterVar.
52 * \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1
53 */
54std::vector<std::vector<Stmt>> MakeLoopNest(const Stage& stage,
55 const std::unordered_map<IterVar, Range>& dom_map,
56 size_t begin_iter_pos, bool new_loop_var,
57 const std::unordered_set<IterVar>& skip_iter,
58 std::unordered_map<IterVar, PrimExpr>* p_value_map,
59 bool debug_keep_trivial_loop);
60
61/*!
62 * \brief Create a nest of if checking the predicates.
63 *
64 * \param predicates The predicates to be checked.
65 * \return List of If nest that checks the predicates.
66 */
67std::vector<Stmt> MakeIfNest(const std::vector<PrimExpr>& predicates);
68
69/*!
70 * \brief Replace the tensor reference (especially in Call's) in stmt by the replace map.
71 * \param stmt The statement to be processed.
72 * \param replace The replacement rule.
73 */
74Stmt ReplaceTensor(Stmt stmt, const std::unordered_map<Tensor, Tensor>& replace);
75/*!
76 * \brief Replace the tensor reference (especially in Call's) in primExpr by the replace map.
77 * \param expr The expression to be processed.
78 * \param replace The replacement rule.
79 */
80PrimExpr ReplaceTensor(PrimExpr expr, const std::unordered_map<Tensor, Tensor>& replace);
81
82/*!
83 * \brief Substitute the variables of stmt by value map.
84 * \param stmt the statment
85 * \param value_map The value map.
86 * \return Substituted result.
87 */
88Stmt Substitute(Stmt stmt, const std::unordered_map<IterVar, PrimExpr>& value_map);
89
90/*!
91 * \brief Substitute the variables of primExpr by value map.
92 * \param expr the expression to be processed.
93 * \param value_map The value map.
94 * \return Substituted result.
95 */
96PrimExpr Substitute(PrimExpr expr, const std::unordered_map<IterVar, PrimExpr>& value_map);
97
98/*!
99 * \brief Converts Halide ForKind to its corresponding IterVarType
100 * \param kind The ForKind to be converted
101 */
102IterVarType ForKindToIterVarType(tir::ForKind kind);
103
104/*!
105 * \brief Converts IterVarType to its corresponding Halide ForKind
106 * \param iter_type The IterVarType to be converted
107 */
108tir::ForKind IterVarTypeToForKind(IterVarType iter_type);
109
110} // namespace te
111} // namespace tvm
112#endif // TVM_TE_OPERATION_OP_UTILS_H_
113