1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/transforms/memory_alloc.cc
22 * \brief A pass for manifesting explicit memory allocations.
23 */
24
25#include <tvm/node/structural_equal.h>
26#include <tvm/node/structural_hash.h>
27#include <tvm/relay/analysis.h>
28#include <tvm/relay/attrs/annotation.h>
29#include <tvm/relay/attrs/call.h>
30#include <tvm/relay/attrs/device_copy.h>
31#include <tvm/relay/attrs/memory.h>
32#include <tvm/relay/expr.h>
33#include <tvm/relay/expr_functor.h>
34#include <tvm/relay/op.h>
35#include <tvm/relay/transform.h>
36#include <tvm/runtime/logging.h>
37#include <tvm/target/target.h>
38
39#include <cstdint>
40#include <cstdio>
41#include <string>
42#include <unordered_set>
43#include <vector>
44
45#include "../backend/te_compiler.h"
46#include "../backend/te_compiler_cache.h"
47#include "../op/annotation/annotation.h"
48#include "../op/call/call.h"
49#include "../op/memory/device_copy.h"
50#include "../op/memory/memory.h"
51#include "../op/vm/vm.h"
52#include "./device_aware_visitors.h"
53#include "./let_list.h"
54#include "./pass_utils.h"
55#include "./pattern_utils.h"
56
57using namespace tvm::runtime;
58
59namespace tvm {
60namespace relay {
61
62class DialectRewriter : public transform::DeviceAwareExprMutator {
63 public:
64 DialectRewriter(IRModule mod, VirtualDevice host_virtual_device)
65 : transform::DeviceAwareExprMutator(mod),
66 mod_(std::move(mod)),
67 host_virtual_device_(std::move(host_virtual_device)) {}
68
69 Function Rewrite(const Function& expr) { return Downcast<Function>(Mutate(expr)); }
70
71 private:
72 using ExprMutator::VisitExpr_;
73
74 Expr VisitExpr_(const TupleNode* tuple_node) final {
75 LetList& scope = scopes_.back();
76 Array<Expr> new_fields;
77 new_fields.reserve(tuple_node->fields.size());
78
79 for (auto field : tuple_node->fields) {
80 auto new_field = Mutate(field);
81 if (const auto* op = new_field.as<ConstantNode>()) {
82 DataType dtype(op->data->dtype);
83 bool is_simple_const = (dtype == DataType::Int(32) || dtype == DataType::Int(64) ||
84 dtype == DataType::Float(32) || dtype == DataType::Float(64) ||
85 dtype == DataType::Bool());
86 if (!op->is_scalar() || !is_simple_const) {
87 VirtualDevice virtual_device = GetVirtualDevice(field);
88 ICHECK(!virtual_device->IsFullyUnconstrained());
89 Var const_var("const", Type(nullptr));
90 new_field = scope.Push(const_var, MaybeOnDeviceFixed(new_field, virtual_device));
91 }
92 }
93 new_fields.push_back(new_field);
94 }
95 return WithFields(GetRef<Tuple>(tuple_node), new_fields);
96 }
97
98 void PreVisitLetBlock_(const LetNode* let_node) final { scopes_.emplace_back(); }
99
100 std::pair<Var, Expr> PreVisitLetBinding_(const Var& var, const Expr& value) final {
101 Expr new_value = Mutate(value);
102 VirtualDevice virtual_device = GetVirtualDevice(value);
103 ICHECK(!virtual_device->IsFullyUnconstrained());
104 scopes_.back().Push(var, MaybeOnDeviceFixed(new_value, virtual_device));
105 // Since we always need a let block on which to bind sub-expressions the rewritten bindings
106 // are tracked in the current scopes. But return the rewritten binding anyway.
107 return {var, new_value};
108 }
109
110 Expr PostVisitLetBlock_(const LetNode* pre_let_node, const LetNode* post_let_node) final {
111 // The current scope has captured all the rewritten let-binding, as well as any additional
112 // bindings we needed to add. All we need is the rewritted body.
113 Expr new_body = post_let_node->body;
114 while (const auto* inner_let_node = new_body.as<LetNode>()) {
115 new_body = inner_let_node->body;
116 }
117 auto ret = scopes_.back().Get(new_body);
118 scopes_.pop_back();
119 return ret;
120 }
121
122 Expr DeviceAwareVisitExpr_(const CallNode* call_node) final {
123 DeviceCopyProps device_copy_props = GetDeviceCopyProps(call_node);
124 CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node);
125
126 if (device_copy_props.body.defined()) {
127 // Special case: device_copy calls remain in their original (and functional) form.
128 // TODO(mbs): device_copy cleanup.
129 return transform::DeviceAwareExprMutator::DeviceAwareVisitExpr_(call_node);
130 }
131
132 if (!call_lowered_props.lowered_func.defined()) {
133 // This is a call to a user-defined Relay functinon, which will be handled directly by
134 // the VM and does not need conversion to DPS.
135 return transform::DeviceAwareExprMutator::DeviceAwareVisitExpr_(call_node);
136 }
137
138 Call call = GetRef<Call>(call_node);
139 VLOG(1) << "converting lowered call to DPS:" << std::endl << PrettyPrint(call);
140
141 VirtualDevice virtual_device = GetVirtualDevice(call);
142 ICHECK(!virtual_device->IsFullyUnconstrained());
143 ICHECK(!scopes_.empty())
144 << "Calls out of a let block are not supported, do you forget to transform "
145 << "with ToANormalForm or set opt_level >= 1 in the pass context?";
146 LetList& scope = scopes_.back();
147
148 std::vector<Expr> new_args;
149 for (const auto& arg : call_lowered_props.arguments) {
150 new_args.push_back(Mutate(arg));
151 }
152 Tuple ins(new_args);
153 Type ret_type = call_node->checked_type_;
154 std::vector<TensorType> out_types = FlattenTupleType(ret_type);
155
156 // Handle reshape.
157 // Original:
158 // reshape(body, <ReshapeAttrs>)
159 // dyn.reshape(body, shape, <ReshapeAttrs>)
160 // After FuseOps:
161 // let %f = fn(x, primitive=1, relay.reshape_only=1) { reshape(x, <ReshapeAttrs>) }
162 // %f(body)
163 // After LowerTEPass:
164 // call_lowered(@xxx_reshape, (body), <LoweredCallAttrs with
165 // relay_attrs|->dict[relay.reshape_only] = 1)
166 // -OR-
167 // call_lowered(@xxx_dyn_reshape, (body, shape), <LoweredCallAttrs with same>)
168 // where @reshape_xxx is bound as a PrimFunc.
169 // (the name is irrelevant, only the relay.reshape_only attribute matters)
170 // After this pass:
171 // vm.reshape_tensor(body, shape, <TIRCallAttrs>)
172 if (IsReshapeOnly(call_lowered_props)) {
173 return EmitReshapeTensor(&scope, ins, call_lowered_props.attrs, ret_type);
174 }
175
176 // At this point we could be calling a PrimFunc or an 'external' and already compiled primitive.
177 // The calling conventions are identical.
178
179 // Handle 'dynamic' calls, ie to PrimFuncs whose result shape must be first computed
180 // by a companion shape function.
181 if (IsDynamic(ret_type)) {
182 return DynamicInvoke(&scope, call_lowered_props.lowered_func, ins, call_lowered_props.attrs,
183 out_types, ret_type, virtual_device);
184 }
185
186 // Handle ordinary primitive calls.
187 Array<Expr> outputs;
188 for (size_t i = 0; i < out_types.size(); ++i) {
189 outputs.push_back(
190 MakeStaticAllocation(&scope, out_types[i], virtual_device, std::to_string(i)));
191 }
192 Tuple outs(outputs);
193 Expr invoke =
194 InvokeTVMOp(call_lowered_props.lowered_func, ins, outs,
195 Downcast<DictAttrs>(call_lowered_props.attrs.metadata.at("relay_attrs")));
196 scope.Push(MaybeOnDeviceFixed(invoke, virtual_device));
197 return ToTupleType(ret_type, std::vector<Expr>(outputs.begin(), outputs.end()));
198 }
199
200 /*!
201 * \brief Returns the Relay Constant representing the 1d tensor with \p value.
202 *
203 * CAUTION: Make sure the constant ends up on the correct device.
204 */
205 inline Constant MakeConstant(const std::vector<int64_t>& value) {
206 return MakeConstantTensor(DataType::Int(64), {static_cast<int64_t>(value.size())}, value);
207 }
208
209 /*! Returns an \p alloc_tensor call for a tensor of \p shape and \p dtype over \p storage. */
210 inline Expr AllocTensor(const Expr& storage, tvm::relay::Expr shape, DataType dtype,
211 Array<IndexExpr> assert_shape) {
212 Expr offset =
213 MaybeOnDeviceFixed(MakeConstantScalar(DataType::Int(64), 0), host_virtual_device_);
214 return tvm::relay::AllocTensor(storage, std::move(offset), std::move(shape), dtype,
215 assert_shape);
216 }
217
218 Expr ComputeAlignment(const DataType& dtype) const {
219 int64_t align = dtype.bits() / 8 * dtype.lanes();
220 if (align < 64) {
221 align = 64;
222 }
223 return MakeConstantScalar(DataType::Int(64), align);
224 }
225
226 Expr ComputeStorageInRelay(const Expr& shape, const TensorType& type) const {
227 auto dtype = DataType(type->dtype);
228 Expr els = Prod(shape, Array<Integer>(nullptr), false, false);
229 Expr num = MakeConstantScalar(DataType::Int(64), dtype.bits() * dtype.lanes());
230 Expr add = Add(num, MakeConstantScalar(DataType::Int(64), 7));
231 Expr div = MakeConstantScalar(DataType::Int(64), 8);
232 Expr ret = Multiply(els, Divide(add, div));
233 return std::move(ret);
234 }
235
236 Expr ComputeStorage(const TensorType& type) {
237 int64_t size = 1;
238 for (auto it : type->shape) {
239 auto val = it.as<IntImmNode>();
240 CHECK(val);
241 size *= val->value;
242 }
243 size *= (type->dtype.bits() * type->dtype.lanes() + 7) / 8;
244 return std::move(MakeConstantScalar(DataType::Int(64), size));
245 }
246
247 // Allocate a tensor with a statically known shape.
248 Var MakeStaticAllocation(LetList* scope, const TensorType& type,
249 const VirtualDevice& virtual_device, String name_hint) {
250 std::vector<int64_t> int_shape;
251 for (auto it : type->shape) {
252 const auto* imm = it.as<IntImmNode>();
253 CHECK(imm) << "expect static int shape";
254 int_shape.push_back(imm->value);
255 }
256 Expr shape = MaybeOnDeviceFixed(MakeConstant(int_shape), host_virtual_device_);
257 Expr size = MaybeOnDeviceFixed(ComputeStorage(type), host_virtual_device_);
258 // Alignment is directly captured in the instruction rather than calculated, so we
259 // don't want to wrap it with an "on_device".
260 Expr alignment = ComputeAlignment(type->dtype);
261 // Run type inference later to get the correct type.
262 Var var("storage_" + name_hint, Type(nullptr));
263 Expr value = AllocStorage(size, alignment, virtual_device, type->dtype);
264 auto sto = scope->Push(var, MaybeOnDeviceFixed(value, virtual_device));
265
266 // TODO(@jroesch): There is a bug with typing based on the constant shape.
267 auto tensor = AllocTensor(sto, shape, type->dtype, /*assert_shape=*/type->shape);
268 Var tensor_var("tensor_" + name_hint, Type(nullptr));
269 return scope->Push(tensor_var, MaybeOnDeviceFixed(tensor, virtual_device));
270 }
271
272 /*!
273 * \brief Appends to \p scope the computation necessary to call the shape function given
274 * in \p tir_call_attrs and bind the resulting result shapes into \p scope. The result
275 * shapes are for a call to a primitive with \p ins arguments. Some combinationn of the
276 * data and/or shapes of \p ins will be needed by the shape function.
277 */
278 Array<Expr> EmitShapeFunc(LetList* scope, const Tuple& ins, const CallLoweredAttrs& attrs) {
279 ICHECK(attrs.metadata.count("prim_shape_fn_states"));
280 Array<Integer> input_states =
281 Downcast<Array<Integer>>(attrs.metadata.at("prim_shape_fn_states"));
282 ICHECK(attrs.metadata.count("prim_shape_fn_var"));
283 auto prim_fn_var = Downcast<GlobalVar>(attrs.metadata.at("prim_shape_fn_var"));
284
285 const auto* func_type_node = prim_fn_var->checked_type().as<FuncTypeNode>();
286 ICHECK(func_type_node);
287
288 // Establish the arguments to the shape function.
289 Array<Expr> shape_func_ins;
290 int input_pos = 0;
291 ICHECK_EQ(ins->fields.size(), input_states.size());
292 for (size_t i = 0; i < ins->fields.size(); ++i) {
293 const Expr& arg = ins->fields[i];
294 Type ty;
295 if (const auto* vn = arg.as<VarNode>()) {
296 ty = vn->type_annotation;
297 } else {
298 ty = arg->checked_type();
299 }
300 int64_t state = input_states[i]->value;
301 // Pass Shapes
302 if (state == tec::kNeedInputShape) {
303 std::vector<Expr> exprs = FromTupleType(ty, arg);
304 for (size_t j = 0; j < exprs.size(); ++j) {
305 Expr sh_of = Mutate(ShapeOf(exprs[j]));
306 Var in_shape_var("in_shape_" + std::to_string(input_pos + j), Type(nullptr));
307 shape_func_ins.push_back(
308 scope->Push(in_shape_var, MaybeOnDeviceFixed(sh_of, host_virtual_device_)));
309 input_pos++;
310 }
311 } else if (state == tec::kNeedInputData) {
312 auto new_arg = Mutate(arg); // already accounts for device
313 VirtualDevice arg_virtual_device = GetVirtualDevice(arg);
314 ICHECK(!arg_virtual_device->IsFullyUnconstrained());
315 // The dynamic shape function is expecting its data on the host/CPU, so insert a
316 // device_copy otherwise. (We'll need to fuse & lower these copies in the same way
317 // we fuse & lower other operators we insert for, eg, dynamic tensor size calculation.)
318 new_arg = MaybeDeviceCopy(MaybeOnDeviceFixed(new_arg, arg_virtual_device),
319 arg_virtual_device, host_virtual_device_);
320 Var in_shape_var("in_shape_" + std::to_string(input_pos), Type(nullptr));
321 shape_func_ins.push_back(
322 scope->Push(in_shape_var, MaybeOnDeviceFixed(new_arg, host_virtual_device_)));
323 input_pos++;
324 } else {
325 // TODO(@jroesch): handle kNeedBoth
326 LOG(FATAL) << "unsupported shape function input state";
327 }
328 }
329 ICHECK_EQ(shape_func_ins.size(), func_type_node->arg_types.size());
330
331 // Establish the result shapes.
332 const auto* res_tuple_node = func_type_node->ret_type.as<TupleTypeNode>();
333 ICHECK(res_tuple_node);
334
335 Array<Expr> out_shapes;
336 for (size_t i = 0; i < res_tuple_node->fields.size(); ++i) {
337 const auto* tensor_type_node = res_tuple_node->fields[i].as<TensorTypeNode>();
338 ICHECK(tensor_type_node);
339 // Put the shape func on the host. This also ensures that everything between
340 // shape_of and shape_func is similarly on the host.
341 Var alloc = MakeStaticAllocation(scope, GetRef<TensorType>(tensor_type_node),
342 host_virtual_device_, "out_shape_" + std::to_string(i));
343 out_shapes.push_back(alloc);
344 }
345
346 // Represent the call in DPS form.
347 auto shape_call = InvokeTVMOp(prim_fn_var, Tuple(shape_func_ins), Tuple(out_shapes),
348 Downcast<DictAttrs>(attrs.metadata.at("relay_attrs")));
349 Var shape_func_var("shape_func", Type(nullptr));
350 scope->Push(shape_func_var, MaybeOnDeviceFixed(shape_call, host_virtual_device_));
351 return out_shapes;
352 }
353
354 // Generate the code for invoking the TVM primitive \p func who's results have dynamic shapes.
355 Expr DynamicInvoke(LetList* scope, const Expr& func, const Tuple& ins,
356 const CallLoweredAttrs& attrs, const std::vector<TensorType>& out_types,
357 const Type& ret_type, const VirtualDevice& virtual_device) {
358 Array<Expr> out_shapes = EmitShapeFunc(scope, ins, attrs);
359 std::vector<Var> storages;
360 CHECK_EQ(out_shapes.size(), out_types.size());
361 for (size_t i = 0; i < out_shapes.size(); ++i) {
362 auto out_shape = out_shapes[i];
363 auto out_type = out_types[i];
364 auto size =
365 MaybeOnDeviceFixed(ComputeStorageInRelay(out_shape, out_type), host_virtual_device_);
366 // Alignment is directly captured in the instruction so don't wrap in "on_device".
367 auto alignment = ComputeAlignment(out_type->dtype);
368 Var sto_var("storage_" + std::to_string(i), Type(nullptr));
369 auto val = AllocStorage(size, alignment, virtual_device, out_type->dtype);
370 storages.push_back(scope->Push(sto_var, MaybeOnDeviceFixed(val, virtual_device)));
371 }
372
373 Array<Expr> outs;
374 for (size_t i = 0; i < storages.size(); ++i) {
375 auto out_shape = out_shapes[i];
376 auto out_type = out_types[i];
377 auto storage = storages[i];
378 auto alloc = AllocTensor(storage, out_shape, out_type->dtype, out_type->shape);
379 Var out_var("out_" + std::to_string(i), Type(nullptr));
380 outs.push_back(scope->Push(out_var, MaybeOnDeviceFixed(alloc, virtual_device)));
381 }
382
383 Tuple tuple_outs(outs);
384 auto call =
385 InvokeTVMOp(func, ins, tuple_outs, Downcast<DictAttrs>(attrs.metadata.at("relay_attrs")));
386 scope->Push(MaybeOnDeviceFixed(call, virtual_device));
387 return ToTupleType(ret_type,
388 std::vector<Expr>(tuple_outs->fields.begin(), tuple_outs->fields.end()));
389 }
390
391 Expr EmitReshapeTensor(LetList* scope, const Tuple& ins, const CallLoweredAttrs& attrs,
392 const Type& ret_type) {
393 ICHECK_GE(ins->fields.size(), 1); // static reshape
394 ICHECK_LE(ins->fields.size(), 2); // dynamic reshape, second arg is shape
395 TensorType ret_ty = Downcast<TensorType>(ret_type);
396 Expr shape_expr;
397 if (IsDynamic(ret_type)) {
398 // Even though the desired output shape has been passed as the second argument to
399 // the dyn.reshape primitive, we'll still call that primitive's shape function. Go figure.
400 Array<Expr> out_shapes = EmitShapeFunc(scope, ins, attrs);
401 ICHECK_EQ(out_shapes.size(), 1);
402 shape_expr = out_shapes[0];
403 } else {
404 std::vector<int64_t> shape;
405 for (const auto& it : ret_ty->shape) {
406 const auto* imm = it.as<IntImmNode>();
407 CHECK(imm) << "expect static int shape";
408 shape.push_back(imm->value);
409 }
410 shape_expr = MaybeOnDeviceFixed(MakeConstant(shape), host_virtual_device_);
411 }
412 return ReshapeTensor(ins->fields[0], shape_expr, ret_ty->shape);
413 }
414
415 private:
416 const Op& device_copy_op_ = Op::Get("device_copy");
417 runtime::DataType compute_dtype_ = runtime::DataType::Int(64);
418 IRModule mod_;
419 VirtualDevice host_virtual_device_;
420
421 std::vector<LetList> scopes_;
422};
423
424namespace transform {
425
426Pass ManifestAllocImportStorage() {
427 auto pass_func = [](IRModule mod, tvm::transform::PassContext pass_cnxt) {
428 mod.CopyOnWrite();
429 mod->ImportFromStd("core.rly");
430 return mod;
431 };
432 return tvm::transform::CreateModulePass(pass_func, /*opt_level=*/0, "ManifestAllocImportStorage",
433 /*required=*/{});
434}
435
436Pass ManifestAllocImpl(VirtualDevice host_virtual_device) {
437 auto pass_func = [host_virtual_device](Function func, IRModule mod, PassContext ctxt) {
438 return DialectRewriter(mod, host_virtual_device).Rewrite(func);
439 };
440 return CreateFunctionPass(pass_func, 0, "ManifestAllocImpl", {});
441}
442
443Pass ManifestAlloc(VirtualDevice cpu_virtual_device) {
444 std::vector<Pass> passes = {ManifestAllocImportStorage(), InferType(),
445 ManifestAllocImpl(std::move(cpu_virtual_device)), InferType()};
446 return Sequential(passes, "ManifestAlloc");
447}
448
449TVM_REGISTER_GLOBAL("relay.transform.ManifestAlloc").set_body_typed(ManifestAlloc);
450
451} // namespace transform
452
453} // namespace relay
454} // namespace tvm
455