1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include <tvm/relay/expr.h> |
20 | #include <tvm/relay/expr_functor.h> |
21 | #include <tvm/relay/transform.h> |
22 | #include <tvm/runtime/builtin_fp16.h> |
23 | |
24 | #include "pattern_utils.h" |
25 | |
26 | namespace tvm { |
27 | namespace relay { |
28 | |
29 | class DivToMulRewrite : public MixedModeMutator { |
30 | Expr Rewrite_(const CallNode* pre, const Expr& post) final { |
31 | if (const CallNode* call_node = post.as<CallNode>()) { |
32 | if (call_node->op == Op::Get("divide" )) { |
33 | auto rhs = call_node->args[1].as<ConstantNode>(); |
34 | if (rhs != nullptr) { |
35 | auto inv = |
36 | runtime::NDArray::Empty(rhs->data.Shape(), rhs->data.DataType(), rhs->data->device); |
37 | std::string dtype = DLDataType2String(rhs->data.DataType()); |
38 | if (dtype == "float32" ) { |
39 | float rhs_val = static_cast<float*>(rhs->data->data)[0]; |
40 | // Check for division by zero |
41 | if (rhs_val == 0.) { |
42 | return post; |
43 | } |
44 | static_cast<float*>(inv->data)[0] = 1. / rhs_val; |
45 | } else if (dtype == "float64" ) { |
46 | double rhs_val = static_cast<double*>(rhs->data->data)[0]; |
47 | // Check for division by zero |
48 | if (rhs_val == 0.) { |
49 | return post; |
50 | } |
51 | static_cast<double*>(inv->data)[0] = 1. / rhs_val; |
52 | } else if (dtype == "float16" ) { |
53 | // Do f16 math in f32 |
54 | float rhs_val = __gnu_h2f_ieee(static_cast<uint16_t*>(rhs->data->data)[0]); |
55 | // Check for division by zero |
56 | if (rhs_val == 0.) { |
57 | return post; |
58 | } |
59 | static_cast<uint16_t*>(inv->data)[0] = __gnu_f2h_ieee(1. / rhs_val); |
60 | } else { |
61 | // Cannot do 1/int because it will truncate |
62 | return post; |
63 | } |
64 | return Multiply(call_node->args[0], Constant(inv)); |
65 | } |
66 | } |
67 | } |
68 | return post; |
69 | } |
70 | }; |
71 | |
72 | namespace transform { |
73 | |
74 | Pass DivToMul() { |
75 | runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func = |
76 | [=](Function f, IRModule m, PassContext pc) { |
77 | return Downcast<Function>(DivToMulRewrite().Mutate(f)); |
78 | }; |
79 | return CreateFunctionPass(pass_func, 0, "DivToMul" , {"InferType" , "FoldConstant" }); |
80 | } |
81 | |
82 | TVM_REGISTER_GLOBAL("relay._transform.DivToMul" ).set_body_typed(DivToMul); |
83 | |
84 | } // namespace transform |
85 | } // namespace relay |
86 | } // namespace tvm |
87 | |