1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16/// \file This file describes the high-level API for converting a function
17/// from one type to another.
18#ifndef GLOW_CONVERTER_FUNCTIONCONVERTER_H
19#define GLOW_CONVERTER_FUNCTIONCONVERTER_H
20
21#include "glow/Base/Type.h"
22
23#include <utility> // For std::pair.
24
25namespace glow {
26class PlaceholderBindings;
27class Function;
28class Node;
29struct NodeValue;
30class Placeholder;
31class Tensor;
32
33/// This class implements the high-level APIs used to convert a function
34/// from one type to another. The actual conversions must be implemented
35/// by derived classes.
36class FunctionConverter {
37protected:
38 /// The function to be converted.
39 Function &function_;
40
41 /// \return the type that \p out needs to have at the end of the conversion
42 /// procedure. In other words, this is the type this value will have at the
43 /// end of ::convert.
44 /// E.g., let say we want to convert:
45 /// \verbatim
46 /// res = matmul float
47 /// \endverbatim
48 /// into
49 /// \verbatim
50 /// res = matmul fp16
51 /// \endverbatim
52 /// The target type for res is fp16.
53 ///
54 /// Using this information, the conversion procedure will insert a conversion
55 /// of \p out from this type to the current type of \p out.
56 /// \verbatim
57 /// res = matmul fp16
58 /// ... = convert fp16 res to res's current type
59 /// \endverbatim
60 ///
61 /// If nullptr is returned or the returned type is identical to the current
62 /// type of the related value, no conversion will be inserted by the
63 /// conversion procedure.
64 virtual TypeRef getTargetTypeForOutput(const NodeValue &out) const;
65
66 /// \return the type that the input operand described by \p idx-th input of
67 /// \p use needs to have at the end of the conversion procedure. In other
68 /// words, this is the type this value will have at the end of ::convert.
69 /// E.g., let say we want to convert:
70 /// \verbatim
71 /// res = matmul float A, B
72 /// \endverbatim
73 /// into
74 /// \verbatim
75 /// res = matmul fp16 A, B
76 /// \endverbatim
77 /// The target type for A (i.e., (matmul, 0)) is fp16.
78 ///
79 /// Using this information, the conversion procedure will insert a conversion
80 /// of (\p node, \p idx) from its current type to the returned type.
81 /// \verbatim
82 /// convertedA = convert A's current type A to returned type
83 /// res = matmul fp16 convertedA, B
84 /// \endverbatim
85 ///
86 /// If nullptr is returned or the returned type is identical to the current
87 /// type of the related value, no conversion will be inserted by the
88 /// conversion procedure.
89 virtual TypeRef getTargetTypeForInput(const Node &use, unsigned idx) const;
90
91 /// Check if \p node can be converted.
92 /// \return false if \p node shouldn't be considered for conversion.
93 virtual bool canConvert(const Node &node) const;
94
95 /// Create a conversion with \p val as input and \p destTy as the destination
96 /// type in \p function, given \p node. In other words, creates something like
97 /// cast val to destTy. \p isInput represents if this is converting an input.
98 virtual Node *createConversion(Function &function, const Node &node,
99 NodeValue &val, TypeRef destTy,
100 bool isInput) = 0;
101
102 /// Given a \p conversion, get its output value.
103 /// The default implementation returns the zero-th result.
104 /// If a conversion node defined more than one value, this
105 /// method must be overloaded.
106 virtual NodeValue getConversionOutput(Node &conversion) const;
107
108 /// Mutate the outputs of \p node to the expected output target
109 /// type (\see getTargetTypeForOutput) and insert the conversions
110 /// to preserve the type consistency with the rest of the network.
111 void convertOutputs(Node &node);
112
113 /// Insert conversion node for each input of \p node that don't
114 /// match getTargetTypeForInput.
115 void convertInputs(Node &node);
116
117 /// Convert the \p input tensor to the \p destTy destination type.
118 virtual void convertTensor(Tensor &input, TypeRef destTy) = 0;
119
120 /// Morph \p node into its final form. For the most part
121 /// this method should be a noop and just return \p node.
122 /// However, this hook provides a way to perform changes
123 /// on more than just the type of the inputs and outputs,
124 /// like changing the opcode of an operation.
125 ///
126 /// \warning \p node must not be deleted.
127 ///
128 /// \pre All the inputs of \p node have been converted to
129 /// their target type using ::getTargetTypeForInput.
130 /// \pre All the results of \p node have been converted to
131 /// their target type using ::getTargetTypeForOutput.
132 ///
133 /// \return the final morphed node.
134 virtual Node &morphNode(Node &node);
135
136 /// Hook to perform some post processing on the final morphed node.
137 virtual void postProcessing(Node &node);
138
139 /// Hook to do a final clean-up after all operations have been converted.
140 virtual void cleanUp() {}
141
142public:
143 /// Create a function converter for \p F.
144 ///
145 /// \note This method will modify \p F when calling ::convert.
146 /// If one wants to keep the original function around,
147 /// they need to clone it before creating this converter.
148 FunctionConverter(Function &F) : function_(F) {}
149
150 virtual ~FunctionConverter() {}
151
152 /// Convert \p F according to ::getTargetTypeForOutput and
153 /// ::getTargetTypeForInput.
154 ///
155 /// The high level algorithm looks like:
156 /// \code
157 /// for each node in function:
158 /// insert conversions for the inputs of node
159 /// update the inputs of node to use the results of the conversions
160 /// mutate the type of the outputs of node
161 /// insert conversions for the outputs of node
162 /// morph node
163 /// postProcessing node
164 /// cleanUp
165 /// \endcode
166 void convert();
167
168 /// Modify the type of \p placeholder according to getTargetTypeForOutput.
169 /// If the \p context is provided and \p placeholder has a backing tensor,
170 /// this tensor is also updated.
171 /// Note: If \p placeholder is used in functions other than F, changes to
172 /// those functions will be made as well to accommodate the converted
173 /// placeholder.
174 void convertPlaceholder(Placeholder &placeholder,
175 PlaceholderBindings *context);
176};
177} // namespace glow
178#endif
179