1 | #include <partition.h> |
2 | |
3 | #include <ATen/core/jit_type.h> |
4 | #include <ATen/cuda/CUDAContext.h> |
5 | #include <c10/util/irange.h> |
6 | #include <instrumentation.h> |
7 | #include <parser.h> |
8 | #include <utils.h> |
9 | #include <torch/csrc/jit/jit_log.h> |
10 | |
11 | namespace torch { |
12 | namespace jit { |
13 | namespace fuser { |
14 | namespace cuda { |
15 | |
16 | const c10::DeviceIndex INVALID_INDEX = -2; |
17 | |
18 | namespace { |
19 | |
20 | bool hasNonElementWiseOperation(const Node* node) { |
21 | if (node->kind() == prim::CudaFusionGroup) { |
22 | for (auto n : node->g(attr::Subgraph)->nodes()) { |
23 | if (hasNonElementWiseOperation(n)) { |
24 | return true; |
25 | } |
26 | } |
27 | } else { |
28 | // prim::Constant is not parsible, but it is also not nonElementWise |
29 | if (node->kind() != prim::Constant && !isElementWiseNode(node)) { |
30 | return true; |
31 | } |
32 | } |
33 | return false; |
34 | } |
35 | |
36 | // Check all outputs are: |
37 | // 1. TensorType |
38 | // 2. on the same device; |
39 | // TODO: update this when codegen can output scalar |
40 | static c10::optional<c10::Device> getDevice(const Value* value) { |
41 | if (!value->type()->isSubtypeOf(*TensorType::get())) { |
42 | // not tensor type, return false as the op is not outputing scalar. |
43 | return c10::nullopt; |
44 | } |
45 | auto tensor_type = value->type()->expectRef<TensorType>(); |
46 | // special case for scalar tensor: return c10::nullopt instead of cpu device. |
47 | // this allows us to fuse scalar cpu tensor with cuda tensor, while avoid |
48 | // merging ops with pure scalar cpu tensors. |
49 | if (is_cpu_scalar(tensor_type)) { |
50 | return c10::nullopt; |
51 | } |
52 | return tensor_type.device(); |
53 | } |
54 | |
55 | static bool hasBfloat(const Node* node) { |
56 | auto has_bfloat = [](const Value* value) { |
57 | if (!value->type()->isSubtypeOf(*TensorType::get())) { |
58 | return false; |
59 | } |
60 | auto opt_scalar_type = value->type()->expectRef<TensorType>().scalarType(); |
61 | if (opt_scalar_type.has_value() && |
62 | opt_scalar_type.value() == at::ScalarType::BFloat16) { |
63 | return true; |
64 | } |
65 | return false; |
66 | }; |
67 | |
68 | if (std::any_of(node->inputs().begin(), node->inputs().end(), has_bfloat) || |
69 | std::any_of(node->outputs().begin(), node->outputs().end(), has_bfloat)) { |
70 | return true; |
71 | } |
72 | return false; |
73 | } |
74 | |
75 | static c10::optional<c10::Device> getDevice(const Node* node) { |
76 | c10::optional<c10::Device> ret = c10::nullopt; |
77 | auto merge_devices = [&ret](const c10::optional<c10::Device>& device) { |
78 | if (device.has_value()) { |
79 | if (ret.has_value()) { |
80 | if (ret.value() != device.value()) { |
81 | // invalidate device to reflect conflicts |
82 | ret->set_index(INVALID_INDEX); |
83 | // return false to indicate early termination |
84 | return false; |
85 | } else { |
86 | // same device, do nothing |
87 | return true; |
88 | } |
89 | } else { |
90 | // initialize return device |
91 | ret = device.value(); |
92 | return true; |
93 | } |
94 | } |
95 | // no device information, do nothing |
96 | return true; |
97 | }; |
98 | for (auto val : node->inputs()) { |
99 | if (!merge_devices(getDevice(val))) { |
100 | return ret; |
101 | } |
102 | } |
103 | for (auto val : node->outputs()) { |
104 | if (!merge_devices(getDevice(val))) { |
105 | return ret; |
106 | } |
107 | } |
108 | return ret; |
109 | } |
110 | |
111 | static bool isDeviceCompatible(const Node* node, const c10::Device& device) { |
112 | // only fuses cuda device |
113 | if (!device.is_cuda()) { |
114 | GRAPH_UPDATE("rejecting node (non-cuda device): " , *node); |
115 | return false; |
116 | } |
117 | const auto major = at::cuda::getDeviceProperties(device.index())->major; |
118 | // disable non-elementwise fusion on pre-volta devices |
119 | if (major < 7 && hasNonElementWiseOperation(node)) { |
120 | GRAPH_UPDATE( |
121 | "rejecting node (non element-wise op not supported on SM < 7X): " , |
122 | *node); |
123 | return false; |
124 | } |
125 | // disable bfloat fusion on pre-ampere devices |
126 | if (major < 8 && hasBfloat(node)) { |
127 | GRAPH_UPDATE("rejecting node (bfloat not supported on SM < 8X): " , *node); |
128 | return false; |
129 | } |
130 | return true; |
131 | } |
132 | |
133 | static bool isFusibleDevice(const Node* node, const c10::Device& device) { |
134 | TORCH_INTERNAL_ASSERT( |
135 | device.index() != INVALID_INDEX, "fusible device needs to be validate" ); |
136 | auto opt_device = getDevice(node); |
137 | // we can be more relaxed here as we known that this function tries to merge |
138 | // node into an existing `device` |
139 | if (opt_device.has_value() && |
140 | (opt_device->index() == INVALID_INDEX || opt_device != device)) { |
141 | GRAPH_UPDATE( |
142 | "rejecting node from fusion (outputs device not matching fusion): " , |
143 | *node); |
144 | return false; |
145 | } |
146 | if (!isDeviceCompatible(node, device)) { |
147 | return false; |
148 | } |
149 | return true; |
150 | } |
151 | |
152 | // TODO: we need to check input type when we handle `to()` |
153 | static bool isFusibleDevice(const Node* node) { |
154 | auto device = getDevice(node); |
155 | // be conservative and only fuse cuda operations, this avoids us initializing |
156 | // operations that produces cpu scalar outputs |
157 | if (!device.has_value() || device->index() == INVALID_INDEX) { |
158 | return false; |
159 | } |
160 | |
161 | if (!isDeviceCompatible(node, device.value())) { |
162 | return false; |
163 | } |
164 | return true; |
165 | } |
166 | |
167 | bool compatibleType(const torch::jit::Value* val) { |
168 | if (auto tensor_type = val->type()->cast<c10::TensorType>()) { |
169 | if (tensor_type->scalarType().has_value()) { |
170 | if (aten_to_data_type(tensor_type->scalarType().value()) == |
171 | DataType::Null) { |
172 | return false; |
173 | } |
174 | if (!isOptionEnabled(EnableOption::Complex)) { |
175 | // Complex is disabled by default until its support is completely added |
176 | // TODO: remove this logic |
177 | if (isComplexType( |
178 | aten_to_data_type(tensor_type->scalarType().value()))) { |
179 | return false; |
180 | } |
181 | } |
182 | } |
183 | // magic number 8 here since our kernel argument only supports rank <= 8 |
184 | if (tensor_type->dim().has_value() && (tensor_type->dim().value() > 8)) { |
185 | return false; |
186 | } |
187 | } |
188 | return true; |
189 | } |
190 | |
191 | bool checkInputTensorTypes(const Node* node) { |
192 | for (const auto i : c10::irange(node->inputs().size())) { |
193 | const auto& val = node->inputs()[i]; |
194 | if (!compatibleType(val)) { |
195 | // special case on aten::_batch_norm_impl_index_backward, the 11th output |
196 | // is going to be discarded, so no need to check data type there. |
197 | if (node->kind() == |
198 | c10::Symbol::fromQualString( |
199 | "aten::_batch_norm_impl_index_backward" ) && |
200 | i == 11) { |
201 | continue; |
202 | } |
203 | return false; |
204 | } |
205 | } |
206 | return true; |
207 | } |
208 | |
209 | bool checkOutputTensorTypes(const Node* node) { |
210 | for (const auto i : c10::irange(node->outputs().size())) { |
211 | const auto& val = node->outputs()[i]; |
212 | if (!compatibleType(val)) { |
213 | // special case on aten::_batch_norm_impl_index, the 4th output |
214 | // is going to be discarded, so no need to check data type there. |
215 | if (node->kind() == |
216 | c10::Symbol::fromQualString("aten::_batch_norm_impl_index" ) && |
217 | i == 3) { |
218 | continue; |
219 | } |
220 | return false; |
221 | } |
222 | } |
223 | return true; |
224 | } |
225 | |
226 | inline bool isFusibleNode(const Node* node) { |
227 | // Check if already part of a fusion group |
228 | if (node->kind() == prim::CudaFusionGroup) |
229 | return true; |
230 | // Check we have a parsing rule |
231 | if (!isNodeParsible(node)) { |
232 | // ignoring profile nodes & constant nodes to avoid noise from debugging |
233 | if (node->kind() != prim::Constant && |
234 | node->kind() != prim::profile_ivalue && node->kind() != prim::profile && |
235 | node->kind() != prim::Param) { |
236 | GRAPH_UPDATE("rejecting node from fusion (node not parsible): " , *node); |
237 | } |
238 | return false; |
239 | } |
240 | // Check if we have a tensor type it's one we support |
241 | if (!checkInputTensorTypes(node)) { |
242 | GRAPH_UPDATE( |
243 | "rejecting node from fusion (input scalar type not supported): " , |
244 | *node); |
245 | return false; |
246 | } |
247 | if (!checkOutputTensorTypes(node)) { |
248 | GRAPH_UPDATE( |
249 | "rejecting node from fusion (output scalar type not supported): " , |
250 | *node); |
251 | return false; |
252 | } |
253 | return true; |
254 | } |
255 | |
256 | } // namespace |
257 | |
258 | bool isFusibleCudaFusionGroup(const Node* node) { |
259 | FUSER_PERF_SCOPE("isFusibleCudaFusionGroup" ); |
260 | |
261 | if (isFusibleNode(node)) { |
262 | auto ret = isFusibleDevice(node); |
263 | return ret; |
264 | } |
265 | return false; |
266 | } |
267 | |
268 | bool isFusibleCudaFusionGroup(const Node* fusion, const Node* node) { |
269 | FUSER_PERF_SCOPE("isFusibleCudaFusionGroup" ); |
270 | bool fused = false; |
271 | // TODO: lift the restriction of not fusing producer containing reduction when |
272 | // we have proper scheduling. |
273 | if (isFusibleNode(node)) { |
274 | // ensure if the node has a designated device, it's on the same device with |
275 | // fusion. |
276 | // TODO: is there a danger of us fusing operations that's supposed to be on |
277 | // separate GPUs? And is that necessarily bad? |
278 | auto device = getDevice(fusion); |
279 | fused = (!device.has_value() || isFusibleDevice(node, device.value())); |
280 | } |
281 | return fused; |
282 | } |
283 | |
284 | } // namespace cuda |
285 | } // namespace fuser |
286 | } // namespace jit |
287 | } // namespace torch |
288 | |