1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * Build cuda modules from source. |
22 | * requires cuda to be available. |
23 | * |
24 | * \file build_cuda.cc |
25 | */ |
26 | #if defined(__linux__) |
27 | #include <sys/stat.h> |
28 | #endif |
29 | #include <cuda_runtime.h> |
30 | #include <nvrtc.h> |
31 | |
32 | #include <cstdlib> |
33 | |
34 | #include "../../runtime/cuda/cuda_common.h" |
35 | #include "../../runtime/cuda/cuda_module.h" |
36 | #include "../build_common.h" |
37 | #include "../source/codegen_cuda.h" |
38 | |
39 | namespace tvm { |
40 | namespace codegen { |
41 | |
42 | #define NVRTC_CALL(x) \ |
43 | { \ |
44 | nvrtcResult result = x; \ |
45 | if (result != NVRTC_SUCCESS) { \ |
46 | LOG(FATAL) << "NvrtcError: " #x " failed with error: " << nvrtcGetErrorString(result); \ |
47 | } \ |
48 | } |
49 | |
50 | std::string FindCUDAIncludePath() { |
51 | #if defined(_WIN32) |
52 | const std::string delimiter = "\\" ; |
53 | #else |
54 | const std::string delimiter = "/" ; |
55 | #endif |
56 | std::string cuda_include_path; |
57 | const char* cuda_path_env = std::getenv("CUDA_PATH" ); |
58 | if (cuda_path_env != nullptr) { |
59 | cuda_include_path += cuda_path_env; |
60 | cuda_include_path += delimiter + "include" ; |
61 | return cuda_include_path; |
62 | } |
63 | |
64 | #if defined(__linux__) |
65 | struct stat st; |
66 | cuda_include_path = "/usr/local/cuda/include" ; |
67 | if (stat(cuda_include_path.c_str(), &st) == 0) { |
68 | return cuda_include_path; |
69 | } |
70 | |
71 | if (stat("/usr/include/cuda.h" , &st) == 0) { |
72 | return "/usr/include" ; |
73 | } |
74 | #endif |
75 | LOG(FATAL) << "Cannot find cuda include path." |
76 | << "CUDA_PATH is not set or CUDA is not installed in the default installation path." |
77 | << "In other than linux, it is necessary to set CUDA_PATH." ; |
78 | return cuda_include_path; |
79 | } |
80 | |
81 | std::string NVRTCCompile(const std::string& code, bool include_path = false) { |
82 | std::vector<std::string> compile_params; |
83 | std::vector<const char*> param_cstrings{}; |
84 | nvrtcProgram prog; |
85 | std::string cc = "30" ; |
86 | int major, minor; |
87 | cudaError_t e1 = cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, 0); |
88 | cudaError_t e2 = cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, 0); |
89 | |
90 | if (e1 == cudaSuccess && e2 == cudaSuccess) { |
91 | cc = std::to_string(major) + std::to_string(minor); |
92 | } else { |
93 | LOG(WARNING) << "cannot detect compute capability from your device, " |
94 | << "fall back to compute_30." ; |
95 | } |
96 | |
97 | compile_params.push_back("-arch=compute_" + cc); |
98 | |
99 | if (include_path) { |
100 | std::string include_option = "--include-path=" + FindCUDAIncludePath(); |
101 | |
102 | compile_params.push_back(include_option); |
103 | } |
104 | |
105 | for (const auto& string : compile_params) { |
106 | param_cstrings.push_back(string.c_str()); |
107 | } |
108 | NVRTC_CALL(nvrtcCreateProgram(&prog, code.c_str(), nullptr, 0, nullptr, nullptr)); |
109 | nvrtcResult compile_res = nvrtcCompileProgram(prog, param_cstrings.size(), param_cstrings.data()); |
110 | |
111 | size_t log_size; |
112 | NVRTC_CALL(nvrtcGetProgramLogSize(prog, &log_size)); |
113 | std::string log; |
114 | log.resize(log_size); |
115 | NVRTC_CALL(nvrtcGetProgramLog(prog, &log[0])); |
116 | ICHECK_EQ(compile_res, NVRTC_SUCCESS) << log; |
117 | size_t ptx_size; |
118 | NVRTC_CALL(nvrtcGetPTXSize(prog, &ptx_size)); |
119 | |
120 | std::string ptx; |
121 | ptx.resize(ptx_size); |
122 | NVRTC_CALL(nvrtcGetPTX(prog, &ptx[0])); |
123 | NVRTC_CALL(nvrtcDestroyProgram(&prog)); |
124 | |
125 | return ptx; |
126 | } |
127 | |
128 | runtime::Module BuildCUDA(IRModule mod, Target target) { |
129 | using tvm::runtime::Registry; |
130 | bool output_ssa = false; |
131 | CodeGenCUDA cg; |
132 | cg.Init(output_ssa); |
133 | |
134 | for (auto kv : mod->functions) { |
135 | ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenCUDA: Can only take PrimFunc" ; |
136 | auto f = Downcast<PrimFunc>(kv.second); |
137 | auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv); |
138 | ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) |
139 | << "CodeGenCUDA: expect calling_conv equals CallingConv::kDeviceKernelLaunch" ; |
140 | cg.AddFunction(f); |
141 | } |
142 | |
143 | std::string code = cg.Finish(); |
144 | |
145 | if (const auto* f = Registry::Get("tvm_callback_cuda_postproc" )) { |
146 | code = (*f)(code).operator std::string(); |
147 | } |
148 | std::string fmt = "ptx" ; |
149 | std::string ptx; |
150 | const auto* f_enter = Registry::Get("target.TargetEnterScope" ); |
151 | (*f_enter)(target); |
152 | if (const auto* f = Registry::Get("tvm_callback_cuda_compile" )) { |
153 | ptx = (*f)(code).operator std::string(); |
154 | // Dirty matching to check PTX vs cubin. |
155 | // TODO(tqchen) more reliable checks |
156 | if (ptx[0] != '/') fmt = "cubin" ; |
157 | } else { |
158 | ptx = NVRTCCompile(code, cg.need_include_path()); |
159 | } |
160 | const auto* f_exit = Registry::Get("target.TargetExitScope" ); |
161 | (*f_exit)(target); |
162 | return CUDAModuleCreate(ptx, fmt, ExtractFuncInfo(mod), code); |
163 | } |
164 | |
165 | TVM_REGISTER_GLOBAL("target.build.cuda" ).set_body_typed(BuildCUDA); |
166 | } // namespace codegen |
167 | } // namespace tvm |
168 | |