1#include <ATen/core/functional.h>
2#include <ATen/core/interned_strings.h>
3#include <c10/core/MemoryFormat.h>
4#include <c10/core/ScalarType.h>
5#include <c10/util/Exception.h>
6#include <torch/csrc/jit/ir/ir.h>
7#include <torch/csrc/jit/ir/ir_views.h>
8#include <torch/csrc/jit/jit_log.h>
9#include <torch/csrc/jit/passes/symbolic_shape_runtime_fusion.h>
10#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
11#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
12#include <torch/csrc/jit/runtime/graph_iterator.h>
13#include <torch/csrc/jit/runtime/register_ops_utils.h>
14#include <torch/csrc/jit/runtime/static/ops.h>
15#include <sstream>
16#include <utility>
17
18namespace torch {
19namespace jit {
20
21// Inserts the Compute for Each Symbolic Shape in the TensorExpr Graph
22// and returns back a map from Symbolic Shape Value to its runtime Value *
23std::map<int64_t, Value*> InsertSymbolicShapesCompute(
24 const ShapeComputeGraphMapping& shape_mapping,
25 Node* tensorexpr_graph) {
26 WithInsertPoint guard(tensorexpr_graph);
27 auto enclosing_graph = tensorexpr_graph->owningGraph();
28
29 std::map<Value*, Value*> shape_graph_input_to_enclosing_graph_value;
30 for (const auto& pair :
31 shape_mapping.enclosing_graph_value_to_shape_graph_input_) {
32 shape_graph_input_to_enclosing_graph_value[pair.second] = pair.first;
33 }
34 std::vector<Value*> shape_compute_graph_inputs;
35 for (Value* shape_graph_input :
36 shape_mapping.partial_eval_shape_graph->inputs()) {
37 auto enclosing_graph_input =
38 shape_graph_input_to_enclosing_graph_value.find(shape_graph_input);
39 TORCH_INTERNAL_ASSERT(
40 enclosing_graph_input !=
41 shape_graph_input_to_enclosing_graph_value.end());
42 if (*enclosing_graph_input->second->type() == *shape_graph_input->type()) {
43 shape_compute_graph_inputs.push_back(tensorexpr_graph->inputs().at(
44 enclosing_graph_input->second->offset()));
45 } else {
46 TORCH_INTERNAL_ASSERT(
47 enclosing_graph_input->second->type()->cast<TensorType>() &&
48 shape_graph_input->type()->isSubtypeOf(ListType::ofInts()));
49 shape_compute_graph_inputs.push_back(enclosing_graph->insert(
50 aten::size,
51 {tensorexpr_graph->inputs().at(
52 enclosing_graph_input->second->offset())}));
53 }
54 }
55 auto sym_shape_values = insertGraph(
56 *enclosing_graph,
57 *shape_mapping.partial_eval_shape_graph,
58 shape_compute_graph_inputs);
59 std::map<int64_t, Value*> sym_shape_to_enclosing_graph_value;
60 for (size_t i = 0;
61 i < shape_mapping.partial_eval_shape_graph->outputs().size();
62 ++i) {
63 Value* output = shape_mapping.partial_eval_shape_graph->outputs().at(i);
64 auto sym_shape =
65 shape_mapping.graph_output_to_symbolic_shape_dim_.find(output);
66 TORCH_INTERNAL_ASSERT(
67 sym_shape != shape_mapping.graph_output_to_symbolic_shape_dim_.end());
68 sym_shape_to_enclosing_graph_value[sym_shape->second] = sym_shape_values[i];
69 }
70 return sym_shape_to_enclosing_graph_value;
71}
72
73void insertDynamicShapesGuard(
74 const ShapeComputeGraphMapping& shape_mapping,
75 Node* guarded_node,
76 bool add_composed_op,
77 std::vector<std::vector<StrideInput>>& input_info,
78 std::vector<StrideInput>& output_strides);
79
80std::string toString(StrideInput si) {
81 switch (si) {
82 case StrideInput::TENSOR_CONT:
83 return "TENSOR_CONT";
84 case StrideInput::TENSOR_CONT_CHANNELS_LAST:
85 return "TENSOR_CONT_CHANNELS_LAST";
86 case StrideInput::S_ONE:
87 return "S_ONE";
88 case StrideInput::S_CONT:
89 return "S_CONT";
90 case StrideInput::S_TRAN_CONT:
91 return "S_TRAN_CONT";
92 case StrideInput::S_AS_ARG:
93 return "S_AS_ARG";
94 }
95 TORCH_INTERNAL_ASSERT(false);
96}
97
98StrideInput strideInputFromString(const std::string& si) {
99 if (si == "TENSOR_CONT") {
100 return StrideInput::TENSOR_CONT;
101 } else if (si == "TENSOR_CONT_CHANNELS_LAST") {
102 return StrideInput::TENSOR_CONT_CHANNELS_LAST;
103 } else if (si == "S_ONE") {
104 return StrideInput::S_ONE;
105 } else if (si == "S_CONT") {
106 return StrideInput::S_CONT;
107 } else if (si == "S_TRAN_CONT") {
108 return StrideInput::S_TRAN_CONT;
109 } else if (si == "S_AS_ARG") {
110 return StrideInput::S_AS_ARG;
111 } else {
112 TORCH_INTERNAL_ASSERT(false);
113 }
114}
115
116// in the runtime guard, strides are serialized as one flat
117// vector. stride_inputs_offset indexes into that vector
118// where the strides of this tensor beegin
119inline StrideInput summarizeStrideDim(
120 const c10::IntArrayRef sizes,
121 const c10::IntArrayRef strides,
122 size_t dim,
123 const std::vector<StrideInput>& stride_inputs,
124 size_t stride_inputs_offset) {
125 if (strides[dim] == 1) {
126 return StrideInput::S_ONE;
127 } else if (
128 dim + 1 < sizes.size() &&
129 strides[dim] == strides[dim + 1] * sizes[dim + 1]) {
130 return StrideInput::S_CONT;
131 // Transposed Contiguous depends on prior dim and contiguous depends on next
132 // dim, so to avoid a mutual dependence check that the next dim is Stride
133 // Contiguous
134 } else if (
135 dim > 0 && strides[dim] == strides[dim - 1] * sizes[dim - 1] &&
136 (stride_inputs[dim - 1 + stride_inputs_offset] != StrideInput::S_CONT)) {
137 return StrideInput::S_TRAN_CONT;
138 } else {
139 return StrideInput::S_AS_ARG;
140 }
141}
142
143std::vector<StrideInput> summarizeInputStrides(const TensorType& tt) {
144 auto strides = *tt.strides().concrete_sizes();
145 auto sizes = *tt.sizes().concrete_sizes();
146 if (c10::is_contiguous_strides(sizes, strides)) {
147 return {StrideInput::TENSOR_CONT};
148 // TODO: channels last 3d
149 } else if (c10::is_channels_last_strides_2d(sizes, strides)) {
150 return {StrideInput::TENSOR_CONT_CHANNELS_LAST};
151 }
152 std::vector<StrideInput> stride_inputs;
153 for (size_t dim = 0; dim < sizes.size(); ++dim) {
154 stride_inputs.push_back(
155 summarizeStrideDim(sizes, strides, dim, stride_inputs, 0));
156 }
157 return stride_inputs;
158};
159
160// Todo: incorporate in codegen
161StrideInput summarizeOutputStrides(const TensorType& tt) {
162 auto strides = *tt.strides().concrete_sizes();
163 auto sizes = *tt.sizes().concrete_sizes();
164 // We only try to maintain output striding for channels last tensors,
165 // otherwise we defer to contiguous
166 // TODO: channels last 3d
167 if (c10::is_channels_last_strides_2d(sizes, strides)) {
168 return StrideInput::TENSOR_CONT_CHANNELS_LAST;
169 }
170 return StrideInput::TENSOR_CONT;
171}
172
173// Generalize Complete Shapes inputs to Symbolic Shapes.
174// Dimensions of value 1 will be preserved, otherwise
175// dimensions with the same value will be bucketed to the same
176// symbolic shape.
177// E.g. Tensor(5, 3), Tensor(3, 1) -> Tensor(SS(-1), SS(-2)), Tensor(SS(-2), 1)
178// Also summarize input striding behavior. The Size information is stored on the
179// type, The striding is returned. See StrideInput for description of stride
180// specializations
181c10::optional<std::vector<std::vector<StrideInput>>>
182TryGeneralizeInputDimensionsToSymbolicShapes(
183 std::shared_ptr<Graph> tensorexpr_graph) {
184 std::map<size_t, int64_t> shape_to_sym_shape;
185 std::vector<std::vector<StrideInput>> input_striding;
186
187 for (Value* v : tensorexpr_graph->inputs()) {
188 if (!v->type()->cast<TensorType>()) {
189 continue;
190 }
191 auto tt = v->type()->expectRef<TensorType>();
192 if (!tt.sizes().isComplete() || !tt.strides().isComplete()) {
193 return c10::nullopt;
194 }
195 input_striding.push_back(summarizeInputStrides(tt));
196 std::vector<at::ShapeSymbol> shape_vec = *tt.symbolic_sizes().sizes();
197 auto new_sizes = c10::fmap(shape_vec, [&](const at::ShapeSymbol& shape) {
198 auto value = shape.value();
199 TORCH_INTERNAL_ASSERT(value >= 0, "Expected complete tensor");
200 if (value == 1) {
201 return value;
202 } else if (shape_to_sym_shape.count(static_cast<size_t>(value))) {
203 return shape_to_sym_shape[value];
204 } else {
205 auto new_shape_symbol = at::ShapeSymbol::newSymbol().value();
206 shape_to_sym_shape[static_cast<size_t>(value)] = new_shape_symbol;
207 return new_shape_symbol;
208 }
209 });
210 v->setType(tt.withSymbolicShapes(c10::SymbolicShape(new_sizes)));
211 }
212 return input_striding;
213}
214
215void moveConstantTensorsOutOfSubgraph(
216 Node* tensorexpr_graph_node,
217 std::shared_ptr<Graph> tensorexpr_graph) {
218 auto parent = tensorexpr_graph_node->owningGraph();
219
220 auto env = [&](Value* v) {
221 TORCH_INTERNAL_ASSERT(
222 false,
223 "this should never happen since constant nodes do not have any inputs",
224 v->debugName());
225 return v;
226 };
227
228 WithInsertPoint wip(tensorexpr_graph_node);
229 std::vector<Node*> to_destroy;
230 for (auto node : tensorexpr_graph->nodes()) {
231 if (node->kind() == prim::Constant) {
232 if (!node->output()->type()->cast<TensorType>()) {
233 continue;
234 }
235
236 // copy the constant and insert that copy into the parent graph.
237 auto copy = parent->createClone(node, env);
238 parent->insertNode(copy);
239
240 // add a new input to the te subgraph and replace the uses of the
241 // constant with this input.
242 auto new_const = tensorexpr_graph->addInput();
243 new_const->setType(node->output()->type());
244 node->output()->replaceAllUsesWith(new_const);
245
246 // add the copy as input to the te node
247 tensorexpr_graph_node->addInput(copy->output());
248
249 to_destroy.push_back(node);
250 }
251 }
252
253 for (auto n : to_destroy) {
254 n->destroy();
255 }
256}
257
258bool GenerateGuard(Node* tensorexpr_graph_node, bool add_composed_op) {
259 auto tensorexpr_graph = SubgraphUtils::getSubgraph(tensorexpr_graph_node);
260
261 // Move constant tensors from the subgraph to the outer scope.
262 // This is necessary because symbolic shape analysis does not handle the
263 // case of broadcast(constant, symbolic_shape) well and that results in poor
264 // performance.
265 moveConstantTensorsOutOfSubgraph(tensorexpr_graph_node, tensorexpr_graph);
266
267 // Generalize Inputs
268 auto input_striding =
269 TryGeneralizeInputDimensionsToSymbolicShapes(tensorexpr_graph);
270 if (!input_striding) {
271 return false;
272 }
273
274 // Get output striding behavior
275 std::vector<StrideInput> output_striding;
276 for (Value* v : tensorexpr_graph->outputs()) {
277 if (!v->type()->cast<TensorType>()) {
278 continue;
279 }
280 auto tt = v->type()->expectRef<TensorType>();
281 if (!tt.sizes().isComplete() || !tt.strides().isComplete()) {
282 return false;
283 }
284 output_striding.push_back(summarizeOutputStrides(tt));
285 }
286
287 // Try To Propagate Shapes
288 auto maybe_shape_compute_mapping =
289 PropagateShapesAndBuildLargeShapeComputeGraph(
290 tensorexpr_graph,
291 *tensorexpr_graph->nodes().begin(),
292 *tensorexpr_graph->nodes().end());
293 if (!maybe_shape_compute_mapping) {
294 return false;
295 }
296
297 // Insert Guard
298 insertDynamicShapesGuard(
299 *maybe_shape_compute_mapping,
300 tensorexpr_graph_node,
301 add_composed_op,
302 *input_striding,
303 output_striding);
304 return true;
305}
306
307void inlineFallbackGraphAndAddSRCopyOutOp(std::shared_ptr<Graph> graph) {
308 DepthFirstGraphNodeIterator it(graph);
309
310 Node* n = nullptr;
311 while ((n = it.next()) != nullptr) {
312 if (n->kind() == prim::FallbackGraph) {
313 break;
314 }
315 }
316 TORCH_INTERNAL_ASSERT(n != nullptr, "Expected to find fallback graph");
317
318 auto if_node = n->owningBlock()->owningNode();
319 IfView if_v(if_node);
320 SubgraphUtils::unmergeSubgraph(n);
321
322 auto false_block = if_v.elseBlock();
323 std::vector<Value*> false_block_outputs(
324 if_v.elseOutputs().begin(), if_v.elseOutputs().end());
325 TORCH_INTERNAL_ASSERT(!false_block_outputs.empty());
326
327 for (auto out : false_block_outputs) {
328 TORCH_INTERNAL_ASSERT(out->type()->cast<TensorType>());
329 }
330 auto copy_node = graph->create(
331 prim::StaticRuntimeCopyOuts,
332 false_block_outputs,
333 false_block_outputs.size());
334 false_block->appendNode(copy_node);
335 for (size_t i = 0; i < false_block_outputs.size(); ++i) {
336 false_block->replaceOutput(i, copy_node->outputs().at(i));
337 }
338}
339
340// TODO: share more logic with tensorexpr_fuser ?
341void insertDynamicShapesGuard(
342 const ShapeComputeGraphMapping& shape_mapping,
343 Node* guarded_node,
344 bool add_composed_op,
345 std::vector<std::vector<StrideInput>>& input_info,
346 std::vector<StrideInput>& output_strides) {
347 GRAPH_DEBUG(
348 "Inserting a prim::TensorExprDynamicGuard guard for a node",
349 *guarded_node);
350 auto subgraph = SubgraphUtils::getSubgraph(guarded_node);
351
352 // Fixup types of the subgraph inputs
353 std::vector<Value*> inputs_to_check;
354 std::vector<TypePtr> guard_types;
355 for (const auto i : c10::irange(guarded_node->inputs().size())) {
356 Value* node_input = guarded_node->inputs().at(i);
357 // We only check inputs of the guarded nodes
358 if (!node_input->type()->cast<TensorType>()) {
359 continue;
360 }
361 inputs_to_check.push_back(node_input);
362 guard_types.emplace_back(
363 subgraph->inputs().at(i)->type()->expect<TensorType>()->withStrides(
364 c10::VaryingShape<c10::Stride>()));
365 }
366 TORCH_INTERNAL_ASSERT(inputs_to_check.size());
367
368 // prim::TensorExprDynamicGuard nodes look like the following:
369 // %types_match : bool = prim::TypeCheck[attr:types](%inp1 : Tensor, %inp2 :
370 // Tensor)
371 // The input tensors are checked against the expected types on attr::types
372 // Omitting refining the input Tensors for now because they are not actually
373 // used within tensorexpr/kernel.cpp (only the inputs to the Graph are, not
374 // the inputs to the node) and we would have to redo the mapping to compute
375 // symbolic shapes
376
377 Node* typecheck_node =
378 guarded_node->owningGraph()
379 ->create(Symbol::prim("TensorExprDynamicGuard"), inputs_to_check, 1)
380 ->insertBefore(guarded_node);
381
382 typecheck_node->tys_(attr::types, std::move(guard_types));
383 Value* typecheck_result = typecheck_node->output()->setType(BoolType::get());
384
385 // Insert if
386 auto versioning_if =
387 guarded_node->owningGraph()
388 ->create(prim::If, {typecheck_result}, guarded_node->outputs().size())
389 ->insertAfter(typecheck_node);
390
391 for (size_t idx = 0; idx < guarded_node->outputs().size(); ++idx) {
392 versioning_if->output(idx)->setType(guarded_node->output(idx)->type());
393 guarded_node->output(idx)->replaceAllUsesWith(versioning_if->output(idx));
394 }
395 auto true_block = versioning_if->addBlock();
396 auto false_block = versioning_if->addBlock();
397
398 // Fill in the false block. It should contain the unoptimized
399 // copy of the fused subgraph.
400 WithInsertPoint guard(false_block->return_node());
401 const auto subgraph_outputs = insertGraph(
402 *guarded_node->owningGraph(), *subgraph, guarded_node->inputs());
403 for (Value* output : subgraph_outputs) {
404 false_block->registerOutput(output);
405 }
406
407 // types get copied to the fallback graph, so remove specializations before
408 // replacing
409 removeTensorTypeSpecializations(false_block);
410 replaceBlockWithFallbackGraph(false_block, guarded_node->inputs());
411
412 // Fill in the true block. It has all inputs type-checked and its
413 // body should be the fusion group node.
414 guarded_node->moveBefore(true_block->return_node());
415
416 for (Value* output : guarded_node->outputs()) {
417 true_block->registerOutput(output);
418 }
419
420 // Insert Symbolic Shapes Compute and add as inputs to TE Node/Graph
421 // symbolic_shape_inputs will be a list of each symbolic shape,
422 // and the last N inputs to TE Graph/Node will be the N
423 // symbolic shape values
424 auto map = InsertSymbolicShapesCompute(shape_mapping, guarded_node);
425 std::vector<int64_t> symbolic_shape_inputs;
426 for (const auto& pair : map) {
427 symbolic_shape_inputs.push_back(pair.first);
428 guarded_node->addInput(pair.second);
429 std::stringstream ss;
430 ss << "SS_" << -pair.first;
431 subgraph->addInput(ss.str())->setType(IntType::get());
432 }
433 guarded_node->is_(
434 attr::symbolic_shape_inputs, std::move(symbolic_shape_inputs));
435
436 std::vector<std::vector<std::string>> input_striding;
437 for (auto& vec : input_info) {
438 auto string_info =
439 fmap(vec, [&](StrideInput inp) { return toString(inp); });
440 input_striding.push_back(string_info);
441 }
442 auto ival = IValue(input_striding);
443 guarded_node->ival_(attr::striding_inputs_desc, ival);
444 typecheck_node->ival_(attr::striding_inputs_desc, std::move(ival));
445
446 for (Value* v : subgraph->inputs()) {
447 if (auto t = v->type()->cast<TensorType>()) {
448 v->setType(t->withStrides(c10::VaryingShape<c10::Stride>()));
449 }
450 }
451 for (Value* v : subgraph->outputs()) {
452 if (auto t = v->type()->cast<TensorType>()) {
453 v->setType(t->withStrides(c10::VaryingShape<c10::Stride>()));
454 }
455 }
456
457 std::vector<std::string> output_striding =
458 fmap(output_strides, [&](StrideInput inp) { return toString(inp); });
459 auto output_ival = IValue(output_striding);
460 guarded_node->ival_(attr::striding_outputs_desc, std::move(output_ival));
461
462 if (add_composed_op) {
463 // only in SR flow do we check for values on the stack and
464 // forward them along as tensor outputs
465 // TODO: - refactor and make explicit part of TE Kernel api
466 guarded_node->i_(attr::allow_stack_outputs, 1);
467
468 // Create a TensorExprDynamicGroup node
469 auto te_dyn_group = SubgraphUtils::createSingletonSubgraph(
470 typecheck_node, prim::TensorExprDynamicGroup);
471 SubgraphUtils::mergeNodeIntoSubgraph(versioning_if, te_dyn_group);
472 inlineFallbackGraphAndAddSRCopyOutOp(
473 SubgraphUtils::getSubgraph(te_dyn_group));
474 }
475}
476
477// This operator is inserted at the end of the fallback block computing outputs
478// for the fusion group. We convert block1():
479// %14 : Tensor = aten::mul(%0, %1)
480// %15 : Tensor = aten::mul(%0, %14)
481// -> (%15, %14)
482// return (%3, %4)
483// to
484// block1():
485// %14 : Tensor = aten::mul(%0, %1)
486// %15 : Tensor = aten::mul(%0, %14)
487// %16 : Tensor, %17 : Tensor = prim::StaticRuntimeCopyOuts(%15, %14)
488// -> (%16, %17)
489// Every output of the block is added as an input, and for each input there is
490// a StaticRuntimeCopyOuts output. SR invokes the composed operator first with
491// no tensors on the stack, in which case the Op will just return back the
492// inputs. Second it invokes it with pre-allocated tensors, one for each output
493// of the Fusion group, which is the same number of outputs of the fallback
494// block. In this case we copy over the values of the inputs to pre-allocated
495// tensors
496// Note: this logic is meant to reflect the invocation of the TE Kernel
497// and `runWithAllocatedOutputs` in tensorexpr_fuser.cpp
498Operation StaticRuntimeCopyOuts(const Node* node) {
499 auto num_ten_inputs = node->inputs().size();
500 return [num_ten_inputs](Stack& stack) {
501 std::vector<IValue> inputs = pop(stack, num_ten_inputs);
502 // uncommon case - first run
503 if (stack.empty()) {
504 for (IValue elem : inputs) {
505 push(stack, std::move(elem));
506 }
507 } else {
508 at::ArrayRef<IValue> outputs = last(stack, num_ten_inputs);
509 for (size_t i = 0; i < inputs.size(); ++i) {
510 IValue out = outputs[i];
511 at::Tensor& out_t = out.toTensor();
512 fastResizeToZero(out_t);
513 out_t.resize_as_(inputs[i].toTensor());
514 out_t.copy_(inputs[i].toTensor());
515 }
516 }
517 return 0;
518 };
519}
520
521RegisterOperators SRCopyOuts({
522 torch::jit::Operator(
523 prim::StaticRuntimeCopyOuts,
524 StaticRuntimeCopyOuts,
525 AliasAnalysisKind::CONSERVATIVE),
526});
527
528// On each invocation of this guard, we need to check all of the static
529// information (dtype/device/requires grad/contiguity/static dims),
530// and also the that the symbolic shape dimensions are observed.
531// For any symbolic dimension we need to set its value on its first
532// use and for all subsequent uses check that the values are equal
533RegisterOperators reg_guard({
534 Operator(
535 "prim::TensorExprDynamicGuard(...) -> bool",
536 [](const Node* node) -> Operation {
537 const auto& types = node->tys(attr::types);
538
539 // Each inputs expected # of dims
540 std::vector<size_t> expected_dims;
541
542 // A flattened vector of all the expected values for all
543 // tensor dims. A positive value corresponds to a static
544 // shape to check and a negative value corresponds to symbolic
545 // dimension index to check
546 std::vector<int64_t> flattened_input_dims;
547
548 // Each inputs expected scalar types
549 std::vector<c10::ScalarType> expected_scalar_types;
550
551 // Map from symbolic dimension value to its set's index
552 std::map<int64_t, size_t> sym_dim_flat_index;
553 TORCH_INTERNAL_ASSERT(!types.empty());
554
555 // we should just be fusing fusion groups with a single device
556 // and with tensors not requiring grad
557 auto maybe_device = types[0]->expect<TensorType>()->device();
558 TORCH_INTERNAL_ASSERT(maybe_device);
559 auto device = *maybe_device;
560
561 // flattened vector of each inputs striding behavior
562 std::vector<StrideInput> flattened_input_striding;
563 const IValue& sym_strides = node->ival(attr::striding_inputs_desc);
564 std::vector<std::vector<std::string>> sym_strides_strs =
565 sym_strides.to<std::vector<std::vector<std::string>>>();
566 for (const auto& vec : sym_strides_strs) {
567 std::vector<StrideInput> input_desc;
568 for (const std::string& str : vec) {
569 flattened_input_striding.push_back(strideInputFromString(str));
570 }
571 }
572
573 for (const auto& type : types) {
574 auto tt = type->expect<TensorType>();
575 auto ss = tt->symbolic_sizes();
576 TORCH_INTERNAL_ASSERT(ss.rank());
577 expected_dims.push_back(*ss.rank());
578 TORCH_INTERNAL_ASSERT(tt->scalarType());
579 expected_scalar_types.push_back(*tt->scalarType());
580 TORCH_INTERNAL_ASSERT(tt->device() && *tt->device() == device);
581 for (size_t i = 0; i < *ss.rank(); ++i) {
582 auto sym_dim = ss[i];
583 auto value = sym_dim.value();
584 if (value >= 0) {
585 flattened_input_dims.push_back(value);
586 } else {
587 // use index for set if it exists, otherwise extend the vector
588 // of sym shapes by 1
589 int64_t sym_dim_index;
590 if (sym_dim_flat_index.count(value)) {
591 sym_dim_index = sym_dim_flat_index[value];
592 } else {
593 auto size = sym_dim_flat_index.size();
594 sym_dim_flat_index[value] = (-1) - size;
595 sym_dim_index = sym_dim_flat_index[value];
596 }
597 // TODO: potential optimization - if there is a Symbolic
598 // Sym with only one use we dont need to test anything
599 flattened_input_dims.push_back(sym_dim_index);
600 }
601 }
602 }
603
604 const auto num_inputs = types.size();
605 const auto num_symbolic_dims = sym_dim_flat_index.size();
606 return [num_inputs,
607 expected_dims,
608 device,
609 expected_scalar_types,
610 flattened_input_dims,
611 flattened_input_striding,
612 num_symbolic_dims](Stack& stack) {
613 at::ArrayRef<IValue> inputs = last(stack, num_inputs);
614 drop(stack, num_inputs);
615 // each invocation we need to reset what value of each symbolic
616 // symbol is.
617 // TODO: could this be a reference and not allocated on
618 // each invocation or would that mess up with multithreaded
619 // inference since we are writing to it?
620 // TODO - smallvector here ?
621 bool grad_mode_enabled = at::GradMode::is_enabled();
622 std::vector<int64_t> flattened_symbolic_dims(num_symbolic_dims, -1);
623 size_t flattened_dim_offset = 0;
624 size_t flattened_stride_offset = 0;
625 for (const auto i : c10::irange(num_inputs)) {
626 at::Tensor tensor = inputs[i].toTensor();
627 if (C10_UNLIKELY(
628 tensor.device() != device ||
629 tensor.dtype() != expected_scalar_types[i])) {
630 push(stack, false);
631 return;
632 }
633 if (C10_UNLIKELY(grad_mode_enabled && tensor.requires_grad())) {
634 push(stack, false);
635 return;
636 }
637 const auto& sizes = tensor.sizes();
638 const auto num_dims = sizes.size();
639 if (C10_UNLIKELY(num_dims != expected_dims[i])) {
640 push(stack, false);
641 return;
642 }
643 auto striding = flattened_input_striding[flattened_stride_offset];
644 // Tensors natively store whether they are contiguous
645 // in the default memory format or in channels last,
646 // so it is more efficient to query whether they follow this
647 // property than iterating over dimensions and checking yourself
648 if (striding == StrideInput::TENSOR_CONT) {
649 if (C10_UNLIKELY(
650 !tensor.is_contiguous(at::MemoryFormat::Contiguous))) {
651 push(stack, false);
652 return;
653 }
654 flattened_stride_offset += 1;
655 } else if (striding == StrideInput::TENSOR_CONT_CHANNELS_LAST) {
656 // TODO: 5D channels last
657 if (C10_UNLIKELY(!tensor.is_contiguous(
658 at::MemoryFormat::ChannelsLast))) {
659 push(stack, false);
660 return;
661 }
662 flattened_stride_offset += 1;
663 } else {
664 auto strides = tensor.strides();
665 for (size_t dim = 0; dim < num_dims; ++dim) {
666 auto summarized_dim = summarizeStrideDim(
667 sizes,
668 strides,
669 dim,
670 flattened_input_striding,
671 flattened_stride_offset);
672 if (C10_UNLIKELY(
673 summarized_dim !=
674 flattened_input_striding
675 [dim + flattened_stride_offset])) {
676 push(stack, false);
677 return;
678 }
679 }
680 flattened_stride_offset += num_dims;
681 }
682 for (const auto dim_index : c10::irange(num_dims)) {
683 const int64_t dim_value =
684 flattened_input_dims[dim_index + flattened_dim_offset];
685 const int64_t tensor_dim = sizes[dim_index];
686 if (dim_value >= 0) {
687 if (C10_UNLIKELY(dim_value != tensor_dim)) {
688 push(stack, false);
689 return;
690 }
691 } else {
692 // flattened sym indices start at -1,
693 // so -1 -> index 0, -2 -> index 1
694 const auto flattened_sym_index = (-dim_value) - 1;
695 const auto flattened_sym_value =
696 flattened_symbolic_dims[flattened_sym_index];
697 // sym symbol already seen, check value
698 if (flattened_symbolic_dims[flattened_sym_index] >= 0) {
699 if (C10_UNLIKELY(flattened_sym_value != tensor_dim)) {
700 push(stack, false);
701 return;
702 }
703 } else {
704 // not seen, write value
705 flattened_symbolic_dims[flattened_sym_index] = tensor_dim;
706 }
707 }
708 }
709 flattened_dim_offset += num_dims;
710 }
711
712 push(stack, true);
713 return;
714 };
715 },
716 aliasAnalysisFromSchema()),
717});
718
719void runTensorExprDynamicGroup(const Code& code, Stack& stack) {
720 InterpreterState interpreter{code};
721 interpreter.run(stack);
722}
723
724Operation createTensorExprDynamicGroup(const Node* node) {
725 const auto& graph = node->g(attr::Subgraph);
726 Code code(graph, "");
727 // This implementation creates a Code object and InterpreterState on every
728 // call to TensorExprDynamicGroup, which affects performance. Ideally, we
729 // should be reusing Code and InterpreterState across calls to this op.
730 // But that is resulting in a "No frames found" error.
731 // TODO: Improve the performance of this by figuring out a better approach.
732 // NB: this is only run in SR, which is single-threaded
733 return [code](Stack& stack) {
734 runTensorExprDynamicGroup(code, stack);
735 return 0;
736 };
737}
738
739RegisterOperators TensorExprDynamicOp({
740 torch::jit::Operator(
741 prim::TensorExprDynamicGroup,
742 createTensorExprDynamicGroup,
743 AliasAnalysisKind::INTERNAL_SPECIAL_CASE),
744});
745
746} // namespace jit
747} // namespace torch
748