1#pragma once
2#include <c10/macros/Export.h>
3
4#include <ir_all_nodes.h>
5
6#include <vector>
7
8namespace torch {
9namespace jit {
10namespace fuser {
11namespace cuda {
12
13//! Transform for-loop structure to handle misaligned addresses
14//!
15//! Sections of misaligned addresses are handled sequentially
16//! while aligned addresses use vectorized memory accesses.
17//!
18//! ---------------------------------------------------------------------------
19//! Before Misaligned Vectorization:
20//!
21//! Inputs: T0
22//! Outputs: T3
23//!
24//! for(...) {
25//! T1[vector_size];
26//! for( i : vector_size ) {
27//! T1[i] = T0[...]
28//! }
29//!
30//! T2[vector_size];
31//! for( i : vector_size ) {
32//! T2[i] = unaryOp(T1[i])
33//! }
34//!
35//! for( i : vector_size ) {
36//! T3[...] = T2[i]
37//! }
38//! }
39//!
40//! ---------------------------------------------------------------------------
41//! After Misaligned Vectorization:
42//!
43//! Inputs: T0
44//! Outputs: T3
45//!
46//! for(...) {
47//! T1[vector_size];
48//! T2[vector_size];
49//!
50//! if (inline_predicate_except_last_root_domain) {
51//! index_except_last_root_domain = ...
52//! address = (int64_t) &T1[index_except_last_root_domain]
53//!
54//! offset_size = (address % vector_size_bytes) / data_type_size_bytes
55//! shift_init = vector_size - offset_size
56//! shift = (shift_init == vector_size) ? 0 : shift_init
57//!
58//! // size of the last root domain
59//! extent = ...
60//! remainder = (extent - shift) % vector_size
61//!
62//! last_root_domain_index = ...
63//!
64//! // Vectorize Section
65//! if ( (last_root_domain_index + shift) < (extent - remainder) ) {
66//! T1[0] = vectorize_load( T0[index + shift] );
67//!
68//! for( i : vector_size ) {
69//! T2[i] = unaryOp(T1[i])
70//! }
71//!
72//! T3[index + shift] = vectorize_store( T2[0] );
73//! }
74//!
75//! // Initial Section
76//! if ( last_root_domain_index == 0 ) {
77//! for( i : shift ) {
78//! T1[i] = T0[...]
79//! }
80//!
81//! for( i : shift ) {
82//! T2[i] = unaryOp(T1[i])
83//! }
84//!
85//! for( i : shift ) {
86//! T3[...] = T2[i]
87//! }
88//! }
89//!
90//! // Remainder Section
91//! if ( (last_root_domain_index + shift) >= (extent - remainder) &&
92//! (last_root_domain_index + shift) < extent) {
93//!
94//! for( i : remainder ) {
95//! T1[i] = T0[index + shift]
96//! }
97//!
98//! for( i : remainder ) {
99//! T2[i] = unaryOp(T1[i])
100//! }
101//!
102//! for( i : remainder ) {
103//! T3[index + shift] = T2[i]
104//! }
105//! }
106//! }
107//! }
108//!
109std::vector<Expr*> processMisalignedVectorization(
110 const std::vector<Expr*>& exprs);
111
112bool containsAnyDirectChildMisalignedVectorize(const kir::ForLoop* fl);
113
114} // namespace cuda
115} // namespace fuser
116} // namespace jit
117} // namespace torch
118