1#include <executor.h>
2#include <fusion.h>
3#include <instrumentation.h>
4#include <ir_iostream.h>
5#include <kernel_cache.h>
6#include <manager.h>
7#include <parser.h>
8#include <scheduler/all_schedulers.h>
9#include <type_inference.h>
10#include <utils.h>
11#include <torch/csrc/jit/jit_log.h>
12#include <torch/csrc/jit/passes/canonicalize.h>
13#include <torch/csrc/jit/passes/cuda_graph_fuser.h>
14#include <torch/csrc/jit/passes/shape_analysis.h>
15#include <torch/csrc/jit/passes/symbolic_shape_analysis.h>
16#include <torch/csrc/jit/runtime/graph_executor.h>
17#include <torch/csrc/jit/runtime/interpreter.h>
18
19#include <ATen/DimVector.h>
20#include <c10/core/DeviceType.h>
21#include <c10/util/irange.h>
22
23#include <unordered_map>
24
25namespace torch {
26namespace jit {
27namespace fuser {
28namespace cuda {
29
30//! [ Note -- cache entry indexing ]
31//!
32//! CudaFusionManager holds the cache and handles interfacing to CudaFusionGroup
33//! node, including selection, construction and execution of FusionExecutors.
34//!
35//! CudaFusionManager bridges PyTorch IR node CudaFusionGroup to GraphCache.
36//! Therefore, we want to cache on stringified graph. But it is expensive to
37//! stringify and hash on a computational graph, we cache the hash of a
38//! stringified graph on node via cache_id.
39//!
40//! CudaFusionGroup node stores:
41//! i. a PyTorch IR in `attr::Subgraph`
42//! ii. an int in `attr::cache_id`, (a cached hash value of
43//! `attr::Subgraph`)
44//!
45//! We have 2 unordered_map at CudaFusionGroup:
46//! std::unordered_map<std::string, int32_t> graph_cache_ids_;
47//! std::unordered_map<int64_t, std::unique_ptr<GraphCache>> graph_cache_;
48//!
49//! Mapping from std::string to graph_cache_id ensures that we assign the same
50//! cache_id to CudaFusionGroup with identical computational grah, allowing
51//! kernel reuse; Direct mapping from cache_id to GraphCache allows efficient
52//! graph_cache indexing;
53
54namespace {
55
56// TODO remove this (75983):
57// we don't need this any more. I think we can use revertAliasCopyOps.
58// Similar refactor should be done infallback graph used by fusion guard.
59// implementation of xxxx_copy ops should be removed.
60//
61// Mark string attribute in alias-copy nodes to enable its implementation
62// in the fallback path.
63void enableAliasCopyNodes(const std::shared_ptr<Graph>& graph, Block* block) {
64 static std::unordered_set<Symbol> alias_copy_op(
65 {prim::expand_copy,
66 prim::expand_as_copy,
67 prim::flatten_copy,
68 prim::permute_copy,
69 prim::reshape_copy,
70 prim::squeeze_copy,
71 prim::t_copy,
72 prim::transpose_copy,
73 prim::unsqueeze_copy,
74 prim::view_copy});
75
76 for (Node* n : block->nodes()) {
77 for (Block* b : n->blocks()) {
78 enableAliasCopyNodes(graph, b);
79 }
80 if (alias_copy_op.find(n->kind()) != alias_copy_op.end()) {
81 n->s_(attr::name, "CudaFusionGroup");
82 }
83 }
84}
85
86static std::unique_ptr<Code> createFallbackCode(const Node* fusion_node) {
87 auto copied_graph = fusion_node->g(attr::Subgraph)->copy();
88 EraseShapeInformation(copied_graph);
89 enableAliasCopyNodes(copied_graph, copied_graph->block());
90 auto code = std::make_unique<Code>(copied_graph, "fallback_cuda_fuser");
91 return code;
92}
93
94// CudaFusionManager is not thread safe!
95// TODO: we should make the tradeoff here to use thread_local instead of global
96// singleton;
97// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
98class CudaFusionManager {
99 public:
100 static CudaFusionManager& getManager() {
101 static CudaFusionManager cuda_fusion_manager_;
102 return cuda_fusion_manager_;
103 };
104
105 // TODO: I'm assuming we have stride information in `graph->toString`
106 // We need to make sure stride information is in the final string, as we
107 // want to AVOID kernel reuse between different fusion_node, unless they
108 // have identical contiguity information! (So identical stride + shape
109 // is even more restricting in a good way)
110 int32_t registerOrGetCacheId(std::shared_ptr<Graph>& graph) {
111 // prepare graph for lowering;
112 // We should not call `EraseShapeInformation(graph);`, graph representation
113 // does not incorporate static sizes, but just rank of input tensors, which
114 // is exactly what we wanted.
115 auto canonical_graph = Canonicalize(graph, false);
116 auto repr = canonical_graph->toString(false);
117
118 std::lock_guard<std::mutex> guard(mutex_);
119 // create new graph_cache_ids_ entry if none existed yet;
120 if (graph_cache_ids_.count(repr) == 0) {
121 int32_t kernel_id = getNextUniqueID();
122 graph_cache_ids_[repr] = kernel_id;
123 TORCH_CHECK(
124 graph_cache_.emplace(kernel_id, std::make_unique<GraphCache>(graph))
125 .second);
126 }
127 return graph_cache_ids_[repr];
128 };
129
130 // get fallback kernel id
131 int32_t getFallbackKernelId() {
132 std::lock_guard<std::mutex> guard(mutex_);
133 return getNextUniqueID();
134 }
135
136 void unregisterCacheId(std::shared_ptr<Graph>& graph) {
137 auto canonical_graph = Canonicalize(graph, false);
138 auto repr = canonical_graph->toString(false);
139
140 // create new graph_cache_ids_ entry if none existed yet;
141 if (graph_cache_ids_.count(repr) > 0) {
142 int32_t kernel_id = graph_cache_ids_[repr];
143 graph_cache_.erase(kernel_id);
144 graph_cache_ids_.erase(repr);
145 }
146 }
147
148 std::vector<at::Tensor> runFusionNode(
149 int32_t kernel_id,
150 const at::ArrayRef<IValue> inputs) {
151 std::lock_guard<std::mutex> guard(mutex_);
152 TORCH_INTERNAL_ASSERT(
153 graph_cache_.count(kernel_id) > 0, "graph cache miss at run time");
154 return graph_cache_[kernel_id]->runGraphWithInputs(inputs);
155 }
156
157 bool hasFallbackCode(int32_t kernel_id) {
158 std::lock_guard<std::mutex> guard(mutex_);
159 return fallback_cache_.count(kernel_id);
160 }
161
162 Code* getFallbackCode(int32_t kernel_id, const Node* fusion_node) {
163 {
164 std::lock_guard<std::mutex> guard(mutex_);
165 auto it = fallback_cache_.find(kernel_id);
166 if (it != fallback_cache_.end()) {
167 return it->second.get();
168 }
169 }
170
171 std::unique_ptr<Code> code = createFallbackCode(fusion_node);
172
173 std::lock_guard<std::mutex> guard(mutex_);
174 auto it = fallback_cache_.insert({kernel_id, std::move(code)}).first;
175 return it->second.get();
176 }
177
178 private:
179 // TODO: Dimension collapsing should be abstracted out and integrated into
180 // graph caching.
181
182 // Dimension collapsing only applicable to profiling executor at this moment
183 bool graphHasReduction(const std::shared_ptr<Graph>& graph) {
184 for (const auto& n : graph->nodes()) {
185 if (isReductionNode(n)) {
186 return true;
187 }
188 }
189 return false;
190 }
191
192 private:
193 std::mutex mutex_;
194
195 void runCudaKernel(
196 int32_t key,
197 const std::vector<int>& contiguity_tag,
198 const c10::Device){};
199
200 int32_t getNextUniqueID() {
201 return next_unique_id_++;
202 };
203
204 std::unordered_map<std::string, int32_t> graph_cache_ids_;
205 std::unordered_map<int64_t, std::unique_ptr<GraphCache>> graph_cache_;
206 std::unordered_map<int64_t, std::unique_ptr<Code>> fallback_cache_;
207
208 int32_t next_unique_id_ = 0;
209};
210
211} // namespace
212
213void compileCudaFusionGroup(Node* fusion_node) {
214 FUSER_PERF_SCOPE("nvFuser::Manager::compileCudaFusionGroup");
215
216 TORCH_CHECK(
217 fusion_node->kind() == prim::CudaFusionGroup,
218 "Only prim::CudaFusionGroup can be compiled");
219 if (fusion_node->hasAttribute(attr::cache_id)) {
220 TORCH_WARN("Double registration of CudaFusionGroup on CudaFusionManager");
221 }
222 // This is not a critical code path, it's OK to do graph copy here;
223 auto graph = fusion_node->g(attr::Subgraph)->copy();
224
225 auto compile_fusion = [&]() {
226 // type propagation is needed, as the protocol only requires scalar type on
227 // input tensors.
228 // Note that even for Profiling Executor, scalar type could still be
229 // missing, especially for output tensor from a given node (as profiling
230 // node only insert meta information after itself).
231 PropagateShapesOnGraph(graph);
232 TypePropagate(graph);
233
234 int32_t fusion_cache_id =
235 CudaFusionManager::getManager().registerOrGetCacheId(graph);
236 fusion_node->i_(attr::cache_id, fusion_cache_id);
237 };
238
239 if (useFallback()) {
240 try {
241 compile_fusion();
242 } catch (...) {
243 TORCH_WARN(
244 "FALLBACK path has been taken inside: ",
245 __FUNCTION__,
246 ". This is an indication that codegen Failed for some reason.\n"
247 "To debug try disable codegen fallback path via setting the env"
248 " variable `export PYTORCH_NVFUSER_DISABLE=fallback`\n"
249 "To report the issue, try enable logging via setting the env"
250 "variable ` export PYTORCH_JIT_LOG_LEVEL=manager.cpp`\n");
251 GRAPH_DUMP("`compile_fusion` hits fallback on graph\n", graph);
252 CudaFusionManager::getManager().unregisterCacheId(graph);
253 }
254 } else {
255 compile_fusion();
256 }
257
258 // Assigning a cache_id to facilitate graph execution and fallback
259 if (!fusion_node->hasAttribute(attr::cache_id)) {
260 int32_t fusion_cache_id =
261 CudaFusionManager::getManager().getFallbackKernelId();
262 fusion_node->i_(attr::cache_id, fusion_cache_id);
263 }
264}
265
266void runCudaFusionGroup(const Node* fusion_node, Stack& stack) {
267 FUSER_PERF_SCOPE("nvFuser::Manager::runCudaFusionGroup");
268 TORCH_CHECK(
269 fusion_node->hasAttribute(attr::cache_id),
270 "node prim::CudaFusionGroup has not been compiled yet");
271
272 // Fallback to use if anything goes wrong
273 auto take_fallback = [&](Stack& stack) {
274 std::unique_ptr<Code> fallback_code_unique;
275 Code* fallback_code;
276 int32_t kernel_id = fusion_node->i(attr::cache_id);
277 fallback_code =
278 CudaFusionManager::getManager().getFallbackCode(kernel_id, fusion_node);
279 InterpreterState{*fallback_code}.run(stack);
280 };
281
282 c10::optional<Stack> stack_copy;
283 auto compare_callback = getCudaFuserComparisonCallback();
284 if (compare_callback.run_fallback) {
285 // make a copy of the stack
286 int64_t inputs_size =
287 static_cast<int64_t>(fusion_node->g(attr::Subgraph)->inputs().size());
288 TORCH_INTERNAL_ASSERT(int64_t(stack.size()) >= inputs_size);
289 stack_copy = Stack();
290 stack_copy->insert(
291 stack_copy->end(), stack.begin(), stack.end() - inputs_size);
292 // deepcopy the last (inputs_size) stack items
293 std::transform(
294 stack.end() - inputs_size,
295 stack.end(),
296 std::back_inserter(*stack_copy),
297 [](const c10::IValue& ivalue) { return ivalue.deepcopy(); });
298 }
299
300 auto run_fusion = [&]() {
301 TORCH_CHECK(
302 fusion_node->kind() == prim::CudaFusionGroup,
303 "prim::CudaFusionGroup expected");
304 int32_t kernel_id = fusion_node->i(attr::cache_id);
305 // Currently we just construct I/O tensors for static graph;
306
307 const auto nInputs = fusion_node->g(attr::Subgraph)->inputs().size();
308
309 at::ArrayRef<IValue> inputs = last(stack, nInputs);
310
311 auto outputs =
312 CudaFusionManager::getManager().runFusionNode(kernel_id, inputs);
313
314 drop(stack, inputs.size());
315 stack.insert(
316 stack.end(),
317 std::make_move_iterator(outputs.begin()),
318 std::make_move_iterator(outputs.end()));
319 };
320
321 if (useFallback()) {
322 try {
323 // if fusion failed once, it's likely to fail again; and failures are
324 // slow. So if the fusion fails, then record the failure and always use
325 // the fallback instead
326 int32_t kernel_id = fusion_node->i(attr::cache_id);
327 bool force_fallback =
328 CudaFusionManager::getManager().hasFallbackCode(kernel_id);
329 if (force_fallback) {
330 take_fallback(stack);
331 } else {
332 run_fusion();
333 }
334 } catch (...) {
335 TORCH_WARN(
336 "FALLBACK path has been taken inside: ",
337 __FUNCTION__,
338 ". This is an indication that codegen Failed for some reason.\n"
339 "To debug try disable codegen fallback path via setting the env"
340 " variable `export PYTORCH_NVFUSER_DISABLE=fallback`\n");
341 take_fallback(stack);
342 }
343 } else {
344 run_fusion();
345 }
346
347 if (compare_callback.callback != nullptr) {
348 Stack fused_outputs;
349 Stack fallback_outputs;
350 int64_t output_count =
351 static_cast<int64_t>(fusion_node->g(attr::Subgraph)->outputs().size());
352 TORCH_CHECK(
353 output_count <= int64_t(stack.size()),
354 "Expected ",
355 output_count,
356 " outputs but found only ",
357 stack.size(),
358 " items on the stack");
359
360 fused_outputs.insert(
361 fused_outputs.begin(), stack.end() - output_count, stack.end());
362
363 if (stack_copy) {
364 take_fallback(*stack_copy);
365 TORCH_CHECK(
366 stack_copy->size() == stack.size(),
367 "Fused graph returns stack with ",
368 stack.size(),
369 " items, compared to ",
370 stack_copy->size(),
371 " from unfused graph");
372 fallback_outputs.insert(
373 fallback_outputs.begin(),
374 stack_copy->end() - output_count,
375 stack_copy->end());
376 }
377 auto graph_str = fusion_node->g(attr::Subgraph)->toString();
378 compare_callback.callback(fused_outputs, fallback_outputs, graph_str);
379 }
380}
381
382} // namespace cuda
383} // namespace fuser
384} // namespace jit
385} // namespace torch
386