1 | #include <ATen/core/functional.h> |
2 | #include <ATen/core/interned_strings.h> |
3 | #include <c10/core/MemoryFormat.h> |
4 | #include <c10/core/ScalarType.h> |
5 | #include <c10/util/Exception.h> |
6 | #include <torch/csrc/jit/ir/ir.h> |
7 | #include <torch/csrc/jit/ir/ir_views.h> |
8 | #include <torch/csrc/jit/jit_log.h> |
9 | #include <torch/csrc/jit/passes/symbolic_shape_runtime_fusion.h> |
10 | #include <torch/csrc/jit/passes/tensorexpr_fuser.h> |
11 | #include <torch/csrc/jit/passes/utils/subgraph_utils.h> |
12 | #include <torch/csrc/jit/runtime/graph_iterator.h> |
13 | #include <torch/csrc/jit/runtime/register_ops_utils.h> |
14 | #include <torch/csrc/jit/runtime/static/ops.h> |
15 | #include <sstream> |
16 | #include <utility> |
17 | |
18 | namespace torch { |
19 | namespace jit { |
20 | |
21 | // Inserts the Compute for Each Symbolic Shape in the TensorExpr Graph |
22 | // and returns back a map from Symbolic Shape Value to its runtime Value * |
23 | std::map<int64_t, Value*> InsertSymbolicShapesCompute( |
24 | const ShapeComputeGraphMapping& shape_mapping, |
25 | Node* tensorexpr_graph) { |
26 | WithInsertPoint guard(tensorexpr_graph); |
27 | auto enclosing_graph = tensorexpr_graph->owningGraph(); |
28 | |
29 | std::map<Value*, Value*> shape_graph_input_to_enclosing_graph_value; |
30 | for (const auto& pair : |
31 | shape_mapping.enclosing_graph_value_to_shape_graph_input_) { |
32 | shape_graph_input_to_enclosing_graph_value[pair.second] = pair.first; |
33 | } |
34 | std::vector<Value*> shape_compute_graph_inputs; |
35 | for (Value* shape_graph_input : |
36 | shape_mapping.partial_eval_shape_graph->inputs()) { |
37 | auto enclosing_graph_input = |
38 | shape_graph_input_to_enclosing_graph_value.find(shape_graph_input); |
39 | TORCH_INTERNAL_ASSERT( |
40 | enclosing_graph_input != |
41 | shape_graph_input_to_enclosing_graph_value.end()); |
42 | if (*enclosing_graph_input->second->type() == *shape_graph_input->type()) { |
43 | shape_compute_graph_inputs.push_back(tensorexpr_graph->inputs().at( |
44 | enclosing_graph_input->second->offset())); |
45 | } else { |
46 | TORCH_INTERNAL_ASSERT( |
47 | enclosing_graph_input->second->type()->cast<TensorType>() && |
48 | shape_graph_input->type()->isSubtypeOf(ListType::ofInts())); |
49 | shape_compute_graph_inputs.push_back(enclosing_graph->insert( |
50 | aten::size, |
51 | {tensorexpr_graph->inputs().at( |
52 | enclosing_graph_input->second->offset())})); |
53 | } |
54 | } |
55 | auto sym_shape_values = insertGraph( |
56 | *enclosing_graph, |
57 | *shape_mapping.partial_eval_shape_graph, |
58 | shape_compute_graph_inputs); |
59 | std::map<int64_t, Value*> sym_shape_to_enclosing_graph_value; |
60 | for (size_t i = 0; |
61 | i < shape_mapping.partial_eval_shape_graph->outputs().size(); |
62 | ++i) { |
63 | Value* output = shape_mapping.partial_eval_shape_graph->outputs().at(i); |
64 | auto sym_shape = |
65 | shape_mapping.graph_output_to_symbolic_shape_dim_.find(output); |
66 | TORCH_INTERNAL_ASSERT( |
67 | sym_shape != shape_mapping.graph_output_to_symbolic_shape_dim_.end()); |
68 | sym_shape_to_enclosing_graph_value[sym_shape->second] = sym_shape_values[i]; |
69 | } |
70 | return sym_shape_to_enclosing_graph_value; |
71 | } |
72 | |
73 | void insertDynamicShapesGuard( |
74 | const ShapeComputeGraphMapping& shape_mapping, |
75 | Node* guarded_node, |
76 | bool add_composed_op, |
77 | std::vector<std::vector<StrideInput>>& input_info, |
78 | std::vector<StrideInput>& output_strides); |
79 | |
80 | std::string toString(StrideInput si) { |
81 | switch (si) { |
82 | case StrideInput::TENSOR_CONT: |
83 | return "TENSOR_CONT" ; |
84 | case StrideInput::TENSOR_CONT_CHANNELS_LAST: |
85 | return "TENSOR_CONT_CHANNELS_LAST" ; |
86 | case StrideInput::S_ONE: |
87 | return "S_ONE" ; |
88 | case StrideInput::S_CONT: |
89 | return "S_CONT" ; |
90 | case StrideInput::S_TRAN_CONT: |
91 | return "S_TRAN_CONT" ; |
92 | case StrideInput::S_AS_ARG: |
93 | return "S_AS_ARG" ; |
94 | } |
95 | TORCH_INTERNAL_ASSERT(false); |
96 | } |
97 | |
98 | StrideInput strideInputFromString(const std::string& si) { |
99 | if (si == "TENSOR_CONT" ) { |
100 | return StrideInput::TENSOR_CONT; |
101 | } else if (si == "TENSOR_CONT_CHANNELS_LAST" ) { |
102 | return StrideInput::TENSOR_CONT_CHANNELS_LAST; |
103 | } else if (si == "S_ONE" ) { |
104 | return StrideInput::S_ONE; |
105 | } else if (si == "S_CONT" ) { |
106 | return StrideInput::S_CONT; |
107 | } else if (si == "S_TRAN_CONT" ) { |
108 | return StrideInput::S_TRAN_CONT; |
109 | } else if (si == "S_AS_ARG" ) { |
110 | return StrideInput::S_AS_ARG; |
111 | } else { |
112 | TORCH_INTERNAL_ASSERT(false); |
113 | } |
114 | } |
115 | |
116 | // in the runtime guard, strides are serialized as one flat |
117 | // vector. stride_inputs_offset indexes into that vector |
118 | // where the strides of this tensor beegin |
119 | inline StrideInput summarizeStrideDim( |
120 | const c10::IntArrayRef sizes, |
121 | const c10::IntArrayRef strides, |
122 | size_t dim, |
123 | const std::vector<StrideInput>& stride_inputs, |
124 | size_t stride_inputs_offset) { |
125 | if (strides[dim] == 1) { |
126 | return StrideInput::S_ONE; |
127 | } else if ( |
128 | dim + 1 < sizes.size() && |
129 | strides[dim] == strides[dim + 1] * sizes[dim + 1]) { |
130 | return StrideInput::S_CONT; |
131 | // Transposed Contiguous depends on prior dim and contiguous depends on next |
132 | // dim, so to avoid a mutual dependence check that the next dim is Stride |
133 | // Contiguous |
134 | } else if ( |
135 | dim > 0 && strides[dim] == strides[dim - 1] * sizes[dim - 1] && |
136 | (stride_inputs[dim - 1 + stride_inputs_offset] != StrideInput::S_CONT)) { |
137 | return StrideInput::S_TRAN_CONT; |
138 | } else { |
139 | return StrideInput::S_AS_ARG; |
140 | } |
141 | } |
142 | |
143 | std::vector<StrideInput> summarizeInputStrides(const TensorType& tt) { |
144 | auto strides = *tt.strides().concrete_sizes(); |
145 | auto sizes = *tt.sizes().concrete_sizes(); |
146 | if (c10::is_contiguous_strides(sizes, strides)) { |
147 | return {StrideInput::TENSOR_CONT}; |
148 | // TODO: channels last 3d |
149 | } else if (c10::is_channels_last_strides_2d(sizes, strides)) { |
150 | return {StrideInput::TENSOR_CONT_CHANNELS_LAST}; |
151 | } |
152 | std::vector<StrideInput> stride_inputs; |
153 | for (size_t dim = 0; dim < sizes.size(); ++dim) { |
154 | stride_inputs.push_back( |
155 | summarizeStrideDim(sizes, strides, dim, stride_inputs, 0)); |
156 | } |
157 | return stride_inputs; |
158 | }; |
159 | |
160 | // Todo: incorporate in codegen |
161 | StrideInput summarizeOutputStrides(const TensorType& tt) { |
162 | auto strides = *tt.strides().concrete_sizes(); |
163 | auto sizes = *tt.sizes().concrete_sizes(); |
164 | // We only try to maintain output striding for channels last tensors, |
165 | // otherwise we defer to contiguous |
166 | // TODO: channels last 3d |
167 | if (c10::is_channels_last_strides_2d(sizes, strides)) { |
168 | return StrideInput::TENSOR_CONT_CHANNELS_LAST; |
169 | } |
170 | return StrideInput::TENSOR_CONT; |
171 | } |
172 | |
173 | // Generalize Complete Shapes inputs to Symbolic Shapes. |
174 | // Dimensions of value 1 will be preserved, otherwise |
175 | // dimensions with the same value will be bucketed to the same |
176 | // symbolic shape. |
177 | // E.g. Tensor(5, 3), Tensor(3, 1) -> Tensor(SS(-1), SS(-2)), Tensor(SS(-2), 1) |
178 | // Also summarize input striding behavior. The Size information is stored on the |
179 | // type, The striding is returned. See StrideInput for description of stride |
180 | // specializations |
181 | c10::optional<std::vector<std::vector<StrideInput>>> |
182 | TryGeneralizeInputDimensionsToSymbolicShapes( |
183 | std::shared_ptr<Graph> tensorexpr_graph) { |
184 | std::map<size_t, int64_t> shape_to_sym_shape; |
185 | std::vector<std::vector<StrideInput>> input_striding; |
186 | |
187 | for (Value* v : tensorexpr_graph->inputs()) { |
188 | if (!v->type()->cast<TensorType>()) { |
189 | continue; |
190 | } |
191 | auto tt = v->type()->expectRef<TensorType>(); |
192 | if (!tt.sizes().isComplete() || !tt.strides().isComplete()) { |
193 | return c10::nullopt; |
194 | } |
195 | input_striding.push_back(summarizeInputStrides(tt)); |
196 | std::vector<at::ShapeSymbol> shape_vec = *tt.symbolic_sizes().sizes(); |
197 | auto new_sizes = c10::fmap(shape_vec, [&](const at::ShapeSymbol& shape) { |
198 | auto value = shape.value(); |
199 | TORCH_INTERNAL_ASSERT(value >= 0, "Expected complete tensor" ); |
200 | if (value == 1) { |
201 | return value; |
202 | } else if (shape_to_sym_shape.count(static_cast<size_t>(value))) { |
203 | return shape_to_sym_shape[value]; |
204 | } else { |
205 | auto new_shape_symbol = at::ShapeSymbol::newSymbol().value(); |
206 | shape_to_sym_shape[static_cast<size_t>(value)] = new_shape_symbol; |
207 | return new_shape_symbol; |
208 | } |
209 | }); |
210 | v->setType(tt.withSymbolicShapes(c10::SymbolicShape(new_sizes))); |
211 | } |
212 | return input_striding; |
213 | } |
214 | |
215 | void moveConstantTensorsOutOfSubgraph( |
216 | Node* tensorexpr_graph_node, |
217 | std::shared_ptr<Graph> tensorexpr_graph) { |
218 | auto parent = tensorexpr_graph_node->owningGraph(); |
219 | |
220 | auto env = [&](Value* v) { |
221 | TORCH_INTERNAL_ASSERT( |
222 | false, |
223 | "this should never happen since constant nodes do not have any inputs" , |
224 | v->debugName()); |
225 | return v; |
226 | }; |
227 | |
228 | WithInsertPoint wip(tensorexpr_graph_node); |
229 | std::vector<Node*> to_destroy; |
230 | for (auto node : tensorexpr_graph->nodes()) { |
231 | if (node->kind() == prim::Constant) { |
232 | if (!node->output()->type()->cast<TensorType>()) { |
233 | continue; |
234 | } |
235 | |
236 | // copy the constant and insert that copy into the parent graph. |
237 | auto copy = parent->createClone(node, env); |
238 | parent->insertNode(copy); |
239 | |
240 | // add a new input to the te subgraph and replace the uses of the |
241 | // constant with this input. |
242 | auto new_const = tensorexpr_graph->addInput(); |
243 | new_const->setType(node->output()->type()); |
244 | node->output()->replaceAllUsesWith(new_const); |
245 | |
246 | // add the copy as input to the te node |
247 | tensorexpr_graph_node->addInput(copy->output()); |
248 | |
249 | to_destroy.push_back(node); |
250 | } |
251 | } |
252 | |
253 | for (auto n : to_destroy) { |
254 | n->destroy(); |
255 | } |
256 | } |
257 | |
258 | bool GenerateGuard(Node* tensorexpr_graph_node, bool add_composed_op) { |
259 | auto tensorexpr_graph = SubgraphUtils::getSubgraph(tensorexpr_graph_node); |
260 | |
261 | // Move constant tensors from the subgraph to the outer scope. |
262 | // This is necessary because symbolic shape analysis does not handle the |
263 | // case of broadcast(constant, symbolic_shape) well and that results in poor |
264 | // performance. |
265 | moveConstantTensorsOutOfSubgraph(tensorexpr_graph_node, tensorexpr_graph); |
266 | |
267 | // Generalize Inputs |
268 | auto input_striding = |
269 | TryGeneralizeInputDimensionsToSymbolicShapes(tensorexpr_graph); |
270 | if (!input_striding) { |
271 | return false; |
272 | } |
273 | |
274 | // Get output striding behavior |
275 | std::vector<StrideInput> output_striding; |
276 | for (Value* v : tensorexpr_graph->outputs()) { |
277 | if (!v->type()->cast<TensorType>()) { |
278 | continue; |
279 | } |
280 | auto tt = v->type()->expectRef<TensorType>(); |
281 | if (!tt.sizes().isComplete() || !tt.strides().isComplete()) { |
282 | return false; |
283 | } |
284 | output_striding.push_back(summarizeOutputStrides(tt)); |
285 | } |
286 | |
287 | // Try To Propagate Shapes |
288 | auto maybe_shape_compute_mapping = |
289 | PropagateShapesAndBuildLargeShapeComputeGraph( |
290 | tensorexpr_graph, |
291 | *tensorexpr_graph->nodes().begin(), |
292 | *tensorexpr_graph->nodes().end()); |
293 | if (!maybe_shape_compute_mapping) { |
294 | return false; |
295 | } |
296 | |
297 | // Insert Guard |
298 | insertDynamicShapesGuard( |
299 | *maybe_shape_compute_mapping, |
300 | tensorexpr_graph_node, |
301 | add_composed_op, |
302 | *input_striding, |
303 | output_striding); |
304 | return true; |
305 | } |
306 | |
307 | void inlineFallbackGraphAndAddSRCopyOutOp(std::shared_ptr<Graph> graph) { |
308 | DepthFirstGraphNodeIterator it(graph); |
309 | |
310 | Node* n = nullptr; |
311 | while ((n = it.next()) != nullptr) { |
312 | if (n->kind() == prim::FallbackGraph) { |
313 | break; |
314 | } |
315 | } |
316 | TORCH_INTERNAL_ASSERT(n != nullptr, "Expected to find fallback graph" ); |
317 | |
318 | auto if_node = n->owningBlock()->owningNode(); |
319 | IfView if_v(if_node); |
320 | SubgraphUtils::unmergeSubgraph(n); |
321 | |
322 | auto false_block = if_v.elseBlock(); |
323 | std::vector<Value*> false_block_outputs( |
324 | if_v.elseOutputs().begin(), if_v.elseOutputs().end()); |
325 | TORCH_INTERNAL_ASSERT(!false_block_outputs.empty()); |
326 | |
327 | for (auto out : false_block_outputs) { |
328 | TORCH_INTERNAL_ASSERT(out->type()->cast<TensorType>()); |
329 | } |
330 | auto copy_node = graph->create( |
331 | prim::StaticRuntimeCopyOuts, |
332 | false_block_outputs, |
333 | false_block_outputs.size()); |
334 | false_block->appendNode(copy_node); |
335 | for (size_t i = 0; i < false_block_outputs.size(); ++i) { |
336 | false_block->replaceOutput(i, copy_node->outputs().at(i)); |
337 | } |
338 | } |
339 | |
340 | // TODO: share more logic with tensorexpr_fuser ? |
341 | void insertDynamicShapesGuard( |
342 | const ShapeComputeGraphMapping& shape_mapping, |
343 | Node* guarded_node, |
344 | bool add_composed_op, |
345 | std::vector<std::vector<StrideInput>>& input_info, |
346 | std::vector<StrideInput>& output_strides) { |
347 | GRAPH_DEBUG( |
348 | "Inserting a prim::TensorExprDynamicGuard guard for a node" , |
349 | *guarded_node); |
350 | auto subgraph = SubgraphUtils::getSubgraph(guarded_node); |
351 | |
352 | // Fixup types of the subgraph inputs |
353 | std::vector<Value*> inputs_to_check; |
354 | std::vector<TypePtr> guard_types; |
355 | for (const auto i : c10::irange(guarded_node->inputs().size())) { |
356 | Value* node_input = guarded_node->inputs().at(i); |
357 | // We only check inputs of the guarded nodes |
358 | if (!node_input->type()->cast<TensorType>()) { |
359 | continue; |
360 | } |
361 | inputs_to_check.push_back(node_input); |
362 | guard_types.emplace_back( |
363 | subgraph->inputs().at(i)->type()->expect<TensorType>()->withStrides( |
364 | c10::VaryingShape<c10::Stride>())); |
365 | } |
366 | TORCH_INTERNAL_ASSERT(inputs_to_check.size()); |
367 | |
368 | // prim::TensorExprDynamicGuard nodes look like the following: |
369 | // %types_match : bool = prim::TypeCheck[attr:types](%inp1 : Tensor, %inp2 : |
370 | // Tensor) |
371 | // The input tensors are checked against the expected types on attr::types |
372 | // Omitting refining the input Tensors for now because they are not actually |
373 | // used within tensorexpr/kernel.cpp (only the inputs to the Graph are, not |
374 | // the inputs to the node) and we would have to redo the mapping to compute |
375 | // symbolic shapes |
376 | |
377 | Node* typecheck_node = |
378 | guarded_node->owningGraph() |
379 | ->create(Symbol::prim("TensorExprDynamicGuard" ), inputs_to_check, 1) |
380 | ->insertBefore(guarded_node); |
381 | |
382 | typecheck_node->tys_(attr::types, std::move(guard_types)); |
383 | Value* typecheck_result = typecheck_node->output()->setType(BoolType::get()); |
384 | |
385 | // Insert if |
386 | auto versioning_if = |
387 | guarded_node->owningGraph() |
388 | ->create(prim::If, {typecheck_result}, guarded_node->outputs().size()) |
389 | ->insertAfter(typecheck_node); |
390 | |
391 | for (size_t idx = 0; idx < guarded_node->outputs().size(); ++idx) { |
392 | versioning_if->output(idx)->setType(guarded_node->output(idx)->type()); |
393 | guarded_node->output(idx)->replaceAllUsesWith(versioning_if->output(idx)); |
394 | } |
395 | auto true_block = versioning_if->addBlock(); |
396 | auto false_block = versioning_if->addBlock(); |
397 | |
398 | // Fill in the false block. It should contain the unoptimized |
399 | // copy of the fused subgraph. |
400 | WithInsertPoint guard(false_block->return_node()); |
401 | const auto subgraph_outputs = insertGraph( |
402 | *guarded_node->owningGraph(), *subgraph, guarded_node->inputs()); |
403 | for (Value* output : subgraph_outputs) { |
404 | false_block->registerOutput(output); |
405 | } |
406 | |
407 | // types get copied to the fallback graph, so remove specializations before |
408 | // replacing |
409 | removeTensorTypeSpecializations(false_block); |
410 | replaceBlockWithFallbackGraph(false_block, guarded_node->inputs()); |
411 | |
412 | // Fill in the true block. It has all inputs type-checked and its |
413 | // body should be the fusion group node. |
414 | guarded_node->moveBefore(true_block->return_node()); |
415 | |
416 | for (Value* output : guarded_node->outputs()) { |
417 | true_block->registerOutput(output); |
418 | } |
419 | |
420 | // Insert Symbolic Shapes Compute and add as inputs to TE Node/Graph |
421 | // symbolic_shape_inputs will be a list of each symbolic shape, |
422 | // and the last N inputs to TE Graph/Node will be the N |
423 | // symbolic shape values |
424 | auto map = InsertSymbolicShapesCompute(shape_mapping, guarded_node); |
425 | std::vector<int64_t> symbolic_shape_inputs; |
426 | for (const auto& pair : map) { |
427 | symbolic_shape_inputs.push_back(pair.first); |
428 | guarded_node->addInput(pair.second); |
429 | std::stringstream ss; |
430 | ss << "SS_" << -pair.first; |
431 | subgraph->addInput(ss.str())->setType(IntType::get()); |
432 | } |
433 | guarded_node->is_( |
434 | attr::symbolic_shape_inputs, std::move(symbolic_shape_inputs)); |
435 | |
436 | std::vector<std::vector<std::string>> input_striding; |
437 | for (auto& vec : input_info) { |
438 | auto string_info = |
439 | fmap(vec, [&](StrideInput inp) { return toString(inp); }); |
440 | input_striding.push_back(string_info); |
441 | } |
442 | auto ival = IValue(input_striding); |
443 | guarded_node->ival_(attr::striding_inputs_desc, ival); |
444 | typecheck_node->ival_(attr::striding_inputs_desc, std::move(ival)); |
445 | |
446 | for (Value* v : subgraph->inputs()) { |
447 | if (auto t = v->type()->cast<TensorType>()) { |
448 | v->setType(t->withStrides(c10::VaryingShape<c10::Stride>())); |
449 | } |
450 | } |
451 | for (Value* v : subgraph->outputs()) { |
452 | if (auto t = v->type()->cast<TensorType>()) { |
453 | v->setType(t->withStrides(c10::VaryingShape<c10::Stride>())); |
454 | } |
455 | } |
456 | |
457 | std::vector<std::string> output_striding = |
458 | fmap(output_strides, [&](StrideInput inp) { return toString(inp); }); |
459 | auto output_ival = IValue(output_striding); |
460 | guarded_node->ival_(attr::striding_outputs_desc, std::move(output_ival)); |
461 | |
462 | if (add_composed_op) { |
463 | // only in SR flow do we check for values on the stack and |
464 | // forward them along as tensor outputs |
465 | // TODO: - refactor and make explicit part of TE Kernel api |
466 | guarded_node->i_(attr::allow_stack_outputs, 1); |
467 | |
468 | // Create a TensorExprDynamicGroup node |
469 | auto te_dyn_group = SubgraphUtils::createSingletonSubgraph( |
470 | typecheck_node, prim::TensorExprDynamicGroup); |
471 | SubgraphUtils::mergeNodeIntoSubgraph(versioning_if, te_dyn_group); |
472 | inlineFallbackGraphAndAddSRCopyOutOp( |
473 | SubgraphUtils::getSubgraph(te_dyn_group)); |
474 | } |
475 | } |
476 | |
477 | // This operator is inserted at the end of the fallback block computing outputs |
478 | // for the fusion group. We convert block1(): |
479 | // %14 : Tensor = aten::mul(%0, %1) |
480 | // %15 : Tensor = aten::mul(%0, %14) |
481 | // -> (%15, %14) |
482 | // return (%3, %4) |
483 | // to |
484 | // block1(): |
485 | // %14 : Tensor = aten::mul(%0, %1) |
486 | // %15 : Tensor = aten::mul(%0, %14) |
487 | // %16 : Tensor, %17 : Tensor = prim::StaticRuntimeCopyOuts(%15, %14) |
488 | // -> (%16, %17) |
489 | // Every output of the block is added as an input, and for each input there is |
490 | // a StaticRuntimeCopyOuts output. SR invokes the composed operator first with |
491 | // no tensors on the stack, in which case the Op will just return back the |
492 | // inputs. Second it invokes it with pre-allocated tensors, one for each output |
493 | // of the Fusion group, which is the same number of outputs of the fallback |
494 | // block. In this case we copy over the values of the inputs to pre-allocated |
495 | // tensors |
496 | // Note: this logic is meant to reflect the invocation of the TE Kernel |
497 | // and `runWithAllocatedOutputs` in tensorexpr_fuser.cpp |
498 | Operation StaticRuntimeCopyOuts(const Node* node) { |
499 | auto num_ten_inputs = node->inputs().size(); |
500 | return [num_ten_inputs](Stack& stack) { |
501 | std::vector<IValue> inputs = pop(stack, num_ten_inputs); |
502 | // uncommon case - first run |
503 | if (stack.empty()) { |
504 | for (IValue elem : inputs) { |
505 | push(stack, std::move(elem)); |
506 | } |
507 | } else { |
508 | at::ArrayRef<IValue> outputs = last(stack, num_ten_inputs); |
509 | for (size_t i = 0; i < inputs.size(); ++i) { |
510 | IValue out = outputs[i]; |
511 | at::Tensor& out_t = out.toTensor(); |
512 | fastResizeToZero(out_t); |
513 | out_t.resize_as_(inputs[i].toTensor()); |
514 | out_t.copy_(inputs[i].toTensor()); |
515 | } |
516 | } |
517 | return 0; |
518 | }; |
519 | } |
520 | |
521 | RegisterOperators SRCopyOuts({ |
522 | torch::jit::Operator( |
523 | prim::StaticRuntimeCopyOuts, |
524 | StaticRuntimeCopyOuts, |
525 | AliasAnalysisKind::CONSERVATIVE), |
526 | }); |
527 | |
528 | // On each invocation of this guard, we need to check all of the static |
529 | // information (dtype/device/requires grad/contiguity/static dims), |
530 | // and also the that the symbolic shape dimensions are observed. |
531 | // For any symbolic dimension we need to set its value on its first |
532 | // use and for all subsequent uses check that the values are equal |
533 | RegisterOperators reg_guard({ |
534 | Operator( |
535 | "prim::TensorExprDynamicGuard(...) -> bool" , |
536 | [](const Node* node) -> Operation { |
537 | const auto& types = node->tys(attr::types); |
538 | |
539 | // Each inputs expected # of dims |
540 | std::vector<size_t> expected_dims; |
541 | |
542 | // A flattened vector of all the expected values for all |
543 | // tensor dims. A positive value corresponds to a static |
544 | // shape to check and a negative value corresponds to symbolic |
545 | // dimension index to check |
546 | std::vector<int64_t> flattened_input_dims; |
547 | |
548 | // Each inputs expected scalar types |
549 | std::vector<c10::ScalarType> expected_scalar_types; |
550 | |
551 | // Map from symbolic dimension value to its set's index |
552 | std::map<int64_t, size_t> sym_dim_flat_index; |
553 | TORCH_INTERNAL_ASSERT(!types.empty()); |
554 | |
555 | // we should just be fusing fusion groups with a single device |
556 | // and with tensors not requiring grad |
557 | auto maybe_device = types[0]->expect<TensorType>()->device(); |
558 | TORCH_INTERNAL_ASSERT(maybe_device); |
559 | auto device = *maybe_device; |
560 | |
561 | // flattened vector of each inputs striding behavior |
562 | std::vector<StrideInput> flattened_input_striding; |
563 | const IValue& sym_strides = node->ival(attr::striding_inputs_desc); |
564 | std::vector<std::vector<std::string>> sym_strides_strs = |
565 | sym_strides.to<std::vector<std::vector<std::string>>>(); |
566 | for (const auto& vec : sym_strides_strs) { |
567 | std::vector<StrideInput> input_desc; |
568 | for (const std::string& str : vec) { |
569 | flattened_input_striding.push_back(strideInputFromString(str)); |
570 | } |
571 | } |
572 | |
573 | for (const auto& type : types) { |
574 | auto tt = type->expect<TensorType>(); |
575 | auto ss = tt->symbolic_sizes(); |
576 | TORCH_INTERNAL_ASSERT(ss.rank()); |
577 | expected_dims.push_back(*ss.rank()); |
578 | TORCH_INTERNAL_ASSERT(tt->scalarType()); |
579 | expected_scalar_types.push_back(*tt->scalarType()); |
580 | TORCH_INTERNAL_ASSERT(tt->device() && *tt->device() == device); |
581 | for (size_t i = 0; i < *ss.rank(); ++i) { |
582 | auto sym_dim = ss[i]; |
583 | auto value = sym_dim.value(); |
584 | if (value >= 0) { |
585 | flattened_input_dims.push_back(value); |
586 | } else { |
587 | // use index for set if it exists, otherwise extend the vector |
588 | // of sym shapes by 1 |
589 | int64_t sym_dim_index; |
590 | if (sym_dim_flat_index.count(value)) { |
591 | sym_dim_index = sym_dim_flat_index[value]; |
592 | } else { |
593 | auto size = sym_dim_flat_index.size(); |
594 | sym_dim_flat_index[value] = (-1) - size; |
595 | sym_dim_index = sym_dim_flat_index[value]; |
596 | } |
597 | // TODO: potential optimization - if there is a Symbolic |
598 | // Sym with only one use we dont need to test anything |
599 | flattened_input_dims.push_back(sym_dim_index); |
600 | } |
601 | } |
602 | } |
603 | |
604 | const auto num_inputs = types.size(); |
605 | const auto num_symbolic_dims = sym_dim_flat_index.size(); |
606 | return [num_inputs, |
607 | expected_dims, |
608 | device, |
609 | expected_scalar_types, |
610 | flattened_input_dims, |
611 | flattened_input_striding, |
612 | num_symbolic_dims](Stack& stack) { |
613 | at::ArrayRef<IValue> inputs = last(stack, num_inputs); |
614 | drop(stack, num_inputs); |
615 | // each invocation we need to reset what value of each symbolic |
616 | // symbol is. |
617 | // TODO: could this be a reference and not allocated on |
618 | // each invocation or would that mess up with multithreaded |
619 | // inference since we are writing to it? |
620 | // TODO - smallvector here ? |
621 | bool grad_mode_enabled = at::GradMode::is_enabled(); |
622 | std::vector<int64_t> flattened_symbolic_dims(num_symbolic_dims, -1); |
623 | size_t flattened_dim_offset = 0; |
624 | size_t flattened_stride_offset = 0; |
625 | for (const auto i : c10::irange(num_inputs)) { |
626 | at::Tensor tensor = inputs[i].toTensor(); |
627 | if (C10_UNLIKELY( |
628 | tensor.device() != device || |
629 | tensor.dtype() != expected_scalar_types[i])) { |
630 | push(stack, false); |
631 | return; |
632 | } |
633 | if (C10_UNLIKELY(grad_mode_enabled && tensor.requires_grad())) { |
634 | push(stack, false); |
635 | return; |
636 | } |
637 | const auto& sizes = tensor.sizes(); |
638 | const auto num_dims = sizes.size(); |
639 | if (C10_UNLIKELY(num_dims != expected_dims[i])) { |
640 | push(stack, false); |
641 | return; |
642 | } |
643 | auto striding = flattened_input_striding[flattened_stride_offset]; |
644 | // Tensors natively store whether they are contiguous |
645 | // in the default memory format or in channels last, |
646 | // so it is more efficient to query whether they follow this |
647 | // property than iterating over dimensions and checking yourself |
648 | if (striding == StrideInput::TENSOR_CONT) { |
649 | if (C10_UNLIKELY( |
650 | !tensor.is_contiguous(at::MemoryFormat::Contiguous))) { |
651 | push(stack, false); |
652 | return; |
653 | } |
654 | flattened_stride_offset += 1; |
655 | } else if (striding == StrideInput::TENSOR_CONT_CHANNELS_LAST) { |
656 | // TODO: 5D channels last |
657 | if (C10_UNLIKELY(!tensor.is_contiguous( |
658 | at::MemoryFormat::ChannelsLast))) { |
659 | push(stack, false); |
660 | return; |
661 | } |
662 | flattened_stride_offset += 1; |
663 | } else { |
664 | auto strides = tensor.strides(); |
665 | for (size_t dim = 0; dim < num_dims; ++dim) { |
666 | auto summarized_dim = summarizeStrideDim( |
667 | sizes, |
668 | strides, |
669 | dim, |
670 | flattened_input_striding, |
671 | flattened_stride_offset); |
672 | if (C10_UNLIKELY( |
673 | summarized_dim != |
674 | flattened_input_striding |
675 | [dim + flattened_stride_offset])) { |
676 | push(stack, false); |
677 | return; |
678 | } |
679 | } |
680 | flattened_stride_offset += num_dims; |
681 | } |
682 | for (const auto dim_index : c10::irange(num_dims)) { |
683 | const int64_t dim_value = |
684 | flattened_input_dims[dim_index + flattened_dim_offset]; |
685 | const int64_t tensor_dim = sizes[dim_index]; |
686 | if (dim_value >= 0) { |
687 | if (C10_UNLIKELY(dim_value != tensor_dim)) { |
688 | push(stack, false); |
689 | return; |
690 | } |
691 | } else { |
692 | // flattened sym indices start at -1, |
693 | // so -1 -> index 0, -2 -> index 1 |
694 | const auto flattened_sym_index = (-dim_value) - 1; |
695 | const auto flattened_sym_value = |
696 | flattened_symbolic_dims[flattened_sym_index]; |
697 | // sym symbol already seen, check value |
698 | if (flattened_symbolic_dims[flattened_sym_index] >= 0) { |
699 | if (C10_UNLIKELY(flattened_sym_value != tensor_dim)) { |
700 | push(stack, false); |
701 | return; |
702 | } |
703 | } else { |
704 | // not seen, write value |
705 | flattened_symbolic_dims[flattened_sym_index] = tensor_dim; |
706 | } |
707 | } |
708 | } |
709 | flattened_dim_offset += num_dims; |
710 | } |
711 | |
712 | push(stack, true); |
713 | return; |
714 | }; |
715 | }, |
716 | aliasAnalysisFromSchema()), |
717 | }); |
718 | |
719 | void runTensorExprDynamicGroup(const Code& code, Stack& stack) { |
720 | InterpreterState interpreter{code}; |
721 | interpreter.run(stack); |
722 | } |
723 | |
724 | Operation createTensorExprDynamicGroup(const Node* node) { |
725 | const auto& graph = node->g(attr::Subgraph); |
726 | Code code(graph, "" ); |
727 | // This implementation creates a Code object and InterpreterState on every |
728 | // call to TensorExprDynamicGroup, which affects performance. Ideally, we |
729 | // should be reusing Code and InterpreterState across calls to this op. |
730 | // But that is resulting in a "No frames found" error. |
731 | // TODO: Improve the performance of this by figuring out a better approach. |
732 | // NB: this is only run in SR, which is single-threaded |
733 | return [code](Stack& stack) { |
734 | runTensorExprDynamicGroup(code, stack); |
735 | return 0; |
736 | }; |
737 | } |
738 | |
739 | RegisterOperators TensorExprDynamicOp({ |
740 | torch::jit::Operator( |
741 | prim::TensorExprDynamicGroup, |
742 | createTensorExprDynamicGroup, |
743 | AliasAnalysisKind::INTERNAL_SPECIAL_CASE), |
744 | }); |
745 | |
746 | } // namespace jit |
747 | } // namespace torch |
748 | |