1/**
2 * Copyright (c) 2017-present, Facebook, Inc.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16#ifndef GLOW_GRAPH_TENSORLAYOUT_H
17#define GLOW_GRAPH_TENSORLAYOUT_H
18
19#include <memory>
20#include <string>
21
22#include "glow/Graph/Nodes.h"
23#include "glow/Support/Error.h"
24
25namespace glow {
26
27/// Layout requirements's Singleton.
28template <typename T> class TensorLayoutSingleton {
29public:
30 /// This is how the verifier, Backend and post-loading canonicalizer can
31 /// access layout constraints.
32 static T &getInstance() {
33 // The Ctor will only be called once.
34 static const std::unique_ptr<T> instance{new T{token_{}}};
35 return *instance;
36 }
37
38protected:
39 /// Allow the base class to call any subclass's constructor.
40 struct token_ {};
41
42 /// Default Ctor.
43 TensorLayoutSingleton() {}
44
45 /// Dtor.
46 virtual ~TensorLayoutSingleton() {}
47
48private:
49 /// Delete copy constructor.
50 TensorLayoutSingleton(const TensorLayoutSingleton &) = delete;
51
52 /// Delete move constructor.
53 TensorLayoutSingleton(TensorLayoutSingleton &&) = delete;
54
55 /// Delete copy assignment.
56 TensorLayoutSingleton &operator=(const TensorLayoutSingleton &) = delete;
57
58 /// Delete move assignment.
59 TensorLayoutSingleton &operator=(TensorLayoutSingleton &&) = delete;
60};
61
62/// TensorLayoutDescription - optional helper class for parsing string-based
63/// layout.
64class TensorLayoutDescription {
65 /// Tensor dimensions descriptions for all dimensions.
66 std::string dims_[max_tensor_dimensions];
67 /// The serialization of the layout.
68 std::string serializedLayout_;
69 /// Expected number of dimensions.
70 size_t numDims_;
71
72public:
73 virtual ~TensorLayoutDescription() = default;
74 /// Constructs this helper class from a serialized string representation.
75 TensorLayoutDescription(const std::string &layoutStr);
76 /// Constructs this helper class from an array of strings representing each
77 /// individual / pre-separated dimension.
78 TensorLayoutDescription(llvm::ArrayRef<std::string> dims);
79 /// \returns the alignment of a dimension \p n.
80 size_t getAlignment(size_t n) const;
81 /// \returns the alignment by parsing dimension string \p s.
82 size_t getAlignment(const std::string &s) const;
83 /// sets the alignment of dimension \p n to the value \p align. \returns the
84 /// new layout serialization for the current dimension.
85 llvm::StringRef setAlignment(size_t n, size_t align);
86 /// \returns the value of the attribute \p name of a dimension \p n.
87 std::string getAttribute(size_t n, llvm::StringRef name) const;
88 /// sets the value of attribute \p name to the value \p value. \returns the
89 /// new layout serialization for the current dimension.
90 llvm::StringRef setAttribute(size_t n, llvm::StringRef name,
91 llvm::StringRef value);
92 /// \returns true if both tensor layouts are the same.
93 bool isSameLayout(const TensorLayoutDescription &rhs) const;
94 /// \returns description of the dimension \p n.
95 const llvm::StringRef getNthDimDescription(size_t n) const;
96 /// \returns the description of all dimensions.
97 llvm::ArrayRef<std::string> getDims() const;
98 /// \returns number of dimensions.
99 size_t getNumDims() const { return numDims_; }
100 /// \returns layout name.
101 std::string getSerializedLayout() const { return serializedLayout_; }
102 /// \returns true if the layout is "*" in all dimensions.
103 bool isAnyLayout();
104 std::string getDebugDesc() const;
105
106protected:
107 /// parse helper: get the custom extensions information. the default, virtual,
108 /// implementation just ignores all the data until the end token.
109 virtual void parseCustomExtensions(llvm::StringRef &text, unsigned idx);
110
111private:
112 /// Constructor helper: Parses the serialized string.
113 void parse(llvm::StringRef text);
114
115 /// parse helper: get the official extensions information.
116 void parseOfficialExtensions(llvm::StringRef &text, unsigned idx);
117
118 /// Modifies \p dimStr to remove an extension starting with the prefix \p
119 /// name.
120 void removeAttribute(const std::string &name, std::string &dimStr);
121
122 /// Rebuilds serializedLayout_ from scratch.
123 void reconstructSerialized();
124};
125
126/// A type to map layout names to layout descriptions.
127using LayoutNameToLayoutDescriptionTy =
128 std::unordered_map<std::string, std::unique_ptr<TensorLayoutDescription>>;
129
130/// Interface for finding out layout requirements.
131class TensorLayoutCommon {
132public:
133 /// \return the default n-D layout for Glow.
134 virtual std::string getDefaultNDLayout(unsigned dims) const;
135
136 /// \returns layout requirements of the Nth input \p n of a Node \p node.
137 virtual std::string getNthInputLayoutRequirements(const Node *node, size_t n);
138
139 /// \returns layout requirements of the Nth result \p n of a Node \p node.
140 virtual std::string getNthResultLayoutRequirements(const Node *node,
141 size_t n);
142
143 /// \returns layout requirements of the Nth input \p n of a Node \p node.
144 /// Delegates to \p getNthInputLayoutRequirementsImpl from \p ctxTensorLayout_
145 /// if it is non-nullptr, or to \p getNthInputLayoutRequirements from the
146 /// current TensorLayoutCommon otherwise.
147 virtual std::string getNthInputLayoutRequirementsImpl(const Node *node,
148 size_t n);
149
150 /// \returns layout requirements of the Nth result \p n of a Node \p node.
151 /// Delegates to \p getNthResultLayoutRequirementsImpl from \p
152 /// ctxTensorLayout_ if it is non-nullptr, or to \p
153 /// getNthResultLayoutRequirements from the current TensorLayoutCommon
154 /// otherwise.
155 virtual std::string getNthResultLayoutRequirementsImpl(const Node *node,
156 size_t n);
157
158 /// \returns true if type \p ty satisfies the \p destLayout layout. If \p
159 /// srcLayout is provided, it is taken into account as well.
160 virtual bool isSatisfiedBy(TypeRef ty,
161 const TensorLayoutDescription &destLayout,
162 const TensorLayoutDescription *srcLayout) const;
163
164 /// \return layouts for all tensor dimensions.
165 virtual llvm::ArrayRef<TensorLayoutDescription> getLayoutsForDims() const;
166
167 /// \returns mapping from layout names to layout descriptions.
168 virtual LayoutNameToLayoutDescriptionTy &
169 getLayoutNameToLayoutDescription() const;
170
171 /// \returns true if layout equirement verification is enabled.
172 bool isEnabled() const { return enabled_; }
173
174protected:
175 TensorLayoutCommon();
176 TensorLayoutCommon(TensorLayoutCommon *ctxTensorLayout);
177 TensorLayoutCommon(TensorLayoutCommon &&) = delete;
178 TensorLayoutCommon &operator=(const TensorLayoutCommon &) = delete;
179 TensorLayoutCommon &operator=(TensorLayoutCommon &&) = delete;
180 virtual ~TensorLayoutCommon();
181
182protected:
183 bool enabled_;
184
185private:
186 /// TensorLayout to be used for recursive calls.
187 TensorLayoutCommon *ctxTensorLayout_{nullptr};
188 /// Mapping from layout names to layout descriptions.
189 static LayoutNameToLayoutDescriptionTy layoutNameToLayoutDescription_;
190};
191
192class CanonicalTensorLayout final
193 : public TensorLayoutCommon,
194 public TensorLayoutSingleton<CanonicalTensorLayout> {
195public:
196 CanonicalTensorLayout(token_) {}
197 CanonicalTensorLayout(TensorLayoutCommon *ctxTensorLayout)
198 : TensorLayoutCommon(ctxTensorLayout) {}
199
200 /// \return the default n-D layout for Glow.
201 std::string getDefaultNDLayout(unsigned dims) const override;
202
203 /// \returns layout requirements of the Nth input \p n of a Node \p node.
204 /// NOTE: Certain nodes are layout agnostic. Others expect their
205 /// inputs/outputs to have a canonical format. For some layout agnostic nodes
206 /// we need to look at the layout of their inputs to determine the layout of
207 /// their outputs, e.g. a batch norm. node, in the canonical representation,
208 /// accepts any input layout such as NCHW or NHWC, but, the output is a
209 /// propoagation of said layout.
210 std::string getNthInputLayoutRequirements(const Node *node,
211 size_t n) override;
212
213 /// \returns layout requirements of the Nth result \p n of a Node \p node.
214 std::string getNthResultLayoutRequirements(const Node *node,
215 size_t n) override;
216
217 /// \returns true of the node accepts any layout.
218 bool acceptsAnyLayout(const Node *node) const;
219};
220
221/// Checks if two layout descriptions \p lhs and \p rhs describe the same layout
222/// for a value of the type \p ty \returns true if layouts are the same. if \p
223/// verbose then print out verbose report.
224bool checkSameLayout(llvm::StringRef srcLayoutStr,
225 llvm::StringRef destLayoutStr, TypeRef ty,
226 const Node *parent, const std::string &prefix,
227 const TensorLayoutCommon &TLC, bool verbose = true);
228
229/// Verifies the correctness of tensor layouts in the function \p F using layout
230/// requirements interface \p TLC. if \p verbose then print out verbose report.
231bool verifyLayouts(const Function &F, TensorLayoutCommon &TLC,
232 bool verbose = true);
233
234} // end namespace glow
235
236#endif // GLOW_GRAPH_TENSORLAYOUT_H
237