1#pragma once
2#include <manager.h>
3#include <transform_view.h>
4
5#include <c10/macros/Export.h>
6#include <torch/csrc/jit/codegen/cuda/interface.h>
7#include <torch/csrc/jit/ir/ir.h>
8#include <torch/csrc/jit/passes/pass_manager.h>
9#include <torch/csrc/jit/runtime/profiling_record.h>
10
11/*
12 * This file contains APIs for cuda fuser;
13 *
14 * We use an empty static struct to hold the function pointers, which are
15 * registered separately. This is to support cpu-only compilation.
16 * Registration is done in torch/csrc/jit/codegen/cuda/register_interface.cpp
17 */
18
19namespace torch {
20namespace jit {
21namespace fuser {
22namespace cuda {
23
24TORCH_CUDA_CU_API bool complyWith(
25 const at::Tensor& tensor,
26 const c10::TensorTypePtr& guard_tensor_type);
27
28struct TORCH_CUDA_CU_API NVFuserPassManager
29 : public PassManager<NVFuserPassManager> {
30 static bool registerPass(bool enabled) {
31 bool old_value = PassManager::isRegistered();
32 if (enabled) {
33 PassManager::registerPass(fuseGraph);
34 } else {
35 PassManager::clearPass();
36 }
37 return old_value;
38 }
39
40 static bool isRegistered() {
41 return PassManager::isRegistered();
42 }
43};
44
45} // namespace cuda
46} // namespace fuser
47} // namespace jit
48} // namespace torch
49