1#include <instrumentation.h>
2#include <ir_iostream.h>
3#include <kernel.h>
4#include <kernel_expr_evaluator.h>
5#include <kernel_ir_dispatch.h>
6#include <lower2device.h>
7
8#include <ATen/cuda/CUDAContext.h>
9
10#include <iostream>
11#include <unordered_set>
12
13namespace torch {
14namespace jit {
15namespace fuser {
16namespace cuda {
17
18IrBuilderPasskey::IrBuilderPasskey(IrContainer* ir_container)
19 : ir_container_(ir_container) {}
20
21namespace kir {
22
23namespace {
24
25//! Scan all primary expressions in the Kernel IR and build
26//! lists of specialized nodes and other interesting information
27class KernelIrScanner : private IrVisitor {
28 public:
29 explicit KernelIrScanner(const Kernel* kernel) {
30 IrVisitor::handle(kernel->topLevelExprs());
31 const auto gpu_lower = GpuLower::current();
32 for (auto split : gpu_lower->nonDivisibleSplitInfo().splitsToValidate()) {
33 auto extent = split->in()->extent();
34 auto factor = split->factor();
35 summary_.splits_to_validate.emplace_back(extent, factor);
36 }
37 }
38
39 const auto& summary() const {
40 return summary_;
41 }
42
43 private:
44 using IrVisitor::handle;
45 void handle(Expr* expr) final {
46 IrVisitor::handle(expr);
47 for (auto inp : expr->inputs()) {
48 handle(inp);
49 }
50 for (auto out : expr->outputs()) {
51 handle(out);
52 }
53 }
54 void handle(BlockSync* sync) final {
55 // TODO: Move to a dedicated validation pass
56 // which is not on the common execution/compilation path
57 if (sync->isWarHazardSync()) {
58 ++summary_.war_hazard_syncs_count;
59 }
60 }
61
62 void handle(GridSync* sync) final {
63 summary_.has_cooperative_grid_reduction = true;
64 }
65
66 void handle(Allocate* allocate) final {
67 switch (allocate->memoryType()) {
68 case MemoryType::Global:
69 summary_.global_allocations.push_back(allocate);
70 break;
71 case MemoryType::Shared:
72 summary_.dynamic_smem_allocations.push_back(allocate);
73 break;
74 case MemoryType::Local:
75 if (!ExpressionEvaluator::isConst(allocate->size())) {
76 summary_.has_dynamic_local_memory_allocations = true;
77 summary_.dynamic_lmem_allocations.emplace_back(allocate);
78 }
79 break;
80 }
81 }
82
83 void handle(RNGOp* rng_op) final {
84 summary_.max_rng_offsets =
85 std::max<int>(summary_.max_rng_offsets, rng_op->getRNGOffset());
86 }
87
88 void handle(TensorIndex* tensor_index) final {
89 const auto tv = tensor_index->view();
90 const auto domain = tv->domain();
91 // Do we have any reductions?
92 summary_.has_block_reductions =
93 summary_.has_block_reductions || domain->hasBlockReduction();
94
95 // Update the largest smem data type
96 if (domain->hasBlockReduction() || domain->hasGridReduction() ||
97 tv->getMemoryType() == MemoryType::Shared) {
98 const auto data_type = tv->dtype();
99 const size_t type_size = dataTypeSize(data_type);
100 if (type_size > max_smem_type_size_) {
101 max_smem_type_size_ = type_size;
102 summary_.largest_smem_data_type = data_type;
103 }
104 }
105 }
106
107 void handle(WelfordOp* welford_op) final {
108 summary_.has_welford = true;
109 TORCH_INTERNAL_ASSERT(welford_op->outAvg()->isA<TensorIndex>());
110 auto out_dom = welford_op->outAvg()->as<TensorIndex>()->view()->domain();
111 summary_.has_block_welford =
112 summary_.has_block_welford || out_dom->hasBlockReduction();
113 }
114
115 void handle(GridWelford* grid_welford) final {
116 summary_.has_welford = true;
117 summary_.has_grid_welford = true;
118 summary_.has_grid_reductions = true;
119 if (grid_welford->welford_op()->isAllreduce()) {
120 summary_.has_cooperative_grid_reduction = true;
121 }
122 }
123
124 void handle(GridReduction* grid_reduction) final {
125 summary_.has_grid_reductions = true;
126 if (grid_reduction->isAllreduce()) {
127 summary_.has_cooperative_grid_reduction = true;
128 }
129 }
130
131 void handle(GroupedGridReduction* grid_reduction) final {
132 summary_.has_grid_reductions = true;
133 if (grid_reduction->isAllreduce()) {
134 summary_.has_cooperative_grid_reduction = true;
135 }
136 }
137
138 void handle(GroupedGridWelford* grid_welford) final {
139 summary_.has_welford = true;
140 summary_.has_grid_welford = true;
141 summary_.has_grid_reductions = true;
142 if (grid_welford->isAllreduce()) {
143 summary_.has_cooperative_grid_reduction = true;
144 }
145 }
146
147 void handle(GridBroadcast* grid_broadcast) final {
148 summary_.has_cooperative_grid_reduction = true;
149 handle(grid_broadcast->broadcast_op());
150 }
151
152 void handle(BroadcastOp* bop) final {
153 const ParallelTypeBitmap parallel_types =
154 GpuLower::current()->threadPredMap().getParallelBroadcastDomains(
155 bop->out()->as<TensorIndex>()->view());
156 summary_.broadcast_parallel_types.emplace(bop, parallel_types);
157 // Do we have block broadcasts?
158 summary_.has_block_broadcasts =
159 summary_.has_block_broadcasts || parallel_types.hasTID();
160 // Do we have grid broadcasts?
161 summary_.has_grid_broadcasts =
162 summary_.has_grid_broadcasts || parallel_types.hasBID();
163 }
164
165 private:
166 size_t max_smem_type_size_ = 0;
167 KernelSummary summary_;
168};
169
170//! Make sure tensors have valid allocations even when parallelized
171//! loops potentially have larger iteration counts than the number of
172//! threads.
173//!
174//! When an IterDomain of a tensor is parallelized, the IterDomain
175//! may not contribute to the allocation of the tensor. For example,
176//! it is assumed that an allocation of a local-memory tensor does not
177//! need to be accounted for an parallelied IterDomain. This is true
178//! when it is guaranteed that each thread only needs to execute the
179//! loop body once. However, if not, the allocation is invalid as it
180//! only has a space for one value per thread.
181//!
182//! ValidateAllocation checks all tensor allocations and sees if any
183//! tensor may have a parallelized loop whose iteration count may
184//! be larger than the number of threads. If so, an error is thrown if
185//! the tensor is not allocated on thread-shared memories. Note that
186//! when allocated on a shared memory (i.e., MemoryType::Shared or
187//! MemoryType::Global for tensors parallelized with threadIdx, or
188//! MemoryType::Global for tensors parallelized with blockIdx), it is
189//! assumed that allocation is properly extended for the iteration
190//! count.
191class ValidateAllocation : private OptOutConstDispatch {
192 public:
193 static void validate(const Kernel* kernel) {
194 ValidateAllocation validate_allocation(kernel);
195 }
196
197 private:
198 explicit ValidateAllocation(const Kernel* kernel) {
199 live_allocations_.emplace_back(std::vector<const Allocate*>());
200 for (const auto& expr : kernel->topLevelExprs()) {
201 OptOutConstDispatch::handle(expr);
202 }
203 live_allocations_.pop_back();
204 TORCH_INTERNAL_ASSERT(live_allocations_.empty());
205 }
206
207 void handle(const Allocate* allocate) final {
208 TORCH_INTERNAL_ASSERT(!live_allocations_.empty());
209 live_allocations_.back().push_back(allocate);
210 }
211
212 // for_loop is parallelized and its stop value is not guaranteed to
213 // be <= the number of threads, which breaks an assumption made
214 // during in the allocation lowering if it's thread-parallel and not
215 // allocated on shared or global memories, or if it's block-parallel
216 // ando not allocated on global memory.
217 void validate(const ForLoop* for_loop) {
218 const auto loop_id = for_loop->iter_domain();
219 for (const auto& allocations : live_allocations_) {
220 for (const auto& allocate : allocations) {
221 const auto tv = dynamic_cast<TensorView*>(allocate->buffer());
222 if (tv == nullptr) {
223 continue;
224 }
225 for (const auto& axis : tv->domain()->domain()) {
226 if (!GpuLower::current()->caMap()->areMapped(
227 loop_id, axis, IdMappingMode::LOOP)) {
228 continue;
229 }
230 if (isParallelTypeThreadDim(loop_id->getParallelType())) {
231 TORCH_INTERNAL_ASSERT(
232 tv->getMemoryType() == MemoryType::Shared ||
233 tv->getMemoryType() == MemoryType::Global,
234 "Tensor t",
235 tv->name(),
236 " must be allocated on SMEM or GMEM.");
237 } else if (isParallelTypeBlockDim(loop_id->getParallelType())) {
238 TORCH_INTERNAL_ASSERT(tv->getMemoryType() == MemoryType::Global);
239 }
240 }
241 }
242 }
243 }
244
245 void handle(const ForLoop* for_loop) final {
246 if (for_loop->stop() != for_loop->iter_domain()->extent() &&
247 isParallelTypeThread(for_loop->iter_domain()->getParallelType())) {
248 validate(for_loop);
249 }
250
251 live_allocations_.emplace_back(std::vector<const Allocate*>());
252 for (const auto& expr : for_loop->body().exprs()) {
253 OptOutConstDispatch::handle(expr);
254 }
255 live_allocations_.pop_back();
256 }
257
258 void handle(const IfThenElse* ite) final {
259 for (const auto& expr : ite->thenBody().exprs()) {
260 OptOutConstDispatch::handle(expr);
261 }
262 for (const auto& expr : ite->elseBody().exprs()) {
263 OptOutConstDispatch::handle(expr);
264 }
265 }
266
267 private:
268 std::vector<std::vector<const Allocate*>> live_allocations_;
269};
270
271} // namespace
272
273// TODO(kir): Kernel IR validation
274void Kernel::finalize(std::vector<Expr*> top_level_exprs) {
275 TORCH_INTERNAL_ASSERT(top_level_exprs_.empty());
276 top_level_exprs_ = std::move(top_level_exprs);
277 warp_padded_parallel_info_ = GpuLower::current()->getWarpPaddedParallelInfo();
278 profile_ = GpuLower::current()->profile();
279 ValidateAllocation::validate(this);
280 analyze();
281 // Make sure this is after analyze as it sets summary_
282 summary_.vectorized_accesses = GpuLower::current()->vectorizedAccesses();
283 summary_.vectorized_set_info = GpuLower::current()->vectorizedSetInfo();
284 summary_.sync_map = GpuLower::current()->syncMap();
285 summary_.parallel_dimension_map_ =
286 GpuLower::current()->parallelDimensionMap();
287}
288
289void Kernel::analyze() {
290 FUSER_PERF_SCOPE("Kernel::analyze");
291
292 const KernelIrScanner ir_scanner(this);
293 summary_ = ir_scanner.summary();
294}
295
296void Kernel::print() const {
297 IrPrinter ir_printer(std::cout);
298 ir_printer.handle(this);
299}
300
301//! Register the Val with this fusion
302void Kernel::registerVal(Val* val) {
303 if (inContainer(val)) {
304 return;
305 }
306 if (val->kernel()) {
307 TORCH_CHECK(
308 val->kernel() == this,
309 val->toString(),
310 " was not found in the active kernel.");
311 }
312
313 Fusion::registerVal(val);
314}
315
316//! Register expr with this fusion.
317//! When we register an expression, we want to update the dependency tracking
318//! of Vals. We add expr to our general expr_set_,
319void Kernel::registerExpr(Expr* expr) {
320 if (inContainer(expr)) {
321 return;
322 }
323
324 if (expr->kernel()) {
325 TORCH_CHECK(
326 expr->kernel() == this,
327 expr->toString(),
328 " was not found in the active kernel.");
329 }
330
331 for (Val* input : expr->inputs()) {
332 TORCH_INTERNAL_ASSERT(
333 inContainer(input),
334 "Input\n",
335 input->toString(),
336 " to expr,\n",
337 expr->toString(),
338 ",\n is invalid because it is not in the same kernel.");
339 }
340
341 for (Val* output : expr->outputs()) {
342 TORCH_INTERNAL_ASSERT(
343 inContainer(output),
344 "Output\n",
345 output->toString(),
346 " to expr,\n",
347 expr->toString(),
348 ",\n is invalid because it is not in the same kernel.");
349 }
350
351 // Register expr is explicitly non-SSA when coming from a kernel. This is
352 // detected inside Fusion::registerExpr
353 Fusion::registerExpr(expr);
354}
355
356std::vector<Expr*>& KernelInternalProxy::topLevelExprs() {
357 return kernel_->top_level_exprs_;
358}
359
360void KernelPerformanceProfile::registerExpr(const Expr* expr) {
361 if (expr_entry_map_.find(expr) != expr_entry_map_.end()) {
362 return;
363 }
364
365 auto slot = getNewIndex();
366 expr_entry_map_.emplace(expr, slot);
367}
368
369int KernelPerformanceProfile::getNewIndex() {
370 return num_profile_entries_++;
371}
372
373bool KernelPerformanceProfile::isProfiled(const Expr* expr) const {
374 return expr_entry_map_.find(expr) != expr_entry_map_.end();
375}
376
377c10::optional<int> KernelPerformanceProfile::getIndex(const Expr* expr) const {
378 auto it = expr_entry_map_.find(expr);
379 if (it == expr_entry_map_.end()) {
380 return c10::optional<int>();
381 } else {
382 return it->second;
383 }
384}
385
386std::array<int, 2> KernelPerformanceProfile::getIndicesInProfileBuffer(
387 const Expr* expr) const {
388 TORCH_INTERNAL_ASSERT(
389 isProfiled(expr), "Not a profiled expression: ", expr->toString());
390
391 int cycle_index = getIndex(expr).value() * 2;
392 int count_index = cycle_index + 1;
393
394 return {cycle_index, count_index};
395}
396
397std::string KernelPerformanceProfile::toString(const at::Tensor& buffer) const {
398 std::stringstream ss;
399 ss << "Kernel performance profile:\n";
400 if (!buffer.defined()) {
401 ss << "No profile found\n";
402 return ss.str();
403 }
404
405 double kilo_freq = at::cuda::getCurrentDeviceProperties()->clockRate;
406
407 ss << std::setprecision(3) << std::fixed;
408
409 for (const auto& kv : expr_entry_map_) {
410 auto expr = kv.first;
411 auto index = kv.second;
412 auto out_tv = ir_utils::getTvOutput(expr);
413 double cycles = static_cast<double>(buffer[index][0].item<int64_t>());
414 auto count = buffer[index][1].item<int64_t>();
415 auto cycles_per_call = count == 0 ? 0.0 : cycles / count;
416 auto us_per_call = cycles_per_call / kilo_freq * 1000.0;
417 ss << expr->getExprType().value() << ", T" << out_tv->name() << ", "
418 << us_per_call << " us, " << count << "\n";
419 }
420
421 return ss.str();
422}
423
424} // namespace kir
425} // namespace cuda
426} // namespace fuser
427} // namespace jit
428} // namespace torch
429