1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tir/analysis/device_constraint_utils.cc
22 * \brief Utilities for extracting and applying device-related constraints to \p PrimFunc
23 * parameters.
24 *
25 * These utilities are used by the \p PlanDevices pass to extract memory (aka 'storage') scope
26 * information from \p PrimFuncs and convert them back into \p VirtualDevice form w.r.t. the
27 * original Relay type of the \p PrimFunc (ie before flattening of tuple arguments/results and
28 * conversion to destination-passing style aka DPS).
29 *
30 * A utility is also supplied to go the other way: impose memory scopes on \p PrimFunc parameters.
31 * However that's still in EXPERIMENTAL form.
32 *
33 * We may extend these utilities to also gather/apply layout information should we add that to
34 * \p VirtualDevice.
35 */
36
37#ifndef TVM_TIR_ANALYSIS_DEVICE_CONSTRAINT_UTILS_H_
38#define TVM_TIR_ANALYSIS_DEVICE_CONSTRAINT_UTILS_H_
39
40#include <tvm/target/virtual_device.h>
41#include <tvm/tir/function.h>
42
43namespace tvm {
44namespace tir {
45
46/*!
47 * A Relay Function with type:
48 * \code
49 * fn((Tensor[...], Tensor[...]), Tensor[...]) -> (Tensor[...], Tensor[...])
50 * ^ ^ ^ ^ ^
51 * a b c d e
52 * \endcode
53 * will be represented by a TIR PrimFunc in flattened and DPS form with at least 5 argument a..e.
54 * \code
55 * primfn(a: handle, b: handle, c: handle, d: handle, e: handle) {
56 * buffers = { ... }
57 * buffer_map = { ... }
58 * ...
59 * }
60 * \endcode
61 *
62 * Each such PrimFunc argument will me mapped to a \p Buffer who's underlying \p data \p Var
63 * has a \p PointerType.
64 *
65 * The PrimFunc may have additional non-pointer arguments, eg for:
66 * - scalar inputs and tensor dimensions
67 * - device contexts
68 * Those should be ignored here since they have no counterpart in the Relay Function.
69 *
70 * We'll need helpers to map on-the-fly between the Relay and TIR view of functions.
71 */
72
73/*!
74 * \brief Returns the \p VirtualDevices capturing the memory (aka storage) scope constraints for all
75 * the arguments and result of \p prim_func. However the result will be w.r.t. the \p prim_func's
76 * representation as a Relay \p Function of \p relay_func_type_ before lowering and conversion to
77 * DPS.
78 */
79Array<VirtualDevice> GetPrimFuncArgAndResultConstraints(const tir::PrimFunc& prim_func,
80 const FuncType& relay_func_type);
81
82/*
83 * \brief Returns \p prim_func written to capture the memory (aka storage) scope constraints
84 * for each of the \p prim_func's parameters given by \p arg_and_result_virtual_devices. However,
85 * \p arg_and_result_virtual_devices should be w.r.t. the \p prim_func's representation as a Relay
86 * \p Function of \p relay_func_type before lowering and conversion to DPS.
87 *
88 * CAUTION: This is experimental. The resulting \p PrimFunc may not have fully accounted for all
89 * new memory scopes.
90 */
91PrimFunc ApplyPrimFuncArgAndResultConstraints(
92 const PrimFunc& prim_func, const FuncType& relay_func_type,
93 const Array<VirtualDevice>& arg_and_result_virtual_devices);
94
95} // namespace tir
96} // namespace tvm
97
98#endif // TVM_TIR_ANALYSIS_DEVICE_CONSTRAINT_UTILS_H_
99