1#include "offline_cache_util.h"
2
3#include "taichi/common/core.h"
4#include "taichi/common/serialization.h"
5#include "taichi/ir/snode.h"
6#include "taichi/ir/transforms.h"
7#include "taichi/program/compile_config.h"
8#include "taichi/program/kernel.h"
9
10#include "picosha2.h"
11
12#include <vector>
13
14namespace taichi::lang {
15
16static std::vector<std::uint8_t> get_offline_cache_key_of_compile_config(
17 const CompileConfig &config) {
18 BinaryOutputSerializer serializer;
19 serializer.initialize();
20 serializer(config.arch);
21 serializer(config.debug);
22 serializer(config.cfg_optimization);
23 serializer(config.check_out_of_bound);
24 serializer(config.opt_level);
25 serializer(config.external_optimization_level);
26 serializer(config.move_loop_invariant_outside_if);
27 serializer(config.demote_dense_struct_fors);
28 serializer(config.advanced_optimization);
29 serializer(config.constant_folding);
30 serializer(config.kernel_profiler);
31 serializer(config.fast_math);
32 serializer(config.flatten_if);
33 serializer(config.make_thread_local);
34 serializer(config.make_block_local);
35 serializer(config.detect_read_only);
36 serializer(config.default_fp->to_string());
37 serializer(config.default_ip.to_string());
38 if (arch_is_cpu(config.arch)) {
39 serializer(config.default_cpu_block_dim);
40 serializer(config.cpu_max_num_threads);
41 } else if (arch_is_gpu(config.arch)) {
42 serializer(config.default_gpu_block_dim);
43 serializer(config.gpu_max_reg);
44 serializer(config.saturating_grid_dim);
45 serializer(config.cpu_max_num_threads);
46 }
47 serializer(config.ad_stack_size);
48 serializer(config.default_ad_stack_size);
49 serializer(config.random_seed);
50 if (config.arch == Arch::cc) {
51 serializer(config.cc_compile_cmd);
52 serializer(config.cc_link_cmd);
53 } else if (config.arch == Arch::opengl || config.arch == Arch::gles) {
54 serializer(config.allow_nv_shader_extension);
55 }
56 serializer(config.make_mesh_block_local);
57 serializer(config.optimize_mesh_reordered_mapping);
58 serializer(config.mesh_localize_to_end_mapping);
59 serializer(config.mesh_localize_from_end_mapping);
60 serializer(config.mesh_localize_all_attr_mappings);
61 serializer(config.demote_no_access_mesh_fors);
62 serializer(config.experimental_auto_mesh_local);
63 serializer(config.auto_mesh_local_default_occupacy);
64 serializer(config.real_matrix_scalarize);
65 serializer.finalize();
66
67 return serializer.data;
68}
69
70static void get_offline_cache_key_of_snode_impl(
71 const SNode *snode,
72 BinaryOutputSerializer &serializer,
73 std::unordered_set<int> &visited) {
74 if (auto iter = visited.find(snode->id); iter != visited.end()) {
75 serializer(snode->id); // Use snode->id as placeholder to identify a snode
76 return;
77 }
78
79 visited.insert(snode->id);
80 for (auto &c : snode->ch) {
81 get_offline_cache_key_of_snode_impl(c.get(), serializer, visited);
82 }
83 for (int i = 0; i < taichi_max_num_indices; ++i) {
84 auto &extractor = snode->extractors[i];
85 serializer(extractor.num_elements_from_root);
86 serializer(extractor.shape);
87 serializer(extractor.acc_shape);
88 serializer(extractor.active);
89 }
90 serializer(snode->index_offsets);
91 serializer(snode->num_active_indices);
92 serializer(snode->physical_index_position);
93 serializer(snode->id);
94 serializer(snode->depth);
95 serializer(snode->name);
96 serializer(snode->num_cells_per_container);
97 serializer(snode->chunk_size);
98 serializer(snode->cell_size_bytes);
99 serializer(snode->offset_bytes_in_parent_cell);
100 serializer(snode->dt->to_string());
101 serializer(snode->has_ambient);
102 if (!snode->ambient_val.dt->is_primitive(PrimitiveTypeID::unknown)) {
103 serializer(snode->ambient_val.stringify());
104 }
105 if (snode->grad_info && !snode->grad_info->is_primal()) {
106 if (auto *adjoint_snode = snode->grad_info->adjoint_snode()) {
107 get_offline_cache_key_of_snode_impl(adjoint_snode, serializer, visited);
108 }
109 if (auto *dual_snode = snode->grad_info->dual_snode()) {
110 get_offline_cache_key_of_snode_impl(dual_snode, serializer, visited);
111 }
112 }
113 if (snode->physical_type) {
114 serializer(snode->physical_type->to_string());
115 }
116 serializer(snode->id_in_bit_struct);
117 serializer(snode->is_bit_level);
118 serializer(snode->is_path_all_dense);
119 serializer(snode->node_type_name);
120 serializer(snode->type);
121 serializer(snode->_morton);
122 serializer(snode->get_snode_tree_id());
123}
124
125std::string get_hashed_offline_cache_key_of_snode(const SNode *snode) {
126 TI_ASSERT(snode);
127
128 BinaryOutputSerializer serializer;
129 serializer.initialize();
130 {
131 std::unordered_set<int> visited;
132 get_offline_cache_key_of_snode_impl(snode, serializer, visited);
133 }
134 serializer.finalize();
135
136 picosha2::hash256_one_by_one hasher;
137 hasher.process(serializer.data.begin(), serializer.data.end());
138 hasher.finish();
139
140 return picosha2::get_hash_hex_string(hasher);
141}
142
143std::string get_hashed_offline_cache_key(const CompileConfig &config,
144 Kernel *kernel) {
145 std::string kernel_ast_string;
146 if (kernel) {
147 std::ostringstream oss;
148 gen_offline_cache_key(kernel->ir.get(), &oss);
149 kernel_ast_string = oss.str();
150 }
151
152 auto compile_config_key = get_offline_cache_key_of_compile_config(config);
153 std::string autodiff_mode =
154 std::to_string(static_cast<std::size_t>(kernel->autodiff_mode));
155 picosha2::hash256_one_by_one hasher;
156 hasher.process(compile_config_key.begin(), compile_config_key.end());
157 hasher.process(kernel_ast_string.begin(), kernel_ast_string.end());
158 hasher.process(autodiff_mode.begin(), autodiff_mode.end());
159 hasher.finish();
160
161 auto res = picosha2::get_hash_hex_string(hasher);
162 res.insert(res.begin(), 'T'); // The key must start with a letter
163 return res;
164}
165
166} // namespace taichi::lang
167