1 | #include <test/cpp/jit/test_custom_class_registrations.h> |
2 | |
3 | #include <torch/custom_class.h> |
4 | #include <torch/script.h> |
5 | |
6 | #include <iostream> |
7 | #include <string> |
8 | #include <vector> |
9 | |
10 | using namespace torch::jit; |
11 | |
12 | namespace { |
13 | |
14 | struct DefaultArgs : torch::CustomClassHolder { |
15 | int x; |
16 | DefaultArgs(int64_t start = 3) : x(start) {} |
17 | int64_t increment(int64_t val = 1) { |
18 | x += val; |
19 | return x; |
20 | } |
21 | int64_t decrement(int64_t val = 1) { |
22 | x += val; |
23 | return x; |
24 | } |
25 | int64_t scale_add(int64_t add, int64_t scale = 1) { |
26 | // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) |
27 | x = scale * x + add; |
28 | return x; |
29 | } |
30 | int64_t divide(c10::optional<int64_t> factor) { |
31 | if (factor) { |
32 | // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) |
33 | x = x / *factor; |
34 | } |
35 | return x; |
36 | } |
37 | }; |
38 | |
39 | struct Foo : torch::CustomClassHolder { |
40 | int x, y; |
41 | Foo() : x(0), y(0) {} |
42 | Foo(int x_, int y_) : x(x_), y(y_) {} |
43 | int64_t info() { |
44 | return this->x * this->y; |
45 | } |
46 | int64_t add(int64_t z) { |
47 | return (x + y) * z; |
48 | } |
49 | void increment(int64_t z) { |
50 | this->x += z; |
51 | this->y += z; |
52 | } |
53 | int64_t combine(c10::intrusive_ptr<Foo> b) { |
54 | return this->info() + b->info(); |
55 | } |
56 | }; |
57 | |
58 | struct _StaticMethod : torch::CustomClassHolder { |
59 | // NOLINTNEXTLINE(modernize-use-equals-default) |
60 | _StaticMethod() {} |
61 | static int64_t staticMethod(int64_t input) { |
62 | return 2 * input; |
63 | } |
64 | }; |
65 | |
66 | struct FooGetterSetter : torch::CustomClassHolder { |
67 | FooGetterSetter() : x(0), y(0) {} |
68 | FooGetterSetter(int64_t x_, int64_t y_) : x(x_), y(y_) {} |
69 | |
70 | int64_t getX() { |
71 | // to make sure this is not just attribute lookup |
72 | return x + 2; |
73 | } |
74 | void setX(int64_t z) { |
75 | // to make sure this is not just attribute lookup |
76 | x = z + 2; |
77 | } |
78 | |
79 | int64_t getY() { |
80 | // to make sure this is not just attribute lookup |
81 | return y + 4; |
82 | } |
83 | |
84 | private: |
85 | int64_t x, y; |
86 | }; |
87 | |
88 | struct FooGetterSetterLambda : torch::CustomClassHolder { |
89 | int64_t x; |
90 | FooGetterSetterLambda() : x(0) {} |
91 | FooGetterSetterLambda(int64_t x_) : x(x_) {} |
92 | }; |
93 | |
94 | struct FooReadWrite : torch::CustomClassHolder { |
95 | int64_t x; |
96 | const int64_t y; |
97 | FooReadWrite() : x(0), y(0) {} |
98 | FooReadWrite(int64_t x_, int64_t y_) : x(x_), y(y_) {} |
99 | }; |
100 | |
101 | struct LambdaInit : torch::CustomClassHolder { |
102 | int x, y; |
103 | LambdaInit(int x_, int y_) : x(x_), y(y_) {} |
104 | int64_t diff() { |
105 | return this->x - this->y; |
106 | } |
107 | }; |
108 | |
109 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) |
110 | struct NoInit : torch::CustomClassHolder { |
111 | int64_t x; |
112 | }; |
113 | |
114 | struct PickleTester : torch::CustomClassHolder { |
115 | PickleTester(std::vector<int64_t> vals) : vals(std::move(vals)) {} |
116 | std::vector<int64_t> vals; |
117 | }; |
118 | |
119 | at::Tensor take_an_instance(const c10::intrusive_ptr<PickleTester>& instance) { |
120 | return torch::zeros({instance->vals.back(), 4}); |
121 | } |
122 | |
123 | struct ElementwiseInterpreter : torch::CustomClassHolder { |
124 | using InstructionType = std::tuple< |
125 | std::string /*op*/, |
126 | std::vector<std::string> /*inputs*/, |
127 | std::string /*output*/>; |
128 | |
129 | // NOLINTNEXTLINE(modernize-use-equals-default) |
130 | ElementwiseInterpreter() {} |
131 | |
132 | // Load a list of instructions into the interpreter. As specified above, |
133 | // instructions specify the operation (currently support "add" and "mul"), |
134 | // the names of the input values, and the name of the single output value |
135 | // from this instruction |
136 | void setInstructions(std::vector<InstructionType> instructions) { |
137 | instructions_ = std::move(instructions); |
138 | } |
139 | |
140 | // Add a constant. The interpreter maintains a set of constants across |
141 | // calls. They are keyed by name, and constants can be referenced in |
142 | // Instructions by the name specified |
143 | void addConstant(const std::string& name, at::Tensor value) { |
144 | constants_.insert_or_assign(name, std::move(value)); |
145 | } |
146 | |
147 | // Set the string names for the positional inputs to the function this |
148 | // interpreter represents. When invoked, the interpreter will assign |
149 | // the positional inputs to the names in the corresponding position in |
150 | // input_names. |
151 | void setInputNames(std::vector<std::string> input_names) { |
152 | input_names_ = std::move(input_names); |
153 | } |
154 | |
155 | // Specify the output name for the function this interpreter represents. This |
156 | // should match the "output" field of one of the instructions in the |
157 | // instruction list, typically the last instruction. |
158 | void setOutputName(std::string output_name) { |
159 | output_name_ = std::move(output_name); |
160 | } |
161 | |
162 | // Invoke this interpreter. This takes a list of positional inputs and returns |
163 | // a single output. Currently, inputs and outputs must all be Tensors. |
164 | at::Tensor __call__(std::vector<at::Tensor> inputs) { |
165 | // Environment to hold local variables |
166 | std::unordered_map<std::string, at::Tensor> environment; |
167 | |
168 | // Load inputs according to the specified names |
169 | if (inputs.size() != input_names_.size()) { |
170 | std::stringstream err; |
171 | err << "Expected " << input_names_.size() << " inputs, but got " |
172 | << inputs.size() << "!" ; |
173 | throw std::runtime_error(err.str()); |
174 | } |
175 | for (size_t i = 0; i < inputs.size(); ++i) { |
176 | environment[input_names_[i]] = inputs[i]; |
177 | } |
178 | |
179 | for (InstructionType& instr : instructions_) { |
180 | // Retrieve all input values for this op |
181 | std::vector<at::Tensor> inputs; |
182 | for (const auto& input_name : std::get<1>(instr)) { |
183 | // Operator output values shadow constants. |
184 | // Imagine all constants are defined in statements at the beginning |
185 | // of a function (a la K&R C). Any definition of an output value must |
186 | // necessarily come after constant definition in textual order. Thus, |
187 | // We look up values in the environment first then the constant table |
188 | // second to implement this shadowing behavior |
189 | if (environment.find(input_name) != environment.end()) { |
190 | inputs.push_back(environment.at(input_name)); |
191 | } else if (constants_.find(input_name) != constants_.end()) { |
192 | inputs.push_back(constants_.at(input_name)); |
193 | } else { |
194 | std::stringstream err; |
195 | err << "Instruction referenced unknown value " << input_name << "!" ; |
196 | throw std::runtime_error(err.str()); |
197 | } |
198 | } |
199 | |
200 | // Run the specified operation |
201 | at::Tensor result; |
202 | const auto& op = std::get<0>(instr); |
203 | if (op == "add" ) { |
204 | if (inputs.size() != 2) { |
205 | throw std::runtime_error("Unexpected number of inputs for add op!" ); |
206 | } |
207 | result = inputs[0] + inputs[1]; |
208 | } else if (op == "mul" ) { |
209 | if (inputs.size() != 2) { |
210 | throw std::runtime_error("Unexpected number of inputs for mul op!" ); |
211 | } |
212 | result = inputs[0] * inputs[1]; |
213 | } else { |
214 | std::stringstream err; |
215 | err << "Unknown operator " << op << "!" ; |
216 | throw std::runtime_error(err.str()); |
217 | } |
218 | |
219 | // Write back result into environment |
220 | const auto& output_name = std::get<2>(instr); |
221 | environment[output_name] = std::move(result); |
222 | } |
223 | |
224 | if (!output_name_) { |
225 | throw std::runtime_error("Output name not specified!" ); |
226 | } |
227 | |
228 | return environment.at(*output_name_); |
229 | } |
230 | |
231 | // Ser/De infrastructure. See |
232 | // https://pytorch.org/tutorials/advanced/torch_script_custom_classes.html#defining-serialization-deserialization-methods-for-custom-c-classes |
233 | // for more info. |
234 | |
235 | // This is the type we will use to marshall information on disk during |
236 | // ser/de. It is a simple tuple composed of primitive types and simple |
237 | // collection types like vector, optional, and dict. |
238 | using SerializationType = std::tuple< |
239 | std::vector<std::string> /*input_names_*/, |
240 | c10::optional<std::string> /*output_name_*/, |
241 | c10::Dict<std::string, at::Tensor> /*constants_*/, |
242 | std::vector<InstructionType> /*instructions_*/ |
243 | >; |
244 | |
245 | // This function yields the SerializationType instance for `this`. |
246 | SerializationType __getstate__() const { |
247 | return SerializationType{ |
248 | input_names_, output_name_, constants_, instructions_}; |
249 | } |
250 | |
251 | // This function will create an instance of `ElementwiseInterpreter` given |
252 | // an instance of `SerializationType`. |
253 | static c10::intrusive_ptr<ElementwiseInterpreter> __setstate__( |
254 | SerializationType state) { |
255 | auto instance = c10::make_intrusive<ElementwiseInterpreter>(); |
256 | std::tie( |
257 | instance->input_names_, |
258 | instance->output_name_, |
259 | instance->constants_, |
260 | instance->instructions_) = std::move(state); |
261 | return instance; |
262 | } |
263 | |
264 | // Class members |
265 | std::vector<std::string> input_names_; |
266 | c10::optional<std::string> output_name_; |
267 | c10::Dict<std::string, at::Tensor> constants_; |
268 | std::vector<InstructionType> instructions_; |
269 | }; |
270 | |
271 | struct ReLUClass : public torch::CustomClassHolder { |
272 | at::Tensor run(const at::Tensor& t) { |
273 | return t.relu(); |
274 | } |
275 | }; |
276 | |
277 | TORCH_LIBRARY(_TorchScriptTesting, m) { |
278 | m.class_<ScalarTypeClass>("_ScalarTypeClass" ) |
279 | .def(torch::init<at::ScalarType>()) |
280 | .def_pickle( |
281 | [](const c10::intrusive_ptr<ScalarTypeClass>& self) { |
282 | return std::make_tuple(self->scalar_type_); |
283 | }, |
284 | [](std::tuple<at::ScalarType> s) { |
285 | return c10::make_intrusive<ScalarTypeClass>(std::get<0>(s)); |
286 | }); |
287 | |
288 | m.class_<ReLUClass>("_ReLUClass" ) |
289 | .def(torch::init<>()) |
290 | .def("run" , &ReLUClass::run); |
291 | |
292 | m.class_<_StaticMethod>("_StaticMethod" ) |
293 | .def(torch::init<>()) |
294 | .def_static("staticMethod" , &_StaticMethod::staticMethod); |
295 | |
296 | m.class_<DefaultArgs>("_DefaultArgs" ) |
297 | .def(torch::init<int64_t>(), "" , {torch::arg("start" ) = 3}) |
298 | .def("increment" , &DefaultArgs::increment, "" , {torch::arg("val" ) = 1}) |
299 | .def("decrement" , &DefaultArgs::decrement, "" , {torch::arg("val" ) = 1}) |
300 | .def( |
301 | "scale_add" , |
302 | &DefaultArgs::scale_add, |
303 | "" , |
304 | {torch::arg("add" ), torch::arg("scale" ) = 1}) |
305 | .def( |
306 | "divide" , |
307 | &DefaultArgs::divide, |
308 | "" , |
309 | {torch::arg("factor" ) = torch::arg::none()}); |
310 | |
311 | m.class_<Foo>("_Foo" ) |
312 | .def(torch::init<int64_t, int64_t>()) |
313 | // .def(torch::init<>()) |
314 | .def("info" , &Foo::info) |
315 | .def("increment" , &Foo::increment) |
316 | .def("add" , &Foo::add) |
317 | .def("combine" , &Foo::combine); |
318 | |
319 | m.class_<FooGetterSetter>("_FooGetterSetter" ) |
320 | .def(torch::init<int64_t, int64_t>()) |
321 | .def_property("x" , &FooGetterSetter::getX, &FooGetterSetter::setX) |
322 | .def_property("y" , &FooGetterSetter::getY); |
323 | |
324 | m.class_<FooGetterSetterLambda>("_FooGetterSetterLambda" ) |
325 | .def(torch::init<int64_t>()) |
326 | .def_property( |
327 | "x" , |
328 | [](const c10::intrusive_ptr<FooGetterSetterLambda>& self) { |
329 | return self->x; |
330 | }, |
331 | [](const c10::intrusive_ptr<FooGetterSetterLambda>& self, |
332 | int64_t val) { self->x = val; }); |
333 | |
334 | m.class_<FooReadWrite>("_FooReadWrite" ) |
335 | .def(torch::init<int64_t, int64_t>()) |
336 | .def_readwrite("x" , &FooReadWrite::x) |
337 | .def_readonly("y" , &FooReadWrite::y); |
338 | |
339 | m.class_<LambdaInit>("_LambdaInit" ) |
340 | .def(torch::init([](int64_t x, int64_t y, bool swap) { |
341 | if (swap) { |
342 | return c10::make_intrusive<LambdaInit>(y, x); |
343 | } else { |
344 | return c10::make_intrusive<LambdaInit>(x, y); |
345 | } |
346 | })) |
347 | .def("diff" , &LambdaInit::diff); |
348 | |
349 | m.class_<NoInit>("_NoInit" ).def( |
350 | "get_x" , [](const c10::intrusive_ptr<NoInit>& self) { return self->x; }); |
351 | |
352 | m.class_<MyStackClass<std::string>>("_StackString" ) |
353 | .def(torch::init<std::vector<std::string>>()) |
354 | .def("push" , &MyStackClass<std::string>::push) |
355 | .def("pop" , &MyStackClass<std::string>::pop) |
356 | .def("clone" , &MyStackClass<std::string>::clone) |
357 | .def("merge" , &MyStackClass<std::string>::merge) |
358 | .def_pickle( |
359 | [](const c10::intrusive_ptr<MyStackClass<std::string>>& self) { |
360 | return self->stack_; |
361 | }, |
362 | [](std::vector<std::string> state) { // __setstate__ |
363 | return c10::make_intrusive<MyStackClass<std::string>>( |
364 | std::vector<std::string>{"i" , "was" , "deserialized" }); |
365 | }) |
366 | .def("return_a_tuple" , &MyStackClass<std::string>::return_a_tuple) |
367 | .def( |
368 | "top" , |
369 | [](const c10::intrusive_ptr<MyStackClass<std::string>>& self) |
370 | -> std::string { return self->stack_.back(); }) |
371 | .def( |
372 | "__str__" , |
373 | [](const c10::intrusive_ptr<MyStackClass<std::string>>& self) { |
374 | std::stringstream ss; |
375 | ss << "[" ; |
376 | for (size_t i = 0; i < self->stack_.size(); ++i) { |
377 | ss << self->stack_[i]; |
378 | if (i != self->stack_.size() - 1) { |
379 | ss << ", " ; |
380 | } |
381 | } |
382 | ss << "]" ; |
383 | return ss.str(); |
384 | }); |
385 | // clang-format off |
386 | // The following will fail with a static assert telling you you have to |
387 | // take an intrusive_ptr<MyStackClass> as the first argument. |
388 | // .def("foo", [](int64_t a) -> int64_t{ return 3;}); |
389 | // clang-format on |
390 | |
391 | m.class_<PickleTester>("_PickleTester" ) |
392 | .def(torch::init<std::vector<int64_t>>()) |
393 | .def_pickle( |
394 | [](c10::intrusive_ptr<PickleTester> self) { // __getstate__ |
395 | return std::vector<int64_t>{1, 3, 3, 7}; |
396 | }, |
397 | [](std::vector<int64_t> state) { // __setstate__ |
398 | return c10::make_intrusive<PickleTester>(std::move(state)); |
399 | }) |
400 | .def( |
401 | "top" , |
402 | [](const c10::intrusive_ptr<PickleTester>& self) { |
403 | return self->vals.back(); |
404 | }) |
405 | .def("pop" , [](const c10::intrusive_ptr<PickleTester>& self) { |
406 | auto val = self->vals.back(); |
407 | self->vals.pop_back(); |
408 | return val; |
409 | }); |
410 | |
411 | m.def( |
412 | "take_an_instance(__torch__.torch.classes._TorchScriptTesting._PickleTester x) -> Tensor Y" , |
413 | take_an_instance); |
414 | // test that schema inference is ok too |
415 | m.def("take_an_instance_inferred" , take_an_instance); |
416 | |
417 | m.class_<ElementwiseInterpreter>("_ElementwiseInterpreter" ) |
418 | .def(torch::init<>()) |
419 | .def("set_instructions" , &ElementwiseInterpreter::setInstructions) |
420 | .def("add_constant" , &ElementwiseInterpreter::addConstant) |
421 | .def("set_input_names" , &ElementwiseInterpreter::setInputNames) |
422 | .def("set_output_name" , &ElementwiseInterpreter::setOutputName) |
423 | .def("__call__" , &ElementwiseInterpreter::__call__) |
424 | .def_pickle( |
425 | /* __getstate__ */ |
426 | [](const c10::intrusive_ptr<ElementwiseInterpreter>& self) { |
427 | return self->__getstate__(); |
428 | }, |
429 | /* __setstate__ */ |
430 | [](ElementwiseInterpreter::SerializationType state) { |
431 | return ElementwiseInterpreter::__setstate__(std::move(state)); |
432 | }); |
433 | } |
434 | |
435 | } // namespace |
436 | |