1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file message_passing.h |
22 | * \brief Common utilities to do message passing |
23 | * on the schedule hyper graph. |
24 | */ |
25 | #ifndef TVM_TE_SCHEDULE_MESSAGE_PASSING_H_ |
26 | #define TVM_TE_SCHEDULE_MESSAGE_PASSING_H_ |
27 | |
28 | #include <tvm/arith/analyzer.h> |
29 | #include <tvm/te/operation.h> |
30 | #include <tvm/te/schedule.h> |
31 | #include <tvm/tir/expr.h> |
32 | |
33 | #include <unordered_map> |
34 | #include <unordered_set> |
35 | #include <vector> |
36 | |
37 | namespace tvm { |
38 | namespace te { |
39 | /*! |
40 | * \brief Downward inference of domain of each IterVar. |
41 | * Caller set the range of the root, then the function |
42 | * propagates it towards the leaves. |
43 | * |
44 | * \param stage The stage to operate on. |
45 | * \param p_state The state of the message passing. |
46 | * \param analyzer Analyzer context, storing information about bounds in p_state. |
47 | * \param allow_missing Whether allow missing value. |
48 | */ |
49 | void PassDownDomain(const Stage& stage, std::unordered_map<IterVar, Range>* p_state, |
50 | arith::Analyzer* analyzer, bool allow_missing = false); |
51 | |
52 | /*! |
53 | * \param Upward inference of index of each IterVar. |
54 | * given index assignement of the leaves, |
55 | * |
56 | * \param stage The stage to operate on. |
57 | * \param dom_map The domain map of each iteration variable's domain. |
58 | * \param p_state The index state of each IterVar. |
59 | * \param allow_missing Whether allow missing value. |
60 | */ |
61 | void PassUpIndex(const Stage& stage, const Map<IterVar, Range>& dom_map, |
62 | std::unordered_map<IterVar, PrimExpr>* p_state, bool allow_missing = false); |
63 | |
64 | /*! |
65 | * \param Downward inference of index of each IterVar. |
66 | * given index assignement of roots. |
67 | * |
68 | * \param stage The stage to operate on. |
69 | * \param dom_map The domain map of each iteration variable's domain. |
70 | * \param p_state The index state of each IterVar. |
71 | * \param allow_missing Whether allow missing value. |
72 | */ |
73 | void PassDownIndex(const Stage& stage, const Map<IterVar, Range>& dom_map, |
74 | std::unordered_map<IterVar, PrimExpr>* p_state, bool allow_missing = false); |
75 | |
76 | /*! |
77 | * \param Upward inference of domain set of each IterVar. |
78 | * given domain assignment of the leaves, |
79 | * |
80 | * \param stage The stage to operate on. |
81 | * \param dom_map The domain map of each iteration variable's maximum domain. |
82 | * \param p_state The index state of each IterVar. |
83 | */ |
84 | void PassUpDomain(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map, |
85 | std::unordered_map<IterVar, IntSet>* p_state); |
86 | |
87 | /*! |
88 | * \brief Upward message passing of bitmask with or relation. |
89 | * \param stage The stage to operate on. |
90 | * \param p_state The index state of each IterVar. |
91 | * \param allow_missing Whether allow missing value. |
92 | */ |
93 | void PassUpBitMaskOr(const Stage& stage, std::unordered_map<IterVar, int>* p_state, |
94 | bool allow_missing = false); |
95 | |
96 | /*! |
97 | * \brief Downward message passing of bitmask with or relation. |
98 | * \param stage The stage to operate on. |
99 | * \param p_state The index state of each IterVar. |
100 | * \param allow_missing Whether allow missing value. |
101 | */ |
102 | void PassDownBitMaskOr(const Stage& stage, std::unordered_map<IterVar, int>* p_state, |
103 | bool allow_missing = false); |
104 | |
105 | /*! |
106 | * \brief Create boundary check predicates given remapped value of root |
107 | * \param stage The stage we operate on |
108 | * \param dom_map The domain map of each value. |
109 | * \param value_map The value map of the root iter var. |
110 | * \param skip_ivar_domain Whether we skip check for IterVar's original domain. |
111 | * \param skip_iter The set of variables to skip bound condition. |
112 | * \return List of predicates that we need to check. |
113 | */ |
114 | std::vector<PrimExpr> MakeBoundCheck(const Stage& stage, const Map<IterVar, Range>& dom_map, |
115 | const std::unordered_map<IterVar, PrimExpr>& value_map, |
116 | bool skip_ivar_domain, |
117 | const std::unordered_set<IterVar>& skip_iter); |
118 | |
119 | } // namespace te |
120 | } // namespace tvm |
121 | #endif // TVM_TE_SCHEDULE_MESSAGE_PASSING_H_ |
122 | |