1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \brief Helper utilities to implement compute_op.
22 * \file compute_op.h
23 */
24#ifndef TVM_TE_OPERATION_COMPUTE_OP_H_
25#define TVM_TE_OPERATION_COMPUTE_OP_H_
26
27#include <tvm/te/operation.h>
28#include <tvm/tir/expr.h>
29
30#include <unordered_map>
31#include <vector>
32
33namespace tvm {
34namespace te {
35// loop nest structure for general compute
36// This the loop nest structured used in compute.
37// Does not include the loop body.
38struct ComputeLoopNest {
39 // The common number of loops between init and main
40 size_t num_common_loop;
41 // predicates for the initialize loop
42 std::vector<PrimExpr> init_predicates;
43 // Initialization nest involved.
44 std::vector<std::vector<Stmt>> init_nest;
45 // Value map for the init code
46 std::unordered_map<IterVar, PrimExpr> init_vmap;
47 // Predicates for the main update loop
48 std::vector<PrimExpr> main_predicates;
49 // The general loop nest
50 std::vector<std::vector<Stmt>> main_nest;
51 // Value map for the IterVar.
52 std::unordered_map<IterVar, PrimExpr> main_vmap;
53
54 /*!
55 * \brief constructor to build ComputeOpNest
56 * \param self The pointer to compute op.
57 * \param stage The scxhedule stage.
58 * \param dom_map The domain map.
59 * \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1
60 * \return The constructed loop nest
61 */
62 static ComputeLoopNest Create(const BaseComputeOpNode* self, const Stage& stage,
63 const std::unordered_map<IterVar, Range>& dom_map,
64 bool debug_keep_trivial_loop);
65};
66
67/*!
68 * \brief Build body of compute for cross thread reduction pattern.
69 * \param self The pointer to ComputeOpNode
70 * \param stage The schedule stage.
71 * \param dom_map The domain map.
72 * \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1
73 * \return The created statement.
74 */
75Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage,
76 const std::unordered_map<IterVar, Range>& dom_map,
77 bool debug_keep_trivial_loop);
78
79/*!
80 * \brief Build body of compute for tensorization.
81 * \param self The pointer to ComputeOpNode
82 * \param stage The schedule stage.
83 * \param dom_map The domain map.
84 * \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1
85 * \return The created statement.
86 */
87Stmt MakeTensorize(const ComputeOpNode* self, const Stage& stage,
88 const std::unordered_map<IterVar, Range>& dom_map, bool debug_keep_trivial_loop);
89
90/*!
91 * \brief Transform the update part when there is no init func in tensorizing
92 * \param stage The stage for tensorizing.
93 * \param dom_map The range of each iter var.
94 * \param n The loop nest structured used in compute.
95 * \param body The body func in tensorize intrin
96 * \param update The update func in tensorize intrin
97 * \return Transformed result.
98 */
99Stmt TransformUpdate(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
100 const ComputeLoopNest& n, Stmt body, Stmt update);
101} // namespace te
102} // namespace tvm
103
104#endif // TVM_TE_OPERATION_COMPUTE_OP_H_
105