1#pragma once
2
3#include <ir_all_nodes.h>
4#include <kernel_ir.h>
5
6#include <vector>
7
8namespace torch {
9namespace jit {
10namespace fuser {
11namespace cuda {
12
13struct IndexFromIdGraph;
14
15//! Insert magic zero definition at the begining of the kernel. Insert magic
16//! zero update after every (outer most) loop nest with a compile time extent.
17//!
18//! This will make sure nvrtc does not aggressively save predicate and indices.
19std::vector<Expr*> insertMagicZero(const std::vector<Expr*>& exprs);
20
21//! Check if val is a reference to the magic zero variable
22TORCH_CUDA_CU_API bool isMagicZero(const Val* val);
23
24//! Check if val is protected with magic zero.
25//!
26//! Specifically, this returns true if val is defined as "x + magic_zero".
27bool isProtectedWithMagicZero(const Val* val);
28
29// Determine if we may run into over reuse of predicates or registers in the
30// compiler. If the loop can be unrolled and the index and domain are not
31// "simple" we likely want the loop protected.
32//
33// Magic zero protection should only be done for global memory and predicates.
34// We should avoid use on registers. Shared memory does not require it, but
35// likely wouldn't hurt.
36bool needsMagicZero(
37 kir::ForLoop* loop,
38 IterDomain* reference_domain = nullptr,
39 Val* ind = nullptr);
40
41struct IndexMagicZeroInfo {
42 //! Index that may be updated with magic zero
43 Val* index = nullptr;
44 //! Loop index that is protected by magic zero. nullptr if no loop
45 //! is protected
46 Val* original_loop_index = nullptr;
47 //! Protected loop index. nullptr if no loop is protected
48 Val* protected_loop_index = nullptr;
49 //! Protected loop. nullptr if no loop is protected
50 IterDomain* loop_id = nullptr;
51};
52
53//! Protect an index val of an IterDomain with magic zero
54//!
55//! This should be only used for predicate indexing.
56//!
57//! No protection is done if none of the loops is determined to require
58//! protection by needsMagicZero.
59IndexMagicZeroInfo protectPredicateIndexWithMagicZero(
60 Val* index,
61 const IndexFromIdGraph& id_graph,
62 const std::vector<kir::ForLoop*>& loops);
63
64//! Protect an index val of a tensor with magic zero
65//!
66//! This should be only used for non-predicate indexing.
67//!
68//! No protection is done if none of the loops is determined to require
69//! protection by needsMagicZero.
70void protectNonPredicateIndexWithMagicZero(
71 const std::vector<kir::ForLoop*>& loops,
72 const std::vector<IterDomain*>& loop_domains,
73 std::unordered_map<IterDomain*, Val*>& concrete_loop_idx_map);
74
75} // namespace cuda
76} // namespace fuser
77} // namespace jit
78} // namespace torch
79