1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file message_passing.h
22 * \brief Common utilities to do message passing
23 * on the schedule hyper graph.
24 */
25#ifndef TVM_TE_SCHEDULE_MESSAGE_PASSING_H_
26#define TVM_TE_SCHEDULE_MESSAGE_PASSING_H_
27
28#include <tvm/arith/analyzer.h>
29#include <tvm/te/operation.h>
30#include <tvm/te/schedule.h>
31#include <tvm/tir/expr.h>
32
33#include <unordered_map>
34#include <unordered_set>
35#include <vector>
36
37namespace tvm {
38namespace te {
39/*!
40 * \brief Downward inference of domain of each IterVar.
41 * Caller set the range of the root, then the function
42 * propagates it towards the leaves.
43 *
44 * \param stage The stage to operate on.
45 * \param p_state The state of the message passing.
46 * \param analyzer Analyzer context, storing information about bounds in p_state.
47 * \param allow_missing Whether allow missing value.
48 */
49void PassDownDomain(const Stage& stage, std::unordered_map<IterVar, Range>* p_state,
50 arith::Analyzer* analyzer, bool allow_missing = false);
51
52/*!
53 * \param Upward inference of index of each IterVar.
54 * given index assignement of the leaves,
55 *
56 * \param stage The stage to operate on.
57 * \param dom_map The domain map of each iteration variable's domain.
58 * \param p_state The index state of each IterVar.
59 * \param allow_missing Whether allow missing value.
60 */
61void PassUpIndex(const Stage& stage, const Map<IterVar, Range>& dom_map,
62 std::unordered_map<IterVar, PrimExpr>* p_state, bool allow_missing = false);
63
64/*!
65 * \param Downward inference of index of each IterVar.
66 * given index assignement of roots.
67 *
68 * \param stage The stage to operate on.
69 * \param dom_map The domain map of each iteration variable's domain.
70 * \param p_state The index state of each IterVar.
71 * \param allow_missing Whether allow missing value.
72 */
73void PassDownIndex(const Stage& stage, const Map<IterVar, Range>& dom_map,
74 std::unordered_map<IterVar, PrimExpr>* p_state, bool allow_missing = false);
75
76/*!
77 * \param Upward inference of domain set of each IterVar.
78 * given domain assignment of the leaves,
79 *
80 * \param stage The stage to operate on.
81 * \param dom_map The domain map of each iteration variable's maximum domain.
82 * \param p_state The index state of each IterVar.
83 */
84void PassUpDomain(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
85 std::unordered_map<IterVar, IntSet>* p_state);
86
87/*!
88 * \brief Upward message passing of bitmask with or relation.
89 * \param stage The stage to operate on.
90 * \param p_state The index state of each IterVar.
91 * \param allow_missing Whether allow missing value.
92 */
93void PassUpBitMaskOr(const Stage& stage, std::unordered_map<IterVar, int>* p_state,
94 bool allow_missing = false);
95
96/*!
97 * \brief Downward message passing of bitmask with or relation.
98 * \param stage The stage to operate on.
99 * \param p_state The index state of each IterVar.
100 * \param allow_missing Whether allow missing value.
101 */
102void PassDownBitMaskOr(const Stage& stage, std::unordered_map<IterVar, int>* p_state,
103 bool allow_missing = false);
104
105/*!
106 * \brief Create boundary check predicates given remapped value of root
107 * \param stage The stage we operate on
108 * \param dom_map The domain map of each value.
109 * \param value_map The value map of the root iter var.
110 * \param skip_ivar_domain Whether we skip check for IterVar's original domain.
111 * \param skip_iter The set of variables to skip bound condition.
112 * \return List of predicates that we need to check.
113 */
114std::vector<PrimExpr> MakeBoundCheck(const Stage& stage, const Map<IterVar, Range>& dom_map,
115 const std::unordered_map<IterVar, PrimExpr>& value_map,
116 bool skip_ivar_domain,
117 const std::unordered_set<IterVar>& skip_iter);
118
119} // namespace te
120} // namespace tvm
121#endif // TVM_TE_SCHEDULE_MESSAGE_PASSING_H_
122