1#pragma once
2
3#include <memory>
4#include <string>
5#include <unordered_map>
6#include <utility>
7#include <vector>
8
9#include <torch/csrc/lazy/backend/backend_data.h>
10#include <torch/csrc/lazy/backend/backend_device.h>
11#include <torch/csrc/lazy/core/ir.h>
12#include <torch/csrc/lazy/core/ir_util.h>
13
14namespace torch {
15namespace lazy {
16
17class TORCH_API Computation {
18 public:
19 virtual int parameters_size() const = 0;
20
21 virtual const std::vector<Shape>& parameter_shapes() const = 0;
22
23 virtual const std::vector<std::string>& parameter_names() const = 0;
24
25 virtual const Shape& result_shape() const = 0;
26
27 virtual const std::string to_string() const = 0;
28
29 virtual ~Computation() = default;
30
31 // Indicates whether this computation is being executed inside a mark step
32 // Assume false unless set otherwise
33 bool in_mark_step = false;
34};
35
36using ComputationPtr = std::shared_ptr<Computation>;
37
38// Keeps track of the code generation state.
39class TORCH_API LoweringContext {
40 public:
41 LoweringContext(const std::string& name, BackendDevice device);
42 LoweringContext(
43 const std::string& name,
44 BackendDevice device,
45 c10::ArrayRef<const torch::lazy::Node*> post_order,
46 Util::EmissionMap emit_status);
47
48 virtual ~LoweringContext() = default;
49
50 static std::unique_ptr<LoweringContext> Create(
51 const std::string& name,
52 BackendDevice device,
53 c10::ArrayRef<const torch::lazy::Node*> post_order,
54 Util::EmissionMap emit_status);
55
56 static std::unique_ptr<LoweringContext> Create(
57 const std::string& name,
58 BackendDevice device);
59
60 const BackendDevice& device() const {
61 return device_;
62 };
63
64 // Retrieves the vector holding all the tensors associated with the parameter
65 // instructions which have been created.
66 const std::vector<BackendDataPtr>& GetParametersData() const;
67
68 // Adds a new input/output alias.
69 virtual void SetUpAlias(
70 const std::vector<int64_t>& output_index,
71 int64_t param_number,
72 const std::vector<int64_t>& param_index,
73 bool must_alias = false) {
74 // Dummy default implementation to do nothing.
75 }
76
77 // Check if parameter shape matches result at index.
78 virtual bool CheckResultShape(
79 const BackendDataPtr& parameter_data,
80 size_t result_idx) {
81 // Dummy default implementation to do nothing.
82 return false;
83 }
84
85 // Adds the given output as a component of the result tuple and returns its
86 // assigned position within the tuple.
87 virtual size_t AddResult(const torch::lazy::Output& output) = 0;
88
89 // Associates the given output with the input parameter of the given index and
90 // shape. Only used for the operator-by-operator execution, mostly for
91 // debugging purposes.
92 virtual void AddParameter(
93 const torch::lazy::Output& output,
94 size_t index,
95 const Shape& shape,
96 const std::string& name) = 0;
97
98 // Build the computation capturing all the operations created with the
99 // embedded builder (returned by the builder() API).
100 virtual ComputationPtr Build() = 0;
101
102 size_t GetEmittedNodeCount() const {
103 return emit_status_.size();
104 }
105
106 protected:
107 BackendDevice device_;
108 std::vector<BackendDataPtr> parameters_;
109 std::vector<size_t> parameter_sequence_;
110 Util::EmissionMap emit_status_;
111};
112
113} // namespace lazy
114} // namespace torch
115