1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file relay/backend/te_compiler.cc
22 * \brief Manages the transition from Relay "Primitive" \p Functions to TIR \p PrimFuncs. Also
23 * handles invocation of external codegen.
24 *
25 * \p LowerTEPass handles the following (as a monolithic blob of code):
26 *
27 * - Most importantly, any function with the "Primitive" attribute is first converted to TE by
28 * \p LowerToTECompute (see te_compiler_cache.cc) using each operator's 'compute' function.
29 * The TE is then 'scheduled' to TIR using the 'anchor' operator's 'schedule' function. Both
30 * of those functions come from the \p OpStrategy returned by the Python
31 * 'relay.backend.lower_call' function (see te_compiler.py).
32 * The TIR is packed as a \p PrimFunc and introduced as a new global function. Calls to the
33 * original "Primitive" function are then rewritten to the form:
34 * \code
35 * call_lowered(@new_global, (... original args...), attributes)
36 * \endcode
37 *
38 * - The above "Primitive" function can appear:
39 * - As a global function
40 * - As a let-bound function
41 * - As an inline function, ie the 'op' of calls.
42 * In all three cases it is possible for the same "Primitive" function to be called multiple
43 * times, and that sharing must be respected.
44 *
45 * - "Primitive" functions must have a "global_symbol" attribute matching their desired or
46 * existing global name. Care is taken to ensure GlobalVars with the same name are shared.
47 *
48 * - It is possible for multiple structurally equal "Primitive" functions to appear in the same
49 * \p IRModule. Only one implementation should be generated, and all calls should share that
50 * implementation.
51 *
52 * - When later converting to DPS (see memory_alloc.cc) we must handle functions who's result
53 * tensor shapes depend at runtime on the input tensor shapes and/or data.
54 * - That dependency is first described in TE form (see \p MakeShapeFunc in
55 * te_compiler_cache.cc), then scheduled to yield a 'dynamic shape function' \p PrimFunc.
56 * This relies on each operator's "FShapeFunc" and "TShapeDataDependent" attributes.
57 * Since shapes are rank-1 tensors everything can be reflected back down into the regular
58 * TE/TIR forms.
59 * - Then the call_lowered attributes must record everything about the dynamic shape function
60 * later needed by memory_alloc.cc. We call this 'cross linking' the call with the shape
61 * function.
62 *
63 * - Two external codegen mechanisms are supported, both triggered by "Primitive" functions which
64 * also have a "Compiler" attribute bound to $compiler:
65 * - Function-at-a-time (old style): The primitive function is passed to the function
66 * registered as 'relay.ext.$compiler'. The function returns a runtime::Module which
67 * should return true for \p ImplementsFunction for the function's global name. That
68 * module is added to the IRModule's "external_mods" attributes.
69 * - IRModule-at-a-item (new style): The \p RelayToTIRTargetHook sub-pass looks for
70 * $compiler names which correspond to TargetKind names with a \p RelayToTIR attribute.
71 * The \p Pass bound to that attribute is run, and each such 'custom' pass can do what
72 * it likes, including replacing Functions with PrimFuncs, or adding new runtime::Modules
73 * to the IRModule's "external_mods" attribute.
74 *
75 * - Calls to functions added by external codegen are also rewritten to call_lowered form, and
76 * may also require cross-linking to dynamic shape functions. However, since the functions
77 * are/will be implemented by a runtime::Module all the Relay type information is no longer
78 * available. So the Relay definitions for these "Primitive" "Compiler" functions are retained
79 * in the \p IRModule, but marked with the "Extern" attribute to signal the function is now
80 * just for carrying metadata.
81 *
82 * - Some operators are handled specially:
83 * - 'reshape', since it's a no-op on the underlying tensor buffer, and this is handled by
84 * condition tests in many passes.
85 * - 'debug', since it's intercepted differently depending on runtimes.
86 *
87 * TODO(mbs): This desperately deserves a refactor to separate all these concerns. See Relax.
88 */
89
90#include "./te_compiler.h"
91
92#include <tvm/driver/driver_api.h>
93#include <tvm/ir/attrs.h>
94#include <tvm/ir/function.h>
95#include <tvm/ir/name_supply.h>
96#include <tvm/relay/analysis.h>
97#include <tvm/relay/attrs/annotation.h>
98#include <tvm/relay/attrs/call.h>
99#include <tvm/relay/attrs/device_copy.h>
100#include <tvm/relay/expr.h>
101#include <tvm/relay/expr_functor.h>
102#include <tvm/relay/op.h>
103#include <tvm/runtime/device_api.h>
104#include <tvm/runtime/registry.h>
105#include <tvm/te/schedule.h>
106#include <tvm/te/schedule_pass.h>
107#include <tvm/tir/transform.h>
108#include <tvm/topi/tags.h>
109
110#include <functional>
111#include <limits>
112#include <mutex>
113#include <tuple>
114#include <unordered_map>
115#include <utility>
116#include <vector>
117
118#include "../op/annotation/annotation.h"
119#include "../op/call/call.h"
120#include "../op/memory/device_copy.h"
121#include "../transforms/device_aware_visitors.h"
122#include "./te_compiler_cache.h"
123#include "./utils.h"
124
125namespace tvm {
126namespace relay {
127// TODO(@jroesch, @csullivan): declare directly elsewhere
128backend::StaticMemoryPlan GraphPlanMemory(const Function& func);
129
130namespace tec {
131
132using namespace tvm::relay::transform;
133
134TVM_REGISTER_OBJECT_TYPE(TECompilerNode);
135
136class TECompilerImpl : public TECompilerNode {
137 public:
138 explicit TECompilerImpl(Optional<IRModule> opt_mod, Optional<String> opt_mod_name)
139 : global_var_supply_(GlobalVarSupply(NameSupply(opt_mod_name.value_or("")))),
140 constant_name_supply_(NameSupply("")) {
141 // Make sure we don't collide with any existing globals in the module.
142 if (opt_mod) {
143 for (const auto& kv : opt_mod.value()->functions) {
144 global_var_supply_->name_supply_->ReserveName(kv.first->name_hint, false);
145 }
146 }
147 }
148
149 // Lower the function.
150 CachedFunc Lower(const CCacheKey& key) {
151 return LowerInternal(key, global_var_supply_)->cached_func;
152 }
153
154 // TODO(gigiblender): Only to be called by the global TE compiler.
155 // Remove this when the global TE compiler is removed.
156 CachedFunc Lower(const CCacheKey& key, const String mod_name) {
157 global_var_supply_->name_supply_->prefix_ = mod_name;
158 return LowerInternal(key, global_var_supply_)->cached_func;
159 }
160
161 // For now, build one module per function.
162 PackedFunc JIT(const CCacheKey& key) final {
163 CCacheValue value = LowerInternal(key, GlobalVarSupply(NameSupply("")));
164 if (value->packed_func != nullptr) {
165 return value->packed_func;
166 }
167 auto m = build(value->cached_func->funcs, key->target, Target(nullptr));
168 value->packed_func = m.GetFunction(value->cached_func->prim_fn_var->name_hint);
169 return value->packed_func;
170 }
171
172 CachedFunc LowerShapeFunc(const CCacheKey& key) final {
173 return LowerShapeFuncInternal(key)->cached_func;
174 }
175
176 IRModule GetLoweredFunctions() {
177 VLOG(1) << "GetLoweredFunctions";
178 IRModule mod;
179 // Extract lowered functions from the cache
180 for (const auto& it : cache_) {
181 auto source_func = it.first;
182 auto lowered_func = it.second;
183
184 IRModule lowered_mod = lowered_func->cached_func->funcs;
185
186 // Annotate functions with their target and put them in the return module
187 for (const auto& kv : lowered_mod->functions) {
188 const GlobalVar& var = kv.first;
189 const BaseFunc& func = kv.second;
190
191 // Only add functions that are not external functions
192 if (!func->GetAttr<String>(attr::kCompiler).defined()) {
193 ICHECK(func->IsInstance<tir::PrimFuncNode>())
194 << "Expected all functions that are not external to be PrimFuncs, but found:"
195 << std::endl
196 << PrettyPrint(func);
197 const tir::PrimFunc& prim_func = Downcast<tir::PrimFunc>(func);
198 mod->Update(var, WithAttr(prim_func, tvm::attr::kTarget, source_func->target));
199 }
200 }
201 }
202
203 // Extract lowered dynamic shape functions from the shape cache
204 for (const auto& it : shape_func_cache_) {
205 auto source_func = it.first;
206 auto lowered_func = it.second;
207 auto target = source_func->target;
208 IRModule lowered_mod = lowered_func->cached_func->funcs;
209
210 // Annotate functions with their target and put them in the return module
211 for (auto kv : lowered_mod->functions) {
212 const GlobalVar& var = kv.first;
213 const BaseFunc& func = kv.second;
214 const tir::PrimFunc& prim_func = Downcast<tir::PrimFunc>(func);
215 mod->Update(var, WithAttr(prim_func, tvm::attr::kTarget, source_func->target));
216 }
217 }
218
219 return mod;
220 }
221
222 void AddExterns(IRModule module) {
223 // Everything tagged with "Compiler" has been compiled, so remove those definitions.
224 std::vector<GlobalVar> to_be_deleted;
225 for (const auto& kv : module->functions) {
226 if (kv.second->GetAttr<String>(attr::kCompiler).defined()) {
227 to_be_deleted.push_back(kv.first);
228 }
229 }
230 for (const auto& global_var : to_be_deleted) {
231 VLOG(1) << "Removing definition for external codegened '" << global_var->name_hint << "'";
232 module->Remove(global_var);
233 }
234 // HOWEVER we still need a Relay definition to go with those now external functions, so
235 // retrieve them from the cache and mark them with "ExternalSymbol".
236 for (const auto& kv1 : cache_) {
237 auto src_func = kv1.first->source_func;
238 ICHECK(src_func.defined());
239 if (src_func->GetAttr<String>(attr::kCompiler).defined()) {
240 for (const auto& kv2 : kv1.second->cached_func->funcs->functions) {
241 if (const auto* function_node = kv2.second.as<FunctionNode>()) {
242 // Abandon the existing function annotations.
243
244 // Unfortunately, Optional<DictAttrs>() is indistinguishable from
245 // NullValue<DictAttrs>(), and DictAttrs() is nullptr, so to erase the attributes, we
246 // need pass in DictAttrs<Map<String, ObjectRef>()), which is a DictAttrs containing no
247 // attributes.
248 Function function =
249 WithFields(GetRef<Function>(function_node), function_node->params,
250 function_node->body, function_node->ret_type, function_node->type_params,
251 /* erase attributes */ DictAttrs(Map<String, ObjectRef>()));
252 // Mark function as 'extern'.
253 function = WithAttr(std::move(function), attr::kExtern, Integer(1));
254 module->Add(kv2.first, function);
255 }
256 }
257 }
258 }
259 }
260
261 Array<tvm::runtime::Module> LowerExternalFunctions() {
262 Array<tvm::runtime::Module> ret;
263 std::vector<CCacheKey> cached_ext_funcs;
264
265 for (const auto& it : cache_) {
266 auto src_func = it.first->source_func;
267 ICHECK(src_func.defined());
268 Optional<String> opt_compiler = src_func->GetAttr<String>(attr::kCompiler);
269 if (opt_compiler.defined()) {
270 Optional<String> opt_symbol_name = src_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
271 ICHECK(opt_symbol_name.defined()) << "No external symbol is set for:" << std::endl
272 << PrettyPrint(src_func);
273 VLOG(1) << "using external codegen '" << opt_compiler.value() << "' for name '"
274 << opt_symbol_name.value() << "' and function:" << std::endl
275 << PrettyPrint(src_func);
276 cached_ext_funcs.push_back(it.first);
277
278 std::string ext_name = "relay.ext." + opt_compiler.value();
279 auto pf = tvm::runtime::Registry::Get(ext_name);
280 ICHECK(pf) << "Failed to find the external codegen tool for " << ext_name;
281 // No need to keep compiler attribute at this point, functions have been
282 // extracted for specific codegen.
283 src_func = WithAttr(std::move(src_func), attr::kCompiler, NullValue<ObjectRef>());
284 VLOG_CONTEXT << opt_compiler.value();
285 With<Target> with_target(it.first->target);
286 runtime::Module ext_mod = (*pf)(src_func);
287 if (ext_mod.defined()) {
288 // TODO(mbs): Can this be an ICHECKs?
289 if (!ext_mod->ImplementsFunction(opt_symbol_name.value())) {
290 VLOG(1) << "Note that the external codegen for '" << opt_compiler.value()
291 << "' returned a runtime module which does not appear to implement '"
292 << opt_symbol_name.value() << "'";
293 }
294 ret.push_back(ext_mod);
295 } else {
296 // It is valid for the external codegen function to return null:
297 // - Unit tests can use it.
298 // - The true compilation may have already been handled by a RelayToTIR custom pass
299 // on the Target's kind. The original Relay functions will be left in place so
300 // that we can capture that their function names are now externally defined.
301 VLOG(1) << "Note that no external runtime module was generated by external codegen '"
302 << opt_compiler.value() << "'";
303 }
304 }
305 }
306
307 // No need to cache external functions as we collected them all to create
308 // external runtime modules.
309 for (const auto& it : cached_ext_funcs) {
310 cache_.erase(it);
311 }
312 return ret;
313 }
314
315 Map<GlobalVar, String> GetDeviceContexts() { return device_contexts_; }
316 void SetDeviceContexts(const Map<GlobalVar, String>& device_contexts) {
317 device_contexts_ = device_contexts;
318 }
319
320 void Clear() final { cache_.clear(); }
321
322 // List all items in the cache.
323 Array<ObjectRef> ListItems() {
324 std::lock_guard<std::mutex> lock(mutex_);
325 Array<ObjectRef> items;
326 for (auto& kv : cache_) {
327 items.push_back(kv.first);
328 items.push_back(kv.second);
329 }
330 return items;
331 }
332
333 /*!
334 * \brief Get the cache key of the function that is being lowered currently
335 * \return the cache key
336 */
337 CCacheKey GetCurrentCCacheKey() { return cur_ccache_key_; }
338
339 private:
340 // implement lowered func
341 CCacheValue LowerInternal(const CCacheKey& key, GlobalVarSupply global_var_supply) {
342 VLOG(1) << "lowering:" << std::endl
343 << PrettyPrint(key->source_func) << std::endl
344 << "for target:" << std::endl
345 << key->target->ToDebugString();
346 std::lock_guard<std::mutex> lock(mutex_);
347 CCacheValue value;
348 auto it = cache_.find(key);
349 if (it != cache_.end()) {
350 VLOG(1) << "already lowered to name:" << std::endl
351 << PrettyPrint(it->second->cached_func->prim_fn_var);
352 it->second->use_count += 1;
353 if (it->second->cached_func.defined()) return it->second;
354 value = it->second;
355 } else {
356 value = CCacheValue(make_object<CCacheValueNode>());
357 value->use_count = 1;
358 cache_[key] = value;
359 }
360 cur_ccache_key_ = key;
361
362 Optional<String> opt_compiler = key->source_func->GetAttr<String>(attr::kCompiler);
363 if (opt_compiler.defined()) {
364 // Don't compile now since we don't have anywhere to put the resulting runtime module.
365 // Instead place the original definition in the cache and wait for LowerExternalFunctions.
366 IRModule ir_module({}, {});
367 Optional<String> opt_global_symbol =
368 key->source_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
369 ICHECK(opt_global_symbol.defined()) << "External function has not been attached a name yet.";
370 // Note that the source_func may already be bound to a global function in the module
371 // we are compiling, in which case we should not attempt to make its name unique w.r.t.
372 // the module's globals. Furthermore, the external codegen tool must bind the compiled
373 // function to the "global_symbol" attribute on the source_func. So do not use GetUniqueName
374 // here.
375 auto global_var = global_var_supply->UniqueGlobalFor(opt_global_symbol.value(), false);
376 global_var->checked_type_ = key->source_func->checked_type();
377 ir_module->Add(global_var, key->source_func);
378 value->cached_func = CachedFunc(key->target, global_var, {}, {}, te::Schedule{nullptr},
379 tir::PrimFunc{nullptr}, {}, ir_module);
380 // Collect these here as it's removed in LowerExternalFunctions()
381 device_contexts_.Set(value->cached_func->prim_fn_var, opt_compiler.value());
382 VLOG(1) << "preparing to use external codegen '" << opt_compiler.value()
383 << "' with name:" << std::endl
384 << PrettyPrint(value->cached_func->prim_fn_var) << std::endl
385 << "and definitions:" << std::endl
386 << PrettyPrint(value->cached_func->funcs);
387 return value;
388 }
389
390 // Enforce use the target.
391 With<Target> target_scope(key->target);
392
393 ICHECK(!value->cached_func.defined());
394 value->cached_func =
395 PrimFuncFor(key->source_func, key->target, global_var_supply, constant_name_supply_);
396
397 if (value->cached_func->prim_func.defined()) {
398 VLOG(1) << "Lowering PrimFunc";
399 IRModule lowered = tvm::LowerPrimFunc(value->cached_func->prim_func.value(),
400 value->cached_func->prim_fn_var->name_hint, false);
401 ICHECK_EQ(lowered->functions.size(), 1);
402 for (const auto& kv : lowered->functions) {
403 value->cached_func->funcs->Add(value->cached_func->prim_fn_var, kv.second);
404 }
405 } else {
406 // NOTE: array will copy on write.
407 Array<te::Tensor> all_args = Array<te::Tensor>(value->cached_func->inputs);
408 for (te::Tensor arg : value->cached_func->outputs) {
409 all_args.push_back(arg);
410 }
411 Array<runtime::NDArray> all_consts;
412 for (auto kv : value->cached_func->constant_tensors) {
413 all_args.push_back(kv.second);
414 all_consts.push_back(kv.first->data);
415 }
416 // lower the function
417 std::unordered_map<te::Tensor, tir::Buffer> binds;
418
419 // If we have memory scopes, need to create tir::Buffer knowing this info
420 size_t i = 0; // for corresponding from tensor array
421 for (Var param : key->source_func->params) {
422 if (!param->virtual_device()->memory_scope.empty()) {
423 for (const auto& ttype : FlattenTupleType(param->checked_type())) {
424 te::Tensor x_ref = value->cached_func->inputs[i];
425 // verification if we have synced params and tensors
426 ICHECK(ttype->dtype == x_ref->dtype && ttype->shape.size() == x_ref->shape.size())
427 << "function parameter does not correspond to prepared tensor";
428 binds[x_ref] =
429 tir::BufferWithOffsetAlignment(x_ref->shape, x_ref->dtype, x_ref->op->name, -1, 0,
430 false, param->virtual_device()->memory_scope);
431 }
432 }
433 i++;
434 }
435 if (key->virtual_device != VirtualDevice::FullyUnconstrained() &&
436 !key->virtual_device->memory_scope.empty() &&
437 key->virtual_device->memory_scope != "global") {
438 ICHECK(value->cached_func->outputs.size() == 1)
439 << "Expect only one output for defined memory scope";
440 te::Tensor x_ref = value->cached_func->outputs[0];
441 binds[x_ref] =
442 tir::BufferWithOffsetAlignment(x_ref->shape, x_ref->dtype, x_ref->op->name, -1, 0,
443 false, key->virtual_device->memory_scope);
444 }
445 auto func_name = value->cached_func->prim_fn_var->name_hint;
446 VLOG(1) << "scheduling";
447 IRModule scheduled_module = tvm::LowerSchedule(value->cached_func->schedule, all_args,
448 func_name, binds, global_var_supply);
449 scheduled_module->Update(tir::transform::BindParams(all_consts)(scheduled_module));
450 for (const auto& kv : scheduled_module->functions) {
451 GlobalVar global_var = kv.first;
452 auto func = kv.second;
453 // Propagate the structural hash of the relay function to the tir
454 // function so associations can be made between the two.
455 Optional<String> hash = key->source_func->attrs.GetAttr<String>("hash");
456 if (hash) {
457 func = WithAttrs(Downcast<tir::PrimFunc>(func), {{String("hash"), hash.value()}});
458 }
459 value->cached_func->funcs->Add(global_var, func);
460 }
461 ICHECK(value->cached_func->funcs->Lookup(value->cached_func->prim_fn_var)
462 .as<tir::PrimFuncNode>());
463 }
464 VLOG(1) << "lowered to name:" << std::endl
465 << PrettyPrint(value->cached_func->prim_fn_var) << std::endl
466 << "with definitions:" << std::endl
467 << PrettyPrint(value->cached_func->funcs);
468
469 return value;
470 }
471
472 // implement lowered shape func
473 CCacheValue LowerShapeFuncInternal(const CCacheKey& key) {
474 VLOG(1) << "lowering dynamic shape function for:" << std::endl
475 << PrettyPrint(key->source_func) << std::endl
476 << "for target:" << std::endl
477 << key->target->ToDebugString();
478 std::lock_guard<std::mutex> lock(mutex_);
479 CCacheValue value;
480 auto it = shape_func_cache_.find(key);
481 if (it != shape_func_cache_.end()) {
482 it->second->use_count += 1;
483 if (it->second->cached_func.defined()) return it->second;
484 value = it->second;
485 } else {
486 value = CCacheValue(make_object<CCacheValueNode>());
487 value->use_count = 0;
488 shape_func_cache_[key] = value;
489 }
490 // Enforce use the target.
491 With<Target> target_scope(key->target);
492
493 ICHECK(!value->cached_func.defined());
494
495 using tvm::transform::PassContext;
496 With<PassContext> fresh_pass_ctx_scope(PassContext::Create());
497 value->cached_func = ShapeFuncFor(key->source_func, key->target, global_var_supply_);
498
499 ICHECK(
500 value->cached_func->funcs->Lookup(value->cached_func->prim_fn_var).as<tir::PrimFuncNode>());
501
502 VLOG(1) << "lowered to name:" << std::endl
503 << PrettyPrint(value->cached_func->prim_fn_var) << std::endl
504 << "with definitions:" << std::endl
505 << PrettyPrint(value->cached_func->funcs);
506 return value;
507 }
508
509 Map<String, Integer> GetOpWeights() const {
510 Map<String, Integer> weights;
511 for (const auto& kv : cache_) {
512 auto value = kv.second;
513 auto name = value->cached_func->prim_fn_var->name_hint;
514 weights.Set(name, value->use_count);
515 }
516 return weights;
517 }
518
519 // TODO(mbs): Hold the output module here and reduce the cache_ to just be from
520 // Function to GlobalVar.
521
522 /*! \brief compiler cache lock*/
523 std::mutex mutex_;
524 /*! \brief internal GlobalVarSupply to get unique GlobalVars */
525 GlobalVarSupply global_var_supply_;
526 /*! \brief A NameSupply object for assigning unique names to constants, across different
527 * invocations of PrimFuncFor. */
528 NameSupply constant_name_supply_;
529 /*! \brief internal compiler cache */
530 std::unordered_map<CCacheKey, CCacheValue> cache_;
531 /*! \brief internal compiler cache for shape funcs */
532 std::unordered_map<CCacheKey, CCacheValue> shape_func_cache_;
533 /*! \brief the cache key of the function that is being lowered currently*/
534 CCacheKey cur_ccache_key_;
535 /*! \brief Map of GlobalVar to C Device API context names */
536 Map<GlobalVar, String> device_contexts_;
537};
538
539TECompiler::TECompiler(Optional<IRModule> opt_mod, Optional<String> mod_name) {
540 auto object = make_object<TECompilerImpl>(std::move(opt_mod), std::move(mod_name));
541 data_ = object;
542}
543
544/*! \brief The global TE compiler */
545// TODO(mbs): To be terminated with extreme prejudice.
546TECompiler& TECompiler::Global() {
547 static TECompiler* inst =
548 new TECompiler(make_object<TECompilerImpl>(Optional<IRModule>(), Optional<String>()));
549 return *inst;
550}
551TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.use_auto_scheduler", Bool);
552TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.use_meta_schedule", Bool);
553TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.use_meta_schedule_dispatch", Integer);
554TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.tir_converter", String);
555
556TVM_REGISTER_GLOBAL("relay.backend._TECompilerGlobal").set_body_typed([]() {
557 return TECompiler::Global();
558});
559
560TVM_REGISTER_GLOBAL("relay.backend._make_CCacheKey")
561 .set_body_typed([](Function source_func, Target target) {
562 return CCacheKey(source_func, target);
563 });
564
565TVM_REGISTER_GLOBAL("relay.backend._make_LoweredOutput")
566 .set_body_typed([](tvm::Array<te::Tensor> outputs, OpImplementation impl) {
567 return LoweredOutput(outputs, impl);
568 });
569
570TVM_REGISTER_GLOBAL("relay.backend._TECompilerClear").set_body_typed([](TECompiler self) {
571 self->Clear();
572});
573
574TVM_REGISTER_GLOBAL("relay.backend._TECompilerLower")
575 .set_body_typed([](TECompiler self, CCacheKey key, const String mod_name) {
576 return self->Lower(key, mod_name);
577 });
578
579TVM_REGISTER_GLOBAL("relay.backend._TECompilerJIT")
580 .set_body_typed([](TECompiler self, CCacheKey key) { return self->JIT(key); });
581
582TVM_REGISTER_GLOBAL("relay.backend._TECompilerListItems").set_body_typed([](TECompiler self) {
583 TECompilerImpl* ptr = dynamic_cast<TECompilerImpl*>(self.operator->());
584 ICHECK(ptr != nullptr);
585 return ptr->ListItems();
586});
587
588using AnalysisRemapping = std::unordered_map<Expr, Expr, ObjectHash, ObjectEqual>;
589
590/*!
591 * \brief Rewrites call expressions to Relay Functions marked as "primitive"
592 * to calls to the corresponding TIR PrimFunc for the appropriate target.
593 *
594 * \code
595 * %0 = fn(...) { prim_op(...) } OR let %p = fn(...) { prim_op(...) }
596 * ... %0(...) ... ... %p(...) ...
597 * ==>
598 * def @q(..., target=<target>) { <tir body> }
599 * ... @q(...) ...
600 * \endcode
601 *
602 * Requires FuseOps, ToANormalForm, EtaExpand and InferType to have run.
603 *
604 * FuseOps is needed to identify and lift all prim op calls:
605 * \code
606 * ... prim_op(...) ...
607 * ==>
608 * %0 = fn(...) { prim_op(...) }
609 * ... %0(...) ...
610 * \endcode
611 *
612 * ToANormalForm is needed so we only need to consider vars and function literals as the call
613 * target.
614 *
615 * EtaExpand is needed to ensures all calls to primitives are direct:
616 * \code
617 * let %p1 = fn(...) { prim_op1(...) }
618 * let %p2 = fn(...) { prim_op2(...) }
619 * let %p = if (...) { %p1 } else { %p2 }
620 * ... %p(...) ...
621 * ==>
622 * let %p1 = fn(...) { prim_op1(...) }
623 * let %p2 = fn(...) { prim_op2(...) }
624 * let %p = fn(...) { if (...) { %p1(...) } else { %p2(...) } }
625 * ... %p(...) ...
626 * \endcode
627 */
628class LowerTensorExprMutator : public DeviceAwareExprMutator {
629 public:
630 LowerTensorExprMutator(IRModule module, ProcessFn process_fn, CompilationConfig config,
631 TECompiler compiler)
632 : DeviceAwareExprMutator(module),
633 module_(std::move(module)),
634 process_fn_(std::move(process_fn)),
635 config_(std::move(config)),
636 compiler_(std::move(compiler)),
637 debug_op_(Op::Get("debug")) {}
638
639 /*!
640 * \brief Returns the primitive function associated with \p expr, or nullptr if none.
641 */
642 BaseFunc ResolveToPrimitive(const Expr& expr) {
643 // NOTE: We can't assume expr->checked_type_ is defined, so can't early exit for first-order
644 // expressions.
645 if (const auto* global_var_node = expr.as<GlobalVarNode>()) {
646 if (!module_->ContainGlobalVar(global_var_node->name_hint)) {
647 // TODO(mbs): extern function cleanup
648 // Assume the function is extern and thus no longer in the IRModule.
649 return {};
650 } else {
651 BaseFunc base_func = module_->Lookup(GetRef<GlobalVar>(global_var_node));
652 return ResolveToPrimitive(base_func);
653 }
654 } else if (const auto* prim_func_node = expr.as<tir::PrimFuncNode>()) {
655 return GetRef<tir::PrimFunc>(prim_func_node);
656 } else if (const auto* var_node = expr.as<VarNode>()) {
657 auto itr = primitive_functions_.find(var_node);
658 if (itr == primitive_functions_.end()) {
659 // Not bound to a primitive function.
660 return {};
661 } else {
662 return itr->second;
663 }
664 } else if (const auto* function_node = expr.as<FunctionNode>()) {
665 if (function_node->HasNonzeroAttr(attr::kExtern)) {
666 // We have a regular call to an 'extern' function. The call itself needs to be rewritten
667 // to call_lowered form, and any required dynamic shape functions generated and
668 // cross-linked.
669 return GetRef<Function>(function_node);
670 } else if (function_node->HasNonzeroAttr(attr::kPrimitive)) {
671 if (const auto* call_node = function_node->body.as<CallNode>()) {
672 if (call_node->op == debug_op_) {
673 // Debug 'primitives' are not lowered.
674 return {};
675 }
676 }
677 // We have a regular call to a 'primitive' function (possibly with a 'Compiler' attribute).
678 // We need to lower and rewrite the call.
679 return GetRef<Function>(function_node);
680 } else {
681 // Not marked as primitive during partitioning or TVM fusion.
682 return {};
683 }
684 } else {
685 return {};
686 }
687 }
688
689 /*!
690 * \brief Returns a 'call_lowered' call to \p prim_fn_var with \p args and \p span with all the
691 * required attributes filled in. Generally \p prim_fn_var will correspond to the lowered or
692 * externally codegen-ed form of \p original_function, where \p lowered_functions binds all
693 * the required lowered functions.
694 *
695 * The call's attributes will capture:
696 * - Any attributes on the original_function.
697 * - All the lowered functions.
698 * TODO(mbs): Pretty sure that's no longer needed.
699 * - Details needed to cross-link the call to it's dynamic shape function, if any.
700 */
701 Expr MakeLoweredCall(const BaseFunc& original_function, const GlobalVar& prim_fn_var,
702 Array<Expr> args, Span span, const Target& target,
703 const Map<GlobalVar, BaseFunc>& lowered_functions) {
704 auto opt_compiler = original_function->GetAttr<String>(attr::kCompiler);
705
706 // Add some metadata on top of the *original function* and invoke the callback so it can
707 // be captured.
708 // TODO(@areusch, @jroesch): this metadata is for AOT, this should be our interface for AOT
709 Map<GlobalVar, tir::PrimFunc> prim_fns;
710 Array<GlobalVar> all_prim_fn_vars;
711 for (const auto& kv : lowered_functions) {
712 if (opt_compiler) {
713 // We expect the original function to have just the "Extern" attribute signaling the
714 // function (will be) compiled externally.
715 ICHECK(kv.second.as<FunctionNode>())
716 << PrettyPrint(kv.first) << " must be bound to an (external) Function";
717 } else {
718 // We expect one or more PrimFuncs, one of which corresponds to 'the' lowered primitive,
719 // and the rest are in support of that via tir::Calls.
720 ICHECK(kv.second.as<tir::PrimFuncNode>())
721 << PrettyPrint(kv.first) << " must be bound to a PrimFunc";
722 prim_fns.Set(kv.first, Downcast<tir::PrimFunc>(kv.second));
723 all_prim_fn_vars.push_back(kv.first);
724 }
725 }
726
727 // Alas, WithAttr cannot work with base classes.
728 if (const auto* prim_func_node = original_function.as<te::PrimFuncNode>()) {
729 auto func_with_metadata = GetRef<te::PrimFunc>(prim_func_node);
730 func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var", prim_fn_var);
731 func_with_metadata = WithAttr(func_with_metadata, "prim_funcs", prim_fns);
732 func_with_metadata = WithAttr(func_with_metadata, tvm::attr::kTarget, target);
733 this->process_fn_(func_with_metadata);
734 } else {
735 const auto* function_node = original_function.as<FunctionNode>();
736 ICHECK(function_node);
737 auto func_with_metadata = GetRef<Function>(function_node);
738 func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var", prim_fn_var);
739 func_with_metadata = WithAttr(func_with_metadata, "prim_funcs", prim_fns);
740 func_with_metadata = WithAttr(func_with_metadata, tvm::attr::kTarget, target);
741 this->process_fn_(func_with_metadata);
742 }
743
744 // Now prepare the attributes of the call_lowered.
745 CallLoweredAttrs call_lowered_attrs;
746
747 // TODO(mbs): "reshape" cleanup.
748 if (!opt_compiler && original_function->HasNonzeroAttr(attr::kReshapeOnly)) {
749 call_lowered_attrs.metadata.Set(attr::kReshapeOnly, tvm::Integer(1));
750 }
751
752 call_lowered_attrs.metadata.Set("relay_attrs", original_function->attrs);
753 call_lowered_attrs.metadata.Set("all_prim_fn_vars", all_prim_fn_vars);
754
755 if (const auto* function_node = original_function.as<FunctionNode>()) {
756 if (IsDynamic(function_node->ret_type)) {
757 // Create a dynamic shape function to calculate the expected shape of the results of
758 // the lowered function.
759 // Shape function keys use the original function as their 'function', but the generic 'cpu'
760 // target as the target since all shape functions run on the host cpu irrespective of where
761 // the primitive runs.
762 CCacheKey shape_key(GetRef<Function>(function_node), config_->host_virtual_device->target);
763 CachedFunc lowered_shape_func = compiler_->LowerShapeFunc(shape_key);
764
765 // Capture the shape function's global var and parameters 'states' in call
766 // annotations so calling convention can be recovered.
767 // TODO(mbs): Shape cleanup.
768 call_lowered_attrs.metadata.Set("prim_shape_fn_var", lowered_shape_func->prim_fn_var);
769 call_lowered_attrs.metadata.Set("prim_shape_fn_states",
770 lowered_shape_func->shape_func_param_states);
771 call_lowered_attrs.metadata.Set(
772 "prim_shape_fn_num_inputs",
773 Integer(static_cast<int>(lowered_shape_func->inputs.size())));
774 call_lowered_attrs.metadata.Set(
775 "prim_shape_fn_num_outputs",
776 Integer(static_cast<int>(lowered_shape_func->outputs.size())));
777 Array<GlobalVar> all_prim_shape_fn_vars;
778 for (const auto& kv : lowered_shape_func->funcs->functions) {
779 CHECK(kv.second.as<tir::PrimFuncNode>()) << "must be a prim fn";
780 all_prim_shape_fn_vars.push_back(kv.first);
781 }
782 call_lowered_attrs.metadata.Set("all_prim_shape_fn_vars", all_prim_shape_fn_vars);
783 }
784 }
785
786 return CallLowered(prim_fn_var, std::move(args), std::move(call_lowered_attrs),
787 std::move(span));
788 }
789
790 std::pair<Var, Expr> PreVisitLetBinding_(const Var& var, const Expr& value) final {
791 Var new_var = Downcast<Var>(Mutate(var));
792 Expr new_value = Mutate(value);
793 BaseFunc prim_func = ResolveToPrimitive(new_value);
794
795 if (prim_func.defined()) {
796 // Remember let var is bound (possibly indirectly) to a primitive function.
797 primitive_functions_.emplace(var.get(), prim_func);
798 }
799 return {new_var, new_value};
800 }
801
802 Expr PostVisitLet_(const LetNode* pre_let_node, const LetNode* post_let_node) final {
803 BaseFunc prim_func = ResolveToPrimitive(post_let_node->value);
804 if (prim_func.defined()) {
805 // Leaving let var scope
806 primitive_functions_.erase(pre_let_node->var.get());
807 // Drop the let node
808 return post_let_node->body;
809 }
810 return DeviceAwareExprMutator::PostVisitLet_(pre_let_node, post_let_node);
811 }
812
813 Expr DeviceAwareVisitExpr_(const FunctionNode* function_node) override {
814 if (function_node->HasNonzeroAttr(attr::kPrimitive) ||
815 function_node->HasNonzeroAttr(attr::kExtern)) {
816 // Nothing to lower inside primitive/external functions.
817 return GetRef<Function>(function_node);
818 } else {
819 return DeviceAwareExprMutator::DeviceAwareVisitExpr_(function_node);
820 }
821 }
822
823 Expr DeviceAwareVisitExpr_(const CallNode* call_node) override {
824 // We can see six forms of calls:
825 // 1. A 'normal' Relay call to a Function with the "Primitive" attribute and not "Compiler"
826 // attribute. We will need to lower that to a global PrimFunc and rewrite the call to:
827 // call_lowered(@new_global, (arg1, ..., argn), <attributes>)
828 // If needed, the call needs to be cross-linked with any dynamic shape functions.
829 // (However, some primitives are special and handled separately.)
830 // 2. A 'normal' Relay call to a Function with the "Primitive" and "Compiler" attributes. We
831 // will need to invoke the "relay.ext.<compiler>" function to yield a runtime module, and
832 // rewrite the call to the same form as above. Dynamic shape function cross-linking may
833 // also be needed.
834 // 3. A 'normal' Relay call to a Function with the "Extern" attribute. This function has
835 // already been compiled by an external codegen and a definition for it exists in some
836 // runtime module. Again, we rewrite to call_lowered form, and cross-link with a dynamic
837 // shape function if needed.
838 // 4. A 'normal' Relay call to a PrimFunc which has already been supplied via a global
839 // definition. We rewrite those to use the call_lowered form, but otherwise nothing else
840 // needs to be done.
841 // 5. A 'call_lowered' call from an earlier invocation of this pass or otherwise deliberately
842 // inserted. It has all the required attributes, and any associated dynamic shape function
843 // has been generated and cross-linked. These calls are not changed.
844 // 6. A 'normal' Relay call to a Relay Function without any special attribute. These
845 // calls are not changed.
846 //
847 // Note that ResolveToPrimitive will yield non-null only for cases 1-4.
848
849 // Prepare the arguments and op.
850 Array<Expr> new_args;
851 for (const auto& arg : call_node->args) {
852 new_args.push_back(VisitExpr(arg));
853 }
854 Expr new_op = VisitExpr(call_node->op);
855
856 // Look for (possibly indirect) calls to primitives.
857 BaseFunc primitive_func = ResolveToPrimitive(call_node->op);
858 if (!primitive_func.defined()) {
859 // Cases 5 and 6: Leave as ordinary call.
860 if (const auto* function_node = call_node->op.as<FunctionNode>()) {
861 process_fn_(GetRef<Function>(function_node));
862 }
863 return WithFields(GetRef<Call>(call_node), std::move(new_op), std::move(new_args));
864 }
865
866 // Special case for case 1: device_copies are left as calls to primitive operators
867 // so that each backend can handle them directly.
868 // TODO(mbs): device_copy cleanup. Would be better for FuseOps to just leave device_copy alone.
869 if (const auto* function_node = primitive_func.as<FunctionNode>()) {
870 DeviceCopyProps device_copy_props = GetDeviceCopyProps(function_node->body);
871 if (device_copy_props.body.defined()) {
872 ICHECK_EQ(new_args.size(), 1);
873 return DeviceCopy(new_args[0], device_copy_props.src_virtual_device,
874 device_copy_props.dst_virtual_device);
875 }
876 }
877
878 ICHECK(call_node->type_args.empty()) << "lowered functions cannot be polymorphic";
879
880 // Case 4: If the function has already been lowered we just need to update the call.
881 if (const auto* prim_func_node = primitive_func.as<tir::PrimFuncNode>()) {
882 // Function should already be Target annotated by this point
883 // but the TE Compiler metadata is still needed for the callback
884 // TODO(Mousius) - Robustify this to not assume we're in the GlobalVar for Target Hooks
885 Optional<Target> opt_target = primitive_func->GetAttr<Target>(tvm::attr::kTarget);
886 ICHECK(opt_target.defined());
887 auto prim_fn_var = Downcast<GlobalVar>(call_node->op);
888 tir::PrimFunc prim_func = GetRef<tir::PrimFunc>(prim_func_node);
889 Map<GlobalVar, BaseFunc> prim_fns = {{prim_fn_var, prim_func}};
890 return MakeLoweredCall(primitive_func, prim_fn_var, std::move(new_args), call_node->span,
891 opt_target.value(), prim_fns);
892 }
893
894 // Determine the target for lowering or external codegen.
895 Target target;
896 Optional<String> opt_compiler = primitive_func->GetAttr<String>(attr::kCompiler);
897 if (opt_compiler.defined()) {
898 // This function needs to be compiled with external codegen.
899 Optional<Target> opt_target = config_->FindPrimitiveTargetForKind(opt_compiler.value());
900 if (opt_target.defined()) {
901 // The target is what's supplied by the compilation config for kind matching the
902 // "Compiler" name.
903 target = opt_target.value();
904 } else {
905 // Legacy fallback.
906 target = Target("ext_dev");
907 }
908 } else {
909 // The target corresponding to the call_node expression's annotation.
910 VirtualDevice virtual_device = GetVirtualDevice(GetRef<Call>(call_node));
911 ICHECK(!virtual_device->IsFullyUnconstrained()) << PrettyPrint(GetRef<Call>(call_node));
912 target = virtual_device->target;
913 ICHECK(target.defined());
914 }
915
916 if (primitive_func->HasNonzeroAttr(attr::kExtern)) {
917 // Case 3: Function has already been compiled.
918 GlobalVar prim_fn_var = Downcast<GlobalVar>(call_node->op);
919 return MakeLoweredCall(primitive_func, prim_fn_var, std::move(new_args), call_node->span,
920 target, /*lowered_functions=*/{});
921 } else {
922 // Cases 1 and 2: lower the primitive function for the desired target, possibly using external
923 // codegen.
924 CCacheKey key(Downcast<Function>(primitive_func), target,
925 GetVirtualDevice(GetRef<Call>(call_node)));
926 CachedFunc cfunc = compiler_->Lower(key);
927 ICHECK(cfunc.defined());
928 return MakeLoweredCall(primitive_func, cfunc->prim_fn_var, std::move(new_args),
929 call_node->span, target, cfunc->funcs->functions);
930 }
931 }
932
933 IRModule module_;
934 ProcessFn process_fn_;
935 /*! \brief All available targets. */
936 CompilationConfig config_;
937 // Map from in-scope let-bound variables to Functions known to be primitive, or PrimFuncs which
938 // have already been lowered. We'll rewrite these to the fresh global vars bound to the lowered
939 // primitive function as we go. Those vars will be bound in the target device-type specific
940 // module we'll ultimately emit for each required device-type. Note that a primitive may be
941 // lowered for multiple device types, each which will be assigned a fresh var.
942 std::unordered_map<const VarNode*, BaseFunc> primitive_functions_;
943 TECompiler compiler_;
944 // Cache ops that need to be frequently used later to reduce lookup overhead.
945 const Op& debug_op_;
946};
947
948Pass LowerTensorExpr(TECompiler compiler, ProcessFn process_fn, CompilationConfig config) {
949 runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
950 [=](Function func, IRModule module, PassContext ctx) {
951 LowerTensorExprMutator lower_te(module, process_fn, config, compiler);
952 return Downcast<Function>(lower_te.Mutate(func));
953 };
954 return CreateFunctionPass(pass_func, 0, "LowerTensorExpr", {});
955}
956
957backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, const CompilationConfig& config,
958 Map<Expr, backend::StorageInfo> storage_info_map) {
959 Function func = Downcast<Function>(mod->Lookup("main"));
960
961 VLOG_CONTEXT << "UpdateMainWorkspaceSize";
962 VLOG(1) << "calculating FunctionInfo for main:" << std::endl << PrettyPrint(func);
963
964 // This is a Map<device,Map<storage_id, size>>
965 // TODO(mbs): Collapsing VirtualDevices to just device type.
966 std::unordered_map<DLDeviceType, std::unordered_map<int, int>, backend::EnumClassHash>
967 sid_workspace;
968 // This is a Map<device, size_of_inputs_and_outputs>
969 std::unordered_map<DLDeviceType, int, backend::EnumClassHash> device_io;
970 // This is a Map<device, size_of_constants>
971 std::unordered_map<DLDeviceType, int, backend::EnumClassHash> device_consts;
972
973 // Initialize the mapping from all storage identifiers to workspace sizes,
974 // the amount of device io, and the device constants.
975 for (const auto& kv : storage_info_map) {
976 const backend::StorageInfo& storage_info = kv.second;
977 const std::vector<int64_t>& storage_ids = storage_info->storage_ids;
978 const std::vector<VirtualDevice>& virtual_devices = storage_info->virtual_devices;
979 CHECK_EQ(storage_ids.size(), virtual_devices.size());
980 for (uint32_t i = 0; i < virtual_devices.size(); i++) {
981 DLDeviceType device_type = virtual_devices[i]->device_type();
982 sid_workspace[device_type][storage_ids[i]] = 0;
983 device_io[device_type] = 0;
984 device_consts[device_type] = 0;
985 }
986 }
987
988 // Iterate the storage map to compute all the tensor sizes in the program.
989 // There are 3 cases in this code:
990 //
991 // First we need to compute the sizes of all
992 // inline constants.
993 //
994 // Second we compute the size of any bound variable as these are input and output
995 // sizes of the program.
996 //
997 // Finally for all other expressions we check which storage identifier they have
998 // been assigned and we compute the maximal size of the storage, as tensors can
999 // share storage with other tensors which are the same size or larger.
1000 //
1001 // In this final case there is only one allocation for all tensors which share storage
1002 // which will be the maximal size of all tensors which were assigned to it.
1003 for (const auto& kv : storage_info_map) {
1004 const Expr& expr = kv.first;
1005 const backend::StorageInfo& storage_info = kv.second;
1006 int64_t size_bytes = backend::CalculateRelayExprSizeBytes(expr->checked_type());
1007 VLOG(1) << "expression:" << std::endl
1008 << PrettyPrint(expr) << std::endl
1009 << "of type:" << std::endl
1010 << PrettyPrint(expr->checked_type()) << std::endl
1011 << "has size " << size_bytes << " and storage info:" << std::endl
1012 << storage_info;
1013 const std::vector<int64_t>& storage_ids = storage_info->storage_ids;
1014 const std::vector<VirtualDevice>& virtual_devices = storage_info->virtual_devices;
1015
1016 if (expr->IsInstance<ConstantNode>()) {
1017 for (const auto& virtual_device : virtual_devices) {
1018 DLDeviceType device_type = virtual_device->device_type();
1019 ICHECK_EQ(device_consts.count(device_type), 1);
1020 device_consts[device_type] += size_bytes;
1021 }
1022 } else if (expr->IsInstance<VarNode>() || expr.same_as(func->body)) {
1023 CHECK(size_bytes == 0 || virtual_devices.size() >= 1) << "must be at least one device";
1024 for (const auto& virtual_device : virtual_devices) {
1025 DLDeviceType device_type = virtual_device->device_type();
1026 device_io[device_type] += size_bytes;
1027 }
1028 } else {
1029 // TODO(@electriclilies): This code is never being called which means sid_workspace is not
1030 // updated.. This means that storage info is probably not being created correctly. Or is not
1031 // equivalent to what was here previously
1032 for (uint32_t i = 0; i < storage_ids.size(); i++) {
1033 // Here we record the largest size of the tensor
1034 // that share the same storage id, because storage_id will
1035 // be shared between multiple tensors that are not live simultaneously.
1036 DLDeviceType device_type = virtual_devices[i]->device_type();
1037 if (size_bytes > sid_workspace[device_type][storage_ids[i]]) {
1038 sid_workspace[device_type][storage_ids[i]] = size_bytes;
1039 }
1040 }
1041 }
1042 }
1043
1044 // This is a Map<device, workspace_size>
1045 std::unordered_map<DLDeviceType, int, backend::EnumClassHash> device_workspace;
1046 // Once we know the sizes of sids, we need to accumulate per device
1047 for (const auto& dev_sid_size : sid_workspace) {
1048 auto dev = dev_sid_size.first;
1049 device_workspace[dev] = 0;
1050 for (const auto& sid_size : dev_sid_size.second) {
1051 device_workspace[dev] += sid_size.second;
1052 }
1053 }
1054
1055 Map<Target, Integer> workspace_sizes;
1056 Map<Target, Integer> io_sizes;
1057 Map<Target, Integer> constant_sizes;
1058 Map<Target, tir::PrimFunc> tir_primfuncs;
1059 Map<Target, Function> relay_primfuncs;
1060
1061 // Initialize all target workspaces to zero
1062 for (const auto& target : config->primitive_targets) {
1063 workspace_sizes.Set(target, 0);
1064 }
1065
1066 for (const auto& dev_and_size : device_workspace) {
1067 Target target = config->FindPrimitiveTargetForDeviceOrFail(dev_and_size.first);
1068 workspace_sizes.Set(target, dev_and_size.second);
1069 relay_primfuncs.Set(target, func);
1070 }
1071 for (const auto& dev_and_size : device_io) {
1072 Target target = config->FindPrimitiveTargetForDeviceOrFail(dev_and_size.first);
1073 io_sizes.Set(target, dev_and_size.second);
1074 }
1075
1076 for (const auto& dev_and_size : device_consts) {
1077 Target target = config->FindPrimitiveTargetForDeviceOrFail(dev_and_size.first);
1078 ICHECK_EQ(constant_sizes.count(target), 0);
1079 constant_sizes.Set(target, dev_and_size.second);
1080 }
1081
1082 backend::FunctionInfo func_info(std::move(workspace_sizes), std::move(io_sizes),
1083 std::move(constant_sizes), std::move(tir_primfuncs),
1084 std::move(relay_primfuncs));
1085 VLOG(1) << "func_info: " << func_info;
1086 return std::move(func_info);
1087}
1088
1089/*!
1090 * \brief A function to create the function metadata for an input function (ie calculate buffer
1091 * input/output sizes)
1092 * \param func The function to calculate function metadata for
1093 * \param function_metadata The map that stores all the function metadatas
1094 */
1095void UpdateFunctionMetadata(BaseFunc func,
1096 Map<String, backend::FunctionInfo>& function_metadata, // NOLINT(*)
1097 Integer workspace_byte_alignment) {
1098 VLOG_CONTEXT << "UpdateFunctionMetadata";
1099 VLOG(1) << "updating function metadata for:" << std::endl << PrettyPrint(func);
1100 // Originally UpdateFunctionMetadata took in CCachedFunc and looped through all the funcs stored
1101 // there Now the goal is to take only one func because process_fn should be controlling the
1102 // iteration However, to do the workspace calculations we need the primfuncs. So process_fn
1103 // needs to either access the cached funcs or be directly passed primfuncs This is bad and
1104 // ideally we don't want process_fn to look at primfuncs There's also the question now of what
1105 // the function metadatas are and how they are used if we can do something else to replicate the
1106 // behavior of the function metadatas that might be good (ie annotating functions or something).
1107 Map<Target, Integer> workspace_sizes;
1108 Map<Target, Integer> io_sizes;
1109 Map<Target, Integer> constant_sizes;
1110 Map<Target, tir::PrimFunc> tir_primfuncs;
1111 Map<Target, Function> relay_primfuncs;
1112
1113 Optional<Map<GlobalVar, tir::PrimFunc>> prim_fns =
1114 func->GetAttr<Map<GlobalVar, tir::PrimFunc>>("prim_funcs");
1115 CHECK(prim_fns) << "primitive functions not set on Relay function by TECompiler.";
1116
1117 Optional<GlobalVar> prim_fn_var = func->GetAttr<GlobalVar>("prim_fn_var");
1118 CHECK(prim_fn_var) << "prim_fn_var must be set on Relay functions by TECompiler.";
1119
1120 Optional<Target> relay_target = func->GetAttr<Target>(tvm::attr::kTarget);
1121 CHECK(relay_target) << "target must be set on Relay functions by the TECompiler.";
1122
1123 for (const auto& kv : prim_fns.value()) {
1124 auto prim_fn = Downcast<tir::PrimFunc>(kv.second);
1125 CHECK(prim_fn.defined()) << "the primitive function must be defined";
1126
1127 Integer workspace_size = CalculateWorkspaceBytes(prim_fn, workspace_byte_alignment);
1128
1129 // Workspace sizes
1130 Target prim_fn_target;
1131 if (prim_fn->attrs->dict.count(tvm::attr::kTarget)) {
1132 prim_fn_target = Downcast<Target>(prim_fn->attrs->dict[tvm::attr::kTarget]);
1133 } else {
1134 prim_fn_target = relay_target.value();
1135 }
1136
1137 workspace_sizes.Set(prim_fn_target, workspace_size);
1138
1139 // Calculating size for I/O
1140 // TODO(mbs): See also the other three utils for calculating tensor bytesize.
1141 for (auto const& param : prim_fn->params) {
1142 bool not_a_buffer = prim_fn->buffer_map.count(param) == 0;
1143 if (not_a_buffer) {
1144 io_sizes.Set(prim_fn_target, 0);
1145 continue;
1146 }
1147
1148 auto p_shape = prim_fn->buffer_map[param]->shape;
1149 int num_of_elements = 1;
1150 for (const auto& dim_index_expr : p_shape) {
1151 if (dim_index_expr->IsInstance<IntImmNode>()) {
1152 num_of_elements *= dim_index_expr.as<IntImmNode>()->value;
1153 } else {
1154 // If shape is dynamic, we cannot calculate workspace in compile time.
1155 num_of_elements = 0;
1156 }
1157 }
1158 int element_size = prim_fn->buffer_map[param]->dtype.bytes();
1159 io_sizes.Set(prim_fn_target, element_size * num_of_elements);
1160 }
1161
1162 constant_sizes.Set(prim_fn_target, 0);
1163 tir_primfuncs.Set(prim_fn_target, prim_fn);
1164 if (func->IsInstance<FunctionNode>()) {
1165 relay_primfuncs.Set(prim_fn_target, Downcast<Function>(func));
1166 }
1167 }
1168
1169 backend::FunctionInfo fi = backend::FunctionInfo(
1170 std::move(workspace_sizes), std::move(io_sizes), std::move(constant_sizes),
1171 std::move(tir_primfuncs), std::move(relay_primfuncs));
1172
1173 VLOG(1) << "FunctionInfo: " << PrettyPrint(prim_fn_var.value()) << " = " << PrettyPrint(fi);
1174
1175 // The primitive function name here corresponds to the string we will use to generate
1176 // this Relay function at the low level.
1177 function_metadata.Set(prim_fn_var.value()->name_hint, fi);
1178}
1179
1180/*! \brief Main lowering driving. */
1181IRModule LowerTE(const IRModule& module, const String& module_name, ProcessFn process_fn,
1182 CompilationConfig config) {
1183 TECompiler compiler(module, module_name);
1184
1185 // TODO(mbs): This is all unnecessarily convoluted. Better would be to accumulate the rewritten
1186 // module as we go (including rewritten Functions, lowered primitives, and runtime modules
1187 // generated by external toolchains), and use a pair of maps over vars and global vars
1188 // to global vars to remember which functions have already been lowered.
1189
1190 // Lower all the callees in module:
1191 // - Functions tagged with "Compiler" are unchanged (checked by CreateFunctionPass)
1192 // - Functions tagged with "Primitive" are unchanged (checked by LowerTensorExprMutator)
1193 // - Called functions tagged with "Compiler" are copied into the compiler cache with a fresh
1194 // GlobalVar, and calls updated (sticking with regular Relay Call).
1195 // - Calls to functions tagged with "Primitive" are compiled to PrimFuncs, and calls updated
1196 // (using call_lowered convention).
1197 IRModule updated_module =
1198 LowerTensorExpr(compiler, std::move(process_fn), std::move(config))(module);
1199
1200 // The Functions tagged with "Compiler" are now residing in the cache ready to be
1201 // compiled by LowerExternalFunctions. However we still need a record of them in the
1202 // IRModule so that the various executors can see which function names need to be
1203 // retrieved. They may, however, have been renamed.
1204 compiler->AddExterns(updated_module);
1205
1206 // Add the lowered functions.
1207 IRModule lowered_module = compiler->GetLoweredFunctions();
1208 VLOG(1) << "capturing " << lowered_module->functions.size() << " new lowered functions";
1209 for (const auto& kv : lowered_module->functions) {
1210 if (updated_module->ContainGlobalVar(kv.first->name_hint)) {
1211 LOG(FATAL) << "duplicate bindings for '" << kv.first->name_hint
1212 << "'. Existing is:" << std::endl
1213 << PrettyPrint(updated_module->Lookup(kv.first->name_hint)) << std::endl
1214 << "while new is:" << std::endl
1215 << PrettyPrint(kv.second);
1216 }
1217 updated_module->Add(kv.first, kv.second);
1218 }
1219
1220 // Invoke external codegen for all Functions in the cache tagged with "Compiler", and
1221 // annotate the module with the resulting runtime modules.
1222 // TODO(mbs): runtime modules should be first class rather than attributes.
1223 Array<runtime::Module> external_mods =
1224 module->GetAttr<Array<runtime::Module>>(tvm::attr::kExternalMods).value_or({});
1225 Array<runtime::Module> new_external_mods = compiler->LowerExternalFunctions();
1226 VLOG(1) << "capturing " << external_mods.size() << " existing and " << new_external_mods.size()
1227 << " new external modules";
1228 for (const auto& mod : new_external_mods) {
1229 external_mods.push_back(mod); // copy-on-write.
1230 }
1231
1232 // Annotate the module with C Device API context mapping (this is until we have Targets
1233 // annotated for the C Device API)
1234 // TODO(Mousius) - Remove "device_contexts" as soon as we have the graph annotated properly with
1235 // Targets
1236 Map<GlobalVar, String> device_contexts =
1237 module->GetAttr<Map<GlobalVar, String>>("device_contexts", Map<GlobalVar, String>()).value();
1238 Map<GlobalVar, String> new_device_contexts = compiler->GetDeviceContexts();
1239 VLOG(1) << "capturing " << device_contexts.size() << " existing and "
1240 << new_device_contexts.size() << " new device contexts for external functions";
1241 for (const auto& kv : new_device_contexts) {
1242 ICHECK_EQ(device_contexts.count(kv.first), 0);
1243 device_contexts.Set(kv.first, kv.second); // copy-on-write.
1244 }
1245
1246 updated_module = WithAttrs(updated_module, {{tvm::attr::kExternalMods, std::move(external_mods)},
1247 {"device_contexts", std::move(device_contexts)}});
1248
1249 if (backend::IsAutoSchedulerEnabled()) {
1250 // Capture all the 'operator weights', ie usage counts for each PrimFunc.
1251 Map<String, Integer> op_weights =
1252 module->GetAttr<Map<String, Integer>>("op_weights", Map<String, Integer>()).value();
1253 Map<String, Integer> new_op_weights = compiler->GetOpWeights();
1254 VLOG(1) << "capturing " << op_weights.size() << " existing and " << new_op_weights.size()
1255 << " new operator weights for PrimFuncs";
1256 for (const auto& kv : new_op_weights) {
1257 ICHECK_EQ(op_weights.count(kv.first), 0);
1258 op_weights.Set(kv.first, kv.second); // copy-on-write.
1259 }
1260 updated_module = WithAttr(updated_module, "op_weights", std::move(op_weights));
1261 }
1262
1263 return updated_module;
1264}
1265
1266Map<Target, IRModule> GetPerTargetModules(IRModule mod) {
1267 std::unordered_map<Target, IRModule, backend::TargetStrHash, backend::TargetStrEqual>
1268 per_target_modules;
1269 for (const auto& kv : mod->functions) {
1270 const GlobalVar& var = kv.first;
1271 const BaseFunc& func = kv.second;
1272 if (func->IsInstance<tir::PrimFuncNode>()) {
1273 // Extract target
1274 Optional<Target> target = func->GetAttr<Target>(tvm::attr::kTarget);
1275 ICHECK(target) << "Target should be set at this point";
1276
1277 // Put the function in per_target_modules
1278 if (!per_target_modules.count(target.value())) {
1279 // Initialize the IRModule for this target with the attributes from the input IRModule
1280 IRModule target_module = IRModule({}, {}, {}, {}, mod->attrs);
1281 // Add the function to the IRModule
1282 target_module->Add(var, func);
1283 per_target_modules[target.value()] = target_module;
1284 } else {
1285 // The IRModule for this target is initialized, so just add the function.
1286 IRModule target_module = per_target_modules.at(target.value());
1287 target_module->Add(var, func);
1288 }
1289 } else if (!func->IsInstance<relay::FunctionNode>()) {
1290 LOG(FATAL)
1291 << "The function types in the IRModule should be RelayFunction or PrimFunc, but got "
1292 << func->GetTypeKey();
1293 }
1294 }
1295 return per_target_modules;
1296}
1297
1298Pass LowerTE(String module_name, CompilationConfig complilation_config, ProcessFn process_fn) {
1299 runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = [=](IRModule module,
1300 PassContext ctx) {
1301 return LowerTE(module, module_name, process_fn, complilation_config);
1302 };
1303
1304 return tvm::transform::Sequential(
1305 {tvm::relay::transform::RelayToTIRTargetHook(complilation_config),
1306 tvm::transform::CreateModulePass(pass_func, 0, "LowerTE", {"InferType"}), InferType(),
1307 tvm::tir::transform::ExtractPrimFuncConstants()});
1308}
1309
1310TVM_REGISTER_GLOBAL("relay.tec.LowerTE")
1311 .set_body_typed([](String module_name, CompilationConfig compilation_config) {
1312 return LowerTE(std::move(module_name), std::move(compilation_config));
1313 });
1314
1315} // namespace tec
1316} // namespace relay
1317} // namespace tvm
1318