1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16#ifndef GLOW_GRAPH_NODES_H
17#define GLOW_GRAPH_NODES_H
18
19#include "glow/Base/Tensor.h"
20#include "glow/Base/Traits.h"
21#include "glow/Graph/Grad.h"
22#include "glow/Graph/Node.h"
23
24#include "llvm/ADT/Hashing.h"
25#include "llvm/Support/Casting.h"
26
27#include <tuple>
28
29namespace glow {
30
31// Storage is the base class for Constants, which are bound to tensors, and
32// Placeholder nodes which are unbound.
33class Storage : public Node {
34public:
35 enum ResultIndices {
36 OutputIdx = 0,
37 };
38
39 Storage(Kinded::Kind k, llvm::StringRef name, const std::string &layout)
40 : Node(k, name), layout_(layout) {}
41
42 /// \return the single output value of the node.
43 NodeValue getOutput() { return getNthResult(0); }
44 const NodeValue getOutput() const { return getNthResult(0); }
45
46 /// Declare the standard Node methods.
47 /// @{
48 void visit(Node *parent, NodeWalker *visitor);
49 void visit(const Node *parent, NodeWalker *visitor) const;
50 bool isEqual(const Storage &other) const;
51 unsigned getNumInputs() const;
52 std::string getInputName(unsigned idx) const;
53 NodeValue getNthInput(unsigned idx);
54 llvm::StringRef getOutputName(unsigned idx) const;
55 bool hasSideEffects() const;
56 bool isCanonical() const { return true; }
57 bool isDataParallel() const { return false; }
58 Node *clone() const;
59 /// @}
60
61 /// \returns result type of the storage.
62 TypeRef getType() const { return Node::getType(0); }
63
64 /// Methods that forward to the result type (that must be valid):
65 /// @{
66 ElemKind getElementType() const { return getType()->getElementType(); };
67 llvm::ArrayRef<dim_t> dims() const { return getType()->dims(); };
68 /// @}
69
70 static bool classof(const Kinded *k) {
71 return k->getKind() == Kinded::Kind::ConstantKind ||
72 k->getKind() == Kinded::Kind::PlaceholderKind;
73 }
74
75 /// \return the layout of the storage.
76 const std::string &getLayout() const { return layout_; }
77
78private:
79 /// Specifies the Storage's layout
80 const std::string layout_;
81};
82
83class Constant : public Storage {
84 /// The tensor payload that the constant holds.
85 Tensor payload_;
86
87public:
88 /// Create a new constant and initialize its payload.
89 Constant(llvm::StringRef name, TypeRef Ty, const std::string &layout)
90 : Storage(Kinded::Kind::ConstantKind, name, layout) {
91 addResult(Ty);
92 payload_.reset(*Ty);
93 }
94
95 Constant(llvm::StringRef name, Tensor &&payload, const std::string &layout)
96 : Storage(Kinded::Kind::ConstantKind, name, layout),
97 payload_(std::move(payload)) {
98 addResult(&payload_.getType());
99 }
100
101 static bool classof(const Kinded *k) {
102 return k->getKind() == Kinded::Kind::ConstantKind;
103 }
104
105 /// If payload is unowned, make an owned copy of the payload for
106 /// modification.
107 void ensureIsOwned() {
108 if (payload_.isUnowned()) {
109 payload_ = payload_.clone();
110 }
111 }
112
113 /// \returns a mutable reference to the payload tensor. If the payload tensor
114 /// is unowned then it will be converted to an owned copy before returning.
115 Tensor &getPayloadMutable() {
116 /// Make sure the payload is owned before handing out a mutable reference.
117 ensureIsOwned();
118
119 assert(!payload_.isUnowned() &&
120 "Can only modify Constants with owned payloads");
121 return payload_;
122 }
123
124 // Get an immutable reference to the payload tensor.
125 const Tensor &getPayload() const { return payload_; }
126
127 template <class ElemTy = float> Handle<ElemTy> getHandle() {
128 return getPayload().getHandle<ElemTy>();
129 }
130
131 void assign(const Tensor *t) {
132 // Make sure when we assign the output type of constant is matching its
133 // payload.
134 assert(t->getType().isEqual(payload_.getType()));
135 payload_.assign(t);
136 }
137
138 void setPayloadType(TypeRef ty) { payload_.setType(ty); }
139
140 bool isDataParallel() const { return false; }
141
142 std::string getDebugDesc(bool skipUsers = false) const;
143
144 llvm::hash_code getHash() const;
145
146 void clearPayload() { payload_.release(); }
147
148 bool verify() const;
149};
150
151/// Placeholder nodes are unbound-storage. The content tensors are attached to
152/// this node at runtime. Placeholders are used as inputs and output nodes to
153/// the network.
154class Placeholder : public Storage {
155 /// Specifies if the placeholder is trainable.
156 bool isTrainable_;
157
158 /// Specifies if associated Tensors should be zeroed when allocated.
159 bool allocZero_{false};
160
161 /// Specifies if this is a static placeholder, this means it is set once
162 /// before the first network run and will be reused by following runs.
163 bool isStatic_{false};
164
165public:
166 /// Create a new placeholder.
167 Placeholder(llvm::StringRef name, TypeRef Ty, bool isTrainable,
168 const std::string &layout)
169 : Storage(Kinded::Kind::PlaceholderKind, name, layout),
170 isTrainable_(isTrainable) {
171 addResult(Ty);
172 }
173
174 /// \returns True if the placeholder are trainable during
175 /// differentiation.
176 bool isTraining() const { return isTrainable_; }
177
178 /// \returns True if associated Tensors should be zeroed when allocated.
179 bool allocZero() const { return allocZero_; }
180
181 /// Update the isStatic_ field.
182 void setStatic(bool isStatic) { isStatic_ = isStatic; }
183
184 /// Get the status of the isStatic_ flag.
185 bool isStatic() const { return isStatic_; }
186
187 /// Sets whether or not associated Tensors should be zeroed.
188 void setAllocZero(bool on = true) { allocZero_ = on; }
189
190 static bool classof(const Kinded *k) {
191 return k->getKind() == Kinded::Kind::PlaceholderKind;
192 }
193
194 bool isDataParallel() const { return false; }
195
196 std::string getDebugDesc(bool skipUsers = false) const;
197
198 llvm::hash_code getHash() const;
199};
200
201/// Calculate the size of the output tensor based on the convolution/pooling
202/// parameters.
203inline std::pair<dim_t, dim_t> calculateConvPoolOutputDims(
204 size_t sx, size_t sy, llvm::ArrayRef<unsigned_t> kernels,
205 llvm::ArrayRef<unsigned_t> strides, llvm::ArrayRef<unsigned_t> pads,
206 llvm::ArrayRef<unsigned_t> dilation = {1, 1}) {
207 PaddingTLBR pdim(pads);
208 ShapeHW kdim(kernels);
209 ShapeHW sdim(strides);
210 size_t outsx = ((sx + pdim.top + pdim.bottom - kdim.height -
211 (kdim.height - 1) * (dilation[0] - 1)) /
212 sdim.height +
213 1);
214 size_t outsy = ((sy + pdim.left + pdim.right - kdim.width -
215 (kdim.width - 1) * (dilation[1] - 1)) /
216 sdim.width +
217 1);
218 return {outsx, outsy};
219}
220
221/// Calculate the size of the output tensor based on the 3D convolution/pooling
222/// parameters \p inH \p inW, \p inT which are the input's height, width, and
223/// depth respectively.
224inline ShapeTHW calculate3DConvPoolOutputDims(
225 size_t inT, size_t inH, size_t inW, llvm::ArrayRef<unsigned_t> kernels,
226 llvm::ArrayRef<unsigned_t> strides, llvm::ArrayRef<unsigned_t> pads) {
227 PaddingNFTBLR pdim(pads);
228 ShapeTHW kdim(kernels);
229 ShapeTHW sdim(strides);
230
231 size_t outT = ((inT + pdim.near + pdim.far - kdim.temporal_frames) /
232 sdim.temporal_frames +
233 1);
234 size_t outH =
235 ((inH + pdim.top + pdim.bottom - kdim.height) / sdim.height + 1);
236 size_t outW = ((inW + pdim.left + pdim.right - kdim.width) / sdim.width + 1);
237
238 llvm::SmallVector<size_t, 3> outDims{outT, outH, outW};
239 return ShapeTHW(llvm::makeArrayRef(outDims));
240}
241
242/// Calculate the size of the output tensor based on the ConvTranspose
243/// parameters.
244inline std::pair<dim_t, dim_t> calculateConvTransposeOutputDims(
245 size_t sx, size_t sy, llvm::ArrayRef<unsigned_t> kernels,
246 llvm::ArrayRef<unsigned_t> strides, llvm::ArrayRef<unsigned_t> pads,
247 llvm::ArrayRef<unsigned_t> dilation = {1, 1}) {
248 PaddingTLBR pdim(pads);
249 ShapeHW kdim(kernels);
250 ShapeHW sdim(strides);
251
252 size_t outsx = (sx - 1) * sdim.height + (kdim.height - 1) * dilation[0] + 1 -
253 pdim.top - pdim.bottom;
254 size_t outsy = (sy - 1) * sdim.width + (kdim.width - 1) * dilation[1] + 1 -
255 pdim.left - pdim.right;
256
257 return {outsx, outsy};
258}
259
260/// Modes of the padding operation.
261enum PaddingMode { CONSTANT = 0, REFLECT, EDGE };
262
263/// Different lengths modes used for SLS variants.
264enum class LengthsMode { Variable, AllOne };
265
266/// Convolution Layouts.
267enum ConvolutionLayout { NHWC = 0, NCHW, NTHWC, NCTHW };
268inline bool is3DData(ConvolutionLayout layout) {
269 return (layout == NTHWC || layout == NCTHW);
270}
271
272/// Modes of pooling for RoiAlign operation.
273enum PoolingMode { AVG = 0, MAX };
274
275/// Activations fused into ConvolutionNode (not supported on all backends).
276enum FusedActivation {
277 NONE = 0,
278 RELU,
279 CLIP,
280 TANH,
281 SIGMOID,
282 LEAKY_RELU,
283};
284
285/// LUT Operators (not supported on all backends).
286enum class LUTOperator {
287 NONE = 0,
288 RELU,
289 CLIP,
290 TANH,
291 SIGMOID,
292 LEAKY_RELU,
293};
294
295enum SplitEmbeddingPoolingMode {
296 EP_SUM = 0,
297 EP_MEAN = 1,
298 EP_NONE = 2,
299 EP_TOTAL,
300};
301enum SplitEmbeddingSparseType {
302 EST_FLOAT = 0,
303 EST_FLOAT16 = 1,
304 EST_INT8 = 2,
305 EST_INT4 = 3,
306 EST_INT2 = 4,
307 EST_TOTAL,
308};
309
310enum WeightsPlacement {
311 DEVICE = 0,
312 MANAGED = 1,
313 MANAGED_CACHING = 2,
314 HOST = 3,
315};
316
317/// Define output operators.
318llvm::raw_ostream &operator<<(llvm::raw_ostream &os, ConvolutionLayout layout);
319llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
320 FusedActivation fusedActivation);
321llvm::raw_ostream &operator<<(llvm::raw_ostream &os, LengthsMode lengthsMode);
322llvm::raw_ostream &operator<<(llvm::raw_ostream &os, LUTOperator lutOperator);
323
324/// Support for hashing the Nodes. This is required for using
325/// llvm::hash_combine.
326class Node;
327class Tensor;
328struct Type;
329struct NodeValue;
330
331/// Convert a float into an unsigned integer binary representation.
332/// FIXME: This is a workaround, because defining the hash_code
333/// hash_value(float) does not work for some reason.
334size_t toBinary(float f);
335/// Convert a collection of floats into a vector of
336/// unsigned integer binary representation.
337/// FIXME: This is a workaround, because defining the hash_code
338/// hash_value(float) does not work for some reason.
339std::vector<size_t> toBinary(llvm::ArrayRef<float> vec);
340llvm::hash_code hash_value(const glow::Tensor &T);
341
342llvm::hash_code hash_value(const glow::Type *T);
343
344llvm::hash_code hash_value(glow::Node *T);
345
346llvm::hash_code hash_value(const glow::NodeValue &T);
347llvm::hash_code hash_value(const glow::NodeHandle &T);
348
349} // namespace glow
350
351// The rest of the nodes are auto-generated into this file:
352#include "glow/AutoGenNodes.h"
353
354namespace glow {
355
356/// A helper class for all the Node visitors.
357/// You probably shouldn't use this directly.
358template <typename ImplClass> class NodeVisitorBase {
359public:
360 ImplClass &asImpl() { return static_cast<ImplClass &>(*this); }
361};
362
363/// A visitor that visits only nodes. It does not recursively
364/// visit any children of nodes.
365template <typename ImplClass, typename RetTy = void, typename... ArgTys>
366class NodeVisitor : public NodeVisitorBase<ImplClass> {
367 using super = NodeVisitorBase<ImplClass>;
368
369public:
370 using super::asImpl;
371
372 // Perform any required pre-processing before visiting.
373 // Sub-classes can override it to provide their custom
374 // pre-processing steps.
375 void pre(Node *N) {}
376 void post(Node *N) {}
377
378 RetTy visit(Node *N, ArgTys... args) {
379 asImpl().pre(N, args...);
380
381 switch (N->getKind()) {
382#define DEF_NODE(CLASS, NAME) \
383 case glow::Kinded::Kind::CLASS##Kind: \
384 return asImpl().visit##CLASS(static_cast<CLASS *>(N), \
385 std::forward<ArgTys>(args)...);
386#include "glow/AutoGenNodes.def"
387
388#define DEF_INSTR(CLASS, NAME) case glow::Kinded::Kind::CLASS##Kind:
389#define DEF_BACKEND_SPECIFIC_INSTR(CLASS, NAME) DEF_INSTR(CLASS, NAME)
390#define DEF_VALUE(CLASS, NAME) DEF_INSTR(CLASS, NAME)
391#include "glow/AutoGenInstr.def"
392
393 llvm_unreachable(
394 "Not reachable, values and instructions are not handled here");
395 }
396 llvm_unreachable("Not reachable, all cases handled");
397 }
398
399// Define default dispatcher implementations chain to parent nodes.
400#define DEF_NODE(CLASS, NAME) \
401 RetTy visit##CLASS(CLASS *N, ArgTys... args) { \
402 auto Ret = asImpl().visit##PARENT(N, std::forward<ArgTys>(args)...); \
403 asImpl().post(N, args...); \
404 return Ret; \
405 }
406#include "glow/AutoGenNodes.def"
407};
408
409// helper to get a string name for an OpType
410template <typename T> const char *getNodeName() {
411 static_assert(std::is_base_of<Node, T>(), "Must be node");
412
413// Do this for every known node
414#undef DEF_NODE
415#define DEF_NODE(CLASS, NAME) \
416 if (std::is_same<T, CLASS>()) { \
417 return #NAME; \
418 }
419// @lint-ignore facebook-hte-DuplicateInclude
420#include "glow/AutoGenNodes.def"
421
422 llvm_unreachable("Not reachable, values are not handled here");
423};
424
425/// Signifiers for exporting and importing properties of Nodes.
426constexpr char layoutSignifier[] = "layout";
427constexpr char staticSignifier[] = "offline";
428constexpr char trainableSignifier[] = "trainable";
429constexpr char elemKindSignifier[] = "elemKind";
430constexpr char loaderNameSignifier[] = "loaderName";
431constexpr char saveNameSignifier[] = "saveName";
432constexpr char stridesSignifier[] = "strides";
433constexpr char qScaleSignifier[] = "qScale";
434constexpr char qOffsetSignifier[] = "qOffset";
435constexpr char shapeSignifier[] = "shape";
436constexpr char originNameToUniqueOffsetMappingSignifier[] =
437 "originNameToUniqueOffsetMapping";
438constexpr char constFoldSubgraphNodeName[] = "Glow__ConstFoldSubgraph";
439constexpr char staticPHDummyNodeName[] = "Glow__StaticPHDummyNode";
440
441/// \returns the string ID for a type attribute property for a specific \p ioNum
442/// and \p signifier and whether \p isInput. E.g. to retrieve result number 0's
443/// shape, you'd pass `(0, "shape", false)`. \p addPrefix is an additional
444/// prefix to include at the front of the returned ID.
445inline std::string getTypeAttrID(unsigned ioNum, const std::string &signifier,
446 bool isInput = false,
447 const std::string &addPrefix = "") {
448 return addPrefix + (isInput ? "i" : "o") + std::to_string(ioNum) + "_" +
449 signifier;
450}
451
452} // namespace glow
453
454#endif // GLOW_GRAPH_NODES_H
455