1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include <tvm/relay/expr.h>
20#include <tvm/relay/expr_functor.h>
21#include <tvm/relay/transform.h>
22#include <tvm/runtime/builtin_fp16.h>
23
24#include "pattern_utils.h"
25
26namespace tvm {
27namespace relay {
28
29class DivToMulRewrite : public MixedModeMutator {
30 Expr Rewrite_(const CallNode* pre, const Expr& post) final {
31 if (const CallNode* call_node = post.as<CallNode>()) {
32 if (call_node->op == Op::Get("divide")) {
33 auto rhs = call_node->args[1].as<ConstantNode>();
34 if (rhs != nullptr) {
35 auto inv =
36 runtime::NDArray::Empty(rhs->data.Shape(), rhs->data.DataType(), rhs->data->device);
37 std::string dtype = DLDataType2String(rhs->data.DataType());
38 if (dtype == "float32") {
39 float rhs_val = static_cast<float*>(rhs->data->data)[0];
40 // Check for division by zero
41 if (rhs_val == 0.) {
42 return post;
43 }
44 static_cast<float*>(inv->data)[0] = 1. / rhs_val;
45 } else if (dtype == "float64") {
46 double rhs_val = static_cast<double*>(rhs->data->data)[0];
47 // Check for division by zero
48 if (rhs_val == 0.) {
49 return post;
50 }
51 static_cast<double*>(inv->data)[0] = 1. / rhs_val;
52 } else if (dtype == "float16") {
53 // Do f16 math in f32
54 float rhs_val = __gnu_h2f_ieee(static_cast<uint16_t*>(rhs->data->data)[0]);
55 // Check for division by zero
56 if (rhs_val == 0.) {
57 return post;
58 }
59 static_cast<uint16_t*>(inv->data)[0] = __gnu_f2h_ieee(1. / rhs_val);
60 } else {
61 // Cannot do 1/int because it will truncate
62 return post;
63 }
64 return Multiply(call_node->args[0], Constant(inv));
65 }
66 }
67 }
68 return post;
69 }
70};
71
72namespace transform {
73
74Pass DivToMul() {
75 runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
76 [=](Function f, IRModule m, PassContext pc) {
77 return Downcast<Function>(DivToMulRewrite().Mutate(f));
78 };
79 return CreateFunctionPass(pass_func, 0, "DivToMul", {"InferType", "FoldConstant"});
80}
81
82TVM_REGISTER_GLOBAL("relay._transform.DivToMul").set_body_typed(DivToMul);
83
84} // namespace transform
85} // namespace relay
86} // namespace tvm
87