1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/op/contrib/ethosu/binary_elementwise.cc
22 * \brief Binary elementwise operators definitions for the Arm(R) Ethos(TM)-U NPU.
23 */
24#include <tvm/relay/op.h>
25
26#include "common.h"
27#include "op_attrs.h"
28
29namespace tvm {
30namespace relay {
31namespace op {
32namespace contrib {
33namespace ethosu {
34
35bool EthosuBinaryElementwiseRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
36 const TypeReporter& reporter) {
37 const int ifm_index = 0;
38 const int ifm2_index = 1;
39 const int result_index = 3;
40 ICHECK_EQ(types.size(), result_index + 1);
41
42 const auto* ifm = types[ifm_index].as<TensorTypeNode>();
43 const auto* ifm2 = types[ifm2_index].as<TensorTypeNode>();
44 if (ifm == nullptr) return false;
45 if (ifm2 == nullptr) return false;
46
47 const auto* param = attrs.as<EthosuBinaryElementwiseAttrs>();
48 ICHECK(param != nullptr) << "EthosuBinaryElementwiseAttrs cannot be nullptr.";
49
50 const String operator_name = "ethosu_binary_elementwise";
51 const String operator_type = param->operator_type;
52 const DataType ifm_dtype = ifm->dtype;
53 const DataType ifm2_dtype = ifm2->dtype;
54 const DataType ofm_dtype = DataTypeFromString(param->ofm_dtype);
55
56 CheckDataTypeMatch(reporter, ifm_dtype, ifm2_dtype, operator_name, "ifm", "ifm2", operator_type);
57
58 if (operator_type == "ADD" || operator_type == "SUB" || operator_type == "MUL") {
59 auto allowed_types = {DataType::Int(8), DataType::UInt(8), DataType::Int(16),
60 DataType::Int(32)};
61 CheckDataType(reporter, ifm_dtype, allowed_types, operator_name, "ifm", operator_type);
62 CheckDataType(reporter, ofm_dtype, allowed_types, operator_name, "ofm", operator_type);
63 } else if (operator_type == "MIN" || operator_type == "MAX") {
64 auto allowed_types = {DataType::Int(8), DataType::UInt(8)};
65 CheckDataType(reporter, ifm_dtype, allowed_types, operator_name, "ifm", operator_type);
66 CheckDataTypeMatch(reporter, ifm_dtype, ofm_dtype, operator_name, "ifm", "ofm", operator_type);
67 } else if (operator_type == "SHR") {
68 CheckDataType(reporter, ifm_dtype, {DataType::Int(32)}, operator_name, "ifm", operator_type);
69 CheckDataType(reporter, ofm_dtype, {DataType::UInt(8), DataType::Int(8), DataType::Int(32)},
70 operator_name, "ofm", operator_type);
71 } else if (operator_type == "SHL") {
72 CheckDataType(reporter, ifm_dtype, {DataType::Int(32)}, operator_name, "ifm", operator_type);
73 CheckDataType(reporter, ofm_dtype, {DataType::Int(32)}, operator_name, "ofm", operator_type);
74 } else {
75 reporter->GetDiagCtx().EmitFatal(
76 Diagnostic::Error(reporter->GetSpan())
77 << "Invalid operator: expected " << operator_name << " 'ADD' or 'SUB' or 'MUL' or "
78 << "'MIN' or 'MAX' or 'SHR' or 'SHL' for operator_type but was " << param->operator_type);
79 return false;
80 }
81
82 // Assign ofm type
83 auto ofm_shape = EthosuInferElementwiseOutputShape(ifm->shape, param->ifm_layout,
84 param->ofm_layout, param->ifm_channels);
85 reporter->Assign(types[result_index], TensorType(ofm_shape, ofm_dtype));
86 return true;
87}
88
89Expr MakeEthosuBinaryElementwise(Expr ifm, Expr ifm2, Expr lut, String operator_type,
90 double ifm_scale, int ifm_zero_point, double ifm2_scale,
91 int ifm2_zero_point, double ofm_scale, int ofm_zero_point,
92 IndexExpr ifm_channels, IndexExpr ifm2_channels,
93 bool reversed_operands, String activation, int clip_min,
94 int clip_max, String rounding_mode, String ifm_layout,
95 String ifm2_layout, String ofm_layout, String ofm_dtype,
96 bool use_rescale, int rescale_scale, int rescale_shift) {
97 auto attrs = make_object<EthosuBinaryElementwiseAttrs>();
98
99 attrs->operator_type = std::move(operator_type);
100 attrs->ifm_scale = ifm_scale;
101 attrs->ifm_zero_point = ifm_zero_point;
102 attrs->ifm2_scale = ifm2_scale;
103 attrs->ifm2_zero_point = ifm2_zero_point;
104 attrs->ofm_scale = ofm_scale;
105 attrs->ofm_zero_point = ofm_zero_point;
106 attrs->ifm_channels = std::move(ifm_channels);
107 attrs->ifm2_channels = std::move(ifm2_channels);
108 attrs->reversed_operands = reversed_operands;
109 attrs->activation = std::move(activation);
110 attrs->clip_min = clip_min;
111 attrs->clip_max = clip_max;
112 attrs->rounding_mode = std::move(rounding_mode);
113 attrs->ifm_layout = std::move(ifm_layout);
114 attrs->ifm2_layout = std::move(ifm2_layout);
115 attrs->ofm_layout = std::move(ofm_layout);
116 attrs->ofm_dtype = std::move(ofm_dtype);
117 attrs->use_rescale = use_rescale;
118 attrs->rescale_scale = rescale_scale;
119 attrs->rescale_shift = rescale_shift;
120
121 static const Op& op = Op::Get("contrib.ethosu.binary_elementwise");
122 return Call(op, {ifm, ifm2, lut}, Attrs(attrs), {});
123}
124
125TVM_REGISTER_GLOBAL("relay.op._make.ethosu_binary_elementwise")
126 .set_body_typed(MakeEthosuBinaryElementwise);
127
128RELAY_REGISTER_OP("contrib.ethosu.binary_elementwise")
129 .describe(R"code(Arm(R) Ethos(TM)-U NPU quantized binary elementwise operator.
130
131This Relay operator corresponds to the hardware-implemented quantized
132binary elementwise operation found on Ethos(TM)-U NPU. It accepts either NHWC
133or NHCWB16 format for the inputs data (input feature maps, or IFMs).
134
135Reference: https://developer.arm.com/documentation/102420/0200/
136
137- **ifm**: NHWC - (1, ifm_height, ifm_width, ifm_channels)
138 NHCWB16 - (1, ifm_height, ifm_channels // 16, ifm_width, 16)
139- **ifm2**: NHWC - (1, ifm_height, ifm_width, ifm_channels)
140 NHCWB16 - (1, ifm_height, ifm_channels // 16, ifm_width, 16)
141- **ofm**: (1, ofm_height, ofm_width, ifm_channels)
142
143)code" TVM_ADD_FILELINE)
144 .set_attrs_type<EthosuBinaryElementwiseAttrs>()
145 .set_num_inputs(3)
146 .add_argument("ifm", "Tensor", "The Input Feature Map tensor (IFM).")
147 .add_argument("ifm2", "Tensor", "The Input Feature Map tensor 2 (IFM2).")
148 .add_argument("lut", "Tensor", "The look-up table of values to use if activation = 'LUT'")
149 .set_support_level(11)
150 .add_type_rel("EthosuBinaryElementwise", EthosuBinaryElementwiseRel);
151
152} // namespace ethosu
153} // namespace contrib
154} // namespace op
155} // namespace relay
156} // namespace tvm
157