1 | #pragma once |
2 | |
3 | #include <torch/csrc/lazy/core/internal_ops/ltc_ops.h> |
4 | #include <torch/csrc/lazy/core/ir.h> |
5 | #include <torch/csrc/lazy/core/ir_builder.h> |
6 | #include <torch/csrc/lazy/core/shape_inference.h> |
7 | #include <torch/csrc/lazy/generated/LazyNonNativeIr.h> |
8 | #include <torch/csrc/lazy/ts_backend/dynamic_ir.h> |
9 | #include <torch/csrc/lazy/ts_backend/ops/device_data.h> |
10 | #include <torch/csrc/lazy/ts_backend/ops/generic.h> |
11 | #include <torch/csrc/lazy/ts_backend/ts_node.h> |
12 | |
13 | namespace torch { |
14 | namespace lazy { |
15 | |
16 | struct TorchScriptIrBuilder : IrBuilder { |
17 | NodePtr MakeDeviceData( |
18 | const std::shared_ptr<BackendData>& data) const override { |
19 | return DeviceData::Create(data); |
20 | } |
21 | // TODO: Scalar node is not currently used by ts_backend. Enable reusing |
22 | // Scalar node later if needed. |
23 | NodePtr MakeScalar(const at::Scalar& value, const at::ScalarType& type) |
24 | const override { |
25 | return MakeNode<Scalar>(value, type); |
26 | } |
27 | NodePtr MakeExpand( |
28 | const Value& input0, |
29 | const std::vector<int64_t>& size, |
30 | const bool& is_scalar_expand) const override { |
31 | return ReuseOrMakeNode<Expand>(input0, size, is_scalar_expand); |
32 | } |
33 | NodePtr MakeCast( |
34 | const Value& input0, |
35 | const at::ScalarType& dtype, |
36 | const c10::optional<at::ScalarType>& stype = |
37 | c10::nullopt) const override { |
38 | return ReuseOrMakeNode<Cast>(input0, dtype, stype); |
39 | } |
40 | NodePtr MakeTensorList(const OpList& inputs) const override { |
41 | return ReuseOrMakeNode<TensorList>(inputs); |
42 | } |
43 | // Generic needs cleanup |
44 | NodePtr MakeGeneric( |
45 | const OpKind& op, |
46 | const OpList& operands, |
47 | const Shape& shape, |
48 | const size_t& num_outputs = 1, |
49 | const hash_t& hash_seed = |
50 | static_cast<uint32_t>(0x5a2d296e9)) const override { |
51 | return MakeNode<Generic>(op, operands, shape, num_outputs, hash_seed); |
52 | } |
53 | |
54 | // dynamic ir nodes |
55 | // TODO: verify if IR node reusing works for Dynamic shape ops |
56 | NodePtr MakeSizeNode(const Value& input, size_t dim) const override { |
57 | return MakeNode<SizeNode>(input, dim); |
58 | } |
59 | NodePtr MakeSizeAdd(const Value& a, const Value& b) const override { |
60 | return MakeNode<SizeAdd>(a, b); |
61 | } |
62 | NodePtr MakeSizeMul(const Value& a, const Value& b) const override { |
63 | return MakeNode<SizeMul>(a, b); |
64 | } |
65 | NodePtr MakeSizeDiv(const Value& a, const Value& b) const override { |
66 | return MakeNode<SizeDiv>(a, b); |
67 | } |
68 | }; |
69 | |
70 | } // namespace lazy |
71 | } // namespace torch |
72 | |