1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file make_unpacked_api.cc Lower PrimFunc to a standard C function API.
22 */
23#include <tvm/runtime/device_api.h>
24#include <tvm/runtime/registry.h>
25#include <tvm/target/target.h>
26#include <tvm/tir/analysis.h>
27#include <tvm/tir/buffer.h>
28#include <tvm/tir/builtin.h>
29#include <tvm/tir/expr.h>
30#include <tvm/tir/stmt_functor.h>
31#include <tvm/tir/transform.h>
32
33#include <unordered_set>
34#include <utility>
35#include <vector>
36
37#include "arg_binder.h"
38#include "ir_utils.h"
39
40namespace tvm {
41namespace tir {
42
43PrimFunc MakeUnpackedAPI(PrimFunc&& func) {
44 auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
45 ICHECK(global_symbol) << "MakeUnpackedAPI: Expect PrimFunc to have the global_symbol attribute";
46
47 auto target = func->GetAttr<Target>(tvm::attr::kTarget);
48 ICHECK(target.defined()) << "MakeUnpackedAPI: Require the target attribute";
49
50 auto* func_ptr = func.CopyOnWrite();
51
52 // Setup device context
53 int target_device_type = target.value()->GetTargetDeviceType();
54 Integer device_type(target_device_type);
55 Integer device_id(0);
56 PrimExpr node = StringImm("default");
57 const Stmt nop = Evaluate(0);
58 std::vector<Stmt> device_init;
59
60 // Collect variables and buffers to map between
61 Array<Var> args;
62
63 for (const Var& param : func->params) {
64 // Ideally all func params should have Buffers defined in the buffer_map
65 // We should look to insert buffer_maps for all PrimFuncs that are returned
66 // to the core compiler.
67 if (func->buffer_map.find(param) != func->buffer_map.end()) {
68 args.push_back(func->buffer_map[param]->data);
69 } else {
70 args.push_back(param);
71 }
72 }
73
74 if (func->buffer_map.size()) {
75 device_init.push_back(AttrStmt(node, attr::device_id, device_id, nop));
76 device_init.push_back(AttrStmt(node, attr::device_type, device_type, nop));
77 }
78
79 func_ptr->body = MergeNest(device_init, func_ptr->body);
80 func_ptr->params = args;
81 func_ptr->ret_type = PrimType(DataType::Int(32));
82 func_ptr->buffer_map = Map<Var, Buffer>();
83
84 // return the function.
85 return std::move(func);
86}
87
88namespace transform {
89
90Pass MakeUnpackedAPI() {
91 auto pass_func = [](IRModule m, PassContext ctx) {
92 IRModuleNode* mptr = m.CopyOnWrite();
93 std::vector<std::pair<GlobalVar, PrimFunc>> updates;
94
95 for (const auto& kv : mptr->functions) {
96 if (auto* n = kv.second.as<PrimFuncNode>()) {
97 PrimFunc func = GetRef<PrimFunc>(n);
98 if (func->GetAttr<Integer>(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) ==
99 CallingConv::kDefault) {
100 auto updated_func = MakeUnpackedAPI(std::move(func));
101 updates.push_back({kv.first, updated_func});
102 }
103 }
104 }
105
106 for (const auto& pair : updates) {
107 mptr->AddUnchecked(pair.first, pair.second);
108 }
109 return m;
110 };
111
112 return tvm::transform::CreateModulePass(pass_func, 0, "tir.MakeUnpackedAPI", {});
113}
114
115TVM_REGISTER_GLOBAL("tir.transform.MakeUnpackedAPI").set_body_typed(MakeUnpackedAPI);
116} // namespace transform
117} // namespace tir
118} // namespace tvm
119