1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef LOOP_SEQUENCER_HPP
18#define LOOP_SEQUENCER_HPP
19
20#include <array>
21#include <functional>
22#include <limits>
23#include <utility>
24#include <vector>
25
26namespace dnnl {
27namespace impl {
28namespace gpu {
29namespace jit {
30
31namespace loop_sequencer {
32
33class Iteration {
34 friend class LoopSequencer;
35
36public:
37 // Iteration counter.
38 constexpr operator int() const { return value; }
39 constexpr int get() const { return value; }
40 // # of guaranteed iterations after this one
41 constexpr int remaining() const { return rem; }
42 // Offset between current counter value and the iteration when the action triggers.
43 constexpr int counterOffset() const { return offset; }
44
45 Iteration() {}
46
47private:
48 Iteration(int value_, int rem_, int offset_)
49 : value(value_), rem(rem_), offset(offset_) {}
50
51 int value;
52 int rem;
53 int offset;
54};
55
56class LoopSequencer {
57public:
58 struct Requirements {
59 int period = 0; // # of loops between repetitions of the action.
60 int phase = 0; // Action triggers when (loop #) mod period = phase.
61 int duration
62 = 0; // Action only triggered if at least _duration_ loops remain (including current loop)
63 int lookahead
64 = 0; // Action will run _lookahead_ loops before trigger condition.
65 int variants = 0; // # of variants of the action.
66 // variants = n is equivalent to scheduling n copies of the action at n times the period, with equally spaced phases.
67 int checkType = 0; // See CheckType enum.
68
69 friend Requirements operator|(
70 const Requirements &req1, const Requirements &req2) {
71 auto result = req1;
72 result.period |= req2.period;
73 result.phase |= req2.phase;
74 result.duration |= req2.duration;
75 result.lookahead |= req2.lookahead;
76 result.variants |= req2.variants;
77 result.checkType |= req2.checkType;
78 return result;
79 }
80
81 Requirements delay(int delay_) const {
82 auto result = *this;
83 result.phase += delay_;
84 result.duration -= delay_;
85 result.duration = std::max(result.duration, 0);
86 return result;
87 }
88 };
89
90 using ActionFunc = std::function<void(
91 Iteration)>; // void action(Iteration iteration);
92 using ActionCheckFunc = std::function<bool(
93 Iteration)>; // bool actionCheck(Iteration iteration);
94 using Callback = std::function<void(
95 int, int)>; // void callback(int arg1, int arg2);
96
97 struct Item {
98 Requirements req;
99 ActionFunc action;
100
101 Item(Requirements req_, ActionFunc action_)
102 : req(req_), action(action_) {}
103 };
104
105 struct CheckedItem : public Item {
106 ActionCheckFunc check;
107
108 CheckedItem(
109 Requirements req_, ActionFunc action_, ActionCheckFunc check_)
110 : Item(req_, action_), check(check_) {}
111 CheckedItem(
112 Item item_, ActionCheckFunc check_ = ActionCheckFunc(nullptr))
113 : Item(item_), check(check_) {}
114 };
115
116 enum class CallbackType {
117 OffsetCounter, // Add offset to loop counter. (arg1 = offset)
118 LoopStart, // Mark top of loop. Jump to bottom if counter <= 0. (arg1 = unroll)
119 LoopEnd, // Decrement counter by unroll and jump to top of loop if > 0. (arg1 = unroll)
120 Jump, // Jump unconditionally. (arg1 = label)
121 JumpIfLT, // Jump if counter < arg. (arg1 = threshold, arg2 = label)
122 JumpTarget, // Mark jump target. (arg1 = label)
123 NotifyPhase, // Notify of change in phase. (arg1 = phase)
124 _end_
125 };
126
127 enum class RemainderHandling {
128 Separate, // Full remainder handling for both main and cooldown loops.
129 Unified, // Combine main and cooldown loop remainder handling.
130 Ignore // No remainder handling; assume loop count is multiple of unroll.
131 };
132
133 enum Phase {
134 PhaseWarmup, // Warmup for main loop.
135 PhaseMainLoop, // Inside main loop.
136 PhaseCooldown, // Cooldown after main loop.
137 PhaseMainPathEnd, // End of main path.
138 PhaseShortLoop, // Short loop, if not enough iterations for main loop.
139 PhaseRemainder, // Unified remainder loop.
140 PhaseFullyUnrolled, // Fully unrolled loop sequence.
141 };
142
143 enum CheckType {
144 StandardCheck
145 = 0, // Loop counter should be checked to see if action can be run.
146 OptionalCheck
147 = 1, // Loop counter may be checked to see if action can be run, but it is not required.
148 Unconditional
149 = 3, // Loop counter must not be checked to see if action can be run.
150 };
151
152 void schedule(Requirements reqs, ActionFunc action);
153 void schedule(std::vector<Item> list);
154 void schedule_if(
155 Requirements reqs, ActionFunc action, ActionCheckFunc check);
156 void schedule_if(std::vector<CheckedItem> list);
157 void analyze();
158 void materialize(int maxLoops = -1);
159
160 void setCallback(CallbackType type, Callback cb);
161 void setRemainderHandling(RemainderHandling handling);
162
163 int getUnroll() const;
164 int getWarmup() const;
165 int getCooldown() const;
166
167protected:
168 struct Action {
169 std::vector<CheckedItem> list;
170 int nextTrigger;
171 };
172
173 enum { NeverScheduled = std::numeric_limits<int>::min() };
174
175 std::vector<Action> actions;
176 std::array<Callback, static_cast<size_t>(CallbackType::_end_)> callbacks;
177 std::vector<std::pair<int, int>> activeChecks;
178 RemainderHandling remainderHandling = RemainderHandling::Separate;
179 int nextLabel = 0;
180 int currentBias = 0;
181
182 int unroll = 1;
183 int maxLookahead = 0;
184 int minCooldown = 0;
185 int warmup = 0;
186
187 bool analyzed = false;
188
189 void validate(std::vector<CheckedItem> &list);
190 void callback(CallbackType type, int arg1, int arg2 = 0);
191 void run(int l, int guaranteedMin, int guaranteedMax = -1);
192 void closeChecks();
193 bool precheck(int thresh);
194 void resetActions();
195 void adjustActionTriggers(int shift);
196 void checkAnalyzed() const;
197};
198
199static inline LoopSequencer::Requirements every(int period) {
200 LoopSequencer::Requirements result;
201 result.period = period;
202 return result;
203}
204
205static inline LoopSequencer::Requirements every(int ph, int period) {
206 LoopSequencer::Requirements result;
207 result.phase = ph;
208 result.period = period;
209 return result;
210}
211
212static inline LoopSequencer::Requirements phase(int ph) {
213 LoopSequencer::Requirements result;
214 result.phase = ph;
215 return result;
216}
217
218static inline LoopSequencer::Requirements duration(int dur) {
219 LoopSequencer::Requirements result;
220 result.duration = dur;
221 return result;
222}
223
224static inline LoopSequencer::Requirements lookahead(int ahead) {
225 LoopSequencer::Requirements result;
226 result.lookahead = ahead;
227 return result;
228}
229
230static inline LoopSequencer::Requirements variants(int vars) {
231 LoopSequencer::Requirements result;
232 result.variants = vars;
233 return result;
234}
235
236static inline LoopSequencer::Requirements unconditional() {
237 LoopSequencer::Requirements result;
238 result.checkType = LoopSequencer::Unconditional;
239 return result;
240}
241
242static inline LoopSequencer::Requirements checkOptional() {
243 LoopSequencer::Requirements result;
244 result.checkType = LoopSequencer::OptionalCheck;
245 return result;
246}
247
248} /* namespace loop_sequencer */
249
250using loop_sequencer::LoopSequencer;
251
252} // namespace jit
253} // namespace gpu
254} // namespace impl
255} // namespace dnnl
256
257#endif /* LOOP_SEQUENCER_HPP */
258