1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \brief local response normalization op constructions
22 * \file nn/local_response_norm.h
23 */
24#ifndef TVM_TOPI_NN_LOCAL_RESPONSE_NORM_H_
25#define TVM_TOPI_NN_LOCAL_RESPONSE_NORM_H_
26
27#include <tvm/te/operation.h>
28#include <tvm/topi/tags.h>
29
30#include <string>
31
32namespace tvm {
33namespace topi {
34namespace nn {
35
36using namespace tvm::te;
37
38/*!
39 * \brief Local response normalization inference operator
40 *
41 * \param data The input tensor. 4-D shape NCHW or NHWC
42 * \param size Integer to define normalisation window size
43 * \param axis Input data layout channel axis
44 * \param alpha Float scaling factor
45 * \param beta Exponent value
46 * \param bias Offset to avoid dividing by zero
47 * \param name The name of the operation
48 * \param tag The tag to mark the operation
49 *
50 * \return A Tensor whose op member is the Local response normalization operation
51 */
52inline Tensor lrn(const Tensor& data, int size, int axis = 1, float alpha = 0.0001,
53 float beta = 0.75, float bias = 2, std::string name = "tensor",
54 std::string tag = kBroadcast) {
55 ICHECK_EQ(data->shape.size(), 4) << "LRN requires 4-D input";
56 ICHECK_EQ(size % 2, 1) << "size should be odd number";
57 ICHECK(axis == 1 || axis == 3) << "axis should be 1 or 3 for NCHW and NHWC";
58 ICHECK(data->dtype.is_float()) << "datatype should be float";
59 auto input_shape = data->shape;
60 Array<PrimExpr> pad_before{0, 0, 0, 0};
61 Array<PrimExpr> pad_after{0, 0, 0, 0};
62 pad_before.Set(axis, static_cast<PrimExpr>(size / 2));
63 pad_after.Set(axis, static_cast<PrimExpr>(size / 2));
64 auto pad_data = pad(data, pad_before, pad_after, 0, "pad_data");
65 auto rxs = tvm::te::reduce_axis(Range(0, size), "rxs");
66 Tensor sqr_sum;
67 if (axis == 1) {
68 sqr_sum = tvm::te::compute(
69 input_shape,
70 [&](Var i, Var l, Var j, Var k) {
71 return tvm::sum(pad_data(i, l + rxs, j, k) * pad_data(i, l + rxs, j, k), {rxs});
72 },
73 "tensor", "sqr_sum");
74 } else if (axis == 3) {
75 sqr_sum = tvm::te::compute(
76 input_shape,
77 [&](Var i, Var l, Var j, Var k) {
78 return tvm::sum(pad_data(i, l, j, k + rxs) * pad_data(i, l, j, k + rxs), {rxs});
79 },
80 "tensor", "sqr_sum");
81 }
82 PrimExpr alpha_imm = tvm::te::make_const(data->dtype, alpha);
83 PrimExpr beta_imm = tvm::te::make_const(data->dtype, beta);
84 PrimExpr bias_imm = tvm::te::make_const(data->dtype, bias);
85 auto sqrt_sum_up = tvm::te::compute(
86 input_shape,
87 [&](Var i, Var j, Var k, Var l) {
88 return tvm::pow(bias_imm + (div(alpha_imm * sqr_sum(i, j, k, l), size)), beta_imm);
89 },
90 "tensor", kElementWise);
91 return topi::divide(data, sqrt_sum_up);
92}
93} // namespace nn
94} // namespace topi
95} // namespace tvm
96#endif // TVM_TOPI_NN_LOCAL_RESPONSE_NORM_H_
97