1/**
2 * Copyright (c) 2017-present, Facebook, Inc.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include <ctype.h>
18#include <memory>
19#include <sstream>
20
21#include <glog/logging.h>
22
23#include "glow/Graph/Graph.h"
24#include "glow/Graph/TensorLayout.h"
25#include "glow/Graph/VerifierHelper.h"
26
27using namespace glow;
28
29/// Checks if two layout descriptions \p lhs and \p rhs describe the same layout
30/// for a value of the type \p ty \returns true if layouts are the same.
31bool glow::checkSameLayout(llvm::StringRef srcLayoutStr,
32 llvm::StringRef destLayoutStr, TypeRef ty,
33 const Node *parent, const std::string &prefix,
34 const TensorLayoutCommon &TLC, bool verbose) {
35 auto srcLayout = TensorLayoutDescription(srcLayoutStr.str());
36 auto destLayout = TensorLayoutDescription(destLayoutStr.str());
37 // Are layouts literally the same?
38 if (srcLayout.isSameLayout(destLayout)) {
39 return true;
40 }
41 // Does the type satisfy the dest layout?
42 if (TLC.isSatisfiedBy(ty, destLayout, &srcLayout)) {
43 return true;
44 }
45 if (verbose) {
46 report("\n\n\n");
47 reportContext(parent);
48 report("\n");
49 report(prefix);
50 report("\n");
51 report(parent->getDebugDesc());
52 report("\nMismatching layouts:\n");
53 report("Provided layout\n");
54 report(srcLayout.getDebugDesc());
55 report("\n");
56 report("Expected layout\n");
57 report(destLayout.getDebugDesc());
58 report("\n");
59 }
60 return false;
61}
62
63/// Verifies the correctness of tensor layouts in the function \p F using layout
64/// requirements interface \p TLC.
65bool glow::verifyLayouts(const Function &F, TensorLayoutCommon &TLC,
66 bool verbose) {
67 bool isValid = true;
68 for (const auto &N : F.getNodes()) {
69 for (unsigned idx = 0, e = N.getNumInputs(); idx < e; ++idx) {
70 auto input = N.getNthInput(idx);
71 auto producerLayout =
72 TLC.getNthResultLayoutRequirements(input.getNode(), input.getResNo());
73 auto consumerLayout = TLC.getNthInputLayoutRequirements(&N, idx);
74 std::string inputName = strFormat("input %d", idx);
75 isValid &= checkSameLayout(producerLayout, consumerLayout,
76 input.getType(), &N, inputName, TLC, verbose);
77 }
78 }
79 return isValid;
80}
81
82TensorLayoutDescription::TensorLayoutDescription(const std::string &layoutStr) {
83 if (layoutStr.empty()) {
84 // 0-D output
85 numDims_ = 0;
86 return;
87 }
88 parse(layoutStr);
89}
90
91static bool isCustomExtension(llvm::StringRef text) {
92 auto nsPos = text.find(':');
93 if (nsPos == llvm::StringRef::npos) {
94 return false;
95 }
96 auto bracketPos = text.find(']');
97 assert(bracketPos != llvm::StringRef::npos && "Expected a closing bracket.");
98 return (bracketPos > nsPos);
99}
100
101// Serialization format -
102// The form for each dimension is as follows:
103// 1. (mandatory) one char representing the current dimension. Either an
104// alphabetic letter or '*'.
105// 2. (optional) token for the start of optional dimension information: '['
106// 3. (optional, must have 2. in place) namespace of the extension followed by
107// ':'. must be provided for non-official backends. example: ocl:<information>
108// 4. (optional, must have 2. in place) end of the current default extension
109// ']'
110// 5. (optional) go to 2.
111// NOTE: To add alignment information, the format is: a=<size_t>
112// Example: N[a=32][namespace_for_unsupported:<bla>]HWC would represent 4-D
113// tensor wherein N needs an alignment of 32 + some closed-backend requirements
114// we don't know about. HWC have no restrictions.
115// NOTES:
116// 1. For each dimension, the identifier can be either a single english alphabet
117// letter, either upper or lower case, or the star symbol.
118// 2. We assume that a single letter is enough for each dimension, it makes
119// parsing easier and avoids adding delimiters in the serialized format,
120// however, we do have a constructor that (theoretically) accepts multi-letter
121// dimensions. If we decide to expand the current support, we will need to add
122// delimiters to the serialized form.
123void TensorLayoutDescription::parse(llvm::StringRef text) {
124 unsigned idx = 0;
125 while (!text.empty()) {
126 char curr = text.front();
127 text = text.drop_front();
128 if (curr == '\0' || isblank(curr)) {
129 continue;
130 }
131 switch (curr) {
132 case '[': {
133 assert(idx > 0 && "Expected at least one parsed entry.");
134 if (isCustomExtension(text)) {
135 parseCustomExtensions(text, idx - 1);
136 } else {
137 parseOfficialExtensions(text, idx - 1);
138 }
139 break;
140 }
141 default: {
142 DCHECK(isalpha(curr) || curr == '*')
143 << "Expected an alphabetic letter or '*'., got: " << curr
144 << " in string: " << text.str();
145 std::string currStr(1, curr);
146 dims_[idx].append(currStr);
147 serializedLayout_.append(dims_[idx]);
148 ++idx;
149 assert(idx <= max_tensor_dimensions && "Too many tensor dimensions");
150 break;
151 }
152 }
153 }
154 numDims_ = idx;
155}
156
157void TensorLayoutDescription::parseCustomExtensions(llvm::StringRef &text,
158 unsigned idx) {
159 char curr = '[';
160 dims_[idx].append("[");
161 for (curr = text.front(); curr != ']' && !text.empty(); curr = text.front()) {
162 dims_[idx].append(std::string(1, curr));
163 text = text.drop_front();
164 }
165 assert(curr == ']' && "Expected closing ']' bracket.");
166 text = text.drop_front();
167 dims_[idx].append("]");
168}
169
170void TensorLayoutDescription::parseOfficialExtensions(llvm::StringRef &text,
171 unsigned idx) {
172 // Only alignment so far - very simple parser:
173 if (!text.consume_front("a=")) {
174 llvm_unreachable("Unsupported layout extension.");
175 }
176 size_t align;
177 if (text.consumeInteger(10, align)) {
178 llvm_unreachable("Expected alignment info.");
179 }
180 if (!text.consume_front("]")) {
181 llvm_unreachable("Expected closing ']'");
182 }
183 dims_[idx].append("[a=");
184 dims_[idx].append(std::to_string(align));
185 dims_[idx].append("]");
186}
187
188TensorLayoutDescription::TensorLayoutDescription(
189 llvm::ArrayRef<std::string> dims) {
190 assert(dims.size() <= max_tensor_dimensions && "Too many tensor dimensions");
191 numDims_ = dims.size();
192 for (unsigned idx = 0; idx < numDims_; ++idx) {
193 dims_[idx] = dims[idx];
194 serializedLayout_.append(dims_[idx]);
195 }
196}
197
198const llvm::StringRef
199TensorLayoutDescription::getNthDimDescription(size_t n) const {
200 assert(n < numDims_ && "Wrong dimension number");
201 return dims_[n];
202}
203
204size_t TensorLayoutDescription::getAlignment(size_t n) const {
205 assert(n < numDims_ && "Wrong dimension number");
206 return getAlignment(dims_[n]);
207}
208
209size_t TensorLayoutDescription::getAlignment(const std::string &s) const {
210 std::string alignPrefix = "a=";
211 size_t pos = s.find(alignPrefix);
212 if (pos == std::string::npos) {
213 // Default alignment:
214 return 1;
215 }
216 auto align = s.substr(pos + alignPrefix.size());
217 size_t ret;
218 std::istringstream(align) >> ret;
219 return ret;
220}
221
222/// \returns the position of ']' for extension at \p pos.
223static size_t getEndOFExtension(llvm::StringRef dimStr, size_t pos) {
224 size_t posEnd = pos;
225 pos = pos - 1;
226 assert(dimStr[pos] == '[' && "Expected start of align extension.");
227 while (dimStr[posEnd] != ']') {
228 ++posEnd;
229 assert(posEnd < dimStr.size() && "Expected to find closing bracket.");
230 }
231 return posEnd;
232}
233
234void TensorLayoutDescription::removeAttribute(const std::string &name,
235 std::string &dimStr) {
236 size_t pos = dimStr.find(name);
237 if (pos != std::string::npos) {
238 size_t posEnd = getEndOFExtension(dimStr, pos);
239 dimStr = dimStr.substr(0, pos - 1) + dimStr.substr(posEnd + 1);
240 }
241}
242
243void TensorLayoutDescription::reconstructSerialized() {
244 serializedLayout_ = "";
245 for (size_t i = 0; i < numDims_; ++i) {
246 serializedLayout_.append(dims_[i]);
247 }
248}
249
250llvm::StringRef TensorLayoutDescription::setAlignment(size_t n, size_t align) {
251 assert(n < numDims_ && "Wrong dimension number");
252 return setAttribute(n, "a=", std::to_string(align));
253}
254
255llvm::StringRef TensorLayoutDescription::setAttribute(size_t n,
256 llvm::StringRef name,
257 llvm::StringRef value) {
258 assert(n < numDims_ && "Wrong dimension number");
259 auto &dimStr = dims_[n];
260 // If we have a current name - remove it.
261 removeAttribute(name.str(), dimStr);
262 // Add new name information to dim:
263 dimStr.append("[");
264 dimStr.append(name.str());
265 dimStr.append(value.str());
266 dimStr.append("]");
267 reconstructSerialized();
268 return dimStr;
269}
270
271std::string TensorLayoutDescription::getAttribute(size_t n,
272 llvm::StringRef name) const {
273 assert(n < numDims_ && "Wrong dimension number");
274 size_t pos = dims_[n].find(name.str());
275 if (pos == std::string::npos) {
276 return "";
277 }
278 size_t posEnd = getEndOFExtension(dims_[n], pos);
279 auto nameSZ = name.size();
280 return dims_[n].substr(pos + nameSZ, posEnd - nameSZ - pos);
281}
282
283llvm::ArrayRef<std::string> TensorLayoutDescription::getDims() const {
284 return llvm::makeArrayRef(dims_, numDims_);
285}
286
287std::string TensorLayoutDescription::getDebugDesc() const {
288 std::string desc = "Layout: " + getSerializedLayout() + " [";
289 for (unsigned idx = 0; idx < numDims_; idx++) {
290 if (idx > 0) {
291 desc += ", ";
292 }
293 desc += "name = ";
294 desc += dims_[idx];
295 desc += " : alignment = ";
296 desc += std::to_string(getAlignment(idx));
297 desc += " : index = ";
298 desc += std::to_string(idx);
299 }
300 desc += "]";
301 return desc;
302}
303
304bool TensorLayoutDescription::isSameLayout(
305 const TensorLayoutDescription &rhs) const {
306 if (numDims_ != rhs.numDims_) {
307 return false;
308 }
309 if (serializedLayout_ != rhs.serializedLayout_) {
310 return false;
311 }
312 return true;
313}
314
315static bool isAnyHelper(llvm::StringRef layout) {
316 for (unsigned idx = 0, e = layout.size(); idx < e; ++idx) {
317 if (layout[idx] != '*') {
318 return false;
319 }
320 }
321 return true;
322}
323
324bool TensorLayoutDescription::isAnyLayout() {
325 return (isAnyHelper(getSerializedLayout()));
326}
327
328/// Definitions of different tensor layouts.
329static std::string dimsNHWC[] = {
330 {"N"},
331 {"H"},
332 {"W"},
333 {"C"},
334};
335static std::string dimsNCHW[] = {
336 {"N"},
337 {"C"},
338 {"H"},
339 {"W"},
340};
341static std::string dimsHWNC[] = {
342 {"H"},
343 {"W"},
344 {"N"},
345 {"C"},
346};
347static std::string dimsCNHW[] = {
348 {"C"},
349 {"N"},
350 {"H"},
351 {"W"},
352};
353static std::string dims0D[]{
354 {""},
355};
356static std::string dims1D[] = {
357 {"N"},
358};
359static std::string dims2D[] = {
360 {"*"},
361 {"*"},
362};
363static std::string dims3D[] = {
364 {"*"},
365 {"*"},
366 {"*"},
367};
368static std::string dims4D[] = {
369 {"*"},
370 {"*"},
371 {"*"},
372 {"*"},
373};
374static std::string dims5D[] = {
375 {"*"}, {"*"}, {"*"}, {"*"}, {"*"},
376};
377static std::string dims6D[] = {
378 {"*"}, {"*"}, {"*"}, {"*"}, {"*"}, {"*"},
379};
380
381static TensorLayoutDescription layoutNHWC(dimsNHWC);
382static TensorLayoutDescription layoutNCHW(dimsNCHW);
383static TensorLayoutDescription layoutHWNC(dimsHWNC);
384static TensorLayoutDescription layoutCNHW(dimsCNHW);
385static TensorLayoutDescription layout0D(dims0D);
386static TensorLayoutDescription layout1D(dims1D);
387static TensorLayoutDescription layout2D(dims2D);
388static TensorLayoutDescription layout3D(dims3D);
389static TensorLayoutDescription layout4D(dims4D);
390static TensorLayoutDescription layout5D(dims5D);
391static TensorLayoutDescription layout6D(dims6D);
392
393/// Glow layouts for any specific number of dimensions.
394static TensorLayoutDescription layoutsForDims[] = {
395 layout0D, layout1D, layout2D, layout3D, layout4D, layout5D, layout6D,
396};
397
398TensorLayoutCommon::TensorLayoutCommon() : enabled_(false) {}
399
400TensorLayoutCommon::TensorLayoutCommon(TensorLayoutCommon *ctxTensorLayout)
401 : TensorLayoutCommon() {
402 ctxTensorLayout_ = ctxTensorLayout;
403}
404
405TensorLayoutCommon::~TensorLayoutCommon() {}
406
407LayoutNameToLayoutDescriptionTy &
408TensorLayoutCommon::getLayoutNameToLayoutDescription() const {
409 if (ctxTensorLayout_) {
410 return ctxTensorLayout_->getLayoutNameToLayoutDescription();
411 }
412 return layoutNameToLayoutDescription_;
413}
414
415llvm::ArrayRef<TensorLayoutDescription>
416TensorLayoutCommon::getLayoutsForDims() const {
417 if (ctxTensorLayout_) {
418 return ctxTensorLayout_->getLayoutsForDims();
419 }
420 return llvm::makeArrayRef(layoutsForDims);
421}
422
423static LayoutNameToLayoutDescriptionTy initLayoutNameToDescription() {
424 LayoutNameToLayoutDescriptionTy map;
425 map.insert(std::make_pair(
426 "NCHW", glow::make_unique<TensorLayoutDescription>("NCHW")));
427 map.insert(std::make_pair(
428 "NHWC", glow::make_unique<TensorLayoutDescription>("NHWC")));
429 map.insert(std::make_pair(
430 "HWNC", glow::make_unique<TensorLayoutDescription>("HWNC")));
431 map.insert(std::make_pair(
432 "CNHW", glow::make_unique<TensorLayoutDescription>("CNHW")));
433 map.insert(
434 std::make_pair("N", glow::make_unique<TensorLayoutDescription>("N")));
435 return map;
436}
437
438LayoutNameToLayoutDescriptionTy
439 TensorLayoutCommon::layoutNameToLayoutDescription_ =
440 initLayoutNameToDescription();
441
442static TensorLayoutDescription *getLayoutFromName(
443 const std::string &name,
444 LayoutNameToLayoutDescriptionTy &layoutNameToLayoutDescription) {
445 if (isAnyHelper(name)) {
446 return nullptr;
447 }
448 auto it = layoutNameToLayoutDescription.find(name);
449 if (it != layoutNameToLayoutDescription.end()) {
450 return it->second.get();
451 }
452 // Add new layout to map:
453 auto *ret = new TensorLayoutDescription(name);
454 if (ret->getNumDims() == 0) {
455 // empty / any layout.
456 delete ret;
457 ret = nullptr;
458 }
459 layoutNameToLayoutDescription.insert(
460 std::make_pair(name, std::unique_ptr<TensorLayoutDescription>(ret)));
461 return ret;
462}
463
464std::string TensorLayoutCommon::getDefaultNDLayout(unsigned dims) const {
465 DCHECK_LE(dims, max_tensor_dimensions) << "Too many dimensions";
466 return getLayoutsForDims()[dims].getSerializedLayout();
467}
468
469std::string
470TensorLayoutCommon::getNthInputLayoutRequirementsImpl(const Node *node,
471 size_t n) {
472 if (ctxTensorLayout_) {
473 return ctxTensorLayout_->getNthInputLayoutRequirementsImpl(node, n);
474 }
475 return getNthInputLayoutRequirements(node, n);
476}
477
478std::string TensorLayoutCommon::getNthInputLayoutRequirements(const Node *node,
479 size_t n) {
480 DCHECK_LT(n, node->getNumInputs()) << "Wrong input number";
481 auto dims = node->getNthInput(n).getType()->dims();
482 DCHECK_LE(dims.size(), max_tensor_dimensions) << "Too many dimensions";
483 if (const auto *TN = llvm::dyn_cast<TransposeNode>(node)) {
484 // The layout for the input of transpose is the same as the layout of the
485 // operation's result producing this input.
486 auto input = TN->getInput();
487 return getNthResultLayoutRequirementsImpl(input.getNode(),
488 input.getResNo());
489 }
490 if (const auto *QN = llvm::dyn_cast<QuantizeNode>(node)) {
491 auto input = QN->getInput();
492 return getNthResultLayoutRequirementsImpl(input.getNode(),
493 input.getResNo());
494 }
495 if (const auto *CTN = llvm::dyn_cast<ConvertToNode>(node)) {
496 auto input = CTN->getInput();
497 return getNthResultLayoutRequirementsImpl(input.getNode(),
498 input.getResNo());
499 }
500 if (const auto *QPN = llvm::dyn_cast<QuantizationProfileNode>(node)) {
501 switch (n) {
502 case QuantizationProfileNode::InputIndices::InputIdx: {
503 auto input = QPN->getInput();
504 return getNthResultLayoutRequirementsImpl(input.getNode(),
505 input.getResNo());
506 }
507 default:
508 return getLayoutsForDims()[dims.size()].getSerializedLayout();
509 }
510 }
511 return getLayoutsForDims()[dims.size()].getSerializedLayout();
512}
513
514/// \returns The index of node \p N input \p in. NumInputs if not found.
515static unsigned getInputIdx(const Node *N, NodeValue in) {
516 for (unsigned idx = 0, e = N->getNumInputs(); idx < e; ++idx) {
517 if (N->getNthInput(idx) == in) {
518 return idx;
519 }
520 }
521 return N->getNumInputs();
522}
523
524/// \returns true if getting the input's layout would cause an infinite loop.
525static bool inputDoesNotKnowRequirements(const Node *node) {
526 switch (node->getKind()) {
527 case Kinded::Kind::TransposeNodeKind:
528 case Kinded::Kind::QuantizeNodeKind:
529 case Kinded::Kind::QuantizationProfileNodeKind:
530 case Kinded::Kind::ConvertToNodeKind:
531 return true;
532 default:
533 return false;
534 }
535}
536
537std::string
538TensorLayoutCommon::getNthResultLayoutRequirementsImpl(const Node *node,
539 size_t n) {
540 if (ctxTensorLayout_) {
541 return ctxTensorLayout_->getNthResultLayoutRequirementsImpl(node, n);
542 }
543 return getNthResultLayoutRequirements(node, n);
544}
545
546std::string TensorLayoutCommon::getNthResultLayoutRequirements(const Node *node,
547 size_t n) {
548 DCHECK_LT(n, node->getNumResults()) << "Wrong output number";
549 auto dims = node->getNthResult(n).getType()->dims();
550 DCHECK_LE(dims.size(), max_tensor_dimensions) << "Too many dimensions";
551 if (auto *TN = llvm::dyn_cast<TransposeNode>(node)) {
552 // If the result of Transpose is a concrete layout, try to use this specific
553 // layout.
554 if (auto *layout = getLayoutFromName(TN->getLayout(),
555 getLayoutNameToLayoutDescription())) {
556 return layout->getSerializedLayout();
557 }
558 // Dynamically form the layout description for transposes.
559 auto input = TN->getInput();
560 while (inputDoesNotKnowRequirements(input)) {
561 input = input.getNode()->getNthInput(0);
562 }
563 auto inputLayout =
564 getNthInputLayoutRequirementsImpl(node, TransposeNode::InputIdx);
565 auto inputLayoutHelper = TensorLayoutDescription(inputLayout);
566 llvm::SmallVector<std::string, max_tensor_dimensions> dims(
567 input.dims().size());
568 auto shuffle = TN->getShuffle();
569 for (unsigned idx = 0, e = inputLayoutHelper.getNumDims(); idx < e; ++idx) {
570 dims[shuffle[idx]] = inputLayoutHelper.getNthDimDescription(idx).str();
571 }
572 TensorLayoutDescription tld(dims);
573 return tld.getSerializedLayout();
574 }
575 if (auto *C = llvm::dyn_cast<Constant>(node)) {
576 if (auto *layout = getLayoutFromName(C->getLayout(),
577 getLayoutNameToLayoutDescription())) {
578 return layout->getSerializedLayout();
579 }
580 }
581 if (auto *PH = llvm::dyn_cast<Placeholder>(node)) {
582 if (auto *layout = getLayoutFromName(PH->getLayout(),
583 getLayoutNameToLayoutDescription())) {
584 return layout->getSerializedLayout();
585 }
586 }
587 if (auto *RN = llvm::dyn_cast<ReshapeNode>(node)) {
588 if (auto *layout = getLayoutFromName(RN->getLayout(),
589 getLayoutNameToLayoutDescription())) {
590 return layout->getSerializedLayout();
591 }
592 auto result = node->getNthResult(n);
593 auto *user = (*result.getUsers().begin()).getUser();
594 unsigned inputIdx = getInputIdx(user, result);
595 if (inputDoesNotKnowRequirements(user) ||
596 inputIdx >= user->getNumInputs() || llvm::isa<TransposeNode>(user)) {
597 return getLayoutsForDims()[dims.size()].getSerializedLayout();
598 }
599 auto layout = getNthInputLayoutRequirementsImpl(user, inputIdx);
600 if (auto *layoutDesc =
601 getLayoutFromName(layout, getLayoutNameToLayoutDescription())) {
602 return layoutDesc->getSerializedLayout();
603 }
604 }
605 return getLayoutsForDims()[dims.size()].getSerializedLayout();
606}
607
608bool TensorLayoutCommon::isSatisfiedBy(
609 TypeRef ty, const TensorLayoutDescription &destLayout,
610 const TensorLayoutDescription *srcLayout) const {
611 // Strides of the type (in elements).
612 auto strides = ty->strides();
613 if (strides.size() != destLayout.getNumDims()) {
614 return false;
615 }
616 unsigned idx = 0;
617 for (const auto &dim : destLayout.getDims()) {
618 // dim.alignment is in bytes, but strides are in elements.
619 if (strides[idx] * ty->getElementSize() % destLayout.getAlignment(dim) !=
620 0) {
621 return false;
622 }
623 idx++;
624 }
625 if (!srcLayout) {
626 return true;
627 }
628 if (destLayout.getNumDims() != srcLayout->getNumDims()) {
629 return false;
630 }
631 // Names should be compatible. * is compatible to anything.
632 if (srcLayout->getSerializedLayout().size() !=
633 destLayout.getSerializedLayout().size()) {
634 return false;
635 }
636 for (unsigned idx = 0, e = destLayout.getSerializedLayout().size(); idx < e;
637 ++idx) {
638 // '*' is compatible with anything.
639 if (destLayout.getSerializedLayout()[idx] == '*' ||
640 srcLayout->getSerializedLayout()[idx] == '*') {
641 continue;
642 }
643 // Non-'*' are only compatible with themselves.
644 if (srcLayout->getSerializedLayout()[idx] ==
645 destLayout.getSerializedLayout()[idx]) {
646 continue;
647 }
648 return false;
649 }
650 return true;
651}
652
653static std::string returnBaseReqOrNHWC(std::string baseReq, const Node *node) {
654 auto baseReqHelper = TensorLayoutDescription(baseReq);
655 if (!baseReqHelper.isSameLayout(
656 CanonicalTensorLayout::getInstance().getLayoutsForDims()[4])) {
657 return baseReq;
658 }
659 if (CanonicalTensorLayout::getInstance().acceptsAnyLayout(node)) {
660 // These nodes accept any 4-D layout.
661 return baseReqHelper.getSerializedLayout();
662 }
663 // NHWC is the canonical default
664 return CanonicalTensorLayout::getInstance().getDefaultNDLayout(4);
665}
666
667std::string
668CanonicalTensorLayout::getNthInputLayoutRequirements(const Node *node,
669 size_t n) {
670 auto baseReq = TensorLayoutCommon::getNthInputLayoutRequirements(node, n);
671 if (acceptsAnyLayout(node)) {
672 return baseReq;
673 }
674 return returnBaseReqOrNHWC(baseReq, node);
675}
676
677std::string
678CanonicalTensorLayout::getNthResultLayoutRequirements(const Node *node,
679 size_t n) {
680 auto baseReq = TensorLayoutCommon::getNthResultLayoutRequirements(node, n);
681 return returnBaseReqOrNHWC(baseReq, node);
682}
683
684std::string CanonicalTensorLayout::getDefaultNDLayout(unsigned dims) const {
685 if (dims == 4) {
686 return layoutNHWC.getSerializedLayout();
687 }
688 return TensorLayoutCommon::getDefaultNDLayout(dims);
689}
690
691static bool acceptsAnyInputLayout(const glow::Node *node) {
692 switch (node->getKind()) {
693 case Kinded::Kind::ConcatNodeKind:
694 case Kinded::Kind::BatchedReduceMeanNodeKind:
695 case Kinded::Kind::BatchedAddNodeKind:
696 case Kinded::Kind::BatchedReduceAddNodeKind:
697 case Kinded::Kind::BatchedReduceMinNodeKind:
698 case Kinded::Kind::BatchedReduceMaxNodeKind:
699 case Kinded::Kind::BatchNormalizationNodeKind:
700 case Kinded::Kind::InstanceNormalizationNodeKind:
701 case Kinded::Kind::BatchNormalizationGradNodeKind:
702 case Kinded::Kind::PadNodeKind:
703 case Kinded::Kind::ReshapeNodeKind:
704 case Kinded::Kind::MeanVarNormalizationNodeKind:
705 case Kinded::Kind::MatMulNodeKind:
706 case Kinded::Kind::FlipNodeKind:
707 case Kinded::Kind::SliceNodeKind:
708 case Kinded::Kind::TileNodeKind:
709 case Kinded::Kind::InsertTensorNodeKind:
710 case Kinded::Kind::SGDNodeKind:
711 case Kinded::Kind::BroadcastNodeKind:
712 case Kinded::Kind::GaussianFillNodeKind:
713 case Kinded::Kind::SpaceToDepthNodeKind:
714 case Kinded::Kind::ChannelwiseQuantizedConvolutionNodeKind:
715 return true;
716 default:
717 return false;
718 }
719}
720
721bool CanonicalTensorLayout::acceptsAnyLayout(const Node *node) const {
722 if (node->isDataParallel()) {
723 return true;
724 }
725 // In the canonical representation, some nodes are input layout agnostic even
726 // if they are not necessarily data parallel:
727 return acceptsAnyInputLayout(node);
728}
729