1#include <gtest/gtest.h>
2
3#include <test/cpp/jit/test_utils.h>
4
5#include <torch/csrc/jit/ir/alias_analysis.h>
6#include <torch/csrc/jit/ir/irparser.h>
7#include <torch/csrc/jit/passes/dead_code_elimination.h>
8#include <torch/csrc/jit/runtime/custom_operator.h>
9#include <torch/csrc/jit/runtime/register_ops_utils.h>
10#include <torch/jit.h>
11
12namespace torch {
13namespace jit {
14
15TEST(CustomOperatorTest, InferredSchema) {
16 torch::RegisterOperators reg(
17 "foo::bar", [](double a, at::Tensor b) { return a + b; });
18 auto& ops = getAllOperatorsFor(Symbol::fromQualString("foo::bar"));
19 ASSERT_EQ(ops.size(), 1);
20
21 auto& op = ops.front();
22 ASSERT_EQ(op->schema().name(), "foo::bar");
23
24 ASSERT_EQ(op->schema().arguments().size(), 2);
25 ASSERT_EQ(op->schema().arguments()[0].name(), "_0");
26 ASSERT_EQ(op->schema().arguments()[0].type()->kind(), TypeKind::FloatType);
27 ASSERT_EQ(op->schema().arguments()[1].name(), "_1");
28 ASSERT_EQ(op->schema().arguments()[1].type()->kind(), TypeKind::TensorType);
29
30 ASSERT_EQ(op->schema().returns()[0].type()->kind(), TypeKind::TensorType);
31
32 Stack stack;
33 push(stack, 2.0f, at::ones(5));
34 op->getOperation()(stack);
35 at::Tensor output;
36 pop(stack, output);
37
38 ASSERT_TRUE(output.allclose(at::full(5, 3.0f)));
39}
40
41TEST(CustomOperatorTest, ExplicitSchema) {
42 torch::RegisterOperators reg(
43 "foo::bar_with_schema(float a, Tensor b) -> Tensor",
44 [](double a, at::Tensor b) { return a + b; });
45
46 auto& ops =
47 getAllOperatorsFor(Symbol::fromQualString("foo::bar_with_schema"));
48 ASSERT_EQ(ops.size(), 1);
49
50 auto& op = ops.front();
51 ASSERT_EQ(op->schema().name(), "foo::bar_with_schema");
52
53 ASSERT_EQ(op->schema().arguments().size(), 2);
54 ASSERT_EQ(op->schema().arguments()[0].name(), "a");
55 ASSERT_EQ(op->schema().arguments()[0].type()->kind(), TypeKind::FloatType);
56 ASSERT_EQ(op->schema().arguments()[1].name(), "b");
57 ASSERT_EQ(op->schema().arguments()[1].type()->kind(), TypeKind::TensorType);
58
59 ASSERT_EQ(op->schema().returns().size(), 1);
60 ASSERT_EQ(op->schema().returns()[0].type()->kind(), TypeKind::TensorType);
61
62 Stack stack;
63 push(stack, 2.0f, at::ones(5));
64 op->getOperation()(stack);
65 at::Tensor output;
66 pop(stack, output);
67
68 ASSERT_TRUE(output.allclose(at::full(5, 3.0f)));
69}
70
71TEST(CustomOperatorTest, ListParameters) {
72 // Check that lists work well.
73 torch::RegisterOperators reg(
74 "foo::lists(int[] ints, float[] floats, complex[] complexdoubles, Tensor[] tensors) -> float[]",
75 [](torch::List<int64_t> ints,
76 torch::List<double> floats,
77 torch::List<c10::complex<double>> complexdoubles,
78 torch::List<at::Tensor> tensors) { return floats; });
79
80 auto& ops = getAllOperatorsFor(Symbol::fromQualString("foo::lists"));
81 ASSERT_EQ(ops.size(), 1);
82
83 auto& op = ops.front();
84 ASSERT_EQ(op->schema().name(), "foo::lists");
85
86 ASSERT_EQ(op->schema().arguments().size(), 4);
87 ASSERT_EQ(op->schema().arguments()[0].name(), "ints");
88 ASSERT_TRUE(
89 op->schema().arguments()[0].type()->isSubtypeOf(*ListType::ofInts()));
90 ASSERT_EQ(op->schema().arguments()[1].name(), "floats");
91 ASSERT_TRUE(
92 op->schema().arguments()[1].type()->isSubtypeOf(*ListType::ofFloats()));
93 ASSERT_EQ(op->schema().arguments()[2].name(), "complexdoubles");
94 ASSERT_TRUE(op->schema().arguments()[2].type()->isSubtypeOf(
95 *ListType::ofComplexDoubles()));
96 ASSERT_EQ(op->schema().arguments()[3].name(), "tensors");
97 ASSERT_TRUE(
98 op->schema().arguments()[3].type()->isSubtypeOf(*ListType::ofTensors()));
99
100 ASSERT_EQ(op->schema().returns().size(), 1);
101 ASSERT_TRUE(
102 op->schema().returns()[0].type()->isSubtypeOf(*ListType::ofFloats()));
103
104 Stack stack;
105 push(stack, c10::List<int64_t>({1, 2}));
106 push(stack, c10::List<double>({1.0, 2.0}));
107 push(
108 stack,
109 c10::List<c10::complex<double>>(
110 {c10::complex<double>(2.4, -5.5), c10::complex<double>(-1.3, 2)}));
111 push(stack, c10::List<at::Tensor>({at::ones(5)}));
112 op->getOperation()(stack);
113 c10::List<double> output;
114 pop(stack, output);
115
116 ASSERT_EQ(output.size(), 2);
117 ASSERT_EQ(output.get(0), 1.0);
118 ASSERT_EQ(output.get(1), 2.0);
119}
120
121TEST(CustomOperatorTest, ListParameters2) {
122 torch::RegisterOperators reg(
123 "foo::lists2(Tensor[] tensors) -> Tensor[]",
124 [](torch::List<at::Tensor> tensors) { return tensors; });
125
126 auto& ops = getAllOperatorsFor(Symbol::fromQualString("foo::lists2"));
127 ASSERT_EQ(ops.size(), 1);
128
129 auto& op = ops.front();
130 ASSERT_EQ(op->schema().name(), "foo::lists2");
131
132 ASSERT_EQ(op->schema().arguments().size(), 1);
133 ASSERT_EQ(op->schema().arguments()[0].name(), "tensors");
134 ASSERT_TRUE(
135 op->schema().arguments()[0].type()->isSubtypeOf(*ListType::ofTensors()));
136
137 ASSERT_EQ(op->schema().returns().size(), 1);
138 ASSERT_TRUE(
139 op->schema().returns()[0].type()->isSubtypeOf(*ListType::ofTensors()));
140
141 Stack stack;
142 push(stack, c10::List<at::Tensor>({at::ones(5)}));
143 op->getOperation()(stack);
144 c10::List<at::Tensor> output;
145 pop(stack, output);
146
147 ASSERT_EQ(output.size(), 1);
148 ASSERT_TRUE(output.get(0).allclose(at::ones(5)));
149}
150
151TEST(CustomOperatorTest, Aliasing) {
152 torch::RegisterOperators reg(
153 "foo::aliasing", [](at::Tensor a, at::Tensor b) -> at::Tensor {
154 a.add_(b);
155 return a;
156 });
157 getAllOperatorsFor(Symbol::fromQualString("foo::aliasing"));
158
159 {
160 auto graph = std::make_shared<Graph>();
161 parseIR(
162 R"IR(
163graph(%x: Tensor, %y: Tensor):
164 %ret : Tensor = foo::aliasing(%x, %y)
165 return (%ret)
166 )IR",
167 graph.get());
168
169 auto opNode = *graph->block()->nodes().begin();
170
171 AliasDb aliasDb(graph);
172 for (const auto input : opNode->inputs()) {
173 // The custom op writes to all its inputs
174 ASSERT_TRUE(aliasDb.writesToAlias(opNode, {input}));
175 // The output should be a wildcard and thus alias all inputs
176 ASSERT_TRUE(aliasDb.mayAlias(opNode->output(), input));
177 }
178 }
179 {
180 // DCE should not remove a custom op
181 auto graph = std::make_shared<Graph>();
182 const auto text = R"IR(
183graph(%x: Tensor, %y: Tensor):
184 # CHECK: foo::aliasing
185 %ret : Tensor = foo::aliasing(%x, %y)
186 return (%x)
187 )IR";
188 parseIR(text, graph.get());
189 EliminateDeadCode(graph);
190
191 testing::FileCheck().run(text, *graph);
192 }
193}
194
195// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
196static constexpr char op_list[] = "foofoo::bar.template;foo::another";
197#define TORCH_SELECTIVE_NAME_IN_SCHEMA(l, n) \
198 torch::detail::SelectiveStr<c10::impl::op_allowlist_contains_name_in_schema( \
199 l, n)>(n)
200
201TEST(TestCustomOperator, OperatorGeneratorUndeclared) {
202 // Try to register an op name that does not exist in op_list.
203 // Expected: the op name is not registered.
204 torch::jit::RegisterOperators reg({OperatorGenerator(
205 TORCH_SELECTIVE_NAME_IN_SCHEMA(
206 op_list, "foofoo::not_exist(float a, Tensor b) -> Tensor"),
207 [](Stack& stack) {
208 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
209 double a;
210 at::Tensor b;
211 pop(stack, a, b);
212 push(stack, a + b);
213 },
214 aliasAnalysisFromSchema())});
215
216 auto& ops = getAllOperatorsFor(Symbol::fromQualString("foofoo::not_exist"));
217 ASSERT_EQ(ops.size(), 0);
218}
219
220TEST(TestCustomOperator, OperatorGeneratorBasic) {
221 // The operator should be successfully registered since its name is in the
222 // whitelist.
223 torch::jit::RegisterOperators reg({OperatorGenerator(
224 TORCH_SELECTIVE_NAME_IN_SCHEMA(
225 op_list, "foofoo::bar.template(float a, Tensor b) -> Tensor"),
226 [](Stack& stack) {
227 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
228 double a;
229 at::Tensor b;
230 pop(stack, a, b);
231 push(stack, a + b);
232 },
233 aliasAnalysisFromSchema())});
234
235 auto& ops = getAllOperatorsFor(Symbol::fromQualString("foofoo::bar"));
236 ASSERT_EQ(ops.size(), 1);
237
238 auto& op = ops.front();
239 ASSERT_EQ(op->schema().name(), "foofoo::bar");
240
241 ASSERT_EQ(op->schema().arguments().size(), 2);
242 ASSERT_EQ(op->schema().arguments()[0].name(), "a");
243 ASSERT_EQ(op->schema().arguments()[0].type()->kind(), TypeKind::FloatType);
244 ASSERT_EQ(op->schema().arguments()[1].name(), "b");
245 ASSERT_EQ(op->schema().arguments()[1].type()->kind(), TypeKind::TensorType);
246
247 ASSERT_EQ(op->schema().returns().size(), 1);
248 ASSERT_EQ(op->schema().returns()[0].type()->kind(), TypeKind::TensorType);
249
250 Stack stack;
251 push(stack, 2.0f, at::ones(5));
252 op->getOperation()(stack);
253 at::Tensor output;
254 pop(stack, output);
255
256 ASSERT_TRUE(output.allclose(at::full(5, 3.0f)));
257}
258
259} // namespace jit
260} // namespace torch
261