1 | #include <gtest/gtest.h> |
2 | |
3 | #include <test/cpp/jit/test_utils.h> |
4 | |
5 | #include <torch/csrc/jit/ir/alias_analysis.h> |
6 | #include <torch/csrc/jit/ir/irparser.h> |
7 | #include <torch/csrc/jit/passes/dead_code_elimination.h> |
8 | #include <torch/csrc/jit/runtime/custom_operator.h> |
9 | #include <torch/csrc/jit/runtime/register_ops_utils.h> |
10 | #include <torch/jit.h> |
11 | |
12 | namespace torch { |
13 | namespace jit { |
14 | |
15 | TEST(CustomOperatorTest, InferredSchema) { |
16 | torch::RegisterOperators reg( |
17 | "foo::bar" , [](double a, at::Tensor b) { return a + b; }); |
18 | auto& ops = getAllOperatorsFor(Symbol::fromQualString("foo::bar" )); |
19 | ASSERT_EQ(ops.size(), 1); |
20 | |
21 | auto& op = ops.front(); |
22 | ASSERT_EQ(op->schema().name(), "foo::bar" ); |
23 | |
24 | ASSERT_EQ(op->schema().arguments().size(), 2); |
25 | ASSERT_EQ(op->schema().arguments()[0].name(), "_0" ); |
26 | ASSERT_EQ(op->schema().arguments()[0].type()->kind(), TypeKind::FloatType); |
27 | ASSERT_EQ(op->schema().arguments()[1].name(), "_1" ); |
28 | ASSERT_EQ(op->schema().arguments()[1].type()->kind(), TypeKind::TensorType); |
29 | |
30 | ASSERT_EQ(op->schema().returns()[0].type()->kind(), TypeKind::TensorType); |
31 | |
32 | Stack stack; |
33 | push(stack, 2.0f, at::ones(5)); |
34 | op->getOperation()(stack); |
35 | at::Tensor output; |
36 | pop(stack, output); |
37 | |
38 | ASSERT_TRUE(output.allclose(at::full(5, 3.0f))); |
39 | } |
40 | |
41 | TEST(CustomOperatorTest, ExplicitSchema) { |
42 | torch::RegisterOperators reg( |
43 | "foo::bar_with_schema(float a, Tensor b) -> Tensor" , |
44 | [](double a, at::Tensor b) { return a + b; }); |
45 | |
46 | auto& ops = |
47 | getAllOperatorsFor(Symbol::fromQualString("foo::bar_with_schema" )); |
48 | ASSERT_EQ(ops.size(), 1); |
49 | |
50 | auto& op = ops.front(); |
51 | ASSERT_EQ(op->schema().name(), "foo::bar_with_schema" ); |
52 | |
53 | ASSERT_EQ(op->schema().arguments().size(), 2); |
54 | ASSERT_EQ(op->schema().arguments()[0].name(), "a" ); |
55 | ASSERT_EQ(op->schema().arguments()[0].type()->kind(), TypeKind::FloatType); |
56 | ASSERT_EQ(op->schema().arguments()[1].name(), "b" ); |
57 | ASSERT_EQ(op->schema().arguments()[1].type()->kind(), TypeKind::TensorType); |
58 | |
59 | ASSERT_EQ(op->schema().returns().size(), 1); |
60 | ASSERT_EQ(op->schema().returns()[0].type()->kind(), TypeKind::TensorType); |
61 | |
62 | Stack stack; |
63 | push(stack, 2.0f, at::ones(5)); |
64 | op->getOperation()(stack); |
65 | at::Tensor output; |
66 | pop(stack, output); |
67 | |
68 | ASSERT_TRUE(output.allclose(at::full(5, 3.0f))); |
69 | } |
70 | |
71 | TEST(CustomOperatorTest, ListParameters) { |
72 | // Check that lists work well. |
73 | torch::RegisterOperators reg( |
74 | "foo::lists(int[] ints, float[] floats, complex[] complexdoubles, Tensor[] tensors) -> float[]" , |
75 | [](torch::List<int64_t> ints, |
76 | torch::List<double> floats, |
77 | torch::List<c10::complex<double>> complexdoubles, |
78 | torch::List<at::Tensor> tensors) { return floats; }); |
79 | |
80 | auto& ops = getAllOperatorsFor(Symbol::fromQualString("foo::lists" )); |
81 | ASSERT_EQ(ops.size(), 1); |
82 | |
83 | auto& op = ops.front(); |
84 | ASSERT_EQ(op->schema().name(), "foo::lists" ); |
85 | |
86 | ASSERT_EQ(op->schema().arguments().size(), 4); |
87 | ASSERT_EQ(op->schema().arguments()[0].name(), "ints" ); |
88 | ASSERT_TRUE( |
89 | op->schema().arguments()[0].type()->isSubtypeOf(*ListType::ofInts())); |
90 | ASSERT_EQ(op->schema().arguments()[1].name(), "floats" ); |
91 | ASSERT_TRUE( |
92 | op->schema().arguments()[1].type()->isSubtypeOf(*ListType::ofFloats())); |
93 | ASSERT_EQ(op->schema().arguments()[2].name(), "complexdoubles" ); |
94 | ASSERT_TRUE(op->schema().arguments()[2].type()->isSubtypeOf( |
95 | *ListType::ofComplexDoubles())); |
96 | ASSERT_EQ(op->schema().arguments()[3].name(), "tensors" ); |
97 | ASSERT_TRUE( |
98 | op->schema().arguments()[3].type()->isSubtypeOf(*ListType::ofTensors())); |
99 | |
100 | ASSERT_EQ(op->schema().returns().size(), 1); |
101 | ASSERT_TRUE( |
102 | op->schema().returns()[0].type()->isSubtypeOf(*ListType::ofFloats())); |
103 | |
104 | Stack stack; |
105 | push(stack, c10::List<int64_t>({1, 2})); |
106 | push(stack, c10::List<double>({1.0, 2.0})); |
107 | push( |
108 | stack, |
109 | c10::List<c10::complex<double>>( |
110 | {c10::complex<double>(2.4, -5.5), c10::complex<double>(-1.3, 2)})); |
111 | push(stack, c10::List<at::Tensor>({at::ones(5)})); |
112 | op->getOperation()(stack); |
113 | c10::List<double> output; |
114 | pop(stack, output); |
115 | |
116 | ASSERT_EQ(output.size(), 2); |
117 | ASSERT_EQ(output.get(0), 1.0); |
118 | ASSERT_EQ(output.get(1), 2.0); |
119 | } |
120 | |
121 | TEST(CustomOperatorTest, ListParameters2) { |
122 | torch::RegisterOperators reg( |
123 | "foo::lists2(Tensor[] tensors) -> Tensor[]" , |
124 | [](torch::List<at::Tensor> tensors) { return tensors; }); |
125 | |
126 | auto& ops = getAllOperatorsFor(Symbol::fromQualString("foo::lists2" )); |
127 | ASSERT_EQ(ops.size(), 1); |
128 | |
129 | auto& op = ops.front(); |
130 | ASSERT_EQ(op->schema().name(), "foo::lists2" ); |
131 | |
132 | ASSERT_EQ(op->schema().arguments().size(), 1); |
133 | ASSERT_EQ(op->schema().arguments()[0].name(), "tensors" ); |
134 | ASSERT_TRUE( |
135 | op->schema().arguments()[0].type()->isSubtypeOf(*ListType::ofTensors())); |
136 | |
137 | ASSERT_EQ(op->schema().returns().size(), 1); |
138 | ASSERT_TRUE( |
139 | op->schema().returns()[0].type()->isSubtypeOf(*ListType::ofTensors())); |
140 | |
141 | Stack stack; |
142 | push(stack, c10::List<at::Tensor>({at::ones(5)})); |
143 | op->getOperation()(stack); |
144 | c10::List<at::Tensor> output; |
145 | pop(stack, output); |
146 | |
147 | ASSERT_EQ(output.size(), 1); |
148 | ASSERT_TRUE(output.get(0).allclose(at::ones(5))); |
149 | } |
150 | |
151 | TEST(CustomOperatorTest, Aliasing) { |
152 | torch::RegisterOperators reg( |
153 | "foo::aliasing" , [](at::Tensor a, at::Tensor b) -> at::Tensor { |
154 | a.add_(b); |
155 | return a; |
156 | }); |
157 | getAllOperatorsFor(Symbol::fromQualString("foo::aliasing" )); |
158 | |
159 | { |
160 | auto graph = std::make_shared<Graph>(); |
161 | parseIR( |
162 | R"IR( |
163 | graph(%x: Tensor, %y: Tensor): |
164 | %ret : Tensor = foo::aliasing(%x, %y) |
165 | return (%ret) |
166 | )IR" , |
167 | graph.get()); |
168 | |
169 | auto opNode = *graph->block()->nodes().begin(); |
170 | |
171 | AliasDb aliasDb(graph); |
172 | for (const auto input : opNode->inputs()) { |
173 | // The custom op writes to all its inputs |
174 | ASSERT_TRUE(aliasDb.writesToAlias(opNode, {input})); |
175 | // The output should be a wildcard and thus alias all inputs |
176 | ASSERT_TRUE(aliasDb.mayAlias(opNode->output(), input)); |
177 | } |
178 | } |
179 | { |
180 | // DCE should not remove a custom op |
181 | auto graph = std::make_shared<Graph>(); |
182 | const auto text = R"IR( |
183 | graph(%x: Tensor, %y: Tensor): |
184 | # CHECK: foo::aliasing |
185 | %ret : Tensor = foo::aliasing(%x, %y) |
186 | return (%x) |
187 | )IR" ; |
188 | parseIR(text, graph.get()); |
189 | EliminateDeadCode(graph); |
190 | |
191 | testing::FileCheck().run(text, *graph); |
192 | } |
193 | } |
194 | |
195 | // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) |
196 | static constexpr char op_list[] = "foofoo::bar.template;foo::another" ; |
197 | #define TORCH_SELECTIVE_NAME_IN_SCHEMA(l, n) \ |
198 | torch::detail::SelectiveStr<c10::impl::op_allowlist_contains_name_in_schema( \ |
199 | l, n)>(n) |
200 | |
201 | TEST(TestCustomOperator, OperatorGeneratorUndeclared) { |
202 | // Try to register an op name that does not exist in op_list. |
203 | // Expected: the op name is not registered. |
204 | torch::jit::RegisterOperators reg({OperatorGenerator( |
205 | TORCH_SELECTIVE_NAME_IN_SCHEMA( |
206 | op_list, "foofoo::not_exist(float a, Tensor b) -> Tensor" ), |
207 | [](Stack& stack) { |
208 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
209 | double a; |
210 | at::Tensor b; |
211 | pop(stack, a, b); |
212 | push(stack, a + b); |
213 | }, |
214 | aliasAnalysisFromSchema())}); |
215 | |
216 | auto& ops = getAllOperatorsFor(Symbol::fromQualString("foofoo::not_exist" )); |
217 | ASSERT_EQ(ops.size(), 0); |
218 | } |
219 | |
220 | TEST(TestCustomOperator, OperatorGeneratorBasic) { |
221 | // The operator should be successfully registered since its name is in the |
222 | // whitelist. |
223 | torch::jit::RegisterOperators reg({OperatorGenerator( |
224 | TORCH_SELECTIVE_NAME_IN_SCHEMA( |
225 | op_list, "foofoo::bar.template(float a, Tensor b) -> Tensor" ), |
226 | [](Stack& stack) { |
227 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
228 | double a; |
229 | at::Tensor b; |
230 | pop(stack, a, b); |
231 | push(stack, a + b); |
232 | }, |
233 | aliasAnalysisFromSchema())}); |
234 | |
235 | auto& ops = getAllOperatorsFor(Symbol::fromQualString("foofoo::bar" )); |
236 | ASSERT_EQ(ops.size(), 1); |
237 | |
238 | auto& op = ops.front(); |
239 | ASSERT_EQ(op->schema().name(), "foofoo::bar" ); |
240 | |
241 | ASSERT_EQ(op->schema().arguments().size(), 2); |
242 | ASSERT_EQ(op->schema().arguments()[0].name(), "a" ); |
243 | ASSERT_EQ(op->schema().arguments()[0].type()->kind(), TypeKind::FloatType); |
244 | ASSERT_EQ(op->schema().arguments()[1].name(), "b" ); |
245 | ASSERT_EQ(op->schema().arguments()[1].type()->kind(), TypeKind::TensorType); |
246 | |
247 | ASSERT_EQ(op->schema().returns().size(), 1); |
248 | ASSERT_EQ(op->schema().returns()[0].type()->kind(), TypeKind::TensorType); |
249 | |
250 | Stack stack; |
251 | push(stack, 2.0f, at::ones(5)); |
252 | op->getOperation()(stack); |
253 | at::Tensor output; |
254 | pop(stack, output); |
255 | |
256 | ASSERT_TRUE(output.allclose(at::full(5, 3.0f))); |
257 | } |
258 | |
259 | } // namespace jit |
260 | } // namespace torch |
261 | |