1 | #include <ATen/core/function_schema.h> |
2 | #include <ATen/core/jit_type.h> |
3 | #include <ATen/core/symbol.h> |
4 | #include <c10/core/ScalarType.h> |
5 | #include <c10/util/ArrayRef.h> |
6 | #include <c10/util/Optional.h> |
7 | #include <torch/csrc/jit/ir/alias_analysis.h> |
8 | #include <torch/csrc/jit/ir/ir.h> |
9 | #include <torch/csrc/jit/jit_log.h> |
10 | #include <torch/csrc/jit/passes/dtype_analysis.h> |
11 | #include <torch/csrc/jit/passes/utils/op_registry.h> |
12 | #include <torch/library.h> |
13 | |
14 | #ifndef AT_PER_OPERATOR_HEADERS |
15 | #include <ATen/Functions.h> |
16 | #else |
17 | #include <ATen/ops/empty.h> |
18 | #endif |
19 | |
20 | #include <algorithm> |
21 | #include <memory> |
22 | #include <stdexcept> |
23 | |
24 | namespace torch { |
25 | namespace jit { |
26 | |
27 | namespace { |
28 | |
29 | using Tensor = at::Tensor; |
30 | using ScalarType = at::ScalarType; |
31 | |
32 | // ---------------------------------------------------------------------------------- |
33 | // Metatensor Inference for Dtype |
34 | // ---------------------------------------------------------------------------------- |
35 | |
36 | std::unique_ptr<Stack> MTensorArgumentCreator(Node* n) { |
37 | auto stack = std::make_unique<std::vector<IValue>>(); |
38 | for (Value* inp : n->inputs()) { |
39 | if (auto tp = inp->type()->cast<TensorType>()) { |
40 | // Zero-dim tensors have special type promotion behavoir, hence the need |
41 | // for rank. |
42 | auto rank = tp->symbolic_sizes().rank(); // Validity checked earlier |
43 | auto tensor_size = std::vector<int64_t>(rank.value(), 1); |
44 | stack->emplace_back(at::empty( |
45 | tensor_size, at::TensorOptions(at::kMeta).dtype(*tp->scalarType()))); |
46 | continue; |
47 | } |
48 | // Someday Todo: Fill in concrete values that we know. |
49 | if (inp->type() == FloatType::get()) { |
50 | stack->emplace_back(1.); |
51 | } else if (inp->type() == IntType::get()) { |
52 | stack->emplace_back(1); |
53 | } else if (inp->type() == BoolType::get()) { |
54 | throw std::runtime_error( |
55 | "Bool currently unsupported, need to verify it's safe to add for all ops" ); |
56 | stack->emplace_back(false); |
57 | } else { |
58 | // Arrays of values are specifically not handled due |
59 | // to the fact that naive default vaules would likely be |
60 | // incorrect anyways. |
61 | throw std::runtime_error("Unsupported input type for Tensor argument" ); |
62 | } |
63 | } |
64 | return stack; |
65 | }; |
66 | |
67 | bool MTensorNodeArgValid(Value* value) { |
68 | auto tensor_type = value->type()->cast<TensorType>(); |
69 | if (!tensor_type) { |
70 | return true; |
71 | } |
72 | if (!tensor_type->scalarType().has_value()) { |
73 | GRAPH_DEBUG("Argument missing Dtype" ); |
74 | return false; |
75 | } |
76 | auto rank = tensor_type->symbolic_sizes().rank(); |
77 | return rank.has_value(); |
78 | } |
79 | |
80 | static bool canBeInferredWithMetaTensor(Node* n) { |
81 | // Not a guarantee that the metatensor will not error out |
82 | // Do not have a allowlist for now and let things error out in execution. |
83 | // Has Tensor output is checked in another place |
84 | bool args_valid = |
85 | std::all_of(n->inputs().begin(), n->inputs().end(), MTensorNodeArgValid); |
86 | |
87 | if (!args_valid) { |
88 | return false; |
89 | } |
90 | if (n->outputs().size() != 1) { |
91 | // Currently not supporting multiple outputs |
92 | return false; |
93 | } |
94 | auto opt_op = n->maybeOperator(); |
95 | if (!opt_op) { |
96 | GRAPH_DEBUG("not registered with Meta" ); |
97 | return false; |
98 | } |
99 | return true; |
100 | } |
101 | |
102 | c10::optional<Tensor> inferWithMetaTensor(Node* n) { |
103 | GRAPH_DEBUG("inferWithMetaTensor" , getHeader(n)); |
104 | if (!canBeInferredWithMetaTensor(n)) { |
105 | return c10::nullopt; |
106 | } |
107 | Operation op = n->getOperation(); |
108 | try { |
109 | auto stack = MTensorArgumentCreator(n); |
110 | GRAPH_DEBUG("Running op for " , getHeader(n)); |
111 | op(*stack); |
112 | GRAPH_DEBUG("op run successfully" , getHeader(n)); |
113 | GRAPH_DEBUG("After receive!" ); |
114 | return stack->back().toTensor(); |
115 | |
116 | } catch (...) { |
117 | GRAPH_DEBUG("caught exception with Metatensor run!" ); |
118 | }; |
119 | return c10::nullopt; |
120 | } |
121 | |
122 | bool setDtype( |
123 | Value* value, |
124 | ScalarType scalarType, |
125 | bool can_overwrite_dtype = false) { |
126 | auto tensor_type = value->type()->cast<TensorType>(); |
127 | TORCH_INTERNAL_ASSERT(tensor_type, "Expecting a tensor type" ); |
128 | if (!tensor_type->scalarType().has_value()) { |
129 | value->setType(tensor_type->withScalarType(scalarType)); |
130 | return true; |
131 | } |
132 | if (tensor_type->scalarType().value() != scalarType) { |
133 | TORCH_INTERNAL_ASSERT( |
134 | can_overwrite_dtype, |
135 | "Expected tensor type to be " , |
136 | scalarType, |
137 | " but found " , |
138 | tensor_type->scalarType().value()); |
139 | value->setType(tensor_type->withScalarType(scalarType)); |
140 | return true; |
141 | } |
142 | return false; |
143 | } |
144 | |
145 | bool tryApplyDtypeMetaTensor(Node* n) { |
146 | // returns if anything was changed |
147 | auto return_tensor = inferWithMetaTensor(n); |
148 | if (!return_tensor) { |
149 | return false; |
150 | } |
151 | GRAPH_DEBUG("Received " , toString(return_tensor->scalar_type())); |
152 | return setDtype(n->output(), return_tensor->scalar_type()); |
153 | } |
154 | |
155 | // ---------------------------------------------------------------------------------- |
156 | // Custom Rules for Dtype |
157 | // ---------------------------------------------------------------------------------- |
158 | using DtypePropRule = std::function<bool(Node*)>; |
159 | // Function to propagate dtype information for a node |
160 | // Returns true if the dtype information was changed |
161 | |
162 | bool setIfAllDtypeMatch(Node* n) { |
163 | // Sets all tensor outputs to the dtype of the first input |
164 | // only if all inputs are the same dtype, otherwise do nothing |
165 | TORCH_INTERNAL_ASSERT(!n->inputs().empty()); |
166 | auto first_arg = n->inputs().at(0); |
167 | auto tensor_type = first_arg->type()->cast<TensorType>(); |
168 | TORCH_INTERNAL_ASSERT(tensor_type, "Expecting a tensor type" ); |
169 | auto scalar_type = tensor_type->scalarType(); |
170 | if (!scalar_type.has_value()) { |
171 | return false; |
172 | } |
173 | for (auto arg : n->inputs()) { |
174 | tensor_type = arg->type()->cast<TensorType>(); |
175 | if (!tensor_type) { |
176 | continue; |
177 | } |
178 | auto arg_scalar_type = tensor_type->scalarType(); |
179 | |
180 | if (!arg_scalar_type.has_value()) { // Allow None for optional args |
181 | continue; |
182 | } |
183 | if (arg_scalar_type != scalar_type) { |
184 | return false; |
185 | } |
186 | } |
187 | |
188 | bool changed = false; |
189 | for (auto output : n->outputs()) { |
190 | if (output->type()->cast<TensorType>()) { |
191 | changed |= setDtype(output, scalar_type.value()); |
192 | } |
193 | } |
194 | return changed; |
195 | } |
196 | |
197 | // DtypePropagationPass is an analysis pass that walks through a graph in |
198 | // topological order and forward propagate Dtypes (ScalarTypes) from graph |
199 | // inputs (expressed in input_descriptors) to all output tensor nodes in the |
200 | // graph. |
201 | struct DtypePropagationPass { |
202 | explicit DtypePropagationPass(std::shared_ptr<Graph> graph) |
203 | : graph_(std::move(graph)) { |
204 | buildDtypeRuleRegistry(); |
205 | } |
206 | |
207 | // returns true if at least one node has its scalar type set on a tensor node |
208 | bool run() { |
209 | return processBlocks(graph_->block()); |
210 | } |
211 | |
212 | private: |
213 | bool processBlocks(at::ArrayRef<Block*> blocks) { |
214 | bool changed = false; |
215 | for (auto block : blocks) { |
216 | changed |= processBlock(block); |
217 | } |
218 | return changed; |
219 | } |
220 | |
221 | bool processBlock(Block* block) { |
222 | GRAPH_DEBUG("processBlock" ); |
223 | bool changed = false; |
224 | for (auto it = block->nodes().begin(); it != block->nodes().end(); it++) { |
225 | changed |= processNode(*it); |
226 | } |
227 | return changed; |
228 | } |
229 | |
230 | bool processNode(Node* n) { |
231 | GRAPH_DEBUG("processNode" ); |
232 | switch (n->kind()) { |
233 | case prim::If: |
234 | return processIf(n); |
235 | case prim::Loop: |
236 | case prim::CallMethod: |
237 | case prim::CallFunction: |
238 | TORCH_INTERNAL_ASSERT(false, "Loop/Call not handled now" ); |
239 | default: |
240 | break; |
241 | } |
242 | |
243 | bool has_tensor_output = |
244 | std::any_of(n->outputs().begin(), n->outputs().end(), [](Value* v) { |
245 | return (bool)v->type()->cast<TensorType>(); |
246 | }); |
247 | |
248 | if (!has_tensor_output) { |
249 | // if output contains no tensor, nothing to propagate |
250 | return false; |
251 | } |
252 | |
253 | switch (n->kind()) { |
254 | case prim::Constant: |
255 | // This is already been propagated by something else in freezing |
256 | return false; |
257 | case prim::ListConstruct: |
258 | case prim::ListUnpack: |
259 | TORCH_INTERNAL_ASSERT( |
260 | false, |
261 | "List Construct and Unpack is not supported in Dtype Propagation" ); |
262 | break; |
263 | default: |
264 | if (n->kind().is_aten()) { |
265 | return processAtenOps(n); |
266 | } else { |
267 | TORCH_INTERNAL_ASSERT( |
268 | false, |
269 | n->kind().toDisplayString(), |
270 | "Op is not supported in Dtype Propagation" ); |
271 | } |
272 | } |
273 | return false; |
274 | } |
275 | |
276 | bool mergeTensorProperties( |
277 | const at::ArrayRef<Value*>& list1, |
278 | const at::ArrayRef<Value*>& list2) { |
279 | // This is currently a placeholder for MobileNet |
280 | // After Month1: implement the merge function |
281 | TORCH_INTERNAL_ASSERT(list1.empty(), "Not implemented yet" ); |
282 | return false; |
283 | } |
284 | |
285 | bool processIf(Node* node) { |
286 | GRAPH_DEBUG("processIf" ); |
287 | bool changed = false; |
288 | auto blocks = node->blocks(); |
289 | auto true_block = blocks.at(0); |
290 | auto false_block = blocks.at(1); |
291 | |
292 | changed |= processBlock(true_block); |
293 | changed |= processBlock(false_block); |
294 | |
295 | changed |= |
296 | mergeTensorProperties(true_block->outputs(), false_block->outputs()); |
297 | |
298 | return changed; |
299 | } |
300 | |
301 | // for efficiency |
302 | bool processAtenOps(Node* n) { |
303 | GRAPH_DEBUG("processAtenOps" ); |
304 | GRAPH_DEBUG("case = " , n->kind(), " " , *n); |
305 | // Custom Rule Matching |
306 | if (auto prop_fn = dtype_prop_registry_->find(n->getOperator())) { |
307 | DtypePropRule rule = *prop_fn; |
308 | return rule(n); |
309 | } |
310 | return tryApplyDtypeMetaTensor(n); |
311 | } |
312 | |
313 | void buildDtypeRuleRegistry() { |
314 | // building a registry for all of the custom dtype rules |
315 | dtype_prop_registry_ = std::make_unique<OperatorMap<DtypePropRule>>(); |
316 | |
317 | dtype_prop_registry_->insert( |
318 | *nn_ops_first_input_preserving(), setIfAllDtypeMatch); |
319 | dtype_prop_registry_->insert( |
320 | *ops_one_tensor_in_shape_transform(), setIfAllDtypeMatch); |
321 | } |
322 | std::unique_ptr<OperatorMap<DtypePropRule>> dtype_prop_registry_; |
323 | std::shared_ptr<Graph> graph_; |
324 | }; |
325 | |
326 | } // anonymous namespace |
327 | |
328 | // This analysis propagates input dtypes (if any) throughout the |
329 | // graph. |
330 | bool DtypePropagation(std::shared_ptr<Graph>& graph) { |
331 | DtypePropagationPass tp = DtypePropagationPass(graph); |
332 | bool changed = tp.run(); |
333 | if (changed) { |
334 | GRAPH_DUMP("After TensorPropertyPropagation pass:" , graph); |
335 | } |
336 | return changed; |
337 | } |
338 | |
339 | } // namespace jit |
340 | } // namespace torch |
341 | |