1 | /** |
2 | * Copyright (c) 2017-present, Facebook, Inc. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | #ifndef GLOW_GRAPH_TENSORLAYOUT_H |
17 | #define GLOW_GRAPH_TENSORLAYOUT_H |
18 | |
19 | #include <memory> |
20 | #include <string> |
21 | |
22 | #include "glow/Graph/Nodes.h" |
23 | #include "glow/Support/Error.h" |
24 | |
25 | namespace glow { |
26 | |
27 | /// Layout requirements's Singleton. |
28 | template <typename T> class TensorLayoutSingleton { |
29 | public: |
30 | /// This is how the verifier, Backend and post-loading canonicalizer can |
31 | /// access layout constraints. |
32 | static T &getInstance() { |
33 | // The Ctor will only be called once. |
34 | static const std::unique_ptr<T> instance{new T{token_{}}}; |
35 | return *instance; |
36 | } |
37 | |
38 | protected: |
39 | /// Allow the base class to call any subclass's constructor. |
40 | struct token_ {}; |
41 | |
42 | /// Default Ctor. |
43 | TensorLayoutSingleton() {} |
44 | |
45 | /// Dtor. |
46 | virtual ~TensorLayoutSingleton() {} |
47 | |
48 | private: |
49 | /// Delete copy constructor. |
50 | TensorLayoutSingleton(const TensorLayoutSingleton &) = delete; |
51 | |
52 | /// Delete move constructor. |
53 | TensorLayoutSingleton(TensorLayoutSingleton &&) = delete; |
54 | |
55 | /// Delete copy assignment. |
56 | TensorLayoutSingleton &operator=(const TensorLayoutSingleton &) = delete; |
57 | |
58 | /// Delete move assignment. |
59 | TensorLayoutSingleton &operator=(TensorLayoutSingleton &&) = delete; |
60 | }; |
61 | |
62 | /// TensorLayoutDescription - optional helper class for parsing string-based |
63 | /// layout. |
64 | class TensorLayoutDescription { |
65 | /// Tensor dimensions descriptions for all dimensions. |
66 | std::string dims_[max_tensor_dimensions]; |
67 | /// The serialization of the layout. |
68 | std::string serializedLayout_; |
69 | /// Expected number of dimensions. |
70 | size_t numDims_; |
71 | |
72 | public: |
73 | virtual ~TensorLayoutDescription() = default; |
74 | /// Constructs this helper class from a serialized string representation. |
75 | TensorLayoutDescription(const std::string &layoutStr); |
76 | /// Constructs this helper class from an array of strings representing each |
77 | /// individual / pre-separated dimension. |
78 | TensorLayoutDescription(llvm::ArrayRef<std::string> dims); |
79 | /// \returns the alignment of a dimension \p n. |
80 | size_t getAlignment(size_t n) const; |
81 | /// \returns the alignment by parsing dimension string \p s. |
82 | size_t getAlignment(const std::string &s) const; |
83 | /// sets the alignment of dimension \p n to the value \p align. \returns the |
84 | /// new layout serialization for the current dimension. |
85 | llvm::StringRef setAlignment(size_t n, size_t align); |
86 | /// \returns the value of the attribute \p name of a dimension \p n. |
87 | std::string getAttribute(size_t n, llvm::StringRef name) const; |
88 | /// sets the value of attribute \p name to the value \p value. \returns the |
89 | /// new layout serialization for the current dimension. |
90 | llvm::StringRef setAttribute(size_t n, llvm::StringRef name, |
91 | llvm::StringRef value); |
92 | /// \returns true if both tensor layouts are the same. |
93 | bool isSameLayout(const TensorLayoutDescription &rhs) const; |
94 | /// \returns description of the dimension \p n. |
95 | const llvm::StringRef getNthDimDescription(size_t n) const; |
96 | /// \returns the description of all dimensions. |
97 | llvm::ArrayRef<std::string> getDims() const; |
98 | /// \returns number of dimensions. |
99 | size_t getNumDims() const { return numDims_; } |
100 | /// \returns layout name. |
101 | std::string getSerializedLayout() const { return serializedLayout_; } |
102 | /// \returns true if the layout is "*" in all dimensions. |
103 | bool isAnyLayout(); |
104 | std::string getDebugDesc() const; |
105 | |
106 | protected: |
107 | /// parse helper: get the custom extensions information. the default, virtual, |
108 | /// implementation just ignores all the data until the end token. |
109 | virtual void parseCustomExtensions(llvm::StringRef &text, unsigned idx); |
110 | |
111 | private: |
112 | /// Constructor helper: Parses the serialized string. |
113 | void parse(llvm::StringRef text); |
114 | |
115 | /// parse helper: get the official extensions information. |
116 | void parseOfficialExtensions(llvm::StringRef &text, unsigned idx); |
117 | |
118 | /// Modifies \p dimStr to remove an extension starting with the prefix \p |
119 | /// name. |
120 | void removeAttribute(const std::string &name, std::string &dimStr); |
121 | |
122 | /// Rebuilds serializedLayout_ from scratch. |
123 | void reconstructSerialized(); |
124 | }; |
125 | |
126 | /// A type to map layout names to layout descriptions. |
127 | using LayoutNameToLayoutDescriptionTy = |
128 | std::unordered_map<std::string, std::unique_ptr<TensorLayoutDescription>>; |
129 | |
130 | /// Interface for finding out layout requirements. |
131 | class TensorLayoutCommon { |
132 | public: |
133 | /// \return the default n-D layout for Glow. |
134 | virtual std::string getDefaultNDLayout(unsigned dims) const; |
135 | |
136 | /// \returns layout requirements of the Nth input \p n of a Node \p node. |
137 | virtual std::string getNthInputLayoutRequirements(const Node *node, size_t n); |
138 | |
139 | /// \returns layout requirements of the Nth result \p n of a Node \p node. |
140 | virtual std::string getNthResultLayoutRequirements(const Node *node, |
141 | size_t n); |
142 | |
143 | /// \returns layout requirements of the Nth input \p n of a Node \p node. |
144 | /// Delegates to \p getNthInputLayoutRequirementsImpl from \p ctxTensorLayout_ |
145 | /// if it is non-nullptr, or to \p getNthInputLayoutRequirements from the |
146 | /// current TensorLayoutCommon otherwise. |
147 | virtual std::string getNthInputLayoutRequirementsImpl(const Node *node, |
148 | size_t n); |
149 | |
150 | /// \returns layout requirements of the Nth result \p n of a Node \p node. |
151 | /// Delegates to \p getNthResultLayoutRequirementsImpl from \p |
152 | /// ctxTensorLayout_ if it is non-nullptr, or to \p |
153 | /// getNthResultLayoutRequirements from the current TensorLayoutCommon |
154 | /// otherwise. |
155 | virtual std::string getNthResultLayoutRequirementsImpl(const Node *node, |
156 | size_t n); |
157 | |
158 | /// \returns true if type \p ty satisfies the \p destLayout layout. If \p |
159 | /// srcLayout is provided, it is taken into account as well. |
160 | virtual bool isSatisfiedBy(TypeRef ty, |
161 | const TensorLayoutDescription &destLayout, |
162 | const TensorLayoutDescription *srcLayout) const; |
163 | |
164 | /// \return layouts for all tensor dimensions. |
165 | virtual llvm::ArrayRef<TensorLayoutDescription> getLayoutsForDims() const; |
166 | |
167 | /// \returns mapping from layout names to layout descriptions. |
168 | virtual LayoutNameToLayoutDescriptionTy & |
169 | getLayoutNameToLayoutDescription() const; |
170 | |
171 | /// \returns true if layout equirement verification is enabled. |
172 | bool isEnabled() const { return enabled_; } |
173 | |
174 | protected: |
175 | TensorLayoutCommon(); |
176 | TensorLayoutCommon(TensorLayoutCommon *ctxTensorLayout); |
177 | TensorLayoutCommon(TensorLayoutCommon &&) = delete; |
178 | TensorLayoutCommon &operator=(const TensorLayoutCommon &) = delete; |
179 | TensorLayoutCommon &operator=(TensorLayoutCommon &&) = delete; |
180 | virtual ~TensorLayoutCommon(); |
181 | |
182 | protected: |
183 | bool enabled_; |
184 | |
185 | private: |
186 | /// TensorLayout to be used for recursive calls. |
187 | TensorLayoutCommon *ctxTensorLayout_{nullptr}; |
188 | /// Mapping from layout names to layout descriptions. |
189 | static LayoutNameToLayoutDescriptionTy layoutNameToLayoutDescription_; |
190 | }; |
191 | |
192 | class CanonicalTensorLayout final |
193 | : public TensorLayoutCommon, |
194 | public TensorLayoutSingleton<CanonicalTensorLayout> { |
195 | public: |
196 | CanonicalTensorLayout(token_) {} |
197 | CanonicalTensorLayout(TensorLayoutCommon *ctxTensorLayout) |
198 | : TensorLayoutCommon(ctxTensorLayout) {} |
199 | |
200 | /// \return the default n-D layout for Glow. |
201 | std::string getDefaultNDLayout(unsigned dims) const override; |
202 | |
203 | /// \returns layout requirements of the Nth input \p n of a Node \p node. |
204 | /// NOTE: Certain nodes are layout agnostic. Others expect their |
205 | /// inputs/outputs to have a canonical format. For some layout agnostic nodes |
206 | /// we need to look at the layout of their inputs to determine the layout of |
207 | /// their outputs, e.g. a batch norm. node, in the canonical representation, |
208 | /// accepts any input layout such as NCHW or NHWC, but, the output is a |
209 | /// propoagation of said layout. |
210 | std::string getNthInputLayoutRequirements(const Node *node, |
211 | size_t n) override; |
212 | |
213 | /// \returns layout requirements of the Nth result \p n of a Node \p node. |
214 | std::string getNthResultLayoutRequirements(const Node *node, |
215 | size_t n) override; |
216 | |
217 | /// \returns true of the node accepts any layout. |
218 | bool acceptsAnyLayout(const Node *node) const; |
219 | }; |
220 | |
221 | /// Checks if two layout descriptions \p lhs and \p rhs describe the same layout |
222 | /// for a value of the type \p ty \returns true if layouts are the same. if \p |
223 | /// verbose then print out verbose report. |
224 | bool checkSameLayout(llvm::StringRef srcLayoutStr, |
225 | llvm::StringRef destLayoutStr, TypeRef ty, |
226 | const Node *parent, const std::string &prefix, |
227 | const TensorLayoutCommon &TLC, bool verbose = true); |
228 | |
229 | /// Verifies the correctness of tensor layouts in the function \p F using layout |
230 | /// requirements interface \p TLC. if \p verbose then print out verbose report. |
231 | bool verifyLayouts(const Function &F, TensorLayoutCommon &TLC, |
232 | bool verbose = true); |
233 | |
234 | } // end namespace glow |
235 | |
236 | #endif // GLOW_GRAPH_TENSORLAYOUT_H |
237 | |