1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <cassert>
18#include <stdexcept>
19
20#include "loop_sequencer.hpp"
21#include "utils.hpp"
22
23namespace dnnl {
24namespace impl {
25namespace gpu {
26namespace jit {
27
28namespace loop_sequencer {
29
30/*************/
31/* Utilities */
32/*************/
33
34void LoopSequencer::schedule(Requirements reqs, ActionFunc action) {
35 schedule({{reqs, action}});
36}
37
38void LoopSequencer::schedule(std::vector<Item> list) {
39 if (!list.empty()) {
40 std::vector<CheckedItem> xlist;
41 xlist.reserve(list.size());
42
43 for (auto &entry : list)
44 xlist.push_back(CheckedItem(entry));
45
46 schedule_if(xlist);
47 }
48}
49
50void LoopSequencer::schedule_if(
51 Requirements reqs, ActionFunc action, ActionCheckFunc check) {
52 schedule_if({{reqs, action, check}});
53}
54
55void LoopSequencer::schedule_if(std::vector<CheckedItem> list) {
56 if (!list.empty()) {
57 validate(list);
58 actions.push_back({list, NeverScheduled});
59 }
60}
61
62void LoopSequencer::setCallback(CallbackType type, Callback cb) {
63 callbacks[static_cast<size_t>(type)] = cb;
64}
65
66void LoopSequencer::setRemainderHandling(RemainderHandling handling) {
67 remainderHandling = handling;
68}
69
70void LoopSequencer::callback(CallbackType type, int arg1, int arg2) {
71 auto &cb = callbacks[static_cast<size_t>(type)];
72 if (cb) cb(arg1, arg2);
73}
74
75void LoopSequencer::validate(std::vector<CheckedItem> &list) {
76 if (list.empty()) throw std::runtime_error("No actions specified.");
77
78 int variants = 1;
79 int headPeriod = 0;
80
81 auto &headReq = list[0].req;
82
83 for (auto &action : list) {
84 auto &req = action.req;
85 if (req.period <= 0) throw std::runtime_error("Invalid action period.");
86 if (req.phase < 0 || req.phase >= req.period)
87 throw std::runtime_error("Invalid action phase.");
88 if (req.duration < 0 || req.duration + req.phase > req.period)
89 throw std::runtime_error("Invalid action duration.");
90 if (req.lookahead <= -req.period || req.lookahead > headReq.lookahead)
91 throw std::runtime_error("Invalid action lookahead.");
92 if (headPeriod == 0)
93 headPeriod = req.period;
94 else if (headPeriod % req.period)
95 throw std::runtime_error(
96 "Backup action's period must evenly divide main action's "
97 "period.");
98 if (req.variants > 0) variants = lcm(variants, req.variants);
99 req.duration = std::max(req.duration, 1);
100 }
101
102 for (auto &action : list)
103 action.req.variants = variants;
104}
105
106void LoopSequencer::checkAnalyzed() const {
107 if (!analyzed)
108 throw std::runtime_error("Must call analyze() or materialize() first.");
109}
110
111int LoopSequencer::getUnroll() const {
112 checkAnalyzed();
113 return unroll;
114}
115
116int LoopSequencer::getWarmup() const {
117 checkAnalyzed();
118 return warmup;
119}
120
121int LoopSequencer::getCooldown() const {
122 checkAnalyzed();
123 return minCooldown;
124}
125
126/**************/
127/* Main logic */
128/**************/
129
130// Analyze action list to determine unroll, warmup, and cooldown.
131// Cooldown will be aligned to unroll.
132void LoopSequencer::analyze() {
133 if (analyzed) return;
134
135 unroll = 1;
136 maxLookahead = 0;
137 minCooldown = 0;
138 for (const auto &action : actions) {
139 const auto &headReq = action.list[0].req;
140
141 unroll = lcm(unroll, headReq.period * headReq.variants);
142 maxLookahead = std::max(maxLookahead, headReq.lookahead);
143
144 int mphase = (headReq.phase - headReq.lookahead) % headReq.period;
145 if (mphase < 0) mphase += headReq.period;
146 minCooldown = std::max(minCooldown,
147 headReq.lookahead - headReq.period + mphase + headReq.duration);
148 }
149
150 minCooldown = ((minCooldown + unroll - 1) / unroll) * unroll;
151 warmup = maxLookahead;
152 analyzed = true;
153}
154
155// Sequence the loop.
156void LoopSequencer::materialize(int maxLoops) {
157 typedef CallbackType CT;
158
159 analyze();
160
161 resetActions();
162 activeChecks.clear();
163 currentBias = 0;
164
165 bool unifyRemainder = (remainderHandling != RemainderHandling::Separate)
166 && (minCooldown >= unroll) && (unroll > 1);
167 int loopBias = minCooldown + unroll - 1;
168
169 int labelShort, labelUnite;
170
171 if (maxLoops > 0) {
172 // Special path: completely unroll loop, handling up to maxLoop iterations.
173 callback(CT::NotifyPhase, PhaseFullyUnrolled);
174 int lMax = ((maxLoops + unroll - 1) / unroll) * unroll;
175 for (int l = -warmup; l < lMax; l++)
176 run(l, 0, maxLoops);
177 closeChecks();
178 } else {
179 // Main path check: main path requires >= minCooldown iterations.
180 if (minCooldown > 0)
181 callback(CT::JumpIfLT, minCooldown, labelShort = nextLabel++);
182
183 if (loopBias != 0) {
184 currentBias += loopBias;
185 callback(CT::OffsetCounter, -loopBias);
186 }
187
188 // Warmup.
189 if (warmup > 0) callback(CT::NotifyPhase, PhaseWarmup);
190 for (int l = -warmup; l < 0; l++)
191 run(l, minCooldown);
192
193 // Main loop.
194 callback(CT::NotifyPhase, PhaseMainLoop);
195 callback(CT::LoopStart, unroll);
196
197 for (int l = 0; l < unroll; l++)
198 run(l, unroll + minCooldown);
199
200 callback(CT::LoopEnd, unroll);
201
202 if (loopBias != 0) {
203 currentBias -= loopBias;
204 callback(CT::OffsetCounter, loopBias);
205 }
206
207 // Cooldown.
208 // - remaining loop count in interval [minCooldown, minCooldown + unroll)
209 // - if unifying remainder, just do minCooldown loops here and leave the rest for remainder.
210 adjustActionTriggers(-unroll);
211 callback(CT::NotifyPhase, PhaseCooldown);
212 for (int l = 0;
213 l < (unifyRemainder ? minCooldown : minCooldown + unroll); l++)
214 run(l, minCooldown, minCooldown + unroll - 1);
215 closeChecks();
216
217 if (minCooldown > 1 && unifyRemainder)
218 callback(CT::OffsetCounter, -minCooldown);
219
220 callback(CT::NotifyPhase, PhaseMainPathEnd);
221
222 if (minCooldown > 1) callback(CT::Jump, labelUnite = nextLabel++);
223
224 // Short loop.
225 if (minCooldown > 0) {
226 callback(CT::JumpTarget, labelShort);
227 callback(CT::NotifyPhase, PhaseShortLoop);
228 }
229
230 if (minCooldown > 1) {
231 resetActions();
232
233 if (unifyRemainder) {
234 // If unifying remainder, group loops into chunks of size unroll.
235 for (int l = -warmup; l < 0; l++)
236 run(l, 0, minCooldown - 1);
237
238 int labelNoChunks = nextLabel++;
239
240 for (int l0 = 0; l0 < (minCooldown - unroll); l0 += unroll) {
241 int chunk = std::min(unroll, minCooldown - l0 - unroll);
242 bool needCheck = precheck(chunk);
243
244 if (needCheck) callback(CT::JumpIfLT, chunk, labelNoChunks);
245 for (int l = 0; l < chunk; l++)
246 run(l, chunk, minCooldown - 1 - l0);
247 callback(CT::OffsetCounter, -chunk);
248 adjustActionTriggers(-chunk);
249 closeChecks();
250 }
251 if (minCooldown > unroll)
252 callback(CT::JumpTarget, labelNoChunks);
253 } else {
254 for (int l = -warmup; l < minCooldown; l++)
255 run(l, 0, minCooldown - 1);
256 }
257
258 closeChecks();
259 callback(CT::JumpTarget, labelUnite);
260 }
261
262 // Unified remainder handling. Loop count is unbiased on all paths.
263 // TODO: is it always safe to unify main/short remainders when there are actions
264 // whose backups have different lookahead?
265 if (unifyRemainder
266 && (remainderHandling != RemainderHandling::Ignore)) {
267 callback(CT::NotifyPhase, PhaseRemainder);
268 for (int l = 0; l < unroll - 1; l++)
269 run(l, 0, unroll - 1);
270 closeChecks();
271 }
272 }
273
274 nextLabel = 0;
275}
276
277void LoopSequencer::run(int l, int guaranteedMin, int guaranteedMax) {
278 typedef CallbackType CT;
279
280 for (auto &action : actions) {
281 const auto &list = action.list;
282
283 // Find the first item in the list that matches trigger criteria (if any) and run it.
284 for (size_t i = 0; i < list.size(); i++) {
285 const auto &item = list[i];
286 const auto &req = item.req;
287 const auto &execute = item.action;
288 const auto &check = item.check;
289 int lTrigger = l + req.lookahead;
290 int minLoops = lTrigger + req.duration;
291 bool lastResort = (i + 1 == list.size());
292
293 if (lTrigger < 0) break;
294
295 if ((lTrigger + req.period) % req.period == req.phase) {
296 // Skip if this action can never be triggered.
297 if (guaranteedMax >= 0 && minLoops > guaranteedMax) continue;
298
299 // Skip if this action may not be triggered, and there's a backup plan.
300 if (minLoops > guaranteedMin && !lastResort) continue;
301
302 // Skip if this action's trigger falls within an already-covered section of iteration space.
303 if (lTrigger < action.nextTrigger) continue;
304
305 // Check if this action has work to do.
306 Iteration iteration(lTrigger,
307 std::max(0, guaranteedMin - lTrigger),
308 currentBias - lTrigger);
309 if (!check || check(iteration)) {
310 bool unconditional = (req.checkType == Unconditional);
311 bool optionalCheck = (req.checkType == OptionalCheck);
312
313 if (!optionalCheck) {
314 // If no loop count check desired for this action, pretend it doesn't need any loops.
315 if (unconditional) minLoops = 0;
316
317 // Finish all active checks > minLoops.
318 // If minLoops not currently being checked, then add check.
319 int thresh = minLoops;
320 bool needCheck = precheck(thresh) & !unconditional;
321
322 if (guaranteedMin < minLoops && needCheck) {
323 int label = nextLabel++;
324 callback(CT::JumpIfLT, thresh, label);
325 activeChecks.push_back(
326 std::make_pair(thresh, label));
327 }
328 }
329
330 execute(iteration);
331 }
332
333 action.nextTrigger = lTrigger - req.phase + req.period;
334 break;
335 } else {
336 // Find when this item will be triggered in this period.
337 lTrigger = lTrigger - (lTrigger % req.period) + req.phase;
338 minLoops = lTrigger + req.duration;
339
340 // If it is guaranteed to be triggered, then don't consider backups.
341 if (minLoops <= guaranteedMin) break;
342 }
343 }
344 }
345}
346
347void LoopSequencer::closeChecks() {
348 for (const auto &val : activeChecks)
349 callback(CallbackType::JumpTarget, val.second);
350 activeChecks.clear();
351}
352
353bool LoopSequencer::precheck(int thresh) {
354 bool alreadyChecked = false;
355
356 for (auto iter = activeChecks.begin(); iter < activeChecks.end();) {
357 int thisThresh = iter->first;
358 int thisLabel = iter->second;
359
360 if (thisThresh > thresh) {
361 callback(CallbackType::JumpTarget, thisLabel);
362 iter = activeChecks.erase(iter);
363 } else {
364 alreadyChecked |= (thisThresh == thresh);
365 iter++;
366 }
367 }
368
369 return !alreadyChecked;
370}
371
372void LoopSequencer::resetActions() {
373 for (auto &action : actions)
374 action.nextTrigger = NeverScheduled;
375}
376
377void LoopSequencer::adjustActionTriggers(int shift) {
378 for (auto &action : actions)
379 if (action.nextTrigger != NeverScheduled) action.nextTrigger += shift;
380}
381
382} /* namespace loop_sequencer */
383
384} // namespace jit
385} // namespace gpu
386} // namespace impl
387} // namespace dnnl
388