1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef LOOP_SEQUENCER_HPP |
18 | #define LOOP_SEQUENCER_HPP |
19 | |
20 | #include <array> |
21 | #include <functional> |
22 | #include <limits> |
23 | #include <utility> |
24 | #include <vector> |
25 | |
26 | namespace dnnl { |
27 | namespace impl { |
28 | namespace gpu { |
29 | namespace jit { |
30 | |
31 | namespace loop_sequencer { |
32 | |
33 | class Iteration { |
34 | friend class LoopSequencer; |
35 | |
36 | public: |
37 | // Iteration counter. |
38 | constexpr operator int() const { return value; } |
39 | constexpr int get() const { return value; } |
40 | // # of guaranteed iterations after this one |
41 | constexpr int remaining() const { return rem; } |
42 | // Offset between current counter value and the iteration when the action triggers. |
43 | constexpr int counterOffset() const { return offset; } |
44 | |
45 | Iteration() {} |
46 | |
47 | private: |
48 | Iteration(int value_, int rem_, int offset_) |
49 | : value(value_), rem(rem_), offset(offset_) {} |
50 | |
51 | int value; |
52 | int rem; |
53 | int offset; |
54 | }; |
55 | |
56 | class LoopSequencer { |
57 | public: |
58 | struct Requirements { |
59 | int period = 0; // # of loops between repetitions of the action. |
60 | int phase = 0; // Action triggers when (loop #) mod period = phase. |
61 | int duration |
62 | = 0; // Action only triggered if at least _duration_ loops remain (including current loop) |
63 | int lookahead |
64 | = 0; // Action will run _lookahead_ loops before trigger condition. |
65 | int variants = 0; // # of variants of the action. |
66 | // variants = n is equivalent to scheduling n copies of the action at n times the period, with equally spaced phases. |
67 | int checkType = 0; // See CheckType enum. |
68 | |
69 | friend Requirements operator|( |
70 | const Requirements &req1, const Requirements &req2) { |
71 | auto result = req1; |
72 | result.period |= req2.period; |
73 | result.phase |= req2.phase; |
74 | result.duration |= req2.duration; |
75 | result.lookahead |= req2.lookahead; |
76 | result.variants |= req2.variants; |
77 | result.checkType |= req2.checkType; |
78 | return result; |
79 | } |
80 | |
81 | Requirements delay(int delay_) const { |
82 | auto result = *this; |
83 | result.phase += delay_; |
84 | result.duration -= delay_; |
85 | result.duration = std::max(result.duration, 0); |
86 | return result; |
87 | } |
88 | }; |
89 | |
90 | using ActionFunc = std::function<void( |
91 | Iteration)>; // void action(Iteration iteration); |
92 | using ActionCheckFunc = std::function<bool( |
93 | Iteration)>; // bool actionCheck(Iteration iteration); |
94 | using Callback = std::function<void( |
95 | int, int)>; // void callback(int arg1, int arg2); |
96 | |
97 | struct Item { |
98 | Requirements req; |
99 | ActionFunc action; |
100 | |
101 | Item(Requirements req_, ActionFunc action_) |
102 | : req(req_), action(action_) {} |
103 | }; |
104 | |
105 | struct CheckedItem : public Item { |
106 | ActionCheckFunc check; |
107 | |
108 | CheckedItem( |
109 | Requirements req_, ActionFunc action_, ActionCheckFunc check_) |
110 | : Item(req_, action_), check(check_) {} |
111 | CheckedItem( |
112 | Item item_, ActionCheckFunc check_ = ActionCheckFunc(nullptr)) |
113 | : Item(item_), check(check_) {} |
114 | }; |
115 | |
116 | enum class CallbackType { |
117 | OffsetCounter, // Add offset to loop counter. (arg1 = offset) |
118 | LoopStart, // Mark top of loop. Jump to bottom if counter <= 0. (arg1 = unroll) |
119 | LoopEnd, // Decrement counter by unroll and jump to top of loop if > 0. (arg1 = unroll) |
120 | Jump, // Jump unconditionally. (arg1 = label) |
121 | JumpIfLT, // Jump if counter < arg. (arg1 = threshold, arg2 = label) |
122 | JumpTarget, // Mark jump target. (arg1 = label) |
123 | NotifyPhase, // Notify of change in phase. (arg1 = phase) |
124 | _end_ |
125 | }; |
126 | |
127 | enum class RemainderHandling { |
128 | Separate, // Full remainder handling for both main and cooldown loops. |
129 | Unified, // Combine main and cooldown loop remainder handling. |
130 | Ignore // No remainder handling; assume loop count is multiple of unroll. |
131 | }; |
132 | |
133 | enum Phase { |
134 | PhaseWarmup, // Warmup for main loop. |
135 | PhaseMainLoop, // Inside main loop. |
136 | PhaseCooldown, // Cooldown after main loop. |
137 | PhaseMainPathEnd, // End of main path. |
138 | PhaseShortLoop, // Short loop, if not enough iterations for main loop. |
139 | PhaseRemainder, // Unified remainder loop. |
140 | PhaseFullyUnrolled, // Fully unrolled loop sequence. |
141 | }; |
142 | |
143 | enum CheckType { |
144 | StandardCheck |
145 | = 0, // Loop counter should be checked to see if action can be run. |
146 | OptionalCheck |
147 | = 1, // Loop counter may be checked to see if action can be run, but it is not required. |
148 | Unconditional |
149 | = 3, // Loop counter must not be checked to see if action can be run. |
150 | }; |
151 | |
152 | void schedule(Requirements reqs, ActionFunc action); |
153 | void schedule(std::vector<Item> list); |
154 | void schedule_if( |
155 | Requirements reqs, ActionFunc action, ActionCheckFunc check); |
156 | void schedule_if(std::vector<CheckedItem> list); |
157 | void analyze(); |
158 | void materialize(int maxLoops = -1); |
159 | |
160 | void setCallback(CallbackType type, Callback cb); |
161 | void setRemainderHandling(RemainderHandling handling); |
162 | |
163 | int getUnroll() const; |
164 | int getWarmup() const; |
165 | int getCooldown() const; |
166 | |
167 | protected: |
168 | struct Action { |
169 | std::vector<CheckedItem> list; |
170 | int nextTrigger; |
171 | }; |
172 | |
173 | enum { NeverScheduled = std::numeric_limits<int>::min() }; |
174 | |
175 | std::vector<Action> actions; |
176 | std::array<Callback, static_cast<size_t>(CallbackType::_end_)> callbacks; |
177 | std::vector<std::pair<int, int>> activeChecks; |
178 | RemainderHandling remainderHandling = RemainderHandling::Separate; |
179 | int nextLabel = 0; |
180 | int currentBias = 0; |
181 | |
182 | int unroll = 1; |
183 | int maxLookahead = 0; |
184 | int minCooldown = 0; |
185 | int warmup = 0; |
186 | |
187 | bool analyzed = false; |
188 | |
189 | void validate(std::vector<CheckedItem> &list); |
190 | void callback(CallbackType type, int arg1, int arg2 = 0); |
191 | void run(int l, int guaranteedMin, int guaranteedMax = -1); |
192 | void closeChecks(); |
193 | bool precheck(int thresh); |
194 | void resetActions(); |
195 | void adjustActionTriggers(int shift); |
196 | void checkAnalyzed() const; |
197 | }; |
198 | |
199 | static inline LoopSequencer::Requirements every(int period) { |
200 | LoopSequencer::Requirements result; |
201 | result.period = period; |
202 | return result; |
203 | } |
204 | |
205 | static inline LoopSequencer::Requirements every(int ph, int period) { |
206 | LoopSequencer::Requirements result; |
207 | result.phase = ph; |
208 | result.period = period; |
209 | return result; |
210 | } |
211 | |
212 | static inline LoopSequencer::Requirements phase(int ph) { |
213 | LoopSequencer::Requirements result; |
214 | result.phase = ph; |
215 | return result; |
216 | } |
217 | |
218 | static inline LoopSequencer::Requirements duration(int dur) { |
219 | LoopSequencer::Requirements result; |
220 | result.duration = dur; |
221 | return result; |
222 | } |
223 | |
224 | static inline LoopSequencer::Requirements lookahead(int ahead) { |
225 | LoopSequencer::Requirements result; |
226 | result.lookahead = ahead; |
227 | return result; |
228 | } |
229 | |
230 | static inline LoopSequencer::Requirements variants(int vars) { |
231 | LoopSequencer::Requirements result; |
232 | result.variants = vars; |
233 | return result; |
234 | } |
235 | |
236 | static inline LoopSequencer::Requirements unconditional() { |
237 | LoopSequencer::Requirements result; |
238 | result.checkType = LoopSequencer::Unconditional; |
239 | return result; |
240 | } |
241 | |
242 | static inline LoopSequencer::Requirements checkOptional() { |
243 | LoopSequencer::Requirements result; |
244 | result.checkType = LoopSequencer::OptionalCheck; |
245 | return result; |
246 | } |
247 | |
248 | } /* namespace loop_sequencer */ |
249 | |
250 | using loop_sequencer::LoopSequencer; |
251 | |
252 | } // namespace jit |
253 | } // namespace gpu |
254 | } // namespace impl |
255 | } // namespace dnnl |
256 | |
257 | #endif /* LOOP_SEQUENCER_HPP */ |
258 | |