1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file correlation.cc |
22 | * \brief Correlation operators |
23 | */ |
24 | #include <tvm/relay/attrs/nn.h> |
25 | #include <tvm/relay/op.h> |
26 | #include <tvm/tir/data_layout.h> |
27 | #include <tvm/tir/op.h> |
28 | #include <tvm/topi/nn.h> |
29 | |
30 | #include <vector> |
31 | |
32 | #include "../op_common.h" |
33 | |
34 | namespace tvm { |
35 | namespace relay { |
36 | |
37 | // relay.nn.correlation |
38 | TVM_REGISTER_NODE_TYPE(CorrelationAttrs); |
39 | |
40 | InferCorrectLayoutOutput CorrelationInferCorrectLayout( |
41 | const Attrs& attrs, const Array<Layout>& new_in_layouts, const Array<Layout>& old_in_layouts, |
42 | const Array<tvm::relay::Type>& old_in_types) { |
43 | const auto* params = attrs.as<CorrelationAttrs>(); |
44 | Layout layout{params->layout}; |
45 | return InferCorrectLayoutOutput({layout, layout}, {layout}, attrs); |
46 | } |
47 | |
48 | // Positional relay function to create correlation operator |
49 | // used by frontend FFI. |
50 | Expr MakeCorrelation(Expr data1, Expr data2, int kernel_size, int max_displacement, int stride1, |
51 | int stride2, Array<IndexExpr> padding, bool is_multiply, String layout) { |
52 | auto attrs = make_object<CorrelationAttrs>(); |
53 | attrs->kernel_size = kernel_size; |
54 | attrs->max_displacement = max_displacement; |
55 | attrs->stride1 = stride1; |
56 | attrs->stride2 = stride2; |
57 | attrs->padding = std::move(padding); |
58 | attrs->is_multiply = is_multiply; |
59 | attrs->layout = std::move(layout); |
60 | static const Op& op = Op::Get("nn.correlation" ); |
61 | return Call(op, {data1, data2}, Attrs(attrs), {}); |
62 | } |
63 | |
64 | bool CorrelationRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
65 | const TypeReporter& reporter) { |
66 | ICHECK_EQ(types.size(), 3); |
67 | const auto* data1 = types[0].as<TensorTypeNode>(); |
68 | const auto* data2 = types[1].as<TensorTypeNode>(); |
69 | if (data1 == nullptr || data2 == nullptr) return false; |
70 | |
71 | const CorrelationAttrs* param = attrs.as<CorrelationAttrs>(); |
72 | ICHECK(param != nullptr); |
73 | ICHECK_EQ(param->layout, "NCHW" ) << "layout not supported." ; |
74 | IndexExpr pad_h, pad_w; |
75 | GetPaddingHeightWidth(param->padding, &pad_h, &pad_w); |
76 | IndexExpr padded_height = data1->shape[2] + pad_h; |
77 | IndexExpr padded_width = data2->shape[3] + pad_w; |
78 | int kernel_radius = (param->kernel_size - 1) / 2; |
79 | int border_size = param->max_displacement + kernel_radius; |
80 | int displacement_radius = param->max_displacement / param->stride2; |
81 | int displacement_size = 2 * displacement_radius + 1; |
82 | int out_channel = displacement_size * displacement_size; |
83 | IndexExpr out_height = |
84 | indexdiv((padded_height - 2 * border_size + param->stride1 - 1), param->stride1); |
85 | IndexExpr out_width = |
86 | indexdiv((padded_width - 2 * border_size + param->stride1 - 1), param->stride1); |
87 | Array<tvm::PrimExpr> oshape{data1->shape[0], out_channel, out_height, out_width}; |
88 | // assign output type |
89 | reporter->Assign(types[2], TensorType(oshape, data1->dtype)); |
90 | return true; |
91 | } |
92 | |
93 | TVM_REGISTER_GLOBAL("relay.op.nn._make.correlation" ).set_body_typed(MakeCorrelation); |
94 | |
95 | RELAY_REGISTER_OP("nn.correlation" ) |
96 | .describe(R"code(Applies correlation to inputs. |
97 | |
98 | The correlation layer performs multiplicative patch comparisons between two feature maps. |
99 | Given two multi-channel feature maps :math:`f_{1}, f_{2}`, with :math:`w`, :math:`h`, and :math:`c` being their width, height, and number of channels, |
100 | the correlation layer lets the network compare each patch from :math:`f_{1}` with each patch from :math:`f_{2}`. |
101 | |
102 | For now we consider only a single comparison of two patches. The 'correlation' of two patches centered at :math:`x_{1}` in the first map and |
103 | :math:`x_{2}` in the second map is then defined as: |
104 | |
105 | .. math:: |
106 | c(x_{1}, x_{2}) = \sum_{o \in [-k,k] \times [-k,k]} <f_{1}(x_{1} + o), f_{2}(x_{2} + o)> |
107 | |
108 | for a square patch of size :math:`K:=2k+1`. |
109 | |
110 | Note that the equation above is identical to one step of a convolution in neural networks, but instead of convolving data with a filter, it convolves data with other |
111 | data. For this reason, it has no training weights. |
112 | |
113 | Computing :math:`c(x_{1}, x_{2})` involves :math:`c * K^{2}` multiplications. Comparing all patch combinations involves :math:`w^{2}*h^{2}` such computations. |
114 | |
115 | Given a maximum displacement :math:`d`, for each location :math:`x_{1}` it computes correlations :math:`c(x_{1}, x_{2})` only in a neighborhood of size :math:`D:=2d+1`, |
116 | by limiting the range of :math:`x_{2}`. We use strides :math:`s_{1}, s_{2}`, to quantize :math:`x_{1}` globally and to quantize :math:`x_{2}` within the neighborhood |
117 | centered around :math:`x_{1}`. |
118 | |
119 | The final output is defined by the following expression: |
120 | |
121 | .. math:: |
122 | out[n, q, i, j] = c(x_{i, j}, x_{q}) |
123 | |
124 | where :math:`i` and :math:`j` enumerate spatial locations in :math:`f_{1}`, and :math:`q` denotes the :math:`q^{th}` neighborhood of :math:`x_{i,j}`. |
125 | )code" TVM_ADD_FILELINE) |
126 | .set_attrs_type<CorrelationAttrs>() |
127 | .set_num_inputs(2) |
128 | .add_argument("data1" , "Tensor" , "Input data1 to the correlation." ) |
129 | .add_argument("data2" , "Tensor" , "Input data2 to the correlation." ) |
130 | .set_support_level(2) |
131 | .set_attr<FInferCorrectLayout>("FInferCorrectLayout" , CorrelationInferCorrectLayout) |
132 | .add_type_rel("Correlation" , CorrelationRel) |
133 | .set_attr<TOpPattern>("TOpPattern" , kOutEWiseFusable); |
134 | |
135 | } // namespace relay |
136 | } // namespace tvm |
137 | |