1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_DTENSOR_CC_DTENSOR_DEVICE_H_
17#define TENSORFLOW_DTENSOR_CC_DTENSOR_DEVICE_H_
18
19#include <string>
20#include <unordered_map>
21#include <vector>
22
23#include "absl/strings/string_view.h"
24#include "tensorflow/c/eager/c_api_experimental.h"
25
26namespace tensorflow {
27namespace dtensor {
28
29// Configure a custom device which runs dtensor while executing
30// operations on `underlying_devices`. Allocates `device_info` and fills
31// `device`, which should then be passed to
32// TFE_RegisterCustomDevice. This only affects eager execution.
33//
34// `device_name` arg should match the `device_name` argument to
35// TFE_RegisterCustomDevice, and is the name of the custom device itself
36// (e.g. pass it to `tf.device` to place operations on it from Python).
37void AllocateDTensorDevice(absl::string_view device_name,
38 TFE_CustomDevice* device, void** device_info);
39
40// Add a mesh to the layout propagator indicated by `device_info`.
41//
42// `serialized_mesh` is a serialized Mesh proto.
43//
44// is_async indicates whether DTensor operations on this mesh will return
45// immediately (with "non-ready" handles) or block until executed. This is
46// exposed as an option for ease of debugging, and will typically be on.
47//
48// `is_host_mesh` indicates this is a CPU mesh used only for sea-of-donuts-style
49// host collectives.
50void AddMesh(const std::string& serialized_mesh, void* device_info,
51 bool is_async, bool is_host_mesh, TF_Status* status);
52
53// Sets a requested layout for outputs of all operations.
54void ExperimentalSetDefaultLayout(const std::string& serialized_layout,
55 void* device_info, TF_Status* status);
56void ExperimentalClearDefaultLayout(void* device_info, TF_Status* status);
57
58// TODO(b/175928457): remove once the bug is fixed.
59// Sets a requested default mesh.
60void ExperimentalSetDefaultMesh(const std::string& serialized_mesh,
61 void* device_info, TF_Status* status);
62void ExperimentalClearDefaultMesh(void* device_info, TF_Status* status);
63
64// Determines whether tensors with a shape previously associated with only one
65// layout use that layout if nothing else can be inferred.
66void SetSameShapePolicy(void* device_info, bool enabled);
67
68// Sets the global device ID-to-core ID mapping for a mesh. Global device IDs
69// are equal to XLA replica IDs for the single XLA computation used by DTensor.
70//
71// See the comment above Mesh::tpu_core_ids() for some nuances.
72void SetTPUCoreIDs(const std::string& mesh_name,
73 const std::vector<int>& tpu_core_ids, void* device_info,
74 TF_Status* status);
75
76// TODO(b/187112276): Delete once we have the TPUCoreIDs live with Device.
77void ClearTPUCoreIDs(void* device_info);
78
79// Returns TPU core locations when given a list of TPU core IDs.
80std::vector<std::vector<int>> TPUCoreIDsToLocations(
81 TFE_Context* context, const std::vector<int>& tpu_core_ids,
82 void* device_info);
83
84// Returns TPU core IDs when given a list of TPU core locations.
85std::vector<int> TPUCoreLocationsToIDs(
86 TFE_Context* context,
87 const std::vector<std::vector<int>>& tpu_core_locations, void* device_info);
88
89// Pack `inputs` tensors into a single parallel tensor handle.
90TFE_TensorHandle* Pack(TFE_Context* context, int num_inputs,
91 TFE_TensorHandle** inputs,
92 const std::string& string_layout, void* device_info,
93 TF_Status* status);
94
95// Returns the raw components placed on each device of `inputs`'s mesh.
96std::vector<TFE_TensorHandle*> Unpack(TFE_Context* context,
97 TFE_TensorHandle* input,
98 void* device_info, TF_Status* status);
99
100// Returns the layout of the dtensor 'input'.
101std::string FetchLayout(TFE_Context* context, TFE_TensorHandle* input,
102 void* device_info, TF_Status* status);
103
104// Pack `indices`, `values`, `shapes` tensors into a SparseTensorWithLayout.
105TFE_TensorHandle* SparsePack(TFE_Context* context, int num_inputs,
106 TFE_TensorHandle** indices,
107 TFE_TensorHandle** values,
108 TFE_TensorHandle** shapes,
109 const std::string& string_layout,
110 void* device_info, TF_Status* status);
111
112// Returns whether `input` is a sparse dtensor. Used in `Unpack` at the python
113// level to determine whether we should wrap component tensors back into a
114// SparseTensor.
115bool IsSparseDTensor(TFE_Context* context, TFE_TensorHandle* input,
116 void* device_info, TF_Status* status);
117
118// Returns a dictionary with cache hits and cache miss information.
119// Cache hit count is mapped under 'hit', and cache miss count is mapped under
120// 'miss'.
121std::unordered_map<std::string, int> GetFunctionCacheHitAndMissCount(
122 TFE_Context* context, void* device_info, TF_Status* status);
123} // namespace dtensor
124} // namespace tensorflow
125
126#endif // TENSORFLOW_DTENSOR_CC_DTENSOR_DEVICE_H_
127