1#pragma once
2
3#include <ir_all_nodes.h>
4
5namespace torch {
6namespace jit {
7namespace fuser {
8namespace cuda {
9
10//! Keep track of certain patterns of reductions.
11//!
12//! - Allreduce IterDomain: reduced and broadcast domain.
13class FusedReductionInfo {
14 public:
15 void markAsAllreduce(IterDomain* id);
16
17 bool isAllreduce(IterDomain* id) const;
18
19 private:
20 // Reduction IterDomains that are also broadcast
21 std::unordered_set<IterDomain*> allreduce_ids_;
22};
23
24//! Detect reductions and broadcasts that are eligible for the fused
25//! reduction kernel. When found, the predicate flags of the broadcast
26//! is unset, which effectively makes the broadcast just a unary set
27//! op.
28//! TODO: Consider moving the warp-based fused reduction here.
29void fuseReductionsAndBroadcasts(Fusion*);
30
31} // namespace cuda
32} // namespace fuser
33} // namespace jit
34} // namespace torch
35