1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * |
22 | * \file to_mixed_precision.cc |
23 | * \brief Automatic mixed floating point precision for relay graphs. i.e. turn a graph into fp16. |
24 | * |
25 | */ |
26 | |
27 | #include <tvm/ir/attrs.h> |
28 | #include <tvm/relay/expr_functor.h> |
29 | #include <tvm/relay/transform.h> |
30 | #include <tvm/runtime/object.h> |
31 | |
32 | #include <utility> |
33 | |
34 | #include "pattern_utils.h" |
35 | |
36 | namespace tvm { |
37 | namespace relay { |
38 | |
39 | TVM_REGISTER_PASS_CONFIG_OPTION("relay.ToMixedPrecision.keep_orig_output_dtype" , Bool); |
40 | // A callable which hashes std::pair |
41 | struct pair_hash { |
42 | template <class T1, class T2> |
43 | std::size_t operator()(const std::pair<T1, T2>& pair) const { |
44 | auto h1 = std::hash<T1>()(pair.first); |
45 | auto h2 = std::hash<T2>()(pair.second); |
46 | |
47 | // Use boost's combine_hash strategy |
48 | return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2)); |
49 | } |
50 | }; |
51 | |
52 | // MIXED_PRECISION_ALWAYS ops should always be done in lower precision due to the speed and memory |
53 | // savings. MIXED_PRECISION_FOLLOW ops can be done in lower precision but don't have speedups to |
54 | // justify a cast. MIXED_PRECISION_NEVER colored ops should not be done in lower precision due to |
55 | // numerical reasons. |
56 | enum MixedTypeConversionCategory : int { |
57 | MIXED_PRECISION_ALWAYS = 0, |
58 | MIXED_PRECISION_FOLLOW = 1, |
59 | MIXED_PRECISION_NEVER = 2 |
60 | }; |
61 | |
62 | // A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype |
63 | using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, DataType>, Expr, pair_hash>; |
64 | |
65 | // Return array is of type : [MixedTypeConversionCategory (int), String, String] |
66 | // The fields are : [ConversionCategory, accumulation_datatype, output_datatype] |
67 | // Call is a call node, DataType is the mixed precision type |
68 | using FTVMMixedPrecisionConversionType = runtime::TypedPackedFunc<Array<ObjectRef>( |
69 | const Call& call_node, const std::string& target_dtype_str)>; |
70 | |
71 | /*! \brief This class transforms the given relay module into a version where |
72 | * as many operations as possible operate in the target mixed precision dtype. |
73 | * |
74 | * Input : A Relay module with operations registered with FTVMMixedPrecisionConversionType |
75 | * functions. These describe when and how the operations will be transformed |
76 | * into the target precision dtype. |
77 | * |
78 | * Output : A Relay module with some operations transformed according to the below |
79 | * methodology. |
80 | * |
81 | * Methodology : |
82 | * 1) Each relay Op is either of conversion category ALWAYS, FOLLOW, NEVER |
83 | * defined by the associated FTVMMixedPrecisionConversionType function. |
84 | * If an operation is not registered, it by default is assumed to be |
85 | * FOLLOW. |
86 | * 2) ALWAYS operations always convert the input floating point args into |
87 | * the target mixed precision dtype. FOLLOW Ops will convert the input |
88 | * floating point args back into FP32 unless all floating point args |
89 | * are in the target mixed precision dtypes. NEVER ops will always cast |
90 | * inputs back into FP32. |
91 | * 3) Each ALWAYS Op, and FOLLOW Op with mixed precision dtype arguments |
92 | * also have an associated accumulation_dtype and output_dtype which |
93 | * describe whether a larger dtype is used to accumulate the results |
94 | * of the operation. The output_dtype meanwhile describes the dtype |
95 | * most Ops should use from this accumulator. |
96 | */ |
97 | class MixedPrecisionPass : public MixedModeMutator { |
98 | private: |
99 | /*! \brief A cache of nodes + target dtype to a cast version of the node with target dtype. */ |
100 | CachedCastNodes cast_nodes_cache_; |
101 | |
102 | /*! \brief The target datatype we want to convert to e.g. FP16 */ |
103 | const DataType mixed_precision_type_; |
104 | |
105 | /*! \brief Map of Ops with no associated FTVMMixedPrecisionConversionType to the times they were |
106 | * encountered. Used for emitting warnings on missing ops in the pass. |
107 | */ |
108 | std::unordered_map<std::string, int> missing_ops_; |
109 | const RelayExprNode* root_; |
110 | std::vector<DataType> original_dtype_; |
111 | bool keep_orig_output_dtype_; |
112 | |
113 | Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const { |
114 | /* If the accumulation dtype is in the attributes make a copy and mutate the field. */ |
115 | Attrs cur_attrs = call->attrs; |
116 | if (cur_attrs.get() != nullptr) { |
117 | // TODO(AndrewZhaoLuo): Figure out a better way to do this |
118 | // modify output_dtype attributes (accumulation dtypes for ops) |
119 | if (auto attrs = cur_attrs.as<Conv1DAttrs>()) { |
120 | return ModifyAttrsOutputDType(attrs, accumulation_dtype); |
121 | } else if (auto attrs = cur_attrs.as<Conv1DTransposeAttrs>()) { |
122 | return ModifyAttrsOutputDType(attrs, accumulation_dtype); |
123 | } else if (auto attrs = cur_attrs.as<Conv2DAttrs>()) { |
124 | return ModifyAttrsOutputDType(attrs, accumulation_dtype); |
125 | } else if (auto attrs = cur_attrs.as<Conv2DTransposeAttrs>()) { |
126 | return ModifyAttrsOutputDType(attrs, accumulation_dtype); |
127 | } else if (auto attrs = cur_attrs.as<Conv2DWinogradAttrs>()) { |
128 | return ModifyAttrsOutputDType(attrs, accumulation_dtype); |
129 | } else if (auto attrs = cur_attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>()) { |
130 | return ModifyAttrsOutputDType(attrs, accumulation_dtype); |
131 | } else if (auto attrs = cur_attrs.as<DeformableConv2DAttrs>()) { |
132 | return ModifyAttrsOutputDType(attrs, accumulation_dtype); |
133 | } else if (auto attrs = cur_attrs.as<Conv3DAttrs>()) { |
134 | return ModifyAttrsOutputDType(attrs, accumulation_dtype); |
135 | } else if (auto attrs = cur_attrs.as<Conv3DTransposeAttrs>()) { |
136 | return ModifyAttrsOutputDType(attrs, accumulation_dtype); |
137 | } else if (auto attrs = cur_attrs.as<Conv3DWinogradAttrs>()) { |
138 | return ModifyAttrsOutputDType(attrs, accumulation_dtype); |
139 | } else if (auto attrs = cur_attrs.as<DenseAttrs>()) { |
140 | return ModifyAttrsOutputDType(attrs, accumulation_dtype); |
141 | } else if (auto attrs = cur_attrs.as<BatchMatmulAttrs>()) { |
142 | return ModifyAttrsOutputDType(attrs, accumulation_dtype); |
143 | } |
144 | |
145 | // modify dtype attributes (creating new tensors of type dtype) |
146 | if (auto attrs = cur_attrs.as<InitOpAttrs>()) { |
147 | return ModifyAttrsDType(attrs, accumulation_dtype); |
148 | } |
149 | } |
150 | |
151 | return cur_attrs; |
152 | } |
153 | |
154 | template <typename T> |
155 | Attrs ModifyAttrsOutputDType(const T* attrs, const DataType& accumulation_dtype) const { |
156 | /* |
157 | Helper template to modify relevant attributes with out_dtype type. |
158 | These represent accumulation dtypes for some operations e.g. |
159 | conv2d might take in fp16 and give a fp32 result. |
160 | Attrs is const because we get it as a const. |
161 | */ |
162 | DataType cur_type = (attrs->out_dtype); |
163 | ObjectPtr<T> new_attrs = make_object<T>(*attrs); |
164 | if (cur_type.is_float() || cur_type.is_bfloat16() || cur_type.is_void()) { |
165 | new_attrs->out_dtype = accumulation_dtype; |
166 | } |
167 | return Attrs(new_attrs); |
168 | } |
169 | |
170 | template <typename T> |
171 | Attrs ModifyAttrsDType(const T* attrs, const DataType& accumulation_dtype) const { |
172 | /* |
173 | Helper template to modify relevant attributes with dtype type. |
174 | This determines the output dtype for some ops. For example |
175 | zeros creates a tensor of zeros of the specified dtype. |
176 | Attrs is const because we get it as a const. |
177 | */ |
178 | DataType cur_type = (attrs->dtype); |
179 | ObjectPtr<T> new_attrs = make_object<T>(*attrs); |
180 | if (cur_type.is_float() || cur_type.is_bfloat16() || cur_type.is_void()) { |
181 | new_attrs->dtype = accumulation_dtype; |
182 | } |
183 | return Attrs(new_attrs); |
184 | } |
185 | |
186 | Type GetType(const Expr& expr) const { |
187 | // The expression has not been changed AND it's existing type |
188 | // is known to still be valid. (See special handling for tuples etc |
189 | // below for where we null out checked_type_ when we can not |
190 | // sure it is still valid. |
191 | Type checked_type = expr->checked_type_; |
192 | if (checked_type.defined()) { |
193 | return checked_type; |
194 | } |
195 | |
196 | // This also populates the checked_type_ field for expr |
197 | return transform::InferTypeLocal(expr); |
198 | } |
199 | |
200 | bool IsMixedPrecisionType(const Type& t, bool ignore_non_float = false) const { |
201 | /* Returns whether t is a type with only target mixed precision type elements. |
202 | If ignore_non_float, then ignore non-floating types. |
203 | */ |
204 | if (const TensorTypeNode* tensor_type = t.as<TensorTypeNode>()) { |
205 | bool is_supported_floating_point_type = |
206 | (tensor_type->dtype).is_float() || (tensor_type->dtype).is_bfloat16(); |
207 | return (ignore_non_float && !is_supported_floating_point_type) || |
208 | tensor_type->dtype == mixed_precision_type_; |
209 | } else if (const TupleTypeNode* tuple_type = t.as<TupleTypeNode>()) { |
210 | for (Type t : tuple_type->fields) { |
211 | if (!IsMixedPrecisionType(t, ignore_non_float)) return false; |
212 | } |
213 | return true; |
214 | } else { |
215 | LOG(FATAL) << "Unsupported type " << t << " we don't know how to handle" ; |
216 | } |
217 | } |
218 | |
219 | Expr CachedCast(const Expr& expr, const DataType& expr_dtype, const DataType& wanted_dtype) { |
220 | /* Cast tensor to the wanted datatype, returning a cached version if it's already been done. */ |
221 | |
222 | // If this is not a floating point type, do not cast. E.g. it might be an integer |
223 | if (!(expr_dtype.is_float() || expr_dtype.is_bfloat16())) { |
224 | return expr; |
225 | } |
226 | |
227 | if (expr_dtype == wanted_dtype) { |
228 | return expr; |
229 | } |
230 | |
231 | const ExprNode* expr_node = expr.as<ExprNode>(); |
232 | CHECK(expr_node) << "Non-expression node found in cast: " << expr; |
233 | |
234 | // Use cached result if possible. |
235 | auto search = cast_nodes_cache_.find({expr_node, wanted_dtype}); |
236 | if (search != cast_nodes_cache_.end()) { |
237 | return search->second; |
238 | } |
239 | |
240 | Expr result = Cast(expr, wanted_dtype); |
241 | cast_nodes_cache_[{expr_node, wanted_dtype}] = result; |
242 | |
243 | // Reverse the cache result, e.g. if we want to reverse the cast simply point to original node |
244 | const ExprNode* new_expr_node = result.as<ExprNode>(); |
245 | cast_nodes_cache_[{new_expr_node, expr_dtype}] = expr; |
246 | return result; |
247 | } |
248 | |
249 | Expr CastArg(const Expr& expr, const Type& expr_type, const DataType& wanted_dtype) { |
250 | /* Helper for casting arguments to call_nodes handling all relevant cases. */ |
251 | if (const TensorTypeNode* tensor_type = expr_type.as<TensorTypeNode>()) { |
252 | return CachedCast(expr, tensor_type->dtype, wanted_dtype); |
253 | } else if (const TupleTypeNode* tuple_type = expr_type.as<TupleTypeNode>()) { |
254 | Array<Expr> new_expr; |
255 | bool all_same = true; |
256 | for (size_t i = 0; i < (tuple_type->fields).size(); i++) { |
257 | Expr tuple_element = GetField(expr, i); |
258 | Type tuple_element_dtype = (tuple_type->fields)[i]; |
259 | Expr casted_element = CastArg(tuple_element, tuple_element_dtype, wanted_dtype); |
260 | new_expr.push_back(casted_element); |
261 | all_same &= casted_element.same_as(tuple_element); |
262 | } |
263 | return all_same ? expr : Tuple(new_expr); |
264 | } |
265 | CHECK(0) << "Unsupported type " << expr_type << " we don't know how to cast for arguments!" ; |
266 | return expr; |
267 | } |
268 | |
269 | std::pair<Array<Expr>, Array<Type>> CastAllArgs(const Array<Expr>& cur_args, |
270 | const Array<Type>& cur_arg_types, |
271 | const DataType& wanted_dtype) { |
272 | Array<Expr> new_args; |
273 | Array<Type> new_arg_types; |
274 | for (size_t i = 0; i < cur_args.size(); i++) { |
275 | Expr cur_arg = cur_args[i]; |
276 | Type cur_arg_type = cur_arg_types[i]; |
277 | Expr new_arg = CastArg(cur_arg, cur_arg_type, wanted_dtype); |
278 | Type new_arg_type = GetType(new_arg); |
279 | new_args.push_back(new_arg); |
280 | new_arg_types.push_back(new_arg_type); |
281 | } |
282 | return {new_args, new_arg_types}; |
283 | } |
284 | |
285 | public: |
286 | using MixedModeMutator::VisitExpr_; |
287 | |
288 | explicit MixedPrecisionPass(Expr base, bool keep_orig_output_dtype, |
289 | DataType mixed_precision_type = DataType::Float(16)) |
290 | : MixedModeMutator(), |
291 | mixed_precision_type_(mixed_precision_type), |
292 | root_(Downcast<Function>(base)->body.get()), |
293 | keep_orig_output_dtype_(keep_orig_output_dtype) { |
294 | if (keep_orig_output_dtype_) { |
295 | if (root_->IsInstance<tvm::relay::TupleNode>()) { |
296 | const TupleTypeNode* tuple_type = (root_->checked_type_).as<TupleTypeNode>(); |
297 | for (Type t : tuple_type->fields) { |
298 | const TensorTypeNode* tensor_type = t.as<TensorTypeNode>(); |
299 | original_dtype_.push_back(tensor_type->dtype); |
300 | } |
301 | } else if (root_->IsInstance<tvm::relay::CallNode>()) { |
302 | original_dtype_.push_back((root_->checked_type_).as<TensorTypeNode>()->dtype); |
303 | } |
304 | } |
305 | if (!(mixed_precision_type_.is_float() || mixed_precision_type_.is_bfloat16())) { |
306 | LOG(FATAL) << "Only support IEEE floating point mixed precision types and bfloat16, but got " |
307 | << mixed_precision_type_; |
308 | } |
309 | } |
310 | |
311 | Expr Rewrite_(const CallNode* pre_call_node, const Expr& post) final { |
312 | const CallNode* post_call_node = post.as<CallNode>(); |
313 | CHECK(post_call_node) << "Expected a CallNode, but got " << post; |
314 | |
315 | Expr cur_op = post_call_node->op; |
316 | |
317 | // TODO(AndrewZhaoLuo): Support ADTs |
318 | // Relay's algebraic data types are not supported yet. |
319 | ICHECK(!cur_op.as<GlobalVarNode>() // used to declare functions for recursion |
320 | && !cur_op.as<ConstructorNode>() // constructing ADT types |
321 | && !cur_op.as<VarNode>()) // used for calling recursive functions |
322 | << "Algebraic Data Types (ADT) are not supported yet for mixed precision pass." ; |
323 | |
324 | // Get info on the operation being called: |
325 | // conversion category (int), accumulation dtype (str), output dtype (str) |
326 | MixedTypeConversionCategory initial_category; |
327 | DataType accumulation_dtype, output_dtype; |
328 | if (cur_op.as<FunctionNode>()) { |
329 | // Avoid messing with functions to avoid changing signature |
330 | initial_category = MIXED_PRECISION_NEVER; |
331 | accumulation_dtype = DataType::Float(32); |
332 | output_dtype = DataType::Float(32); |
333 | } else if (cur_op.as<OpNode>()) { |
334 | static auto attr_map = |
335 | Op::GetAttrMap<FTVMMixedPrecisionConversionType>("FTVMMixedPrecisionConversionType" ); |
336 | Op op = Downcast<Op>(cur_op); |
337 | if (attr_map.count(op)) { |
338 | // Calculate the conversion category and dtypes from registered attribute. |
339 | FTVMMixedPrecisionConversionType func = attr_map[op]; |
340 | Array<ObjectRef> op_descriptor = |
341 | func(GetRef<Call>(pre_call_node), DLDataType2String(mixed_precision_type_)); |
342 | ICHECK(op_descriptor.size() == 3) |
343 | << "got the wrong number of returned arguments (expected 3 got " << op_descriptor.size() |
344 | << ") from FTVMMixedPrecisionConversionType for " << AsText(op, false); |
345 | |
346 | int64_t op_conversion_type = Downcast<Integer>(op_descriptor[0])->value; |
347 | initial_category = static_cast<MixedTypeConversionCategory>(op_conversion_type); |
348 | accumulation_dtype = DataType(String2DLDataType(Downcast<String>(op_descriptor[1]))); |
349 | output_dtype = DataType(String2DLDataType(Downcast<String>(op_descriptor[2]))); |
350 | } else { |
351 | missing_ops_[op->name] += 1; |
352 | |
353 | // If not registered, by default assume is a generic FOLLOW operation. |
354 | initial_category = MIXED_PRECISION_FOLLOW; |
355 | accumulation_dtype = mixed_precision_type_; |
356 | output_dtype = mixed_precision_type_; |
357 | } |
358 | } else { |
359 | LOG(FATAL) << "Unsupported op type in CallNode: " << pre_call_node->op; |
360 | } |
361 | |
362 | // First check if all the new mutated args are in lower precision form |
363 | Array<Type> cur_arg_types; |
364 | bool all_args_mixed_type_compatible = true; |
365 | for (Expr arg : post_call_node->args) { |
366 | Type cur_arg_type = GetType(arg); |
367 | cur_arg_types.push_back(cur_arg_type); |
368 | |
369 | if (initial_category == MIXED_PRECISION_FOLLOW && all_args_mixed_type_compatible) { |
370 | // We can cast Vars and Constants to the right types so don't care about the types. |
371 | bool is_mixed_type_compatible = IsMixedPrecisionType(cur_arg_type, true) || |
372 | arg->IsInstance<VarNode>() || |
373 | arg->IsInstance<ConstantNode>(); |
374 | all_args_mixed_type_compatible &= is_mixed_type_compatible; |
375 | } |
376 | } |
377 | |
378 | // Determine the final category we want for conversion |
379 | MixedTypeConversionCategory final_category = initial_category; |
380 | if (initial_category == MIXED_PRECISION_FOLLOW) { |
381 | final_category = |
382 | all_args_mixed_type_compatible ? MIXED_PRECISION_ALWAYS : MIXED_PRECISION_NEVER; |
383 | } |
384 | |
385 | // Create the new arguments to the call. |
386 | DataType wanted_arg_dtypes = |
387 | final_category == MIXED_PRECISION_ALWAYS ? mixed_precision_type_ : DataType::Float(32); |
388 | auto call_args_and_types = CastAllArgs(post_call_node->args, cur_arg_types, wanted_arg_dtypes); |
389 | Array<Expr> new_args = call_args_and_types.first; |
390 | Array<Type> new_arg_types; |
391 | |
392 | if (pre_call_node->op.as<FunctionNode>()) { |
393 | // Function Nodes don't store type info in the Call, it should be a [] |
394 | new_arg_types = pre_call_node->type_args; |
395 | } else { |
396 | new_arg_types = call_args_and_types.second; |
397 | } |
398 | |
399 | // Finally create the new attributes. |
400 | if (final_category == MIXED_PRECISION_ALWAYS) { |
401 | Attrs new_attrs = GetNewAttrs(pre_call_node, accumulation_dtype); |
402 | Expr output = Call(cur_op, new_args, new_attrs, new_arg_types, pre_call_node->span); |
403 | if (accumulation_dtype != output_dtype) { |
404 | output = CastArg(output, GetType(output), output_dtype); |
405 | } |
406 | if (pre_call_node == root_ && keep_orig_output_dtype_) { |
407 | if (original_dtype_[0] != output_dtype) { |
408 | output = CastArg(output, GetType(output), original_dtype_[0]); |
409 | } |
410 | } |
411 | return output; |
412 | } |
413 | |
414 | return Call(cur_op, new_args, pre_call_node->attrs, new_arg_types, pre_call_node->span); |
415 | } |
416 | |
417 | Expr Rewrite_(const TupleGetItemNode* pre, const Expr& post) { |
418 | // The old checked type in the expression may not be valid so clear it |
419 | post->checked_type_ = Type(nullptr); |
420 | return post; |
421 | } |
422 | |
423 | Expr Rewrite_(const TupleNode* pre, const Expr& post) { |
424 | // The old checked type in the expression may not be valid so clear it |
425 | post->checked_type_ = Type(nullptr); |
426 | if (pre == root_ && keep_orig_output_dtype_) { |
427 | Array<Expr> new_expr; |
428 | bool all_same = true; |
429 | for (size_t i = 0; i < original_dtype_.size(); i++) { |
430 | Expr output_element = GetField(post, i); |
431 | Expr casted_element; |
432 | auto output_element_type = transform::InferTypeLocal(output_element); |
433 | casted_element = CastArg(output_element, output_element_type, original_dtype_[i]); |
434 | new_expr.push_back(casted_element); |
435 | all_same &= casted_element.same_as(output_element); |
436 | } |
437 | if (!all_same) { |
438 | return Tuple(new_expr); |
439 | } |
440 | } |
441 | return post; |
442 | } |
443 | |
444 | Expr VisitExpr_(const FunctionNode* func) final { |
445 | // Erase the ret_type annotation and let the normal pass recalculate |
446 | const_cast<FunctionNode*>(func)->ret_type = Type(nullptr); |
447 | return ExprMutator::VisitExpr_(func); |
448 | } |
449 | |
450 | Expr VisitExpr_(const LetNode* op) final { |
451 | // First convert as much of the bound computation to lower precision as possible |
452 | Expr value = this->Mutate(op->value); |
453 | |
454 | // Then rewrite the var type and associated expression |
455 | Var var = Downcast<Var>(this->Mutate(op->var)); |
456 | VarNode* mutable_var = const_cast<VarNode*>((op->var).as<VarNode>()); |
457 | mutable_var->type_annotation = GetType(value); |
458 | mutable_var->checked_type_ = mutable_var->type_annotation; |
459 | |
460 | // Mutate body last as it may depend on previous results |
461 | Expr body = this->Mutate(op->body); |
462 | return Let(var, value, body, op->span); |
463 | } |
464 | |
465 | // To access map of ops not registered for error reporting |
466 | friend Expr ToMixedPrecision(const Expr& expr, bool keep_orig_output_dtype, |
467 | const DataType& mixed_precision_type, int missing_op_mode); |
468 | }; |
469 | |
470 | Expr ToMixedPrecision(const Expr& expr, bool keep_orig_output_dtype, |
471 | const DataType& mixed_precision_type, int missing_op_mode) { |
472 | /* |
473 | missing_op_mode: |
474 | |
475 | 0: Does not allow any missing ops. Will throw errors and terminate the pass when encountering any. |
476 | 1: Allow missing ops but throw warnings. |
477 | 2: Allow missing ops and silently ignore them. |
478 | */ |
479 | ICHECK(missing_op_mode >= 0 && missing_op_mode <= 2) |
480 | << " missing_op_mode must be either 0, 1, or 2 got " << missing_op_mode; |
481 | |
482 | MixedPrecisionPass converter = |
483 | MixedPrecisionPass(expr, keep_orig_output_dtype, mixed_precision_type); |
484 | auto result = converter.Mutate(expr); |
485 | |
486 | for (auto it = converter.missing_ops_.begin(); |
487 | missing_op_mode != 2 && it != converter.missing_ops_.end(); it++) { |
488 | std::string op_name = it->first; |
489 | int appear_count = it->second; |
490 | |
491 | LOG(WARNING) << "Op \"" << op_name << "\" not registered " |
492 | << "FTVMMixedPrecisionConversionType appears " << appear_count |
493 | << " times in graph." ; |
494 | } |
495 | |
496 | if (converter.missing_ops_.size() != 0 && missing_op_mode == 0) { |
497 | CHECK(0) << "Missing ops were found!" ; |
498 | } |
499 | return result; |
500 | } |
501 | |
502 | namespace transform { |
503 | |
504 | Pass ToMixedPrecision(DataType mixed_precision_type, int missing_op_mode) { |
505 | runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func = |
506 | [=](Function f, IRModule m, PassContext pc) { |
507 | bool keep_orig_output_dtype = false; |
508 | keep_orig_output_dtype = pc->GetConfig("relay.ToMixedPrecision.keep_orig_output_dtype" , |
509 | Bool(keep_orig_output_dtype)) |
510 | .value(); |
511 | return Downcast<Function>( |
512 | ToMixedPrecision(f, keep_orig_output_dtype, mixed_precision_type, missing_op_mode)); |
513 | }; |
514 | return CreateFunctionPass(pass_func, 0, "ToMixedPrecision" , {}); |
515 | } |
516 | |
517 | TVM_REGISTER_GLOBAL("relay._transform.ToMixedPrecision" ).set_body_typed(ToMixedPrecision); |
518 | |
519 | } // namespace transform |
520 | |
521 | } // namespace relay |
522 | } // namespace tvm |
523 | |