1#pragma once
2
3#include <torch/csrc/lazy/core/internal_ops/ltc_ops.h>
4#include <torch/csrc/lazy/core/ir.h>
5#include <torch/csrc/lazy/core/ir_builder.h>
6#include <torch/csrc/lazy/core/shape_inference.h>
7#include <torch/csrc/lazy/generated/LazyNonNativeIr.h>
8#include <torch/csrc/lazy/ts_backend/dynamic_ir.h>
9#include <torch/csrc/lazy/ts_backend/ops/device_data.h>
10#include <torch/csrc/lazy/ts_backend/ops/generic.h>
11#include <torch/csrc/lazy/ts_backend/ts_node.h>
12
13namespace torch {
14namespace lazy {
15
16struct TorchScriptIrBuilder : IrBuilder {
17 NodePtr MakeDeviceData(
18 const std::shared_ptr<BackendData>& data) const override {
19 return DeviceData::Create(data);
20 }
21 // TODO: Scalar node is not currently used by ts_backend. Enable reusing
22 // Scalar node later if needed.
23 NodePtr MakeScalar(const at::Scalar& value, const at::ScalarType& type)
24 const override {
25 return MakeNode<Scalar>(value, type);
26 }
27 NodePtr MakeExpand(
28 const Value& input0,
29 const std::vector<int64_t>& size,
30 const bool& is_scalar_expand) const override {
31 return ReuseOrMakeNode<Expand>(input0, size, is_scalar_expand);
32 }
33 NodePtr MakeCast(
34 const Value& input0,
35 const at::ScalarType& dtype,
36 const c10::optional<at::ScalarType>& stype =
37 c10::nullopt) const override {
38 return ReuseOrMakeNode<Cast>(input0, dtype, stype);
39 }
40 NodePtr MakeTensorList(const OpList& inputs) const override {
41 return ReuseOrMakeNode<TensorList>(inputs);
42 }
43 // Generic needs cleanup
44 NodePtr MakeGeneric(
45 const OpKind& op,
46 const OpList& operands,
47 const Shape& shape,
48 const size_t& num_outputs = 1,
49 const hash_t& hash_seed =
50 static_cast<uint32_t>(0x5a2d296e9)) const override {
51 return MakeNode<Generic>(op, operands, shape, num_outputs, hash_seed);
52 }
53
54 // dynamic ir nodes
55 // TODO: verify if IR node reusing works for Dynamic shape ops
56 NodePtr MakeSizeNode(const Value& input, size_t dim) const override {
57 return MakeNode<SizeNode>(input, dim);
58 }
59 NodePtr MakeSizeAdd(const Value& a, const Value& b) const override {
60 return MakeNode<SizeAdd>(a, b);
61 }
62 NodePtr MakeSizeMul(const Value& a, const Value& b) const override {
63 return MakeNode<SizeMul>(a, b);
64 }
65 NodePtr MakeSizeDiv(const Value& a, const Value& b) const override {
66 return MakeNode<SizeDiv>(a, b);
67 }
68};
69
70} // namespace lazy
71} // namespace torch
72