1#include <torch/csrc/jit/runtime/static/impl.h>
2
3#include <ATen/MemoryOverlap.h>
4#include <ATen/core/symbol.h>
5#include <ATen/record_function.h>
6#include <c10/core/CPUAllocator.h>
7#include <c10/core/InferenceMode.h>
8#include <c10/macros/Macros.h>
9#include <c10/util/MaybeOwned.h>
10#include <c10/util/irange.h>
11#include <caffe2/core/scope_guard.h>
12#include <caffe2/core/timer.h>
13#include <torch/csrc/jit/ir/alias_analysis.h>
14#include <torch/csrc/jit/jit_log.h>
15#include <torch/csrc/jit/passes/add_if_then_else.h>
16#include <torch/csrc/jit/passes/canonicalize.h>
17#include <torch/csrc/jit/passes/dead_code_elimination.h>
18#include <torch/csrc/jit/passes/eliminate_no_ops.h>
19#include <torch/csrc/jit/passes/freeze_module.h>
20#include <torch/csrc/jit/passes/remove_mutation.h>
21#include <torch/csrc/jit/passes/subgraph_rewrite.h>
22#include <torch/csrc/jit/passes/variadic_ops.h>
23#include <torch/csrc/jit/runtime/graph_iterator.h>
24#include <torch/csrc/jit/runtime/static/fusion.h>
25#include <torch/csrc/jit/runtime/static/memory_planner.h>
26#include <torch/csrc/jit/runtime/static/ops.h>
27#include <torch/csrc/jit/runtime/static/passes.h>
28#include <torch/csrc/jit/runtime/vararg_functions.h>
29#include <algorithm>
30
31#ifndef AT_PER_OPERATOR_HEADERS
32#include <ATen/NativeFunctions.h>
33#else
34#include <ATen/ops/clone_native.h>
35#endif
36
37#include <iterator>
38#include <limits>
39#include <sstream>
40#include <stdexcept>
41
42#ifdef FBCODE_CAFFE2
43#include <common/logging/logging.h>
44#include <folly/dynamic.h>
45#include <folly/json.h>
46#endif
47
48// used in test only
49C10_DEFINE_bool(
50 static_runtime_disable_debug_memory_overlap_check,
51 false,
52 "If true, disable the memory overlap check in debug mode in ProcessedNode::run()");
53
54namespace torch {
55namespace jit {
56
57namespace {
58
59bool allArgsAreTensors(const Node* node) {
60 const auto& inputs = node->inputs();
61 return std::all_of(inputs.begin(), inputs.end(), [](const Value* value) {
62 return value->type()->kind() == TypeKind::TensorType;
63 });
64}
65
66} // namespace
67
68// A manually curated set of ops that are disallowed in static runtime.
69// These are rarely-used ops. Disallowing them typically eliminates
70// corner cases in graph optimizations, allowing for more aggressive
71// optimizations and better performance.
72bool isUnsupportedOp(const Node* node) {
73 auto kind = node->kind();
74 if (kind != aten::__is__ && kind != aten::__isnot__) {
75 return false;
76 }
77
78 // We can't support aten::__is__ (and __isnot__) with tensor arguments.
79 // Consider the following graph:
80 // def forward(x):
81 // y = x.detach()
82 // return x is y
83 // We have a graph optimization that removes the `detach` node since it is
84 // a no-op during inference. But this affects the result - we get true
85 // instead of false! There are many other graph passes affected by this
86 // issue.
87 return allArgsAreTensors(node);
88}
89
90namespace {
91
92bool canEnableStaticRuntimeImpl(const Block* block) {
93 if (block == nullptr) {
94 return false;
95 }
96
97 bool can_support = true;
98 for (auto* node : block->nodes()) {
99 for (auto* subblock : node->blocks()) {
100 // The ordering prevents && from short circuiting, which we want -
101 // it's useful to see *all* the unsupported ops.
102 can_support = canEnableStaticRuntimeImpl(subblock) && can_support;
103 }
104
105 const auto kind = node->kind();
106 if (kind == prim::Constant) {
107 continue;
108 }
109 // check if can get op from Node
110 const Operator* op = node->maybeOperator();
111 if (isUnsupportedOp(node) || (!op && !nativeOpIsRegistered(kind))) {
112 can_support = false;
113 LOG(WARNING) << "Found unsupported op: " << kind.toQualString();
114 }
115 }
116 return can_support;
117}
118
119} // namespace
120
121// Graph must be frozen. canEnableStaticRuntime will return false
122// if there's any prim::CallMethod ops left in the graph.
123bool canEnableStaticRuntime(const std::shared_ptr<torch::jit::Graph>& graph) {
124 return canEnableStaticRuntimeImpl(graph->block());
125}
126
127namespace {
128
129auto sr_metadata_registerer = torch::class_<StaticRuntimeMetadata>(
130 "StaticRuntime",
131 "StaticRuntimeMetadata");
132
133} // namespace
134
135std::string dumpValueSet(
136 const FastSet<const Value*>& value_set,
137 const char* set_name) {
138 std::ostringstream oss;
139 oss << set_name << ": {";
140 for (const auto* val : value_set) {
141 oss << "%" << val->debugName() << ", ";
142 }
143 oss << "}";
144 return oss.str();
145}
146
147namespace {
148
149void OptimizeGraph(
150 std::shared_ptr<torch::jit::Graph>& graph,
151 const StaticModuleOptions& opts,
152 std::vector<IValue> sample_inputs) {
153 GRAPH_DUMP("Before optimizations: ", graph);
154 if (opts.enable_tensorexpr_fusion) {
155 if (sample_inputs.empty()) {
156 VLOG(1) << "Cannot perform TensorExpr fusion - sample_inputs is empty";
157 } else {
158 VLOG(1) << "Performing TensorExpr fusion";
159 performTensorExprFusion(graph, std::move(sample_inputs));
160 }
161 }
162 Inline(*graph);
163 ConstantPropagation(graph);
164 Canonicalize(graph);
165 ConstantPropagation(graph);
166 RemoveTensorMutation(graph);
167 ConstantPropagation(graph);
168 EliminateNoOpSlice(graph);
169 EliminateDeadCode(graph);
170 FuseInferenceOpsForSparseNN(graph);
171 UseVariadicCat(graph);
172 UseVariadicStack(graph);
173 EliminateTrivialEquallySplit(graph);
174 EliminateExtraPermuteOps(graph);
175
176 if (opts.enable_out_variant) {
177 UseVariadicOp(
178 graph,
179 fromQualString("fb::sigrid_transforms_torch_bind"),
180 fromQualString("fb::variadic_sigrid_transforms_torch_bind"));
181 UseVariadicOp(
182 graph,
183 fromQualString("torcharrow::inference_wrapper_run_flat"),
184 fromQualString("torcharrow::variadic_inference_wrapper_run_flat"));
185 // These fused ops only have out variants - we can't do the fusion when
186 // out variants are disabled.
187 FuseSignLog1P(graph);
188 FuseClampNaNToNum(graph);
189
190#ifdef FBCODE_CAFFE2
191 if (opts.use_copy_variants && !opts.enable_tensorexpr_fusion) {
192 ReplaceWithCopy(graph);
193 } else {
194 ReplacePermuteWithCopy(graph);
195 }
196 if (opts.use_maybe_copy_variants && !opts.enable_tensorexpr_fusion) {
197 ReplaceWithMaybeCopy(graph);
198 }
199 FuseListUnpack(graph);
200 RemoveUnnecessaryOutputs(graph);
201 PrepackWeights(graph);
202#endif
203 }
204
205 ConstantPropagation(graph);
206 RemoveImmutableInputDictLookups(graph);
207 UseVariadicTupleUnpack(graph);
208 UseVariadicGroupedAccessor(graph);
209 EliminateNoOps(
210 graph, /* custom_ops */ {fromQualString("fb::scale_gradient")});
211 AddIfThenElseOp(graph);
212 UseSplitAndSqueeze(graph);
213 UseInPlaceGetRealInputsFromOptionalInputsV2(graph);
214 GRAPH_DUMP("Final graph after optimizations: ", graph);
215}
216
217bool IsSelfInGraphInput(std::shared_ptr<torch::jit::Graph>& graph) {
218 return !graph->inputs().empty() && graph->inputs().at(0)->type()->is_module();
219}
220
221// remove unused input 0 from graph
222bool removeSelfFromGraphInput(std::shared_ptr<torch::jit::Graph>& graph) {
223 if (graph->inputs().at(0)->type()->is_module()) {
224 if (graph->inputs().at(0)->hasUses()) {
225 return false;
226 }
227 graph->eraseInput(0);
228 }
229 return true;
230}
231
232std::vector<Value*> valueVecFromFastSet(const FastSet<const Value*>& s) {
233 std::vector<Value*> result;
234 result.reserve(s.size());
235 for (auto* v : s) {
236 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
237 result.emplace_back(const_cast<Value*>(v));
238 }
239 return result;
240}
241
242bool mayContainAlias(const AliasDb& db, const Value* v1, const Value* v2) {
243 // AliasDb is not const-correct here, so we have to const_cast
244 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
245 return db.mayContainAlias(const_cast<Value*>(v1), const_cast<Value*>(v2));
246}
247
248bool mayContainAlias(
249 const AliasDb& db,
250 const Value* a,
251 const FastSet<const Value*>& b) {
252 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
253 return db.mayContainAlias(const_cast<Value*>(a), valueVecFromFastSet(b));
254}
255
256bool escapesScope(const AliasDb& db, const Value* a) {
257 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
258 return db.escapesScope({const_cast<Value*>(a)});
259}
260
261void PrepareGraphForStaticModule(
262 std::shared_ptr<torch::jit::Graph> graph,
263 const StaticModuleOptions& opts,
264 std::vector<IValue> sample_inputs) {
265 TORCH_CHECK(canEnableStaticRuntime(graph));
266 OptimizeGraph(graph, opts, std::move(sample_inputs));
267
268 // Static runtime moves its outputs out of the runtime
269 // by default. In some rare cases, this is not actually safe to
270 // do - for example, if the value is a constant, static runtime
271 // needs to hold onto a copy. Rather than adding special logic
272 // to handle this rare case, we use this pass to detect it and
273 // create an owned reference that can be safely moved out of the
274 // runtime.
275 CreateOwnedRefsForSpecialValues(*graph);
276
277 // We assume that each sub-block has at least one output. If we
278 // detect any that have 0, force the sub-block to return None.
279 ForceNonEmptyOutputs(*graph);
280}
281
282std::pair<std::shared_ptr<Graph>, c10::optional<Module>> PrepareForStaticModule(
283 const torch::jit::Module& m,
284 bool is_frozen,
285 const StaticModuleOptions& opts,
286 std::vector<IValue> sample_inputs) {
287 LOG(INFO) << "StaticModuleOptions: enable_out_variant "
288 << opts.enable_out_variant << ", optimize_memory "
289 << opts.optimize_memory << ", manage_output_tensors "
290 << opts.manage_output_tensors << ", use_copy_variants "
291 << opts.use_copy_variants << ", use_maybe_copy_variants "
292 << opts.use_maybe_copy_variants << ", enable_tensorexpr_fusion "
293 << opts.enable_tensorexpr_fusion;
294
295 Module module = m.copy();
296 if (!is_frozen) {
297 module.eval();
298 module = freeze_module(module);
299 }
300
301 Method method = module.get_method("forward");
302 auto graph = module.get_method("forward").graph();
303
304 if (!sample_inputs.empty() && IsSelfInGraphInput(graph)) {
305 sample_inputs.insert(sample_inputs.begin(), m._ivalue());
306 }
307 PrepareGraphForStaticModule(graph, opts, std::move(sample_inputs));
308
309 return std::make_pair(graph, module);
310}
311
312std::pair<std::shared_ptr<Graph>, c10::optional<Module>> PrepareForStaticModule(
313 std::shared_ptr<torch::jit::Graph> graph,
314 const StaticModuleOptions& opts,
315 std::vector<IValue> sample_inputs) {
316 PrepareGraphForStaticModule(graph, opts, std::move(sample_inputs));
317 return std::make_pair(graph, c10::nullopt);
318}
319
320} // namespace
321
322void ValueGroup::init(const Block& block, const AliasDb& db) {
323 external_aliases_.clear();
324 output_aliases_.clear();
325 // Build `external_aliases` as we look through nodes forwardly from
326 // the graph's inputs and add aliases of the inputs being created by the
327 // nodes.
328 external_aliases_.insert(block.inputs().begin(), block.inputs().end());
329 for (const auto* node : block.nodes()) {
330 if (node->kind() == prim::Constant) {
331 for (const auto* output : node->outputs()) {
332 external_aliases_.insert(output);
333 }
334 }
335 }
336 for (const auto* node : block.nodes()) {
337 if (node->kind() == prim::Constant) {
338 // Constants are already in `external_aliases`.
339 continue;
340 }
341 for (const auto* v : node->outputs()) {
342 if (escapesScope(db, v) || mayContainAlias(db, v, external_aliases_)) {
343 external_aliases_.insert(v);
344 }
345 }
346 }
347
348 // Build `output_aliases` as we look through nodes reversely so that we can
349 // start from the output values, and follow the flows backwardly from there.
350 output_aliases_.insert(block.outputs().begin(), block.outputs().end());
351 for (const auto* node : block.nodes().reverse()) {
352 if (node->kind() == prim::Constant) {
353 // Constants cannot create any aliases.
354 continue;
355 }
356 for (const auto* v : node->outputs()) {
357 if (mayContainAlias(db, v, output_aliases_)) {
358 output_aliases_.insert(v);
359 }
360 }
361 }
362}
363
364namespace {
365
366bool isTensorList(const Value* value) {
367 auto* type = value->type()->castRaw<ListType>();
368 if (!type) {
369 return false;
370 }
371 return type->getElementType()->kind() == c10::TypeKind::TensorType;
372}
373
374bool containTensorsOnly(at::ArrayRef<Value*> values) {
375 // return true only if all outputs are tensors
376 return std::all_of(values.begin(), values.end(), [](const Value* value) {
377 return value->type()->kind() == c10::TypeKind::TensorType ||
378 isTensorList(value);
379 });
380}
381
382bool isPureFunction(const Node* node) {
383 auto* schema = node->maybeSchema();
384 return schema &&
385 schema->aliasAnalysis() == c10::AliasAnalysisKind::PURE_FUNCTION;
386}
387
388} // namespace
389
390ManagedTensorRanges::ManagedTensorRanges(
391 Block& block,
392 const AliasDb& alias_db,
393 const FastSet<const Value*>& managed_tensor_values) {
394 const std::vector<Node*> nodes(block.nodes().begin(), block.nodes().end());
395 const FastSet<const Value*> graph_inputs(
396 block.inputs().begin(), block.inputs().end());
397
398 const auto num_nodes = nodes.size();
399 for (const auto i : c10::irange(num_nodes)) {
400 auto* node = nodes[i];
401 for (auto* input : node->inputs()) {
402 auto* lifetime = getLifetime(input);
403 if (!lifetime) {
404 continue;
405 }
406 DCHECK(lifetime->end <= i);
407 lifetime->end = i;
408 }
409 for (auto* output : node->outputs()) {
410 if (!alias_db.isMutableType(output)) {
411 continue;
412 }
413 value_lifetimes_.emplace(output, Lifetime(i, i));
414 }
415 }
416 for (auto* graph_output : block.outputs()) {
417 auto* lifetime = getLifetime(graph_output);
418 if (!lifetime) {
419 continue;
420 }
421 lifetime->end = num_nodes;
422 }
423
424 // Handle aliases. Aliases may extend a Value*'s lifetime. If a node
425 // has an input and output that may alias each other, set the input's
426 // lifetime end to max(input.lifetime_end, output.lifetime_end). Iterate
427 // backwards to handle chains of aliases.
428 for (const auto* node : block.nodes().reverse()) {
429 if (isPureFunction(node)) {
430 // If the node is a pure function, it doesn't create any aliases,
431 // so we can safely skip it.
432 continue;
433 }
434
435 auto inputs = collectValuesWithTrackedLifetimes(node->inputs());
436 auto outputs = collectValuesWithTrackedLifetimes(node->outputs());
437 for (auto* input : inputs) {
438 auto* input_lifetime = getLifetime(input);
439 DCHECK(input_lifetime != nullptr);
440 for (auto* output : outputs) {
441 if (mayContainAlias(alias_db, input, output)) {
442 auto* output_lifetime = getLifetime(output);
443 DCHECK(output_lifetime != nullptr);
444 input_lifetime->end =
445 std::max(output_lifetime->end, input_lifetime->end);
446 }
447 }
448 }
449 }
450 for (auto* managed_tensor : managed_tensor_values) {
451 auto* lifetime = getLifetime(managed_tensor);
452 DCHECK(lifetime && lifetime->end <= num_nodes);
453 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
454 Node* freeing_node;
455 if (lifetime->end == num_nodes) {
456 freeing_node = block.return_node();
457 } else {
458 freeing_node = nodes[lifetime->end];
459 }
460 node_to_newly_free_tensors_[freeing_node].emplace_back(managed_tensor);
461 }
462}
463
464bool ManagedTensorRanges::nodeFreesManagedTensors(Node* node) const {
465 auto it = node_to_newly_free_tensors_.find(node);
466 return it != node_to_newly_free_tensors_.end() && !it->second.empty();
467}
468
469const std::vector<const Value*>& ManagedTensorRanges::
470 availableTensorValuesAfterNode(Node* node) const {
471 return node_to_newly_free_tensors_.at(node);
472}
473
474bool ManagedTensorRanges::lifetimesOverlap(const Value* v1, const Value* v2)
475 const {
476 const auto* v1_lifetime = getLifetime(v1);
477 const auto* v2_lifetime = getLifetime(v2);
478 if (!v1_lifetime || !v2_lifetime) {
479 return false;
480 }
481
482 if (v1_lifetime->start < v2_lifetime->start) {
483 return v1_lifetime->end >= v2_lifetime->start;
484 }
485 return v2_lifetime->end >= v1_lifetime->start;
486}
487
488const ManagedTensorRanges::Lifetime* ManagedTensorRanges::getLifetime(
489 const Value* value) const {
490 auto it = value_lifetimes_.find(value);
491 if (it != value_lifetimes_.end()) {
492 return &it->second;
493 }
494 return nullptr;
495}
496
497ManagedTensorRanges::Lifetime* ManagedTensorRanges::getLifetime(
498 const Value* value) {
499 // const_cast is safe here, this is just a way to avoid code duplication
500 // between the const/non-const versions of getLifetime.
501
502 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
503 const auto* const_this = const_cast<const ManagedTensorRanges*>(this);
504
505 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
506 return const_cast<ManagedTensorRanges::Lifetime*>(
507 const_this->getLifetime(value));
508}
509
510std::vector<const Value*> ManagedTensorRanges::
511 collectValuesWithTrackedLifetimes(at::ArrayRef<const Value*> values) {
512 std::vector<const Value*> mutable_values;
513 mutable_values.reserve(values.size());
514 std::copy_if(
515 values.begin(),
516 values.end(),
517 std::back_inserter(mutable_values),
518 [this](const Value* value) { return getLifetime(value) != nullptr; });
519 return mutable_values;
520}
521
522StaticModule::StaticModule(
523 std::shared_ptr<torch::jit::Graph> g,
524 const StaticModuleOptions& opts,
525 std::vector<IValue> sample_inputs)
526 : StaticModule(
527 PrepareForStaticModule(g->copy(), opts, std::move(sample_inputs)),
528 opts) {}
529
530StaticModule::StaticModule(
531 const torch::jit::Module& m,
532 bool is_frozen,
533 const StaticModuleOptions& opts,
534 std::vector<IValue> sample_inputs)
535 : StaticModule(
536 PrepareForStaticModule(m, is_frozen, opts, std::move(sample_inputs)),
537 opts) {}
538
539StaticModule::StaticModule(
540 std::pair<std::shared_ptr<torch::jit::Graph>, c10::optional<Module>>
541 graph_and_module,
542 const StaticModuleOptions& opts)
543 : opts_(opts),
544 graph_(std::move(graph_and_module.first)),
545 module_(std::move(graph_and_module.second)),
546 num_inputs_(graph_->inputs().size()) {
547 sr_metadata_ = c10::make_intrusive<jit::StaticRuntimeMetadata>(opts_);
548 // recursively attach metadata to prim::fork nodes
549 attachNodeMetadata(graph_->block());
550
551 // check opt flags
552 if (opts.manage_output_tensors) {
553 TORCH_CHECK(
554 opts_.enable_out_variant,
555 "When manage_output_tensors is true, enable_out_variant must be set to true");
556 }
557 if (opts_.optimize_memory) {
558 TORCH_CHECK(
559 opts_.enable_out_variant,
560 "When optimize_memory is true, enable_out_variant must be set to true");
561 }
562
563 // handle schema
564 if (module_.has_value()) {
565 Method method = module_->get_method("forward");
566 schema_ = method.function().getSchema();
567 const auto num_schema_args = schema_->arguments().size();
568 DCHECK(num_schema_args > 0);
569 if (removeSelfFromGraphInput(graph_)) {
570 module_ = c10::nullopt;
571 num_inputs_ = num_schema_args - 1;
572 }
573 }
574
575 {
576 size_t nodes_size = 0, constants_size = 0;
577 for (Node* node : graph_->nodes()) {
578 ++(node->kind() == prim::Constant ? constants_size : nodes_size);
579 }
580
581 constants_.reserve(constants_size);
582 functions_.reserve(nodes_size);
583 }
584
585 // Create ProcessedFunction instances first to freeze their addresses to pass
586 // to ProcessedNode.
587 AliasDb alias_db(graph_, /*isFrozen=*/false);
588 GRAPH_DEBUG("AliasDb: ", alias_db.toString());
589
590 // Maps each Value* in the graph to its index in the values_ array that will
591 // eventually be created by StaticRuntime.
592 FastMap<const Value*, uint32_t> value_to_index;
593 prepareFunctionsAndConstants(graph_->block(), alias_db, value_to_index);
594
595 const auto constants_index_offset = 0;
596 const auto values_index_offset = constants_index_offset + constants().size();
597 value_buffer_size_ = values_index_offset;
598
599 value_buffer_size_ +=
600 prepareBlockInfo(graph_->block(), values_index_offset, value_to_index);
601
602 prepareStaticNodeInfos(graph_->block(), value_to_index, alias_db);
603
604 for (auto& block_and_info : block_infos_) {
605 auto& block_info = block_and_info.second;
606 block_info.prepare_for_memory_planner(alias_db, opts);
607 }
608}
609
610size_t StaticModule::prepareBlockInfo(
611 Block* block,
612 const size_t start_idx,
613 FastMap<const Value*, uint32_t>& value_to_index) {
614 block_infos_.emplace(block, BlockInfo(start_idx, *block));
615
616 const auto num_inputs = block->inputs().size();
617 for (const auto i : c10::irange(num_inputs)) {
618 value_to_index.emplace(block->inputs()[i], start_idx + i);
619 }
620 auto cur_idx = start_idx + num_inputs;
621
622 for (auto* node : block->nodes()) {
623 for (auto* sub_block : node->blocks()) {
624 cur_idx += prepareBlockInfo(sub_block, cur_idx, value_to_index);
625 }
626
627 if (node->kind() == prim::Constant) {
628 continue;
629 }
630
631 TORCH_CHECK(
632 cur_idx < (1 << 16),
633 "outputs offset in values table",
634 cur_idx,
635 " would overflow 2-byte index storage");
636
637 const auto num_outputs = node->outputs().size();
638 for (const auto i : c10::irange(num_outputs)) {
639 value_to_index.emplace(node->outputs()[i], cur_idx + i);
640 }
641 cur_idx += num_outputs;
642 }
643
644 std::vector<uint16_t> output_indices;
645 output_indices.reserve(block->outputs().size());
646 for (auto* output : block->outputs()) {
647 const auto output_idx = value_to_index.at(output);
648 TORCH_CHECK(
649 output_idx < (1 << 16),
650 "outputs offset in values table",
651 output_idx,
652 " would overflow 2-byte index storage");
653 output_indices.push_back(output_idx);
654 }
655
656 block_infos_.at(block).set_output_indices(std::move(output_indices));
657 return cur_idx - start_idx;
658}
659
660void StaticModule::attachNodeMetadata(Block* block) {
661 for (auto* node : block->nodes()) {
662 if (node->kind() == prim::fork) {
663 node->ival_(getStaticRuntimeMetadataSymbol(), IValue(sr_metadata_));
664 }
665 for (auto* sub_block : node->blocks()) {
666 attachNodeMetadata(sub_block);
667 }
668 }
669}
670
671void StaticModule::prepareFunctionsAndConstants(
672 Block* block,
673 const AliasDb& alias_db,
674 FastMap<const Value*, uint32_t>& value_to_index) {
675 for (auto* node : block->nodes()) {
676 for (auto* sub_block : node->blocks()) {
677 prepareFunctionsAndConstants(sub_block, alias_db, value_to_index);
678 }
679
680 if (node->kind() == prim::Constant) {
681 auto* v = node->output();
682 TORCH_CHECK(v->type()->kind() != FunctionType::Kind);
683 value_to_index.emplace(v, constants_.size());
684 constants_.emplace_back(toIValue(v).value());
685 continue;
686 }
687
688 // see [Check and correct bad schema alias info at runtime]
689 bool check_outputs_for_overlap =
690 !alias_db.mayContainAlias(node->inputs(), node->outputs()) &&
691 containTensorsOnly(node->outputs());
692 // new ProcessedFunction
693 functions_.emplace_back(
694 node, opts_.enable_out_variant, check_outputs_for_overlap);
695 }
696}
697
698size_t StaticModule::prepareStaticNodeInfos(
699 Block* block,
700 const FastMap<const Value*, uint32_t>& value_to_index,
701 const AliasDb& alias_db,
702 size_t node_idx) {
703 const auto node_start = node_idx;
704
705 auto& block_info = block_infos_.at(block);
706 std::vector<StaticNodeInfo> nodes;
707 FastMap<Node*, bool> node_has_out_variant;
708
709 for (auto* node : block->nodes()) {
710 if (node->kind() == prim::Constant) {
711 continue;
712 }
713
714 for (auto* sub_block : node->blocks()) {
715 node_idx +=
716 prepareStaticNodeInfos(sub_block, value_to_index, alias_db, node_idx);
717 }
718 ProcessedNodeInputs input_indices(node->inputs().size());
719 for (const auto input_idx : c10::irange(node->inputs().size())) {
720 auto* input = node->inputs()[input_idx];
721 auto input_ivalue_idx = value_to_index.at(input);
722 TORCH_CHECK(
723 input_ivalue_idx < (1 << 16),
724 "input index in values table ",
725 input_ivalue_idx,
726 " would overflow 2-byte index storage");
727 input_indices[input_idx] = input_ivalue_idx;
728 }
729
730 ProcessedFunction* fn = &functions_[node_idx];
731
732 // create a new ProcessedNode
733 const auto node_output_idx = node->outputs().empty()
734 // The index is unused if there are no outputs, so just create a
735 // placeholder value.
736 ? std::numeric_limits<uint16_t>::max()
737 : value_to_index.at(node->output(0));
738 nodes.emplace_back(node, fn, std::move(input_indices), node_output_idx);
739
740 node_has_out_variant.emplace(node, nodes.back().has_out_variant());
741 ++node_idx;
742 }
743
744 block_info.set_nodes(std::move(nodes), node_has_out_variant);
745 block_info.init_value_group(alias_db);
746
747 return node_idx - node_start;
748}
749
750void BlockInfo::set_nodes(
751 std::vector<StaticNodeInfo> nodes,
752 const FastMap<Node*, bool>& node_has_out_variant) {
753 nodes_ = std::move(nodes);
754
755 for (auto& node : nodes_) {
756 if (node.num_outputs() == 1 &&
757 isOptimizableContainerType(node.node(), node_has_out_variant)) {
758 node_is_optimizable_container_type_.emplace(node.node());
759 }
760 }
761}
762void BlockInfo::prepare_for_memory_planner(
763 const AliasDb& alias_db,
764 const StaticModuleOptions& opts) {
765 if (!opts.enable_out_variant) {
766 return;
767 }
768
769 // Never manage graph outputs so that we can do std::move(output_ivalue).
770 // This does not affect performance if the graph returns a collection object.
771 FastSet<const Value*> graph_output_values(
772 block_.outputs().begin(), block_.outputs().end());
773
774 // collect register indices of outputs of ops with out variant
775 for (StaticNodeInfo& pnode : nodes_) {
776 if (!pnode.has_out_variant()) {
777 continue;
778 }
779 auto outputs = pnode.node()->outputs();
780 for (const auto i : c10::irange(outputs.size())) {
781 const Value* out_v = outputs[i];
782 // Types are stored in the underlying TorchScript IR
783 bool is_tensor_type = out_v->type()->castRaw<TensorType>();
784 if (opts.manage_output_tensors && is_tensor_type &&
785 graph_output_values.find(out_v) == graph_output_values.end() &&
786 value_group_.isOutputAlias(out_v)) {
787 managed_output_tensor_values_.insert(out_v);
788 continue;
789 }
790 if (value_group_.isAlwaysAlive(out_v)) {
791 continue;
792 }
793 if (is_tensor_type) {
794 managed_tensor_values_.insert(out_v);
795 } else if (node_is_optimizable_container_type(pnode.node())) {
796 // We "leak" certain container types because their allocations
797 // take a long time
798 leaked_values_.insert(out_v);
799 }
800 }
801 }
802
803 for (const Value* output : block_.outputs()) {
804 managed_tensor_values_.erase(output);
805 }
806 GRAPH_DEBUG("managed_tensor_values: ", dumpValueSet(managed_tensor_values_));
807 GRAPH_DEBUG(
808 "managed_output_tensor_values_: ",
809 dumpValueSet(managed_output_tensor_values_));
810
811 managed_tensor_ranges_ =
812 ManagedTensorRanges(block_, alias_db, managed_tensor_values_);
813}
814
815const StaticModuleOptions& StaticModule::opts() const {
816 return opts_;
817}
818
819size_t StaticModule::num_outputs() const {
820 return graph_->outputs().size();
821}
822
823size_t StaticModule::num_inputs() const {
824 return num_inputs_;
825}
826
827StaticRuntime& StaticModule::runtime() {
828 if (!cached_runtime_) {
829 cached_runtime_ = std::make_unique<StaticRuntime>(*this);
830 }
831 return *cached_runtime_;
832}
833
834Node* StaticModule::findNodeWithKindForTesting(const std::string& kind) const {
835 for (auto& block_and_info : block_infos_) {
836 auto& block_info = block_and_info.second;
837 for (auto& pnode : block_info.nodes()) {
838 if (pnode.node()->kind().toQualString() == kind) {
839 return pnode.node();
840 }
841 }
842 }
843 return nullptr;
844}
845
846c10::IValue StaticModule::operator()(
847 const std::vector<c10::IValue>& args,
848 const KeywordArgs& kwargs) {
849 return runtime()(args, kwargs);
850}
851
852c10::IValue StaticModule::operator()(
853 std::vector<c10::IValue>&& args,
854 const KeywordArgs& kwargs) {
855 return runtime()(std::move(args), kwargs);
856}
857
858BlockRunner::BlockRunner(
859 const StaticModule& sm,
860 IValue* values,
861 Block* block,
862 torch::jit::TaskLauncher* launcher,
863 bool is_root_block)
864 : static_module_(sm),
865 block_info_(static_module_.block_info(block)),
866 is_root_block_(is_root_block),
867 first_input_is_self_(
868 is_root_block_ && static_module_.first_input_is_self()),
869 inputs_begin_(block_info_.block_inputs_idx()),
870 // TODO(T108633124): Turn on manage output tensors for sub-blocks.
871 manage_output_tensors_enabled_(
872 is_root_block_ && sm.opts().manage_output_tensors),
873 values_(values) {
874 nodes_.reserve(block_info_.nodes().size());
875 for (auto& pre_pnode : block_info_.nodes()) {
876 nodes_.emplace_back(pre_pnode, values_);
877 }
878
879 for (auto index : block_info_.block_output_indices()) {
880 outputs_.emplace_back(&values_[index]);
881 }
882
883 for (auto& pnode : nodes_) {
884 auto* node = pnode.node();
885
886 // attach the async taskLauncher to processedNodes
887 pnode.set_metadata(launcher);
888 auto blocks = node->blocks();
889 const auto num_blocks = blocks.size();
890 if (num_blocks == 0) {
891 continue;
892 }
893 DCHECK(node->kind() == prim::If || node->kind() == prim::Loop);
894 std::vector<BlockRunner> block_runners;
895 block_runners.reserve(num_blocks);
896
897 for (auto* b : blocks) {
898 block_runners.emplace_back(sm, values_, b, launcher);
899 }
900 pnode.set_metadata(std::move(block_runners));
901 }
902}
903
904// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
905BlockRunner::BlockRunner(BlockRunner&&) noexcept = default;
906
907BlockRunner::~BlockRunner() = default;
908
909void BlockRunner::set_arg(const size_t idx, std::vector<IValue>&& args) {
910 DCHECK(idx < args.size());
911 Input(idx + first_input_is_self_) = std::move(args[idx]);
912}
913
914void BlockRunner::set_arg(const size_t idx, const std::vector<IValue>& args) {
915 DCHECK(idx < args.size());
916 Input(idx + first_input_is_self_) = args[idx];
917}
918
919void BlockRunner::set_arg(const size_t idx, const IValue& arg) {
920 Input(idx + first_input_is_self_) = arg;
921}
922
923namespace {
924void check_type(const Argument& schema_arg, const IValue& arg) {
925 // Fast path for most common case
926 if (arg.isTensor() &&
927 schema_arg.type()->kind() == c10::TypeKind::TensorType) {
928 return;
929 }
930 TORCH_CHECK(arg.type()->isSubtypeOf(schema_arg.type()));
931}
932} // namespace
933
934template <typename IValueList>
935void BlockRunner::set_inputs(
936 IValueList&& args,
937 const std::unordered_map<std::string, c10::IValue>& kwargs) {
938 const auto& schema = static_module_.schema();
939 if (first_input_is_self_) {
940 Input(0) = static_module_.module()._ivalue();
941 }
942
943 if (!is_root_block_ || C10_UNLIKELY(!schema)) {
944 TORCH_CHECK(
945 kwargs.empty(), "Schema is not available, but BlockRunner got kwargs.");
946
947 const auto total_num_inputs = args.size() + first_input_is_self_;
948 TORCH_CHECK(total_num_inputs == block_info_.num_inputs());
949
950 for (size_t i = 0; i < args.size(); ++i) {
951 set_arg(i, std::forward<IValueList>(args));
952 }
953 return;
954 }
955
956 const auto& schema_args = schema->arguments();
957 size_t consumed_kwargs = 0;
958 DCHECK(!schema_args.empty());
959 TORCH_CHECK(
960 args.size() < schema_args.size(),
961 "Static runtime got too many arguments");
962 for (size_t i = 0; i < schema_args.size() - 1; ++i) {
963 // Start at 1 since the schema always contains `self`.
964 const auto& schema_arg = schema_args[i + 1];
965
966 if (i < args.size()) {
967 check_type(schema_arg, args[i]);
968 set_arg(i, std::forward<IValueList>(args));
969 continue;
970 }
971
972 auto it = kwargs.find(schema_arg.name());
973 if (it != kwargs.end()) {
974 check_type(schema_arg, it->second);
975 set_arg(i, it->second);
976 ++consumed_kwargs;
977 continue;
978 }
979
980 auto maybe_default_val = schema_arg.default_value();
981 if (maybe_default_val) {
982 set_arg(i, *maybe_default_val);
983 continue;
984 }
985
986 TORCH_CHECK(
987 false, "Static runtime is missing required kwarg ", schema_arg.name());
988 }
989 TORCH_CHECK(consumed_kwargs == kwargs.size());
990}
991
992void BlockRunner::create_memory_planner() {
993 if (!planner_) {
994 planner_ = std::make_unique<StandardMemoryPlanner>(
995 this,
996 block_info_,
997 static_module_.opts().enable_out_variant,
998 manage_output_tensors_enabled_,
999 static_module_.opts().optimize_memory);
1000 }
1001}
1002
1003namespace {
1004
1005void destroyNodeOutputs(ProcessedNode& p_node) {
1006 const auto borrows_outputs = borrowsOutputs(p_node.node()->kind());
1007 for (const auto i : c10::irange(p_node.num_outputs())) {
1008 auto& output = p_node.Output(i);
1009 if (doesNotHeapAllocateWhenStoredInIValue(*output.type())) {
1010 continue;
1011 }
1012
1013 if (borrows_outputs) {
1014 // NB: No need to incref here. This codepath is only hit if the run didn't
1015 // finish, so we shouldn't be returning anything to the client.
1016 c10::MaybeOwnedTraits<IValue>::destroyBorrow(output);
1017 } else {
1018 output = IValue();
1019 }
1020 }
1021}
1022
1023} // namespace
1024
1025void BlockRunner::clean_up_intermediate_ivalues() noexcept {
1026 // We have to iterate in reverse order here due to borrowed
1027 // IValues - we don't want to destroy a value until all of its
1028 // borrows are cleaned up!
1029 for (auto it = nodes_.rbegin(); it != nodes_.rend(); ++it) {
1030 destroyNodeOutputs(*it);
1031 }
1032}
1033
1034void BlockRunner::resetMemory() noexcept {
1035 planner_.reset();
1036 // We must clean up intermediate values before inputs in case
1037 // there are borrowed inputs and static runtime owns the only
1038 // reference (e.g. the inputs were std::move'd into the runtime)
1039 clean_up_intermediate_ivalues();
1040 clean_up_input_ivalues();
1041}
1042
1043c10::IValue BlockRunner::move_outputs_to_tuple(uint32_t num_outputs) {
1044 switch (num_outputs) {
1045 case 1:
1046 return c10::ivalue::Tuple::create(IValue(std::move(*outputs_[0])));
1047 case 2:
1048 return c10::ivalue::Tuple::create(
1049 IValue(std::move(*outputs_[0])), IValue(std::move(*outputs_[1])));
1050 case 3:
1051 return c10::ivalue::Tuple::create(
1052 IValue(std::move(*outputs_[0])),
1053 IValue(std::move(*outputs_[1])),
1054 IValue(std::move(*outputs_[2])));
1055 default: {
1056 std::vector<c10::IValue> outputs;
1057 outputs.reserve(num_outputs);
1058 for (const auto i : c10::irange(num_outputs)) {
1059 // use move here. Otherwise, clean up outputs_[i] explicitly
1060 outputs.emplace_back(std::move(*outputs_[i]));
1061 }
1062 return c10::ivalue::Tuple::create(std::move(outputs));
1063 }
1064 }
1065}
1066
1067/// [Check and correct bad schema alias info at runtime]
1068/// Static runtime relies on the operator schema's alias info to be correct for
1069/// memory planning. Because it's hard to enforce the alias info to be correct,
1070/// we need to do runtime detection for accidental aliases that do not comply
1071/// with the schema. Only aliases of managed tensors are problematic. To avoid
1072/// runtime crashes, we can add runtime detection and force the op to comply
1073/// with its schema by cloning the alias. Because all managed tensors' data_ptrs
1074/// are part of the internal buffer that the MemoryPlanner allocates, we can
1075/// check aliases by checking the memory overlap with this internal buffer. But
1076/// a tensor's storage can be resized during inferenceso we need another way to
1077/// handle the resized case.
1078///
1079/// There are two ways for incorrect schema to break memory planning. Let's look
1080/// at two examples:
1081///
1082/// Example 1:
1083/// @code
1084/// def forward(x):
1085/// a = x + x
1086/// b = bad_op(a) # b ends up aliasing a incorrectly
1087/// return (b)
1088/// @endcode
1089/// bad_op: its schema says it returns a new Tensor, but it actually returns an
1090/// alias. In this case, the memory planner would recognize `a` as a managed
1091/// tensor and clean up its memory before returning `b`. But `b` is actually an
1092/// alias of `a`, when `a`'s data_ptr get reset, `b`'s data_ptr gets reset too.
1093///
1094/// Example 2:
1095/// @code
1096/// def forward(x):
1097/// a = x + x
1098/// a2 = bad_op(a) # a2 ends up alias a incorrectly
1099/// b = a + a
1100/// c = b * b # c shares storage with a
1101/// d = c + 2 # d shares storage with b
1102/// e = a2 * a2
1103/// return (d, e)
1104/// @endcode
1105/// With the memory reuse algorithm, `c` could end up sharing storage with `a`,
1106/// but because of bad_op, `a2` now aliases `a`. `c` overwrites `a` and
1107/// therefore `a2`, leading to the wrong results. We solve this problem with two
1108/// steps. Note this doesn't happen with the current memory reuse algorithm
1109/// because of the way it's implemented. Things could change with a different
1110/// implementation.
1111///
1112/// Step 1, annotate the ProcessedNodes with a flag `check_memory_overlap_` set
1113/// to true if its outputs do not alias its inputs as indicated by the AliasDb
1114/// and all of its outputs are Tensors. Then at runtime, we check that the
1115/// nodes' output tensors do not overlap with the internal buffer that the
1116/// MemoryPlanner allocates. For latency concerns, we only run this check for
1117/// fallback ops. The schemas of native ops and out variants are vetted and
1118/// enforced with static runtime unit tests. For the first iteration, we do a
1119/// full memory overlap check with
1120/// ProcessedNode::verify_and_correct_memory_overlap() because the internal
1121/// buffer doesn't exist yet.
1122///
1123/// Step 2, if a managed tensor gets resized during inference, it gets a new
1124/// data_ptr which is not from the buffer. We can tackle this corner case by
1125/// delaying the deallocation of the managed tensors to after the outputs are no
1126/// longer used (essentially merging the internal/output buffers into one).
1127/// Before the merging is implemented, we add another flag `overlap_detected_`
1128/// to flag any node with overlap detected in Step 1 and do a full memory
1129/// overlap check if the fast check (by checking memory overlap with internal
1130/// buffer) fails. There is still a corner case that fails with the added flag.
1131/// If a resize is triggered at the same time as the op creating an alias at the
1132/// same time, the current checks would fail to detect the alias.
1133void BlockRunner::verify_and_correct_memory_overlap(ProcessedNode& n) {
1134 // The slow check can be removed once the internal/output buffers are merged
1135 if (C10_UNLIKELY(n.check_outputs_for_memory_overlap())) {
1136 if (C10_UNLIKELY(!planner_)) {
1137 // slow check, for first iter only
1138 n.verify_and_correct_memory_overlap();
1139 } else {
1140 bool overlap_detected_with_fast_check = false;
1141 for (size_t i = 0; i < n.outputs().size(); i++) {
1142 auto& output = n.Output(i);
1143 if (output.isTensor()) {
1144 overlap_detected_with_fast_check |=
1145 fast_check_and_correct_overlap_with(n, output);
1146 } else if (output.isTensorList()) {
1147 auto tensor_list = output.toListRef();
1148 for (auto& ival : tensor_list) {
1149 overlap_detected_with_fast_check |=
1150 fast_check_and_correct_overlap_with(
1151 n,
1152 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
1153 const_cast<c10::IValue&>(ival));
1154 }
1155 }
1156 }
1157 if (n.outputs_memory_overlap_detected() &&
1158 !overlap_detected_with_fast_check) {
1159 // slow check. Only run when the fast check fails.
1160 n.verify_and_correct_memory_overlap();
1161 }
1162 }
1163 }
1164}
1165
1166bool BlockRunner::fast_check_and_correct_overlap_with(
1167 ProcessedNode& n,
1168 c10::IValue& tensor_ival) {
1169 auto& tensor = tensor_ival.toTensor();
1170 if (planner_->overlapWithInternalBuffer(tensor.data_ptr())) {
1171 DLOG(INFO) << "Detected alias for node: " << PrintNode(n.node());
1172 tensor_ival = at::native::clone(tensor, c10::nullopt);
1173 n.set_outputs_memory_overlap_detected();
1174 return true;
1175 }
1176 return false;
1177}
1178
1179BlockRunner::Deallocator::~Deallocator() {
1180 // Assume cleanup cannot throw.
1181 cleanupImpl();
1182#ifndef NDEBUG
1183 block_runner_.check_for_memory_leak(/*output_returned*/ false);
1184#endif
1185}
1186
1187void BlockRunner::Deallocator::cleanupImpl() {
1188 // MemoryPlanner is created after the first invocation of `run()`. This
1189 // is done intentionally because MemoryPlanner uses `Tensor` sizes of
1190 // the previous `run()` for memory planning of subsequent runs
1191 if (C10_LIKELY(finished_)) {
1192 block_runner_.create_memory_planner();
1193 }
1194
1195 if (C10_LIKELY(block_runner_.planner_)) {
1196 block_runner_.planner_->deallocate();
1197 } else {
1198 // This is the first run, and it didn't finish, so we can't use a
1199 // `MemoryPlanner` to deallocate stuff. Just reset everything mannually.
1200 block_runner_.resetMemory();
1201 }
1202 // clean up owning refs of input tensors
1203 block_runner_.clean_up_input_ivalues();
1204 if (C10_UNLIKELY(!finished_)) {
1205 block_runner_.deallocateOutputTensors();
1206 }
1207}
1208
1209template <typename IValueList>
1210c10::IValue BlockRunner::run_impl(
1211 IValueList&& args,
1212 const KeywordArgs& kwargs) {
1213 // We assume inference workloads, so we do not need
1214 // autograd. Enabling this is a significant win on dispatcher
1215 // overhead because it saves a round of dispatch for at least some
1216 // functions, such as resize_ and resize_as_.
1217 c10::InferenceMode mode;
1218
1219 {
1220 auto on_exit = Deallocator(*this);
1221
1222 if (planner_) {
1223 DCHECK(!manage_output_tensors_enabled_ || checkOutputTensorMemoryLeaks());
1224 planner_->allocate();
1225 }
1226
1227 set_inputs(std::forward<IValueList>(args), kwargs);
1228
1229 for (auto& n : nodes_) {
1230 // LOG(INFO) << "Running node: " << PrintNode(n.node());
1231 n.run();
1232 // Check for incorrect schema alias info.
1233 verify_and_correct_memory_overlap(n);
1234 }
1235 on_exit.setFinished();
1236 }
1237
1238 // no need to keep references of outputs in static runtime anymore
1239 if (block_info_.num_outputs() > 1) {
1240 return move_outputs_to_tuple(block_info_.num_outputs());
1241 }
1242
1243 DCHECK(check_for_memory_leak(/*output_returned*/ false));
1244
1245 // use move here. Otherwise, clean up outputs_[0] explicitly
1246 return std::move(*outputs_[0]);
1247}
1248
1249template <typename IValueList>
1250c10::IValue BlockRunner::run_impl_record_functions(
1251 IValueList&& args,
1252 const KeywordArgs& kwargs) {
1253 auto step_callbacks =
1254 at::getStepCallbacksUnlessEmpty(at::RecordScope::STATIC_RUNTIME_MODEL);
1255 if (C10_UNLIKELY(step_callbacks.has_value())) {
1256 at::RecordFunction guard(std::move(*step_callbacks));
1257 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(guard.isActive());
1258 guard.needsInputs()
1259 ? guard.before(
1260 "forward", c10::ArrayRef<const IValue>(args.data(), args.size()))
1261 : guard.before("forward");
1262
1263 return run_impl(std::forward<IValueList>(args), kwargs);
1264 }
1265 return run_impl(std::forward<IValueList>(args), kwargs);
1266}
1267
1268template <typename IValueList>
1269c10::intrusive_ptr<c10::ivalue::Future> BlockRunner::run_impl_async(
1270 IValueList&& args,
1271 const KeywordArgs& kwargs) {
1272 // run the graph inline in the caller thread. Async ops will be
1273 // executed on taskLauncher attached to the metadata of ProcessedNodes
1274 c10::IValue output = run_impl(args, kwargs);
1275
1276 // If the output is of type future, return it
1277 if (output.isFuture()) {
1278 return output.toFuture();
1279 }
1280
1281 // wrap the output into future, mark completed and return it
1282 TypePtr return_type;
1283 if (block_info_.num_outputs() > 1) {
1284 return_type = TupleType::create(
1285 fmap(outputs(), [](const IValue* v) { return v->type(); }));
1286 } else {
1287 return_type = outputs().at(0)->type();
1288 }
1289 c10::intrusive_ptr<Future> future = c10::make_intrusive<Future>(return_type);
1290 future->markCompleted(output);
1291 return future;
1292}
1293
1294template <typename IValueList>
1295c10::intrusive_ptr<c10::ivalue::Future> BlockRunner::
1296 run_impl_record_functions_async(
1297 IValueList&& args,
1298 const KeywordArgs& kwargs) {
1299 auto step_callbacks =
1300 at::getStepCallbacksUnlessEmpty(at::RecordScope::STATIC_RUNTIME_MODEL);
1301 if (C10_UNLIKELY(step_callbacks.has_value())) {
1302 at::RecordFunction guard(std::move(*step_callbacks));
1303 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(guard.isActive());
1304 guard.needsInputs()
1305 ? guard.before(
1306 "forward", c10::ArrayRef<const IValue>(args.data(), args.size()))
1307 : guard.before("forward");
1308
1309 return run_impl_async(std::forward<IValueList>(args), kwargs);
1310 }
1311 return run_impl_async(std::forward<IValueList>(args), kwargs);
1312}
1313
1314c10::IValue BlockRunner::operator()(
1315 const std::vector<c10::IValue>& args,
1316 const KeywordArgs& kwargs) {
1317#ifdef PYTORCH_DISABLE_NET_PROFILING
1318 return run_impl(args, kwargs);
1319#else
1320 return run_impl_record_functions(args, kwargs);
1321#endif
1322}
1323
1324c10::IValue BlockRunner::operator()(
1325 std::vector<c10::IValue>&& args,
1326 const KeywordArgs& kwargs) {
1327#ifdef PYTORCH_DISABLE_NET_PROFILING
1328 return run_impl(std::move(args), kwargs);
1329#else
1330 return run_impl_record_functions(std::move(args), kwargs);
1331#endif
1332}
1333
1334c10::intrusive_ptr<c10::ivalue::Future> BlockRunner::runAsync(
1335 const std::vector<c10::IValue>& args,
1336 const KeywordArgs& kwargs) {
1337#ifdef PYTORCH_DISABLE_NET_PROFILING
1338 return run_impl_async(args, kwargs);
1339#else
1340 return run_impl_record_functions_async(args, kwargs);
1341#endif
1342}
1343
1344c10::intrusive_ptr<c10::ivalue::Future> BlockRunner::runAsync(
1345 std::vector<c10::IValue>&& args,
1346 const KeywordArgs& kwargs) {
1347#ifdef PYTORCH_DISABLE_NET_PROFILING
1348 return run_impl_async(std::move(args), kwargs);
1349#else
1350 return run_impl_record_functions_async(std::move(args), kwargs);
1351#endif
1352}
1353
1354namespace {
1355
1356std::string generate_latency_json(const std::string& label, double millis) {
1357#ifdef FBCODE_CAFFE2
1358 folly::dynamic json = folly::dynamic::object();
1359 json["type"] = label;
1360 json["metric"] = "latency";
1361 json["unit"] = "ms";
1362 json["value"] = millis;
1363 return "PyTorchObserver " + folly::toJson(json);
1364#else
1365 return "";
1366#endif
1367}
1368
1369} // namespace
1370
1371void BlockRunner::benchmark(
1372 const std::vector<std::vector<c10::IValue>>& args_list,
1373 const std::vector<KeywordArgs>& kwargs_list,
1374 const int warmup_runs,
1375 const int main_runs,
1376 bool print_per_node_time,
1377 bool generate_ai_pep_output) {
1378 TORCH_CHECK(kwargs_list.empty() || args_list.size() == kwargs_list.size());
1379 std::cout << "Input size: " << args_list.size() << std::endl;
1380 float time_per_iter =
1381 benchmark_model(args_list, kwargs_list, warmup_runs, main_runs);
1382 std::cout << "Static runtime ms per iter: " << time_per_iter
1383 << ". Iters per second: " << 1000.0 / time_per_iter << std::endl;
1384
1385 IndividualMetrics results =
1386 benchmark_individual_ops(args_list, kwargs_list, warmup_runs, main_runs);
1387
1388 if (print_per_node_time) {
1389 for (const auto i : c10::irange(nodes_.size())) {
1390 const Node* node = nodes_[i].node();
1391 std::cout << "Node #" << i << ": " << results.time_per_node[i]
1392 << " ms/iter, ";
1393 node->print(std::cout, 0, nullptr, false);
1394 }
1395 }
1396
1397 std::vector<std::pair<std::string, double>> time_per_node_type_vec{
1398 results.time_per_node_type.begin(), results.time_per_node_type.end()};
1399 if (args_list.empty()) {
1400 std::sort(
1401 time_per_node_type_vec.begin(),
1402 time_per_node_type_vec.end(),
1403 [&results](auto& left, auto& right) {
1404 return results.instances_per_node_type[left.first] >
1405 results.instances_per_node_type[right.first];
1406 });
1407 } else {
1408 std::sort(
1409 time_per_node_type_vec.begin(),
1410 time_per_node_type_vec.end(),
1411 [](auto& left, auto& right) { return left.second > right.second; });
1412 }
1413 std::cout << "Time per node type:" << std::endl;
1414 for (const auto& p : time_per_node_type_vec) {
1415 const std::string& kind = p.first;
1416 const double ms = p.second;
1417 std::cout << std::setw(15) << ms << " ms. " << std::setw(10)
1418 << results.percent_per_node_type[kind] << "%. " << kind << " ("
1419 << results.instances_per_node_type[kind] << " nodes";
1420 if (results.out_nodes.count(kind)) {
1421 std::cout << ", out variant)" << std::endl;
1422 } else if (results.native_nodes.count(kind)) {
1423 std::cout << ", native)" << std::endl;
1424 } else {
1425 std::cout << ")" << std::endl;
1426 }
1427
1428 if (generate_ai_pep_output) {
1429 LOG(INFO) << generate_latency_json(kind, ms);
1430 }
1431 }
1432 if (generate_ai_pep_output) {
1433 LOG(INFO) << generate_latency_json(
1434 "static_runtime_first_iter", results.first_iter_time);
1435 }
1436 std::cout << std::setw(15) << results.total_time << " ms. in Total"
1437 << std::endl;
1438 std::cout << "BlockRunner setup time: " << results.setup_time << " ms"
1439 << std::endl;
1440 std::cout << "Memory allocation time: " << results.memory_alloc_time
1441 << " ms\n";
1442 std::cout << "Memory deallocation time: " << results.memory_dealloc_time
1443 << " ms" << std::endl;
1444 std::cout << "Outputs deallocation time: " << results.output_dealloc_time
1445 << " ms" << std::endl;
1446 std::cout << "First iter time: " << results.first_iter_time << " ms"
1447 << std::endl;
1448 std::cout << "Number of operators: " << nodes_.size() << std::endl;
1449
1450 if (planner_) {
1451 std::cout << "Total number of managed tensors: "
1452 << planner_->total_num_managed_tensors() << std::endl;
1453 std::cout << "Total number of managed output tensors: "
1454 << planner_->total_num_managed_output_tensors() << std::endl;
1455 std::cout << "Total number of unmanaged values: "
1456 << planner_->total_num_unmanaged() << std::endl;
1457 std::cout << "Number of unmanaged values requiring cleanup: "
1458 << planner_->num_unmanaged_non_scalars() << std::endl;
1459 std::cout << "Number of unmanaged values not requiring cleanup: "
1460 << planner_->num_unmanaged_scalars() << std::endl;
1461 std::cout << "Total memory managed: " << planner_->total_managed()
1462 << " bytes" << std::endl;
1463 if (static_module_.opts().optimize_memory) {
1464 std::cout << "Total number of reused tensors: "
1465 << planner_->total_reused_tensors() << std::endl;
1466 }
1467 }
1468
1469 auto unsupported_nodes_count = results.total_nodes_count -
1470 results.out_nodes_count - results.native_nodes.size();
1471 std::cout << "Total number of 'out' variant nodes/total number of nodes: "
1472 << results.out_nodes_count << "/" << results.total_nodes_count
1473 << " ("
1474 << 100.0 * (results.out_nodes_count) /
1475 static_cast<float>(results.total_nodes_count)
1476 << "%)" << std::endl;
1477 std::cout << "Total number of nodes not covered by SR/total number of nodes: "
1478 << unsupported_nodes_count << "/" << results.total_nodes_count
1479 << " ("
1480 << 100.0 * (unsupported_nodes_count) /
1481 static_cast<float>(results.total_nodes_count)
1482 << "%)" << std::endl;
1483
1484 check_for_memory_leak();
1485
1486#ifndef NDEBUG
1487 KeywordArgs empty_kwargs;
1488 display_nodes(
1489 args_list[0], kwargs_list.size() > 0 ? kwargs_list[0] : empty_kwargs);
1490#endif
1491}
1492
1493float BlockRunner::benchmark_model(
1494 const std::vector<std::vector<c10::IValue>>& args_list,
1495 const std::vector<KeywordArgs>& kwargs_list,
1496 const int warmup_runs,
1497 const int main_runs) {
1498 TORCH_CHECK(warmup_runs >= 0 && main_runs >= 1);
1499 TORCH_CHECK(kwargs_list.empty() || args_list.size() == kwargs_list.size());
1500
1501 const bool is_kwargs_empty = kwargs_list.empty();
1502 const KeywordArgs empty_kwargs;
1503 for (const auto i : c10::irange(warmup_runs)) {
1504 (void)i; // Suppress unused variable warning
1505 for (const auto j : c10::irange(args_list.size())) {
1506 operator()(args_list[j], is_kwargs_empty ? empty_kwargs : kwargs_list[j]);
1507 if (manage_output_tensors_enabled_) {
1508 deallocateOutputTensors();
1509 }
1510 }
1511 }
1512 caffe2::Timer timer;
1513 for (const auto i : c10::irange(main_runs)) {
1514 (void)i; // Suppress unused variable warning
1515 for (const auto j : c10::irange(args_list.size())) {
1516 operator()(args_list[j], is_kwargs_empty ? empty_kwargs : kwargs_list[j]);
1517 if (manage_output_tensors_enabled_) {
1518 deallocateOutputTensors();
1519 }
1520 }
1521 }
1522 float millis = timer.MilliSeconds();
1523 return millis / (static_cast<float>(main_runs) * args_list.size());
1524}
1525
1526bool display_ivalue(const IValue& iv) {
1527 if (iv.isTensor()) {
1528 std::cout << "Tensor " << iv.toTensor().toString() << " {";
1529 for (const auto i : c10::irange(iv.toTensor().sizes().size())) {
1530 std::cout << iv.toTensor().sizes()[i];
1531 if (iv.toTensor().sizes().size() > i + 1) {
1532 std::cout << ", ";
1533 }
1534 }
1535 std::cout << "}\n";
1536 return true;
1537 } else if (iv.isTensorList()) {
1538 std::cout << "TensorList {" << iv.toTensorList().size() << "}\n";
1539 return true;
1540 } else if (iv.isGenericDict()) {
1541 std::cout << "Dict {" << iv.toGenericDict().size() << "}\n";
1542 return true;
1543 } else if (iv.isTuple()) {
1544 std::cout << "Tuple {" << iv.toTupleRef().elements().size() << "}\n";
1545 return true;
1546 } else if (iv.isInt()) {
1547 std::cout << "int {" << iv.toInt() << "}\n";
1548 return true;
1549 } else if (iv.isBool()) {
1550 std::cout << "bool {" << iv.toBool() << "}\n";
1551 return true;
1552 } else if (iv.isDouble()) {
1553 std::cout << "double {" << iv.toDouble() << "}\n";
1554 return true;
1555 }
1556 return false;
1557}
1558
1559void display_pnode_info(const ProcessedNode& pnode) {
1560 pnode.node()->print(std::cout, 0, nullptr, false);
1561 for (const auto i : c10::irange(pnode.num_inputs())) {
1562 std::cout << "\ti" << i << ": ";
1563 if (!display_ivalue(pnode.Input(i))) {
1564 std::cout << *(pnode.node()->inputs()[i]->type()) << '\n';
1565 }
1566 }
1567 const auto outputs = pnode.outputs();
1568 for (const auto i : c10::irange(outputs.size())) {
1569 std::cout << "\to" << i << ": ";
1570 if (!display_ivalue(outputs[i])) {
1571 std::cout << *(pnode.node()->outputs()[i]->type()) << '\n';
1572 }
1573 }
1574}
1575
1576void BlockRunner::display_nodes(
1577 const std::vector<c10::IValue>& args,
1578 const KeywordArgs& kwargs) {
1579 c10::InferenceMode mode;
1580
1581 auto on_exit = Deallocator(*this);
1582
1583 if (planner_) {
1584 planner_->allocate();
1585 }
1586 set_inputs(args, kwargs);
1587
1588 for (auto& node : nodes_) {
1589 node.run();
1590 display_pnode_info(node);
1591 }
1592 on_exit.setFinished();
1593}
1594
1595BlockRunner::IndividualMetrics BlockRunner::benchmark_individual_ops(
1596 const std::vector<std::vector<c10::IValue>>& args_list,
1597 const std::vector<KeywordArgs>& kwargs_list,
1598 const int warmup_runs,
1599 const int main_runs) {
1600 TORCH_CHECK(kwargs_list.empty() || args_list.size() == kwargs_list.size());
1601 TORCH_CHECK(warmup_runs >= 1 && main_runs >= 1);
1602
1603 IndividualMetrics results;
1604 results.time_per_node.resize(nodes_.size(), 0);
1605 if (args_list.empty()) {
1606 // When the given input is empty, compute the op statistics from the given
1607 // graph without executing it.
1608 for (const auto i : c10::irange(nodes_.size())) {
1609 const Node* node = nodes_[i].node();
1610 std::string kind(node->kind().toQualString());
1611 // TODO: Collect op statistics from sub-blocks here.
1612 results.time_per_node[i] = 0;
1613 results.time_per_node_type[kind] = 0;
1614 results.instances_per_node_type[kind]++;
1615 if (nodes_[i].has_out_variant()) {
1616 results.out_nodes.insert(kind);
1617 results.out_nodes_count++;
1618 } else if (nodes_[i].has_native()) {
1619 results.native_nodes.insert(kind);
1620 }
1621 results.total_time += results.time_per_node[i];
1622 }
1623 results.total_nodes_count = nodes_.size();
1624 results.memory_alloc_time = 0;
1625 results.memory_dealloc_time = 0;
1626 results.output_dealloc_time = 0;
1627 for (const auto& p : results.time_per_node_type) {
1628 const std::string& kind = p.first;
1629 results.percent_per_node_type[kind] = 0;
1630 }
1631 return results;
1632 }
1633
1634 const bool is_kwargs_empty = kwargs_list.empty();
1635 const KeywordArgs empty_kwargs;
1636 bool manage_output_tensors = static_module_.opts().manage_output_tensors;
1637 // See comment on above use of InferenceMode for
1638 // explanation.
1639 c10::InferenceMode mode;
1640
1641 // setup time
1642 caffe2::Timer timer;
1643
1644 set_inputs(args_list[0], is_kwargs_empty ? empty_kwargs : kwargs_list[0]);
1645
1646 results.setup_time = timer.MilliSeconds();
1647
1648 // The first iteration profiles each node's output Tensors' sizes and
1649 // initializes the memory planner with the profile information. Folllowing
1650 // iterations just use the already established memory planning.
1651 timer.Start();
1652 operator()(args_list[0], is_kwargs_empty ? empty_kwargs : kwargs_list[0]);
1653 if (manage_output_tensors) {
1654 deallocateOutputTensors();
1655 }
1656 results.first_iter_time = timer.MilliSeconds();
1657
1658 // warmup runs
1659 for (const auto i : c10::irange(warmup_runs - 1)) {
1660 (void)i; // Suppress unused variable warning
1661 for (const auto j : c10::irange(args_list.size())) {
1662 operator()(args_list[j], is_kwargs_empty ? empty_kwargs : kwargs_list[j]);
1663 if (manage_output_tensors) {
1664 deallocateOutputTensors();
1665 }
1666 }
1667 }
1668
1669 // main runs
1670 for (const auto i : c10::irange(main_runs)) {
1671 (void)i; // Suppress unused variable warning
1672
1673 for (const auto j : c10::irange(args_list.size())) {
1674 set_inputs(args_list[j], is_kwargs_empty ? empty_kwargs : kwargs_list[j]);
1675
1676 timer.Start();
1677 if (planner_) {
1678 planner_->allocate();
1679 }
1680 float millis = timer.MilliSeconds();
1681 results.memory_alloc_time += millis;
1682
1683 for (const auto k : c10::irange(nodes_.size())) {
1684 timer.Start();
1685 nodes_[k].run();
1686 millis = timer.MilliSeconds();
1687 results.time_per_node[k] += millis;
1688 verify_and_correct_memory_overlap(nodes_[k]);
1689 }
1690 timer.Start();
1691 create_memory_planner();
1692 planner_->deallocate();
1693 // clean up owning refs of input tensors
1694 clean_up_input_ivalues();
1695 if (manage_output_tensors) {
1696 deallocateOutputTensors();
1697 }
1698 millis = timer.MilliSeconds();
1699 results.memory_dealloc_time += millis;
1700
1701 timer.Start();
1702 // no need to keep references of outputs in static runtime anymore
1703 c10::IValue output;
1704 if (static_module_.num_outputs() > 1) {
1705 output = move_outputs_to_tuple(static_module_.num_outputs());
1706 }
1707
1708 DCHECK(check_for_memory_leak(/*output_returned*/ false));
1709
1710 // use move here. Otherwise, clean up outputs_[0] explicitly
1711 output = std::move(*outputs_[0]);
1712 // release outputs explicitly to measure the time it takes
1713 output = IValue();
1714 millis = timer.MilliSeconds();
1715 results.output_dealloc_time += millis;
1716 }
1717 }
1718
1719 // post processing
1720 const float num_total_iters =
1721 (static_cast<float>(main_runs) * args_list.size());
1722 for (const auto i : c10::irange(nodes_.size())) {
1723 const Node* node = nodes_[i].node();
1724 std::string kind = std::string(node->kind().toQualString());
1725 results.time_per_node[i] /= num_total_iters;
1726 results.time_per_node_type[kind] += results.time_per_node[i];
1727 results.instances_per_node_type[kind]++;
1728 if (nodes_[i].has_out_variant()) {
1729 results.out_nodes.insert(kind);
1730 results.out_nodes_count++;
1731 } else if (nodes_[i].has_native()) {
1732 results.native_nodes.insert(kind);
1733 }
1734 results.total_time += results.time_per_node[i];
1735 }
1736 results.total_nodes_count = nodes_.size();
1737 results.memory_alloc_time /= num_total_iters;
1738 results.memory_dealloc_time /= num_total_iters;
1739 results.output_dealloc_time /= num_total_iters;
1740 for (const auto& p : results.time_per_node_type) {
1741 const std::string& kind = p.first;
1742 results.percent_per_node_type[kind] = p.second / results.total_time * 100;
1743 }
1744 return results;
1745}
1746
1747bool BlockRunner::check_for_memory_leak(
1748 bool output_returned,
1749 bool recurse_on_sub_blocks) {
1750 // check for inputs
1751 for (const auto i : c10::irange(block_info_.num_inputs())) {
1752 TORCH_CHECK(
1753 values_[i + block_info_.block_inputs_idx()].isNone(),
1754 "Input ",
1755 i,
1756 " was not cleaned up");
1757 }
1758 FastSet<const IValue*> output_ivalues(outputs_.begin(), outputs_.end());
1759 for (const auto n : c10::irange(nodes_.size())) {
1760 auto& pnode = nodes_[n];
1761 for (const auto i : c10::irange(pnode.num_outputs())) {
1762 const IValue* ival = &pnode.Output(i);
1763 const Value* val = pnode.node()->output(i);
1764 // subtlety: isManagedOutputTensorValue may give a false
1765 // negative here if an output is an alias of this value, so
1766 // check the actual tensor!
1767 if (planner_ &&
1768 (isManagedOutputTensor(*ival) || isManagedOutputTensorValue(val))) {
1769 // `ival` contains a managed output tensor that the runtime doesn't
1770 // reclaim at the end of an iteration, but the client does so
1771 // by explicitly calling
1772 // `BlockRunner::deallocateOutputTensors`.
1773 continue;
1774 }
1775 const std::string error_msg = "Output " + c10::to_string(i) + ", %" +
1776 val->debugName() + " of node " + c10::to_string(n) +
1777 " which has kind " + pnode.node()->kind().toQualString() +
1778 " was not cleaned up";
1779 if (output_ivalues.count(ival) == 0) {
1780 // check for intermediates
1781 if (!ival->isNone()) {
1782 TORCH_CHECK(
1783 ival->isTensor() ||
1784 block_info_.node_is_optimizable_container_type(
1785 pnode.node()) ||
1786 doesNotHeapAllocateWhenStoredInIValue(*val->type()),
1787 error_msg);
1788 if (ival->isTensor()) {
1789 const auto& t = ival->toTensor();
1790 if (t.defined()) {
1791 auto* storage_impl = t.storage().unsafeGetStorageImpl();
1792 TORCH_CHECK(
1793 storage_impl->data() == nullptr ||
1794 (planner_ &&
1795 planner_->isManagedStorageImpl(storage_impl)),
1796 error_msg);
1797 }
1798 }
1799 }
1800 } else {
1801 // check for outputs
1802 if (output_returned) {
1803 TORCH_CHECK(ival->isNone(), error_msg);
1804 }
1805 }
1806 }
1807 auto* metadata = pnode.metadata();
1808 if (recurse_on_sub_blocks && metadata) {
1809 auto& block_runners = metadata->block_runners();
1810 for (auto& block_runner : block_runners) {
1811 block_runner.check_for_memory_leak(
1812 output_returned, recurse_on_sub_blocks);
1813 }
1814 }
1815 }
1816 VLOG(1) << "Finished checking for memory leak";
1817 return true;
1818}
1819
1820void BlockRunner::deallocateOutputTensors() {
1821 if (!static_module_.opts().manage_output_tensors) {
1822 TORCH_CHECK(
1823 !planner_ || planner_->numOutputBufferBytes() == 0,
1824 "manage_output_tensors is disabled, but output tensor buffer is not empty.");
1825 return;
1826 }
1827 if (planner_) {
1828 planner_->deallocateOutputTensors();
1829 DCHECK(checkOutputTensorMemoryLeaks());
1830 }
1831}
1832
1833bool BlockRunner::checkOutputTensorMemoryLeaks() {
1834 if (!static_module_.opts().manage_output_tensors || !planner_) {
1835 return true;
1836 }
1837 for (const auto n : c10::irange(nodes_.size())) {
1838 auto& pnode = nodes_[n];
1839 for (const auto i : c10::irange(pnode.num_outputs())) {
1840 const IValue* ival = &pnode.Output(i);
1841 const Value* val = pnode.node()->output(i);
1842 if (!isManagedOutputTensorValue(val) || !ival->isTensor()) {
1843 // ival can not be a tensor if it's being managed by ops like
1844 // to_maybe_copy_out; see ReplaceWithMaybeCopy for details.
1845 continue;
1846 }
1847 const auto& t = ival->toTensor();
1848 if (t.defined()) {
1849 auto* storage_impl = t.storage().unsafeGetStorageImpl();
1850 const std::string error_msg = "Output " + c10::to_string(i) + ", %" +
1851 val->debugName() + " of node " + c10::to_string(n) +
1852 " was not cleaned up";
1853 TORCH_CHECK(storage_impl->data() == nullptr, error_msg);
1854 }
1855 }
1856 }
1857 VLOG(1) << "Finished checking for memory leak from output tensors";
1858 return true;
1859}
1860
1861bool BlockRunner::isManagedOutputTensor(const IValue& ivalue) const {
1862 return planner_ && planner_->isManagedOutputTensor(ivalue);
1863}
1864
1865bool BlockRunner::isManagedOutputTensorValue(const Value* value) const {
1866 // It's possible that manage_output_tensors_ was disabled after initializing
1867 // managed_output_tensor_values, so we have to check that flag here.
1868 if (!planner_ || !manage_output_tensors_enabled_) {
1869 return false;
1870 }
1871 const auto& managed_outputs = block_info_.managed_output_tensor_values();
1872 return managed_outputs.find(value) != managed_outputs.end();
1873}
1874
1875void BlockRunner::disableManageOutputTensors() {
1876 if (!manage_output_tensors_enabled_) {
1877 return;
1878 }
1879 manage_output_tensors_enabled_ = false;
1880 if (!planner_) {
1881 return;
1882 }
1883 // Reset all IValues and destruct planner_ so that it can be reconstructed in
1884 // the next run.
1885 for (auto& n : nodes_) {
1886 for (const auto i : c10::irange(n.outputs().size())) {
1887 n.Output(i) = IValue();
1888 }
1889 }
1890 planner_.reset();
1891}
1892
1893ProcessedFunction::ProcessedFunction(
1894 Node* node,
1895 bool enable_out_variant,
1896 bool check_memory_overlap)
1897 : check_memory_overlap_(check_memory_overlap),
1898 num_outputs_(node->outputs().size()) {
1899 if (enable_out_variant) {
1900 f_ = getOutOfPlaceOperation(node);
1901 if (f_) {
1902 kind_ = ProcessedFunction::Kind::kOutVariant;
1903 // do not check memory overlap for out variants
1904 check_memory_overlap_ = false;
1905 VLOG(1) << "Switch to out variant for node: " << PrintNode(node);
1906 return;
1907 }
1908 }
1909 {
1910 f_ = getNativeOperation(node);
1911 if (f_) {
1912 kind_ = ProcessedFunction::Kind::kNativeFunction;
1913#ifdef NDEBUG
1914 // skip this check in opt mode because these ops are better vetted
1915 check_memory_overlap_ = false;
1916#endif
1917 VLOG(1) << "Switch to native impl for node: " << PrintNode(node);
1918 return;
1919 }
1920 }
1921 {
1922 const Operator& op = node->getOperator();
1923 f_ = [node_op = op.getOperation(node),
1924 has_var_args = hasVarArgs(node)](ProcessedNode* pnode) mutable {
1925 std::vector<IValue> stack;
1926 const size_t size = pnode->num_inputs();
1927 stack.reserve(size + has_var_args);
1928 for (const auto i : c10::irange(size)) {
1929 stack.emplace_back(pnode->Input(i));
1930 }
1931 // Need to store the number of inputs in stack for variadic ops.
1932 if (has_var_args) {
1933 stack.emplace_back(static_cast<int>(size));
1934 }
1935 node_op(stack);
1936 TORCH_DCHECK_EQ(stack.size(), pnode->num_outputs());
1937 for (const auto i : c10::irange(pnode->num_outputs())) {
1938 pnode->Output(i) = std::move(stack[i]);
1939 }
1940 };
1941 kind_ = ProcessedFunction::Kind::kInterpreterFallback;
1942 VLOG(1) << "Fallback interpreter for node: " << PrintNode(node);
1943 }
1944}
1945
1946StaticNodeInfo::StaticNodeInfo(
1947 Node* node,
1948 ProcessedFunction* fn,
1949 ProcessedNodeInputs inputs,
1950 uint16_t outputs_offset)
1951 : node_(node),
1952 fn_(fn),
1953 inputs_(std::move(inputs)),
1954 outputs_offset_(outputs_offset) {
1955 TORCH_CHECK(num_outputs() == node->outputs().size());
1956}
1957
1958std::vector<IValue> ProcessedNode::inputs_ivalue_vec() const {
1959 std::vector<IValue> result;
1960 result.reserve(inputs_.size());
1961 for (const auto idx : c10::irange(num_inputs())) {
1962 result.emplace_back(Input(idx));
1963 }
1964 return result;
1965}
1966
1967void ProcessedNode::run() {
1968#ifndef PYTORCH_DISABLE_PER_OP_PROFILING
1969 auto step_callbacks =
1970 at::getStepCallbacksUnlessEmpty(at::RecordScope::STATIC_RUNTIME_OP);
1971 if (C10_UNLIKELY(step_callbacks.has_value())) {
1972 at::RecordFunction guard(std::move(*step_callbacks));
1973 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(guard.isActive());
1974 if (guard.needsInputs()) {
1975 const auto inputs = inputs_ivalue_vec();
1976 guard.before(
1977 get_op_name(),
1978 c10::ArrayRef<const IValue>(inputs.data(), inputs.size()));
1979 } else {
1980 guard.before(get_op_name());
1981 }
1982 if (has_out_variant()) {
1983 guard._setStaticRuntimeOutVariant();
1984 }
1985
1986 fn_->run(this);
1987 } else {
1988 fn_->run(this);
1989 }
1990#else
1991 fn_->run(this);
1992#endif
1993#ifndef NDEBUG
1994 if (FLAGS_static_runtime_disable_debug_memory_overlap_check) {
1995 // run check but do not enforce
1996 verify_no_memory_overlap();
1997 } else {
1998 DCHECK(verify_no_memory_overlap());
1999 }
2000#endif
2001}
2002
2003static bool checkNoMemoryOverlap(const at::Tensor& a, const at::Tensor& b) {
2004 at::MemOverlapStatus status = at::get_overlap_status(a, b);
2005 if (status == at::MemOverlapStatus::Full ||
2006 status == at::MemOverlapStatus::Partial) {
2007 return false;
2008 }
2009 if (status == at::MemOverlapStatus::TooHard) {
2010 VLOG(1) << "Detected TOO_HARD memory overlap status";
2011 }
2012 return true;
2013}
2014
2015bool ProcessedNode::verify_no_memory_overlap(bool force_check) const {
2016 const static std::array<c10::Symbol, 7> special_case_ops = {
2017 fromQualString("prim::TypeCheck"),
2018 fromQualString("prim::IfThenElse"),
2019 fromQualString("static_runtime::select_tensor"),
2020 fromQualString("static_runtime::VarTupleUnpack"),
2021 fromQualString("static_runtime::dict_unpack"),
2022 fromQualString("static_runtime::fused_split_and_squeeze"),
2023 fromQualString("static_runtime::create_owned_ref")};
2024 if (!force_check &&
2025 std::find(
2026 begin(special_case_ops), end(special_case_ops), node()->kind()) !=
2027 end(special_case_ops)) {
2028 return true;
2029 }
2030
2031 return verify_outputs_dont_overlap_each_other() &&
2032 verify_inputs_dont_overlap_outputs(force_check);
2033}
2034
2035bool ProcessedNode::verify_outputs_dont_overlap_each_other() const {
2036 for (const auto i : c10::irange(num_outputs())) {
2037 if (!Output(i).isTensor()) {
2038 continue;
2039 }
2040 const auto& out0_t = Output(i).toTensor();
2041 for (const auto j : c10::irange(i + 1, num_outputs())) {
2042 if (!Output(j).isTensor()) {
2043 continue;
2044 }
2045 const auto& out1_t = Output(j).toTensor();
2046 if (!checkNoMemoryOverlap(out0_t, out1_t)) {
2047 LOG(INFO) << "Node output " << i << " overlaps with output " << j
2048 << ", " << PrintNode(node_);
2049 return false;
2050 }
2051 }
2052 }
2053 return true;
2054}
2055
2056bool ProcessedNode::verify_inputs_dont_overlap_outputs(bool force_check) const {
2057 auto schema = node()->maybeSchema();
2058 // skip memory overlap check for mutable or view ops with only one output
2059 bool skip_check = !schema ||
2060 ((schema->is_mutable() || !fn_->checkMemoryOverlap()) &&
2061 num_outputs() == 1);
2062 if (!schema || (!force_check && skip_check)) {
2063 if (!schema) {
2064 VLOG(2) << "Detected that op schema is null";
2065 return true;
2066 }
2067 VLOG(2) << "schema->is_mutable: " << schema->is_mutable()
2068 << ", fn_->checkMemoryOverlap: " << fn_->checkMemoryOverlap()
2069 << ", num_outputs_: " << num_outputs();
2070 return true;
2071 }
2072
2073 for (const auto i : c10::irange(inputs_.size())) {
2074 const IValue* in = &Input(i);
2075 if (!in->isTensor()) {
2076 continue;
2077 }
2078 const auto& in_t = in->toTensor();
2079 for (const auto j : c10::irange(num_outputs())) {
2080 const IValue& out = Output(j);
2081 if (!out.isTensor()) {
2082 continue;
2083 }
2084 const auto& out_t = out.toTensor();
2085 if (!checkNoMemoryOverlap(in_t, out_t)) {
2086 LOG(INFO) << "Node input " << i << " overlaps with output " << j << ", "
2087 << PrintNode(node_);
2088 LOG(INFO) << *schema;
2089 return false;
2090 }
2091 }
2092 }
2093 return true;
2094}
2095
2096bool ProcessedNode::check_and_correct_overlap_with(
2097 const at::Tensor& input,
2098 c10::IValue& output_ival) {
2099 auto& tensor = output_ival.toTensor();
2100 if (!checkNoMemoryOverlap(input, tensor)) {
2101 DLOG(INFO) << "Detected alias for node: " << PrintNode(node());
2102 output_ival = at::native::clone(tensor, c10::nullopt);
2103 set_outputs_memory_overlap_detected();
2104 return true;
2105 }
2106 return false;
2107}
2108
2109void ProcessedNode::verify_and_correct_memory_overlap() {
2110 for (const auto i : c10::irange(inputs_.size())) {
2111 const IValue& in = Input(i);
2112 if (!in.isTensor()) {
2113 continue;
2114 }
2115 const auto& in_t = in.toTensor();
2116 for (const auto j : c10::irange(num_outputs())) {
2117 auto& output = Output(j);
2118 if (output.isTensor()) {
2119 check_and_correct_overlap_with(in_t, output);
2120 } else if (output.isTensorList()) {
2121 auto tensors = output.toListRef();
2122 for (const auto& ival : tensors) {
2123 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
2124 check_and_correct_overlap_with(in_t, const_cast<c10::IValue&>(ival));
2125 }
2126#ifdef FBCODE_CAFFE2
2127 if (outputs_memory_overlap_detected()) {
2128 LOG_EVERY_MS(WARNING, 60000)
2129 << "Detected alias for node: " << PrintNode(node());
2130 }
2131#endif
2132 }
2133 }
2134 }
2135}
2136
2137StaticRuntime::StaticRuntime(const StaticModule& sm)
2138 : values_(sm.value_buffer_size()) {
2139 std::copy(sm.constants().begin(), sm.constants().end(), values_.data());
2140 // default task launcher set to inter-op thread pool
2141 async_task_launcher_ = at::launch;
2142 block_ = std::make_unique<BlockRunner>(
2143 sm,
2144 values_.data(),
2145 sm.root_block(),
2146 &async_task_launcher_,
2147 true /*is_root_block*/);
2148}
2149
2150c10::IValue StaticRuntime::operator()(
2151 const std::vector<c10::IValue>& args,
2152 const KeywordArgs& kwargs) {
2153 return (*block_)(args, kwargs);
2154}
2155
2156c10::IValue StaticRuntime::operator()(
2157 std::vector<c10::IValue>&& args,
2158 const KeywordArgs& kwargs) {
2159 return (*block_)(std::move(args), kwargs);
2160}
2161
2162c10::intrusive_ptr<c10::ivalue::Future> StaticRuntime::runAsync(
2163 const std::vector<c10::IValue>& args,
2164 const KeywordArgs& kwargs,
2165 torch::jit::TaskLauncher taskLauncher) {
2166 async_task_launcher_ = std::move(taskLauncher);
2167 return block_->runAsync(args, kwargs);
2168}
2169
2170c10::intrusive_ptr<c10::ivalue::Future> StaticRuntime::runAsync(
2171 std::vector<c10::IValue>&& args,
2172 const KeywordArgs& kwargs,
2173 torch::jit::TaskLauncher taskLauncher) {
2174 async_task_launcher_ = std::move(taskLauncher);
2175 return block_->runAsync(std::move(args), kwargs);
2176}
2177
2178bool StaticRuntime::check_for_memory_leak(bool output_returned) {
2179 return block_->check_for_memory_leak(
2180 output_returned, /* recurse_on_sub_blocks */ true);
2181}
2182
2183bool StaticRuntime::checkOutputTensorMemoryLeaks() {
2184 return block_->checkOutputTensorMemoryLeaks();
2185}
2186
2187void StaticRuntime::deallocateOutputTensors() {
2188 block_->deallocateOutputTensors();
2189}
2190
2191bool StaticRuntime::isManagedOutputTensor(const IValue& ivalue) const {
2192 return block_->isManagedOutputTensor(ivalue);
2193}
2194
2195void StaticRuntime::disableManageOutputTensors() {
2196 block_->disableManageOutputTensors();
2197}
2198
2199const MemoryPlanner* StaticRuntime::get_memory_planner() const {
2200 return block_->get_memory_planner();
2201}
2202
2203} // namespace jit
2204} // namespace torch
2205