1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20#include <tvm/tir/builtin.h>
21#include <tvm/tir/stmt.h>
22#include <tvm/tir/transform.h>
23
24#include "../../arith/ir_visitor_with_analyzer.h"
25
26namespace tvm {
27namespace tir {
28
29inline bool IsVtcmStorage(std::string scope) {
30 return scope.find("global.vtcm") != std::string::npos;
31}
32
33class VtcmAllocator : public StmtExprMutator {
34 public:
35 using StmtExprMutator::VisitStmt_;
36 VtcmAllocator() {}
37
38 Stmt VisitStmt_(const AllocateNode* op) final {
39 std::string storage_scope = GetStorageScope(op->buffer_var);
40 if (IsVtcmStorage(storage_scope)) {
41 Stmt body = this->VisitStmt(op->body);
42 Array<PrimExpr> args;
43 args.push_back(StringImm(storage_scope));
44 args.push_back(IntImm(DataType::Int(64), op->extents.size()));
45 args.push_back(Call(DataType::Handle(), builtin::tvm_stack_make_shape(), op->extents));
46 return LetStmt(op->buffer_var,
47 Call(op->buffer_var.dtype(), builtin::nd_mem_alloc_with_scope(), args), body);
48 }
49 return StmtExprMutator::VisitStmt_(op);
50 }
51
52 protected:
53 std::string GetStorageScope(const Var& var) {
54 auto* ptr = var->type_annotation.as<PointerTypeNode>();
55 ICHECK(ptr) << "Buffer Var's type annotation must be of PointerType";
56 return ptr->storage_scope;
57 }
58};
59
60PrimFunc LowerVtcmAlloc(PrimFunc func) {
61 auto fptr = func.CopyOnWrite();
62 fptr->body = VtcmAllocator()(std::move(fptr->body));
63 return func;
64}
65
66namespace transform {
67
68Pass LowerVtcmAlloc() {
69 auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
70 return LowerVtcmAlloc(std::move(f));
71 };
72 return CreatePrimFuncPass(pass_func, 0, "tir.LowerVtcmAlloc", {});
73}
74
75TVM_REGISTER_GLOBAL("tir.transform.LowerVtcmAlloc").set_body_typed(LowerVtcmAlloc);
76
77} // namespace transform
78
79} // namespace tir
80} // namespace tvm
81