1#include <test/cpp/jit/test_custom_class_registrations.h>
2
3#include <torch/custom_class.h>
4#include <torch/script.h>
5
6#include <iostream>
7#include <string>
8#include <vector>
9
10using namespace torch::jit;
11
12namespace {
13
14struct DefaultArgs : torch::CustomClassHolder {
15 int x;
16 DefaultArgs(int64_t start = 3) : x(start) {}
17 int64_t increment(int64_t val = 1) {
18 x += val;
19 return x;
20 }
21 int64_t decrement(int64_t val = 1) {
22 x += val;
23 return x;
24 }
25 int64_t scale_add(int64_t add, int64_t scale = 1) {
26 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
27 x = scale * x + add;
28 return x;
29 }
30 int64_t divide(c10::optional<int64_t> factor) {
31 if (factor) {
32 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
33 x = x / *factor;
34 }
35 return x;
36 }
37};
38
39struct Foo : torch::CustomClassHolder {
40 int x, y;
41 Foo() : x(0), y(0) {}
42 Foo(int x_, int y_) : x(x_), y(y_) {}
43 int64_t info() {
44 return this->x * this->y;
45 }
46 int64_t add(int64_t z) {
47 return (x + y) * z;
48 }
49 void increment(int64_t z) {
50 this->x += z;
51 this->y += z;
52 }
53 int64_t combine(c10::intrusive_ptr<Foo> b) {
54 return this->info() + b->info();
55 }
56};
57
58struct _StaticMethod : torch::CustomClassHolder {
59 // NOLINTNEXTLINE(modernize-use-equals-default)
60 _StaticMethod() {}
61 static int64_t staticMethod(int64_t input) {
62 return 2 * input;
63 }
64};
65
66struct FooGetterSetter : torch::CustomClassHolder {
67 FooGetterSetter() : x(0), y(0) {}
68 FooGetterSetter(int64_t x_, int64_t y_) : x(x_), y(y_) {}
69
70 int64_t getX() {
71 // to make sure this is not just attribute lookup
72 return x + 2;
73 }
74 void setX(int64_t z) {
75 // to make sure this is not just attribute lookup
76 x = z + 2;
77 }
78
79 int64_t getY() {
80 // to make sure this is not just attribute lookup
81 return y + 4;
82 }
83
84 private:
85 int64_t x, y;
86};
87
88struct FooGetterSetterLambda : torch::CustomClassHolder {
89 int64_t x;
90 FooGetterSetterLambda() : x(0) {}
91 FooGetterSetterLambda(int64_t x_) : x(x_) {}
92};
93
94struct FooReadWrite : torch::CustomClassHolder {
95 int64_t x;
96 const int64_t y;
97 FooReadWrite() : x(0), y(0) {}
98 FooReadWrite(int64_t x_, int64_t y_) : x(x_), y(y_) {}
99};
100
101struct LambdaInit : torch::CustomClassHolder {
102 int x, y;
103 LambdaInit(int x_, int y_) : x(x_), y(y_) {}
104 int64_t diff() {
105 return this->x - this->y;
106 }
107};
108
109// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
110struct NoInit : torch::CustomClassHolder {
111 int64_t x;
112};
113
114struct PickleTester : torch::CustomClassHolder {
115 PickleTester(std::vector<int64_t> vals) : vals(std::move(vals)) {}
116 std::vector<int64_t> vals;
117};
118
119at::Tensor take_an_instance(const c10::intrusive_ptr<PickleTester>& instance) {
120 return torch::zeros({instance->vals.back(), 4});
121}
122
123struct ElementwiseInterpreter : torch::CustomClassHolder {
124 using InstructionType = std::tuple<
125 std::string /*op*/,
126 std::vector<std::string> /*inputs*/,
127 std::string /*output*/>;
128
129 // NOLINTNEXTLINE(modernize-use-equals-default)
130 ElementwiseInterpreter() {}
131
132 // Load a list of instructions into the interpreter. As specified above,
133 // instructions specify the operation (currently support "add" and "mul"),
134 // the names of the input values, and the name of the single output value
135 // from this instruction
136 void setInstructions(std::vector<InstructionType> instructions) {
137 instructions_ = std::move(instructions);
138 }
139
140 // Add a constant. The interpreter maintains a set of constants across
141 // calls. They are keyed by name, and constants can be referenced in
142 // Instructions by the name specified
143 void addConstant(const std::string& name, at::Tensor value) {
144 constants_.insert_or_assign(name, std::move(value));
145 }
146
147 // Set the string names for the positional inputs to the function this
148 // interpreter represents. When invoked, the interpreter will assign
149 // the positional inputs to the names in the corresponding position in
150 // input_names.
151 void setInputNames(std::vector<std::string> input_names) {
152 input_names_ = std::move(input_names);
153 }
154
155 // Specify the output name for the function this interpreter represents. This
156 // should match the "output" field of one of the instructions in the
157 // instruction list, typically the last instruction.
158 void setOutputName(std::string output_name) {
159 output_name_ = std::move(output_name);
160 }
161
162 // Invoke this interpreter. This takes a list of positional inputs and returns
163 // a single output. Currently, inputs and outputs must all be Tensors.
164 at::Tensor __call__(std::vector<at::Tensor> inputs) {
165 // Environment to hold local variables
166 std::unordered_map<std::string, at::Tensor> environment;
167
168 // Load inputs according to the specified names
169 if (inputs.size() != input_names_.size()) {
170 std::stringstream err;
171 err << "Expected " << input_names_.size() << " inputs, but got "
172 << inputs.size() << "!";
173 throw std::runtime_error(err.str());
174 }
175 for (size_t i = 0; i < inputs.size(); ++i) {
176 environment[input_names_[i]] = inputs[i];
177 }
178
179 for (InstructionType& instr : instructions_) {
180 // Retrieve all input values for this op
181 std::vector<at::Tensor> inputs;
182 for (const auto& input_name : std::get<1>(instr)) {
183 // Operator output values shadow constants.
184 // Imagine all constants are defined in statements at the beginning
185 // of a function (a la K&R C). Any definition of an output value must
186 // necessarily come after constant definition in textual order. Thus,
187 // We look up values in the environment first then the constant table
188 // second to implement this shadowing behavior
189 if (environment.find(input_name) != environment.end()) {
190 inputs.push_back(environment.at(input_name));
191 } else if (constants_.find(input_name) != constants_.end()) {
192 inputs.push_back(constants_.at(input_name));
193 } else {
194 std::stringstream err;
195 err << "Instruction referenced unknown value " << input_name << "!";
196 throw std::runtime_error(err.str());
197 }
198 }
199
200 // Run the specified operation
201 at::Tensor result;
202 const auto& op = std::get<0>(instr);
203 if (op == "add") {
204 if (inputs.size() != 2) {
205 throw std::runtime_error("Unexpected number of inputs for add op!");
206 }
207 result = inputs[0] + inputs[1];
208 } else if (op == "mul") {
209 if (inputs.size() != 2) {
210 throw std::runtime_error("Unexpected number of inputs for mul op!");
211 }
212 result = inputs[0] * inputs[1];
213 } else {
214 std::stringstream err;
215 err << "Unknown operator " << op << "!";
216 throw std::runtime_error(err.str());
217 }
218
219 // Write back result into environment
220 const auto& output_name = std::get<2>(instr);
221 environment[output_name] = std::move(result);
222 }
223
224 if (!output_name_) {
225 throw std::runtime_error("Output name not specified!");
226 }
227
228 return environment.at(*output_name_);
229 }
230
231 // Ser/De infrastructure. See
232 // https://pytorch.org/tutorials/advanced/torch_script_custom_classes.html#defining-serialization-deserialization-methods-for-custom-c-classes
233 // for more info.
234
235 // This is the type we will use to marshall information on disk during
236 // ser/de. It is a simple tuple composed of primitive types and simple
237 // collection types like vector, optional, and dict.
238 using SerializationType = std::tuple<
239 std::vector<std::string> /*input_names_*/,
240 c10::optional<std::string> /*output_name_*/,
241 c10::Dict<std::string, at::Tensor> /*constants_*/,
242 std::vector<InstructionType> /*instructions_*/
243 >;
244
245 // This function yields the SerializationType instance for `this`.
246 SerializationType __getstate__() const {
247 return SerializationType{
248 input_names_, output_name_, constants_, instructions_};
249 }
250
251 // This function will create an instance of `ElementwiseInterpreter` given
252 // an instance of `SerializationType`.
253 static c10::intrusive_ptr<ElementwiseInterpreter> __setstate__(
254 SerializationType state) {
255 auto instance = c10::make_intrusive<ElementwiseInterpreter>();
256 std::tie(
257 instance->input_names_,
258 instance->output_name_,
259 instance->constants_,
260 instance->instructions_) = std::move(state);
261 return instance;
262 }
263
264 // Class members
265 std::vector<std::string> input_names_;
266 c10::optional<std::string> output_name_;
267 c10::Dict<std::string, at::Tensor> constants_;
268 std::vector<InstructionType> instructions_;
269};
270
271struct ReLUClass : public torch::CustomClassHolder {
272 at::Tensor run(const at::Tensor& t) {
273 return t.relu();
274 }
275};
276
277TORCH_LIBRARY(_TorchScriptTesting, m) {
278 m.class_<ScalarTypeClass>("_ScalarTypeClass")
279 .def(torch::init<at::ScalarType>())
280 .def_pickle(
281 [](const c10::intrusive_ptr<ScalarTypeClass>& self) {
282 return std::make_tuple(self->scalar_type_);
283 },
284 [](std::tuple<at::ScalarType> s) {
285 return c10::make_intrusive<ScalarTypeClass>(std::get<0>(s));
286 });
287
288 m.class_<ReLUClass>("_ReLUClass")
289 .def(torch::init<>())
290 .def("run", &ReLUClass::run);
291
292 m.class_<_StaticMethod>("_StaticMethod")
293 .def(torch::init<>())
294 .def_static("staticMethod", &_StaticMethod::staticMethod);
295
296 m.class_<DefaultArgs>("_DefaultArgs")
297 .def(torch::init<int64_t>(), "", {torch::arg("start") = 3})
298 .def("increment", &DefaultArgs::increment, "", {torch::arg("val") = 1})
299 .def("decrement", &DefaultArgs::decrement, "", {torch::arg("val") = 1})
300 .def(
301 "scale_add",
302 &DefaultArgs::scale_add,
303 "",
304 {torch::arg("add"), torch::arg("scale") = 1})
305 .def(
306 "divide",
307 &DefaultArgs::divide,
308 "",
309 {torch::arg("factor") = torch::arg::none()});
310
311 m.class_<Foo>("_Foo")
312 .def(torch::init<int64_t, int64_t>())
313 // .def(torch::init<>())
314 .def("info", &Foo::info)
315 .def("increment", &Foo::increment)
316 .def("add", &Foo::add)
317 .def("combine", &Foo::combine);
318
319 m.class_<FooGetterSetter>("_FooGetterSetter")
320 .def(torch::init<int64_t, int64_t>())
321 .def_property("x", &FooGetterSetter::getX, &FooGetterSetter::setX)
322 .def_property("y", &FooGetterSetter::getY);
323
324 m.class_<FooGetterSetterLambda>("_FooGetterSetterLambda")
325 .def(torch::init<int64_t>())
326 .def_property(
327 "x",
328 [](const c10::intrusive_ptr<FooGetterSetterLambda>& self) {
329 return self->x;
330 },
331 [](const c10::intrusive_ptr<FooGetterSetterLambda>& self,
332 int64_t val) { self->x = val; });
333
334 m.class_<FooReadWrite>("_FooReadWrite")
335 .def(torch::init<int64_t, int64_t>())
336 .def_readwrite("x", &FooReadWrite::x)
337 .def_readonly("y", &FooReadWrite::y);
338
339 m.class_<LambdaInit>("_LambdaInit")
340 .def(torch::init([](int64_t x, int64_t y, bool swap) {
341 if (swap) {
342 return c10::make_intrusive<LambdaInit>(y, x);
343 } else {
344 return c10::make_intrusive<LambdaInit>(x, y);
345 }
346 }))
347 .def("diff", &LambdaInit::diff);
348
349 m.class_<NoInit>("_NoInit").def(
350 "get_x", [](const c10::intrusive_ptr<NoInit>& self) { return self->x; });
351
352 m.class_<MyStackClass<std::string>>("_StackString")
353 .def(torch::init<std::vector<std::string>>())
354 .def("push", &MyStackClass<std::string>::push)
355 .def("pop", &MyStackClass<std::string>::pop)
356 .def("clone", &MyStackClass<std::string>::clone)
357 .def("merge", &MyStackClass<std::string>::merge)
358 .def_pickle(
359 [](const c10::intrusive_ptr<MyStackClass<std::string>>& self) {
360 return self->stack_;
361 },
362 [](std::vector<std::string> state) { // __setstate__
363 return c10::make_intrusive<MyStackClass<std::string>>(
364 std::vector<std::string>{"i", "was", "deserialized"});
365 })
366 .def("return_a_tuple", &MyStackClass<std::string>::return_a_tuple)
367 .def(
368 "top",
369 [](const c10::intrusive_ptr<MyStackClass<std::string>>& self)
370 -> std::string { return self->stack_.back(); })
371 .def(
372 "__str__",
373 [](const c10::intrusive_ptr<MyStackClass<std::string>>& self) {
374 std::stringstream ss;
375 ss << "[";
376 for (size_t i = 0; i < self->stack_.size(); ++i) {
377 ss << self->stack_[i];
378 if (i != self->stack_.size() - 1) {
379 ss << ", ";
380 }
381 }
382 ss << "]";
383 return ss.str();
384 });
385 // clang-format off
386 // The following will fail with a static assert telling you you have to
387 // take an intrusive_ptr<MyStackClass> as the first argument.
388 // .def("foo", [](int64_t a) -> int64_t{ return 3;});
389 // clang-format on
390
391 m.class_<PickleTester>("_PickleTester")
392 .def(torch::init<std::vector<int64_t>>())
393 .def_pickle(
394 [](c10::intrusive_ptr<PickleTester> self) { // __getstate__
395 return std::vector<int64_t>{1, 3, 3, 7};
396 },
397 [](std::vector<int64_t> state) { // __setstate__
398 return c10::make_intrusive<PickleTester>(std::move(state));
399 })
400 .def(
401 "top",
402 [](const c10::intrusive_ptr<PickleTester>& self) {
403 return self->vals.back();
404 })
405 .def("pop", [](const c10::intrusive_ptr<PickleTester>& self) {
406 auto val = self->vals.back();
407 self->vals.pop_back();
408 return val;
409 });
410
411 m.def(
412 "take_an_instance(__torch__.torch.classes._TorchScriptTesting._PickleTester x) -> Tensor Y",
413 take_an_instance);
414 // test that schema inference is ok too
415 m.def("take_an_instance_inferred", take_an_instance);
416
417 m.class_<ElementwiseInterpreter>("_ElementwiseInterpreter")
418 .def(torch::init<>())
419 .def("set_instructions", &ElementwiseInterpreter::setInstructions)
420 .def("add_constant", &ElementwiseInterpreter::addConstant)
421 .def("set_input_names", &ElementwiseInterpreter::setInputNames)
422 .def("set_output_name", &ElementwiseInterpreter::setOutputName)
423 .def("__call__", &ElementwiseInterpreter::__call__)
424 .def_pickle(
425 /* __getstate__ */
426 [](const c10::intrusive_ptr<ElementwiseInterpreter>& self) {
427 return self->__getstate__();
428 },
429 /* __setstate__ */
430 [](ElementwiseInterpreter::SerializationType state) {
431 return ElementwiseInterpreter::__setstate__(std::move(state));
432 });
433}
434
435} // namespace
436