1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file correlation.cc
22 * \brief Correlation operators
23 */
24#include <tvm/relay/attrs/nn.h>
25#include <tvm/relay/op.h>
26#include <tvm/tir/data_layout.h>
27#include <tvm/tir/op.h>
28#include <tvm/topi/nn.h>
29
30#include <vector>
31
32#include "../op_common.h"
33
34namespace tvm {
35namespace relay {
36
37// relay.nn.correlation
38TVM_REGISTER_NODE_TYPE(CorrelationAttrs);
39
40InferCorrectLayoutOutput CorrelationInferCorrectLayout(
41 const Attrs& attrs, const Array<Layout>& new_in_layouts, const Array<Layout>& old_in_layouts,
42 const Array<tvm::relay::Type>& old_in_types) {
43 const auto* params = attrs.as<CorrelationAttrs>();
44 Layout layout{params->layout};
45 return InferCorrectLayoutOutput({layout, layout}, {layout}, attrs);
46}
47
48// Positional relay function to create correlation operator
49// used by frontend FFI.
50Expr MakeCorrelation(Expr data1, Expr data2, int kernel_size, int max_displacement, int stride1,
51 int stride2, Array<IndexExpr> padding, bool is_multiply, String layout) {
52 auto attrs = make_object<CorrelationAttrs>();
53 attrs->kernel_size = kernel_size;
54 attrs->max_displacement = max_displacement;
55 attrs->stride1 = stride1;
56 attrs->stride2 = stride2;
57 attrs->padding = std::move(padding);
58 attrs->is_multiply = is_multiply;
59 attrs->layout = std::move(layout);
60 static const Op& op = Op::Get("nn.correlation");
61 return Call(op, {data1, data2}, Attrs(attrs), {});
62}
63
64bool CorrelationRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
65 const TypeReporter& reporter) {
66 ICHECK_EQ(types.size(), 3);
67 const auto* data1 = types[0].as<TensorTypeNode>();
68 const auto* data2 = types[1].as<TensorTypeNode>();
69 if (data1 == nullptr || data2 == nullptr) return false;
70
71 const CorrelationAttrs* param = attrs.as<CorrelationAttrs>();
72 ICHECK(param != nullptr);
73 ICHECK_EQ(param->layout, "NCHW") << "layout not supported.";
74 IndexExpr pad_h, pad_w;
75 GetPaddingHeightWidth(param->padding, &pad_h, &pad_w);
76 IndexExpr padded_height = data1->shape[2] + pad_h;
77 IndexExpr padded_width = data2->shape[3] + pad_w;
78 int kernel_radius = (param->kernel_size - 1) / 2;
79 int border_size = param->max_displacement + kernel_radius;
80 int displacement_radius = param->max_displacement / param->stride2;
81 int displacement_size = 2 * displacement_radius + 1;
82 int out_channel = displacement_size * displacement_size;
83 IndexExpr out_height =
84 indexdiv((padded_height - 2 * border_size + param->stride1 - 1), param->stride1);
85 IndexExpr out_width =
86 indexdiv((padded_width - 2 * border_size + param->stride1 - 1), param->stride1);
87 Array<tvm::PrimExpr> oshape{data1->shape[0], out_channel, out_height, out_width};
88 // assign output type
89 reporter->Assign(types[2], TensorType(oshape, data1->dtype));
90 return true;
91}
92
93TVM_REGISTER_GLOBAL("relay.op.nn._make.correlation").set_body_typed(MakeCorrelation);
94
95RELAY_REGISTER_OP("nn.correlation")
96 .describe(R"code(Applies correlation to inputs.
97
98The correlation layer performs multiplicative patch comparisons between two feature maps.
99Given two multi-channel feature maps :math:`f_{1}, f_{2}`, with :math:`w`, :math:`h`, and :math:`c` being their width, height, and number of channels,
100the correlation layer lets the network compare each patch from :math:`f_{1}` with each patch from :math:`f_{2}`.
101
102For now we consider only a single comparison of two patches. The 'correlation' of two patches centered at :math:`x_{1}` in the first map and
103:math:`x_{2}` in the second map is then defined as:
104
105.. math::
106 c(x_{1}, x_{2}) = \sum_{o \in [-k,k] \times [-k,k]} <f_{1}(x_{1} + o), f_{2}(x_{2} + o)>
107
108for a square patch of size :math:`K:=2k+1`.
109
110Note that the equation above is identical to one step of a convolution in neural networks, but instead of convolving data with a filter, it convolves data with other
111data. For this reason, it has no training weights.
112
113Computing :math:`c(x_{1}, x_{2})` involves :math:`c * K^{2}` multiplications. Comparing all patch combinations involves :math:`w^{2}*h^{2}` such computations.
114
115Given a maximum displacement :math:`d`, for each location :math:`x_{1}` it computes correlations :math:`c(x_{1}, x_{2})` only in a neighborhood of size :math:`D:=2d+1`,
116by limiting the range of :math:`x_{2}`. We use strides :math:`s_{1}, s_{2}`, to quantize :math:`x_{1}` globally and to quantize :math:`x_{2}` within the neighborhood
117centered around :math:`x_{1}`.
118
119The final output is defined by the following expression:
120
121.. math::
122 out[n, q, i, j] = c(x_{i, j}, x_{q})
123
124where :math:`i` and :math:`j` enumerate spatial locations in :math:`f_{1}`, and :math:`q` denotes the :math:`q^{th}` neighborhood of :math:`x_{i,j}`.
125)code" TVM_ADD_FILELINE)
126 .set_attrs_type<CorrelationAttrs>()
127 .set_num_inputs(2)
128 .add_argument("data1", "Tensor", "Input data1 to the correlation.")
129 .add_argument("data2", "Tensor", "Input data2 to the correlation.")
130 .set_support_level(2)
131 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", CorrelationInferCorrectLayout)
132 .add_type_rel("Correlation", CorrelationRel)
133 .set_attr<TOpPattern>("TOpPattern", kOutEWiseFusable);
134
135} // namespace relay
136} // namespace tvm
137