1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 *
22 * \file to_mixed_precision.cc
23 * \brief Automatic mixed floating point precision for relay graphs. i.e. turn a graph into fp16.
24 *
25 */
26
27#include <tvm/ir/attrs.h>
28#include <tvm/relay/expr_functor.h>
29#include <tvm/relay/transform.h>
30#include <tvm/runtime/object.h>
31
32#include <utility>
33
34#include "pattern_utils.h"
35
36namespace tvm {
37namespace relay {
38
39TVM_REGISTER_PASS_CONFIG_OPTION("relay.ToMixedPrecision.keep_orig_output_dtype", Bool);
40// A callable which hashes std::pair
41struct pair_hash {
42 template <class T1, class T2>
43 std::size_t operator()(const std::pair<T1, T2>& pair) const {
44 auto h1 = std::hash<T1>()(pair.first);
45 auto h2 = std::hash<T2>()(pair.second);
46
47 // Use boost's combine_hash strategy
48 return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
49 }
50};
51
52// MIXED_PRECISION_ALWAYS ops should always be done in lower precision due to the speed and memory
53// savings. MIXED_PRECISION_FOLLOW ops can be done in lower precision but don't have speedups to
54// justify a cast. MIXED_PRECISION_NEVER colored ops should not be done in lower precision due to
55// numerical reasons.
56enum MixedTypeConversionCategory : int {
57 MIXED_PRECISION_ALWAYS = 0,
58 MIXED_PRECISION_FOLLOW = 1,
59 MIXED_PRECISION_NEVER = 2
60};
61
62// A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype
63using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, DataType>, Expr, pair_hash>;
64
65// Return array is of type : [MixedTypeConversionCategory (int), String, String]
66// The fields are : [ConversionCategory, accumulation_datatype, output_datatype]
67// Call is a call node, DataType is the mixed precision type
68using FTVMMixedPrecisionConversionType = runtime::TypedPackedFunc<Array<ObjectRef>(
69 const Call& call_node, const std::string& target_dtype_str)>;
70
71/*! \brief This class transforms the given relay module into a version where
72 * as many operations as possible operate in the target mixed precision dtype.
73 *
74 * Input : A Relay module with operations registered with FTVMMixedPrecisionConversionType
75 * functions. These describe when and how the operations will be transformed
76 * into the target precision dtype.
77 *
78 * Output : A Relay module with some operations transformed according to the below
79 * methodology.
80 *
81 * Methodology :
82 * 1) Each relay Op is either of conversion category ALWAYS, FOLLOW, NEVER
83 * defined by the associated FTVMMixedPrecisionConversionType function.
84 * If an operation is not registered, it by default is assumed to be
85 * FOLLOW.
86 * 2) ALWAYS operations always convert the input floating point args into
87 * the target mixed precision dtype. FOLLOW Ops will convert the input
88 * floating point args back into FP32 unless all floating point args
89 * are in the target mixed precision dtypes. NEVER ops will always cast
90 * inputs back into FP32.
91 * 3) Each ALWAYS Op, and FOLLOW Op with mixed precision dtype arguments
92 * also have an associated accumulation_dtype and output_dtype which
93 * describe whether a larger dtype is used to accumulate the results
94 * of the operation. The output_dtype meanwhile describes the dtype
95 * most Ops should use from this accumulator.
96 */
97class MixedPrecisionPass : public MixedModeMutator {
98 private:
99 /*! \brief A cache of nodes + target dtype to a cast version of the node with target dtype. */
100 CachedCastNodes cast_nodes_cache_;
101
102 /*! \brief The target datatype we want to convert to e.g. FP16 */
103 const DataType mixed_precision_type_;
104
105 /*! \brief Map of Ops with no associated FTVMMixedPrecisionConversionType to the times they were
106 * encountered. Used for emitting warnings on missing ops in the pass.
107 */
108 std::unordered_map<std::string, int> missing_ops_;
109 const RelayExprNode* root_;
110 std::vector<DataType> original_dtype_;
111 bool keep_orig_output_dtype_;
112
113 Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const {
114 /* If the accumulation dtype is in the attributes make a copy and mutate the field. */
115 Attrs cur_attrs = call->attrs;
116 if (cur_attrs.get() != nullptr) {
117 // TODO(AndrewZhaoLuo): Figure out a better way to do this
118 // modify output_dtype attributes (accumulation dtypes for ops)
119 if (auto attrs = cur_attrs.as<Conv1DAttrs>()) {
120 return ModifyAttrsOutputDType(attrs, accumulation_dtype);
121 } else if (auto attrs = cur_attrs.as<Conv1DTransposeAttrs>()) {
122 return ModifyAttrsOutputDType(attrs, accumulation_dtype);
123 } else if (auto attrs = cur_attrs.as<Conv2DAttrs>()) {
124 return ModifyAttrsOutputDType(attrs, accumulation_dtype);
125 } else if (auto attrs = cur_attrs.as<Conv2DTransposeAttrs>()) {
126 return ModifyAttrsOutputDType(attrs, accumulation_dtype);
127 } else if (auto attrs = cur_attrs.as<Conv2DWinogradAttrs>()) {
128 return ModifyAttrsOutputDType(attrs, accumulation_dtype);
129 } else if (auto attrs = cur_attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>()) {
130 return ModifyAttrsOutputDType(attrs, accumulation_dtype);
131 } else if (auto attrs = cur_attrs.as<DeformableConv2DAttrs>()) {
132 return ModifyAttrsOutputDType(attrs, accumulation_dtype);
133 } else if (auto attrs = cur_attrs.as<Conv3DAttrs>()) {
134 return ModifyAttrsOutputDType(attrs, accumulation_dtype);
135 } else if (auto attrs = cur_attrs.as<Conv3DTransposeAttrs>()) {
136 return ModifyAttrsOutputDType(attrs, accumulation_dtype);
137 } else if (auto attrs = cur_attrs.as<Conv3DWinogradAttrs>()) {
138 return ModifyAttrsOutputDType(attrs, accumulation_dtype);
139 } else if (auto attrs = cur_attrs.as<DenseAttrs>()) {
140 return ModifyAttrsOutputDType(attrs, accumulation_dtype);
141 } else if (auto attrs = cur_attrs.as<BatchMatmulAttrs>()) {
142 return ModifyAttrsOutputDType(attrs, accumulation_dtype);
143 }
144
145 // modify dtype attributes (creating new tensors of type dtype)
146 if (auto attrs = cur_attrs.as<InitOpAttrs>()) {
147 return ModifyAttrsDType(attrs, accumulation_dtype);
148 }
149 }
150
151 return cur_attrs;
152 }
153
154 template <typename T>
155 Attrs ModifyAttrsOutputDType(const T* attrs, const DataType& accumulation_dtype) const {
156 /*
157 Helper template to modify relevant attributes with out_dtype type.
158 These represent accumulation dtypes for some operations e.g.
159 conv2d might take in fp16 and give a fp32 result.
160 Attrs is const because we get it as a const.
161 */
162 DataType cur_type = (attrs->out_dtype);
163 ObjectPtr<T> new_attrs = make_object<T>(*attrs);
164 if (cur_type.is_float() || cur_type.is_bfloat16() || cur_type.is_void()) {
165 new_attrs->out_dtype = accumulation_dtype;
166 }
167 return Attrs(new_attrs);
168 }
169
170 template <typename T>
171 Attrs ModifyAttrsDType(const T* attrs, const DataType& accumulation_dtype) const {
172 /*
173 Helper template to modify relevant attributes with dtype type.
174 This determines the output dtype for some ops. For example
175 zeros creates a tensor of zeros of the specified dtype.
176 Attrs is const because we get it as a const.
177 */
178 DataType cur_type = (attrs->dtype);
179 ObjectPtr<T> new_attrs = make_object<T>(*attrs);
180 if (cur_type.is_float() || cur_type.is_bfloat16() || cur_type.is_void()) {
181 new_attrs->dtype = accumulation_dtype;
182 }
183 return Attrs(new_attrs);
184 }
185
186 Type GetType(const Expr& expr) const {
187 // The expression has not been changed AND it's existing type
188 // is known to still be valid. (See special handling for tuples etc
189 // below for where we null out checked_type_ when we can not
190 // sure it is still valid.
191 Type checked_type = expr->checked_type_;
192 if (checked_type.defined()) {
193 return checked_type;
194 }
195
196 // This also populates the checked_type_ field for expr
197 return transform::InferTypeLocal(expr);
198 }
199
200 bool IsMixedPrecisionType(const Type& t, bool ignore_non_float = false) const {
201 /* Returns whether t is a type with only target mixed precision type elements.
202 If ignore_non_float, then ignore non-floating types.
203 */
204 if (const TensorTypeNode* tensor_type = t.as<TensorTypeNode>()) {
205 bool is_supported_floating_point_type =
206 (tensor_type->dtype).is_float() || (tensor_type->dtype).is_bfloat16();
207 return (ignore_non_float && !is_supported_floating_point_type) ||
208 tensor_type->dtype == mixed_precision_type_;
209 } else if (const TupleTypeNode* tuple_type = t.as<TupleTypeNode>()) {
210 for (Type t : tuple_type->fields) {
211 if (!IsMixedPrecisionType(t, ignore_non_float)) return false;
212 }
213 return true;
214 } else {
215 LOG(FATAL) << "Unsupported type " << t << " we don't know how to handle";
216 }
217 }
218
219 Expr CachedCast(const Expr& expr, const DataType& expr_dtype, const DataType& wanted_dtype) {
220 /* Cast tensor to the wanted datatype, returning a cached version if it's already been done. */
221
222 // If this is not a floating point type, do not cast. E.g. it might be an integer
223 if (!(expr_dtype.is_float() || expr_dtype.is_bfloat16())) {
224 return expr;
225 }
226
227 if (expr_dtype == wanted_dtype) {
228 return expr;
229 }
230
231 const ExprNode* expr_node = expr.as<ExprNode>();
232 CHECK(expr_node) << "Non-expression node found in cast: " << expr;
233
234 // Use cached result if possible.
235 auto search = cast_nodes_cache_.find({expr_node, wanted_dtype});
236 if (search != cast_nodes_cache_.end()) {
237 return search->second;
238 }
239
240 Expr result = Cast(expr, wanted_dtype);
241 cast_nodes_cache_[{expr_node, wanted_dtype}] = result;
242
243 // Reverse the cache result, e.g. if we want to reverse the cast simply point to original node
244 const ExprNode* new_expr_node = result.as<ExprNode>();
245 cast_nodes_cache_[{new_expr_node, expr_dtype}] = expr;
246 return result;
247 }
248
249 Expr CastArg(const Expr& expr, const Type& expr_type, const DataType& wanted_dtype) {
250 /* Helper for casting arguments to call_nodes handling all relevant cases. */
251 if (const TensorTypeNode* tensor_type = expr_type.as<TensorTypeNode>()) {
252 return CachedCast(expr, tensor_type->dtype, wanted_dtype);
253 } else if (const TupleTypeNode* tuple_type = expr_type.as<TupleTypeNode>()) {
254 Array<Expr> new_expr;
255 bool all_same = true;
256 for (size_t i = 0; i < (tuple_type->fields).size(); i++) {
257 Expr tuple_element = GetField(expr, i);
258 Type tuple_element_dtype = (tuple_type->fields)[i];
259 Expr casted_element = CastArg(tuple_element, tuple_element_dtype, wanted_dtype);
260 new_expr.push_back(casted_element);
261 all_same &= casted_element.same_as(tuple_element);
262 }
263 return all_same ? expr : Tuple(new_expr);
264 }
265 CHECK(0) << "Unsupported type " << expr_type << " we don't know how to cast for arguments!";
266 return expr;
267 }
268
269 std::pair<Array<Expr>, Array<Type>> CastAllArgs(const Array<Expr>& cur_args,
270 const Array<Type>& cur_arg_types,
271 const DataType& wanted_dtype) {
272 Array<Expr> new_args;
273 Array<Type> new_arg_types;
274 for (size_t i = 0; i < cur_args.size(); i++) {
275 Expr cur_arg = cur_args[i];
276 Type cur_arg_type = cur_arg_types[i];
277 Expr new_arg = CastArg(cur_arg, cur_arg_type, wanted_dtype);
278 Type new_arg_type = GetType(new_arg);
279 new_args.push_back(new_arg);
280 new_arg_types.push_back(new_arg_type);
281 }
282 return {new_args, new_arg_types};
283 }
284
285 public:
286 using MixedModeMutator::VisitExpr_;
287
288 explicit MixedPrecisionPass(Expr base, bool keep_orig_output_dtype,
289 DataType mixed_precision_type = DataType::Float(16))
290 : MixedModeMutator(),
291 mixed_precision_type_(mixed_precision_type),
292 root_(Downcast<Function>(base)->body.get()),
293 keep_orig_output_dtype_(keep_orig_output_dtype) {
294 if (keep_orig_output_dtype_) {
295 if (root_->IsInstance<tvm::relay::TupleNode>()) {
296 const TupleTypeNode* tuple_type = (root_->checked_type_).as<TupleTypeNode>();
297 for (Type t : tuple_type->fields) {
298 const TensorTypeNode* tensor_type = t.as<TensorTypeNode>();
299 original_dtype_.push_back(tensor_type->dtype);
300 }
301 } else if (root_->IsInstance<tvm::relay::CallNode>()) {
302 original_dtype_.push_back((root_->checked_type_).as<TensorTypeNode>()->dtype);
303 }
304 }
305 if (!(mixed_precision_type_.is_float() || mixed_precision_type_.is_bfloat16())) {
306 LOG(FATAL) << "Only support IEEE floating point mixed precision types and bfloat16, but got "
307 << mixed_precision_type_;
308 }
309 }
310
311 Expr Rewrite_(const CallNode* pre_call_node, const Expr& post) final {
312 const CallNode* post_call_node = post.as<CallNode>();
313 CHECK(post_call_node) << "Expected a CallNode, but got " << post;
314
315 Expr cur_op = post_call_node->op;
316
317 // TODO(AndrewZhaoLuo): Support ADTs
318 // Relay's algebraic data types are not supported yet.
319 ICHECK(!cur_op.as<GlobalVarNode>() // used to declare functions for recursion
320 && !cur_op.as<ConstructorNode>() // constructing ADT types
321 && !cur_op.as<VarNode>()) // used for calling recursive functions
322 << "Algebraic Data Types (ADT) are not supported yet for mixed precision pass.";
323
324 // Get info on the operation being called:
325 // conversion category (int), accumulation dtype (str), output dtype (str)
326 MixedTypeConversionCategory initial_category;
327 DataType accumulation_dtype, output_dtype;
328 if (cur_op.as<FunctionNode>()) {
329 // Avoid messing with functions to avoid changing signature
330 initial_category = MIXED_PRECISION_NEVER;
331 accumulation_dtype = DataType::Float(32);
332 output_dtype = DataType::Float(32);
333 } else if (cur_op.as<OpNode>()) {
334 static auto attr_map =
335 Op::GetAttrMap<FTVMMixedPrecisionConversionType>("FTVMMixedPrecisionConversionType");
336 Op op = Downcast<Op>(cur_op);
337 if (attr_map.count(op)) {
338 // Calculate the conversion category and dtypes from registered attribute.
339 FTVMMixedPrecisionConversionType func = attr_map[op];
340 Array<ObjectRef> op_descriptor =
341 func(GetRef<Call>(pre_call_node), DLDataType2String(mixed_precision_type_));
342 ICHECK(op_descriptor.size() == 3)
343 << "got the wrong number of returned arguments (expected 3 got " << op_descriptor.size()
344 << ") from FTVMMixedPrecisionConversionType for " << AsText(op, false);
345
346 int64_t op_conversion_type = Downcast<Integer>(op_descriptor[0])->value;
347 initial_category = static_cast<MixedTypeConversionCategory>(op_conversion_type);
348 accumulation_dtype = DataType(String2DLDataType(Downcast<String>(op_descriptor[1])));
349 output_dtype = DataType(String2DLDataType(Downcast<String>(op_descriptor[2])));
350 } else {
351 missing_ops_[op->name] += 1;
352
353 // If not registered, by default assume is a generic FOLLOW operation.
354 initial_category = MIXED_PRECISION_FOLLOW;
355 accumulation_dtype = mixed_precision_type_;
356 output_dtype = mixed_precision_type_;
357 }
358 } else {
359 LOG(FATAL) << "Unsupported op type in CallNode: " << pre_call_node->op;
360 }
361
362 // First check if all the new mutated args are in lower precision form
363 Array<Type> cur_arg_types;
364 bool all_args_mixed_type_compatible = true;
365 for (Expr arg : post_call_node->args) {
366 Type cur_arg_type = GetType(arg);
367 cur_arg_types.push_back(cur_arg_type);
368
369 if (initial_category == MIXED_PRECISION_FOLLOW && all_args_mixed_type_compatible) {
370 // We can cast Vars and Constants to the right types so don't care about the types.
371 bool is_mixed_type_compatible = IsMixedPrecisionType(cur_arg_type, true) ||
372 arg->IsInstance<VarNode>() ||
373 arg->IsInstance<ConstantNode>();
374 all_args_mixed_type_compatible &= is_mixed_type_compatible;
375 }
376 }
377
378 // Determine the final category we want for conversion
379 MixedTypeConversionCategory final_category = initial_category;
380 if (initial_category == MIXED_PRECISION_FOLLOW) {
381 final_category =
382 all_args_mixed_type_compatible ? MIXED_PRECISION_ALWAYS : MIXED_PRECISION_NEVER;
383 }
384
385 // Create the new arguments to the call.
386 DataType wanted_arg_dtypes =
387 final_category == MIXED_PRECISION_ALWAYS ? mixed_precision_type_ : DataType::Float(32);
388 auto call_args_and_types = CastAllArgs(post_call_node->args, cur_arg_types, wanted_arg_dtypes);
389 Array<Expr> new_args = call_args_and_types.first;
390 Array<Type> new_arg_types;
391
392 if (pre_call_node->op.as<FunctionNode>()) {
393 // Function Nodes don't store type info in the Call, it should be a []
394 new_arg_types = pre_call_node->type_args;
395 } else {
396 new_arg_types = call_args_and_types.second;
397 }
398
399 // Finally create the new attributes.
400 if (final_category == MIXED_PRECISION_ALWAYS) {
401 Attrs new_attrs = GetNewAttrs(pre_call_node, accumulation_dtype);
402 Expr output = Call(cur_op, new_args, new_attrs, new_arg_types, pre_call_node->span);
403 if (accumulation_dtype != output_dtype) {
404 output = CastArg(output, GetType(output), output_dtype);
405 }
406 if (pre_call_node == root_ && keep_orig_output_dtype_) {
407 if (original_dtype_[0] != output_dtype) {
408 output = CastArg(output, GetType(output), original_dtype_[0]);
409 }
410 }
411 return output;
412 }
413
414 return Call(cur_op, new_args, pre_call_node->attrs, new_arg_types, pre_call_node->span);
415 }
416
417 Expr Rewrite_(const TupleGetItemNode* pre, const Expr& post) {
418 // The old checked type in the expression may not be valid so clear it
419 post->checked_type_ = Type(nullptr);
420 return post;
421 }
422
423 Expr Rewrite_(const TupleNode* pre, const Expr& post) {
424 // The old checked type in the expression may not be valid so clear it
425 post->checked_type_ = Type(nullptr);
426 if (pre == root_ && keep_orig_output_dtype_) {
427 Array<Expr> new_expr;
428 bool all_same = true;
429 for (size_t i = 0; i < original_dtype_.size(); i++) {
430 Expr output_element = GetField(post, i);
431 Expr casted_element;
432 auto output_element_type = transform::InferTypeLocal(output_element);
433 casted_element = CastArg(output_element, output_element_type, original_dtype_[i]);
434 new_expr.push_back(casted_element);
435 all_same &= casted_element.same_as(output_element);
436 }
437 if (!all_same) {
438 return Tuple(new_expr);
439 }
440 }
441 return post;
442 }
443
444 Expr VisitExpr_(const FunctionNode* func) final {
445 // Erase the ret_type annotation and let the normal pass recalculate
446 const_cast<FunctionNode*>(func)->ret_type = Type(nullptr);
447 return ExprMutator::VisitExpr_(func);
448 }
449
450 Expr VisitExpr_(const LetNode* op) final {
451 // First convert as much of the bound computation to lower precision as possible
452 Expr value = this->Mutate(op->value);
453
454 // Then rewrite the var type and associated expression
455 Var var = Downcast<Var>(this->Mutate(op->var));
456 VarNode* mutable_var = const_cast<VarNode*>((op->var).as<VarNode>());
457 mutable_var->type_annotation = GetType(value);
458 mutable_var->checked_type_ = mutable_var->type_annotation;
459
460 // Mutate body last as it may depend on previous results
461 Expr body = this->Mutate(op->body);
462 return Let(var, value, body, op->span);
463 }
464
465 // To access map of ops not registered for error reporting
466 friend Expr ToMixedPrecision(const Expr& expr, bool keep_orig_output_dtype,
467 const DataType& mixed_precision_type, int missing_op_mode);
468};
469
470Expr ToMixedPrecision(const Expr& expr, bool keep_orig_output_dtype,
471 const DataType& mixed_precision_type, int missing_op_mode) {
472 /*
473 missing_op_mode:
474
475 0: Does not allow any missing ops. Will throw errors and terminate the pass when encountering any.
476 1: Allow missing ops but throw warnings.
477 2: Allow missing ops and silently ignore them.
478 */
479 ICHECK(missing_op_mode >= 0 && missing_op_mode <= 2)
480 << " missing_op_mode must be either 0, 1, or 2 got " << missing_op_mode;
481
482 MixedPrecisionPass converter =
483 MixedPrecisionPass(expr, keep_orig_output_dtype, mixed_precision_type);
484 auto result = converter.Mutate(expr);
485
486 for (auto it = converter.missing_ops_.begin();
487 missing_op_mode != 2 && it != converter.missing_ops_.end(); it++) {
488 std::string op_name = it->first;
489 int appear_count = it->second;
490
491 LOG(WARNING) << "Op \"" << op_name << "\" not registered "
492 << "FTVMMixedPrecisionConversionType appears " << appear_count
493 << " times in graph.";
494 }
495
496 if (converter.missing_ops_.size() != 0 && missing_op_mode == 0) {
497 CHECK(0) << "Missing ops were found!";
498 }
499 return result;
500}
501
502namespace transform {
503
504Pass ToMixedPrecision(DataType mixed_precision_type, int missing_op_mode) {
505 runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
506 [=](Function f, IRModule m, PassContext pc) {
507 bool keep_orig_output_dtype = false;
508 keep_orig_output_dtype = pc->GetConfig("relay.ToMixedPrecision.keep_orig_output_dtype",
509 Bool(keep_orig_output_dtype))
510 .value();
511 return Downcast<Function>(
512 ToMixedPrecision(f, keep_orig_output_dtype, mixed_precision_type, missing_op_mode));
513 };
514 return CreateFunctionPass(pass_func, 0, "ToMixedPrecision", {});
515}
516
517TVM_REGISTER_GLOBAL("relay._transform.ToMixedPrecision").set_body_typed(ToMixedPrecision);
518
519} // namespace transform
520
521} // namespace relay
522} // namespace tvm
523