1#pragma once
2
3#include <c10/core/DeviceType.h>
4#include <c10/util/Exception.h>
5
6namespace c10 {
7
8/**
9 * QScheme is an enum that specifies the type of quantization. This has a one
10 * to one correspondence with Quantizer
11 * Please refer to ATen/quantized/Quantizer.h to see the Quantizers classes.
12 * Keep this file in sync with torch/nn/_qscheme.py
13 */
14enum class QScheme : uint8_t {
15 PER_TENSOR_AFFINE = 0,
16 PER_CHANNEL_AFFINE = 1,
17 PER_TENSOR_SYMMETRIC = 2,
18 PER_CHANNEL_SYMMETRIC = 3,
19 PER_CHANNEL_AFFINE_FLOAT_QPARAMS = 4,
20 COMPILE_TIME_NUM_QSCHEMES = 5,
21};
22
23constexpr auto kPerTensorAffine = QScheme::PER_TENSOR_AFFINE;
24constexpr auto kPerChannelAffine = QScheme::PER_CHANNEL_AFFINE;
25constexpr auto kPerTensorSymmetric = QScheme::PER_TENSOR_SYMMETRIC;
26constexpr auto kPerChannelSymmetric = QScheme::PER_CHANNEL_SYMMETRIC;
27constexpr auto kPerChannelAffineFloatQParams =
28 QScheme::PER_CHANNEL_AFFINE_FLOAT_QPARAMS;
29constexpr int COMPILE_TIME_NUM_QSCHEMES =
30 static_cast<int>(QScheme::COMPILE_TIME_NUM_QSCHEMES);
31
32inline std::string toString(QScheme qscheme) {
33 switch (qscheme) {
34 case kPerTensorAffine:
35 return "per_tensor_affine";
36 case kPerChannelAffine:
37 return "per_channel_affine";
38 case kPerTensorSymmetric:
39 return "per_tensor_symmetric";
40 case kPerChannelSymmetric:
41 return "per_channel_symmetric";
42 case kPerChannelAffineFloatQParams:
43 return "per_channel_affine_float_qparams";
44 default:
45 TORCH_CHECK(false, "Unrecognized qscheme: ", static_cast<int>(qscheme));
46 }
47}
48
49} // namespace c10
50