1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/relay/op/contrib/ethosu/binary_elementwise.cc |
22 | * \brief Binary elementwise operators definitions for the Arm(R) Ethos(TM)-U NPU. |
23 | */ |
24 | #include <tvm/relay/op.h> |
25 | |
26 | #include "common.h" |
27 | #include "op_attrs.h" |
28 | |
29 | namespace tvm { |
30 | namespace relay { |
31 | namespace op { |
32 | namespace contrib { |
33 | namespace ethosu { |
34 | |
35 | bool EthosuBinaryElementwiseRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
36 | const TypeReporter& reporter) { |
37 | const int ifm_index = 0; |
38 | const int ifm2_index = 1; |
39 | const int result_index = 3; |
40 | ICHECK_EQ(types.size(), result_index + 1); |
41 | |
42 | const auto* ifm = types[ifm_index].as<TensorTypeNode>(); |
43 | const auto* ifm2 = types[ifm2_index].as<TensorTypeNode>(); |
44 | if (ifm == nullptr) return false; |
45 | if (ifm2 == nullptr) return false; |
46 | |
47 | const auto* param = attrs.as<EthosuBinaryElementwiseAttrs>(); |
48 | ICHECK(param != nullptr) << "EthosuBinaryElementwiseAttrs cannot be nullptr." ; |
49 | |
50 | const String operator_name = "ethosu_binary_elementwise" ; |
51 | const String operator_type = param->operator_type; |
52 | const DataType ifm_dtype = ifm->dtype; |
53 | const DataType ifm2_dtype = ifm2->dtype; |
54 | const DataType ofm_dtype = DataTypeFromString(param->ofm_dtype); |
55 | |
56 | CheckDataTypeMatch(reporter, ifm_dtype, ifm2_dtype, operator_name, "ifm" , "ifm2" , operator_type); |
57 | |
58 | if (operator_type == "ADD" || operator_type == "SUB" || operator_type == "MUL" ) { |
59 | auto allowed_types = {DataType::Int(8), DataType::UInt(8), DataType::Int(16), |
60 | DataType::Int(32)}; |
61 | CheckDataType(reporter, ifm_dtype, allowed_types, operator_name, "ifm" , operator_type); |
62 | CheckDataType(reporter, ofm_dtype, allowed_types, operator_name, "ofm" , operator_type); |
63 | } else if (operator_type == "MIN" || operator_type == "MAX" ) { |
64 | auto allowed_types = {DataType::Int(8), DataType::UInt(8)}; |
65 | CheckDataType(reporter, ifm_dtype, allowed_types, operator_name, "ifm" , operator_type); |
66 | CheckDataTypeMatch(reporter, ifm_dtype, ofm_dtype, operator_name, "ifm" , "ofm" , operator_type); |
67 | } else if (operator_type == "SHR" ) { |
68 | CheckDataType(reporter, ifm_dtype, {DataType::Int(32)}, operator_name, "ifm" , operator_type); |
69 | CheckDataType(reporter, ofm_dtype, {DataType::UInt(8), DataType::Int(8), DataType::Int(32)}, |
70 | operator_name, "ofm" , operator_type); |
71 | } else if (operator_type == "SHL" ) { |
72 | CheckDataType(reporter, ifm_dtype, {DataType::Int(32)}, operator_name, "ifm" , operator_type); |
73 | CheckDataType(reporter, ofm_dtype, {DataType::Int(32)}, operator_name, "ofm" , operator_type); |
74 | } else { |
75 | reporter->GetDiagCtx().EmitFatal( |
76 | Diagnostic::Error(reporter->GetSpan()) |
77 | << "Invalid operator: expected " << operator_name << " 'ADD' or 'SUB' or 'MUL' or " |
78 | << "'MIN' or 'MAX' or 'SHR' or 'SHL' for operator_type but was " << param->operator_type); |
79 | return false; |
80 | } |
81 | |
82 | // Assign ofm type |
83 | auto ofm_shape = EthosuInferElementwiseOutputShape(ifm->shape, param->ifm_layout, |
84 | param->ofm_layout, param->ifm_channels); |
85 | reporter->Assign(types[result_index], TensorType(ofm_shape, ofm_dtype)); |
86 | return true; |
87 | } |
88 | |
89 | Expr MakeEthosuBinaryElementwise(Expr ifm, Expr ifm2, Expr lut, String operator_type, |
90 | double ifm_scale, int ifm_zero_point, double ifm2_scale, |
91 | int ifm2_zero_point, double ofm_scale, int ofm_zero_point, |
92 | IndexExpr ifm_channels, IndexExpr ifm2_channels, |
93 | bool reversed_operands, String activation, int clip_min, |
94 | int clip_max, String rounding_mode, String ifm_layout, |
95 | String ifm2_layout, String ofm_layout, String ofm_dtype, |
96 | bool use_rescale, int rescale_scale, int rescale_shift) { |
97 | auto attrs = make_object<EthosuBinaryElementwiseAttrs>(); |
98 | |
99 | attrs->operator_type = std::move(operator_type); |
100 | attrs->ifm_scale = ifm_scale; |
101 | attrs->ifm_zero_point = ifm_zero_point; |
102 | attrs->ifm2_scale = ifm2_scale; |
103 | attrs->ifm2_zero_point = ifm2_zero_point; |
104 | attrs->ofm_scale = ofm_scale; |
105 | attrs->ofm_zero_point = ofm_zero_point; |
106 | attrs->ifm_channels = std::move(ifm_channels); |
107 | attrs->ifm2_channels = std::move(ifm2_channels); |
108 | attrs->reversed_operands = reversed_operands; |
109 | attrs->activation = std::move(activation); |
110 | attrs->clip_min = clip_min; |
111 | attrs->clip_max = clip_max; |
112 | attrs->rounding_mode = std::move(rounding_mode); |
113 | attrs->ifm_layout = std::move(ifm_layout); |
114 | attrs->ifm2_layout = std::move(ifm2_layout); |
115 | attrs->ofm_layout = std::move(ofm_layout); |
116 | attrs->ofm_dtype = std::move(ofm_dtype); |
117 | attrs->use_rescale = use_rescale; |
118 | attrs->rescale_scale = rescale_scale; |
119 | attrs->rescale_shift = rescale_shift; |
120 | |
121 | static const Op& op = Op::Get("contrib.ethosu.binary_elementwise" ); |
122 | return Call(op, {ifm, ifm2, lut}, Attrs(attrs), {}); |
123 | } |
124 | |
125 | TVM_REGISTER_GLOBAL("relay.op._make.ethosu_binary_elementwise" ) |
126 | .set_body_typed(MakeEthosuBinaryElementwise); |
127 | |
128 | RELAY_REGISTER_OP("contrib.ethosu.binary_elementwise" ) |
129 | .describe(R"code(Arm(R) Ethos(TM)-U NPU quantized binary elementwise operator. |
130 | |
131 | This Relay operator corresponds to the hardware-implemented quantized |
132 | binary elementwise operation found on Ethos(TM)-U NPU. It accepts either NHWC |
133 | or NHCWB16 format for the inputs data (input feature maps, or IFMs). |
134 | |
135 | Reference: https://developer.arm.com/documentation/102420/0200/ |
136 | |
137 | - **ifm**: NHWC - (1, ifm_height, ifm_width, ifm_channels) |
138 | NHCWB16 - (1, ifm_height, ifm_channels // 16, ifm_width, 16) |
139 | - **ifm2**: NHWC - (1, ifm_height, ifm_width, ifm_channels) |
140 | NHCWB16 - (1, ifm_height, ifm_channels // 16, ifm_width, 16) |
141 | - **ofm**: (1, ofm_height, ofm_width, ifm_channels) |
142 | |
143 | )code" TVM_ADD_FILELINE) |
144 | .set_attrs_type<EthosuBinaryElementwiseAttrs>() |
145 | .set_num_inputs(3) |
146 | .add_argument("ifm" , "Tensor" , "The Input Feature Map tensor (IFM)." ) |
147 | .add_argument("ifm2" , "Tensor" , "The Input Feature Map tensor 2 (IFM2)." ) |
148 | .add_argument("lut" , "Tensor" , "The look-up table of values to use if activation = 'LUT'" ) |
149 | .set_support_level(11) |
150 | .add_type_rel("EthosuBinaryElementwise" , EthosuBinaryElementwiseRel); |
151 | |
152 | } // namespace ethosu |
153 | } // namespace contrib |
154 | } // namespace op |
155 | } // namespace relay |
156 | } // namespace tvm |
157 | |