1#pragma once
2
3#include <c10/macros/Export.h>
4
5#include <ir_all_nodes.h>
6
7namespace torch {
8namespace jit {
9namespace fuser {
10namespace cuda {
11
12class ContigIDs;
13
14void validateIr(Fusion* fusion);
15
16//! Validate vectorization and collect information on vectorization
17//! used in code generation as well as runtime validation.
18void validateAndCollectVectorizeInfo(Fusion* fusion);
19
20//! Find the contig root domains that a vectorized leaf domain
21//! of a consumer TV depends on. Required for runtime validation.
22void fillConsumerVectorizedContigRootDomains(
23 const TensorView* consumer_tv,
24 const ContigIDs& contig_finder);
25
26//! Find the contig root domains that a vectorized leaf domain
27//! of a producer TV depends on. Required for runtime validation.
28//! Producer must be transformed as consumer.
29void fillProducerVectorizedContigRootDomains(
30 const TensorView* producer_tv,
31 const TensorView* consumer_tv,
32 const std::unordered_map<IterDomain*, IterDomain*>& c2p_map,
33 const ContigIDs& contig_finder);
34
35//! Validates partial split expressions. Partial split only uses an
36//! inner subdomain specified by start and stop offsets, ignoring the
37//! values outside the range. It's designed to be used with non-padded
38//! shift, which introduces non-zero start and stop smaller than the
39//! extent. This function makes sure all tensors have all values
40//! calculated that are necessary for output values.
41void validatePartialSplit(Fusion* fusion);
42
43//! Validate data format and GPU arch compatibility of scheduled
44//! mma operators on the fusion.
45void validateMma(Fusion* fusion);
46
47//! Validates swizzle ops to ensure consistent indexing:
48//! - Currently only allow swizzle ops on the right of CA axis,
49//! - (Except ZShape) All swizzle ops have to be on const sized ids
50//! - Xor and Transpose swizzle have to have equal dimensions on the
51//! participating ids.
52void validateSwizzle(Fusion* fusion);
53
54//! Validate use of ParallelType::Group. It is currently only allowed
55//! in ReductionOp and not in WelfordOp. Group has similar constraints
56//! as Vectorize, e.g., it can only be used with IterDomains with
57//! static extents. Differences are, e.g, it has no constraints on
58//! alignments and predicates. Each individual reduction has its own
59//! predicate, so it is possile for only part of grouped reductions to
60//! be executed.
61//!
62//! Also, grouping is only enabled for persistent grid reductions, in
63//! other words, grid allreduces. Note that no grid reduction without
64//! broadcast is persistent anymore.
65//!
66//! Validated ReductionOp with ParallelType::Group is converted to
67//! GroupedReductionOp.
68void validateAndConvertIterDomainGrouping(Fusion* fusion);
69
70//! Validate the number of grouped reductions is within the limit
71void validateGroupedReductions(Fusion* fusion);
72
73} // namespace cuda
74} // namespace fuser
75} // namespace jit
76} // namespace torch
77