1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_DTENSOR_CC_CONSTANTS_H_
17#define TENSORFLOW_DTENSOR_CC_CONSTANTS_H_
18
19namespace tensorflow {
20namespace dtensor {
21// Constants used within dtensor scope.
22
23// Qualified attribute without `_` prefix.
24// Used in Ops attribute registration.
25static constexpr char kQualifiedLayoutAttr[] = "layout";
26
27// Internal attribute to DTensor MLIR passes and Graph nodes.
28// Prefixed with `_` so that it doesn't require op attribute registration.
29static constexpr char kLayoutAttr[] = "_layout";
30
31// Indicates a non-binding layout hint provided by the user.
32// `tf` prefix attached in MLIR importer for dialect requirements.
33static constexpr char kCustomDefaultLayoutAttr[] = "tf._default_layout";
34
35// Indicates a non-binding layout hint provided by the user.
36static constexpr char kDefaultLayoutAttr[] = "_default_layout";
37
38// Attribute carries layout information from Custom Device Arguments.
39// `tf` prefix attached in MLIR importer for dialect requirements.
40static constexpr char kCustomDeviceAttr[] = "tf._layout";
41
42// Attribute attached on _Arg node for the mesh config.
43static constexpr char kMeshAttr[] = "_mesh";
44
45// Attribute carries mesh information from Custom Device Arguments.
46// `tf` prefix attached in MLIR importer for dialect requirements.
47static constexpr char kCustomDeviceMeshAttr[] = "tf._mesh";
48
49// Attribute carries argument indices for newly inferred layout of resource
50// handle.
51static constexpr char kNewResourceLayoutIndices[] =
52 "_inferred_resource_indices";
53
54// Attribute carries layout for newly inferred layout of resource handle.
55static constexpr char kNewResourceArgLayouts[] = "_inferred_resource_layouts";
56
57// Attribute carries input layout information for shape op.
58static constexpr char kShapeOpInputLayout[] = "_shape_input_layout";
59
60// Attribute carries input layout index for shape op. This forms a 1 -> 1
61// mapping for kShapeOpInputLayout above.
62static constexpr char kShapeOpInputLayoutIndices[] = "_shape_input_indices";
63
64// Attribute that carries global shape of operation. Used to preserve global
65// shape to be used during SPMD expansion.
66static constexpr char kGlobalShape[] = "_global_shape";
67
68// Global shape attribute with `tf.` dialect to be used for annotating func op
69// arguments/return values.
70static constexpr char kGlobalShapeDialectAttr[] = "tf._global_shape";
71
72// Attribute attached to resource-type function arguments containing the local
73// shape of the tensor that is being assigned to it.
74static constexpr char kAssignedResourceLocalShape[] =
75 "tf._assigned_resource_local_shape";
76
77// Tensor handles smaller than this is considered as small tensor. We perform
78// some optimizations around it. For example, will be transformed into constant
79// values during graph building, instead of being passed as inputs. In addition,
80// we allow automatical broadcasting small non-DTensor to DTensor device, which
81// is very useful for shape/axis info tensor in eager mode (eliminating the need
82// forcing users to do explicit copy-to-mesh).
83static constexpr int kSmallTensorThreshold = 20;
84
85// Contains a serialized mesh. Will be attached to a FloorMod op to denote which
86// mesh the output of the FloorMod op is giving coordinates for.
87static constexpr char kMeshCoordinatesAttr[] = "_mesh_coordinates";
88
89// Attribute used to determine if a module pass should log long form information
90// such as IR dumps etc.
91static constexpr char kDoNotLog[] = "dtensor.do_not_log";
92
93// The number of TPU cores in a donut.
94static constexpr int kTpuDonutSize = 8;
95
96// An attribute used to cache the computation of device seeds, so that we don't
97// constantly recompute device seeds in a cluster for a given layout.
98static constexpr char kDeviceSeedForMeshDims[] =
99 "dtensor.device_seed_for_mesh_dims";
100
101// Attribute that determines whether to skip XlA compilation. There are some ops
102// that run on a TPU mesh but are not expected to be compiled by XLA, e.g.
103// VarHandleOp, DestroyResourceOp, etc. For such an case, set this attribute
104// to true on the StatefulPartitionedCallOp generated by MLIR lowering.
105static constexpr char kSkipXlaCompilation[] = "_skip_xla_compilation";
106
107// Prefix of pipelining mesh name (kPipelineMeshNamePrefix + composite device
108// name).
109static constexpr char kPipelineMeshNamePrefix[] = "pipe_cluster:";
110
111// An attribute which stores the cache_key for the graph in the module. Used
112// to uniquely name functions.
113static constexpr char kCacheKey[] = "dtensor.cache_key";
114
115// An attribute that determines whether a tensor is a sparse tensor. If this
116// attribute exists in a tensor, then this tensor is a sparse tensor.
117static constexpr char kSparseValue[] = "tf._sparse";
118
119// TPUEmbedding configuration attribute with `tf.` dialect to be used for
120// annotating func op that contains tpu embedding configuration ops.
121static constexpr char kTPUEmbeddingConfiguration[] =
122 "tf._tpu_embedding_configuration";
123
124// Attribute mapping table_id to func op arguments using as TPUEmbedding tables
125// `tf` prefix attached in MLIR importer for dialect requirements.
126static constexpr char kTPUEmbeddingTableID[] = "tf._tpu_embedding_table_id";
127
128// Attribute mapping slot_id to func op arguments using as TPUEmbedding slot
129// variables.`tf` prefix attached in MLIR importer for dialect requirements.
130static constexpr char kTPUEmbeddingSlotID[] = "tf._tpu_embedding_slot_id";
131
132// Name of dtensor load embedding function.
133static constexpr char kLoadEmbeddingFn[] = "load_embedding_fn";
134
135// Name of dtensor retrieve embedding function.
136static constexpr char kRetrieveEmbeddingFn[] = "retrieve_embedding_fn";
137} // namespace dtensor
138} // namespace tensorflow
139
140#endif // TENSORFLOW_DTENSOR_CC_CONSTANTS_H_
141