1#include <gtest/gtest.h>
2
3#include <torch/torch.h>
4
5#include <test/cpp/api/support.h>
6
7#define TORCH_ENUM_PRETTY_PRINT_TEST(name) \
8 { \
9 v = torch::k##name; \
10 std::string pretty_print_name("k"); \
11 pretty_print_name.append(#name); \
12 ASSERT_EQ(torch::enumtype::get_enum_name(v), pretty_print_name); \
13 }
14
15TEST(EnumTest, AllEnums) {
16 c10::variant<
17 torch::enumtype::kLinear,
18 torch::enumtype::kConv1D,
19 torch::enumtype::kConv2D,
20 torch::enumtype::kConv3D,
21 torch::enumtype::kConvTranspose1D,
22 torch::enumtype::kConvTranspose2D,
23 torch::enumtype::kConvTranspose3D,
24 torch::enumtype::kSigmoid,
25 torch::enumtype::kTanh,
26 torch::enumtype::kReLU,
27 torch::enumtype::kLeakyReLU,
28 torch::enumtype::kFanIn,
29 torch::enumtype::kFanOut,
30 torch::enumtype::kConstant,
31 torch::enumtype::kReflect,
32 torch::enumtype::kReplicate,
33 torch::enumtype::kCircular,
34 torch::enumtype::kNearest,
35 torch::enumtype::kBilinear,
36 torch::enumtype::kBicubic,
37 torch::enumtype::kTrilinear,
38 torch::enumtype::kArea,
39 torch::enumtype::kSum,
40 torch::enumtype::kMean,
41 torch::enumtype::kMax,
42 torch::enumtype::kNone,
43 torch::enumtype::kBatchMean,
44 torch::enumtype::kZeros,
45 torch::enumtype::kBorder,
46 torch::enumtype::kReflection,
47 torch::enumtype::kRNN_TANH,
48 torch::enumtype::kRNN_RELU,
49 torch::enumtype::kLSTM,
50 torch::enumtype::kGRU>
51 v;
52
53 TORCH_ENUM_PRETTY_PRINT_TEST(Linear)
54 TORCH_ENUM_PRETTY_PRINT_TEST(Conv1D)
55 TORCH_ENUM_PRETTY_PRINT_TEST(Conv2D)
56 TORCH_ENUM_PRETTY_PRINT_TEST(Conv3D)
57 TORCH_ENUM_PRETTY_PRINT_TEST(ConvTranspose1D)
58 TORCH_ENUM_PRETTY_PRINT_TEST(ConvTranspose2D)
59 TORCH_ENUM_PRETTY_PRINT_TEST(ConvTranspose3D)
60 TORCH_ENUM_PRETTY_PRINT_TEST(Sigmoid)
61 TORCH_ENUM_PRETTY_PRINT_TEST(Tanh)
62 TORCH_ENUM_PRETTY_PRINT_TEST(ReLU)
63 TORCH_ENUM_PRETTY_PRINT_TEST(LeakyReLU)
64 TORCH_ENUM_PRETTY_PRINT_TEST(FanIn)
65 TORCH_ENUM_PRETTY_PRINT_TEST(FanOut)
66 TORCH_ENUM_PRETTY_PRINT_TEST(Constant)
67 TORCH_ENUM_PRETTY_PRINT_TEST(Reflect)
68 TORCH_ENUM_PRETTY_PRINT_TEST(Replicate)
69 TORCH_ENUM_PRETTY_PRINT_TEST(Circular)
70 TORCH_ENUM_PRETTY_PRINT_TEST(Nearest)
71 TORCH_ENUM_PRETTY_PRINT_TEST(Bilinear)
72 TORCH_ENUM_PRETTY_PRINT_TEST(Bicubic)
73 TORCH_ENUM_PRETTY_PRINT_TEST(Trilinear)
74 TORCH_ENUM_PRETTY_PRINT_TEST(Area)
75 TORCH_ENUM_PRETTY_PRINT_TEST(Sum)
76 TORCH_ENUM_PRETTY_PRINT_TEST(Mean)
77 TORCH_ENUM_PRETTY_PRINT_TEST(Max)
78 TORCH_ENUM_PRETTY_PRINT_TEST(None)
79 TORCH_ENUM_PRETTY_PRINT_TEST(BatchMean)
80 TORCH_ENUM_PRETTY_PRINT_TEST(Zeros)
81 TORCH_ENUM_PRETTY_PRINT_TEST(Border)
82 TORCH_ENUM_PRETTY_PRINT_TEST(Reflection)
83 TORCH_ENUM_PRETTY_PRINT_TEST(RNN_TANH)
84 TORCH_ENUM_PRETTY_PRINT_TEST(RNN_RELU)
85 TORCH_ENUM_PRETTY_PRINT_TEST(LSTM)
86 TORCH_ENUM_PRETTY_PRINT_TEST(GRU)
87}
88