1#include "taichi/cache/gfx/cache_manager.h"
2#include "taichi/analysis/offline_cache_util.h"
3#include "taichi/codegen/spirv/snode_struct_compiler.h"
4#include "taichi/common/cleanup.h"
5#include "taichi/common/version.h"
6#include "taichi/program/kernel.h"
7#include "taichi/runtime/gfx/aot_module_loader_impl.h"
8#include "taichi/runtime/gfx/snode_tree_manager.h"
9#include "taichi/util/lock.h"
10#include "taichi/util/offline_cache.h"
11
12namespace taichi::lang {
13
14namespace {
15
16constexpr char kMetadataFileLockName[] = "metadata.lock";
17constexpr char kAotMetadataFilename[] = "metadata.tcb";
18constexpr char kDebuggingAotMetadataFilename[] = "metadata.json";
19constexpr char kGraphMetadataFilename[] = "graphs.tcb";
20constexpr char kOfflineCacheMetadataFilename[] = "offline_cache_metadata.tcb";
21using CompiledKernelData = gfx::GfxRuntime::RegisterParams;
22
23inline gfx::CacheManager::Metadata::KernelMetadata make_kernel_metadata(
24 const std::string &key,
25 const gfx::GfxRuntime::RegisterParams &compiled) {
26 std::size_t codes_size = 0;
27 for (const auto &e : compiled.task_spirv_source_codes) {
28 codes_size += e.size() * sizeof(*e.data());
29 }
30
31 gfx::CacheManager::Metadata::KernelMetadata res;
32 res.kernel_key = key;
33 res.size = codes_size;
34 res.created_at = std::time(nullptr);
35 res.last_used_at = std::time(nullptr);
36 res.num_files = compiled.task_spirv_source_codes.size();
37 return res;
38}
39
40} // namespace
41
42namespace offline_cache {
43
44template <>
45struct CacheCleanerUtils<gfx::CacheManager::Metadata> {
46 using MetadataType = gfx::CacheManager::Metadata;
47 using KernelMetaData = MetadataType::KernelMetadata;
48
49 // To save metadata as file
50 static bool save_metadata(const CacheCleanerConfig &config,
51 const MetadataType &data) {
52 // Update AOT metadata
53 gfx::TaichiAotData old_aot_data, new_aot_data;
54 auto aot_metadata_path =
55 taichi::join_path(config.path, kAotMetadataFilename);
56 if (read_from_binary_file(old_aot_data, aot_metadata_path)) {
57 const auto &kernels = data.kernels;
58 for (auto &k : old_aot_data.kernels) {
59 if (kernels.count(k.name)) {
60 new_aot_data.kernels.push_back(std::move(k));
61 }
62 }
63 write_to_binary_file(new_aot_data, aot_metadata_path);
64 }
65 write_to_binary_file(
66 data, taichi::join_path(config.path, config.metadata_filename));
67 return true;
68 }
69
70 static bool save_debugging_metadata(const CacheCleanerConfig &config,
71 const MetadataType &data) {
72 // Do nothing
73 return true;
74 }
75
76 // To get cache files name
77 static std::vector<std::string> get_cache_files(
78 const CacheCleanerConfig &config,
79 const KernelMetaData &kernel_meta) {
80 std::vector<std::string> result;
81 for (std::size_t i = 0; i < kernel_meta.num_files; ++i) {
82 result.push_back(kernel_meta.kernel_key + std::to_string(i) + "." +
83 kSpirvCacheFilenameExt);
84 }
85 return result;
86 }
87
88 // To remove other files except cache files and offline cache metadta files
89 static void remove_other_files(const CacheCleanerConfig &config) {
90 taichi::remove(taichi::join_path(config.path, kAotMetadataFilename));
91 taichi::remove(
92 taichi::join_path(config.path, kDebuggingAotMetadataFilename));
93 taichi::remove(taichi::join_path(config.path, kGraphMetadataFilename));
94 }
95
96 // To check if a file is cache file
97 static bool is_valid_cache_file(const CacheCleanerConfig &config,
98 const std::string &name) {
99 return filename_extension(name) == kSpirvCacheFilenameExt;
100 }
101};
102
103} // namespace offline_cache
104
105namespace gfx {
106
107CacheManager::CacheManager(Params &&init_params)
108 : mode_(init_params.mode),
109 runtime_(init_params.runtime),
110 compile_config_(*init_params.compile_config),
111 compiled_structs_(*init_params.compiled_structs) {
112 TI_ASSERT(init_params.runtime);
113 TI_ASSERT(init_params.compile_config);
114 TI_ASSERT(init_params.compiled_structs);
115
116 path_ = offline_cache::get_cache_path_by_arch(init_params.cache_path,
117 init_params.arch);
118 { // Load cached module with checking
119 using Error = offline_cache::LoadMetadataError;
120 using offline_cache::load_metadata_with_checking;
121 Metadata tmp;
122 auto filepath = taichi::join_path(path_, kOfflineCacheMetadataFilename);
123 if (load_metadata_with_checking(tmp, filepath) == Error::kNoError) {
124 auto lock_path = taichi::join_path(path_, kMetadataFileLockName);
125 auto exists =
126 taichi::path_exists(taichi::join_path(path_, kAotMetadataFilename)) &&
127 taichi::path_exists(taichi::join_path(path_, kGraphMetadataFilename));
128 if (exists) {
129 if (lock_with_file(lock_path)) {
130 auto _ = make_cleanup([&lock_path]() {
131 if (!unlock_with_file(lock_path)) {
132 TI_WARN(
133 "Unlock {} failed. You can remove this .lock file manually "
134 "and try again.",
135 lock_path);
136 }
137 });
138 gfx::AotModuleParams params;
139 params.module_path = path_;
140 params.runtime = runtime_;
141 params.enable_lazy_loading = true;
142 cached_module_ = gfx::make_aot_module(params, init_params.arch);
143 } else {
144 TI_WARN(
145 "Lock {} failed. You can run 'ti cache clean -p {}' and try "
146 "again.",
147 lock_path, path_);
148 }
149 }
150 }
151 }
152
153 caching_module_builder_ = std::make_unique<gfx::AotModuleBuilderImpl>(
154 compiled_structs_, init_params.arch, compile_config_,
155 std::move(init_params.caps));
156
157 offline_cache_metadata_.version[0] = TI_VERSION_MAJOR;
158 offline_cache_metadata_.version[1] = TI_VERSION_MINOR;
159 offline_cache_metadata_.version[2] = TI_VERSION_PATCH;
160}
161
162CompiledKernelData CacheManager::load_or_compile(const CompileConfig &config,
163 Kernel *kernel) {
164 if (kernel->is_evaluator) {
165 spirv::lower(config, kernel);
166 return gfx::run_codegen(kernel, runtime_->get_ti_device()->arch(),
167 runtime_->get_ti_device()->get_caps(),
168 compiled_structs_, config);
169 }
170 std::string kernel_key = make_kernel_key(config, kernel);
171 if (mode_ > NotCache) {
172 if (auto opt = this->try_load_cached_kernel(kernel, kernel_key)) {
173 return *opt;
174 }
175 }
176 return this->compile_and_cache_kernel(kernel_key, kernel);
177}
178
179void CacheManager::dump_with_merging() const {
180 if (mode_ == MemAndDiskCache && !offline_cache_metadata_.kernels.empty()) {
181 taichi::create_directories(path_);
182 auto *cache_builder =
183 static_cast<gfx::AotModuleBuilderImpl *>(caching_module_builder_.get());
184 cache_builder->mangle_aot_data();
185
186 auto lock_path = taichi::join_path(path_, kMetadataFileLockName);
187 if (lock_with_file(lock_path)) {
188 auto _ = make_cleanup([&lock_path]() {
189 if (!unlock_with_file(lock_path)) {
190 TI_WARN(
191 "Unlock {} failed. You can remove this .lock file manually and "
192 "try again.",
193 lock_path);
194 }
195 });
196
197 // Update metadata.{tcb,json}
198 cache_builder->merge_with_old_meta_data(path_);
199 cache_builder->dump(path_, "");
200
201 // Update offline_cache_metadata.tcb
202 using offline_cache::load_metadata_with_checking;
203 using Error = offline_cache::LoadMetadataError;
204 Metadata old_data;
205 const auto filename =
206 taichi::join_path(path_, kOfflineCacheMetadataFilename);
207 if (load_metadata_with_checking(old_data, filename) == Error::kNoError) {
208 for (auto &[k, v] : offline_cache_metadata_.kernels) {
209 auto iter = old_data.kernels.find(k);
210 if (iter != old_data.kernels.end()) { // Update
211 iter->second.last_used_at = v.last_used_at;
212 } else { // Add new
213 old_data.size += v.size;
214 old_data.kernels[k] = std::move(v);
215 }
216 }
217 write_to_binary_file(old_data, filename);
218 } else {
219 write_to_binary_file(offline_cache_metadata_, filename);
220 }
221 }
222 }
223}
224
225void CacheManager::clean_offline_cache(offline_cache::CleanCachePolicy policy,
226 int max_bytes,
227 double cleaning_factor) const {
228 if (mode_ == MemAndDiskCache) {
229 using CacheCleaner = offline_cache::CacheCleaner<Metadata>;
230 offline_cache::CacheCleanerConfig params;
231 params.path = path_;
232 params.policy = policy;
233 params.cleaning_factor = cleaning_factor;
234 params.max_size = max_bytes;
235 params.metadata_filename = kOfflineCacheMetadataFilename;
236 params.debugging_metadata_filename = ""; // No debugging file
237 params.metadata_lock_name = kMetadataFileLockName;
238 CacheCleaner::run(params);
239 }
240}
241
242std::optional<CompiledKernelData> CacheManager::try_load_cached_kernel(
243 Kernel *kernel,
244 const std::string &key) {
245 if (mode_ == NotCache) {
246 return std::nullopt;
247 }
248 // Find in memory-cache
249 auto *cache_builder =
250 static_cast<gfx::AotModuleBuilderImpl *>(caching_module_builder_.get());
251 auto params_opt = cache_builder->try_get_kernel_register_params(key);
252 if (params_opt.has_value()) {
253 TI_DEBUG("Create kernel '{}' from in-memory cache (key='{}')",
254 kernel->get_name(), key);
255 // TODO: Support multiple SNodeTrees in AOT.
256 params_opt->num_snode_trees = compiled_structs_.size();
257 return params_opt;
258 }
259 // Find in disk-cache
260 if (mode_ == MemAndDiskCache && cached_module_) {
261 if (auto *aot_kernel = cached_module_->get_kernel(key)) {
262 TI_DEBUG("Create kernel '{}' from cache (key='{}')", kernel->get_name(),
263 key);
264 auto *aot_kernel_impl = static_cast<gfx::KernelImpl *>(aot_kernel);
265 auto compiled = aot_kernel_impl->params();
266 // TODO: Support multiple SNodeTrees in AOT.
267 compiled.num_snode_trees = compiled_structs_.size();
268 auto kmetadata = make_kernel_metadata(key, compiled);
269 offline_cache_metadata_.size += kmetadata.size;
270 offline_cache_metadata_.kernels[key] = std::move(kmetadata);
271 return compiled;
272 }
273 }
274 return std::nullopt;
275}
276
277CompiledKernelData CacheManager::compile_and_cache_kernel(
278 const std::string &key,
279 Kernel *kernel) {
280 TI_DEBUG_IF(mode_ == MemAndDiskCache, "Cache kernel '{}' (key='{}')",
281 kernel->get_name(), key);
282 auto *cache_builder =
283 static_cast<gfx::AotModuleBuilderImpl *>(caching_module_builder_.get());
284 TI_ASSERT(cache_builder != nullptr);
285 cache_builder->add(key, kernel);
286 auto params_opt = cache_builder->try_get_kernel_register_params(key);
287 TI_ASSERT(params_opt.has_value());
288 // TODO: Support multiple SNodeTrees in AOT.
289 params_opt->num_snode_trees = compiled_structs_.size();
290 auto kmetadata = make_kernel_metadata(key, *params_opt);
291 offline_cache_metadata_.size += kmetadata.size;
292 offline_cache_metadata_.kernels[key] = std::move(kmetadata);
293 return *params_opt;
294}
295
296std::string CacheManager::make_kernel_key(const CompileConfig &config,
297 Kernel *kernel) const {
298 if (mode_ < MemAndDiskCache) {
299 return kernel->get_name();
300 }
301 auto key = kernel->get_cached_kernel_key();
302 if (key.empty()) {
303 key = get_hashed_offline_cache_key(config, kernel);
304 kernel->set_kernel_key_for_cache(key);
305 }
306 return key;
307}
308
309} // namespace gfx
310} // namespace taichi::lang
311