1 | #include <gtest/gtest.h> |
2 | #include <torch/csrc/jit/tensorexpr/ir_simplifier.h> |
3 | #include <torch/csrc/jit/tensorexpr/llvm_codegen.h> |
4 | #include <torch/csrc/jit/tensorexpr/loopnest.h> |
5 | #include <torch/csrc/jit/tensorexpr/operators/conv2d.h> |
6 | #include <torch/csrc/jit/tensorexpr/tensor.h> |
7 | #include <torch/torch.h> |
8 | |
9 | namespace torch { |
10 | namespace jit { |
11 | |
12 | namespace te = torch::jit::tensorexpr; |
13 | namespace F = torch::nn::functional; |
14 | |
15 | #ifdef TORCH_ENABLE_LLVM |
16 | |
17 | // Generate test data with few bits of precision, to minimize error |
18 | // accumulation from floating-point reordering. |
19 | static at::Tensor genTestData(c10::IntArrayRef args) { |
20 | return at::trunc(at::randn(args) * 256.0f) / 256.0f; |
21 | } |
22 | |
23 | TEST(Conv, DepthwiseConv2D) { |
24 | constexpr int N = 1, C = 72, H = 56, W = 56; |
25 | constexpr int K = 72, R = 3, S = 3; |
26 | constexpr int kPad = 1, kStride = 2, kGroups = C; |
27 | constexpr int CperG = C / kGroups; |
28 | |
29 | te::BufHandle input("input" , {N, C, H, W}, te::kFloat); |
30 | te::BufHandle weight("weight" , {K, CperG, R, S}, te::kFloat); |
31 | te::BufHandle bias("bias" , {K}, te::kFloat); |
32 | te::Tensor output = |
33 | te::conv2d_depthwise(input, weight, bias, kStride, kPad, kGroups); |
34 | |
35 | te::LoopNest loop({output}); |
36 | loop.simplify(); |
37 | loop.prepareForCodegen(); |
38 | te::LLVMCodeGen cg(loop.root_stmt(), {input, weight, bias, output}); |
39 | |
40 | auto it = genTestData({N, C, H, W}); |
41 | auto wt = genTestData({K, CperG, R, S}); |
42 | auto bt = genTestData({K}); |
43 | auto ref = at::conv2d(it, wt, bt, kStride, kPad, /*dilation=*/1, kGroups); |
44 | auto ot = at::zeros_like(ref); |
45 | cg.call( |
46 | {it.data_ptr<float>(), |
47 | wt.data_ptr<float>(), |
48 | bt.data_ptr<float>(), |
49 | ot.data_ptr<float>()}); |
50 | |
51 | ASSERT_TRUE(at::allclose(ref, ot)); |
52 | } |
53 | |
54 | TEST(Conv, DepthwiseConv2DNoBias) { |
55 | constexpr int N = 1, C = 72, H = 56, W = 56; |
56 | constexpr int K = 72, R = 3, S = 3; |
57 | constexpr int kPad = 1, kStride = 2, kGroups = C; |
58 | constexpr int CperG = C / kGroups; |
59 | |
60 | te::BufHandle input("input" , {N, C, H, W}, te::kFloat); |
61 | te::BufHandle weight("weight" , {K, CperG, R, S}, te::kFloat); |
62 | te::Tensor output = |
63 | te::conv2d_depthwise(input, weight, kStride, kPad, kGroups); |
64 | |
65 | te::LoopNest loop({output}); |
66 | loop.simplify(); |
67 | loop.prepareForCodegen(); |
68 | te::LLVMCodeGen cg(loop.root_stmt(), {input, weight, output}); |
69 | |
70 | auto it = genTestData({N, C, H, W}); |
71 | auto wt = genTestData({K, CperG, R, S}); |
72 | auto ref = |
73 | at::conv2d(it, wt, at::Tensor(), kStride, kPad, /*dilation=*/1, kGroups); |
74 | auto ot = at::zeros_like(ref); |
75 | cg.call({it.data_ptr<float>(), wt.data_ptr<float>(), ot.data_ptr<float>()}); |
76 | |
77 | ASSERT_TRUE(at::allclose(ref, ot)); |
78 | } |
79 | |
80 | TEST(Conv, DepthwiseConv2DDynamicShapes) { |
81 | te::VarHandle N_var("N" , te::kInt); |
82 | te::VarHandle C_var("C" , te::kInt); |
83 | te::VarHandle H_var("H" , te::kInt); |
84 | te::VarHandle W_var("W" , te::kInt); |
85 | te::VarHandle K_var("K" , te::kInt); |
86 | te::VarHandle CperG_var("CperG" , te::kInt); |
87 | te::VarHandle R_var("R" , te::kInt); |
88 | te::VarHandle S_var("S" , te::kInt); |
89 | te::VarHandle kPad_var("kPad" , te::kInt); |
90 | te::VarHandle kStride_var("kStride" , te::kInt); |
91 | te::VarHandle kGroups_var("kGroups" , te::kInt); |
92 | |
93 | te::BufHandle input("input" , {N_var, C_var, H_var, W_var}, te::kFloat); |
94 | te::BufHandle weight("weight" , {K_var, CperG_var, R_var, S_var}, te::kFloat); |
95 | te::Tensor output = te::conv2d_depthwise( |
96 | input, |
97 | weight, |
98 | N_var, |
99 | C_var, |
100 | H_var, |
101 | W_var, |
102 | K_var, |
103 | CperG_var, |
104 | R_var, |
105 | S_var, |
106 | kStride_var, |
107 | kPad_var, |
108 | kGroups_var); |
109 | |
110 | te::LoopNest loop({output}); |
111 | loop.simplify(); |
112 | loop.prepareForCodegen(); |
113 | std::vector<te::CodeGen::BufferArg> buffer_args = { |
114 | input, |
115 | weight, |
116 | N_var, |
117 | C_var, |
118 | H_var, |
119 | W_var, |
120 | K_var, |
121 | CperG_var, |
122 | R_var, |
123 | S_var, |
124 | kPad_var, |
125 | kStride_var, |
126 | kGroups_var, |
127 | output}; |
128 | te::LLVMCodeGen cg(loop.root_stmt(), buffer_args); |
129 | |
130 | constexpr int N = 1, C = 72, H = 56, W = 56; |
131 | constexpr int K = 72, R = 3, S = 3; |
132 | constexpr int kPad = 1, kStride = 2, kGroups = C; |
133 | constexpr int CperG = C / kGroups; |
134 | |
135 | auto it = genTestData({N, C, H, W}); |
136 | auto wt = genTestData({K, CperG, R, S}); |
137 | auto ref = |
138 | at::conv2d(it, wt, at::Tensor(), kStride, kPad, /*dilation=*/1, kGroups); |
139 | auto ot = at::zeros_like(ref); |
140 | std::vector<te::CodeGen::CallArg> call_args = { |
141 | it.data_ptr<float>(), |
142 | wt.data_ptr<float>(), |
143 | N, |
144 | C, |
145 | H, |
146 | W, |
147 | K, |
148 | CperG, |
149 | R, |
150 | S, |
151 | kPad, |
152 | kStride, |
153 | kGroups, |
154 | ot.data_ptr<float>()}; |
155 | cg.call(call_args); |
156 | |
157 | ASSERT_TRUE(at::allclose(ref, ot)); |
158 | } |
159 | |
160 | #endif |
161 | |
162 | TEST(Conv, Conv2D) { |
163 | // Input dimensions. |
164 | constexpr int N = 1; |
165 | constexpr int C = 3; |
166 | constexpr int H = 11; |
167 | constexpr int W = 11; |
168 | |
169 | // Filter dimensions. |
170 | constexpr int K = 8; |
171 | constexpr int R = 3; |
172 | constexpr int S = 3; |
173 | |
174 | // Output dims. |
175 | constexpr int OH = H - R + 1; |
176 | constexpr int OW = W - S + 1; |
177 | |
178 | // Compute reference result. |
179 | at::Tensor input = torch::randn({N, C, H, W}); |
180 | at::Tensor filter = torch::randn({K, C, R, S}); |
181 | at::Tensor ref = F::conv2d(input, filter); |
182 | |
183 | // Double check the output size is as expected. |
184 | ASSERT_EQ(ref.size(0), N); |
185 | ASSERT_EQ(ref.size(1), K); |
186 | ASSERT_EQ(ref.size(2), OH); |
187 | ASSERT_EQ(ref.size(3), OW); |
188 | |
189 | te::BufHandle inputB("input" , {N, C, H, W}, te::kFloat); |
190 | te::BufHandle filterB("filter" , {K, C, R, S}, te::kFloat); |
191 | |
192 | te::Tensor conv = te::Reduce( |
193 | "conv" , |
194 | {N, K, OH, OW}, |
195 | te::Sum(), |
196 | // FIXME: We have to use a `std::vector` parameter here and then unpack |
197 | // it, because we don't have an overload allowing for an arbitrary number |
198 | // of ExprHandle/VarHandle parameters. |
199 | [&](const std::vector<te::VarHandle>& v) { |
200 | auto const& n = v[0]; |
201 | auto const& k = v[1]; |
202 | auto const& oh = v[2]; |
203 | auto const& ow = v[3]; |
204 | auto const& c = v[4]; |
205 | auto const& r = v[5]; |
206 | auto const& s = v[6]; |
207 | // FIXME: We have to use `call` and construct a `std::vector` here |
208 | // because the `operator()` overload is only specialized for a small |
209 | // number of arguments. |
210 | return inputB.load(n, c, oh + r, ow + s) * filterB.load(k, c, r, s); |
211 | }, |
212 | // FIXME: If you forget one of the reduction dims, you get a segfault. |
213 | // Could that be caught by a verifier? |
214 | {C, R, S}); |
215 | |
216 | // FIXME: It'd be nice to have a single header that pulls in things like |
217 | // LoopNest, IRSimplifier, etc. |
218 | te::LoopNest loop({conv}); |
219 | loop.prepareForCodegen(); |
220 | te::StmtPtr s = loop.root_stmt(); |
221 | s = te::IRSimplifier::simplify(s); |
222 | |
223 | at::Tensor result = at::empty_like(ref); |
224 | te::SimpleIREvaluator cg(s, {inputB, filterB, conv}); |
225 | cg.call( |
226 | {input.data_ptr<float>(), |
227 | filter.data_ptr<float>(), |
228 | result.data_ptr<float>()}); |
229 | |
230 | ASSERT_TRUE(at::allclose(ref, result, 1e-3, 1e-3)); |
231 | } |
232 | |
233 | } // namespace jit |
234 | } // namespace torch |
235 | |