1/*
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 * All rights reserved.
4 * This source code is licensed under the BSD-style license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7#pragma once
8
9#include <array>
10#include <string>
11#include <type_traits>
12
13/*
14 * Copied from fbgemm/ConvUtils.h to avoid the dependency on fbgemm
15 */
16template <int N, int... Vals>
17constexpr
18 typename std::enable_if<N == sizeof...(Vals), std::array<int, N>>::type
19 array_of_ones() {
20 return std::array<int, N>{{Vals...}};
21}
22
23template <int N, int... Vals>
24constexpr
25 typename std::enable_if<N != sizeof...(Vals), std::array<int, N>>::type
26 array_of_ones() {
27 return array_of_ones<N, Vals..., 1>();
28}
29
30/**
31 * @brief A struct to conveniently store all convolution parameters.
32 */
33template <int SPATIAL_DIM = 2> struct conv_param_t {
34 int MB; ///< Mini Batch size
35 int IC; ///< Number of Input Channels
36 int OC; ///< Number of Output Channels
37 std::array<int, SPATIAL_DIM> IN_DIM; ///< Input Image Dimension
38 int G; ///< Number of Groups
39 std::array<int, SPATIAL_DIM> K; ///< Filter (Kernel) dimensions
40 std::array<int, SPATIAL_DIM> stride; //< Strides
41 std::array<int, SPATIAL_DIM * 2>
42 pad; //< Padding (first SPATIAL_DIM is for prev/top/left padding, second
43 // SPATIAL_DIM is for next/bottom/right padding)
44 std::array<int, SPATIAL_DIM> dilation; //< Kernel dilation
45
46 // The following are derived parameters
47 std::array<int, SPATIAL_DIM> OUT_DIM; //< Output Image Dimension
48 std::array<int, SPATIAL_DIM> IN_DIMP; //< Input Image Dimension Padded
49
50 /**
51 * @brief Constructor for initializing the convolution parameters.
52 */
53 conv_param_t(
54 int mb, int ic, int oc, std::array<int, SPATIAL_DIM> in_dim, int g,
55 std::array<int, SPATIAL_DIM> k, std::array<int, SPATIAL_DIM> strd,
56 std::array<int, SPATIAL_DIM * 2> pd,
57 std::array<int, SPATIAL_DIM> dilations = array_of_ones<SPATIAL_DIM>())
58 : MB(mb), IC(ic), OC(oc), IN_DIM(in_dim), G(g), K(k), stride(strd),
59 pad(pd), dilation(dilations) {
60 if (ic % g != 0) {
61 throw std::runtime_error(
62 "groups = " + std::to_string(g) +
63 " does not divide number of input channels = " + std::to_string(ic));
64 }
65 if (oc % g != 0) {
66 throw std::runtime_error(
67 "groups = " + std::to_string(g) +
68 " does not divide number of output channels = " + std::to_string(oc));
69 }
70
71 for (int d = 0; d < SPATIAL_DIM; ++d) {
72 IN_DIMP[d] = IN_DIM[d] + pad[d] + pad[SPATIAL_DIM + d];
73 OUT_DIM[d] = (IN_DIMP[d] - dilation[d] * (K[d] - 1) - 1) / stride[d] + 1;
74 }
75 }
76
77 /**
78 * @brief Helper function to get convolution parameters as string.
79 */
80 std::string toString() const {
81 std::string dim_string[3] = {"T", "H", "W"};
82
83 std::string out = "";
84 out += "MB:" + std::to_string(MB) + ", ";
85 out += "IC:" + std::to_string(IC) + ", ";
86 out += "OC:" + std::to_string(OC) + ", ";
87 if (SPATIAL_DIM <= 3) {
88 for (int d = 0; d < SPATIAL_DIM; ++d) {
89 out += "I" + dim_string[3 - SPATIAL_DIM + d] + ":" +
90 std::to_string(IN_DIM[d]) + ", ";
91 }
92 } else {
93 for (int d = 0; d < SPATIAL_DIM; ++d) {
94 out += "I" + std::to_string(d) + ":" + std::to_string(IN_DIM[d]) + ", ";
95 }
96 }
97 out += "G:" + std::to_string(G) + ", ";
98 if (SPATIAL_DIM <= 3) {
99 for (int d = 0; d < SPATIAL_DIM; ++d) {
100 out += "K" + dim_string[3 - SPATIAL_DIM + d] + ":" +
101 std::to_string(K[d]) + ", ";
102 }
103 for (int d = 0; d < SPATIAL_DIM; ++d) {
104 out += "stride_" + dim_string[3 - SPATIAL_DIM + d] + ":" +
105 std::to_string(stride[d]) + ", ";
106 }
107 for (int d = 0; d < SPATIAL_DIM * 2; ++d) {
108 out += "pad_" + dim_string[3 - SPATIAL_DIM + (d % SPATIAL_DIM)] + ":" +
109 std::to_string(pad[d]) + ", ";
110 }
111 for (int d = 0; d < SPATIAL_DIM; ++d) {
112 out += "dilation_" + dim_string[3 - SPATIAL_DIM + d] + ":" +
113 std::to_string(dilation[d]);
114 if (d < SPATIAL_DIM - 1) {
115 out += ", ";
116 }
117 }
118 } else {
119 for (int d = 0; d < SPATIAL_DIM; ++d) {
120 out += "K" + std::to_string(d) + ":" + std::to_string(K[d]) + ", ";
121 }
122 for (int d = 0; d < SPATIAL_DIM; ++d) {
123 out += "stride_" + std::to_string(d) + ":" + std::to_string(stride[d]) +
124 ", ";
125 }
126 for (int d = 0; d < SPATIAL_DIM; ++d) {
127 out += "pad_" + std::to_string(d) + ":" + std::to_string(pad[d]);
128 if (d < SPATIAL_DIM * 2 - 1) {
129 out += ", ";
130 }
131 }
132 for (int d = 0; d < SPATIAL_DIM; ++d) {
133 out += "dilation_" + std::to_string(d) + ":" +
134 std::to_string(dilation[d]) + ", ";
135 }
136 }
137 return out;
138 }
139};
140
141/**
142 * @brief A struct to conveniently store all convolution parameters.
143 */
144template <int SPATIAL_DIM = 2> struct avg_pool_param_t {
145 int MB; ///< Mini Batch size
146 int IC; ///< Number of Input Channels
147 std::array<int, SPATIAL_DIM> IN_DIM; ///< Input Image Dimension
148 std::array<int, SPATIAL_DIM> K; ///< Kernel dimensions
149 std::array<int, SPATIAL_DIM> stride; //< Strides
150 std::array<int, SPATIAL_DIM * 2>
151 pad; //< Padding (first SPATIAL_DIM is for prev/top/left padding, second
152 // SPATIAL_DIM is for next/bottom/right padding)
153
154 // The following are derived parameters
155 int OC; ///< Number of Ouptut Channels
156 std::array<int, SPATIAL_DIM> OUT_DIM; //< Output Image Dimension
157 std::array<int, SPATIAL_DIM> IN_DIMP; //< Input Image Dimension Padded
158
159 /**
160 * @brief Constructor for initializing the average pool parameters.
161 */
162 avg_pool_param_t(int mb, int ic, std::array<int, SPATIAL_DIM> in_dim,
163 std::array<int, SPATIAL_DIM> k,
164 std::array<int, SPATIAL_DIM> strd,
165 std::array<int, SPATIAL_DIM * 2> pd)
166 : MB(mb), IC(ic), IN_DIM(in_dim), K(k), stride(strd), pad(pd) {
167 for (int d = 0; d < SPATIAL_DIM; ++d) {
168 OC = IC;
169 IN_DIMP[d] = IN_DIM[d] + pad[d] + pad[SPATIAL_DIM + d];
170 OUT_DIM[d] = (IN_DIMP[d] - (K[d] - 1) - 1) / stride[d] + 1;
171 }
172 }
173
174 /**
175 * @brief Helper function to get convolution parameters as string.
176 */
177 std::string toString() const {
178 std::string dim_string[3] = {"T", "H", "W"};
179
180 std::string out = "";
181 out += "MB:" + std::to_string(MB) + ", ";
182 out += "IC:" + std::to_string(IC) + ", ";
183 out += "OC:" + std::to_string(OC) + ", ";
184 if (SPATIAL_DIM <= 3) {
185 for (int d = 0; d < SPATIAL_DIM; ++d) {
186 out += "I" + dim_string[3 - SPATIAL_DIM + d] + ":" +
187 std::to_string(IN_DIM[d]) + ", ";
188 }
189 } else {
190 for (int d = 0; d < SPATIAL_DIM; ++d) {
191 out += "I" + std::to_string(d) + ":" + std::to_string(IN_DIM[d]) + ", ";
192 }
193 }
194 if (SPATIAL_DIM <= 3) {
195 for (int d = 0; d < SPATIAL_DIM; ++d) {
196 out += "K" + dim_string[3 - SPATIAL_DIM + d] + ":" +
197 std::to_string(K[d]) + ", ";
198 }
199 for (int d = 0; d < SPATIAL_DIM; ++d) {
200 out += "stride_" + dim_string[3 - SPATIAL_DIM + d] + ":" +
201 std::to_string(stride[d]) + ", ";
202 }
203 for (int d = 0; d < SPATIAL_DIM * 2; ++d) {
204 out += "pad_" + dim_string[3 - SPATIAL_DIM + (d % SPATIAL_DIM)] + ":" +
205 std::to_string(pad[d]) + ", ";
206 }
207 } else {
208 for (int d = 0; d < SPATIAL_DIM; ++d) {
209 out += "K" + std::to_string(d) + ":" + std::to_string(K[d]) + ", ";
210 }
211 for (int d = 0; d < SPATIAL_DIM; ++d) {
212 out += "stride_" + std::to_string(d) + ":" + std::to_string(stride[d]) +
213 ", ";
214 }
215 for (int d = 0; d < SPATIAL_DIM; ++d) {
216 out += "pad_" + std::to_string(d) + ":" + std::to_string(pad[d]);
217 if (d < SPATIAL_DIM * 2 - 1) {
218 out += ", ";
219 }
220 }
221 }
222 return out;
223 }
224};
225