1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file relay/backend/te_compiler.cc |
22 | * \brief Manages the transition from Relay "Primitive" \p Functions to TIR \p PrimFuncs. Also |
23 | * handles invocation of external codegen. |
24 | * |
25 | * \p LowerTEPass handles the following (as a monolithic blob of code): |
26 | * |
27 | * - Most importantly, any function with the "Primitive" attribute is first converted to TE by |
28 | * \p LowerToTECompute (see te_compiler_cache.cc) using each operator's 'compute' function. |
29 | * The TE is then 'scheduled' to TIR using the 'anchor' operator's 'schedule' function. Both |
30 | * of those functions come from the \p OpStrategy returned by the Python |
31 | * 'relay.backend.lower_call' function (see te_compiler.py). |
32 | * The TIR is packed as a \p PrimFunc and introduced as a new global function. Calls to the |
33 | * original "Primitive" function are then rewritten to the form: |
34 | * \code |
35 | * call_lowered(@new_global, (... original args...), attributes) |
36 | * \endcode |
37 | * |
38 | * - The above "Primitive" function can appear: |
39 | * - As a global function |
40 | * - As a let-bound function |
41 | * - As an inline function, ie the 'op' of calls. |
42 | * In all three cases it is possible for the same "Primitive" function to be called multiple |
43 | * times, and that sharing must be respected. |
44 | * |
45 | * - "Primitive" functions must have a "global_symbol" attribute matching their desired or |
46 | * existing global name. Care is taken to ensure GlobalVars with the same name are shared. |
47 | * |
48 | * - It is possible for multiple structurally equal "Primitive" functions to appear in the same |
49 | * \p IRModule. Only one implementation should be generated, and all calls should share that |
50 | * implementation. |
51 | * |
52 | * - When later converting to DPS (see memory_alloc.cc) we must handle functions who's result |
53 | * tensor shapes depend at runtime on the input tensor shapes and/or data. |
54 | * - That dependency is first described in TE form (see \p MakeShapeFunc in |
55 | * te_compiler_cache.cc), then scheduled to yield a 'dynamic shape function' \p PrimFunc. |
56 | * This relies on each operator's "FShapeFunc" and "TShapeDataDependent" attributes. |
57 | * Since shapes are rank-1 tensors everything can be reflected back down into the regular |
58 | * TE/TIR forms. |
59 | * - Then the call_lowered attributes must record everything about the dynamic shape function |
60 | * later needed by memory_alloc.cc. We call this 'cross linking' the call with the shape |
61 | * function. |
62 | * |
63 | * - Two external codegen mechanisms are supported, both triggered by "Primitive" functions which |
64 | * also have a "Compiler" attribute bound to $compiler: |
65 | * - Function-at-a-time (old style): The primitive function is passed to the function |
66 | * registered as 'relay.ext.$compiler'. The function returns a runtime::Module which |
67 | * should return true for \p ImplementsFunction for the function's global name. That |
68 | * module is added to the IRModule's "external_mods" attributes. |
69 | * - IRModule-at-a-item (new style): The \p RelayToTIRTargetHook sub-pass looks for |
70 | * $compiler names which correspond to TargetKind names with a \p RelayToTIR attribute. |
71 | * The \p Pass bound to that attribute is run, and each such 'custom' pass can do what |
72 | * it likes, including replacing Functions with PrimFuncs, or adding new runtime::Modules |
73 | * to the IRModule's "external_mods" attribute. |
74 | * |
75 | * - Calls to functions added by external codegen are also rewritten to call_lowered form, and |
76 | * may also require cross-linking to dynamic shape functions. However, since the functions |
77 | * are/will be implemented by a runtime::Module all the Relay type information is no longer |
78 | * available. So the Relay definitions for these "Primitive" "Compiler" functions are retained |
79 | * in the \p IRModule, but marked with the "Extern" attribute to signal the function is now |
80 | * just for carrying metadata. |
81 | * |
82 | * - Some operators are handled specially: |
83 | * - 'reshape', since it's a no-op on the underlying tensor buffer, and this is handled by |
84 | * condition tests in many passes. |
85 | * - 'debug', since it's intercepted differently depending on runtimes. |
86 | * |
87 | * TODO(mbs): This desperately deserves a refactor to separate all these concerns. See Relax. |
88 | */ |
89 | |
90 | #include "./te_compiler.h" |
91 | |
92 | #include <tvm/driver/driver_api.h> |
93 | #include <tvm/ir/attrs.h> |
94 | #include <tvm/ir/function.h> |
95 | #include <tvm/ir/name_supply.h> |
96 | #include <tvm/relay/analysis.h> |
97 | #include <tvm/relay/attrs/annotation.h> |
98 | #include <tvm/relay/attrs/call.h> |
99 | #include <tvm/relay/attrs/device_copy.h> |
100 | #include <tvm/relay/expr.h> |
101 | #include <tvm/relay/expr_functor.h> |
102 | #include <tvm/relay/op.h> |
103 | #include <tvm/runtime/device_api.h> |
104 | #include <tvm/runtime/registry.h> |
105 | #include <tvm/te/schedule.h> |
106 | #include <tvm/te/schedule_pass.h> |
107 | #include <tvm/tir/transform.h> |
108 | #include <tvm/topi/tags.h> |
109 | |
110 | #include <functional> |
111 | #include <limits> |
112 | #include <mutex> |
113 | #include <tuple> |
114 | #include <unordered_map> |
115 | #include <utility> |
116 | #include <vector> |
117 | |
118 | #include "../op/annotation/annotation.h" |
119 | #include "../op/call/call.h" |
120 | #include "../op/memory/device_copy.h" |
121 | #include "../transforms/device_aware_visitors.h" |
122 | #include "./te_compiler_cache.h" |
123 | #include "./utils.h" |
124 | |
125 | namespace tvm { |
126 | namespace relay { |
127 | // TODO(@jroesch, @csullivan): declare directly elsewhere |
128 | backend::StaticMemoryPlan GraphPlanMemory(const Function& func); |
129 | |
130 | namespace tec { |
131 | |
132 | using namespace tvm::relay::transform; |
133 | |
134 | TVM_REGISTER_OBJECT_TYPE(TECompilerNode); |
135 | |
136 | class TECompilerImpl : public TECompilerNode { |
137 | public: |
138 | explicit TECompilerImpl(Optional<IRModule> opt_mod, Optional<String> opt_mod_name) |
139 | : global_var_supply_(GlobalVarSupply(NameSupply(opt_mod_name.value_or("" )))), |
140 | constant_name_supply_(NameSupply("" )) { |
141 | // Make sure we don't collide with any existing globals in the module. |
142 | if (opt_mod) { |
143 | for (const auto& kv : opt_mod.value()->functions) { |
144 | global_var_supply_->name_supply_->ReserveName(kv.first->name_hint, false); |
145 | } |
146 | } |
147 | } |
148 | |
149 | // Lower the function. |
150 | CachedFunc Lower(const CCacheKey& key) { |
151 | return LowerInternal(key, global_var_supply_)->cached_func; |
152 | } |
153 | |
154 | // TODO(gigiblender): Only to be called by the global TE compiler. |
155 | // Remove this when the global TE compiler is removed. |
156 | CachedFunc Lower(const CCacheKey& key, const String mod_name) { |
157 | global_var_supply_->name_supply_->prefix_ = mod_name; |
158 | return LowerInternal(key, global_var_supply_)->cached_func; |
159 | } |
160 | |
161 | // For now, build one module per function. |
162 | PackedFunc JIT(const CCacheKey& key) final { |
163 | CCacheValue value = LowerInternal(key, GlobalVarSupply(NameSupply("" ))); |
164 | if (value->packed_func != nullptr) { |
165 | return value->packed_func; |
166 | } |
167 | auto m = build(value->cached_func->funcs, key->target, Target(nullptr)); |
168 | value->packed_func = m.GetFunction(value->cached_func->prim_fn_var->name_hint); |
169 | return value->packed_func; |
170 | } |
171 | |
172 | CachedFunc LowerShapeFunc(const CCacheKey& key) final { |
173 | return LowerShapeFuncInternal(key)->cached_func; |
174 | } |
175 | |
176 | IRModule GetLoweredFunctions() { |
177 | VLOG(1) << "GetLoweredFunctions" ; |
178 | IRModule mod; |
179 | // Extract lowered functions from the cache |
180 | for (const auto& it : cache_) { |
181 | auto source_func = it.first; |
182 | auto lowered_func = it.second; |
183 | |
184 | IRModule lowered_mod = lowered_func->cached_func->funcs; |
185 | |
186 | // Annotate functions with their target and put them in the return module |
187 | for (const auto& kv : lowered_mod->functions) { |
188 | const GlobalVar& var = kv.first; |
189 | const BaseFunc& func = kv.second; |
190 | |
191 | // Only add functions that are not external functions |
192 | if (!func->GetAttr<String>(attr::kCompiler).defined()) { |
193 | ICHECK(func->IsInstance<tir::PrimFuncNode>()) |
194 | << "Expected all functions that are not external to be PrimFuncs, but found:" |
195 | << std::endl |
196 | << PrettyPrint(func); |
197 | const tir::PrimFunc& prim_func = Downcast<tir::PrimFunc>(func); |
198 | mod->Update(var, WithAttr(prim_func, tvm::attr::kTarget, source_func->target)); |
199 | } |
200 | } |
201 | } |
202 | |
203 | // Extract lowered dynamic shape functions from the shape cache |
204 | for (const auto& it : shape_func_cache_) { |
205 | auto source_func = it.first; |
206 | auto lowered_func = it.second; |
207 | auto target = source_func->target; |
208 | IRModule lowered_mod = lowered_func->cached_func->funcs; |
209 | |
210 | // Annotate functions with their target and put them in the return module |
211 | for (auto kv : lowered_mod->functions) { |
212 | const GlobalVar& var = kv.first; |
213 | const BaseFunc& func = kv.second; |
214 | const tir::PrimFunc& prim_func = Downcast<tir::PrimFunc>(func); |
215 | mod->Update(var, WithAttr(prim_func, tvm::attr::kTarget, source_func->target)); |
216 | } |
217 | } |
218 | |
219 | return mod; |
220 | } |
221 | |
222 | void AddExterns(IRModule module) { |
223 | // Everything tagged with "Compiler" has been compiled, so remove those definitions. |
224 | std::vector<GlobalVar> to_be_deleted; |
225 | for (const auto& kv : module->functions) { |
226 | if (kv.second->GetAttr<String>(attr::kCompiler).defined()) { |
227 | to_be_deleted.push_back(kv.first); |
228 | } |
229 | } |
230 | for (const auto& global_var : to_be_deleted) { |
231 | VLOG(1) << "Removing definition for external codegened '" << global_var->name_hint << "'" ; |
232 | module->Remove(global_var); |
233 | } |
234 | // HOWEVER we still need a Relay definition to go with those now external functions, so |
235 | // retrieve them from the cache and mark them with "ExternalSymbol". |
236 | for (const auto& kv1 : cache_) { |
237 | auto src_func = kv1.first->source_func; |
238 | ICHECK(src_func.defined()); |
239 | if (src_func->GetAttr<String>(attr::kCompiler).defined()) { |
240 | for (const auto& kv2 : kv1.second->cached_func->funcs->functions) { |
241 | if (const auto* function_node = kv2.second.as<FunctionNode>()) { |
242 | // Abandon the existing function annotations. |
243 | |
244 | // Unfortunately, Optional<DictAttrs>() is indistinguishable from |
245 | // NullValue<DictAttrs>(), and DictAttrs() is nullptr, so to erase the attributes, we |
246 | // need pass in DictAttrs<Map<String, ObjectRef>()), which is a DictAttrs containing no |
247 | // attributes. |
248 | Function function = |
249 | WithFields(GetRef<Function>(function_node), function_node->params, |
250 | function_node->body, function_node->ret_type, function_node->type_params, |
251 | /* erase attributes */ DictAttrs(Map<String, ObjectRef>())); |
252 | // Mark function as 'extern'. |
253 | function = WithAttr(std::move(function), attr::kExtern, Integer(1)); |
254 | module->Add(kv2.first, function); |
255 | } |
256 | } |
257 | } |
258 | } |
259 | } |
260 | |
261 | Array<tvm::runtime::Module> LowerExternalFunctions() { |
262 | Array<tvm::runtime::Module> ret; |
263 | std::vector<CCacheKey> cached_ext_funcs; |
264 | |
265 | for (const auto& it : cache_) { |
266 | auto src_func = it.first->source_func; |
267 | ICHECK(src_func.defined()); |
268 | Optional<String> opt_compiler = src_func->GetAttr<String>(attr::kCompiler); |
269 | if (opt_compiler.defined()) { |
270 | Optional<String> opt_symbol_name = src_func->GetAttr<String>(tvm::attr::kGlobalSymbol); |
271 | ICHECK(opt_symbol_name.defined()) << "No external symbol is set for:" << std::endl |
272 | << PrettyPrint(src_func); |
273 | VLOG(1) << "using external codegen '" << opt_compiler.value() << "' for name '" |
274 | << opt_symbol_name.value() << "' and function:" << std::endl |
275 | << PrettyPrint(src_func); |
276 | cached_ext_funcs.push_back(it.first); |
277 | |
278 | std::string ext_name = "relay.ext." + opt_compiler.value(); |
279 | auto pf = tvm::runtime::Registry::Get(ext_name); |
280 | ICHECK(pf) << "Failed to find the external codegen tool for " << ext_name; |
281 | // No need to keep compiler attribute at this point, functions have been |
282 | // extracted for specific codegen. |
283 | src_func = WithAttr(std::move(src_func), attr::kCompiler, NullValue<ObjectRef>()); |
284 | VLOG_CONTEXT << opt_compiler.value(); |
285 | With<Target> with_target(it.first->target); |
286 | runtime::Module ext_mod = (*pf)(src_func); |
287 | if (ext_mod.defined()) { |
288 | // TODO(mbs): Can this be an ICHECKs? |
289 | if (!ext_mod->ImplementsFunction(opt_symbol_name.value())) { |
290 | VLOG(1) << "Note that the external codegen for '" << opt_compiler.value() |
291 | << "' returned a runtime module which does not appear to implement '" |
292 | << opt_symbol_name.value() << "'" ; |
293 | } |
294 | ret.push_back(ext_mod); |
295 | } else { |
296 | // It is valid for the external codegen function to return null: |
297 | // - Unit tests can use it. |
298 | // - The true compilation may have already been handled by a RelayToTIR custom pass |
299 | // on the Target's kind. The original Relay functions will be left in place so |
300 | // that we can capture that their function names are now externally defined. |
301 | VLOG(1) << "Note that no external runtime module was generated by external codegen '" |
302 | << opt_compiler.value() << "'" ; |
303 | } |
304 | } |
305 | } |
306 | |
307 | // No need to cache external functions as we collected them all to create |
308 | // external runtime modules. |
309 | for (const auto& it : cached_ext_funcs) { |
310 | cache_.erase(it); |
311 | } |
312 | return ret; |
313 | } |
314 | |
315 | Map<GlobalVar, String> GetDeviceContexts() { return device_contexts_; } |
316 | void SetDeviceContexts(const Map<GlobalVar, String>& device_contexts) { |
317 | device_contexts_ = device_contexts; |
318 | } |
319 | |
320 | void Clear() final { cache_.clear(); } |
321 | |
322 | // List all items in the cache. |
323 | Array<ObjectRef> ListItems() { |
324 | std::lock_guard<std::mutex> lock(mutex_); |
325 | Array<ObjectRef> items; |
326 | for (auto& kv : cache_) { |
327 | items.push_back(kv.first); |
328 | items.push_back(kv.second); |
329 | } |
330 | return items; |
331 | } |
332 | |
333 | /*! |
334 | * \brief Get the cache key of the function that is being lowered currently |
335 | * \return the cache key |
336 | */ |
337 | CCacheKey GetCurrentCCacheKey() { return cur_ccache_key_; } |
338 | |
339 | private: |
340 | // implement lowered func |
341 | CCacheValue LowerInternal(const CCacheKey& key, GlobalVarSupply global_var_supply) { |
342 | VLOG(1) << "lowering:" << std::endl |
343 | << PrettyPrint(key->source_func) << std::endl |
344 | << "for target:" << std::endl |
345 | << key->target->ToDebugString(); |
346 | std::lock_guard<std::mutex> lock(mutex_); |
347 | CCacheValue value; |
348 | auto it = cache_.find(key); |
349 | if (it != cache_.end()) { |
350 | VLOG(1) << "already lowered to name:" << std::endl |
351 | << PrettyPrint(it->second->cached_func->prim_fn_var); |
352 | it->second->use_count += 1; |
353 | if (it->second->cached_func.defined()) return it->second; |
354 | value = it->second; |
355 | } else { |
356 | value = CCacheValue(make_object<CCacheValueNode>()); |
357 | value->use_count = 1; |
358 | cache_[key] = value; |
359 | } |
360 | cur_ccache_key_ = key; |
361 | |
362 | Optional<String> opt_compiler = key->source_func->GetAttr<String>(attr::kCompiler); |
363 | if (opt_compiler.defined()) { |
364 | // Don't compile now since we don't have anywhere to put the resulting runtime module. |
365 | // Instead place the original definition in the cache and wait for LowerExternalFunctions. |
366 | IRModule ir_module({}, {}); |
367 | Optional<String> opt_global_symbol = |
368 | key->source_func->GetAttr<String>(tvm::attr::kGlobalSymbol); |
369 | ICHECK(opt_global_symbol.defined()) << "External function has not been attached a name yet." ; |
370 | // Note that the source_func may already be bound to a global function in the module |
371 | // we are compiling, in which case we should not attempt to make its name unique w.r.t. |
372 | // the module's globals. Furthermore, the external codegen tool must bind the compiled |
373 | // function to the "global_symbol" attribute on the source_func. So do not use GetUniqueName |
374 | // here. |
375 | auto global_var = global_var_supply->UniqueGlobalFor(opt_global_symbol.value(), false); |
376 | global_var->checked_type_ = key->source_func->checked_type(); |
377 | ir_module->Add(global_var, key->source_func); |
378 | value->cached_func = CachedFunc(key->target, global_var, {}, {}, te::Schedule{nullptr}, |
379 | tir::PrimFunc{nullptr}, {}, ir_module); |
380 | // Collect these here as it's removed in LowerExternalFunctions() |
381 | device_contexts_.Set(value->cached_func->prim_fn_var, opt_compiler.value()); |
382 | VLOG(1) << "preparing to use external codegen '" << opt_compiler.value() |
383 | << "' with name:" << std::endl |
384 | << PrettyPrint(value->cached_func->prim_fn_var) << std::endl |
385 | << "and definitions:" << std::endl |
386 | << PrettyPrint(value->cached_func->funcs); |
387 | return value; |
388 | } |
389 | |
390 | // Enforce use the target. |
391 | With<Target> target_scope(key->target); |
392 | |
393 | ICHECK(!value->cached_func.defined()); |
394 | value->cached_func = |
395 | PrimFuncFor(key->source_func, key->target, global_var_supply, constant_name_supply_); |
396 | |
397 | if (value->cached_func->prim_func.defined()) { |
398 | VLOG(1) << "Lowering PrimFunc" ; |
399 | IRModule lowered = tvm::LowerPrimFunc(value->cached_func->prim_func.value(), |
400 | value->cached_func->prim_fn_var->name_hint, false); |
401 | ICHECK_EQ(lowered->functions.size(), 1); |
402 | for (const auto& kv : lowered->functions) { |
403 | value->cached_func->funcs->Add(value->cached_func->prim_fn_var, kv.second); |
404 | } |
405 | } else { |
406 | // NOTE: array will copy on write. |
407 | Array<te::Tensor> all_args = Array<te::Tensor>(value->cached_func->inputs); |
408 | for (te::Tensor arg : value->cached_func->outputs) { |
409 | all_args.push_back(arg); |
410 | } |
411 | Array<runtime::NDArray> all_consts; |
412 | for (auto kv : value->cached_func->constant_tensors) { |
413 | all_args.push_back(kv.second); |
414 | all_consts.push_back(kv.first->data); |
415 | } |
416 | // lower the function |
417 | std::unordered_map<te::Tensor, tir::Buffer> binds; |
418 | |
419 | // If we have memory scopes, need to create tir::Buffer knowing this info |
420 | size_t i = 0; // for corresponding from tensor array |
421 | for (Var param : key->source_func->params) { |
422 | if (!param->virtual_device()->memory_scope.empty()) { |
423 | for (const auto& ttype : FlattenTupleType(param->checked_type())) { |
424 | te::Tensor x_ref = value->cached_func->inputs[i]; |
425 | // verification if we have synced params and tensors |
426 | ICHECK(ttype->dtype == x_ref->dtype && ttype->shape.size() == x_ref->shape.size()) |
427 | << "function parameter does not correspond to prepared tensor" ; |
428 | binds[x_ref] = |
429 | tir::BufferWithOffsetAlignment(x_ref->shape, x_ref->dtype, x_ref->op->name, -1, 0, |
430 | false, param->virtual_device()->memory_scope); |
431 | } |
432 | } |
433 | i++; |
434 | } |
435 | if (key->virtual_device != VirtualDevice::FullyUnconstrained() && |
436 | !key->virtual_device->memory_scope.empty() && |
437 | key->virtual_device->memory_scope != "global" ) { |
438 | ICHECK(value->cached_func->outputs.size() == 1) |
439 | << "Expect only one output for defined memory scope" ; |
440 | te::Tensor x_ref = value->cached_func->outputs[0]; |
441 | binds[x_ref] = |
442 | tir::BufferWithOffsetAlignment(x_ref->shape, x_ref->dtype, x_ref->op->name, -1, 0, |
443 | false, key->virtual_device->memory_scope); |
444 | } |
445 | auto func_name = value->cached_func->prim_fn_var->name_hint; |
446 | VLOG(1) << "scheduling" ; |
447 | IRModule scheduled_module = tvm::LowerSchedule(value->cached_func->schedule, all_args, |
448 | func_name, binds, global_var_supply); |
449 | scheduled_module->Update(tir::transform::BindParams(all_consts)(scheduled_module)); |
450 | for (const auto& kv : scheduled_module->functions) { |
451 | GlobalVar global_var = kv.first; |
452 | auto func = kv.second; |
453 | // Propagate the structural hash of the relay function to the tir |
454 | // function so associations can be made between the two. |
455 | Optional<String> hash = key->source_func->attrs.GetAttr<String>("hash" ); |
456 | if (hash) { |
457 | func = WithAttrs(Downcast<tir::PrimFunc>(func), {{String("hash" ), hash.value()}}); |
458 | } |
459 | value->cached_func->funcs->Add(global_var, func); |
460 | } |
461 | ICHECK(value->cached_func->funcs->Lookup(value->cached_func->prim_fn_var) |
462 | .as<tir::PrimFuncNode>()); |
463 | } |
464 | VLOG(1) << "lowered to name:" << std::endl |
465 | << PrettyPrint(value->cached_func->prim_fn_var) << std::endl |
466 | << "with definitions:" << std::endl |
467 | << PrettyPrint(value->cached_func->funcs); |
468 | |
469 | return value; |
470 | } |
471 | |
472 | // implement lowered shape func |
473 | CCacheValue LowerShapeFuncInternal(const CCacheKey& key) { |
474 | VLOG(1) << "lowering dynamic shape function for:" << std::endl |
475 | << PrettyPrint(key->source_func) << std::endl |
476 | << "for target:" << std::endl |
477 | << key->target->ToDebugString(); |
478 | std::lock_guard<std::mutex> lock(mutex_); |
479 | CCacheValue value; |
480 | auto it = shape_func_cache_.find(key); |
481 | if (it != shape_func_cache_.end()) { |
482 | it->second->use_count += 1; |
483 | if (it->second->cached_func.defined()) return it->second; |
484 | value = it->second; |
485 | } else { |
486 | value = CCacheValue(make_object<CCacheValueNode>()); |
487 | value->use_count = 0; |
488 | shape_func_cache_[key] = value; |
489 | } |
490 | // Enforce use the target. |
491 | With<Target> target_scope(key->target); |
492 | |
493 | ICHECK(!value->cached_func.defined()); |
494 | |
495 | using tvm::transform::PassContext; |
496 | With<PassContext> fresh_pass_ctx_scope(PassContext::Create()); |
497 | value->cached_func = ShapeFuncFor(key->source_func, key->target, global_var_supply_); |
498 | |
499 | ICHECK( |
500 | value->cached_func->funcs->Lookup(value->cached_func->prim_fn_var).as<tir::PrimFuncNode>()); |
501 | |
502 | VLOG(1) << "lowered to name:" << std::endl |
503 | << PrettyPrint(value->cached_func->prim_fn_var) << std::endl |
504 | << "with definitions:" << std::endl |
505 | << PrettyPrint(value->cached_func->funcs); |
506 | return value; |
507 | } |
508 | |
509 | Map<String, Integer> GetOpWeights() const { |
510 | Map<String, Integer> weights; |
511 | for (const auto& kv : cache_) { |
512 | auto value = kv.second; |
513 | auto name = value->cached_func->prim_fn_var->name_hint; |
514 | weights.Set(name, value->use_count); |
515 | } |
516 | return weights; |
517 | } |
518 | |
519 | // TODO(mbs): Hold the output module here and reduce the cache_ to just be from |
520 | // Function to GlobalVar. |
521 | |
522 | /*! \brief compiler cache lock*/ |
523 | std::mutex mutex_; |
524 | /*! \brief internal GlobalVarSupply to get unique GlobalVars */ |
525 | GlobalVarSupply global_var_supply_; |
526 | /*! \brief A NameSupply object for assigning unique names to constants, across different |
527 | * invocations of PrimFuncFor. */ |
528 | NameSupply constant_name_supply_; |
529 | /*! \brief internal compiler cache */ |
530 | std::unordered_map<CCacheKey, CCacheValue> cache_; |
531 | /*! \brief internal compiler cache for shape funcs */ |
532 | std::unordered_map<CCacheKey, CCacheValue> shape_func_cache_; |
533 | /*! \brief the cache key of the function that is being lowered currently*/ |
534 | CCacheKey cur_ccache_key_; |
535 | /*! \brief Map of GlobalVar to C Device API context names */ |
536 | Map<GlobalVar, String> device_contexts_; |
537 | }; |
538 | |
539 | TECompiler::TECompiler(Optional<IRModule> opt_mod, Optional<String> mod_name) { |
540 | auto object = make_object<TECompilerImpl>(std::move(opt_mod), std::move(mod_name)); |
541 | data_ = object; |
542 | } |
543 | |
544 | /*! \brief The global TE compiler */ |
545 | // TODO(mbs): To be terminated with extreme prejudice. |
546 | TECompiler& TECompiler::Global() { |
547 | static TECompiler* inst = |
548 | new TECompiler(make_object<TECompilerImpl>(Optional<IRModule>(), Optional<String>())); |
549 | return *inst; |
550 | } |
551 | TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.use_auto_scheduler" , Bool); |
552 | TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.use_meta_schedule" , Bool); |
553 | TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.use_meta_schedule_dispatch" , Integer); |
554 | TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.tir_converter" , String); |
555 | |
556 | TVM_REGISTER_GLOBAL("relay.backend._TECompilerGlobal" ).set_body_typed([]() { |
557 | return TECompiler::Global(); |
558 | }); |
559 | |
560 | TVM_REGISTER_GLOBAL("relay.backend._make_CCacheKey" ) |
561 | .set_body_typed([](Function source_func, Target target) { |
562 | return CCacheKey(source_func, target); |
563 | }); |
564 | |
565 | TVM_REGISTER_GLOBAL("relay.backend._make_LoweredOutput" ) |
566 | .set_body_typed([](tvm::Array<te::Tensor> outputs, OpImplementation impl) { |
567 | return LoweredOutput(outputs, impl); |
568 | }); |
569 | |
570 | TVM_REGISTER_GLOBAL("relay.backend._TECompilerClear" ).set_body_typed([](TECompiler self) { |
571 | self->Clear(); |
572 | }); |
573 | |
574 | TVM_REGISTER_GLOBAL("relay.backend._TECompilerLower" ) |
575 | .set_body_typed([](TECompiler self, CCacheKey key, const String mod_name) { |
576 | return self->Lower(key, mod_name); |
577 | }); |
578 | |
579 | TVM_REGISTER_GLOBAL("relay.backend._TECompilerJIT" ) |
580 | .set_body_typed([](TECompiler self, CCacheKey key) { return self->JIT(key); }); |
581 | |
582 | TVM_REGISTER_GLOBAL("relay.backend._TECompilerListItems" ).set_body_typed([](TECompiler self) { |
583 | TECompilerImpl* ptr = dynamic_cast<TECompilerImpl*>(self.operator->()); |
584 | ICHECK(ptr != nullptr); |
585 | return ptr->ListItems(); |
586 | }); |
587 | |
588 | using AnalysisRemapping = std::unordered_map<Expr, Expr, ObjectHash, ObjectEqual>; |
589 | |
590 | /*! |
591 | * \brief Rewrites call expressions to Relay Functions marked as "primitive" |
592 | * to calls to the corresponding TIR PrimFunc for the appropriate target. |
593 | * |
594 | * \code |
595 | * %0 = fn(...) { prim_op(...) } OR let %p = fn(...) { prim_op(...) } |
596 | * ... %0(...) ... ... %p(...) ... |
597 | * ==> |
598 | * def @q(..., target=<target>) { <tir body> } |
599 | * ... @q(...) ... |
600 | * \endcode |
601 | * |
602 | * Requires FuseOps, ToANormalForm, EtaExpand and InferType to have run. |
603 | * |
604 | * FuseOps is needed to identify and lift all prim op calls: |
605 | * \code |
606 | * ... prim_op(...) ... |
607 | * ==> |
608 | * %0 = fn(...) { prim_op(...) } |
609 | * ... %0(...) ... |
610 | * \endcode |
611 | * |
612 | * ToANormalForm is needed so we only need to consider vars and function literals as the call |
613 | * target. |
614 | * |
615 | * EtaExpand is needed to ensures all calls to primitives are direct: |
616 | * \code |
617 | * let %p1 = fn(...) { prim_op1(...) } |
618 | * let %p2 = fn(...) { prim_op2(...) } |
619 | * let %p = if (...) { %p1 } else { %p2 } |
620 | * ... %p(...) ... |
621 | * ==> |
622 | * let %p1 = fn(...) { prim_op1(...) } |
623 | * let %p2 = fn(...) { prim_op2(...) } |
624 | * let %p = fn(...) { if (...) { %p1(...) } else { %p2(...) } } |
625 | * ... %p(...) ... |
626 | * \endcode |
627 | */ |
628 | class LowerTensorExprMutator : public DeviceAwareExprMutator { |
629 | public: |
630 | LowerTensorExprMutator(IRModule module, ProcessFn process_fn, CompilationConfig config, |
631 | TECompiler compiler) |
632 | : DeviceAwareExprMutator(module), |
633 | module_(std::move(module)), |
634 | process_fn_(std::move(process_fn)), |
635 | config_(std::move(config)), |
636 | compiler_(std::move(compiler)), |
637 | debug_op_(Op::Get("debug" )) {} |
638 | |
639 | /*! |
640 | * \brief Returns the primitive function associated with \p expr, or nullptr if none. |
641 | */ |
642 | BaseFunc ResolveToPrimitive(const Expr& expr) { |
643 | // NOTE: We can't assume expr->checked_type_ is defined, so can't early exit for first-order |
644 | // expressions. |
645 | if (const auto* global_var_node = expr.as<GlobalVarNode>()) { |
646 | if (!module_->ContainGlobalVar(global_var_node->name_hint)) { |
647 | // TODO(mbs): extern function cleanup |
648 | // Assume the function is extern and thus no longer in the IRModule. |
649 | return {}; |
650 | } else { |
651 | BaseFunc base_func = module_->Lookup(GetRef<GlobalVar>(global_var_node)); |
652 | return ResolveToPrimitive(base_func); |
653 | } |
654 | } else if (const auto* prim_func_node = expr.as<tir::PrimFuncNode>()) { |
655 | return GetRef<tir::PrimFunc>(prim_func_node); |
656 | } else if (const auto* var_node = expr.as<VarNode>()) { |
657 | auto itr = primitive_functions_.find(var_node); |
658 | if (itr == primitive_functions_.end()) { |
659 | // Not bound to a primitive function. |
660 | return {}; |
661 | } else { |
662 | return itr->second; |
663 | } |
664 | } else if (const auto* function_node = expr.as<FunctionNode>()) { |
665 | if (function_node->HasNonzeroAttr(attr::kExtern)) { |
666 | // We have a regular call to an 'extern' function. The call itself needs to be rewritten |
667 | // to call_lowered form, and any required dynamic shape functions generated and |
668 | // cross-linked. |
669 | return GetRef<Function>(function_node); |
670 | } else if (function_node->HasNonzeroAttr(attr::kPrimitive)) { |
671 | if (const auto* call_node = function_node->body.as<CallNode>()) { |
672 | if (call_node->op == debug_op_) { |
673 | // Debug 'primitives' are not lowered. |
674 | return {}; |
675 | } |
676 | } |
677 | // We have a regular call to a 'primitive' function (possibly with a 'Compiler' attribute). |
678 | // We need to lower and rewrite the call. |
679 | return GetRef<Function>(function_node); |
680 | } else { |
681 | // Not marked as primitive during partitioning or TVM fusion. |
682 | return {}; |
683 | } |
684 | } else { |
685 | return {}; |
686 | } |
687 | } |
688 | |
689 | /*! |
690 | * \brief Returns a 'call_lowered' call to \p prim_fn_var with \p args and \p span with all the |
691 | * required attributes filled in. Generally \p prim_fn_var will correspond to the lowered or |
692 | * externally codegen-ed form of \p original_function, where \p lowered_functions binds all |
693 | * the required lowered functions. |
694 | * |
695 | * The call's attributes will capture: |
696 | * - Any attributes on the original_function. |
697 | * - All the lowered functions. |
698 | * TODO(mbs): Pretty sure that's no longer needed. |
699 | * - Details needed to cross-link the call to it's dynamic shape function, if any. |
700 | */ |
701 | Expr MakeLoweredCall(const BaseFunc& original_function, const GlobalVar& prim_fn_var, |
702 | Array<Expr> args, Span span, const Target& target, |
703 | const Map<GlobalVar, BaseFunc>& lowered_functions) { |
704 | auto opt_compiler = original_function->GetAttr<String>(attr::kCompiler); |
705 | |
706 | // Add some metadata on top of the *original function* and invoke the callback so it can |
707 | // be captured. |
708 | // TODO(@areusch, @jroesch): this metadata is for AOT, this should be our interface for AOT |
709 | Map<GlobalVar, tir::PrimFunc> prim_fns; |
710 | Array<GlobalVar> all_prim_fn_vars; |
711 | for (const auto& kv : lowered_functions) { |
712 | if (opt_compiler) { |
713 | // We expect the original function to have just the "Extern" attribute signaling the |
714 | // function (will be) compiled externally. |
715 | ICHECK(kv.second.as<FunctionNode>()) |
716 | << PrettyPrint(kv.first) << " must be bound to an (external) Function" ; |
717 | } else { |
718 | // We expect one or more PrimFuncs, one of which corresponds to 'the' lowered primitive, |
719 | // and the rest are in support of that via tir::Calls. |
720 | ICHECK(kv.second.as<tir::PrimFuncNode>()) |
721 | << PrettyPrint(kv.first) << " must be bound to a PrimFunc" ; |
722 | prim_fns.Set(kv.first, Downcast<tir::PrimFunc>(kv.second)); |
723 | all_prim_fn_vars.push_back(kv.first); |
724 | } |
725 | } |
726 | |
727 | // Alas, WithAttr cannot work with base classes. |
728 | if (const auto* prim_func_node = original_function.as<te::PrimFuncNode>()) { |
729 | auto func_with_metadata = GetRef<te::PrimFunc>(prim_func_node); |
730 | func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var" , prim_fn_var); |
731 | func_with_metadata = WithAttr(func_with_metadata, "prim_funcs" , prim_fns); |
732 | func_with_metadata = WithAttr(func_with_metadata, tvm::attr::kTarget, target); |
733 | this->process_fn_(func_with_metadata); |
734 | } else { |
735 | const auto* function_node = original_function.as<FunctionNode>(); |
736 | ICHECK(function_node); |
737 | auto func_with_metadata = GetRef<Function>(function_node); |
738 | func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var" , prim_fn_var); |
739 | func_with_metadata = WithAttr(func_with_metadata, "prim_funcs" , prim_fns); |
740 | func_with_metadata = WithAttr(func_with_metadata, tvm::attr::kTarget, target); |
741 | this->process_fn_(func_with_metadata); |
742 | } |
743 | |
744 | // Now prepare the attributes of the call_lowered. |
745 | CallLoweredAttrs call_lowered_attrs; |
746 | |
747 | // TODO(mbs): "reshape" cleanup. |
748 | if (!opt_compiler && original_function->HasNonzeroAttr(attr::kReshapeOnly)) { |
749 | call_lowered_attrs.metadata.Set(attr::kReshapeOnly, tvm::Integer(1)); |
750 | } |
751 | |
752 | call_lowered_attrs.metadata.Set("relay_attrs" , original_function->attrs); |
753 | call_lowered_attrs.metadata.Set("all_prim_fn_vars" , all_prim_fn_vars); |
754 | |
755 | if (const auto* function_node = original_function.as<FunctionNode>()) { |
756 | if (IsDynamic(function_node->ret_type)) { |
757 | // Create a dynamic shape function to calculate the expected shape of the results of |
758 | // the lowered function. |
759 | // Shape function keys use the original function as their 'function', but the generic 'cpu' |
760 | // target as the target since all shape functions run on the host cpu irrespective of where |
761 | // the primitive runs. |
762 | CCacheKey shape_key(GetRef<Function>(function_node), config_->host_virtual_device->target); |
763 | CachedFunc lowered_shape_func = compiler_->LowerShapeFunc(shape_key); |
764 | |
765 | // Capture the shape function's global var and parameters 'states' in call |
766 | // annotations so calling convention can be recovered. |
767 | // TODO(mbs): Shape cleanup. |
768 | call_lowered_attrs.metadata.Set("prim_shape_fn_var" , lowered_shape_func->prim_fn_var); |
769 | call_lowered_attrs.metadata.Set("prim_shape_fn_states" , |
770 | lowered_shape_func->shape_func_param_states); |
771 | call_lowered_attrs.metadata.Set( |
772 | "prim_shape_fn_num_inputs" , |
773 | Integer(static_cast<int>(lowered_shape_func->inputs.size()))); |
774 | call_lowered_attrs.metadata.Set( |
775 | "prim_shape_fn_num_outputs" , |
776 | Integer(static_cast<int>(lowered_shape_func->outputs.size()))); |
777 | Array<GlobalVar> all_prim_shape_fn_vars; |
778 | for (const auto& kv : lowered_shape_func->funcs->functions) { |
779 | CHECK(kv.second.as<tir::PrimFuncNode>()) << "must be a prim fn" ; |
780 | all_prim_shape_fn_vars.push_back(kv.first); |
781 | } |
782 | call_lowered_attrs.metadata.Set("all_prim_shape_fn_vars" , all_prim_shape_fn_vars); |
783 | } |
784 | } |
785 | |
786 | return CallLowered(prim_fn_var, std::move(args), std::move(call_lowered_attrs), |
787 | std::move(span)); |
788 | } |
789 | |
790 | std::pair<Var, Expr> PreVisitLetBinding_(const Var& var, const Expr& value) final { |
791 | Var new_var = Downcast<Var>(Mutate(var)); |
792 | Expr new_value = Mutate(value); |
793 | BaseFunc prim_func = ResolveToPrimitive(new_value); |
794 | |
795 | if (prim_func.defined()) { |
796 | // Remember let var is bound (possibly indirectly) to a primitive function. |
797 | primitive_functions_.emplace(var.get(), prim_func); |
798 | } |
799 | return {new_var, new_value}; |
800 | } |
801 | |
802 | Expr PostVisitLet_(const LetNode* pre_let_node, const LetNode* post_let_node) final { |
803 | BaseFunc prim_func = ResolveToPrimitive(post_let_node->value); |
804 | if (prim_func.defined()) { |
805 | // Leaving let var scope |
806 | primitive_functions_.erase(pre_let_node->var.get()); |
807 | // Drop the let node |
808 | return post_let_node->body; |
809 | } |
810 | return DeviceAwareExprMutator::PostVisitLet_(pre_let_node, post_let_node); |
811 | } |
812 | |
813 | Expr DeviceAwareVisitExpr_(const FunctionNode* function_node) override { |
814 | if (function_node->HasNonzeroAttr(attr::kPrimitive) || |
815 | function_node->HasNonzeroAttr(attr::kExtern)) { |
816 | // Nothing to lower inside primitive/external functions. |
817 | return GetRef<Function>(function_node); |
818 | } else { |
819 | return DeviceAwareExprMutator::DeviceAwareVisitExpr_(function_node); |
820 | } |
821 | } |
822 | |
823 | Expr DeviceAwareVisitExpr_(const CallNode* call_node) override { |
824 | // We can see six forms of calls: |
825 | // 1. A 'normal' Relay call to a Function with the "Primitive" attribute and not "Compiler" |
826 | // attribute. We will need to lower that to a global PrimFunc and rewrite the call to: |
827 | // call_lowered(@new_global, (arg1, ..., argn), <attributes>) |
828 | // If needed, the call needs to be cross-linked with any dynamic shape functions. |
829 | // (However, some primitives are special and handled separately.) |
830 | // 2. A 'normal' Relay call to a Function with the "Primitive" and "Compiler" attributes. We |
831 | // will need to invoke the "relay.ext.<compiler>" function to yield a runtime module, and |
832 | // rewrite the call to the same form as above. Dynamic shape function cross-linking may |
833 | // also be needed. |
834 | // 3. A 'normal' Relay call to a Function with the "Extern" attribute. This function has |
835 | // already been compiled by an external codegen and a definition for it exists in some |
836 | // runtime module. Again, we rewrite to call_lowered form, and cross-link with a dynamic |
837 | // shape function if needed. |
838 | // 4. A 'normal' Relay call to a PrimFunc which has already been supplied via a global |
839 | // definition. We rewrite those to use the call_lowered form, but otherwise nothing else |
840 | // needs to be done. |
841 | // 5. A 'call_lowered' call from an earlier invocation of this pass or otherwise deliberately |
842 | // inserted. It has all the required attributes, and any associated dynamic shape function |
843 | // has been generated and cross-linked. These calls are not changed. |
844 | // 6. A 'normal' Relay call to a Relay Function without any special attribute. These |
845 | // calls are not changed. |
846 | // |
847 | // Note that ResolveToPrimitive will yield non-null only for cases 1-4. |
848 | |
849 | // Prepare the arguments and op. |
850 | Array<Expr> new_args; |
851 | for (const auto& arg : call_node->args) { |
852 | new_args.push_back(VisitExpr(arg)); |
853 | } |
854 | Expr new_op = VisitExpr(call_node->op); |
855 | |
856 | // Look for (possibly indirect) calls to primitives. |
857 | BaseFunc primitive_func = ResolveToPrimitive(call_node->op); |
858 | if (!primitive_func.defined()) { |
859 | // Cases 5 and 6: Leave as ordinary call. |
860 | if (const auto* function_node = call_node->op.as<FunctionNode>()) { |
861 | process_fn_(GetRef<Function>(function_node)); |
862 | } |
863 | return WithFields(GetRef<Call>(call_node), std::move(new_op), std::move(new_args)); |
864 | } |
865 | |
866 | // Special case for case 1: device_copies are left as calls to primitive operators |
867 | // so that each backend can handle them directly. |
868 | // TODO(mbs): device_copy cleanup. Would be better for FuseOps to just leave device_copy alone. |
869 | if (const auto* function_node = primitive_func.as<FunctionNode>()) { |
870 | DeviceCopyProps device_copy_props = GetDeviceCopyProps(function_node->body); |
871 | if (device_copy_props.body.defined()) { |
872 | ICHECK_EQ(new_args.size(), 1); |
873 | return DeviceCopy(new_args[0], device_copy_props.src_virtual_device, |
874 | device_copy_props.dst_virtual_device); |
875 | } |
876 | } |
877 | |
878 | ICHECK(call_node->type_args.empty()) << "lowered functions cannot be polymorphic" ; |
879 | |
880 | // Case 4: If the function has already been lowered we just need to update the call. |
881 | if (const auto* prim_func_node = primitive_func.as<tir::PrimFuncNode>()) { |
882 | // Function should already be Target annotated by this point |
883 | // but the TE Compiler metadata is still needed for the callback |
884 | // TODO(Mousius) - Robustify this to not assume we're in the GlobalVar for Target Hooks |
885 | Optional<Target> opt_target = primitive_func->GetAttr<Target>(tvm::attr::kTarget); |
886 | ICHECK(opt_target.defined()); |
887 | auto prim_fn_var = Downcast<GlobalVar>(call_node->op); |
888 | tir::PrimFunc prim_func = GetRef<tir::PrimFunc>(prim_func_node); |
889 | Map<GlobalVar, BaseFunc> prim_fns = {{prim_fn_var, prim_func}}; |
890 | return MakeLoweredCall(primitive_func, prim_fn_var, std::move(new_args), call_node->span, |
891 | opt_target.value(), prim_fns); |
892 | } |
893 | |
894 | // Determine the target for lowering or external codegen. |
895 | Target target; |
896 | Optional<String> opt_compiler = primitive_func->GetAttr<String>(attr::kCompiler); |
897 | if (opt_compiler.defined()) { |
898 | // This function needs to be compiled with external codegen. |
899 | Optional<Target> opt_target = config_->FindPrimitiveTargetForKind(opt_compiler.value()); |
900 | if (opt_target.defined()) { |
901 | // The target is what's supplied by the compilation config for kind matching the |
902 | // "Compiler" name. |
903 | target = opt_target.value(); |
904 | } else { |
905 | // Legacy fallback. |
906 | target = Target("ext_dev" ); |
907 | } |
908 | } else { |
909 | // The target corresponding to the call_node expression's annotation. |
910 | VirtualDevice virtual_device = GetVirtualDevice(GetRef<Call>(call_node)); |
911 | ICHECK(!virtual_device->IsFullyUnconstrained()) << PrettyPrint(GetRef<Call>(call_node)); |
912 | target = virtual_device->target; |
913 | ICHECK(target.defined()); |
914 | } |
915 | |
916 | if (primitive_func->HasNonzeroAttr(attr::kExtern)) { |
917 | // Case 3: Function has already been compiled. |
918 | GlobalVar prim_fn_var = Downcast<GlobalVar>(call_node->op); |
919 | return MakeLoweredCall(primitive_func, prim_fn_var, std::move(new_args), call_node->span, |
920 | target, /*lowered_functions=*/{}); |
921 | } else { |
922 | // Cases 1 and 2: lower the primitive function for the desired target, possibly using external |
923 | // codegen. |
924 | CCacheKey key(Downcast<Function>(primitive_func), target, |
925 | GetVirtualDevice(GetRef<Call>(call_node))); |
926 | CachedFunc cfunc = compiler_->Lower(key); |
927 | ICHECK(cfunc.defined()); |
928 | return MakeLoweredCall(primitive_func, cfunc->prim_fn_var, std::move(new_args), |
929 | call_node->span, target, cfunc->funcs->functions); |
930 | } |
931 | } |
932 | |
933 | IRModule module_; |
934 | ProcessFn process_fn_; |
935 | /*! \brief All available targets. */ |
936 | CompilationConfig config_; |
937 | // Map from in-scope let-bound variables to Functions known to be primitive, or PrimFuncs which |
938 | // have already been lowered. We'll rewrite these to the fresh global vars bound to the lowered |
939 | // primitive function as we go. Those vars will be bound in the target device-type specific |
940 | // module we'll ultimately emit for each required device-type. Note that a primitive may be |
941 | // lowered for multiple device types, each which will be assigned a fresh var. |
942 | std::unordered_map<const VarNode*, BaseFunc> primitive_functions_; |
943 | TECompiler compiler_; |
944 | // Cache ops that need to be frequently used later to reduce lookup overhead. |
945 | const Op& debug_op_; |
946 | }; |
947 | |
948 | Pass LowerTensorExpr(TECompiler compiler, ProcessFn process_fn, CompilationConfig config) { |
949 | runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func = |
950 | [=](Function func, IRModule module, PassContext ctx) { |
951 | LowerTensorExprMutator lower_te(module, process_fn, config, compiler); |
952 | return Downcast<Function>(lower_te.Mutate(func)); |
953 | }; |
954 | return CreateFunctionPass(pass_func, 0, "LowerTensorExpr" , {}); |
955 | } |
956 | |
957 | backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, const CompilationConfig& config, |
958 | Map<Expr, backend::StorageInfo> storage_info_map) { |
959 | Function func = Downcast<Function>(mod->Lookup("main" )); |
960 | |
961 | VLOG_CONTEXT << "UpdateMainWorkspaceSize" ; |
962 | VLOG(1) << "calculating FunctionInfo for main:" << std::endl << PrettyPrint(func); |
963 | |
964 | // This is a Map<device,Map<storage_id, size>> |
965 | // TODO(mbs): Collapsing VirtualDevices to just device type. |
966 | std::unordered_map<DLDeviceType, std::unordered_map<int, int>, backend::EnumClassHash> |
967 | sid_workspace; |
968 | // This is a Map<device, size_of_inputs_and_outputs> |
969 | std::unordered_map<DLDeviceType, int, backend::EnumClassHash> device_io; |
970 | // This is a Map<device, size_of_constants> |
971 | std::unordered_map<DLDeviceType, int, backend::EnumClassHash> device_consts; |
972 | |
973 | // Initialize the mapping from all storage identifiers to workspace sizes, |
974 | // the amount of device io, and the device constants. |
975 | for (const auto& kv : storage_info_map) { |
976 | const backend::StorageInfo& storage_info = kv.second; |
977 | const std::vector<int64_t>& storage_ids = storage_info->storage_ids; |
978 | const std::vector<VirtualDevice>& virtual_devices = storage_info->virtual_devices; |
979 | CHECK_EQ(storage_ids.size(), virtual_devices.size()); |
980 | for (uint32_t i = 0; i < virtual_devices.size(); i++) { |
981 | DLDeviceType device_type = virtual_devices[i]->device_type(); |
982 | sid_workspace[device_type][storage_ids[i]] = 0; |
983 | device_io[device_type] = 0; |
984 | device_consts[device_type] = 0; |
985 | } |
986 | } |
987 | |
988 | // Iterate the storage map to compute all the tensor sizes in the program. |
989 | // There are 3 cases in this code: |
990 | // |
991 | // First we need to compute the sizes of all |
992 | // inline constants. |
993 | // |
994 | // Second we compute the size of any bound variable as these are input and output |
995 | // sizes of the program. |
996 | // |
997 | // Finally for all other expressions we check which storage identifier they have |
998 | // been assigned and we compute the maximal size of the storage, as tensors can |
999 | // share storage with other tensors which are the same size or larger. |
1000 | // |
1001 | // In this final case there is only one allocation for all tensors which share storage |
1002 | // which will be the maximal size of all tensors which were assigned to it. |
1003 | for (const auto& kv : storage_info_map) { |
1004 | const Expr& expr = kv.first; |
1005 | const backend::StorageInfo& storage_info = kv.second; |
1006 | int64_t size_bytes = backend::CalculateRelayExprSizeBytes(expr->checked_type()); |
1007 | VLOG(1) << "expression:" << std::endl |
1008 | << PrettyPrint(expr) << std::endl |
1009 | << "of type:" << std::endl |
1010 | << PrettyPrint(expr->checked_type()) << std::endl |
1011 | << "has size " << size_bytes << " and storage info:" << std::endl |
1012 | << storage_info; |
1013 | const std::vector<int64_t>& storage_ids = storage_info->storage_ids; |
1014 | const std::vector<VirtualDevice>& virtual_devices = storage_info->virtual_devices; |
1015 | |
1016 | if (expr->IsInstance<ConstantNode>()) { |
1017 | for (const auto& virtual_device : virtual_devices) { |
1018 | DLDeviceType device_type = virtual_device->device_type(); |
1019 | ICHECK_EQ(device_consts.count(device_type), 1); |
1020 | device_consts[device_type] += size_bytes; |
1021 | } |
1022 | } else if (expr->IsInstance<VarNode>() || expr.same_as(func->body)) { |
1023 | CHECK(size_bytes == 0 || virtual_devices.size() >= 1) << "must be at least one device" ; |
1024 | for (const auto& virtual_device : virtual_devices) { |
1025 | DLDeviceType device_type = virtual_device->device_type(); |
1026 | device_io[device_type] += size_bytes; |
1027 | } |
1028 | } else { |
1029 | // TODO(@electriclilies): This code is never being called which means sid_workspace is not |
1030 | // updated.. This means that storage info is probably not being created correctly. Or is not |
1031 | // equivalent to what was here previously |
1032 | for (uint32_t i = 0; i < storage_ids.size(); i++) { |
1033 | // Here we record the largest size of the tensor |
1034 | // that share the same storage id, because storage_id will |
1035 | // be shared between multiple tensors that are not live simultaneously. |
1036 | DLDeviceType device_type = virtual_devices[i]->device_type(); |
1037 | if (size_bytes > sid_workspace[device_type][storage_ids[i]]) { |
1038 | sid_workspace[device_type][storage_ids[i]] = size_bytes; |
1039 | } |
1040 | } |
1041 | } |
1042 | } |
1043 | |
1044 | // This is a Map<device, workspace_size> |
1045 | std::unordered_map<DLDeviceType, int, backend::EnumClassHash> device_workspace; |
1046 | // Once we know the sizes of sids, we need to accumulate per device |
1047 | for (const auto& dev_sid_size : sid_workspace) { |
1048 | auto dev = dev_sid_size.first; |
1049 | device_workspace[dev] = 0; |
1050 | for (const auto& sid_size : dev_sid_size.second) { |
1051 | device_workspace[dev] += sid_size.second; |
1052 | } |
1053 | } |
1054 | |
1055 | Map<Target, Integer> workspace_sizes; |
1056 | Map<Target, Integer> io_sizes; |
1057 | Map<Target, Integer> constant_sizes; |
1058 | Map<Target, tir::PrimFunc> tir_primfuncs; |
1059 | Map<Target, Function> relay_primfuncs; |
1060 | |
1061 | // Initialize all target workspaces to zero |
1062 | for (const auto& target : config->primitive_targets) { |
1063 | workspace_sizes.Set(target, 0); |
1064 | } |
1065 | |
1066 | for (const auto& dev_and_size : device_workspace) { |
1067 | Target target = config->FindPrimitiveTargetForDeviceOrFail(dev_and_size.first); |
1068 | workspace_sizes.Set(target, dev_and_size.second); |
1069 | relay_primfuncs.Set(target, func); |
1070 | } |
1071 | for (const auto& dev_and_size : device_io) { |
1072 | Target target = config->FindPrimitiveTargetForDeviceOrFail(dev_and_size.first); |
1073 | io_sizes.Set(target, dev_and_size.second); |
1074 | } |
1075 | |
1076 | for (const auto& dev_and_size : device_consts) { |
1077 | Target target = config->FindPrimitiveTargetForDeviceOrFail(dev_and_size.first); |
1078 | ICHECK_EQ(constant_sizes.count(target), 0); |
1079 | constant_sizes.Set(target, dev_and_size.second); |
1080 | } |
1081 | |
1082 | backend::FunctionInfo func_info(std::move(workspace_sizes), std::move(io_sizes), |
1083 | std::move(constant_sizes), std::move(tir_primfuncs), |
1084 | std::move(relay_primfuncs)); |
1085 | VLOG(1) << "func_info: " << func_info; |
1086 | return std::move(func_info); |
1087 | } |
1088 | |
1089 | /*! |
1090 | * \brief A function to create the function metadata for an input function (ie calculate buffer |
1091 | * input/output sizes) |
1092 | * \param func The function to calculate function metadata for |
1093 | * \param function_metadata The map that stores all the function metadatas |
1094 | */ |
1095 | void UpdateFunctionMetadata(BaseFunc func, |
1096 | Map<String, backend::FunctionInfo>& function_metadata, // NOLINT(*) |
1097 | Integer workspace_byte_alignment) { |
1098 | VLOG_CONTEXT << "UpdateFunctionMetadata" ; |
1099 | VLOG(1) << "updating function metadata for:" << std::endl << PrettyPrint(func); |
1100 | // Originally UpdateFunctionMetadata took in CCachedFunc and looped through all the funcs stored |
1101 | // there Now the goal is to take only one func because process_fn should be controlling the |
1102 | // iteration However, to do the workspace calculations we need the primfuncs. So process_fn |
1103 | // needs to either access the cached funcs or be directly passed primfuncs This is bad and |
1104 | // ideally we don't want process_fn to look at primfuncs There's also the question now of what |
1105 | // the function metadatas are and how they are used if we can do something else to replicate the |
1106 | // behavior of the function metadatas that might be good (ie annotating functions or something). |
1107 | Map<Target, Integer> workspace_sizes; |
1108 | Map<Target, Integer> io_sizes; |
1109 | Map<Target, Integer> constant_sizes; |
1110 | Map<Target, tir::PrimFunc> tir_primfuncs; |
1111 | Map<Target, Function> relay_primfuncs; |
1112 | |
1113 | Optional<Map<GlobalVar, tir::PrimFunc>> prim_fns = |
1114 | func->GetAttr<Map<GlobalVar, tir::PrimFunc>>("prim_funcs" ); |
1115 | CHECK(prim_fns) << "primitive functions not set on Relay function by TECompiler." ; |
1116 | |
1117 | Optional<GlobalVar> prim_fn_var = func->GetAttr<GlobalVar>("prim_fn_var" ); |
1118 | CHECK(prim_fn_var) << "prim_fn_var must be set on Relay functions by TECompiler." ; |
1119 | |
1120 | Optional<Target> relay_target = func->GetAttr<Target>(tvm::attr::kTarget); |
1121 | CHECK(relay_target) << "target must be set on Relay functions by the TECompiler." ; |
1122 | |
1123 | for (const auto& kv : prim_fns.value()) { |
1124 | auto prim_fn = Downcast<tir::PrimFunc>(kv.second); |
1125 | CHECK(prim_fn.defined()) << "the primitive function must be defined" ; |
1126 | |
1127 | Integer workspace_size = CalculateWorkspaceBytes(prim_fn, workspace_byte_alignment); |
1128 | |
1129 | // Workspace sizes |
1130 | Target prim_fn_target; |
1131 | if (prim_fn->attrs->dict.count(tvm::attr::kTarget)) { |
1132 | prim_fn_target = Downcast<Target>(prim_fn->attrs->dict[tvm::attr::kTarget]); |
1133 | } else { |
1134 | prim_fn_target = relay_target.value(); |
1135 | } |
1136 | |
1137 | workspace_sizes.Set(prim_fn_target, workspace_size); |
1138 | |
1139 | // Calculating size for I/O |
1140 | // TODO(mbs): See also the other three utils for calculating tensor bytesize. |
1141 | for (auto const& param : prim_fn->params) { |
1142 | bool not_a_buffer = prim_fn->buffer_map.count(param) == 0; |
1143 | if (not_a_buffer) { |
1144 | io_sizes.Set(prim_fn_target, 0); |
1145 | continue; |
1146 | } |
1147 | |
1148 | auto p_shape = prim_fn->buffer_map[param]->shape; |
1149 | int num_of_elements = 1; |
1150 | for (const auto& dim_index_expr : p_shape) { |
1151 | if (dim_index_expr->IsInstance<IntImmNode>()) { |
1152 | num_of_elements *= dim_index_expr.as<IntImmNode>()->value; |
1153 | } else { |
1154 | // If shape is dynamic, we cannot calculate workspace in compile time. |
1155 | num_of_elements = 0; |
1156 | } |
1157 | } |
1158 | int element_size = prim_fn->buffer_map[param]->dtype.bytes(); |
1159 | io_sizes.Set(prim_fn_target, element_size * num_of_elements); |
1160 | } |
1161 | |
1162 | constant_sizes.Set(prim_fn_target, 0); |
1163 | tir_primfuncs.Set(prim_fn_target, prim_fn); |
1164 | if (func->IsInstance<FunctionNode>()) { |
1165 | relay_primfuncs.Set(prim_fn_target, Downcast<Function>(func)); |
1166 | } |
1167 | } |
1168 | |
1169 | backend::FunctionInfo fi = backend::FunctionInfo( |
1170 | std::move(workspace_sizes), std::move(io_sizes), std::move(constant_sizes), |
1171 | std::move(tir_primfuncs), std::move(relay_primfuncs)); |
1172 | |
1173 | VLOG(1) << "FunctionInfo: " << PrettyPrint(prim_fn_var.value()) << " = " << PrettyPrint(fi); |
1174 | |
1175 | // The primitive function name here corresponds to the string we will use to generate |
1176 | // this Relay function at the low level. |
1177 | function_metadata.Set(prim_fn_var.value()->name_hint, fi); |
1178 | } |
1179 | |
1180 | /*! \brief Main lowering driving. */ |
1181 | IRModule LowerTE(const IRModule& module, const String& module_name, ProcessFn process_fn, |
1182 | CompilationConfig config) { |
1183 | TECompiler compiler(module, module_name); |
1184 | |
1185 | // TODO(mbs): This is all unnecessarily convoluted. Better would be to accumulate the rewritten |
1186 | // module as we go (including rewritten Functions, lowered primitives, and runtime modules |
1187 | // generated by external toolchains), and use a pair of maps over vars and global vars |
1188 | // to global vars to remember which functions have already been lowered. |
1189 | |
1190 | // Lower all the callees in module: |
1191 | // - Functions tagged with "Compiler" are unchanged (checked by CreateFunctionPass) |
1192 | // - Functions tagged with "Primitive" are unchanged (checked by LowerTensorExprMutator) |
1193 | // - Called functions tagged with "Compiler" are copied into the compiler cache with a fresh |
1194 | // GlobalVar, and calls updated (sticking with regular Relay Call). |
1195 | // - Calls to functions tagged with "Primitive" are compiled to PrimFuncs, and calls updated |
1196 | // (using call_lowered convention). |
1197 | IRModule updated_module = |
1198 | LowerTensorExpr(compiler, std::move(process_fn), std::move(config))(module); |
1199 | |
1200 | // The Functions tagged with "Compiler" are now residing in the cache ready to be |
1201 | // compiled by LowerExternalFunctions. However we still need a record of them in the |
1202 | // IRModule so that the various executors can see which function names need to be |
1203 | // retrieved. They may, however, have been renamed. |
1204 | compiler->AddExterns(updated_module); |
1205 | |
1206 | // Add the lowered functions. |
1207 | IRModule lowered_module = compiler->GetLoweredFunctions(); |
1208 | VLOG(1) << "capturing " << lowered_module->functions.size() << " new lowered functions" ; |
1209 | for (const auto& kv : lowered_module->functions) { |
1210 | if (updated_module->ContainGlobalVar(kv.first->name_hint)) { |
1211 | LOG(FATAL) << "duplicate bindings for '" << kv.first->name_hint |
1212 | << "'. Existing is:" << std::endl |
1213 | << PrettyPrint(updated_module->Lookup(kv.first->name_hint)) << std::endl |
1214 | << "while new is:" << std::endl |
1215 | << PrettyPrint(kv.second); |
1216 | } |
1217 | updated_module->Add(kv.first, kv.second); |
1218 | } |
1219 | |
1220 | // Invoke external codegen for all Functions in the cache tagged with "Compiler", and |
1221 | // annotate the module with the resulting runtime modules. |
1222 | // TODO(mbs): runtime modules should be first class rather than attributes. |
1223 | Array<runtime::Module> external_mods = |
1224 | module->GetAttr<Array<runtime::Module>>(tvm::attr::kExternalMods).value_or({}); |
1225 | Array<runtime::Module> new_external_mods = compiler->LowerExternalFunctions(); |
1226 | VLOG(1) << "capturing " << external_mods.size() << " existing and " << new_external_mods.size() |
1227 | << " new external modules" ; |
1228 | for (const auto& mod : new_external_mods) { |
1229 | external_mods.push_back(mod); // copy-on-write. |
1230 | } |
1231 | |
1232 | // Annotate the module with C Device API context mapping (this is until we have Targets |
1233 | // annotated for the C Device API) |
1234 | // TODO(Mousius) - Remove "device_contexts" as soon as we have the graph annotated properly with |
1235 | // Targets |
1236 | Map<GlobalVar, String> device_contexts = |
1237 | module->GetAttr<Map<GlobalVar, String>>("device_contexts" , Map<GlobalVar, String>()).value(); |
1238 | Map<GlobalVar, String> new_device_contexts = compiler->GetDeviceContexts(); |
1239 | VLOG(1) << "capturing " << device_contexts.size() << " existing and " |
1240 | << new_device_contexts.size() << " new device contexts for external functions" ; |
1241 | for (const auto& kv : new_device_contexts) { |
1242 | ICHECK_EQ(device_contexts.count(kv.first), 0); |
1243 | device_contexts.Set(kv.first, kv.second); // copy-on-write. |
1244 | } |
1245 | |
1246 | updated_module = WithAttrs(updated_module, {{tvm::attr::kExternalMods, std::move(external_mods)}, |
1247 | {"device_contexts" , std::move(device_contexts)}}); |
1248 | |
1249 | if (backend::IsAutoSchedulerEnabled()) { |
1250 | // Capture all the 'operator weights', ie usage counts for each PrimFunc. |
1251 | Map<String, Integer> op_weights = |
1252 | module->GetAttr<Map<String, Integer>>("op_weights" , Map<String, Integer>()).value(); |
1253 | Map<String, Integer> new_op_weights = compiler->GetOpWeights(); |
1254 | VLOG(1) << "capturing " << op_weights.size() << " existing and " << new_op_weights.size() |
1255 | << " new operator weights for PrimFuncs" ; |
1256 | for (const auto& kv : new_op_weights) { |
1257 | ICHECK_EQ(op_weights.count(kv.first), 0); |
1258 | op_weights.Set(kv.first, kv.second); // copy-on-write. |
1259 | } |
1260 | updated_module = WithAttr(updated_module, "op_weights" , std::move(op_weights)); |
1261 | } |
1262 | |
1263 | return updated_module; |
1264 | } |
1265 | |
1266 | Map<Target, IRModule> GetPerTargetModules(IRModule mod) { |
1267 | std::unordered_map<Target, IRModule, backend::TargetStrHash, backend::TargetStrEqual> |
1268 | per_target_modules; |
1269 | for (const auto& kv : mod->functions) { |
1270 | const GlobalVar& var = kv.first; |
1271 | const BaseFunc& func = kv.second; |
1272 | if (func->IsInstance<tir::PrimFuncNode>()) { |
1273 | // Extract target |
1274 | Optional<Target> target = func->GetAttr<Target>(tvm::attr::kTarget); |
1275 | ICHECK(target) << "Target should be set at this point" ; |
1276 | |
1277 | // Put the function in per_target_modules |
1278 | if (!per_target_modules.count(target.value())) { |
1279 | // Initialize the IRModule for this target with the attributes from the input IRModule |
1280 | IRModule target_module = IRModule({}, {}, {}, {}, mod->attrs); |
1281 | // Add the function to the IRModule |
1282 | target_module->Add(var, func); |
1283 | per_target_modules[target.value()] = target_module; |
1284 | } else { |
1285 | // The IRModule for this target is initialized, so just add the function. |
1286 | IRModule target_module = per_target_modules.at(target.value()); |
1287 | target_module->Add(var, func); |
1288 | } |
1289 | } else if (!func->IsInstance<relay::FunctionNode>()) { |
1290 | LOG(FATAL) |
1291 | << "The function types in the IRModule should be RelayFunction or PrimFunc, but got " |
1292 | << func->GetTypeKey(); |
1293 | } |
1294 | } |
1295 | return per_target_modules; |
1296 | } |
1297 | |
1298 | Pass LowerTE(String module_name, CompilationConfig complilation_config, ProcessFn process_fn) { |
1299 | runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = [=](IRModule module, |
1300 | PassContext ctx) { |
1301 | return LowerTE(module, module_name, process_fn, complilation_config); |
1302 | }; |
1303 | |
1304 | return tvm::transform::Sequential( |
1305 | {tvm::relay::transform::RelayToTIRTargetHook(complilation_config), |
1306 | tvm::transform::CreateModulePass(pass_func, 0, "LowerTE" , {"InferType" }), InferType(), |
1307 | tvm::tir::transform::ExtractPrimFuncConstants()}); |
1308 | } |
1309 | |
1310 | TVM_REGISTER_GLOBAL("relay.tec.LowerTE" ) |
1311 | .set_body_typed([](String module_name, CompilationConfig compilation_config) { |
1312 | return LowerTE(std::move(module_name), std::move(compilation_config)); |
1313 | }); |
1314 | |
1315 | } // namespace tec |
1316 | } // namespace relay |
1317 | } // namespace tvm |
1318 | |