1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16#ifndef GLOW_BASE_DEVICETENSORTRANSFERMANAGER_H
17#define GLOW_BASE_DEVICETENSORTRANSFERMANAGER_H
18
19#include "glow/Base/Tensor.h"
20#include "glow/Support/Error.h"
21
22#include <functional>
23
24namespace glow {
25
26class Tensor;
27
28#define GLOW_DRT_DEFAULT_CB [](Error e) { ERR_TO_VOID(std::move(e)); }
29
30class DeviceTensorTransferManager {
31public:
32 virtual ~DeviceTensorTransferManager() {}
33 /// Copies the contents of \p tensor from the host to the \p location address
34 /// on this device. Updates the tensor residency info.
35 virtual void transferToDevice(
36 Tensor &tensor, void *locationContext = nullptr,
37 std::function<void(Error)> resultCB = GLOW_DRT_DEFAULT_CB) = 0;
38
39 /// Copies the device buffer associated with \p tensor to the host.
40 /// The tensor must be resident on this device. If \p release is true, frees
41 /// the device memory. Updates the tensor residency info.
42 virtual void transferFromDevice(
43 Tensor &tensor, bool release = true,
44 std::function<void(Error)> resultCB = GLOW_DRT_DEFAULT_CB) = 0;
45
46 /// Releases the device buffer associated with \p tensor.
47 virtual bool releaseDeviceTensor(void *locationContext) = 0;
48};
49
50} // namespace glow
51
52#endif // GLOW_BASE_DEVICETENSORTRANSFERMANAGER_H
53