1 | #include "offline_cache_util.h" |
2 | |
3 | #include "taichi/common/core.h" |
4 | #include "taichi/common/serialization.h" |
5 | #include "taichi/ir/snode.h" |
6 | #include "taichi/ir/transforms.h" |
7 | #include "taichi/program/compile_config.h" |
8 | #include "taichi/program/kernel.h" |
9 | |
10 | #include "picosha2.h" |
11 | |
12 | #include <vector> |
13 | |
14 | namespace taichi::lang { |
15 | |
16 | static std::vector<std::uint8_t> get_offline_cache_key_of_compile_config( |
17 | const CompileConfig &config) { |
18 | BinaryOutputSerializer serializer; |
19 | serializer.initialize(); |
20 | serializer(config.arch); |
21 | serializer(config.debug); |
22 | serializer(config.cfg_optimization); |
23 | serializer(config.check_out_of_bound); |
24 | serializer(config.opt_level); |
25 | serializer(config.external_optimization_level); |
26 | serializer(config.move_loop_invariant_outside_if); |
27 | serializer(config.demote_dense_struct_fors); |
28 | serializer(config.advanced_optimization); |
29 | serializer(config.constant_folding); |
30 | serializer(config.kernel_profiler); |
31 | serializer(config.fast_math); |
32 | serializer(config.flatten_if); |
33 | serializer(config.make_thread_local); |
34 | serializer(config.make_block_local); |
35 | serializer(config.detect_read_only); |
36 | serializer(config.default_fp->to_string()); |
37 | serializer(config.default_ip.to_string()); |
38 | if (arch_is_cpu(config.arch)) { |
39 | serializer(config.default_cpu_block_dim); |
40 | serializer(config.cpu_max_num_threads); |
41 | } else if (arch_is_gpu(config.arch)) { |
42 | serializer(config.default_gpu_block_dim); |
43 | serializer(config.gpu_max_reg); |
44 | serializer(config.saturating_grid_dim); |
45 | serializer(config.cpu_max_num_threads); |
46 | } |
47 | serializer(config.ad_stack_size); |
48 | serializer(config.default_ad_stack_size); |
49 | serializer(config.random_seed); |
50 | if (config.arch == Arch::cc) { |
51 | serializer(config.cc_compile_cmd); |
52 | serializer(config.cc_link_cmd); |
53 | } else if (config.arch == Arch::opengl || config.arch == Arch::gles) { |
54 | serializer(config.allow_nv_shader_extension); |
55 | } |
56 | serializer(config.make_mesh_block_local); |
57 | serializer(config.optimize_mesh_reordered_mapping); |
58 | serializer(config.mesh_localize_to_end_mapping); |
59 | serializer(config.mesh_localize_from_end_mapping); |
60 | serializer(config.mesh_localize_all_attr_mappings); |
61 | serializer(config.demote_no_access_mesh_fors); |
62 | serializer(config.experimental_auto_mesh_local); |
63 | serializer(config.auto_mesh_local_default_occupacy); |
64 | serializer(config.real_matrix_scalarize); |
65 | serializer.finalize(); |
66 | |
67 | return serializer.data; |
68 | } |
69 | |
70 | static void get_offline_cache_key_of_snode_impl( |
71 | const SNode *snode, |
72 | BinaryOutputSerializer &serializer, |
73 | std::unordered_set<int> &visited) { |
74 | if (auto iter = visited.find(snode->id); iter != visited.end()) { |
75 | serializer(snode->id); // Use snode->id as placeholder to identify a snode |
76 | return; |
77 | } |
78 | |
79 | visited.insert(snode->id); |
80 | for (auto &c : snode->ch) { |
81 | get_offline_cache_key_of_snode_impl(c.get(), serializer, visited); |
82 | } |
83 | for (int i = 0; i < taichi_max_num_indices; ++i) { |
84 | auto & = snode->extractors[i]; |
85 | serializer(extractor.num_elements_from_root); |
86 | serializer(extractor.shape); |
87 | serializer(extractor.acc_shape); |
88 | serializer(extractor.active); |
89 | } |
90 | serializer(snode->index_offsets); |
91 | serializer(snode->num_active_indices); |
92 | serializer(snode->physical_index_position); |
93 | serializer(snode->id); |
94 | serializer(snode->depth); |
95 | serializer(snode->name); |
96 | serializer(snode->num_cells_per_container); |
97 | serializer(snode->chunk_size); |
98 | serializer(snode->cell_size_bytes); |
99 | serializer(snode->offset_bytes_in_parent_cell); |
100 | serializer(snode->dt->to_string()); |
101 | serializer(snode->has_ambient); |
102 | if (!snode->ambient_val.dt->is_primitive(PrimitiveTypeID::unknown)) { |
103 | serializer(snode->ambient_val.stringify()); |
104 | } |
105 | if (snode->grad_info && !snode->grad_info->is_primal()) { |
106 | if (auto *adjoint_snode = snode->grad_info->adjoint_snode()) { |
107 | get_offline_cache_key_of_snode_impl(adjoint_snode, serializer, visited); |
108 | } |
109 | if (auto *dual_snode = snode->grad_info->dual_snode()) { |
110 | get_offline_cache_key_of_snode_impl(dual_snode, serializer, visited); |
111 | } |
112 | } |
113 | if (snode->physical_type) { |
114 | serializer(snode->physical_type->to_string()); |
115 | } |
116 | serializer(snode->id_in_bit_struct); |
117 | serializer(snode->is_bit_level); |
118 | serializer(snode->is_path_all_dense); |
119 | serializer(snode->node_type_name); |
120 | serializer(snode->type); |
121 | serializer(snode->_morton); |
122 | serializer(snode->get_snode_tree_id()); |
123 | } |
124 | |
125 | std::string get_hashed_offline_cache_key_of_snode(const SNode *snode) { |
126 | TI_ASSERT(snode); |
127 | |
128 | BinaryOutputSerializer serializer; |
129 | serializer.initialize(); |
130 | { |
131 | std::unordered_set<int> visited; |
132 | get_offline_cache_key_of_snode_impl(snode, serializer, visited); |
133 | } |
134 | serializer.finalize(); |
135 | |
136 | picosha2::hash256_one_by_one hasher; |
137 | hasher.process(serializer.data.begin(), serializer.data.end()); |
138 | hasher.finish(); |
139 | |
140 | return picosha2::get_hash_hex_string(hasher); |
141 | } |
142 | |
143 | std::string get_hashed_offline_cache_key(const CompileConfig &config, |
144 | Kernel *kernel) { |
145 | std::string kernel_ast_string; |
146 | if (kernel) { |
147 | std::ostringstream oss; |
148 | gen_offline_cache_key(kernel->ir.get(), &oss); |
149 | kernel_ast_string = oss.str(); |
150 | } |
151 | |
152 | auto compile_config_key = get_offline_cache_key_of_compile_config(config); |
153 | std::string autodiff_mode = |
154 | std::to_string(static_cast<std::size_t>(kernel->autodiff_mode)); |
155 | picosha2::hash256_one_by_one hasher; |
156 | hasher.process(compile_config_key.begin(), compile_config_key.end()); |
157 | hasher.process(kernel_ast_string.begin(), kernel_ast_string.end()); |
158 | hasher.process(autodiff_mode.begin(), autodiff_mode.end()); |
159 | hasher.finish(); |
160 | |
161 | auto res = picosha2::get_hash_hex_string(hasher); |
162 | res.insert(res.begin(), 'T'); // The key must start with a letter |
163 | return res; |
164 | } |
165 | |
166 | } // namespace taichi::lang |
167 | |