1#include <partition.h>
2
3#include <ATen/core/jit_type.h>
4#include <ATen/cuda/CUDAContext.h>
5#include <c10/util/irange.h>
6#include <instrumentation.h>
7#include <parser.h>
8#include <utils.h>
9#include <torch/csrc/jit/jit_log.h>
10
11namespace torch {
12namespace jit {
13namespace fuser {
14namespace cuda {
15
16const c10::DeviceIndex INVALID_INDEX = -2;
17
18namespace {
19
20bool hasNonElementWiseOperation(const Node* node) {
21 if (node->kind() == prim::CudaFusionGroup) {
22 for (auto n : node->g(attr::Subgraph)->nodes()) {
23 if (hasNonElementWiseOperation(n)) {
24 return true;
25 }
26 }
27 } else {
28 // prim::Constant is not parsible, but it is also not nonElementWise
29 if (node->kind() != prim::Constant && !isElementWiseNode(node)) {
30 return true;
31 }
32 }
33 return false;
34}
35
36// Check all outputs are:
37// 1. TensorType
38// 2. on the same device;
39// TODO: update this when codegen can output scalar
40static c10::optional<c10::Device> getDevice(const Value* value) {
41 if (!value->type()->isSubtypeOf(*TensorType::get())) {
42 // not tensor type, return false as the op is not outputing scalar.
43 return c10::nullopt;
44 }
45 auto tensor_type = value->type()->expectRef<TensorType>();
46 // special case for scalar tensor: return c10::nullopt instead of cpu device.
47 // this allows us to fuse scalar cpu tensor with cuda tensor, while avoid
48 // merging ops with pure scalar cpu tensors.
49 if (is_cpu_scalar(tensor_type)) {
50 return c10::nullopt;
51 }
52 return tensor_type.device();
53}
54
55static bool hasBfloat(const Node* node) {
56 auto has_bfloat = [](const Value* value) {
57 if (!value->type()->isSubtypeOf(*TensorType::get())) {
58 return false;
59 }
60 auto opt_scalar_type = value->type()->expectRef<TensorType>().scalarType();
61 if (opt_scalar_type.has_value() &&
62 opt_scalar_type.value() == at::ScalarType::BFloat16) {
63 return true;
64 }
65 return false;
66 };
67
68 if (std::any_of(node->inputs().begin(), node->inputs().end(), has_bfloat) ||
69 std::any_of(node->outputs().begin(), node->outputs().end(), has_bfloat)) {
70 return true;
71 }
72 return false;
73}
74
75static c10::optional<c10::Device> getDevice(const Node* node) {
76 c10::optional<c10::Device> ret = c10::nullopt;
77 auto merge_devices = [&ret](const c10::optional<c10::Device>& device) {
78 if (device.has_value()) {
79 if (ret.has_value()) {
80 if (ret.value() != device.value()) {
81 // invalidate device to reflect conflicts
82 ret->set_index(INVALID_INDEX);
83 // return false to indicate early termination
84 return false;
85 } else {
86 // same device, do nothing
87 return true;
88 }
89 } else {
90 // initialize return device
91 ret = device.value();
92 return true;
93 }
94 }
95 // no device information, do nothing
96 return true;
97 };
98 for (auto val : node->inputs()) {
99 if (!merge_devices(getDevice(val))) {
100 return ret;
101 }
102 }
103 for (auto val : node->outputs()) {
104 if (!merge_devices(getDevice(val))) {
105 return ret;
106 }
107 }
108 return ret;
109}
110
111static bool isDeviceCompatible(const Node* node, const c10::Device& device) {
112 // only fuses cuda device
113 if (!device.is_cuda()) {
114 GRAPH_UPDATE("rejecting node (non-cuda device): ", *node);
115 return false;
116 }
117 const auto major = at::cuda::getDeviceProperties(device.index())->major;
118 // disable non-elementwise fusion on pre-volta devices
119 if (major < 7 && hasNonElementWiseOperation(node)) {
120 GRAPH_UPDATE(
121 "rejecting node (non element-wise op not supported on SM < 7X): ",
122 *node);
123 return false;
124 }
125 // disable bfloat fusion on pre-ampere devices
126 if (major < 8 && hasBfloat(node)) {
127 GRAPH_UPDATE("rejecting node (bfloat not supported on SM < 8X): ", *node);
128 return false;
129 }
130 return true;
131}
132
133static bool isFusibleDevice(const Node* node, const c10::Device& device) {
134 TORCH_INTERNAL_ASSERT(
135 device.index() != INVALID_INDEX, "fusible device needs to be validate");
136 auto opt_device = getDevice(node);
137 // we can be more relaxed here as we known that this function tries to merge
138 // node into an existing `device`
139 if (opt_device.has_value() &&
140 (opt_device->index() == INVALID_INDEX || opt_device != device)) {
141 GRAPH_UPDATE(
142 "rejecting node from fusion (outputs device not matching fusion): ",
143 *node);
144 return false;
145 }
146 if (!isDeviceCompatible(node, device)) {
147 return false;
148 }
149 return true;
150}
151
152// TODO: we need to check input type when we handle `to()`
153static bool isFusibleDevice(const Node* node) {
154 auto device = getDevice(node);
155 // be conservative and only fuse cuda operations, this avoids us initializing
156 // operations that produces cpu scalar outputs
157 if (!device.has_value() || device->index() == INVALID_INDEX) {
158 return false;
159 }
160
161 if (!isDeviceCompatible(node, device.value())) {
162 return false;
163 }
164 return true;
165}
166
167bool compatibleType(const torch::jit::Value* val) {
168 if (auto tensor_type = val->type()->cast<c10::TensorType>()) {
169 if (tensor_type->scalarType().has_value()) {
170 if (aten_to_data_type(tensor_type->scalarType().value()) ==
171 DataType::Null) {
172 return false;
173 }
174 if (!isOptionEnabled(EnableOption::Complex)) {
175 // Complex is disabled by default until its support is completely added
176 // TODO: remove this logic
177 if (isComplexType(
178 aten_to_data_type(tensor_type->scalarType().value()))) {
179 return false;
180 }
181 }
182 }
183 // magic number 8 here since our kernel argument only supports rank <= 8
184 if (tensor_type->dim().has_value() && (tensor_type->dim().value() > 8)) {
185 return false;
186 }
187 }
188 return true;
189}
190
191bool checkInputTensorTypes(const Node* node) {
192 for (const auto i : c10::irange(node->inputs().size())) {
193 const auto& val = node->inputs()[i];
194 if (!compatibleType(val)) {
195 // special case on aten::_batch_norm_impl_index_backward, the 11th output
196 // is going to be discarded, so no need to check data type there.
197 if (node->kind() ==
198 c10::Symbol::fromQualString(
199 "aten::_batch_norm_impl_index_backward") &&
200 i == 11) {
201 continue;
202 }
203 return false;
204 }
205 }
206 return true;
207}
208
209bool checkOutputTensorTypes(const Node* node) {
210 for (const auto i : c10::irange(node->outputs().size())) {
211 const auto& val = node->outputs()[i];
212 if (!compatibleType(val)) {
213 // special case on aten::_batch_norm_impl_index, the 4th output
214 // is going to be discarded, so no need to check data type there.
215 if (node->kind() ==
216 c10::Symbol::fromQualString("aten::_batch_norm_impl_index") &&
217 i == 3) {
218 continue;
219 }
220 return false;
221 }
222 }
223 return true;
224}
225
226inline bool isFusibleNode(const Node* node) {
227 // Check if already part of a fusion group
228 if (node->kind() == prim::CudaFusionGroup)
229 return true;
230 // Check we have a parsing rule
231 if (!isNodeParsible(node)) {
232 // ignoring profile nodes & constant nodes to avoid noise from debugging
233 if (node->kind() != prim::Constant &&
234 node->kind() != prim::profile_ivalue && node->kind() != prim::profile &&
235 node->kind() != prim::Param) {
236 GRAPH_UPDATE("rejecting node from fusion (node not parsible): ", *node);
237 }
238 return false;
239 }
240 // Check if we have a tensor type it's one we support
241 if (!checkInputTensorTypes(node)) {
242 GRAPH_UPDATE(
243 "rejecting node from fusion (input scalar type not supported): ",
244 *node);
245 return false;
246 }
247 if (!checkOutputTensorTypes(node)) {
248 GRAPH_UPDATE(
249 "rejecting node from fusion (output scalar type not supported): ",
250 *node);
251 return false;
252 }
253 return true;
254}
255
256} // namespace
257
258bool isFusibleCudaFusionGroup(const Node* node) {
259 FUSER_PERF_SCOPE("isFusibleCudaFusionGroup");
260
261 if (isFusibleNode(node)) {
262 auto ret = isFusibleDevice(node);
263 return ret;
264 }
265 return false;
266}
267
268bool isFusibleCudaFusionGroup(const Node* fusion, const Node* node) {
269 FUSER_PERF_SCOPE("isFusibleCudaFusionGroup");
270 bool fused = false;
271 // TODO: lift the restriction of not fusing producer containing reduction when
272 // we have proper scheduling.
273 if (isFusibleNode(node)) {
274 // ensure if the node has a designated device, it's on the same device with
275 // fusion.
276 // TODO: is there a danger of us fusing operations that's supposed to be on
277 // separate GPUs? And is that necessarily bad?
278 auto device = getDevice(fusion);
279 fused = (!device.has_value() || isFusibleDevice(node, device.value()));
280 }
281 return fused;
282}
283
284} // namespace cuda
285} // namespace fuser
286} // namespace jit
287} // namespace torch
288