1#include <ATen/core/function_schema.h>
2#include <ATen/core/jit_type.h>
3#include <ATen/core/symbol.h>
4#include <c10/core/ScalarType.h>
5#include <c10/util/ArrayRef.h>
6#include <c10/util/Optional.h>
7#include <torch/csrc/jit/ir/alias_analysis.h>
8#include <torch/csrc/jit/ir/ir.h>
9#include <torch/csrc/jit/jit_log.h>
10#include <torch/csrc/jit/passes/dtype_analysis.h>
11#include <torch/csrc/jit/passes/utils/op_registry.h>
12#include <torch/library.h>
13
14#ifndef AT_PER_OPERATOR_HEADERS
15#include <ATen/Functions.h>
16#else
17#include <ATen/ops/empty.h>
18#endif
19
20#include <algorithm>
21#include <memory>
22#include <stdexcept>
23
24namespace torch {
25namespace jit {
26
27namespace {
28
29using Tensor = at::Tensor;
30using ScalarType = at::ScalarType;
31
32// ----------------------------------------------------------------------------------
33// Metatensor Inference for Dtype
34// ----------------------------------------------------------------------------------
35
36std::unique_ptr<Stack> MTensorArgumentCreator(Node* n) {
37 auto stack = std::make_unique<std::vector<IValue>>();
38 for (Value* inp : n->inputs()) {
39 if (auto tp = inp->type()->cast<TensorType>()) {
40 // Zero-dim tensors have special type promotion behavoir, hence the need
41 // for rank.
42 auto rank = tp->symbolic_sizes().rank(); // Validity checked earlier
43 auto tensor_size = std::vector<int64_t>(rank.value(), 1);
44 stack->emplace_back(at::empty(
45 tensor_size, at::TensorOptions(at::kMeta).dtype(*tp->scalarType())));
46 continue;
47 }
48 // Someday Todo: Fill in concrete values that we know.
49 if (inp->type() == FloatType::get()) {
50 stack->emplace_back(1.);
51 } else if (inp->type() == IntType::get()) {
52 stack->emplace_back(1);
53 } else if (inp->type() == BoolType::get()) {
54 throw std::runtime_error(
55 "Bool currently unsupported, need to verify it's safe to add for all ops");
56 stack->emplace_back(false);
57 } else {
58 // Arrays of values are specifically not handled due
59 // to the fact that naive default vaules would likely be
60 // incorrect anyways.
61 throw std::runtime_error("Unsupported input type for Tensor argument");
62 }
63 }
64 return stack;
65};
66
67bool MTensorNodeArgValid(Value* value) {
68 auto tensor_type = value->type()->cast<TensorType>();
69 if (!tensor_type) {
70 return true;
71 }
72 if (!tensor_type->scalarType().has_value()) {
73 GRAPH_DEBUG("Argument missing Dtype");
74 return false;
75 }
76 auto rank = tensor_type->symbolic_sizes().rank();
77 return rank.has_value();
78}
79
80static bool canBeInferredWithMetaTensor(Node* n) {
81 // Not a guarantee that the metatensor will not error out
82 // Do not have a allowlist for now and let things error out in execution.
83 // Has Tensor output is checked in another place
84 bool args_valid =
85 std::all_of(n->inputs().begin(), n->inputs().end(), MTensorNodeArgValid);
86
87 if (!args_valid) {
88 return false;
89 }
90 if (n->outputs().size() != 1) {
91 // Currently not supporting multiple outputs
92 return false;
93 }
94 auto opt_op = n->maybeOperator();
95 if (!opt_op) {
96 GRAPH_DEBUG("not registered with Meta");
97 return false;
98 }
99 return true;
100}
101
102c10::optional<Tensor> inferWithMetaTensor(Node* n) {
103 GRAPH_DEBUG("inferWithMetaTensor", getHeader(n));
104 if (!canBeInferredWithMetaTensor(n)) {
105 return c10::nullopt;
106 }
107 Operation op = n->getOperation();
108 try {
109 auto stack = MTensorArgumentCreator(n);
110 GRAPH_DEBUG("Running op for ", getHeader(n));
111 op(*stack);
112 GRAPH_DEBUG("op run successfully", getHeader(n));
113 GRAPH_DEBUG("After receive!");
114 return stack->back().toTensor();
115
116 } catch (...) {
117 GRAPH_DEBUG("caught exception with Metatensor run!");
118 };
119 return c10::nullopt;
120}
121
122bool setDtype(
123 Value* value,
124 ScalarType scalarType,
125 bool can_overwrite_dtype = false) {
126 auto tensor_type = value->type()->cast<TensorType>();
127 TORCH_INTERNAL_ASSERT(tensor_type, "Expecting a tensor type");
128 if (!tensor_type->scalarType().has_value()) {
129 value->setType(tensor_type->withScalarType(scalarType));
130 return true;
131 }
132 if (tensor_type->scalarType().value() != scalarType) {
133 TORCH_INTERNAL_ASSERT(
134 can_overwrite_dtype,
135 "Expected tensor type to be ",
136 scalarType,
137 " but found ",
138 tensor_type->scalarType().value());
139 value->setType(tensor_type->withScalarType(scalarType));
140 return true;
141 }
142 return false;
143}
144
145bool tryApplyDtypeMetaTensor(Node* n) {
146 // returns if anything was changed
147 auto return_tensor = inferWithMetaTensor(n);
148 if (!return_tensor) {
149 return false;
150 }
151 GRAPH_DEBUG("Received ", toString(return_tensor->scalar_type()));
152 return setDtype(n->output(), return_tensor->scalar_type());
153}
154
155// ----------------------------------------------------------------------------------
156// Custom Rules for Dtype
157// ----------------------------------------------------------------------------------
158using DtypePropRule = std::function<bool(Node*)>;
159// Function to propagate dtype information for a node
160// Returns true if the dtype information was changed
161
162bool setIfAllDtypeMatch(Node* n) {
163 // Sets all tensor outputs to the dtype of the first input
164 // only if all inputs are the same dtype, otherwise do nothing
165 TORCH_INTERNAL_ASSERT(!n->inputs().empty());
166 auto first_arg = n->inputs().at(0);
167 auto tensor_type = first_arg->type()->cast<TensorType>();
168 TORCH_INTERNAL_ASSERT(tensor_type, "Expecting a tensor type");
169 auto scalar_type = tensor_type->scalarType();
170 if (!scalar_type.has_value()) {
171 return false;
172 }
173 for (auto arg : n->inputs()) {
174 tensor_type = arg->type()->cast<TensorType>();
175 if (!tensor_type) {
176 continue;
177 }
178 auto arg_scalar_type = tensor_type->scalarType();
179
180 if (!arg_scalar_type.has_value()) { // Allow None for optional args
181 continue;
182 }
183 if (arg_scalar_type != scalar_type) {
184 return false;
185 }
186 }
187
188 bool changed = false;
189 for (auto output : n->outputs()) {
190 if (output->type()->cast<TensorType>()) {
191 changed |= setDtype(output, scalar_type.value());
192 }
193 }
194 return changed;
195}
196
197// DtypePropagationPass is an analysis pass that walks through a graph in
198// topological order and forward propagate Dtypes (ScalarTypes) from graph
199// inputs (expressed in input_descriptors) to all output tensor nodes in the
200// graph.
201struct DtypePropagationPass {
202 explicit DtypePropagationPass(std::shared_ptr<Graph> graph)
203 : graph_(std::move(graph)) {
204 buildDtypeRuleRegistry();
205 }
206
207 // returns true if at least one node has its scalar type set on a tensor node
208 bool run() {
209 return processBlocks(graph_->block());
210 }
211
212 private:
213 bool processBlocks(at::ArrayRef<Block*> blocks) {
214 bool changed = false;
215 for (auto block : blocks) {
216 changed |= processBlock(block);
217 }
218 return changed;
219 }
220
221 bool processBlock(Block* block) {
222 GRAPH_DEBUG("processBlock");
223 bool changed = false;
224 for (auto it = block->nodes().begin(); it != block->nodes().end(); it++) {
225 changed |= processNode(*it);
226 }
227 return changed;
228 }
229
230 bool processNode(Node* n) {
231 GRAPH_DEBUG("processNode");
232 switch (n->kind()) {
233 case prim::If:
234 return processIf(n);
235 case prim::Loop:
236 case prim::CallMethod:
237 case prim::CallFunction:
238 TORCH_INTERNAL_ASSERT(false, "Loop/Call not handled now");
239 default:
240 break;
241 }
242
243 bool has_tensor_output =
244 std::any_of(n->outputs().begin(), n->outputs().end(), [](Value* v) {
245 return (bool)v->type()->cast<TensorType>();
246 });
247
248 if (!has_tensor_output) {
249 // if output contains no tensor, nothing to propagate
250 return false;
251 }
252
253 switch (n->kind()) {
254 case prim::Constant:
255 // This is already been propagated by something else in freezing
256 return false;
257 case prim::ListConstruct:
258 case prim::ListUnpack:
259 TORCH_INTERNAL_ASSERT(
260 false,
261 "List Construct and Unpack is not supported in Dtype Propagation");
262 break;
263 default:
264 if (n->kind().is_aten()) {
265 return processAtenOps(n);
266 } else {
267 TORCH_INTERNAL_ASSERT(
268 false,
269 n->kind().toDisplayString(),
270 "Op is not supported in Dtype Propagation");
271 }
272 }
273 return false;
274 }
275
276 bool mergeTensorProperties(
277 const at::ArrayRef<Value*>& list1,
278 const at::ArrayRef<Value*>& list2) {
279 // This is currently a placeholder for MobileNet
280 // After Month1: implement the merge function
281 TORCH_INTERNAL_ASSERT(list1.empty(), "Not implemented yet");
282 return false;
283 }
284
285 bool processIf(Node* node) {
286 GRAPH_DEBUG("processIf");
287 bool changed = false;
288 auto blocks = node->blocks();
289 auto true_block = blocks.at(0);
290 auto false_block = blocks.at(1);
291
292 changed |= processBlock(true_block);
293 changed |= processBlock(false_block);
294
295 changed |=
296 mergeTensorProperties(true_block->outputs(), false_block->outputs());
297
298 return changed;
299 }
300
301 // for efficiency
302 bool processAtenOps(Node* n) {
303 GRAPH_DEBUG("processAtenOps");
304 GRAPH_DEBUG("case = ", n->kind(), " ", *n);
305 // Custom Rule Matching
306 if (auto prop_fn = dtype_prop_registry_->find(n->getOperator())) {
307 DtypePropRule rule = *prop_fn;
308 return rule(n);
309 }
310 return tryApplyDtypeMetaTensor(n);
311 }
312
313 void buildDtypeRuleRegistry() {
314 // building a registry for all of the custom dtype rules
315 dtype_prop_registry_ = std::make_unique<OperatorMap<DtypePropRule>>();
316
317 dtype_prop_registry_->insert(
318 *nn_ops_first_input_preserving(), setIfAllDtypeMatch);
319 dtype_prop_registry_->insert(
320 *ops_one_tensor_in_shape_transform(), setIfAllDtypeMatch);
321 }
322 std::unique_ptr<OperatorMap<DtypePropRule>> dtype_prop_registry_;
323 std::shared_ptr<Graph> graph_;
324};
325
326} // anonymous namespace
327
328// This analysis propagates input dtypes (if any) throughout the
329// graph.
330bool DtypePropagation(std::shared_ptr<Graph>& graph) {
331 DtypePropagationPass tp = DtypePropagationPass(graph);
332 bool changed = tp.run();
333 if (changed) {
334 GRAPH_DUMP("After TensorPropertyPropagation pass:", graph);
335 }
336 return changed;
337}
338
339} // namespace jit
340} // namespace torch
341