1#pragma once
2
3#include <c10/macros/Export.h>
4
5#include <ir_base_nodes.h>
6#include <utils.h>
7
8#include <deque>
9#include <unordered_map>
10#include <unordered_set>
11
12namespace torch {
13namespace jit {
14namespace fuser {
15namespace cuda {
16
17class IrBuilderPasskey;
18class ExprPasskey;
19class OptOutMutator;
20
21class Int;
22class Bool;
23class NamedScalar;
24
25// Passkey for container to register names with statements
26class IrContainerPasskey {
27 friend class IrContainer;
28
29 private:
30 explicit IrContainerPasskey() {}
31};
32
33class TORCH_CUDA_CU_API IrContainer : public PolymorphicBase {
34 public:
35 IrContainer();
36
37 IrContainer(const IrContainer& other);
38 IrContainer(IrContainer&& other) noexcept;
39
40 IrContainer& operator=(const IrContainer& other);
41 IrContainer& operator=(IrContainer&& other) noexcept;
42
43 virtual ~IrContainer();
44
45 bool inContainer(const Statement* stmt) const;
46
47 void assertInContainer(const Statement* stmt, const std::string& msg) const {
48 TORCH_CHECK(
49 inContainer(stmt), msg, " it was not found in the active container.");
50 }
51
52 //! Return in insertion order
53 const std::deque<Val*> deterministic_vals() const noexcept {
54 std::deque<Val*> vals_deque;
55 std::transform(
56 vals_up_.begin(),
57 vals_up_.end(),
58 std::back_inserter(vals_deque),
59 [](const std::unique_ptr<Val>& val_up) { return val_up.get(); });
60 return vals_deque;
61 }
62
63 //! Register the Statement with this container
64 virtual void registerStmt(IrBuilderPasskey, Statement* stmt);
65
66 //! Register the Val with this container
67 virtual void registerVal(IrBuilderPasskey, Val* val);
68
69 //! Register expr with this container.
70 virtual void registerExpr(IrBuilderPasskey, Expr* expr);
71
72 //! Allow expr's to register themselves with a container, this is only used
73 //! for broadcastOp so it can register itself in its constructor so root maps
74 //! can be built.
75 virtual void registerExpr(ExprPasskey, Expr* expr);
76
77 //! Return the set of Exprs registered with this fusion. Warning: This will
78 //! return exprs outside inputs/outputs, so can be unsafe for use with
79 //! segmented fusions.
80 const std::unordered_set<Expr*>& unordered_exprs() const noexcept {
81 return exprs_;
82 }
83
84 //! Return the set of Vals registered with this fusion
85 const std::unordered_set<Val*>& vals() const noexcept {
86 return vals_;
87 }
88
89 // Shortcuts for frequently used vals
90 Int* zeroVal();
91 Int* oneVal();
92 Bool* falseVal();
93 Bool* trueVal();
94 NamedScalar* magicZeroVal();
95
96 protected:
97 static IrCloner copy(const IrContainer* from, IrContainer* to);
98
99 friend void swap(IrContainer& a, IrContainer& b) noexcept;
100
101 // Let mutator remove Exprs.
102 friend OptOutMutator;
103
104 virtual void removeExpr(Expr* expr);
105
106 //! Completely remove val from the fusion, break all dependencies associated
107 //! with it
108 virtual void removeVal(Val* val);
109
110 //! Register the Val with this container
111 virtual void registerVal(Val* val);
112
113 //! Register expr with this container.
114 virtual void registerExpr(Expr* expr);
115
116 StmtNameType getValName(ValType vtype) {
117 if (val_type_name_map_.find(vtype) == val_type_name_map_.end()) {
118 val_type_name_map_[vtype] = 0;
119 }
120 return val_type_name_map_[vtype]++;
121 }
122
123 StmtNameType getExprName() {
124 return expr_name_counter_++;
125 }
126
127 void clear() noexcept;
128
129 // Deque of unique pointer is the memory owning data structure
130 std::deque<std::unique_ptr<Val>> vals_up_;
131
132 // A convenient set to return when we just need an unordered set to do
133 // something like check if a Val is in this container
134 std::unordered_set<Val*> vals_;
135
136 // Deque of unique pointer is the memory owning data structure
137 std::deque<std::unique_ptr<Expr>> exprs_up_;
138
139 // A convenient set to return when we just need an unordered set to do
140 // something like check if an Expr is in this container
141 std::unordered_set<Expr*> exprs_;
142
143 // Used to implement a generic "inContainer" that can be passed an invalid
144 // pointer. Specifically a pointer to a Statement owned by another container
145 // that has been freed. We can't check normally with the unordered_sets we
146 // already have because it would require a const_cast from a constant
147 // expr/val, or a dynamic cast from a Statement.
148 std::unordered_set<void*> raw_ptrs_;
149
150 // Values names counters
151 std::unordered_map<ValType, StmtNameType, TypeHash> val_type_name_map_;
152
153 // Expression names counter
154 StmtNameType expr_name_counter_ = 0;
155
156 // Manually store some persistent, frequently used nodes. It's very
157 // challenging to do this anything but manually as detecting when a container
158 // may or may not have one of these vals is tricky. Specifically because if
159 // the container doesn't own it, it's hard to understand from the outside if
160 // the node may have been removed then re-registered. It could also be tricky
161 // to know when we're using a different container as in FusionCopy_test
162 // demonstrates deleting then creating containers can result in the same
163 // pointer for the container.
164 std::unique_ptr<Bool> true_val_;
165 std::unique_ptr<Bool> false_val_;
166 std::unique_ptr<Int> one_val_;
167 std::unique_ptr<Int> zero_val_;
168 std::unique_ptr<NamedScalar> magic_zero_val_;
169};
170
171} // namespace cuda
172} // namespace fuser
173} // namespace jit
174} // namespace torch
175