1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <cassert> |
18 | #include <stdexcept> |
19 | |
20 | #include "loop_sequencer.hpp" |
21 | #include "utils.hpp" |
22 | |
23 | namespace dnnl { |
24 | namespace impl { |
25 | namespace gpu { |
26 | namespace jit { |
27 | |
28 | namespace loop_sequencer { |
29 | |
30 | /*************/ |
31 | /* Utilities */ |
32 | /*************/ |
33 | |
34 | void LoopSequencer::schedule(Requirements reqs, ActionFunc action) { |
35 | schedule({{reqs, action}}); |
36 | } |
37 | |
38 | void LoopSequencer::schedule(std::vector<Item> list) { |
39 | if (!list.empty()) { |
40 | std::vector<CheckedItem> xlist; |
41 | xlist.reserve(list.size()); |
42 | |
43 | for (auto &entry : list) |
44 | xlist.push_back(CheckedItem(entry)); |
45 | |
46 | schedule_if(xlist); |
47 | } |
48 | } |
49 | |
50 | void LoopSequencer::schedule_if( |
51 | Requirements reqs, ActionFunc action, ActionCheckFunc check) { |
52 | schedule_if({{reqs, action, check}}); |
53 | } |
54 | |
55 | void LoopSequencer::schedule_if(std::vector<CheckedItem> list) { |
56 | if (!list.empty()) { |
57 | validate(list); |
58 | actions.push_back({list, NeverScheduled}); |
59 | } |
60 | } |
61 | |
62 | void LoopSequencer::setCallback(CallbackType type, Callback cb) { |
63 | callbacks[static_cast<size_t>(type)] = cb; |
64 | } |
65 | |
66 | void LoopSequencer::setRemainderHandling(RemainderHandling handling) { |
67 | remainderHandling = handling; |
68 | } |
69 | |
70 | void LoopSequencer::callback(CallbackType type, int arg1, int arg2) { |
71 | auto &cb = callbacks[static_cast<size_t>(type)]; |
72 | if (cb) cb(arg1, arg2); |
73 | } |
74 | |
75 | void LoopSequencer::validate(std::vector<CheckedItem> &list) { |
76 | if (list.empty()) throw std::runtime_error("No actions specified." ); |
77 | |
78 | int variants = 1; |
79 | int headPeriod = 0; |
80 | |
81 | auto &headReq = list[0].req; |
82 | |
83 | for (auto &action : list) { |
84 | auto &req = action.req; |
85 | if (req.period <= 0) throw std::runtime_error("Invalid action period." ); |
86 | if (req.phase < 0 || req.phase >= req.period) |
87 | throw std::runtime_error("Invalid action phase." ); |
88 | if (req.duration < 0 || req.duration + req.phase > req.period) |
89 | throw std::runtime_error("Invalid action duration." ); |
90 | if (req.lookahead <= -req.period || req.lookahead > headReq.lookahead) |
91 | throw std::runtime_error("Invalid action lookahead." ); |
92 | if (headPeriod == 0) |
93 | headPeriod = req.period; |
94 | else if (headPeriod % req.period) |
95 | throw std::runtime_error( |
96 | "Backup action's period must evenly divide main action's " |
97 | "period." ); |
98 | if (req.variants > 0) variants = lcm(variants, req.variants); |
99 | req.duration = std::max(req.duration, 1); |
100 | } |
101 | |
102 | for (auto &action : list) |
103 | action.req.variants = variants; |
104 | } |
105 | |
106 | void LoopSequencer::checkAnalyzed() const { |
107 | if (!analyzed) |
108 | throw std::runtime_error("Must call analyze() or materialize() first." ); |
109 | } |
110 | |
111 | int LoopSequencer::getUnroll() const { |
112 | checkAnalyzed(); |
113 | return unroll; |
114 | } |
115 | |
116 | int LoopSequencer::getWarmup() const { |
117 | checkAnalyzed(); |
118 | return warmup; |
119 | } |
120 | |
121 | int LoopSequencer::getCooldown() const { |
122 | checkAnalyzed(); |
123 | return minCooldown; |
124 | } |
125 | |
126 | /**************/ |
127 | /* Main logic */ |
128 | /**************/ |
129 | |
130 | // Analyze action list to determine unroll, warmup, and cooldown. |
131 | // Cooldown will be aligned to unroll. |
132 | void LoopSequencer::analyze() { |
133 | if (analyzed) return; |
134 | |
135 | unroll = 1; |
136 | maxLookahead = 0; |
137 | minCooldown = 0; |
138 | for (const auto &action : actions) { |
139 | const auto &headReq = action.list[0].req; |
140 | |
141 | unroll = lcm(unroll, headReq.period * headReq.variants); |
142 | maxLookahead = std::max(maxLookahead, headReq.lookahead); |
143 | |
144 | int mphase = (headReq.phase - headReq.lookahead) % headReq.period; |
145 | if (mphase < 0) mphase += headReq.period; |
146 | minCooldown = std::max(minCooldown, |
147 | headReq.lookahead - headReq.period + mphase + headReq.duration); |
148 | } |
149 | |
150 | minCooldown = ((minCooldown + unroll - 1) / unroll) * unroll; |
151 | warmup = maxLookahead; |
152 | analyzed = true; |
153 | } |
154 | |
155 | // Sequence the loop. |
156 | void LoopSequencer::materialize(int maxLoops) { |
157 | typedef CallbackType CT; |
158 | |
159 | analyze(); |
160 | |
161 | resetActions(); |
162 | activeChecks.clear(); |
163 | currentBias = 0; |
164 | |
165 | bool unifyRemainder = (remainderHandling != RemainderHandling::Separate) |
166 | && (minCooldown >= unroll) && (unroll > 1); |
167 | int loopBias = minCooldown + unroll - 1; |
168 | |
169 | int labelShort, labelUnite; |
170 | |
171 | if (maxLoops > 0) { |
172 | // Special path: completely unroll loop, handling up to maxLoop iterations. |
173 | callback(CT::NotifyPhase, PhaseFullyUnrolled); |
174 | int lMax = ((maxLoops + unroll - 1) / unroll) * unroll; |
175 | for (int l = -warmup; l < lMax; l++) |
176 | run(l, 0, maxLoops); |
177 | closeChecks(); |
178 | } else { |
179 | // Main path check: main path requires >= minCooldown iterations. |
180 | if (minCooldown > 0) |
181 | callback(CT::JumpIfLT, minCooldown, labelShort = nextLabel++); |
182 | |
183 | if (loopBias != 0) { |
184 | currentBias += loopBias; |
185 | callback(CT::OffsetCounter, -loopBias); |
186 | } |
187 | |
188 | // Warmup. |
189 | if (warmup > 0) callback(CT::NotifyPhase, PhaseWarmup); |
190 | for (int l = -warmup; l < 0; l++) |
191 | run(l, minCooldown); |
192 | |
193 | // Main loop. |
194 | callback(CT::NotifyPhase, PhaseMainLoop); |
195 | callback(CT::LoopStart, unroll); |
196 | |
197 | for (int l = 0; l < unroll; l++) |
198 | run(l, unroll + minCooldown); |
199 | |
200 | callback(CT::LoopEnd, unroll); |
201 | |
202 | if (loopBias != 0) { |
203 | currentBias -= loopBias; |
204 | callback(CT::OffsetCounter, loopBias); |
205 | } |
206 | |
207 | // Cooldown. |
208 | // - remaining loop count in interval [minCooldown, minCooldown + unroll) |
209 | // - if unifying remainder, just do minCooldown loops here and leave the rest for remainder. |
210 | adjustActionTriggers(-unroll); |
211 | callback(CT::NotifyPhase, PhaseCooldown); |
212 | for (int l = 0; |
213 | l < (unifyRemainder ? minCooldown : minCooldown + unroll); l++) |
214 | run(l, minCooldown, minCooldown + unroll - 1); |
215 | closeChecks(); |
216 | |
217 | if (minCooldown > 1 && unifyRemainder) |
218 | callback(CT::OffsetCounter, -minCooldown); |
219 | |
220 | callback(CT::NotifyPhase, PhaseMainPathEnd); |
221 | |
222 | if (minCooldown > 1) callback(CT::Jump, labelUnite = nextLabel++); |
223 | |
224 | // Short loop. |
225 | if (minCooldown > 0) { |
226 | callback(CT::JumpTarget, labelShort); |
227 | callback(CT::NotifyPhase, PhaseShortLoop); |
228 | } |
229 | |
230 | if (minCooldown > 1) { |
231 | resetActions(); |
232 | |
233 | if (unifyRemainder) { |
234 | // If unifying remainder, group loops into chunks of size unroll. |
235 | for (int l = -warmup; l < 0; l++) |
236 | run(l, 0, minCooldown - 1); |
237 | |
238 | int labelNoChunks = nextLabel++; |
239 | |
240 | for (int l0 = 0; l0 < (minCooldown - unroll); l0 += unroll) { |
241 | int chunk = std::min(unroll, minCooldown - l0 - unroll); |
242 | bool needCheck = precheck(chunk); |
243 | |
244 | if (needCheck) callback(CT::JumpIfLT, chunk, labelNoChunks); |
245 | for (int l = 0; l < chunk; l++) |
246 | run(l, chunk, minCooldown - 1 - l0); |
247 | callback(CT::OffsetCounter, -chunk); |
248 | adjustActionTriggers(-chunk); |
249 | closeChecks(); |
250 | } |
251 | if (minCooldown > unroll) |
252 | callback(CT::JumpTarget, labelNoChunks); |
253 | } else { |
254 | for (int l = -warmup; l < minCooldown; l++) |
255 | run(l, 0, minCooldown - 1); |
256 | } |
257 | |
258 | closeChecks(); |
259 | callback(CT::JumpTarget, labelUnite); |
260 | } |
261 | |
262 | // Unified remainder handling. Loop count is unbiased on all paths. |
263 | // TODO: is it always safe to unify main/short remainders when there are actions |
264 | // whose backups have different lookahead? |
265 | if (unifyRemainder |
266 | && (remainderHandling != RemainderHandling::Ignore)) { |
267 | callback(CT::NotifyPhase, PhaseRemainder); |
268 | for (int l = 0; l < unroll - 1; l++) |
269 | run(l, 0, unroll - 1); |
270 | closeChecks(); |
271 | } |
272 | } |
273 | |
274 | nextLabel = 0; |
275 | } |
276 | |
277 | void LoopSequencer::run(int l, int guaranteedMin, int guaranteedMax) { |
278 | typedef CallbackType CT; |
279 | |
280 | for (auto &action : actions) { |
281 | const auto &list = action.list; |
282 | |
283 | // Find the first item in the list that matches trigger criteria (if any) and run it. |
284 | for (size_t i = 0; i < list.size(); i++) { |
285 | const auto &item = list[i]; |
286 | const auto &req = item.req; |
287 | const auto &execute = item.action; |
288 | const auto &check = item.check; |
289 | int lTrigger = l + req.lookahead; |
290 | int minLoops = lTrigger + req.duration; |
291 | bool lastResort = (i + 1 == list.size()); |
292 | |
293 | if (lTrigger < 0) break; |
294 | |
295 | if ((lTrigger + req.period) % req.period == req.phase) { |
296 | // Skip if this action can never be triggered. |
297 | if (guaranteedMax >= 0 && minLoops > guaranteedMax) continue; |
298 | |
299 | // Skip if this action may not be triggered, and there's a backup plan. |
300 | if (minLoops > guaranteedMin && !lastResort) continue; |
301 | |
302 | // Skip if this action's trigger falls within an already-covered section of iteration space. |
303 | if (lTrigger < action.nextTrigger) continue; |
304 | |
305 | // Check if this action has work to do. |
306 | Iteration iteration(lTrigger, |
307 | std::max(0, guaranteedMin - lTrigger), |
308 | currentBias - lTrigger); |
309 | if (!check || check(iteration)) { |
310 | bool unconditional = (req.checkType == Unconditional); |
311 | bool optionalCheck = (req.checkType == OptionalCheck); |
312 | |
313 | if (!optionalCheck) { |
314 | // If no loop count check desired for this action, pretend it doesn't need any loops. |
315 | if (unconditional) minLoops = 0; |
316 | |
317 | // Finish all active checks > minLoops. |
318 | // If minLoops not currently being checked, then add check. |
319 | int thresh = minLoops; |
320 | bool needCheck = precheck(thresh) & !unconditional; |
321 | |
322 | if (guaranteedMin < minLoops && needCheck) { |
323 | int label = nextLabel++; |
324 | callback(CT::JumpIfLT, thresh, label); |
325 | activeChecks.push_back( |
326 | std::make_pair(thresh, label)); |
327 | } |
328 | } |
329 | |
330 | execute(iteration); |
331 | } |
332 | |
333 | action.nextTrigger = lTrigger - req.phase + req.period; |
334 | break; |
335 | } else { |
336 | // Find when this item will be triggered in this period. |
337 | lTrigger = lTrigger - (lTrigger % req.period) + req.phase; |
338 | minLoops = lTrigger + req.duration; |
339 | |
340 | // If it is guaranteed to be triggered, then don't consider backups. |
341 | if (minLoops <= guaranteedMin) break; |
342 | } |
343 | } |
344 | } |
345 | } |
346 | |
347 | void LoopSequencer::closeChecks() { |
348 | for (const auto &val : activeChecks) |
349 | callback(CallbackType::JumpTarget, val.second); |
350 | activeChecks.clear(); |
351 | } |
352 | |
353 | bool LoopSequencer::precheck(int thresh) { |
354 | bool alreadyChecked = false; |
355 | |
356 | for (auto iter = activeChecks.begin(); iter < activeChecks.end();) { |
357 | int thisThresh = iter->first; |
358 | int thisLabel = iter->second; |
359 | |
360 | if (thisThresh > thresh) { |
361 | callback(CallbackType::JumpTarget, thisLabel); |
362 | iter = activeChecks.erase(iter); |
363 | } else { |
364 | alreadyChecked |= (thisThresh == thresh); |
365 | iter++; |
366 | } |
367 | } |
368 | |
369 | return !alreadyChecked; |
370 | } |
371 | |
372 | void LoopSequencer::resetActions() { |
373 | for (auto &action : actions) |
374 | action.nextTrigger = NeverScheduled; |
375 | } |
376 | |
377 | void LoopSequencer::adjustActionTriggers(int shift) { |
378 | for (auto &action : actions) |
379 | if (action.nextTrigger != NeverScheduled) action.nextTrigger += shift; |
380 | } |
381 | |
382 | } /* namespace loop_sequencer */ |
383 | |
384 | } // namespace jit |
385 | } // namespace gpu |
386 | } // namespace impl |
387 | } // namespace dnnl |
388 | |