1#include <gtest/gtest.h>
2
3#include <ATen/native/quantized/PackedParams.h>
4#include <test/cpp/tensorexpr/test_base.h>
5#include <torch/csrc/jit/ir/ir.h>
6#include <torch/csrc/jit/ir/irparser.h>
7#include <torch/csrc/jit/tensorexpr/kernel.h>
8#include <torch/csrc/jit/tensorexpr/loopnest.h>
9#include <torch/csrc/jit/tensorexpr/tensor.h>
10#include <torch/csrc/jit/testing/file_check.h>
11#include <torch/torch.h>
12#include <cmath>
13#include <sstream>
14#include "torch/csrc/jit/tensorexpr/eval.h"
15#include "torch/csrc/jit/tensorexpr/ir.h"
16
17namespace torch {
18namespace jit {
19
20using namespace torch::jit::tensorexpr;
21using SimpleIRExprEval = ExprEval<SimpleIREvaluator>;
22using namespace torch::indexing;
23using namespace torch::jit::tensorexpr;
24
25class Quantization : public ::testing::Test {
26 public:
27 void SetUp() override {
28 getTEMustUseLLVMOnCPU() = false;
29 }
30};
31
32TEST_F(Quantization, QuantDequantInt8) {
33 const auto graph_string = R"IR(
34 graph(%x.1 : Float(2, 2, strides=[2, 1], device=cpu)):
35 %2 : int = prim::Constant[value=12]()
36 %3 : int = prim::Constant[value=13]()
37 %4 : float = prim::Constant[value=0.1]()
38 %q.1 : QInt8(2, 2) = aten::quantize_per_tensor(%x.1, %4, %3, %2)
39 %6 : Float(2, 2) = aten::dequantize(%q.1)
40 return (%6))IR";
41 auto graph = std::make_shared<Graph>();
42 parseIR(graph_string, &*graph);
43
44 auto x = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
45 auto q = at::quantize_per_tensor(x, 0.1f, 13, at::kQInt8);
46 auto y_expected = at::dequantize(q);
47 TensorExprKernel k(graph);
48 std::vector<at::Tensor> inputs = {x};
49 StmtPtr s = k.getCodeGenStmt();
50
51 std::vector<IValue> stack = fmap<IValue>(inputs);
52 k.run(stack);
53 auto y = stack[0].toTensor();
54 bool check = at::allclose(y_expected, y);
55 if (!check) {
56 std::cout << "y_expected:\n" << y_expected << std::endl;
57 std::cout << "y:\n" << y << std::endl;
58 }
59 TORCH_CHECK_EQ(check, 1);
60}
61
62TEST_F(Quantization, QuantDequantUInt8) {
63 const auto graph_string = R"IR(
64 graph(%x.1 : Float(2, 2, strides=[2, 1], device=cpu)):
65 %2 : int = prim::Constant[value=13]()
66 %3 : int = prim::Constant[value=122]()
67 %4 : float = prim::Constant[value=0.1]()
68 %q.1 : QUInt8(2, 2) = aten::quantize_per_tensor(%x.1, %4, %3, %2)
69 %6 : Float(2, 2) = aten::dequantize(%q.1)
70 return (%6))IR";
71 auto graph = std::make_shared<Graph>();
72 parseIR(graph_string, &*graph);
73
74 auto x = 2 * at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
75 auto q = at::quantize_per_tensor(x, 0.1f, 122, at::kQUInt8);
76 auto y_expected = at::dequantize(q);
77 TensorExprKernel k(graph);
78 std::vector<at::Tensor> inputs = {x};
79 StmtPtr s = k.getCodeGenStmt();
80
81 std::vector<IValue> stack = fmap<IValue>(inputs);
82 k.run(stack);
83 auto y = stack[0].toTensor();
84 bool check = at::allclose(y_expected, y);
85 if (!check) {
86 std::cout << "y_expected:\n" << y_expected << std::endl;
87 std::cout << "y:\n" << y << std::endl;
88 }
89 TORCH_CHECK_EQ(check, 1);
90}
91
92TEST_F(Quantization, QuantDequantUInt8_NLC) {
93 const auto graph_string = R"IR(
94 graph(%x.1 : Float(1, 2, 2, strides=[4, 1, 2], device=cpu)):
95 %2 : int = prim::Constant[value=13]()
96 %3 : int = prim::Constant[value=122]()
97 %4 : float = prim::Constant[value=0.1]()
98 %q.1 : QUInt8(1, 2, 2) = aten::quantize_per_tensor(%x.1, %4, %3, %2)
99 %6 : Float(1, 2, 2) = aten::dequantize(%q.1)
100 return (%6))IR";
101 auto graph = std::make_shared<Graph>();
102 parseIR(graph_string, &*graph);
103
104 auto x = 2 * at::rand({1, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
105 x.unsafeGetTensorImpl()->set_sizes_and_strides(
106 std::initializer_list<int64_t>{1, 2, 2}, {4, 1, 2});
107 auto q = at::quantize_per_tensor(x, 0.1f, 122, at::kQUInt8);
108 auto y_expected = at::dequantize(q);
109 TensorExprKernel k(graph);
110 std::vector<at::Tensor> inputs = {x};
111 StmtPtr s = k.getCodeGenStmt();
112
113 std::vector<IValue> stack = fmap<IValue>(inputs);
114 k.run(stack);
115 auto y = stack[0].toTensor();
116 bool check = at::allclose(y_expected, y);
117 if (!check) {
118 std::cout << "x:\n" << x << std::endl;
119 std::cout << "y_expected:\n" << y_expected << std::endl;
120 std::cout << "y:\n" << y << std::endl;
121 }
122 TORCH_CHECK_EQ(check, 1);
123}
124
125at::Tensor quantized_add(
126 at::Tensor x1,
127 at::Tensor x2,
128 double scale,
129 int64_t zero) {
130 const auto qadd_op =
131 c10::Dispatcher::singleton()
132 .findSchemaOrThrow("quantized::add", "")
133 .typed<at::Tensor(at::Tensor, at::Tensor, double, int64_t)>();
134 return qadd_op.call(x1, x2, scale, zero);
135}
136
137TEST_F(Quantization, QuantAddDequantInt8) {
138 const auto graph_string = R"IR(
139 graph(%x1 : Float(2, 2, strides=[2, 1], device=cpu), %x2 : Float(2, 2, strides=[2, 1], device=cpu)):
140 %2 : int = prim::Constant[value=12]()
141 %qz1 : int = prim::Constant[value=13]()
142 %qs1 : float = prim::Constant[value=0.1]()
143 %qz2 : int = prim::Constant[value=13]()
144 %qs2 : float = prim::Constant[value=0.1]()
145 %qza : int = prim::Constant[value=13]()
146 %qsa : float = prim::Constant[value=0.1]()
147 %q1 : QInt8(2, 2) = aten::quantize_per_tensor(%x1, %qs1, %qz1, %2)
148 %q2 : QInt8(2, 2) = aten::quantize_per_tensor(%x2, %qs2, %qz2, %2)
149 %qa : QInt8(2, 2) = quantized::add(%q1, %q2, %qsa, %qza)
150 %6 : Float(2, 2) = aten::dequantize(%qa)
151 return (%6))IR";
152 auto graph = std::make_shared<Graph>();
153 parseIR(graph_string, &*graph);
154
155 auto x1 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
156 auto x2 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
157 auto q1 = at::quantize_per_tensor(x1, 0.1f, 13, at::kQInt8);
158 auto q2 = at::quantize_per_tensor(x2, 0.1f, 13, at::kQInt8);
159 auto qa = quantized_add(q1, q2, 0.1f, 13);
160 auto y_expected = at::dequantize(qa);
161 TensorExprKernel k(graph);
162 std::vector<at::Tensor> inputs = {x1, x2};
163 StmtPtr s = k.getCodeGenStmt();
164
165 std::vector<IValue> stack = fmap<IValue>(inputs);
166 k.run(stack);
167 auto y = stack[0].toTensor();
168 bool check = at::allclose(y_expected, y);
169 if (!check) {
170 std::cout << "x1:\n" << x1 << std::endl;
171 std::cout << "q1:\n" << q1 << std::endl;
172 std::cout << "x2:\n" << x2 << std::endl;
173 std::cout << "q2:\n" << q2 << std::endl;
174 std::cout << "y_expected:\n" << y_expected << std::endl;
175 std::cout << "y:\n" << y << std::endl;
176 }
177 TORCH_CHECK_EQ(check, 1);
178}
179
180TEST_F(Quantization, QuantAddDequantUInt8) {
181 const auto graph_string = R"IR(
182 graph(%x1 : Float(2, 2, strides=[2, 1], device=cpu), %x2 : Float(2, 2, strides=[2, 1], device=cpu)):
183 %2 : int = prim::Constant[value=13]()
184 %qz1 : int = prim::Constant[value=13]()
185 %qs1 : float = prim::Constant[value=0.1]()
186 %qz2 : int = prim::Constant[value=13]()
187 %qs2 : float = prim::Constant[value=0.1]()
188 %qza : int = prim::Constant[value=13]()
189 %qsa : float = prim::Constant[value=0.1]()
190 %q1 : QUInt8(2, 2) = aten::quantize_per_tensor(%x1, %qs1, %qz1, %2)
191 %q2 : QUInt8(2, 2) = aten::quantize_per_tensor(%x2, %qs2, %qz2, %2)
192 %qa : QUInt8(2, 2) = quantized::add(%q1, %q2, %qsa, %qza)
193 %6 : Float(2, 2) = aten::dequantize(%qa)
194 return (%6))IR";
195 auto graph = std::make_shared<Graph>();
196 parseIR(graph_string, &*graph);
197
198 auto x1 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
199 auto x2 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
200 auto q1 = at::quantize_per_tensor(x1, 0.1f, 13, at::kQUInt8);
201 auto q2 = at::quantize_per_tensor(x2, 0.1f, 13, at::kQUInt8);
202 auto qa = quantized_add(q1, q2, 0.1f, 13);
203 auto y_expected = at::dequantize(qa);
204
205 TensorExprKernel k(graph);
206 std::vector<at::Tensor> inputs = {x1, x2};
207 StmtPtr s = k.getCodeGenStmt();
208
209 std::vector<IValue> stack = fmap<IValue>(inputs);
210 k.run(stack);
211 auto y = stack[0].toTensor();
212 bool check = at::allclose(y_expected, y);
213 if (!check) {
214 std::cout << "x1:\n" << x1 << std::endl;
215 std::cout << "q1:\n" << q1 << std::endl;
216 std::cout << "x2:\n" << x2 << std::endl;
217 std::cout << "q2:\n" << q2 << std::endl;
218 std::cout << "y_expected:\n" << y_expected << std::endl;
219 std::cout << "y:\n" << y << std::endl;
220 }
221 TORCH_CHECK_EQ(check, 1);
222}
223
224TEST_F(Quantization, QuantSigmoidDequantUInt8) {
225 const auto graph_string = R"IR(
226 graph(%x1 : Float(2, 2, strides=[2, 1], device=cpu)):
227 %2 : int = prim::Constant[value=13]()
228 %qz1 : int = prim::Constant[value=13]()
229 %qs1 : float = prim::Constant[value=0.1]()
230 %q1 : QUInt8(2, 2) = aten::quantize_per_tensor(%x1, %qs1, %qz1, %2)
231 %qa : QUInt8(2, 2) = aten::sigmoid(%q1)
232 %6 : Float(2, 2) = aten::dequantize(%qa)
233 return (%6))IR";
234 auto graph = std::make_shared<Graph>();
235 parseIR(graph_string, &*graph);
236
237 auto x1 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
238 auto q1 = at::quantize_per_tensor(x1, 0.1f, 13, at::kQUInt8);
239 auto qs = at::sigmoid(q1);
240 auto y_expected = at::dequantize(qs);
241
242 TensorExprKernel k(graph);
243 std::vector<at::Tensor> inputs = {x1};
244 StmtPtr s = k.getCodeGenStmt();
245
246 std::vector<IValue> stack = fmap<IValue>(inputs);
247 k.run(stack);
248 auto y = stack[0].toTensor();
249 bool check = at::allclose(y_expected, y);
250 if (!check) {
251 std::cout << "x1:\n" << x1 << std::endl;
252 std::cout << "q1:\n" << q1 << std::endl;
253 std::cout << "qs:\n" << qs << std::endl;
254 std::cout << "y_expected:\n" << y_expected << std::endl;
255 std::cout << "y:\n" << y << std::endl;
256 }
257 TORCH_CHECK_EQ(check, 1);
258}
259
260at::Tensor quantized_mul(
261 at::Tensor x1,
262 at::Tensor x2,
263 double scale,
264 int64_t zero) {
265 const auto op =
266 c10::Dispatcher::singleton()
267 .findSchemaOrThrow("quantized::mul", "")
268 .typed<at::Tensor(at::Tensor, at::Tensor, double, int64_t)>();
269 return op.call(x1, x2, scale, zero);
270}
271
272TEST_F(Quantization, QuantMulDequantUInt8) {
273 const auto graph_string = R"IR(
274 graph(%x1 : Float(2, 2, strides=[2, 1], device=cpu), %x2 : Float(2, 2, strides=[2, 1], device=cpu)):
275 %2 : int = prim::Constant[value=13]()
276 %qz1 : int = prim::Constant[value=13]()
277 %qs1 : float = prim::Constant[value=0.1]()
278 %qz2 : int = prim::Constant[value=13]()
279 %qs2 : float = prim::Constant[value=0.1]()
280 %qza : int = prim::Constant[value=13]()
281 %qsa : float = prim::Constant[value=0.1]()
282 %q1 : QUInt8(2, 2) = aten::quantize_per_tensor(%x1, %qs1, %qz1, %2)
283 %q2 : QUInt8(2, 2) = aten::quantize_per_tensor(%x2, %qs2, %qz2, %2)
284 %qa : QUInt8(2, 2) = quantized::mul(%q1, %q2, %qsa, %qza)
285 %6 : Float(2, 2) = aten::dequantize(%qa)
286 return (%6))IR";
287 auto graph = std::make_shared<Graph>();
288 parseIR(graph_string, &*graph);
289
290 auto x1 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
291 auto x2 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
292 auto q1 = at::quantize_per_tensor(x1, 0.1f, 13, at::kQUInt8);
293 auto q2 = at::quantize_per_tensor(x2, 0.1f, 13, at::kQUInt8);
294 auto qa = quantized_mul(q1, q2, 0.1f, 13);
295 auto y_expected = at::dequantize(qa);
296
297 TensorExprKernel k(graph);
298 std::vector<at::Tensor> inputs = {x1, x2};
299 StmtPtr s = k.getCodeGenStmt();
300
301 std::vector<IValue> stack = fmap<IValue>(inputs);
302 k.run(stack);
303 auto y = stack[0].toTensor();
304 bool check = at::allclose(y_expected, y);
305 if (!check) {
306 std::cout << "x1:\n" << x1 << std::endl;
307 std::cout << "q1:\n" << q1 << std::endl;
308 std::cout << "x2:\n" << x2 << std::endl;
309 std::cout << "q2:\n" << q2 << std::endl;
310 std::cout << "y_expected:\n" << y_expected << std::endl;
311 std::cout << "y:\n" << y << std::endl;
312 }
313 TORCH_CHECK_EQ(check, 1);
314}
315
316TEST_F(Quantization, QuantUpsampleNearst2dDequantUInt8) {
317 const auto graph_string = R"IR(
318 graph(%x : Float(1, 1, 4, 4, strides=[16, 16, 4, 1], device=cpu)):
319 %2 : int = prim::Constant[value=13]()
320 %4 : NoneType = prim::Constant()
321 %3 : int[] = prim::Constant[value=[6, 6]]()
322 %qz : int = prim::Constant[value=13]()
323 %qs : float = prim::Constant[value=0.1]()
324 %q : QUInt8(1, 1, 4, 4) = aten::quantize_per_tensor(%x, %qs, %qz, %2)
325 %qu : QUInt8(1, 1, 6, 6) = aten::upsample_nearest2d(%q, %3, %4)
326 %6 : Float(1, 1, 6, 6) = aten::dequantize(%qu)
327 return (%6))IR";
328 auto graph = std::make_shared<Graph>();
329 parseIR(graph_string, &*graph);
330
331 auto x = at::rand({1, 1, 4, 4}, TensorOptions(kCPU).dtype(at::kFloat));
332 auto q = at::quantize_per_tensor(x, 0.1f, 13, at::kQUInt8);
333 auto qu = at::upsample_nearest2d(q, {6, 6});
334 auto y_expected = at::dequantize(qu);
335
336 TensorExprKernel k(graph);
337 std::vector<at::Tensor> inputs = {x};
338 StmtPtr s = k.getCodeGenStmt();
339
340 std::vector<IValue> stack = fmap<IValue>(inputs);
341 k.run(stack);
342 auto y = stack[0].toTensor();
343 bool check = at::allclose(y_expected, y);
344 if (!check) {
345 std::cout << "x:\n" << x << std::endl;
346 std::cout << "q:\n" << q << std::endl;
347 std::cout << "qu:\n" << qu << std::endl;
348 std::cout << "y_expected:\n" << y_expected << std::endl;
349 std::cout << "y:\n" << y << std::endl;
350 }
351 TORCH_CHECK_EQ(check, 1);
352}
353
354TEST_F(Quantization, UpsampleNearst2d) {
355 const auto graph_string = R"IR(
356 graph(%x : Float(1, 1, 2, 2, strides=[2, 2, 2, 1], device=cpu)):
357 %4 : NoneType = prim::Constant()
358 %3 : int[] = prim::Constant[value=[4, 4]]()
359 %u : Float(1, 1, 4, 4) = aten::upsample_nearest2d(%x, %3, %4)
360 return (%u))IR";
361 auto graph = std::make_shared<Graph>();
362 parseIR(graph_string, &*graph);
363
364 auto x = at::rand({1, 1, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
365 auto y_expected = at::upsample_nearest2d(x, {4, 4});
366
367 TensorExprKernel k(graph);
368 std::vector<at::Tensor> inputs = {x};
369 StmtPtr s = k.getCodeGenStmt();
370
371 std::vector<IValue> stack = fmap<IValue>(inputs);
372 k.run(stack);
373 auto y = stack[0].toTensor();
374 bool check = at::allclose(y_expected, y);
375 if (!check) {
376 std::cout << "x:\n" << x << std::endl;
377 std::cout << "y_expected:\n" << y_expected << std::endl;
378 std::cout << "y:\n" << y << std::endl;
379 }
380 TORCH_CHECK_EQ(check, 1);
381}
382
383at::Tensor quantized_cat(
384 c10::List<at::Tensor> const& xs,
385 int64_t dim,
386 double scale,
387 int64_t zero) {
388 const auto op = c10::Dispatcher::singleton()
389 .findSchemaOrThrow("quantized::cat", "")
390 .typed<at::Tensor(
391 c10::List<at::Tensor> const&,
392 int64_t,
393 c10::optional<double>,
394 c10::optional<int64_t>)>();
395 return op.redispatch(
396 DispatchKeySet({DispatchKey::QuantizedCPU}), xs, dim, scale, zero);
397}
398
399TEST_F(Quantization, QuantCatDequantUInt8) {
400 const auto graph_string = R"IR(
401 graph(%x : Float(1, 1, 2, 2, strides=[2, 2, 2, 1], device=cpu), %y : Float(1, 1, 2, 2, strides=[2, 2, 2, 1], device=cpu), %z : Float(1, 1, 2, 2, strides=[2, 2, 2, 1], device=cpu)):
402 %qdt : int = prim::Constant[value=13]()
403 %qxz : int = prim::Constant[value=13]()
404 %qxs : float = prim::Constant[value=0.1]()
405 %qyz : int = prim::Constant[value=16]()
406 %qys : float = prim::Constant[value=0.15]()
407 %qzz : int = prim::Constant[value=19]()
408 %qzs : float = prim::Constant[value=0.2]()
409 %qx : QUInt8(1, 1, 2, 2) = aten::quantize_per_tensor(%x, %qxs, %qxz, %qdt)
410 %qy : QUInt8(1, 1, 2, 2) = aten::quantize_per_tensor(%y, %qys, %qyz, %qdt)
411 %qz : QUInt8(1, 1, 2, 2) = aten::quantize_per_tensor(%z, %qzs, %qzz, %qdt)
412 %catx : Tensor[] = prim::ListConstruct(%qx, %qy, %qz)
413 %catd : int = prim::Constant[value=0]()
414 %qcat : QUInt8(3, 1, 2, 2) = quantized::cat(%catx, %catd, %qxs, %qxz)
415 %cat : Float(3, 1, 2, 2) = aten::dequantize(%qcat)
416 return (%cat))IR";
417 auto graph = std::make_shared<Graph>();
418 parseIR(graph_string, &*graph);
419
420 auto x = at::rand({1, 1, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
421 auto y = at::rand({1, 1, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
422 auto z = at::rand({1, 1, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
423 auto qx = at::quantize_per_tensor(x, 0.1f, 13, at::kQUInt8);
424 auto qy = at::quantize_per_tensor(y, 0.15f, 16, at::kQUInt8);
425 auto qz = at::quantize_per_tensor(z, 0.2f, 19, at::kQUInt8);
426 auto qcat = quantized_cat({qx, qy, qz}, 0, 0.1f, 13);
427 auto expected = at::dequantize(qcat);
428
429 TensorExprKernel k(graph);
430 std::vector<at::Tensor> inputs = {x, y, z};
431 StmtPtr s = k.getCodeGenStmt();
432
433 std::vector<IValue> stack = fmap<IValue>(inputs);
434 k.run(stack);
435 auto result = stack[0].toTensor();
436 bool check = at::allclose(expected, result);
437 if (!check) {
438 std::cout << "x:\n" << x << std::endl;
439 std::cout << "y:\n" << y << std::endl;
440 std::cout << "z:\n" << z << std::endl;
441 std::cout << "qx:\n" << qx << std::endl;
442 std::cout << "qy:\n" << qy << std::endl;
443 std::cout << "qz:\n" << qz << std::endl;
444 std::cout << "qcat:\n" << qcat << std::endl;
445 std::cout << "expected:\n" << expected << std::endl;
446 std::cout << "result:\n" << result << std::endl;
447 }
448 TORCH_CHECK_EQ(check, 1);
449}
450
451} // namespace jit
452} // namespace torch
453