1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | /// \file This file describes the high-level API for converting a function |
17 | /// from one type to another. |
18 | #ifndef GLOW_CONVERTER_FUNCTIONCONVERTER_H |
19 | #define GLOW_CONVERTER_FUNCTIONCONVERTER_H |
20 | |
21 | #include "glow/Base/Type.h" |
22 | |
23 | #include <utility> // For std::pair. |
24 | |
25 | namespace glow { |
26 | class PlaceholderBindings; |
27 | class Function; |
28 | class Node; |
29 | struct NodeValue; |
30 | class Placeholder; |
31 | class Tensor; |
32 | |
33 | /// This class implements the high-level APIs used to convert a function |
34 | /// from one type to another. The actual conversions must be implemented |
35 | /// by derived classes. |
36 | class FunctionConverter { |
37 | protected: |
38 | /// The function to be converted. |
39 | Function &function_; |
40 | |
41 | /// \return the type that \p out needs to have at the end of the conversion |
42 | /// procedure. In other words, this is the type this value will have at the |
43 | /// end of ::convert. |
44 | /// E.g., let say we want to convert: |
45 | /// \verbatim |
46 | /// res = matmul float |
47 | /// \endverbatim |
48 | /// into |
49 | /// \verbatim |
50 | /// res = matmul fp16 |
51 | /// \endverbatim |
52 | /// The target type for res is fp16. |
53 | /// |
54 | /// Using this information, the conversion procedure will insert a conversion |
55 | /// of \p out from this type to the current type of \p out. |
56 | /// \verbatim |
57 | /// res = matmul fp16 |
58 | /// ... = convert fp16 res to res's current type |
59 | /// \endverbatim |
60 | /// |
61 | /// If nullptr is returned or the returned type is identical to the current |
62 | /// type of the related value, no conversion will be inserted by the |
63 | /// conversion procedure. |
64 | virtual TypeRef getTargetTypeForOutput(const NodeValue &out) const; |
65 | |
66 | /// \return the type that the input operand described by \p idx-th input of |
67 | /// \p use needs to have at the end of the conversion procedure. In other |
68 | /// words, this is the type this value will have at the end of ::convert. |
69 | /// E.g., let say we want to convert: |
70 | /// \verbatim |
71 | /// res = matmul float A, B |
72 | /// \endverbatim |
73 | /// into |
74 | /// \verbatim |
75 | /// res = matmul fp16 A, B |
76 | /// \endverbatim |
77 | /// The target type for A (i.e., (matmul, 0)) is fp16. |
78 | /// |
79 | /// Using this information, the conversion procedure will insert a conversion |
80 | /// of (\p node, \p idx) from its current type to the returned type. |
81 | /// \verbatim |
82 | /// convertedA = convert A's current type A to returned type |
83 | /// res = matmul fp16 convertedA, B |
84 | /// \endverbatim |
85 | /// |
86 | /// If nullptr is returned or the returned type is identical to the current |
87 | /// type of the related value, no conversion will be inserted by the |
88 | /// conversion procedure. |
89 | virtual TypeRef getTargetTypeForInput(const Node &use, unsigned idx) const; |
90 | |
91 | /// Check if \p node can be converted. |
92 | /// \return false if \p node shouldn't be considered for conversion. |
93 | virtual bool canConvert(const Node &node) const; |
94 | |
95 | /// Create a conversion with \p val as input and \p destTy as the destination |
96 | /// type in \p function, given \p node. In other words, creates something like |
97 | /// cast val to destTy. \p isInput represents if this is converting an input. |
98 | virtual Node *createConversion(Function &function, const Node &node, |
99 | NodeValue &val, TypeRef destTy, |
100 | bool isInput) = 0; |
101 | |
102 | /// Given a \p conversion, get its output value. |
103 | /// The default implementation returns the zero-th result. |
104 | /// If a conversion node defined more than one value, this |
105 | /// method must be overloaded. |
106 | virtual NodeValue getConversionOutput(Node &conversion) const; |
107 | |
108 | /// Mutate the outputs of \p node to the expected output target |
109 | /// type (\see getTargetTypeForOutput) and insert the conversions |
110 | /// to preserve the type consistency with the rest of the network. |
111 | void convertOutputs(Node &node); |
112 | |
113 | /// Insert conversion node for each input of \p node that don't |
114 | /// match getTargetTypeForInput. |
115 | void convertInputs(Node &node); |
116 | |
117 | /// Convert the \p input tensor to the \p destTy destination type. |
118 | virtual void convertTensor(Tensor &input, TypeRef destTy) = 0; |
119 | |
120 | /// Morph \p node into its final form. For the most part |
121 | /// this method should be a noop and just return \p node. |
122 | /// However, this hook provides a way to perform changes |
123 | /// on more than just the type of the inputs and outputs, |
124 | /// like changing the opcode of an operation. |
125 | /// |
126 | /// \warning \p node must not be deleted. |
127 | /// |
128 | /// \pre All the inputs of \p node have been converted to |
129 | /// their target type using ::getTargetTypeForInput. |
130 | /// \pre All the results of \p node have been converted to |
131 | /// their target type using ::getTargetTypeForOutput. |
132 | /// |
133 | /// \return the final morphed node. |
134 | virtual Node &morphNode(Node &node); |
135 | |
136 | /// Hook to perform some post processing on the final morphed node. |
137 | virtual void postProcessing(Node &node); |
138 | |
139 | /// Hook to do a final clean-up after all operations have been converted. |
140 | virtual void cleanUp() {} |
141 | |
142 | public: |
143 | /// Create a function converter for \p F. |
144 | /// |
145 | /// \note This method will modify \p F when calling ::convert. |
146 | /// If one wants to keep the original function around, |
147 | /// they need to clone it before creating this converter. |
148 | FunctionConverter(Function &F) : function_(F) {} |
149 | |
150 | virtual ~FunctionConverter() {} |
151 | |
152 | /// Convert \p F according to ::getTargetTypeForOutput and |
153 | /// ::getTargetTypeForInput. |
154 | /// |
155 | /// The high level algorithm looks like: |
156 | /// \code |
157 | /// for each node in function: |
158 | /// insert conversions for the inputs of node |
159 | /// update the inputs of node to use the results of the conversions |
160 | /// mutate the type of the outputs of node |
161 | /// insert conversions for the outputs of node |
162 | /// morph node |
163 | /// postProcessing node |
164 | /// cleanUp |
165 | /// \endcode |
166 | void convert(); |
167 | |
168 | /// Modify the type of \p placeholder according to getTargetTypeForOutput. |
169 | /// If the \p context is provided and \p placeholder has a backing tensor, |
170 | /// this tensor is also updated. |
171 | /// Note: If \p placeholder is used in functions other than F, changes to |
172 | /// those functions will be made as well to accommodate the converted |
173 | /// placeholder. |
174 | void convertPlaceholder(Placeholder &placeholder, |
175 | PlaceholderBindings *context); |
176 | }; |
177 | } // namespace glow |
178 | #endif |
179 | |