1 | /** |
2 | * Copyright (c) 2017-present, Facebook, Inc. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | |
17 | #include <ctype.h> |
18 | #include <memory> |
19 | #include <sstream> |
20 | |
21 | #include <glog/logging.h> |
22 | |
23 | #include "glow/Graph/Graph.h" |
24 | #include "glow/Graph/TensorLayout.h" |
25 | #include "glow/Graph/VerifierHelper.h" |
26 | |
27 | using namespace glow; |
28 | |
29 | /// Checks if two layout descriptions \p lhs and \p rhs describe the same layout |
30 | /// for a value of the type \p ty \returns true if layouts are the same. |
31 | bool glow::checkSameLayout(llvm::StringRef srcLayoutStr, |
32 | llvm::StringRef destLayoutStr, TypeRef ty, |
33 | const Node *parent, const std::string &prefix, |
34 | const TensorLayoutCommon &TLC, bool verbose) { |
35 | auto srcLayout = TensorLayoutDescription(srcLayoutStr.str()); |
36 | auto destLayout = TensorLayoutDescription(destLayoutStr.str()); |
37 | // Are layouts literally the same? |
38 | if (srcLayout.isSameLayout(destLayout)) { |
39 | return true; |
40 | } |
41 | // Does the type satisfy the dest layout? |
42 | if (TLC.isSatisfiedBy(ty, destLayout, &srcLayout)) { |
43 | return true; |
44 | } |
45 | if (verbose) { |
46 | report("\n\n\n" ); |
47 | reportContext(parent); |
48 | report("\n" ); |
49 | report(prefix); |
50 | report("\n" ); |
51 | report(parent->getDebugDesc()); |
52 | report("\nMismatching layouts:\n" ); |
53 | report("Provided layout\n" ); |
54 | report(srcLayout.getDebugDesc()); |
55 | report("\n" ); |
56 | report("Expected layout\n" ); |
57 | report(destLayout.getDebugDesc()); |
58 | report("\n" ); |
59 | } |
60 | return false; |
61 | } |
62 | |
63 | /// Verifies the correctness of tensor layouts in the function \p F using layout |
64 | /// requirements interface \p TLC. |
65 | bool glow::verifyLayouts(const Function &F, TensorLayoutCommon &TLC, |
66 | bool verbose) { |
67 | bool isValid = true; |
68 | for (const auto &N : F.getNodes()) { |
69 | for (unsigned idx = 0, e = N.getNumInputs(); idx < e; ++idx) { |
70 | auto input = N.getNthInput(idx); |
71 | auto producerLayout = |
72 | TLC.getNthResultLayoutRequirements(input.getNode(), input.getResNo()); |
73 | auto consumerLayout = TLC.getNthInputLayoutRequirements(&N, idx); |
74 | std::string inputName = strFormat("input %d" , idx); |
75 | isValid &= checkSameLayout(producerLayout, consumerLayout, |
76 | input.getType(), &N, inputName, TLC, verbose); |
77 | } |
78 | } |
79 | return isValid; |
80 | } |
81 | |
82 | TensorLayoutDescription::TensorLayoutDescription(const std::string &layoutStr) { |
83 | if (layoutStr.empty()) { |
84 | // 0-D output |
85 | numDims_ = 0; |
86 | return; |
87 | } |
88 | parse(layoutStr); |
89 | } |
90 | |
91 | static bool isCustomExtension(llvm::StringRef text) { |
92 | auto nsPos = text.find(':'); |
93 | if (nsPos == llvm::StringRef::npos) { |
94 | return false; |
95 | } |
96 | auto bracketPos = text.find(']'); |
97 | assert(bracketPos != llvm::StringRef::npos && "Expected a closing bracket." ); |
98 | return (bracketPos > nsPos); |
99 | } |
100 | |
101 | // Serialization format - |
102 | // The form for each dimension is as follows: |
103 | // 1. (mandatory) one char representing the current dimension. Either an |
104 | // alphabetic letter or '*'. |
105 | // 2. (optional) token for the start of optional dimension information: '[' |
106 | // 3. (optional, must have 2. in place) namespace of the extension followed by |
107 | // ':'. must be provided for non-official backends. example: ocl:<information> |
108 | // 4. (optional, must have 2. in place) end of the current default extension |
109 | // ']' |
110 | // 5. (optional) go to 2. |
111 | // NOTE: To add alignment information, the format is: a=<size_t> |
112 | // Example: N[a=32][namespace_for_unsupported:<bla>]HWC would represent 4-D |
113 | // tensor wherein N needs an alignment of 32 + some closed-backend requirements |
114 | // we don't know about. HWC have no restrictions. |
115 | // NOTES: |
116 | // 1. For each dimension, the identifier can be either a single english alphabet |
117 | // letter, either upper or lower case, or the star symbol. |
118 | // 2. We assume that a single letter is enough for each dimension, it makes |
119 | // parsing easier and avoids adding delimiters in the serialized format, |
120 | // however, we do have a constructor that (theoretically) accepts multi-letter |
121 | // dimensions. If we decide to expand the current support, we will need to add |
122 | // delimiters to the serialized form. |
123 | void TensorLayoutDescription::parse(llvm::StringRef text) { |
124 | unsigned idx = 0; |
125 | while (!text.empty()) { |
126 | char curr = text.front(); |
127 | text = text.drop_front(); |
128 | if (curr == '\0' || isblank(curr)) { |
129 | continue; |
130 | } |
131 | switch (curr) { |
132 | case '[': { |
133 | assert(idx > 0 && "Expected at least one parsed entry." ); |
134 | if (isCustomExtension(text)) { |
135 | parseCustomExtensions(text, idx - 1); |
136 | } else { |
137 | parseOfficialExtensions(text, idx - 1); |
138 | } |
139 | break; |
140 | } |
141 | default: { |
142 | DCHECK(isalpha(curr) || curr == '*') |
143 | << "Expected an alphabetic letter or '*'., got: " << curr |
144 | << " in string: " << text.str(); |
145 | std::string currStr(1, curr); |
146 | dims_[idx].append(currStr); |
147 | serializedLayout_.append(dims_[idx]); |
148 | ++idx; |
149 | assert(idx <= max_tensor_dimensions && "Too many tensor dimensions" ); |
150 | break; |
151 | } |
152 | } |
153 | } |
154 | numDims_ = idx; |
155 | } |
156 | |
157 | void TensorLayoutDescription::parseCustomExtensions(llvm::StringRef &text, |
158 | unsigned idx) { |
159 | char curr = '['; |
160 | dims_[idx].append("[" ); |
161 | for (curr = text.front(); curr != ']' && !text.empty(); curr = text.front()) { |
162 | dims_[idx].append(std::string(1, curr)); |
163 | text = text.drop_front(); |
164 | } |
165 | assert(curr == ']' && "Expected closing ']' bracket." ); |
166 | text = text.drop_front(); |
167 | dims_[idx].append("]" ); |
168 | } |
169 | |
170 | void TensorLayoutDescription::parseOfficialExtensions(llvm::StringRef &text, |
171 | unsigned idx) { |
172 | // Only alignment so far - very simple parser: |
173 | if (!text.consume_front("a=" )) { |
174 | llvm_unreachable("Unsupported layout extension." ); |
175 | } |
176 | size_t align; |
177 | if (text.consumeInteger(10, align)) { |
178 | llvm_unreachable("Expected alignment info." ); |
179 | } |
180 | if (!text.consume_front("]" )) { |
181 | llvm_unreachable("Expected closing ']'" ); |
182 | } |
183 | dims_[idx].append("[a=" ); |
184 | dims_[idx].append(std::to_string(align)); |
185 | dims_[idx].append("]" ); |
186 | } |
187 | |
188 | TensorLayoutDescription::TensorLayoutDescription( |
189 | llvm::ArrayRef<std::string> dims) { |
190 | assert(dims.size() <= max_tensor_dimensions && "Too many tensor dimensions" ); |
191 | numDims_ = dims.size(); |
192 | for (unsigned idx = 0; idx < numDims_; ++idx) { |
193 | dims_[idx] = dims[idx]; |
194 | serializedLayout_.append(dims_[idx]); |
195 | } |
196 | } |
197 | |
198 | const llvm::StringRef |
199 | TensorLayoutDescription::getNthDimDescription(size_t n) const { |
200 | assert(n < numDims_ && "Wrong dimension number" ); |
201 | return dims_[n]; |
202 | } |
203 | |
204 | size_t TensorLayoutDescription::getAlignment(size_t n) const { |
205 | assert(n < numDims_ && "Wrong dimension number" ); |
206 | return getAlignment(dims_[n]); |
207 | } |
208 | |
209 | size_t TensorLayoutDescription::getAlignment(const std::string &s) const { |
210 | std::string alignPrefix = "a=" ; |
211 | size_t pos = s.find(alignPrefix); |
212 | if (pos == std::string::npos) { |
213 | // Default alignment: |
214 | return 1; |
215 | } |
216 | auto align = s.substr(pos + alignPrefix.size()); |
217 | size_t ret; |
218 | std::istringstream(align) >> ret; |
219 | return ret; |
220 | } |
221 | |
222 | /// \returns the position of ']' for extension at \p pos. |
223 | static size_t getEndOFExtension(llvm::StringRef dimStr, size_t pos) { |
224 | size_t posEnd = pos; |
225 | pos = pos - 1; |
226 | assert(dimStr[pos] == '[' && "Expected start of align extension." ); |
227 | while (dimStr[posEnd] != ']') { |
228 | ++posEnd; |
229 | assert(posEnd < dimStr.size() && "Expected to find closing bracket." ); |
230 | } |
231 | return posEnd; |
232 | } |
233 | |
234 | void TensorLayoutDescription::removeAttribute(const std::string &name, |
235 | std::string &dimStr) { |
236 | size_t pos = dimStr.find(name); |
237 | if (pos != std::string::npos) { |
238 | size_t posEnd = getEndOFExtension(dimStr, pos); |
239 | dimStr = dimStr.substr(0, pos - 1) + dimStr.substr(posEnd + 1); |
240 | } |
241 | } |
242 | |
243 | void TensorLayoutDescription::reconstructSerialized() { |
244 | serializedLayout_ = "" ; |
245 | for (size_t i = 0; i < numDims_; ++i) { |
246 | serializedLayout_.append(dims_[i]); |
247 | } |
248 | } |
249 | |
250 | llvm::StringRef TensorLayoutDescription::setAlignment(size_t n, size_t align) { |
251 | assert(n < numDims_ && "Wrong dimension number" ); |
252 | return setAttribute(n, "a=" , std::to_string(align)); |
253 | } |
254 | |
255 | llvm::StringRef TensorLayoutDescription::setAttribute(size_t n, |
256 | llvm::StringRef name, |
257 | llvm::StringRef value) { |
258 | assert(n < numDims_ && "Wrong dimension number" ); |
259 | auto &dimStr = dims_[n]; |
260 | // If we have a current name - remove it. |
261 | removeAttribute(name.str(), dimStr); |
262 | // Add new name information to dim: |
263 | dimStr.append("[" ); |
264 | dimStr.append(name.str()); |
265 | dimStr.append(value.str()); |
266 | dimStr.append("]" ); |
267 | reconstructSerialized(); |
268 | return dimStr; |
269 | } |
270 | |
271 | std::string TensorLayoutDescription::getAttribute(size_t n, |
272 | llvm::StringRef name) const { |
273 | assert(n < numDims_ && "Wrong dimension number" ); |
274 | size_t pos = dims_[n].find(name.str()); |
275 | if (pos == std::string::npos) { |
276 | return "" ; |
277 | } |
278 | size_t posEnd = getEndOFExtension(dims_[n], pos); |
279 | auto nameSZ = name.size(); |
280 | return dims_[n].substr(pos + nameSZ, posEnd - nameSZ - pos); |
281 | } |
282 | |
283 | llvm::ArrayRef<std::string> TensorLayoutDescription::getDims() const { |
284 | return llvm::makeArrayRef(dims_, numDims_); |
285 | } |
286 | |
287 | std::string TensorLayoutDescription::getDebugDesc() const { |
288 | std::string desc = "Layout: " + getSerializedLayout() + " [" ; |
289 | for (unsigned idx = 0; idx < numDims_; idx++) { |
290 | if (idx > 0) { |
291 | desc += ", " ; |
292 | } |
293 | desc += "name = " ; |
294 | desc += dims_[idx]; |
295 | desc += " : alignment = " ; |
296 | desc += std::to_string(getAlignment(idx)); |
297 | desc += " : index = " ; |
298 | desc += std::to_string(idx); |
299 | } |
300 | desc += "]" ; |
301 | return desc; |
302 | } |
303 | |
304 | bool TensorLayoutDescription::isSameLayout( |
305 | const TensorLayoutDescription &rhs) const { |
306 | if (numDims_ != rhs.numDims_) { |
307 | return false; |
308 | } |
309 | if (serializedLayout_ != rhs.serializedLayout_) { |
310 | return false; |
311 | } |
312 | return true; |
313 | } |
314 | |
315 | static bool isAnyHelper(llvm::StringRef layout) { |
316 | for (unsigned idx = 0, e = layout.size(); idx < e; ++idx) { |
317 | if (layout[idx] != '*') { |
318 | return false; |
319 | } |
320 | } |
321 | return true; |
322 | } |
323 | |
324 | bool TensorLayoutDescription::isAnyLayout() { |
325 | return (isAnyHelper(getSerializedLayout())); |
326 | } |
327 | |
328 | /// Definitions of different tensor layouts. |
329 | static std::string dimsNHWC[] = { |
330 | {"N" }, |
331 | {"H" }, |
332 | {"W" }, |
333 | {"C" }, |
334 | }; |
335 | static std::string dimsNCHW[] = { |
336 | {"N" }, |
337 | {"C" }, |
338 | {"H" }, |
339 | {"W" }, |
340 | }; |
341 | static std::string dimsHWNC[] = { |
342 | {"H" }, |
343 | {"W" }, |
344 | {"N" }, |
345 | {"C" }, |
346 | }; |
347 | static std::string dimsCNHW[] = { |
348 | {"C" }, |
349 | {"N" }, |
350 | {"H" }, |
351 | {"W" }, |
352 | }; |
353 | static std::string dims0D[]{ |
354 | {"" }, |
355 | }; |
356 | static std::string dims1D[] = { |
357 | {"N" }, |
358 | }; |
359 | static std::string dims2D[] = { |
360 | {"*" }, |
361 | {"*" }, |
362 | }; |
363 | static std::string dims3D[] = { |
364 | {"*" }, |
365 | {"*" }, |
366 | {"*" }, |
367 | }; |
368 | static std::string dims4D[] = { |
369 | {"*" }, |
370 | {"*" }, |
371 | {"*" }, |
372 | {"*" }, |
373 | }; |
374 | static std::string dims5D[] = { |
375 | {"*" }, {"*" }, {"*" }, {"*" }, {"*" }, |
376 | }; |
377 | static std::string dims6D[] = { |
378 | {"*" }, {"*" }, {"*" }, {"*" }, {"*" }, {"*" }, |
379 | }; |
380 | |
381 | static TensorLayoutDescription layoutNHWC(dimsNHWC); |
382 | static TensorLayoutDescription layoutNCHW(dimsNCHW); |
383 | static TensorLayoutDescription layoutHWNC(dimsHWNC); |
384 | static TensorLayoutDescription layoutCNHW(dimsCNHW); |
385 | static TensorLayoutDescription layout0D(dims0D); |
386 | static TensorLayoutDescription layout1D(dims1D); |
387 | static TensorLayoutDescription layout2D(dims2D); |
388 | static TensorLayoutDescription layout3D(dims3D); |
389 | static TensorLayoutDescription layout4D(dims4D); |
390 | static TensorLayoutDescription layout5D(dims5D); |
391 | static TensorLayoutDescription layout6D(dims6D); |
392 | |
393 | /// Glow layouts for any specific number of dimensions. |
394 | static TensorLayoutDescription layoutsForDims[] = { |
395 | layout0D, layout1D, layout2D, layout3D, layout4D, layout5D, layout6D, |
396 | }; |
397 | |
398 | TensorLayoutCommon::TensorLayoutCommon() : enabled_(false) {} |
399 | |
400 | TensorLayoutCommon::TensorLayoutCommon(TensorLayoutCommon *ctxTensorLayout) |
401 | : TensorLayoutCommon() { |
402 | ctxTensorLayout_ = ctxTensorLayout; |
403 | } |
404 | |
405 | TensorLayoutCommon::~TensorLayoutCommon() {} |
406 | |
407 | LayoutNameToLayoutDescriptionTy & |
408 | TensorLayoutCommon::getLayoutNameToLayoutDescription() const { |
409 | if (ctxTensorLayout_) { |
410 | return ctxTensorLayout_->getLayoutNameToLayoutDescription(); |
411 | } |
412 | return layoutNameToLayoutDescription_; |
413 | } |
414 | |
415 | llvm::ArrayRef<TensorLayoutDescription> |
416 | TensorLayoutCommon::getLayoutsForDims() const { |
417 | if (ctxTensorLayout_) { |
418 | return ctxTensorLayout_->getLayoutsForDims(); |
419 | } |
420 | return llvm::makeArrayRef(layoutsForDims); |
421 | } |
422 | |
423 | static LayoutNameToLayoutDescriptionTy initLayoutNameToDescription() { |
424 | LayoutNameToLayoutDescriptionTy map; |
425 | map.insert(std::make_pair( |
426 | "NCHW" , glow::make_unique<TensorLayoutDescription>("NCHW" ))); |
427 | map.insert(std::make_pair( |
428 | "NHWC" , glow::make_unique<TensorLayoutDescription>("NHWC" ))); |
429 | map.insert(std::make_pair( |
430 | "HWNC" , glow::make_unique<TensorLayoutDescription>("HWNC" ))); |
431 | map.insert(std::make_pair( |
432 | "CNHW" , glow::make_unique<TensorLayoutDescription>("CNHW" ))); |
433 | map.insert( |
434 | std::make_pair("N" , glow::make_unique<TensorLayoutDescription>("N" ))); |
435 | return map; |
436 | } |
437 | |
438 | LayoutNameToLayoutDescriptionTy |
439 | TensorLayoutCommon::layoutNameToLayoutDescription_ = |
440 | initLayoutNameToDescription(); |
441 | |
442 | static TensorLayoutDescription *getLayoutFromName( |
443 | const std::string &name, |
444 | LayoutNameToLayoutDescriptionTy &layoutNameToLayoutDescription) { |
445 | if (isAnyHelper(name)) { |
446 | return nullptr; |
447 | } |
448 | auto it = layoutNameToLayoutDescription.find(name); |
449 | if (it != layoutNameToLayoutDescription.end()) { |
450 | return it->second.get(); |
451 | } |
452 | // Add new layout to map: |
453 | auto *ret = new TensorLayoutDescription(name); |
454 | if (ret->getNumDims() == 0) { |
455 | // empty / any layout. |
456 | delete ret; |
457 | ret = nullptr; |
458 | } |
459 | layoutNameToLayoutDescription.insert( |
460 | std::make_pair(name, std::unique_ptr<TensorLayoutDescription>(ret))); |
461 | return ret; |
462 | } |
463 | |
464 | std::string TensorLayoutCommon::getDefaultNDLayout(unsigned dims) const { |
465 | DCHECK_LE(dims, max_tensor_dimensions) << "Too many dimensions" ; |
466 | return getLayoutsForDims()[dims].getSerializedLayout(); |
467 | } |
468 | |
469 | std::string |
470 | TensorLayoutCommon::getNthInputLayoutRequirementsImpl(const Node *node, |
471 | size_t n) { |
472 | if (ctxTensorLayout_) { |
473 | return ctxTensorLayout_->getNthInputLayoutRequirementsImpl(node, n); |
474 | } |
475 | return getNthInputLayoutRequirements(node, n); |
476 | } |
477 | |
478 | std::string TensorLayoutCommon::getNthInputLayoutRequirements(const Node *node, |
479 | size_t n) { |
480 | DCHECK_LT(n, node->getNumInputs()) << "Wrong input number" ; |
481 | auto dims = node->getNthInput(n).getType()->dims(); |
482 | DCHECK_LE(dims.size(), max_tensor_dimensions) << "Too many dimensions" ; |
483 | if (const auto *TN = llvm::dyn_cast<TransposeNode>(node)) { |
484 | // The layout for the input of transpose is the same as the layout of the |
485 | // operation's result producing this input. |
486 | auto input = TN->getInput(); |
487 | return getNthResultLayoutRequirementsImpl(input.getNode(), |
488 | input.getResNo()); |
489 | } |
490 | if (const auto *QN = llvm::dyn_cast<QuantizeNode>(node)) { |
491 | auto input = QN->getInput(); |
492 | return getNthResultLayoutRequirementsImpl(input.getNode(), |
493 | input.getResNo()); |
494 | } |
495 | if (const auto *CTN = llvm::dyn_cast<ConvertToNode>(node)) { |
496 | auto input = CTN->getInput(); |
497 | return getNthResultLayoutRequirementsImpl(input.getNode(), |
498 | input.getResNo()); |
499 | } |
500 | if (const auto *QPN = llvm::dyn_cast<QuantizationProfileNode>(node)) { |
501 | switch (n) { |
502 | case QuantizationProfileNode::InputIndices::InputIdx: { |
503 | auto input = QPN->getInput(); |
504 | return getNthResultLayoutRequirementsImpl(input.getNode(), |
505 | input.getResNo()); |
506 | } |
507 | default: |
508 | return getLayoutsForDims()[dims.size()].getSerializedLayout(); |
509 | } |
510 | } |
511 | return getLayoutsForDims()[dims.size()].getSerializedLayout(); |
512 | } |
513 | |
514 | /// \returns The index of node \p N input \p in. NumInputs if not found. |
515 | static unsigned getInputIdx(const Node *N, NodeValue in) { |
516 | for (unsigned idx = 0, e = N->getNumInputs(); idx < e; ++idx) { |
517 | if (N->getNthInput(idx) == in) { |
518 | return idx; |
519 | } |
520 | } |
521 | return N->getNumInputs(); |
522 | } |
523 | |
524 | /// \returns true if getting the input's layout would cause an infinite loop. |
525 | static bool inputDoesNotKnowRequirements(const Node *node) { |
526 | switch (node->getKind()) { |
527 | case Kinded::Kind::TransposeNodeKind: |
528 | case Kinded::Kind::QuantizeNodeKind: |
529 | case Kinded::Kind::QuantizationProfileNodeKind: |
530 | case Kinded::Kind::ConvertToNodeKind: |
531 | return true; |
532 | default: |
533 | return false; |
534 | } |
535 | } |
536 | |
537 | std::string |
538 | TensorLayoutCommon::getNthResultLayoutRequirementsImpl(const Node *node, |
539 | size_t n) { |
540 | if (ctxTensorLayout_) { |
541 | return ctxTensorLayout_->getNthResultLayoutRequirementsImpl(node, n); |
542 | } |
543 | return getNthResultLayoutRequirements(node, n); |
544 | } |
545 | |
546 | std::string TensorLayoutCommon::getNthResultLayoutRequirements(const Node *node, |
547 | size_t n) { |
548 | DCHECK_LT(n, node->getNumResults()) << "Wrong output number" ; |
549 | auto dims = node->getNthResult(n).getType()->dims(); |
550 | DCHECK_LE(dims.size(), max_tensor_dimensions) << "Too many dimensions" ; |
551 | if (auto *TN = llvm::dyn_cast<TransposeNode>(node)) { |
552 | // If the result of Transpose is a concrete layout, try to use this specific |
553 | // layout. |
554 | if (auto *layout = getLayoutFromName(TN->getLayout(), |
555 | getLayoutNameToLayoutDescription())) { |
556 | return layout->getSerializedLayout(); |
557 | } |
558 | // Dynamically form the layout description for transposes. |
559 | auto input = TN->getInput(); |
560 | while (inputDoesNotKnowRequirements(input)) { |
561 | input = input.getNode()->getNthInput(0); |
562 | } |
563 | auto inputLayout = |
564 | getNthInputLayoutRequirementsImpl(node, TransposeNode::InputIdx); |
565 | auto inputLayoutHelper = TensorLayoutDescription(inputLayout); |
566 | llvm::SmallVector<std::string, max_tensor_dimensions> dims( |
567 | input.dims().size()); |
568 | auto shuffle = TN->getShuffle(); |
569 | for (unsigned idx = 0, e = inputLayoutHelper.getNumDims(); idx < e; ++idx) { |
570 | dims[shuffle[idx]] = inputLayoutHelper.getNthDimDescription(idx).str(); |
571 | } |
572 | TensorLayoutDescription tld(dims); |
573 | return tld.getSerializedLayout(); |
574 | } |
575 | if (auto *C = llvm::dyn_cast<Constant>(node)) { |
576 | if (auto *layout = getLayoutFromName(C->getLayout(), |
577 | getLayoutNameToLayoutDescription())) { |
578 | return layout->getSerializedLayout(); |
579 | } |
580 | } |
581 | if (auto *PH = llvm::dyn_cast<Placeholder>(node)) { |
582 | if (auto *layout = getLayoutFromName(PH->getLayout(), |
583 | getLayoutNameToLayoutDescription())) { |
584 | return layout->getSerializedLayout(); |
585 | } |
586 | } |
587 | if (auto *RN = llvm::dyn_cast<ReshapeNode>(node)) { |
588 | if (auto *layout = getLayoutFromName(RN->getLayout(), |
589 | getLayoutNameToLayoutDescription())) { |
590 | return layout->getSerializedLayout(); |
591 | } |
592 | auto result = node->getNthResult(n); |
593 | auto *user = (*result.getUsers().begin()).getUser(); |
594 | unsigned inputIdx = getInputIdx(user, result); |
595 | if (inputDoesNotKnowRequirements(user) || |
596 | inputIdx >= user->getNumInputs() || llvm::isa<TransposeNode>(user)) { |
597 | return getLayoutsForDims()[dims.size()].getSerializedLayout(); |
598 | } |
599 | auto layout = getNthInputLayoutRequirementsImpl(user, inputIdx); |
600 | if (auto *layoutDesc = |
601 | getLayoutFromName(layout, getLayoutNameToLayoutDescription())) { |
602 | return layoutDesc->getSerializedLayout(); |
603 | } |
604 | } |
605 | return getLayoutsForDims()[dims.size()].getSerializedLayout(); |
606 | } |
607 | |
608 | bool TensorLayoutCommon::isSatisfiedBy( |
609 | TypeRef ty, const TensorLayoutDescription &destLayout, |
610 | const TensorLayoutDescription *srcLayout) const { |
611 | // Strides of the type (in elements). |
612 | auto strides = ty->strides(); |
613 | if (strides.size() != destLayout.getNumDims()) { |
614 | return false; |
615 | } |
616 | unsigned idx = 0; |
617 | for (const auto &dim : destLayout.getDims()) { |
618 | // dim.alignment is in bytes, but strides are in elements. |
619 | if (strides[idx] * ty->getElementSize() % destLayout.getAlignment(dim) != |
620 | 0) { |
621 | return false; |
622 | } |
623 | idx++; |
624 | } |
625 | if (!srcLayout) { |
626 | return true; |
627 | } |
628 | if (destLayout.getNumDims() != srcLayout->getNumDims()) { |
629 | return false; |
630 | } |
631 | // Names should be compatible. * is compatible to anything. |
632 | if (srcLayout->getSerializedLayout().size() != |
633 | destLayout.getSerializedLayout().size()) { |
634 | return false; |
635 | } |
636 | for (unsigned idx = 0, e = destLayout.getSerializedLayout().size(); idx < e; |
637 | ++idx) { |
638 | // '*' is compatible with anything. |
639 | if (destLayout.getSerializedLayout()[idx] == '*' || |
640 | srcLayout->getSerializedLayout()[idx] == '*') { |
641 | continue; |
642 | } |
643 | // Non-'*' are only compatible with themselves. |
644 | if (srcLayout->getSerializedLayout()[idx] == |
645 | destLayout.getSerializedLayout()[idx]) { |
646 | continue; |
647 | } |
648 | return false; |
649 | } |
650 | return true; |
651 | } |
652 | |
653 | static std::string returnBaseReqOrNHWC(std::string baseReq, const Node *node) { |
654 | auto baseReqHelper = TensorLayoutDescription(baseReq); |
655 | if (!baseReqHelper.isSameLayout( |
656 | CanonicalTensorLayout::getInstance().getLayoutsForDims()[4])) { |
657 | return baseReq; |
658 | } |
659 | if (CanonicalTensorLayout::getInstance().acceptsAnyLayout(node)) { |
660 | // These nodes accept any 4-D layout. |
661 | return baseReqHelper.getSerializedLayout(); |
662 | } |
663 | // NHWC is the canonical default |
664 | return CanonicalTensorLayout::getInstance().getDefaultNDLayout(4); |
665 | } |
666 | |
667 | std::string |
668 | CanonicalTensorLayout::getNthInputLayoutRequirements(const Node *node, |
669 | size_t n) { |
670 | auto baseReq = TensorLayoutCommon::getNthInputLayoutRequirements(node, n); |
671 | if (acceptsAnyLayout(node)) { |
672 | return baseReq; |
673 | } |
674 | return returnBaseReqOrNHWC(baseReq, node); |
675 | } |
676 | |
677 | std::string |
678 | CanonicalTensorLayout::getNthResultLayoutRequirements(const Node *node, |
679 | size_t n) { |
680 | auto baseReq = TensorLayoutCommon::getNthResultLayoutRequirements(node, n); |
681 | return returnBaseReqOrNHWC(baseReq, node); |
682 | } |
683 | |
684 | std::string CanonicalTensorLayout::getDefaultNDLayout(unsigned dims) const { |
685 | if (dims == 4) { |
686 | return layoutNHWC.getSerializedLayout(); |
687 | } |
688 | return TensorLayoutCommon::getDefaultNDLayout(dims); |
689 | } |
690 | |
691 | static bool acceptsAnyInputLayout(const glow::Node *node) { |
692 | switch (node->getKind()) { |
693 | case Kinded::Kind::ConcatNodeKind: |
694 | case Kinded::Kind::BatchedReduceMeanNodeKind: |
695 | case Kinded::Kind::BatchedAddNodeKind: |
696 | case Kinded::Kind::BatchedReduceAddNodeKind: |
697 | case Kinded::Kind::BatchedReduceMinNodeKind: |
698 | case Kinded::Kind::BatchedReduceMaxNodeKind: |
699 | case Kinded::Kind::BatchNormalizationNodeKind: |
700 | case Kinded::Kind::InstanceNormalizationNodeKind: |
701 | case Kinded::Kind::BatchNormalizationGradNodeKind: |
702 | case Kinded::Kind::PadNodeKind: |
703 | case Kinded::Kind::ReshapeNodeKind: |
704 | case Kinded::Kind::MeanVarNormalizationNodeKind: |
705 | case Kinded::Kind::MatMulNodeKind: |
706 | case Kinded::Kind::FlipNodeKind: |
707 | case Kinded::Kind::SliceNodeKind: |
708 | case Kinded::Kind::TileNodeKind: |
709 | case Kinded::Kind::InsertTensorNodeKind: |
710 | case Kinded::Kind::SGDNodeKind: |
711 | case Kinded::Kind::BroadcastNodeKind: |
712 | case Kinded::Kind::GaussianFillNodeKind: |
713 | case Kinded::Kind::SpaceToDepthNodeKind: |
714 | case Kinded::Kind::ChannelwiseQuantizedConvolutionNodeKind: |
715 | return true; |
716 | default: |
717 | return false; |
718 | } |
719 | } |
720 | |
721 | bool CanonicalTensorLayout::acceptsAnyLayout(const Node *node) const { |
722 | if (node->isDataParallel()) { |
723 | return true; |
724 | } |
725 | // In the canonical representation, some nodes are input layout agnostic even |
726 | // if they are not necessarily data parallel: |
727 | return acceptsAnyInputLayout(node); |
728 | } |
729 | |