1#pragma once
2
3#include <inlining.h>
4#include <root_domain_map.h>
5#include <transform_replay.h>
6
7#include <c10/macros/Export.h>
8#include <c10/util/Exception.h>
9
10#include <deque>
11#include <unordered_map>
12#include <unordered_set>
13#include <vector>
14
15namespace torch {
16namespace jit {
17namespace fuser {
18namespace cuda {
19
20class TensorDomain;
21class TensorView;
22
23struct ComputeAt {
24 public:
25 // Runs the compute at pass making producer look like consumer, computing
26 // producer relative to consumer
27 static void runAt(
28 TensorView* producer,
29 TensorView* consumer,
30 int64_t consumer_position,
31 ComputeAtMode mode = ComputeAtMode::Standard);
32
33 // Runs the compute with pass making consumer look like producer, computing
34 // producer relative to consumer
35 static void runWith(
36 TensorView* producer,
37 TensorView* consumer,
38 int64_t producer_position,
39 ComputeAtMode mode = ComputeAtMode::Standard);
40};
41
42} // namespace cuda
43} // namespace fuser
44} // namespace jit
45} // namespace torch
46