1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tir/analysis/calculate_workspace.cc
22 * \brief Calculate any intermediary memory required by PrimFuncs.
23 */
24#include <tvm/arith/analyzer.h>
25#include <tvm/runtime/device_api.h>
26#include <tvm/tir/analysis.h>
27#include <tvm/tir/function.h>
28#include <tvm/tir/stmt_functor.h>
29#include <tvm/tir/usmp/utils.h>
30
31namespace tvm {
32namespace tir {
33
34template <typename T>
35class WorkspaceCalculator : public StmtExprVisitor {
36 public:
37 WorkspaceCalculator() = default;
38 size_t operator()(const PrimFunc& func);
39 size_t byte_alignment = tvm::runtime::kDefaultWorkspaceAlignment;
40
41 private:
42 void VisitStmt_(const T* op) override;
43 size_t GetByteAlignedSize(Integer non_aligned_size);
44 size_t CalculateExtentsSize(const DataType& dtype, const Array<PrimExpr>& extents);
45 size_t current_size = 0;
46 size_t max_size = 0;
47};
48
49template <typename T>
50size_t WorkspaceCalculator<T>::operator()(const PrimFunc& func) {
51 this->VisitStmt(func->body);
52 return this->max_size;
53}
54
55template <typename T>
56size_t WorkspaceCalculator<T>::GetByteAlignedSize(Integer non_aligned_size) {
57 return non_aligned_size.defined()
58 ? ((non_aligned_size.IntValue() + byte_alignment - 1) / byte_alignment) *
59 byte_alignment
60 : 0;
61}
62
63template <typename T>
64void WorkspaceCalculator<T>::VisitStmt_(const T* op) {
65 auto size = GetByteAlignedSize(usmp::CalculateExtentsSize(op));
66 current_size += size;
67 if (current_size > max_size) {
68 max_size = current_size;
69 }
70 StmtExprVisitor::VisitStmt(op->body);
71 current_size -= size;
72}
73
74size_t CalculateConstantBytes(const PrimFunc& func, const Integer& byte_alignment) {
75 WorkspaceCalculator<AllocateConstNode> wc;
76 wc.byte_alignment = byte_alignment->value;
77 return wc(func);
78}
79
80size_t CalculateWorkspaceBytes(const PrimFunc& func, const Integer& byte_alignment) {
81 WorkspaceCalculator<AllocateNode> wc;
82 wc.byte_alignment = byte_alignment->value;
83 return wc(func);
84}
85
86TVM_REGISTER_GLOBAL("tir.analysis.calculate_constant_bytes")
87 .set_body_typed([](PrimFunc func, Integer constant_byte_alignment) {
88 return static_cast<int>(CalculateConstantBytes(func, constant_byte_alignment));
89 });
90
91TVM_REGISTER_GLOBAL("tir.analysis.calculate_workspace_bytes")
92 .set_body_typed([](PrimFunc func, Integer workspace_byte_alignment) {
93 return static_cast<int>(CalculateWorkspaceBytes(func, workspace_byte_alignment));
94 });
95
96} // namespace tir
97} // namespace tvm
98