1 | #include <gtest/gtest.h> |
2 | |
3 | #include <torch/torch.h> |
4 | |
5 | #include <test/cpp/api/support.h> |
6 | |
7 | #define TORCH_ENUM_PRETTY_PRINT_TEST(name) \ |
8 | { \ |
9 | v = torch::k##name; \ |
10 | std::string pretty_print_name("k"); \ |
11 | pretty_print_name.append(#name); \ |
12 | ASSERT_EQ(torch::enumtype::get_enum_name(v), pretty_print_name); \ |
13 | } |
14 | |
15 | TEST(EnumTest, AllEnums) { |
16 | c10::variant< |
17 | torch::enumtype::kLinear, |
18 | torch::enumtype::kConv1D, |
19 | torch::enumtype::kConv2D, |
20 | torch::enumtype::kConv3D, |
21 | torch::enumtype::kConvTranspose1D, |
22 | torch::enumtype::kConvTranspose2D, |
23 | torch::enumtype::kConvTranspose3D, |
24 | torch::enumtype::kSigmoid, |
25 | torch::enumtype::kTanh, |
26 | torch::enumtype::kReLU, |
27 | torch::enumtype::kLeakyReLU, |
28 | torch::enumtype::kFanIn, |
29 | torch::enumtype::kFanOut, |
30 | torch::enumtype::kConstant, |
31 | torch::enumtype::kReflect, |
32 | torch::enumtype::kReplicate, |
33 | torch::enumtype::kCircular, |
34 | torch::enumtype::kNearest, |
35 | torch::enumtype::kBilinear, |
36 | torch::enumtype::kBicubic, |
37 | torch::enumtype::kTrilinear, |
38 | torch::enumtype::kArea, |
39 | torch::enumtype::kSum, |
40 | torch::enumtype::kMean, |
41 | torch::enumtype::kMax, |
42 | torch::enumtype::kNone, |
43 | torch::enumtype::kBatchMean, |
44 | torch::enumtype::kZeros, |
45 | torch::enumtype::kBorder, |
46 | torch::enumtype::kReflection, |
47 | torch::enumtype::kRNN_TANH, |
48 | torch::enumtype::kRNN_RELU, |
49 | torch::enumtype::kLSTM, |
50 | torch::enumtype::kGRU> |
51 | v; |
52 | |
53 | TORCH_ENUM_PRETTY_PRINT_TEST(Linear) |
54 | TORCH_ENUM_PRETTY_PRINT_TEST(Conv1D) |
55 | TORCH_ENUM_PRETTY_PRINT_TEST(Conv2D) |
56 | TORCH_ENUM_PRETTY_PRINT_TEST(Conv3D) |
57 | TORCH_ENUM_PRETTY_PRINT_TEST(ConvTranspose1D) |
58 | TORCH_ENUM_PRETTY_PRINT_TEST(ConvTranspose2D) |
59 | TORCH_ENUM_PRETTY_PRINT_TEST(ConvTranspose3D) |
60 | TORCH_ENUM_PRETTY_PRINT_TEST(Sigmoid) |
61 | TORCH_ENUM_PRETTY_PRINT_TEST(Tanh) |
62 | TORCH_ENUM_PRETTY_PRINT_TEST(ReLU) |
63 | TORCH_ENUM_PRETTY_PRINT_TEST(LeakyReLU) |
64 | TORCH_ENUM_PRETTY_PRINT_TEST(FanIn) |
65 | TORCH_ENUM_PRETTY_PRINT_TEST(FanOut) |
66 | TORCH_ENUM_PRETTY_PRINT_TEST(Constant) |
67 | TORCH_ENUM_PRETTY_PRINT_TEST(Reflect) |
68 | TORCH_ENUM_PRETTY_PRINT_TEST(Replicate) |
69 | TORCH_ENUM_PRETTY_PRINT_TEST(Circular) |
70 | TORCH_ENUM_PRETTY_PRINT_TEST(Nearest) |
71 | TORCH_ENUM_PRETTY_PRINT_TEST(Bilinear) |
72 | TORCH_ENUM_PRETTY_PRINT_TEST(Bicubic) |
73 | TORCH_ENUM_PRETTY_PRINT_TEST(Trilinear) |
74 | TORCH_ENUM_PRETTY_PRINT_TEST(Area) |
75 | TORCH_ENUM_PRETTY_PRINT_TEST(Sum) |
76 | TORCH_ENUM_PRETTY_PRINT_TEST(Mean) |
77 | TORCH_ENUM_PRETTY_PRINT_TEST(Max) |
78 | TORCH_ENUM_PRETTY_PRINT_TEST(None) |
79 | TORCH_ENUM_PRETTY_PRINT_TEST(BatchMean) |
80 | TORCH_ENUM_PRETTY_PRINT_TEST(Zeros) |
81 | TORCH_ENUM_PRETTY_PRINT_TEST(Border) |
82 | TORCH_ENUM_PRETTY_PRINT_TEST(Reflection) |
83 | TORCH_ENUM_PRETTY_PRINT_TEST(RNN_TANH) |
84 | TORCH_ENUM_PRETTY_PRINT_TEST(RNN_RELU) |
85 | TORCH_ENUM_PRETTY_PRINT_TEST(LSTM) |
86 | TORCH_ENUM_PRETTY_PRINT_TEST(GRU) |
87 | } |
88 | |