1 | #include <executor.h> |
2 | #include <fusion.h> |
3 | #include <instrumentation.h> |
4 | #include <ir_iostream.h> |
5 | #include <kernel_cache.h> |
6 | #include <manager.h> |
7 | #include <parser.h> |
8 | #include <scheduler/all_schedulers.h> |
9 | #include <type_inference.h> |
10 | #include <utils.h> |
11 | #include <torch/csrc/jit/jit_log.h> |
12 | #include <torch/csrc/jit/passes/canonicalize.h> |
13 | #include <torch/csrc/jit/passes/cuda_graph_fuser.h> |
14 | #include <torch/csrc/jit/passes/shape_analysis.h> |
15 | #include <torch/csrc/jit/passes/symbolic_shape_analysis.h> |
16 | #include <torch/csrc/jit/runtime/graph_executor.h> |
17 | #include <torch/csrc/jit/runtime/interpreter.h> |
18 | |
19 | #include <ATen/DimVector.h> |
20 | #include <c10/core/DeviceType.h> |
21 | #include <c10/util/irange.h> |
22 | |
23 | #include <unordered_map> |
24 | |
25 | namespace torch { |
26 | namespace jit { |
27 | namespace fuser { |
28 | namespace cuda { |
29 | |
30 | //! [ Note -- cache entry indexing ] |
31 | //! |
32 | //! CudaFusionManager holds the cache and handles interfacing to CudaFusionGroup |
33 | //! node, including selection, construction and execution of FusionExecutors. |
34 | //! |
35 | //! CudaFusionManager bridges PyTorch IR node CudaFusionGroup to GraphCache. |
36 | //! Therefore, we want to cache on stringified graph. But it is expensive to |
37 | //! stringify and hash on a computational graph, we cache the hash of a |
38 | //! stringified graph on node via cache_id. |
39 | //! |
40 | //! CudaFusionGroup node stores: |
41 | //! i. a PyTorch IR in `attr::Subgraph` |
42 | //! ii. an int in `attr::cache_id`, (a cached hash value of |
43 | //! `attr::Subgraph`) |
44 | //! |
45 | //! We have 2 unordered_map at CudaFusionGroup: |
46 | //! std::unordered_map<std::string, int32_t> graph_cache_ids_; |
47 | //! std::unordered_map<int64_t, std::unique_ptr<GraphCache>> graph_cache_; |
48 | //! |
49 | //! Mapping from std::string to graph_cache_id ensures that we assign the same |
50 | //! cache_id to CudaFusionGroup with identical computational grah, allowing |
51 | //! kernel reuse; Direct mapping from cache_id to GraphCache allows efficient |
52 | //! graph_cache indexing; |
53 | |
54 | namespace { |
55 | |
56 | // TODO remove this (75983): |
57 | // we don't need this any more. I think we can use revertAliasCopyOps. |
58 | // Similar refactor should be done infallback graph used by fusion guard. |
59 | // implementation of xxxx_copy ops should be removed. |
60 | // |
61 | // Mark string attribute in alias-copy nodes to enable its implementation |
62 | // in the fallback path. |
63 | void enableAliasCopyNodes(const std::shared_ptr<Graph>& graph, Block* block) { |
64 | static std::unordered_set<Symbol> alias_copy_op( |
65 | {prim::expand_copy, |
66 | prim::expand_as_copy, |
67 | prim::flatten_copy, |
68 | prim::permute_copy, |
69 | prim::reshape_copy, |
70 | prim::squeeze_copy, |
71 | prim::t_copy, |
72 | prim::transpose_copy, |
73 | prim::unsqueeze_copy, |
74 | prim::view_copy}); |
75 | |
76 | for (Node* n : block->nodes()) { |
77 | for (Block* b : n->blocks()) { |
78 | enableAliasCopyNodes(graph, b); |
79 | } |
80 | if (alias_copy_op.find(n->kind()) != alias_copy_op.end()) { |
81 | n->s_(attr::name, "CudaFusionGroup" ); |
82 | } |
83 | } |
84 | } |
85 | |
86 | static std::unique_ptr<Code> createFallbackCode(const Node* fusion_node) { |
87 | auto copied_graph = fusion_node->g(attr::Subgraph)->copy(); |
88 | EraseShapeInformation(copied_graph); |
89 | enableAliasCopyNodes(copied_graph, copied_graph->block()); |
90 | auto code = std::make_unique<Code>(copied_graph, "fallback_cuda_fuser" ); |
91 | return code; |
92 | } |
93 | |
94 | // CudaFusionManager is not thread safe! |
95 | // TODO: we should make the tradeoff here to use thread_local instead of global |
96 | // singleton; |
97 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) |
98 | class CudaFusionManager { |
99 | public: |
100 | static CudaFusionManager& getManager() { |
101 | static CudaFusionManager cuda_fusion_manager_; |
102 | return cuda_fusion_manager_; |
103 | }; |
104 | |
105 | // TODO: I'm assuming we have stride information in `graph->toString` |
106 | // We need to make sure stride information is in the final string, as we |
107 | // want to AVOID kernel reuse between different fusion_node, unless they |
108 | // have identical contiguity information! (So identical stride + shape |
109 | // is even more restricting in a good way) |
110 | int32_t registerOrGetCacheId(std::shared_ptr<Graph>& graph) { |
111 | // prepare graph for lowering; |
112 | // We should not call `EraseShapeInformation(graph);`, graph representation |
113 | // does not incorporate static sizes, but just rank of input tensors, which |
114 | // is exactly what we wanted. |
115 | auto canonical_graph = Canonicalize(graph, false); |
116 | auto repr = canonical_graph->toString(false); |
117 | |
118 | std::lock_guard<std::mutex> guard(mutex_); |
119 | // create new graph_cache_ids_ entry if none existed yet; |
120 | if (graph_cache_ids_.count(repr) == 0) { |
121 | int32_t kernel_id = getNextUniqueID(); |
122 | graph_cache_ids_[repr] = kernel_id; |
123 | TORCH_CHECK( |
124 | graph_cache_.emplace(kernel_id, std::make_unique<GraphCache>(graph)) |
125 | .second); |
126 | } |
127 | return graph_cache_ids_[repr]; |
128 | }; |
129 | |
130 | // get fallback kernel id |
131 | int32_t getFallbackKernelId() { |
132 | std::lock_guard<std::mutex> guard(mutex_); |
133 | return getNextUniqueID(); |
134 | } |
135 | |
136 | void unregisterCacheId(std::shared_ptr<Graph>& graph) { |
137 | auto canonical_graph = Canonicalize(graph, false); |
138 | auto repr = canonical_graph->toString(false); |
139 | |
140 | // create new graph_cache_ids_ entry if none existed yet; |
141 | if (graph_cache_ids_.count(repr) > 0) { |
142 | int32_t kernel_id = graph_cache_ids_[repr]; |
143 | graph_cache_.erase(kernel_id); |
144 | graph_cache_ids_.erase(repr); |
145 | } |
146 | } |
147 | |
148 | std::vector<at::Tensor> runFusionNode( |
149 | int32_t kernel_id, |
150 | const at::ArrayRef<IValue> inputs) { |
151 | std::lock_guard<std::mutex> guard(mutex_); |
152 | TORCH_INTERNAL_ASSERT( |
153 | graph_cache_.count(kernel_id) > 0, "graph cache miss at run time" ); |
154 | return graph_cache_[kernel_id]->runGraphWithInputs(inputs); |
155 | } |
156 | |
157 | bool hasFallbackCode(int32_t kernel_id) { |
158 | std::lock_guard<std::mutex> guard(mutex_); |
159 | return fallback_cache_.count(kernel_id); |
160 | } |
161 | |
162 | Code* getFallbackCode(int32_t kernel_id, const Node* fusion_node) { |
163 | { |
164 | std::lock_guard<std::mutex> guard(mutex_); |
165 | auto it = fallback_cache_.find(kernel_id); |
166 | if (it != fallback_cache_.end()) { |
167 | return it->second.get(); |
168 | } |
169 | } |
170 | |
171 | std::unique_ptr<Code> code = createFallbackCode(fusion_node); |
172 | |
173 | std::lock_guard<std::mutex> guard(mutex_); |
174 | auto it = fallback_cache_.insert({kernel_id, std::move(code)}).first; |
175 | return it->second.get(); |
176 | } |
177 | |
178 | private: |
179 | // TODO: Dimension collapsing should be abstracted out and integrated into |
180 | // graph caching. |
181 | |
182 | // Dimension collapsing only applicable to profiling executor at this moment |
183 | bool graphHasReduction(const std::shared_ptr<Graph>& graph) { |
184 | for (const auto& n : graph->nodes()) { |
185 | if (isReductionNode(n)) { |
186 | return true; |
187 | } |
188 | } |
189 | return false; |
190 | } |
191 | |
192 | private: |
193 | std::mutex mutex_; |
194 | |
195 | void runCudaKernel( |
196 | int32_t key, |
197 | const std::vector<int>& contiguity_tag, |
198 | const c10::Device){}; |
199 | |
200 | int32_t getNextUniqueID() { |
201 | return next_unique_id_++; |
202 | }; |
203 | |
204 | std::unordered_map<std::string, int32_t> graph_cache_ids_; |
205 | std::unordered_map<int64_t, std::unique_ptr<GraphCache>> graph_cache_; |
206 | std::unordered_map<int64_t, std::unique_ptr<Code>> fallback_cache_; |
207 | |
208 | int32_t next_unique_id_ = 0; |
209 | }; |
210 | |
211 | } // namespace |
212 | |
213 | void compileCudaFusionGroup(Node* fusion_node) { |
214 | FUSER_PERF_SCOPE("nvFuser::Manager::compileCudaFusionGroup" ); |
215 | |
216 | TORCH_CHECK( |
217 | fusion_node->kind() == prim::CudaFusionGroup, |
218 | "Only prim::CudaFusionGroup can be compiled" ); |
219 | if (fusion_node->hasAttribute(attr::cache_id)) { |
220 | TORCH_WARN("Double registration of CudaFusionGroup on CudaFusionManager" ); |
221 | } |
222 | // This is not a critical code path, it's OK to do graph copy here; |
223 | auto graph = fusion_node->g(attr::Subgraph)->copy(); |
224 | |
225 | auto compile_fusion = [&]() { |
226 | // type propagation is needed, as the protocol only requires scalar type on |
227 | // input tensors. |
228 | // Note that even for Profiling Executor, scalar type could still be |
229 | // missing, especially for output tensor from a given node (as profiling |
230 | // node only insert meta information after itself). |
231 | PropagateShapesOnGraph(graph); |
232 | TypePropagate(graph); |
233 | |
234 | int32_t fusion_cache_id = |
235 | CudaFusionManager::getManager().registerOrGetCacheId(graph); |
236 | fusion_node->i_(attr::cache_id, fusion_cache_id); |
237 | }; |
238 | |
239 | if (useFallback()) { |
240 | try { |
241 | compile_fusion(); |
242 | } catch (...) { |
243 | TORCH_WARN( |
244 | "FALLBACK path has been taken inside: " , |
245 | __FUNCTION__, |
246 | ". This is an indication that codegen Failed for some reason.\n" |
247 | "To debug try disable codegen fallback path via setting the env" |
248 | " variable `export PYTORCH_NVFUSER_DISABLE=fallback`\n" |
249 | "To report the issue, try enable logging via setting the env" |
250 | "variable ` export PYTORCH_JIT_LOG_LEVEL=manager.cpp`\n" ); |
251 | GRAPH_DUMP("`compile_fusion` hits fallback on graph\n" , graph); |
252 | CudaFusionManager::getManager().unregisterCacheId(graph); |
253 | } |
254 | } else { |
255 | compile_fusion(); |
256 | } |
257 | |
258 | // Assigning a cache_id to facilitate graph execution and fallback |
259 | if (!fusion_node->hasAttribute(attr::cache_id)) { |
260 | int32_t fusion_cache_id = |
261 | CudaFusionManager::getManager().getFallbackKernelId(); |
262 | fusion_node->i_(attr::cache_id, fusion_cache_id); |
263 | } |
264 | } |
265 | |
266 | void runCudaFusionGroup(const Node* fusion_node, Stack& stack) { |
267 | FUSER_PERF_SCOPE("nvFuser::Manager::runCudaFusionGroup" ); |
268 | TORCH_CHECK( |
269 | fusion_node->hasAttribute(attr::cache_id), |
270 | "node prim::CudaFusionGroup has not been compiled yet" ); |
271 | |
272 | // Fallback to use if anything goes wrong |
273 | auto take_fallback = [&](Stack& stack) { |
274 | std::unique_ptr<Code> fallback_code_unique; |
275 | Code* fallback_code; |
276 | int32_t kernel_id = fusion_node->i(attr::cache_id); |
277 | fallback_code = |
278 | CudaFusionManager::getManager().getFallbackCode(kernel_id, fusion_node); |
279 | InterpreterState{*fallback_code}.run(stack); |
280 | }; |
281 | |
282 | c10::optional<Stack> stack_copy; |
283 | auto compare_callback = getCudaFuserComparisonCallback(); |
284 | if (compare_callback.run_fallback) { |
285 | // make a copy of the stack |
286 | int64_t inputs_size = |
287 | static_cast<int64_t>(fusion_node->g(attr::Subgraph)->inputs().size()); |
288 | TORCH_INTERNAL_ASSERT(int64_t(stack.size()) >= inputs_size); |
289 | stack_copy = Stack(); |
290 | stack_copy->insert( |
291 | stack_copy->end(), stack.begin(), stack.end() - inputs_size); |
292 | // deepcopy the last (inputs_size) stack items |
293 | std::transform( |
294 | stack.end() - inputs_size, |
295 | stack.end(), |
296 | std::back_inserter(*stack_copy), |
297 | [](const c10::IValue& ivalue) { return ivalue.deepcopy(); }); |
298 | } |
299 | |
300 | auto run_fusion = [&]() { |
301 | TORCH_CHECK( |
302 | fusion_node->kind() == prim::CudaFusionGroup, |
303 | "prim::CudaFusionGroup expected" ); |
304 | int32_t kernel_id = fusion_node->i(attr::cache_id); |
305 | // Currently we just construct I/O tensors for static graph; |
306 | |
307 | const auto nInputs = fusion_node->g(attr::Subgraph)->inputs().size(); |
308 | |
309 | at::ArrayRef<IValue> inputs = last(stack, nInputs); |
310 | |
311 | auto outputs = |
312 | CudaFusionManager::getManager().runFusionNode(kernel_id, inputs); |
313 | |
314 | drop(stack, inputs.size()); |
315 | stack.insert( |
316 | stack.end(), |
317 | std::make_move_iterator(outputs.begin()), |
318 | std::make_move_iterator(outputs.end())); |
319 | }; |
320 | |
321 | if (useFallback()) { |
322 | try { |
323 | // if fusion failed once, it's likely to fail again; and failures are |
324 | // slow. So if the fusion fails, then record the failure and always use |
325 | // the fallback instead |
326 | int32_t kernel_id = fusion_node->i(attr::cache_id); |
327 | bool force_fallback = |
328 | CudaFusionManager::getManager().hasFallbackCode(kernel_id); |
329 | if (force_fallback) { |
330 | take_fallback(stack); |
331 | } else { |
332 | run_fusion(); |
333 | } |
334 | } catch (...) { |
335 | TORCH_WARN( |
336 | "FALLBACK path has been taken inside: " , |
337 | __FUNCTION__, |
338 | ". This is an indication that codegen Failed for some reason.\n" |
339 | "To debug try disable codegen fallback path via setting the env" |
340 | " variable `export PYTORCH_NVFUSER_DISABLE=fallback`\n" ); |
341 | take_fallback(stack); |
342 | } |
343 | } else { |
344 | run_fusion(); |
345 | } |
346 | |
347 | if (compare_callback.callback != nullptr) { |
348 | Stack fused_outputs; |
349 | Stack fallback_outputs; |
350 | int64_t output_count = |
351 | static_cast<int64_t>(fusion_node->g(attr::Subgraph)->outputs().size()); |
352 | TORCH_CHECK( |
353 | output_count <= int64_t(stack.size()), |
354 | "Expected " , |
355 | output_count, |
356 | " outputs but found only " , |
357 | stack.size(), |
358 | " items on the stack" ); |
359 | |
360 | fused_outputs.insert( |
361 | fused_outputs.begin(), stack.end() - output_count, stack.end()); |
362 | |
363 | if (stack_copy) { |
364 | take_fallback(*stack_copy); |
365 | TORCH_CHECK( |
366 | stack_copy->size() == stack.size(), |
367 | "Fused graph returns stack with " , |
368 | stack.size(), |
369 | " items, compared to " , |
370 | stack_copy->size(), |
371 | " from unfused graph" ); |
372 | fallback_outputs.insert( |
373 | fallback_outputs.begin(), |
374 | stack_copy->end() - output_count, |
375 | stack_copy->end()); |
376 | } |
377 | auto graph_str = fusion_node->g(attr::Subgraph)->toString(); |
378 | compare_callback.callback(fused_outputs, fallback_outputs, graph_str); |
379 | } |
380 | } |
381 | |
382 | } // namespace cuda |
383 | } // namespace fuser |
384 | } // namespace jit |
385 | } // namespace torch |
386 | |