1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// The Stream is used in conjunction with the StreamExecutor "parent" to
17// perform actions with a linear stream of dependencies. Dependencies can also
18// be created between Streams to do task management (i.e. limit which tasks
19// can be performed concurrently and specify what task dependencies exist).
20
21#ifndef TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_STREAM_H_
22#define TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_STREAM_H_
23
24#include <complex>
25#include <cstdint>
26#include <functional>
27#include <memory>
28#include <type_traits>
29
30#include "absl/base/thread_annotations.h"
31#include "absl/synchronization/mutex.h"
32#include "tensorflow/compiler/xla/stream_executor/blas.h"
33#include "tensorflow/compiler/xla/stream_executor/device_memory.h"
34#include "tensorflow/compiler/xla/stream_executor/dnn.h"
35#include "tensorflow/compiler/xla/stream_executor/event.h"
36#include "tensorflow/compiler/xla/stream_executor/fft.h"
37#include "tensorflow/compiler/xla/stream_executor/kernel.h"
38#include "tensorflow/compiler/xla/stream_executor/launch_dim.h"
39#include "tensorflow/compiler/xla/stream_executor/lib/array_slice.h"
40#include "tensorflow/compiler/xla/stream_executor/platform/port.h"
41#include "tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.h"
42#include "tensorflow/compiler/xla/stream_executor/temporary_memory_manager.h"
43
44namespace stream_executor {
45
46namespace host {
47class HostBlas;
48class HostFft;
49class HostRng;
50class HostTimer;
51} // namespace host
52
53namespace ocl {
54class CLBlas;
55} // namespace ocl
56
57namespace internal {
58class StreamInterface;
59} // namespace internal
60
61class DeviceMemoryBase;
62template <typename ElemT>
63class DeviceMemory;
64
65class Timer;
66
67namespace dnn {
68class BatchDescriptor;
69class FilterDescriptor;
70class ConvolutionDescriptor;
71class ProfileResult;
72class AlgorithmDesc;
73} // namespace dnn
74
75class StreamExecutor;
76class ScratchAllocator;
77
78namespace detail {
79
80// Helper class to prevent a template function argument from being deduced. This
81// is identical to std::type_identity in C++20.
82template <typename T>
83struct NonDeduced {
84 using type = T;
85};
86template <typename T>
87using NonDeducedType = typename NonDeduced<T>::type;
88
89// Helper to return if `T` is the same type as `First` or any or `Rest`.
90template <typename T>
91constexpr bool is_any_of() {
92 return false;
93}
94
95template <typename T, typename First, typename... Rest>
96constexpr bool is_any_of() {
97 return std::is_same_v<T, First> || is_any_of<T, Rest...>();
98}
99
100} // namespace detail
101
102// Convert a type to the corresponding QuantizedActivationMode.
103template <typename ElementType>
104struct Quantization;
105
106// Represents a stream of dependent computations on a GPU device.
107//
108// The operations within a stream execute linearly and asynchronously until
109// BlockHostUntilDone() is invoked, which synchronously joins host code with
110// the execution of the stream.
111//
112// If any given operation fails when entraining work for the stream, ok() will
113// indicate that an error has occurred. After initialization, once a stream is
114// !ok(), it will never be ok().
115//
116// Thread-safe post-initialization.
117class Stream {
118 public:
119 // Instantiate a stream tied to parent as a platform executor. Work
120 // entrained onto this stream will be launched/managed on that
121 // StreamExecutor's platform.
122 explicit Stream(StreamExecutor *parent);
123
124 // Deallocates any stream resources that the parent StreamExecutor has
125 // bestowed
126 // upon this object.
127 ~Stream();
128
129 // Returns whether any errors have occurred while entraining work for this
130 // stream.
131 bool ok() const { return !InErrorState(); }
132
133 // Retrieves execution status back into the stream from the underlying
134 // implementation without blocking the stream.
135 //
136 // Normally, Stream::BlockHostUntilDone is used to get execution status.
137 // However, some devices use out-of-band mechnanisms to ensure their streams
138 // have finished on-device work, without needing to block the streams. (These
139 // devices should also override AllowsSyncOnCompletion to return false.) For
140 // these devices, this method can be used after work is finished to retrieve
141 // execution status.
142 port::Status RefreshStatus() TF_LOCKS_EXCLUDED(mu_);
143
144 // Initialize the stream. This must be performed before entraining any other
145 // operations.
146 Stream &Init() TF_LOCKS_EXCLUDED(mu_);
147
148 // Initializes timer t via the StreamExecutor.
149 Stream &InitTimer(Timer *t);
150
151 // Convenience wrapper around Init() and InitTimer().
152 Stream &InitWithTimer(Timer *t);
153
154 // Get or create a sub-stream from this stream. If there is any sub-stream in
155 // the pool that can be reused then just return this sub-stream. Otherwise
156 // create a new sub-stream.
157 //
158 // TODO(b/112196569): The semantics of failed sub-streams is error-prone.
159 Stream *GetOrCreateSubStream() TF_LOCKS_EXCLUDED(mu_);
160
161 // Return the sub-stream back to the host stream so that it can be reused
162 // later. Sub-streams that are !ok() will not be reused.
163 //
164 // TODO(b/112196569): The semantics of failed sub-streams is error-prone.
165 void ReturnSubStream(Stream *sub_stream) TF_LOCKS_EXCLUDED(mu_);
166
167 // Allocate temporary memories. The stream will deallocate them when blocked
168 // or destroyed.
169 template <typename T>
170 port::StatusOr<std::unique_ptr<TemporaryDeviceMemory<T>>>
171 AllocateTemporaryArray(uint64_t element_count);
172
173 // Entrains onto the stream of operations: a kernel launch with the given
174 // (variadic) parameters for the invocation. These arguments can be things
175 // like DeviceMemory or primitive types such as int. What arguments you may
176 // pass to a given kernel are noted as the template parameters to the
177 // TypedKernel type that the machocc compiler generates.
178 //
179 // Template parameters:
180 // Params... The type list of formal parameters that the typed kernel
181 // expects, which is matched against Args...
182 // Args... The deduced type list for passed actual arguments
183 //
184 // Implementation: A compile-time compatibility check is performed that has
185 // some leniency versus an exact parameter pack match -- for example,
186 // `const DeviceMemory<T>` is considered "pack compatible" with a
187 // `const DeviceMemory<T>&` formal parameter; in part, because we don't have
188 // perfect forwarding support without rvalue references. It also attempts to
189 // spit out helpful static_assert error traces with information as to the
190 // argument number and types that were mismatched.
191 template <typename... Params, typename... Args>
192 port::Status ThenLaunch(ThreadDim thread_dims, BlockDim block_dims,
193 const TypedKernel<Params...> &kernel, Args... args);
194
195 // Record a "start" event for the interval timer at this point in the
196 // stream's execution (relative to the previously and subsequently enqueued
197 // items in the stream's execution). Streams may be started/stopped multiple
198 // times.
199 Stream &ThenStartTimer(Timer *t);
200
201 // Record a "stop" event for the interval timer at this point in the
202 // stream's execution. See also Stream::ThenStartTimer.
203 Stream &ThenStopTimer(Timer *t);
204
205 // TODO(leary) If work is added to the stream that is being depended upon,
206 // then what? Have to describe what happens.
207 template <typename... Params>
208 Stream &ThenWaitFor(Stream *other, Params... more_streams) {
209 return ThenWaitFor(more_streams...).ThenWaitFor(other);
210 }
211
212 // Create a dependency for this stream's next work on the other stream
213 // completing. Does not take ownership of other, and other must not be
214 // null.
215 //
216 // Checks that a stream does not wait for itself, and it is up to the
217 // user to guarantee that a stream does not come to wait on itself in a
218 // cyclic manner; in that case, behavior is undefined.
219 //
220 // N.B. Base recursion case for the variadic ThenWaitFor.
221 Stream &ThenWaitFor(Stream *other);
222
223 // Waits for all streams values in others.
224 // Checks that there is no shallow circular wait (i.e. that "this" is not in
225 // others)
226 template <typename P>
227 Stream &ThenWaitFor(P others) {
228 for (auto &stream : *others) {
229 CHECK_NE(stream.get(), this);
230 ThenWaitFor(stream.get());
231 }
232 return *this;
233 }
234
235 // Waits for an event object to be set.
236 // Note that ThenRecordEvent must have been called on the event before
237 // you call this function; otherwise the event will be considered complete
238 // and this wait will do nothing.
239 Stream &ThenWaitFor(Event *event);
240
241 // Inserts the specified event into the end of this stream. Once the stream
242 // has processed all events prior to the insertion point, the event will be
243 // marked as completed.
244 // The stream does not take ownership of event - meaning that event's lifetime
245 // must extend past the point at which it is marked complete!
246 Stream &ThenRecordEvent(Event *event);
247
248 ////////////////
249 // DNN support
250 //
251 // See DnnSupport::* for comments on the following methods.
252
253 Stream &ThenBatchNormalizationForward(
254 const DeviceMemory<float> &x, const DeviceMemory<float> &scale,
255 const DeviceMemory<float> &offset,
256 const DeviceMemory<float> &estimated_mean,
257 const DeviceMemory<float> &estimated_variance,
258 const DeviceMemory<float> &side_input, const dnn::BatchDescriptor &x_desc,
259 const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
260 const double exponential_average_factor,
261 dnn::ActivationMode activation_mode, DeviceMemory<float> *y,
262 DeviceMemory<float> *batch_mean, DeviceMemory<float> *batch_var,
263 DeviceMemory<float> *saved_mean, DeviceMemory<float> *saved_inv_var,
264 bool is_training, ScratchAllocator *reserve_space_allocator,
265 ScratchAllocator *workspace_allocator);
266
267 Stream &ThenBatchNormalizationBackward(
268 const DeviceMemory<float> &y_backprop, const DeviceMemory<float> &x,
269 const DeviceMemory<float> &scale, const DeviceMemory<float> &offset,
270 const DeviceMemory<float> &mean, const DeviceMemory<float> &inv_var,
271 const DeviceMemory<float> &y, const dnn::BatchDescriptor &x_desc,
272 const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
273 dnn::ActivationMode activation_mode, DeviceMemory<float> *x_backprop,
274 DeviceMemory<float> *scale_backprop, DeviceMemory<float> *offset_backprop,
275 DeviceMemory<float> *side_input_backprop,
276 DeviceMemory<uint8_t> *reserve_space_data,
277 ScratchAllocator *workspace_allocator);
278
279 Stream &ThenBatchNormalizationForward(
280 const DeviceMemory<Eigen::half> &x, const DeviceMemory<float> &scale,
281 const DeviceMemory<float> &offset,
282 const DeviceMemory<float> &estimated_mean,
283 const DeviceMemory<float> &estimated_variance,
284 const DeviceMemory<Eigen::half> &side_input,
285 const dnn::BatchDescriptor &x_desc,
286 const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
287 const double exponential_average_factor,
288 dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half> *y,
289 DeviceMemory<float> *batch_mean, DeviceMemory<float> *batch_var,
290 DeviceMemory<float> *saved_mean, DeviceMemory<float> *saved_inv_var,
291 bool is_training, ScratchAllocator *reserve_space_allocator,
292 ScratchAllocator *workspace_allocator);
293
294 Stream &ThenBatchNormalizationBackward(
295 const DeviceMemory<Eigen::half> &y_backprop,
296 const DeviceMemory<Eigen::half> &x, const DeviceMemory<float> &scale,
297 const DeviceMemory<float> &offset, const DeviceMemory<float> &mean,
298 const DeviceMemory<float> &inv_var, const DeviceMemory<Eigen::half> &y,
299 const dnn::BatchDescriptor &x_desc,
300 const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
301 dnn::ActivationMode activation_mode,
302 DeviceMemory<Eigen::half> *x_backprop,
303 DeviceMemory<float> *scale_backprop, DeviceMemory<float> *offset_backprop,
304 DeviceMemory<Eigen::half> *side_input_backprop,
305 DeviceMemory<uint8_t> *reserve_space_data,
306 ScratchAllocator *workspace_allocator);
307
308 Stream &ThenConvolve(const dnn::BatchDescriptor &input_descriptor,
309 const DeviceMemory<float> &input_data,
310 const dnn::FilterDescriptor &filter_descriptor,
311 const DeviceMemory<float> &filter_data,
312 const dnn::ConvolutionDescriptor &convolution_descriptor,
313 const dnn::BatchDescriptor &output_descriptor,
314 DeviceMemory<float> *output);
315
316 Stream &ThenConvolveQuantized(
317 const dnn::BatchDescriptor &input_descriptor,
318 const DeviceMemory<float> &input_data,
319 const dnn::FilterDescriptor &filter_descriptor,
320 const DeviceMemory<int8_t> &filter_coefficients,
321 const DeviceMemory<float> &coefficient_scales,
322 const dnn::ConvolutionDescriptor &convolution_descriptor,
323 const dnn::BatchDescriptor &output_descriptor,
324 DeviceMemory<float> *output_data);
325
326 Stream &ThenConvolveQuantized(
327 const dnn::BatchDescriptor &input_descriptor,
328 const DeviceMemory<float> &input_data,
329 const dnn::FilterDescriptor &filter_descriptor,
330 const DeviceMemory<int16> &filter_coefficients,
331 const DeviceMemory<float> &coefficient_scales,
332 const dnn::ConvolutionDescriptor &convolution_descriptor,
333 const dnn::BatchDescriptor &output_descriptor,
334 DeviceMemory<float> *output_data);
335
336 template <typename InputType, typename OutputType>
337 port::Status ConvolveWithAlgorithm(
338 dnn::ConvolutionKind kind, const dnn::BatchDescriptor &input_descriptor,
339 DeviceMemory<InputType> input_data,
340 const dnn::FilterDescriptor &filter_descriptor,
341 DeviceMemory<InputType> filter_data,
342 const dnn::BatchDescriptor &output_descriptor,
343 DeviceMemory<OutputType> output_data,
344 const dnn::ConvolutionDescriptor &convolution_descriptor,
345 ScratchAllocator *scratch_allocator,
346 const dnn::AlgorithmConfig &algorithm_config,
347 dnn::ProfileResult *output_profile_result) {
348 DeviceMemory<uint8_t> scratch_memory;
349 dnn::AlgorithmDesc algorithm_desc;
350 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
351 TF_RETURN_IF_ERROR(dnn->PrepareForConvolution(
352 kind, this, input_descriptor, input_data, filter_descriptor,
353 filter_data, output_descriptor, output_data, convolution_descriptor,
354 algorithm_config, scratch_allocator, &algorithm_desc,
355 &scratch_memory));
356 return dnn->DoConvolve(kind, dnn::ToDataType<InputType>::value,
357 dnn::ToDataType<OutputType>::value, this,
358 input_descriptor, input_data, filter_descriptor,
359 filter_data, output_descriptor, output_data,
360 convolution_descriptor, algorithm_desc,
361 scratch_memory, output_profile_result);
362 }
363 return port::UnimplementedError("DNN library is not found.");
364 }
365
366 template <typename InputT, typename ScaleT, typename SideInputT,
367 typename BiasT, typename OutputT>
368 port::Status FusedConvolveWithAlgorithm(
369 const dnn::BatchDescriptor &conv_input_descriptor,
370 const DeviceMemory<InputT> &conv_input_data, ScaleT conv_input_scale,
371 const dnn::FilterDescriptor &filter_descriptor,
372 const DeviceMemory<InputT> &filter_data,
373 const dnn::ConvolutionDescriptor &convolution_descriptor,
374 const DeviceMemory<SideInputT> &side_input_data, ScaleT side_input_scale,
375 const dnn::BatchDescriptor &bias_descriptor,
376 const DeviceMemory<BiasT> &biases, dnn::ActivationMode activation_mode,
377 const dnn::BatchDescriptor &output_descriptor,
378 DeviceMemory<OutputT> *output, ScratchAllocator *scratch_allocator,
379 const dnn::AlgorithmConfig &algorithm_config,
380 dnn::ProfileResult *output_profile_result) {
381 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
382 return dnn->DoFusedConvolve(
383 this, dnn::ToDataType<InputT>::value,
384 dnn::ToDataType<SideInputT>::value, dnn::ToDataType<BiasT>::value,
385 dnn::ToDataType<OutputT>::value, conv_input_descriptor,
386 conv_input_data, conv_input_scale, filter_descriptor, filter_data,
387 convolution_descriptor, side_input_data, side_input_scale,
388 bias_descriptor, biases, activation_mode, output_descriptor, *output,
389 scratch_allocator, algorithm_config, output_profile_result);
390 }
391 return port::UnimplementedError("DNN library is not found.");
392 }
393
394 port::StatusOr<std::unique_ptr<const dnn::ConvRunner>> ConvolveRunnerFromDesc(
395 const dnn::AlgorithmDesc &algorithm_desc, dnn::ConvolutionKind kind,
396 dnn::DataType element_type, dnn::DataType output_type,
397 const dnn::BatchDescriptor &input_descriptor,
398 const dnn::FilterDescriptor &filter_descriptor,
399 const dnn::BatchDescriptor &output_descriptor,
400 const dnn::ConvolutionDescriptor &convolution_descriptor) {
401 dnn::DnnSupport *dnn_support = parent_->AsDnn();
402 if (!dnn_support) {
403 return port::UnimplementedError("DNN library is not found.");
404 }
405 return dnn_support->ConvolveRunnerFromDesc(
406 this, algorithm_desc, kind, element_type, output_type, input_descriptor,
407 filter_descriptor, output_descriptor, convolution_descriptor);
408 }
409
410 port::StatusOr<std::unique_ptr<const dnn::FusedConvRunner>>
411 FusedConvolveRunnerFromDesc(
412 const dnn::AlgorithmDesc &algorithm_desc, dnn::ConvolutionKind kind,
413 dnn::DataType element_type, dnn::DataType bias_type,
414 dnn::DataType output_type, double conv_input_scale,
415 double side_input_scale, double leakyrelu_alpha,
416 const dnn::BatchDescriptor &input_descriptor,
417 const dnn::FilterDescriptor &filter_descriptor,
418 const dnn::BatchDescriptor &bias_descriptor,
419 const dnn::BatchDescriptor &output_descriptor,
420 const dnn::ConvolutionDescriptor &convolution_descriptor,
421 dnn::ActivationMode activation_mode) {
422 dnn::DnnSupport *dnn_support = parent_->AsDnn();
423 if (!dnn_support) {
424 return port::UnimplementedError("DNN library is not found.");
425 }
426 return dnn_support->FusedConvolveRunnerFromDesc(
427 this, algorithm_desc, kind, element_type, bias_type, output_type,
428 conv_input_scale, side_input_scale, leakyrelu_alpha, input_descriptor,
429 filter_descriptor, bias_descriptor, output_descriptor,
430 convolution_descriptor, activation_mode);
431 }
432
433 Stream &ThenSeparableConvolve(
434 const dnn::BatchDescriptor &input_descriptor,
435 const DeviceMemory<float> &input_data,
436 const dnn::FilterDescriptor &filter_descriptor, int depth_multiplier,
437 const DeviceMemory<float> &first_weights,
438 const DeviceMemory<float> &second_weights,
439 const dnn::ConvolutionDescriptor &convolution_descriptor,
440 const dnn::BatchDescriptor &output_descriptor,
441 DeviceMemory<float> *output);
442
443 Stream &ThenMatMul(const DeviceMemory<float> &input_data,
444 const DeviceMemory<float> &weights,
445 const dnn::BatchDescriptor &input_dimensions,
446 const dnn::BatchDescriptor &output_dimensions,
447 DeviceMemory<float> *output_data);
448
449 Stream &ThenMatMulQuantized(const DeviceMemory<float> &input_data,
450 const DeviceMemory<int8_t> &weights,
451 const DeviceMemory<float> &weight_scales,
452 const dnn::BatchDescriptor &input_dimensions,
453 const dnn::BatchDescriptor &output_dimensions,
454 DeviceMemory<float> *output_data);
455
456 Stream &ThenMatMulQuantized(const DeviceMemory<float> &input_data,
457 const DeviceMemory<int16> &weights,
458 const DeviceMemory<float> &weight_scales,
459 const dnn::BatchDescriptor &input_dimensions,
460 const dnn::BatchDescriptor &output_dimensions,
461 DeviceMemory<float> *output_data);
462
463 Stream &ThenBiasAdd(const DeviceMemory<float> &input_data,
464 const DeviceMemory<float> &biases,
465 const dnn::BatchDescriptor &dimensions,
466 DeviceMemory<float> *output_data);
467
468 template <typename ElementType>
469 port::Status ThenPoolForward(
470 const dnn::PoolingDescriptor &pooling_dimensions,
471 const dnn::BatchDescriptor &input_dimensions,
472 const DeviceMemory<ElementType> &input_data,
473 const dnn::BatchDescriptor &output_dimensions,
474 DeviceMemory<ElementType> *output_data,
475 ScratchAllocator *workspace_allocator = nullptr) {
476 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
477 return dnn->DoPoolForward(dnn::ToDataType<ElementType>::value, this,
478 pooling_dimensions, input_dimensions,
479 input_data, output_dimensions, *output_data,
480 workspace_allocator);
481 }
482 return port::UnimplementedError("DNN library is not found.");
483 }
484
485 template <typename ElementType>
486 port::Status ThenPoolBackward(
487 const dnn::PoolingDescriptor &pooling_dimensions,
488 const dnn::BatchDescriptor &input_dimensions,
489 const DeviceMemory<ElementType> &input_data,
490 const dnn::BatchDescriptor &output_dimensions,
491 const DeviceMemory<ElementType> &output_data,
492 const DeviceMemory<ElementType> &input_diff_data,
493 DeviceMemory<ElementType> *output_diff_data,
494 ScratchAllocator *workspace_allocator = nullptr) {
495 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
496 return dnn->DoPoolBackward(
497 dnn::ToDataType<ElementType>::value, this, pooling_dimensions,
498 input_dimensions, input_data, output_dimensions, output_data,
499 input_diff_data, *output_diff_data, workspace_allocator);
500 }
501 return port::UnimplementedError("DNN library is not found.");
502 }
503
504 Stream &ThenNormalizeWithDimensions(
505 const dnn::NormalizeDescriptor &normalize_descriptor,
506 const dnn::BatchDescriptor &dimensions,
507 const DeviceMemory<float> &input_data, DeviceMemory<float> *output_data);
508
509 Stream &ThenNormalizeBackwardWithDimensions(
510 const dnn::NormalizeDescriptor &normalize_descriptor,
511 const dnn::BatchDescriptor &dimensions,
512 const DeviceMemory<float> &raw_data,
513 const DeviceMemory<float> &normalized_data,
514 const DeviceMemory<float> &normalized_variable_gradient,
515 DeviceMemory<float> *raw_variable_gradient,
516 ScratchAllocator *workspace_allocator = nullptr);
517
518 Stream &ThenActivate(dnn::ActivationMode activation_mode,
519 const dnn::BatchDescriptor &dimensions,
520 const DeviceMemory<float> &input_data,
521 DeviceMemory<float> *output_data);
522
523 // Same as ThenActivate, but also takes an options argument that can be used
524 // for platform-specific option flags.
525 Stream &ThenActivateWithOptions(dnn::ActivationMode activation_mode,
526 const dnn::BatchDescriptor &dimensions,
527 const DeviceMemory<float> &input_data,
528 DeviceMemory<float> *output_data,
529 uint64_t options);
530
531 Stream &ThenDepthConcatenate(
532 port::ArraySlice<dnn::BatchDescriptor> input_dimensions, // non-absl ok
533 port::ArraySlice<const DeviceMemory<float> *> input_data, // non-absl ok
534 DeviceMemory<float> *output_data);
535
536 Stream &ThenSpaceConcatenate(
537 port::ArraySlice<dnn::BatchDescriptor> input_dimensions, // non-absl ok
538 port::ArraySlice<const DeviceMemory<float> *> input_data, // non-absl ok
539 DeviceMemory<float> *output_data,
540 dnn::SpaceConcatenateMode concat_direction);
541
542 // Change the layout of the data by shrinking one dimension (or set of
543 // dimensions) and growing another dimension (or set of dimensions), while
544 // keeping the total number of data elements constant, and maintaining the
545 // current data ordering.
546 Stream &ThenReshape(const dnn::BatchDescriptor &input_dimensions,
547 const DeviceMemory<float> &input_data,
548 const dnn::BatchDescriptor &output_dimensions,
549 DeviceMemory<float> *output_data);
550
551 // Depth to space takes an X by Y image with depth D*M² and changes it to an
552 // MX x MY image with depth D. Each input location (x,y) with depth D*M² in
553 // the input image is changed to an MxM contiguous area in the output image,
554 // with the values being laid out in raster order specified by
555 // DepthToSpaceLayout, and will have a new depth of D.
556 // See the DoDepthToSpace comment for more information.
557 Stream &ThenDepthToSpace(const dnn::BatchDescriptor &input_dimensions,
558 const DeviceMemory<float> &input_data,
559 const dnn::DepthToSpaceLayout &depth_to_space_layout,
560 const int sqrt_depth_reduction,
561 DeviceMemory<float> *output_data);
562
563 // Space to depth is the inverse of depth to space. Space to depth takes each
564 // non-overlapping M by M patch (in the X and Y dimensions) with depth D of
565 // the input, and transforms it to a 1 by 1 patch with depth D*M². If the
566 // input has size (MX, MY, D), the output has size (X, Y, D*M²). The number of
567 // data elements is not changed.
568 Stream &ThenSpaceToDepth(const dnn::BatchDescriptor &input_dimensions,
569 const DeviceMemory<float> &input_data,
570 const dnn::DepthToSpaceLayout &space_to_depth_layout,
571 const int sqrt_depth_increase,
572 DeviceMemory<float> *output_data);
573
574 Stream &ThenElementwiseOperate(
575 dnn::ElementwiseOperation operation,
576 port::ArraySlice<dnn::BatchDescriptor> input_dimensions, // non-absl ok
577 port::ArraySlice<const DeviceMemory<float> *> input_data, // non-absl ok
578 const dnn::BatchDescriptor &output_dimensions,
579 DeviceMemory<float> *output_data);
580
581 Stream &ThenElementwiseOperateScaledQuantized(
582 dnn::ElementwiseOperation operation,
583 port::ArraySlice<int> input_multiplicands, // non-absl ok
584 int output_divisor,
585 port::ArraySlice<dnn::BatchDescriptor> input_dimensions, // non-absl ok
586 port::ArraySlice<const DeviceMemory<float> *> input_data, // non-absl ok
587 const dnn::BatchDescriptor &output_dimensions,
588 DeviceMemory<float> *output_data);
589
590 Stream &ThenXYPad(const dnn::BatchDescriptor &dimensions,
591 const DeviceMemory<float> &input_data, int64_t left_pad,
592 int64_t right_pad, int64_t top_pad, int64_t bottom_pad,
593 DeviceMemory<float> *output_data);
594
595 Stream &ThenXYSlice(const dnn::BatchDescriptor &dimensions,
596 const DeviceMemory<float> &input_data, int64_t left_trim,
597 int64_t right_trim, int64_t top_trim, int64_t bottom_trim,
598 DeviceMemory<float> *output_data);
599
600 // Grows the input tensor by replicating the X and Y dimensions. The batch and
601 // depth/feature_map dimensions are unchanged. Currently, the input tensor is
602 // limited to X=1 and Y=1.
603 Stream &ThenXYBroadcast(const dnn::BatchDescriptor &dimensions,
604 const DeviceMemory<float> &input_data,
605 int64_t replicate_x, int64_t replicate_y,
606 DeviceMemory<float> *output_data);
607
608 // See DnnSupport::DoMemcpyD2HQuantized.
609 Stream &ThenMemcpyD2HQuantized(const DeviceMemory<float> &gpu_unquantized_src,
610 dnn::QuantizedActivationMode mode,
611 void *host_dst, uint64_t size);
612
613 // Template version of ThenMemcpyD2HQuantized that takes a MutableArraySlice
614 // and uses the Quantization trait to call the generic version of
615 // ThenMemcpyD2HQuantized with the correct QuantizedActivationMode.
616 template <typename ElementType>
617 Stream &ThenMemcpyD2HQuantized(
618 const DeviceMemory<float> &gpu_unquantized_src,
619 port::MutableArraySlice<ElementType> host_dst) {
620 return ThenMemcpyD2HQuantized(
621 gpu_unquantized_src, Quantization<ElementType>::kModeId,
622 host_dst.data(), host_dst.size() * sizeof(ElementType));
623 }
624
625 // See DnnSupport::DoMemcpyH2DQuantized.
626 Stream &ThenMemcpyH2DQuantized(const void *host_src, uint64_t size,
627 dnn::QuantizedActivationMode mode,
628 DeviceMemory<float> *gpu_unquantized_dst);
629
630 // Template version of ThenMemcpyH2DQuantized that takes an array slice
631 // and uses the Quantization trait to call the generic version of
632 // ThenMemcpyH2DQuantized with the correct QuantizedActivationMode.
633 template <typename ElementType>
634 Stream &ThenMemcpyH2DQuantized(
635 port::ArraySlice<ElementType> host_src, // non-absl ok
636 DeviceMemory<float> *gpu_unquantized_dst) {
637 return ThenMemcpyH2DQuantized(
638 host_src.data(), host_src.size() * sizeof(ElementType),
639 Quantization<ElementType>::kModeId, gpu_unquantized_dst);
640 }
641
642 // See DnnSupport::DoCopyHostBuffer2Device.
643 Stream &ThenCopyHostBuffer2Device(HostBuffer *buffer_src,
644 DeviceMemory<float> *gpu_unquantized_dst);
645
646 // See DnnSupport::DoCopyDevice2HostBuffer.
647 Stream &ThenCopyDevice2HostBuffer(
648 const DeviceMemory<float> &gpu_unquantized_src, HostBuffer *buffer_dst);
649
650 /////////////////
651 // BLAS support
652
653 // See BlasSupport::DoBlasAxpy. Note that, even for the case where alpha is
654 // present in DeviceMemory, it must be an execution-time constant (i.e. a
655 // value
656 // that the stream does not change or populate during the course of
657 // execution). The value is effectively captured at stream-enqueue time.
658 Stream &ThenBlasAxpy(uint64_t elem_count, float alpha,
659 const DeviceMemory<float> &x, int incx,
660 DeviceMemory<float> *y, int incy);
661 Stream &ThenBlasAxpy(uint64_t elem_count, double alpha,
662 const DeviceMemory<double> &x, int incx,
663 DeviceMemory<double> *y, int incy);
664 Stream &ThenBlasAxpy(uint64_t elem_count, std::complex<float> alpha,
665 const DeviceMemory<std::complex<float>> &x, int incx,
666 DeviceMemory<std::complex<float>> *y, int incy);
667 Stream &ThenBlasAxpy(uint64_t elem_count, std::complex<double> alpha,
668 const DeviceMemory<std::complex<double>> &x, int incx,
669 DeviceMemory<std::complex<double>> *y, int incy);
670
671 // See BlasSupport::DoBlasCopy.
672 Stream &ThenBlasCopy(uint64_t elem_count, const DeviceMemory<float> &x,
673 int incx, DeviceMemory<float> *y, int incy);
674 Stream &ThenBlasCopy(uint64_t elem_count, const DeviceMemory<double> &x,
675 int incx, DeviceMemory<double> *y, int incy);
676 Stream &ThenBlasCopy(uint64_t elem_count,
677 const DeviceMemory<std::complex<float>> &x, int incx,
678 DeviceMemory<std::complex<float>> *y, int incy);
679 Stream &ThenBlasCopy(uint64_t elem_count,
680 const DeviceMemory<std::complex<double>> &x, int incx,
681 DeviceMemory<std::complex<double>> *y, int incy);
682
683 // See BlasSupport::DoBlasScal.
684 Stream &ThenBlasScal(uint64_t elem_count, float alpha, DeviceMemory<float> *x,
685 int incx);
686 Stream &ThenBlasScal(uint64_t elem_count, double alpha,
687 DeviceMemory<double> *x, int incx);
688 Stream &ThenBlasScal(uint64_t elem_count, float alpha,
689 DeviceMemory<std::complex<float>> *x, int incx);
690 Stream &ThenBlasScal(uint64_t elem_count, double alpha,
691 DeviceMemory<std::complex<double>> *x, int incx);
692 Stream &ThenBlasScal(uint64_t elem_count, std::complex<float> alpha,
693 DeviceMemory<std::complex<float>> *x, int incx);
694 Stream &ThenBlasScal(uint64_t elem_count, std::complex<double> alpha,
695 DeviceMemory<std::complex<double>> *x, int incx);
696
697 // See BlasSupport::DoBlasGemv.
698 Stream &ThenBlasGemv(blas::Transpose trans, uint64_t m, uint64 n, float alpha,
699 const DeviceMemory<float> &a, int lda,
700 const DeviceMemory<float> &x, int incx, float beta,
701 DeviceMemory<float> *y, int incy);
702 Stream &ThenBlasGemv(blas::Transpose trans, uint64_t m, uint64 n,
703 double alpha, const DeviceMemory<double> &a, int lda,
704 const DeviceMemory<double> &x, int incx, double beta,
705 DeviceMemory<double> *y, int incy);
706 Stream &ThenBlasGemv(blas::Transpose trans, uint64_t m, uint64 n,
707 std::complex<float> alpha,
708 const DeviceMemory<std::complex<float>> &a, int lda,
709 const DeviceMemory<std::complex<float>> &x, int incx,
710 std::complex<float> beta,
711 DeviceMemory<std::complex<float>> *y, int incy);
712 Stream &ThenBlasGemv(blas::Transpose trans, uint64_t m, uint64 n,
713 std::complex<double> alpha,
714 const DeviceMemory<std::complex<double>> &a, int lda,
715 const DeviceMemory<std::complex<double>> &x, int incx,
716 std::complex<double> beta,
717 DeviceMemory<std::complex<double>> *y, int incy);
718
719 Stream &ThenBlasGemvWithProfiling(blas::Transpose trans, uint64_t m, uint64 n,
720 float alpha, const DeviceMemory<float> &a,
721 int lda, const DeviceMemory<float> &x,
722 int incx, float beta,
723 DeviceMemory<float> *y, int incy,
724 blas::ProfileResult *output_profile_result);
725 Stream &ThenBlasGemvWithProfiling(blas::Transpose trans, uint64_t m, uint64 n,
726 double alpha, const DeviceMemory<double> &a,
727 int lda, const DeviceMemory<double> &x,
728 int incx, double beta,
729 DeviceMemory<double> *y, int incy,
730 blas::ProfileResult *output_profile_result);
731 Stream &ThenBlasGemvWithProfiling(
732 blas::Transpose trans, uint64_t m, uint64 n, std::complex<float> alpha,
733 const DeviceMemory<std::complex<float>> &a, int lda,
734 const DeviceMemory<std::complex<float>> &x, int incx,
735 std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy,
736 blas::ProfileResult *output_profile_result);
737 Stream &ThenBlasGemvWithProfiling(
738 blas::Transpose trans, uint64_t m, uint64 n, std::complex<double> alpha,
739 const DeviceMemory<std::complex<double>> &a, int lda,
740 const DeviceMemory<std::complex<double>> &x, int incx,
741 std::complex<double> beta, DeviceMemory<std::complex<double>> *y,
742 int incy, blas::ProfileResult *output_profile_result);
743
744 // See BlasSupport::DoBlasSbmv.
745 Stream &ThenBlasSbmv(blas::UpperLower uplo, uint64_t n, uint64 k, float alpha,
746 const DeviceMemory<float> &a, int lda,
747 const DeviceMemory<float> &x, int incx, float beta,
748 DeviceMemory<float> *y, int incy);
749 Stream &ThenBlasSbmv(blas::UpperLower uplo, uint64_t n, uint64 k,
750 double alpha, const DeviceMemory<double> &a, int lda,
751 const DeviceMemory<double> &x, int incx, double beta,
752 DeviceMemory<double> *y, int incy);
753
754 template <typename InputType>
755 port::Status ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
756 uint64_t m, uint64 n, uint64 k,
757 const DeviceMemory<InputType> &a, int lda,
758 const DeviceMemory<InputType> &b, int ldb,
759 DeviceMemory<InputType> *c, int ldc,
760 blas::ComputePrecision precision) {
761 InputType alpha{1.0};
762 InputType beta{0.0};
763 return ThenBlasGemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c,
764 ldc, precision);
765 }
766
767 // TODO(parkers): Update all callers to pass kDefaultComputePrecision.
768 template <typename InputType>
769 port::Status ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
770 uint64_t m, uint64 n, uint64 k,
771 const DeviceMemory<InputType> &a, int lda,
772 const DeviceMemory<InputType> &b, int ldb,
773 DeviceMemory<InputType> *c, int ldc) {
774 return ThenBlasGemm(transa, transb, m, n, k, a, lda, b, ldb, c, ldc,
775 blas::kDefaultComputePrecision);
776 }
777
778 template <typename InputType, typename ConstantType>
779 port::Status ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
780 uint64_t m, uint64 n, uint64 k, ConstantType alpha,
781 const DeviceMemory<InputType> &a, int lda,
782 const DeviceMemory<InputType> &b, int ldb,
783 ConstantType beta, DeviceMemory<InputType> *c,
784 int ldc, blas::ComputePrecision precision) {
785 static_assert(
786 detail::is_any_of<InputType, Eigen::half, Eigen::bfloat16, float,
787 double, std::complex<float>, std::complex<double>>(),
788 "Input can be half, bf16, float, double, std::complex<float> or "
789 "std::complex<double>");
790 static_assert(!std::is_same_v<InputType, Eigen::half> ||
791 detail::is_any_of<ConstantType, float, Eigen::half>(),
792 "If input is Eigen::half, constant has to be either "
793 "Eigen::half or float");
794 static_assert(
795 detail::is_any_of<InputType, Eigen::half, ConstantType>(),
796 "If input is not Eigen::half, constant and input types have to match");
797 blas::BlasSupport *blas = parent()->AsBlas();
798 if (!blas) {
799 return port::InternalError(
800 "Attempting to perform BLAS operation using "
801 "StreamExecutor without BLAS support");
802 }
803
804 void *alpha_ptr = &alpha;
805 void *beta_ptr = &beta;
806 float alpha_storage, beta_storage;
807 UpcastHalfToFloat<ConstantType>(&alpha_ptr, &beta_ptr, &alpha_storage,
808 &beta_storage);
809
810 return blas->DoBlasGemm(this, transa, transb, m, n, k,
811 blas::ToDataType<InputType>::value, alpha_ptr, a,
812 lda, b, ldb, beta_ptr, c, ldc, precision);
813 }
814
815 // TODO(parkers): Update all callers to pass kDefaultComputePrecision.
816 template <typename InputType, typename ConstantType>
817 port::Status ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
818 uint64_t m, uint64 n, uint64 k, ConstantType alpha,
819 const DeviceMemory<InputType> &a, int lda,
820 const DeviceMemory<InputType> &b, int ldb,
821 ConstantType beta, DeviceMemory<InputType> *c,
822 int ldc) {
823 return ThenBlasGemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c,
824 ldc, blas::kDefaultComputePrecision);
825 }
826
827 Stream &ThenBlasGemmWithProfiling(blas::Transpose transa,
828 blas::Transpose transb, uint64_t m,
829 uint64 n, uint64_t k, float alpha,
830 const DeviceMemory<Eigen::half> &a, int lda,
831 const DeviceMemory<Eigen::half> &b, int ldb,
832 float beta, DeviceMemory<Eigen::half> *c,
833 int ldc,
834 blas::ProfileResult *output_profile_result);
835 Stream &ThenBlasGemmWithProfiling(blas::Transpose transa,
836 blas::Transpose transb, uint64_t m,
837 uint64 n, uint64_t k, float alpha,
838 const DeviceMemory<float> &a, int lda,
839 const DeviceMemory<float> &b, int ldb,
840 float beta, DeviceMemory<float> *c, int ldc,
841 blas::ProfileResult *output_profile_result);
842 Stream &ThenBlasGemmWithProfiling(blas::Transpose transa,
843 blas::Transpose transb, uint64_t m,
844 uint64 n, uint64_t k, double alpha,
845 const DeviceMemory<double> &a, int lda,
846 const DeviceMemory<double> &b, int ldb,
847 double beta, DeviceMemory<double> *c,
848 int ldc,
849 blas::ProfileResult *output_profile_result);
850 Stream &ThenBlasGemmWithProfiling(
851 blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
852 uint64_t k, std::complex<float> alpha,
853 const DeviceMemory<std::complex<float>> &a, int lda,
854 const DeviceMemory<std::complex<float>> &b, int ldb,
855 std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
856 blas::ProfileResult *output_profile_result);
857 Stream &ThenBlasGemmWithProfiling(
858 blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
859 uint64_t k, std::complex<double> alpha,
860 const DeviceMemory<std::complex<double>> &a, int lda,
861 const DeviceMemory<std::complex<double>> &b, int ldb,
862 std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
863 blas::ProfileResult *output_profile_result);
864
865 template <typename InputType, typename OutputType>
866 port::Status ThenBlasGemmWithAlgorithm(
867 blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
868 uint64_t k, const DeviceMemory<InputType> &a, int lda,
869 const DeviceMemory<InputType> &b, int ldb, DeviceMemory<OutputType> *c,
870 int ldc, blas::ComputationType computation_type,
871 blas::AlgorithmType algorithm,
872 blas::ProfileResult *output_profile_result) {
873 OutputType alpha{1};
874 OutputType beta{0};
875 return ThenBlasGemmWithAlgorithm(transa, transb, m, n, k, alpha, a, lda, b,
876 ldb, beta, c, ldc, computation_type,
877 algorithm, blas::kDefaultComputePrecision,
878 output_profile_result);
879 }
880
881 template <typename InputType, typename OutputType, typename ConstantType>
882 port::Status ThenBlasGemmWithAlgorithm(
883 blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
884 uint64_t k, ConstantType alpha, const DeviceMemory<InputType> &a, int lda,
885 const DeviceMemory<InputType> &b, int ldb, ConstantType beta,
886 DeviceMemory<OutputType> *c, int ldc,
887 blas::ComputationType computation_type, blas::AlgorithmType algorithm,
888 blas::ComputePrecision precision,
889 blas::ProfileResult *output_profile_result) {
890 TF_RETURN_IF_ERROR(
891 CheckTypesForExtendedBlas<InputType, OutputType, ConstantType>(
892 computation_type));
893
894 blas::BlasSupport *blas = parent()->AsBlas();
895 if (!blas) {
896 return port::InternalError(
897 "Attempting to perform BLAS operation using "
898 "StreamExecutor without BLAS support");
899 }
900
901 void *alpha_ptr = &alpha;
902 void *beta_ptr = &beta;
903 float alpha_storage, beta_storage;
904 UpcastHalfToFloat<ConstantType>(&alpha_ptr, &beta_ptr, &alpha_storage,
905 &beta_storage);
906
907 port::Status st = blas->DoBlasGemmWithAlgorithm(
908 this, transa, transb, m, n, k, alpha_ptr, a,
909 blas::ToDataType<InputType>::value, lda, b,
910 blas::ToDataType<InputType>::value, ldb, beta_ptr, c,
911 blas::ToDataType<OutputType>::value, ldc, computation_type, algorithm,
912 precision, output_profile_result);
913 if (output_profile_result) {
914 // The error is recorded in the profile.
915 return ::tsl::OkStatus();
916 }
917 return st;
918 }
919
920 template <typename InputType, typename OutputType, typename ConstantType>
921 port::Status ThenBlasGemmStridedBatchedWithAlgorithm(
922 blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
923 uint64_t k, ConstantType alpha, const DeviceMemory<InputType> &a, int lda,
924 int64_t stride_a, const DeviceMemory<InputType> &b, int ldb,
925 int64_t stride_b, ConstantType beta, DeviceMemory<OutputType> *c, int ldc,
926 int64_t stride_c, int batch_count, blas::ComputationType computation_type,
927 blas::AlgorithmType algorithm, blas::ComputePrecision precision,
928 blas::ProfileResult *output_profile_result) {
929 TF_RETURN_IF_ERROR(
930 CheckTypesForExtendedBlas<InputType, OutputType, ConstantType>(
931 computation_type));
932
933 blas::BlasSupport *blas = parent()->AsBlas();
934 if (!blas) {
935 return port::InternalError(
936 "Attempting to perform BLAS operation using "
937 "StreamExecutor without BLAS support");
938 }
939 void *alpha_ptr = &alpha;
940 void *beta_ptr = &beta;
941 float alpha_storage, beta_storage;
942 UpcastHalfToFloat<ConstantType>(&alpha_ptr, &beta_ptr, &alpha_storage,
943 &beta_storage);
944 port::Status st = blas->DoBlasGemmStridedBatchedWithAlgorithm(
945 this, transa, transb, m, n, k, alpha_ptr, a,
946 blas::ToDataType<InputType>::value, lda, stride_a, b,
947 blas::ToDataType<InputType>::value, ldb, stride_b, beta_ptr, c,
948 blas::ToDataType<OutputType>::value, ldc, stride_c, batch_count,
949 computation_type, algorithm, precision, output_profile_result);
950 if (output_profile_result) {
951 // The error is recorded in the profile.
952 return ::tsl::OkStatus();
953 }
954 return st;
955 }
956
957 template <typename T>
958 using DeviceMemorySlice = port::ArraySlice<DeviceMemory<T> *>; // non-absl ok
959
960 // See BlasSupport::DoBlasGemmBatched.
961 Stream &ThenBlasGemmBatched(blas::Transpose transa, blas::Transpose transb,
962 uint64_t m, uint64 n, uint64_t k, float alpha,
963 const DeviceMemorySlice<Eigen::half> &a, int lda,
964 const DeviceMemorySlice<Eigen::half> &b, int ldb,
965 float beta,
966 const DeviceMemorySlice<Eigen::half> &c, int ldc,
967 int batch_count);
968 Stream &ThenBlasGemmBatched(blas::Transpose transa, blas::Transpose transb,
969 uint64_t m, uint64 n, uint64 k, float alpha,
970 const DeviceMemorySlice<float> &a, int lda,
971 const DeviceMemorySlice<float> &b, int ldb,
972 float beta, const DeviceMemorySlice<float> &c,
973 int ldc, int batch_count);
974 Stream &ThenBlasGemmBatched(
975 blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
976 uint64 k, double alpha,
977 const port::ArraySlice<DeviceMemory<double> *> &a, // non-absl ok
978 int lda,
979 const port::ArraySlice<DeviceMemory<double> *> &b, // non-absl ok
980 int ldb, double beta,
981 const port::ArraySlice<DeviceMemory<double> *> &c, // non-absl ok
982 int ldc, int batch_count);
983 Stream &ThenBlasGemmBatched(blas::Transpose transa, blas::Transpose transb,
984 uint64_t m, uint64 n, uint64_t k,
985 std::complex<float> alpha,
986 const DeviceMemorySlice<std::complex<float>> &a,
987 int lda,
988 const DeviceMemorySlice<std::complex<float>> &b,
989 int ldb, std::complex<float> beta,
990 const DeviceMemorySlice<std::complex<float>> &c,
991 int ldc, int batch_count);
992 Stream &ThenBlasGemmBatched(blas::Transpose transa, blas::Transpose transb,
993 uint64_t m, uint64 n, uint64_t k,
994 std::complex<double> alpha,
995 const DeviceMemorySlice<std::complex<double>> &a,
996 int lda,
997 const DeviceMemorySlice<std::complex<double>> &b,
998 int ldb, std::complex<double> beta,
999 const DeviceMemorySlice<std::complex<double>> &c,
1000 int ldc, int batch_count);
1001 Stream &ThenBlasGemmBatchedWithScratch(
1002 blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
1003 uint64_t k, float alpha, const DeviceMemorySlice<Eigen::half> &a, int lda,
1004 const DeviceMemorySlice<Eigen::half> &b, int ldb, float beta,
1005 const DeviceMemorySlice<Eigen::half> &c, int ldc, int batch_count,
1006 ScratchAllocator *scratch_allocator);
1007 Stream &ThenBlasGemmBatchedWithScratch(
1008 blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
1009 uint64_t k, float alpha, const DeviceMemorySlice<float> &a, int lda,
1010 const DeviceMemorySlice<float> &b, int ldb, float beta,
1011 const DeviceMemorySlice<float> &c, int ldc, int batch_count,
1012 ScratchAllocator *scratch_allocator);
1013 Stream &ThenBlasGemmBatchedWithScratch(
1014 blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
1015 uint64_t k, double alpha, const DeviceMemorySlice<double> &a, int lda,
1016 const DeviceMemorySlice<double> &b, int ldb, double beta,
1017 const DeviceMemorySlice<double> &c, int ldc, int batch_count,
1018 ScratchAllocator *scratch_allocator);
1019 Stream &ThenBlasGemmBatchedWithScratch(
1020 blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
1021 uint64_t k, std::complex<float> alpha,
1022 const DeviceMemorySlice<std::complex<float>> &a, int lda,
1023 const DeviceMemorySlice<std::complex<float>> &b, int ldb,
1024 std::complex<float> beta, const DeviceMemorySlice<std::complex<float>> &c,
1025 int ldc, int batch_count, ScratchAllocator *scratch_allocator);
1026 Stream &ThenBlasGemmBatchedWithScratch(
1027 blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
1028 uint64_t k, std::complex<double> alpha,
1029 const DeviceMemorySlice<std::complex<double>> &a, int lda,
1030 const DeviceMemorySlice<std::complex<double>> &b, int ldb,
1031 std::complex<double> beta,
1032 const DeviceMemorySlice<std::complex<double>> &c, int ldc,
1033 int batch_count, ScratchAllocator *scratch_allocator);
1034
1035 template <typename InputType, typename ConstantType>
1036 port::Status ThenBlasGemmStridedBatched(
1037 blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
1038 uint64_t k, ConstantType alpha, const DeviceMemory<InputType> &a, int lda,
1039 int64_t stride_a, const DeviceMemory<InputType> &b, int ldb,
1040 int64_t stride_b, ConstantType beta, DeviceMemory<InputType> *c, int ldc,
1041 int64_t stride_c, int batch_count, blas::ComputePrecision precision) {
1042 static_assert(
1043 detail::is_any_of<InputType, float, Eigen::half, Eigen::bfloat16,
1044 double, std::complex<float>, std::complex<double>>(),
1045 "Unsupported input type");
1046 static_assert(
1047 std::is_same_v<ConstantType, InputType> ||
1048 (detail::is_any_of<InputType, Eigen::half, Eigen::bfloat16>() &&
1049 std::is_same_v<ConstantType, float>),
1050 "Mismatched input and alpha/beta types");
1051 blas::BlasSupport *blas = parent()->AsBlas();
1052 if (!blas) {
1053 return port::InternalError(
1054 "Attempting to perform BLAS operation using "
1055 "StreamExecutor without BLAS support");
1056 }
1057
1058 void *alpha_ptr = &alpha;
1059 void *beta_ptr = &beta;
1060 float alpha_storage, beta_storage;
1061 UpcastHalfToFloat<ConstantType>(&alpha_ptr, &beta_ptr, &alpha_storage,
1062 &beta_storage);
1063
1064 return blas->DoBlasGemmStridedBatched(
1065 this, transa, transb, m, n, k, blas::ToDataType<InputType>::value,
1066 alpha_ptr, a, lda, stride_a, b, ldb, stride_b, beta_ptr, c, ldc,
1067 stride_c, batch_count, precision);
1068 }
1069
1070 // See BlasSupport::DoBlasTrsm.
1071 Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
1072 blas::Transpose transa, blas::Diagonal diag, uint64_t m,
1073 uint64_t n, float alpha, const DeviceMemory<float> &a,
1074 int lda, DeviceMemory<float> *b, int ldb);
1075 Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
1076 blas::Transpose transa, blas::Diagonal diag, uint64_t m,
1077 uint64_t n, double alpha, const DeviceMemory<double> &a,
1078 int lda, DeviceMemory<double> *b, int ldb);
1079 Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
1080 blas::Transpose transa, blas::Diagonal diag, uint64_t m,
1081 uint64_t n, std::complex<float> alpha,
1082 const DeviceMemory<std::complex<float>> &a, int lda,
1083 DeviceMemory<std::complex<float>> *b, int ldb);
1084 Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
1085 blas::Transpose transa, blas::Diagonal diag, uint64_t m,
1086 uint64_t n, std::complex<double> alpha,
1087 const DeviceMemory<std::complex<double>> &a, int lda,
1088 DeviceMemory<std::complex<double>> *b, int ldb);
1089
1090 // See BlasSupport::DoBlasTrsmBatched.
1091 Stream &ThenBlasTrsmBatched(blas::Side side, blas::UpperLower uplo,
1092 blas::Transpose transa, blas::Diagonal diag,
1093 uint64_t m, uint64 n, float alpha,
1094 const DeviceMemory<float *> &as, int lda,
1095 DeviceMemory<float *> *bs, int ldb,
1096 int batch_count);
1097 Stream &ThenBlasTrsmBatched(blas::Side side, blas::UpperLower uplo,
1098 blas::Transpose transa, blas::Diagonal diag,
1099 uint64_t m, uint64 n, double alpha,
1100 const DeviceMemory<double *> &as, int lda,
1101 DeviceMemory<double *> *bs, int ldb,
1102 int batch_count);
1103 Stream &ThenBlasTrsmBatched(blas::Side side, blas::UpperLower uplo,
1104 blas::Transpose transa, blas::Diagonal diag,
1105 uint64_t m, uint64 n, std::complex<float> alpha,
1106 const DeviceMemory<std::complex<float> *> &as,
1107 int lda, DeviceMemory<std::complex<float> *> *bs,
1108 int ldb, int batch_count);
1109 Stream &ThenBlasTrsmBatched(blas::Side side, blas::UpperLower uplo,
1110 blas::Transpose transa, blas::Diagonal diag,
1111 uint64_t m, uint64 n, std::complex<double> alpha,
1112 const DeviceMemory<std::complex<double> *> &as,
1113 int lda, DeviceMemory<std::complex<double> *> *bs,
1114 int ldb, int batch_count);
1115
1116 // See FftSupport::DoFft.
1117 Stream &ThenFft(fft::Plan *plan,
1118 const DeviceMemory<std::complex<float>> &input,
1119 DeviceMemory<std::complex<float>> *output);
1120 Stream &ThenFft(fft::Plan *plan,
1121 const DeviceMemory<std::complex<double>> &input,
1122 DeviceMemory<std::complex<double>> *output);
1123 Stream &ThenFft(fft::Plan *plan, const DeviceMemory<float> &input,
1124 DeviceMemory<std::complex<float>> *output);
1125 Stream &ThenFft(fft::Plan *plan, const DeviceMemory<double> &input,
1126 DeviceMemory<std::complex<double>> *output);
1127 Stream &ThenFft(fft::Plan *plan,
1128 const DeviceMemory<std::complex<float>> &input,
1129 DeviceMemory<float> *output);
1130 Stream &ThenFft(fft::Plan *plan,
1131 const DeviceMemory<std::complex<double>> &input,
1132 DeviceMemory<double> *output);
1133
1134 // Makes the RNG use the provided value as the basis for further generation.
1135 // /dev/urandom (good) and /dev/random (better, but sometimes slow) are good
1136 // sources of seed data if the default (high quality) sources are not
1137 // desired.
1138 // For most use cases, this function will not be necessary; each provided
1139 // back-end implementation will be appropriately seeded by default.
1140 // At a minimum 16 bytes of data are required in the seed buffer.
1141 //
1142 // To seed with good (non-reproducible) data:
1143 // File* f = File::Open("/dev/random", "r");
1144 // int64_t bytes_read = f->Read(seed_data, bytes_to_read);
1145 // < error checking >
1146 // stream.ThenSetRngSeed(seed_data, bytes_read);
1147 //
1148 // To seed with reproducible data:
1149 // uint64_t seed_data[2] = { <data> };
1150 // stream.ThenSetRngSeed(seed_data, 16);
1151 Stream &ThenSetRngSeed(const uint8 *seed, uint64_t seed_bytes);
1152
1153 // Populates the memory indicated by values with uniform-random-distribution
1154 // values. TODO(leary) seeding API/description
1155 //
1156 // Uses the type and size of the DeviceMemory to infer what data should be
1157 // populated.
1158 Stream &ThenPopulateRandUniform(DeviceMemory<float> *values);
1159 Stream &ThenPopulateRandUniform(DeviceMemory<double> *values);
1160 Stream &ThenPopulateRandUniform(DeviceMemory<std::complex<float>> *values);
1161 Stream &ThenPopulateRandUniform(DeviceMemory<std::complex<double>> *values);
1162 Stream &ThenPopulateRandGaussian(float mean, float stddev,
1163 DeviceMemory<float> *values);
1164 Stream &ThenPopulateRandGaussian(double mean, double stddev,
1165 DeviceMemory<double> *values);
1166
1167 // Entrain onto the stream: a memcpy to a host destination from a GPU source
1168 // of the given target size. host_dst must be a pointer to host memory
1169 // allocated by StreamExecutor::HostMemoryAllocate or otherwise allocated and
1170 // then registered with StreamExecutor::HostMemoryRegister.
1171 Stream &ThenMemcpy(void *host_dst, const DeviceMemoryBase &gpu_src,
1172 uint64_t size);
1173
1174 // Entrain onto the stream: a memcpy to a GPU destination from a host source
1175 // of the given target size. host_src must be a pointer to host memory
1176 // allocated by StreamExecutor::HostMemoryAllocate or otherwise allocated and
1177 // then registered with StreamExecutor::HostMemoryRegister.
1178 Stream &ThenMemcpy(DeviceMemoryBase *gpu_dst, const void *host_src,
1179 uint64_t size);
1180
1181 // Alternative interface for memcpying from device to host that takes an
1182 // array slice. Checks that the destination size can accommodate the host
1183 // slice size.
1184 template <typename T>
1185 Stream &ThenMemcpyD2H(const DeviceMemory<T> &gpu_src,
1186 port::MutableArraySlice<T> host_dst) {
1187 auto host_size = host_dst.size() * sizeof(T);
1188 CHECK(gpu_src.size() == 0 || host_size >= gpu_src.size());
1189 return ThenMemcpy(host_dst.begin(), gpu_src, host_size);
1190 }
1191
1192 // Alternative interface for memcpying from host to device that takes an
1193 // array slice. Checks that the destination size can accommodate the host
1194 // slice size.
1195 template <typename T>
1196 Stream &ThenMemcpyH2D(port::ArraySlice<T> host_src, // non-absl ok
1197 DeviceMemory<T> *gpu_dst) {
1198 auto host_size = host_src.size() * sizeof(T);
1199 CHECK(gpu_dst->size() == 0 || gpu_dst->size() >= host_size);
1200 return ThenMemcpy(gpu_dst, host_src.begin(), host_size);
1201 }
1202
1203 // Entrain onto the stream: a memcpy to a GPU destination from a GPU source
1204 // of the given target size. gpu_src/dst must be pointers to GPU memory and
1205 // peer access must be enabled between their owning StreamExecutors.
1206 Stream &ThenMemcpy(DeviceMemoryBase *gpu_dst, const DeviceMemoryBase &gpu_src,
1207 uint64_t size);
1208
1209 // Calls to the device-to-device copy overload of ThenMemcpy -- useful for
1210 // ensuring that the host pointer isn't getting confused accidentally with a
1211 // device pointer if you're not doing metaprogramming against the API.
1212 Stream &ThenMemcpyD2D(DeviceMemoryBase *gpu_dst,
1213 const DeviceMemoryBase &gpu_src, uint64_t size) {
1214 return ThenMemcpy(gpu_dst, gpu_src, size);
1215 }
1216
1217 // Entrain onto the stream: a memset of zero at a GPU location of size bytes.
1218 // The location must not be null.
1219 Stream &ThenMemZero(DeviceMemoryBase *location, uint64_t size);
1220
1221 // Entrain onto the stream: a memset of a 32-bit pattern at a GPU location of
1222 // size bytes, where bytes must be evenly 32-bit sized (i.e. evenly divisible
1223 // by 4). The location must not be null.
1224 Stream &ThenMemset32(DeviceMemoryBase *location, uint32_t pattern,
1225 uint64_t size);
1226
1227 // Enqueue a forward operation of the RNN model onto the stream.
1228 // See DnnSupport::DoRnnForward for more details.
1229 Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc,
1230 const dnn::RnnSequenceTensorDescriptor &input_desc,
1231 const DeviceMemory<Eigen::half> &input_data,
1232 const DeviceMemory<int> &seq_lengths_data,
1233 const dnn::RnnStateTensorDescriptor &input_h_desc,
1234 const DeviceMemory<Eigen::half> &input_h_data,
1235 const dnn::RnnStateTensorDescriptor &input_c_desc,
1236 const DeviceMemory<Eigen::half> &input_c_data,
1237 const DeviceMemory<Eigen::half> &params,
1238 const dnn::RnnSequenceTensorDescriptor &output_desc,
1239 DeviceMemory<Eigen::half> *output_data,
1240 const dnn::RnnStateTensorDescriptor &output_h_desc,
1241 DeviceMemory<Eigen::half> *output_h_data,
1242 const dnn::RnnStateTensorDescriptor &output_c_desc,
1243 DeviceMemory<Eigen::half> *output_c_data,
1244 bool is_training,
1245 ScratchAllocator *reserve_space_allocator,
1246 ScratchAllocator *workspace_allocator,
1247 dnn::ProfileResult *output_profile_result);
1248
1249 Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc,
1250 const dnn::RnnSequenceTensorDescriptor &input_desc,
1251 const DeviceMemory<float> &input_data,
1252 const DeviceMemory<int> &seq_lengths_data,
1253 const dnn::RnnStateTensorDescriptor &input_h_desc,
1254 const DeviceMemory<float> &input_h_data,
1255 const dnn::RnnStateTensorDescriptor &input_c_desc,
1256 const DeviceMemory<float> &input_c_data,
1257 const DeviceMemory<float> &params,
1258 const dnn::RnnSequenceTensorDescriptor &output_desc,
1259 DeviceMemory<float> *output_data,
1260 const dnn::RnnStateTensorDescriptor &output_h_desc,
1261 DeviceMemory<float> *output_h_data,
1262 const dnn::RnnStateTensorDescriptor &output_c_desc,
1263 DeviceMemory<float> *output_c_data, bool is_training,
1264 ScratchAllocator *reserve_space_allocator,
1265 ScratchAllocator *workspace_allocator,
1266 dnn::ProfileResult *output_profile_result);
1267
1268 Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc,
1269 const dnn::RnnSequenceTensorDescriptor &input_desc,
1270 const DeviceMemory<double> &input_data,
1271 const DeviceMemory<int> &seq_lengths_data,
1272 const dnn::RnnStateTensorDescriptor &input_h_desc,
1273 const DeviceMemory<double> &input_h_data,
1274 const dnn::RnnStateTensorDescriptor &input_c_desc,
1275 const DeviceMemory<double> &input_c_data,
1276 const DeviceMemory<double> &params,
1277 const dnn::RnnSequenceTensorDescriptor &output_desc,
1278 DeviceMemory<double> *output_data,
1279 const dnn::RnnStateTensorDescriptor &output_h_desc,
1280 DeviceMemory<double> *output_h_data,
1281 const dnn::RnnStateTensorDescriptor &output_c_desc,
1282 DeviceMemory<double> *output_c_data, bool is_training,
1283 ScratchAllocator *reserve_space_allocator,
1284 ScratchAllocator *workspace_allocator,
1285 dnn::ProfileResult *output_profile_result);
1286
1287 // Enqueue a backward operation of the RNN model onto the stream.
1288 // See DnnSupport::DoRnnBackward for more details.
1289 Stream &ThenRnnBackward(
1290 const dnn::RnnDescriptor &rnn_desc,
1291 const dnn::RnnSequenceTensorDescriptor &input_desc,
1292 const DeviceMemory<Eigen::half> &input_data,
1293 const DeviceMemory<int> &seq_lengths_data,
1294 const dnn::RnnStateTensorDescriptor &input_h_desc,
1295 const DeviceMemory<Eigen::half> &input_h_data,
1296 const dnn::RnnStateTensorDescriptor &input_c_desc,
1297 const DeviceMemory<Eigen::half> &input_c_data,
1298 const DeviceMemory<Eigen::half> &params,
1299 const dnn::RnnSequenceTensorDescriptor &output_desc,
1300 const DeviceMemory<Eigen::half> &output_data,
1301 const dnn::RnnStateTensorDescriptor &output_h_desc,
1302 const DeviceMemory<Eigen::half> &output_h_data,
1303 const dnn::RnnStateTensorDescriptor &output_c_desc,
1304 const DeviceMemory<Eigen::half> &output_c_data,
1305 const DeviceMemory<Eigen::half> &output_backprop_data,
1306 const DeviceMemory<Eigen::half> &output_h_backprop_data,
1307 const DeviceMemory<Eigen::half> &output_c_backprop_data,
1308 DeviceMemory<Eigen::half> *input_backprop_data,
1309 DeviceMemory<Eigen::half> *input_h_backprop_data,
1310 DeviceMemory<Eigen::half> *input_c_backprop_data,
1311 DeviceMemory<Eigen::half> *params_backprop_data,
1312 DeviceMemory<uint8_t> *reserve_space_data,
1313 ScratchAllocator *workspace_allocator,
1314 dnn::ProfileResult *output_profile_result);
1315
1316 Stream &ThenRnnBackward(const dnn::RnnDescriptor &rnn_desc,
1317 const dnn::RnnSequenceTensorDescriptor &input_desc,
1318 const DeviceMemory<float> &input_data,
1319 const DeviceMemory<int> &seq_lengths_data,
1320 const dnn::RnnStateTensorDescriptor &input_h_desc,
1321 const DeviceMemory<float> &input_h_data,
1322 const dnn::RnnStateTensorDescriptor &input_c_desc,
1323 const DeviceMemory<float> &input_c_data,
1324 const DeviceMemory<float> &params,
1325 const dnn::RnnSequenceTensorDescriptor &output_desc,
1326 const DeviceMemory<float> &output_data,
1327 const dnn::RnnStateTensorDescriptor &output_h_desc,
1328 const DeviceMemory<float> &output_h_data,
1329 const dnn::RnnStateTensorDescriptor &output_c_desc,
1330 const DeviceMemory<float> &output_c_data,
1331 const DeviceMemory<float> &output_backprop_data,
1332 const DeviceMemory<float> &output_h_backprop_data,
1333 const DeviceMemory<float> &output_c_backprop_data,
1334 DeviceMemory<float> *input_backprop_data,
1335 DeviceMemory<float> *input_h_backprop_data,
1336 DeviceMemory<float> *input_c_backprop_data,
1337 DeviceMemory<float> *params_backprop_data,
1338 DeviceMemory<uint8_t> *reserve_space_data,
1339 ScratchAllocator *workspace_allocator,
1340 dnn::ProfileResult *output_profile_result);
1341
1342 Stream &ThenRnnBackward(const dnn::RnnDescriptor &rnn_desc,
1343 const dnn::RnnSequenceTensorDescriptor &input_desc,
1344 const DeviceMemory<double> &input_data,
1345 const DeviceMemory<int> &seq_lengths_data,
1346 const dnn::RnnStateTensorDescriptor &input_h_desc,
1347 const DeviceMemory<double> &input_h_data,
1348 const dnn::RnnStateTensorDescriptor &input_c_desc,
1349 const DeviceMemory<double> &input_c_data,
1350 const DeviceMemory<double> &params,
1351 const dnn::RnnSequenceTensorDescriptor &output_desc,
1352 const DeviceMemory<double> &output_data,
1353 const dnn::RnnStateTensorDescriptor &output_h_desc,
1354 const DeviceMemory<double> &output_h_data,
1355 const dnn::RnnStateTensorDescriptor &output_c_desc,
1356 const DeviceMemory<double> &output_c_data,
1357 const DeviceMemory<double> &output_backprop_data,
1358 const DeviceMemory<double> &output_h_backprop_data,
1359 const DeviceMemory<double> &output_c_backprop_data,
1360 DeviceMemory<double> *input_backprop_data,
1361 DeviceMemory<double> *input_h_backprop_data,
1362 DeviceMemory<double> *input_c_backprop_data,
1363 DeviceMemory<double> *params_backprop_data,
1364 DeviceMemory<uint8_t> *reserve_space_data,
1365 ScratchAllocator *workspace_allocator,
1366 dnn::ProfileResult *output_profile_result);
1367
1368 // Enqueue a CTCLoss operation onto the stream.
1369 // See DnnSupport::DoCtcLoss for more details.
1370 Stream &ThenCtcLoss(const dnn::RnnStateTensorDescriptor &probs_desc,
1371 const DeviceMemory<float> &probs_data,
1372 absl::Span<const int> labels_data,
1373 absl::Span<const int> labels_lengths_data,
1374 absl::Span<const int> input_lengths_data,
1375 DeviceMemory<float> *costs_data,
1376 const dnn::RnnStateTensorDescriptor &grads_desc,
1377 DeviceMemory<float> *grads_data,
1378 ScratchAllocator *workspace_allocator);
1379
1380 // Enqueue onto the stream a operation that transforms a tensor.
1381 // See DnnSupport::DoTransformTensor for more details.
1382 Stream &ThenTransformTensor(const dnn::BatchDescriptor &input_desc,
1383 dnn::DataType input_type,
1384 const DeviceMemoryBase &input_data,
1385 const dnn::BatchDescriptor &output_desc,
1386 dnn::DataType output_type, float scale,
1387 DeviceMemoryBase *output_data);
1388
1389 // The templated version of the above ThenTransformTensor. Useful when the
1390 // input and output types are statically known.
1391 template <typename InElemT, typename OutElemT>
1392 Stream &ThenTransformTensor(const dnn::BatchDescriptor &input_desc,
1393 const DeviceMemory<InElemT> &input_data,
1394 const dnn::BatchDescriptor &output_desc,
1395 DeviceMemory<OutElemT> *output_data) {
1396 return ThenTransformTensor(input_desc, dnn::ToDataType<InElemT>(),
1397 input_data, output_desc,
1398 dnn::ToDataType<OutElemT>(), output_data);
1399 }
1400
1401 // (Synchronously) block the host code waiting for the operations
1402 // entrained on the stream (enqueued to this point in program
1403 // execution) to complete.
1404 //
1405 // Returns an OK status if the blocking was successful and the stream is ok().
1406 // Otherwise returns an error describing why the blocking failed.
1407 port::Status BlockHostUntilDone() TF_LOCKS_EXCLUDED(mu_);
1408
1409 // Warning! This method interacts with internal threads in
1410 // sometimes-unpredictable ways and is intended for GPU-Executor-internal
1411 // use
1412 // only. Please check with a member of the FASTR team before making use of
1413 // this method.
1414 //
1415 // Entrains onto the stream a function to be executed on the host at some
1416 // point in the future.
1417 // Async host callbacks DO NOT block the stream as device functions (or as
1418 // synchronous host callbacks). No synchronization is possible with
1419 // asynchronous callbacks; they are strictly fire-and-forget.
1420 // This method is private due to the potential for undefined behavior with
1421 // synchronization using OpenCL user events.
1422 // The ONLY lifetime guarantee in these calls is that the StreamExecutor
1423 // parameter will still be valid - this Stream may not be!
1424 // Any callbacks requiring device API calls must use this method.
1425 Stream &ThenEnqueueOnBackgroundThread(
1426 std::function<void(StreamExecutor *)> task);
1427
1428 // Returns the (opaque) platform-specific backing object. Ownership is not
1429 // transferred to the caller.
1430 internal::StreamInterface *implementation() { return implementation_.get(); }
1431
1432 // Entrains onto the stream a callback to the host (from the device).
1433 // Behaves as ThenDoHostCallbackWithStatus below, but the callback should
1434 // never fail or its failure is inconsequential.
1435 //
1436 // This is kept for backward compatibility. Future code should use
1437 // ThenDoHostCallbackWithStatus and explicitly return a success status.
1438 // TODO(b/112125301): Eventually remove this method.
1439 Stream &ThenDoHostCallback(std::function<void()> callback);
1440
1441 // Entrains onto the stream a callback to the host (from the device).
1442 // Host callbacks block/occupy the stream just as device functions
1443 // (execute one at a time, block later stream operations).
1444 // Whether the callback return status affects the result of BlockHostUntilDone
1445 // is platform-dependent.
1446 //
1447 // Behavior is undefined when synchronizing using OpenCL user events.
1448 // Behavior is undefined if host callbacks call device routines or insert
1449 // them into any stream.
1450 //
1451 // On certain platforms, ThenDoHostCallback is expected to have significant
1452 // negative effects on performance.
1453 Stream &ThenDoHostCallbackWithStatus(std::function<port::Status()> callback);
1454
1455 // Runs the given callback after the next call to BlockHostUntilDone on this
1456 // stream (or after the Stream does BlockHostUntilDone in its destructor).
1457 // This can act as a faster alternative to ThenDoHostCallbackWithStatus for
1458 // some use cases.
1459 Stream &ThenRunAfterNextBlockHostUntilDone(std::function<void()> callback);
1460
1461 // Returns the StreamExecutor (parent object) associated with this stream.
1462 StreamExecutor *parent() const {
1463 CHECK(parent_ != nullptr);
1464 return parent_;
1465 }
1466
1467 //
1468 CudaComputeCapability GetCudaComputeCapability() const {
1469 return parent()->GetDeviceDescription().cuda_compute_capability();
1470 }
1471
1472 RocmComputeCapability GetRocmComputeCapability() const {
1473 return parent()->GetDeviceDescription().rocm_compute_capability();
1474 }
1475 // Returns the (internal usage) temporary-memory-allocation manager associated
1476 // with this stream.
1477 internal::TemporaryMemoryManager *temporary_memory_manager();
1478
1479 // Returns a debugging string "[stream=0x...,impl=0x...]".
1480 std::string DebugStreamPointers() const;
1481
1482 private:
1483 friend class host::HostBlas; // for parent_.
1484 friend class host::HostFft; // for parent_.
1485 friend class host::HostRng; // for parent_.
1486 template <typename... Args>
1487 friend struct ThenBlasImpl; // for implementing ThenBlasXXX.
1488 friend class ocl::CLBlas; // for parent_.
1489
1490 // Checks whether types match before a call to extended BLAS version.
1491 template <typename ABType, typename CType, typename ScaleType>
1492 port::Status CheckTypesForExtendedBlas(
1493 blas::ComputationType computation_type) {
1494 static_assert(
1495 detail::is_any_of<ABType, Eigen::half, Eigen::bfloat16, float, double,
1496 int8_t, std::complex<float>, std::complex<double>>(),
1497 "The only buffer types supported are: Eigen::half, float, "
1498 "double, int8, std::complex<float> and std::complex<double>");
1499 static_assert(
1500 std::is_same_v<ABType, CType> ||
1501 (std::is_same_v<ABType, int8_t> && std::is_same_v<CType, int32_t>),
1502 "Input and output buffer types should be the same unless input is "
1503 "int8 and output is int32");
1504 static_assert(
1505 std::is_same_v<ScaleType, CType> ||
1506 (std::is_same_v<ScaleType, float> &&
1507 detail::is_any_of<CType, Eigen::half, Eigen::bfloat16>()),
1508 "Mismatched alpha/beta and output types");
1509
1510 bool valid_computation_type = [computation_type] {
1511 switch (computation_type) {
1512 case blas::ComputationType::kF16:
1513 return std::is_same_v<CType, Eigen::half>;
1514 case blas::ComputationType::kF32:
1515 return detail::is_any_of<CType, Eigen::half, Eigen::bfloat16, float,
1516 std::complex<float>>();
1517 case blas::ComputationType::kF64:
1518 return detail::is_any_of<CType, double, std::complex<double>>();
1519 case blas::ComputationType::kI32:
1520 return std::is_same_v<CType, int32_t>;
1521 case blas::ComputationType::kF16AsF32: // fall-through
1522 case blas::ComputationType::kBF16AsF32: // fall-through
1523 case blas::ComputationType::kTF32AsF32:
1524 return detail::is_any_of<CType, float, std::complex<float>>();
1525 }
1526 }();
1527
1528 if (!valid_computation_type) {
1529 return port::InternalError(absl::StrCat(
1530 "Invalid computation type ",
1531 blas::ComputationTypeString(computation_type), " for output type: ",
1532 blas::DataTypeString(blas::ToDataType<CType>::value)));
1533 }
1534 return ::tsl::OkStatus();
1535 }
1536
1537 bool InErrorState() const TF_LOCKS_EXCLUDED(mu_) {
1538 absl::ReaderMutexLock lock(&mu_);
1539 return !status_.ok();
1540 }
1541
1542 // Sets the error state if operation_retcode is false.
1543 // This is a useful shorthand for many stream routines.
1544 void CheckError(bool operation_retcode) TF_LOCKS_EXCLUDED(mu_);
1545
1546 // Checks the status and logs the error message, if any.
1547 void CheckStatus(port::Status status) TF_LOCKS_EXCLUDED(mu_);
1548
1549 void SetError() { CheckError(false /* = operation_retcode */); }
1550
1551 void SetErrorAndLogNoDnnSupport() {
1552 SetError();
1553 LOG(WARNING) << "attempting to perform DNN operation using StreamExecutor "
1554 "without DNN support";
1555 }
1556
1557 // Runs the set of callbacks that are intended to run after
1558 // BlockHostUntilDone.
1559 void RunAfterBlockHostUntilDoneCallbacks();
1560
1561 // The StreamExecutor that supports the operation of this stream.
1562 StreamExecutor *parent_;
1563
1564 // The platform-dependent implementation that the StreamExecutor interface
1565 // delegates to.
1566 std::unique_ptr<internal::StreamInterface> implementation_;
1567
1568 // mutex that guards the allocation / error state flags.
1569 // Mutable so that it can be obtained via const reader lock.
1570 mutable absl::Mutex mu_;
1571
1572 // Whether Init() was successfully called to allocate this stream on the
1573 // underlying platform. It simply flips from 0 to 1 with a sanity check.
1574 // See StreamExecutor::AllocateStream.
1575 bool allocated_ ABSL_GUARDED_BY(mu_);
1576
1577 // The last error (if any) of all method calls.
1578 port::Status status_ ABSL_GUARDED_BY(mu_);
1579
1580 // Sub-streams that are generated from this stream. Each element has a pointer
1581 // to sub-stream and a boolean value indicating if this substream is ready to
1582 // be reused.
1583 std::vector<std::pair<std::unique_ptr<Stream>, bool>> sub_streams_
1584 ABSL_GUARDED_BY(mu_);
1585
1586 // Streams can allocate temporary memories to help with work they enqueue
1587 // (e.g. for scratch memory spaces). This member tracks those allocations and
1588 // notes when they can be reclaimed -- reclamation is attempted when
1589 // BlockHostUntilDone() is called.
1590 internal::TemporaryMemoryManager temporary_memory_manager_;
1591
1592 // Callbacks enqueued to be run after the next call to BlockHostUntilDone().
1593 std::vector<std::function<void()>> after_block_host_until_done_callbacks_
1594 ABSL_GUARDED_BY(mu_);
1595
1596 // Non-extended BLAS interface requires alpha/beta to be floats when input
1597 // type is Eigen::half. However, for consistency purposes it is convenient
1598 // for the interface to accept Eigen::half.
1599 template <typename T>
1600 void UpcastHalfToFloat(void **alpha_ptr, void **beta_ptr,
1601 float *alpha_storage, float *beta_storage) {
1602 if (std::is_same<T, Eigen::half>::value) {
1603 *alpha_storage =
1604 static_cast<float>(*reinterpret_cast<Eigen::half *>(*alpha_ptr));
1605 *beta_storage =
1606 static_cast<float>(*reinterpret_cast<Eigen::half *>(*beta_ptr));
1607 *alpha_ptr = alpha_storage;
1608 *beta_ptr = beta_storage;
1609 } else if (std::is_same<T, Eigen::bfloat16>::value) {
1610 *alpha_storage =
1611 static_cast<float>(*reinterpret_cast<Eigen::bfloat16 *>(*alpha_ptr));
1612 *beta_storage =
1613 static_cast<float>(*reinterpret_cast<Eigen::bfloat16 *>(*beta_ptr));
1614 *alpha_ptr = alpha_storage;
1615 *beta_ptr = beta_storage;
1616 }
1617 }
1618
1619 SE_DISALLOW_COPY_AND_ASSIGN(Stream);
1620};
1621
1622////////////
1623// Inlines
1624
1625template <typename... Params, typename... Args>
1626inline port::Status Stream::ThenLaunch(ThreadDim thread_dims,
1627 BlockDim block_dims,
1628 const TypedKernel<Params...> &kernel,
1629 Args... args) {
1630 KernelInvocationChecker<std::tuple<Params...>,
1631 std::tuple<Args...>>::CheckAllStaticAssert();
1632
1633 // This is the core that allows type-safe kernel launching.
1634 // Since the platforms take kernel arguments as tuples of (void *, size),
1635 // we pack the variadic parameters passed as ...args into the desired
1636 // tuple form and pass that packed form to the StreamExecutor::Launch()
1637 // implementation.
1638 KernelArgsArray<sizeof...(args)> kernel_args;
1639 kernel.PackParams(&kernel_args, args...);
1640 TF_RETURN_IF_ERROR(
1641 parent_->Launch(this, thread_dims, block_dims, kernel, kernel_args));
1642 return ::tsl::OkStatus();
1643}
1644
1645template <typename T>
1646inline port::StatusOr<std::unique_ptr<TemporaryDeviceMemory<T>>>
1647Stream::AllocateTemporaryArray(uint64_t element_count) {
1648 return temporary_memory_manager_.AllocateArray<T>(element_count);
1649}
1650
1651inline internal::TemporaryMemoryManager *Stream::temporary_memory_manager() {
1652 return &temporary_memory_manager_;
1653}
1654
1655template <>
1656struct Quantization<uint8_t> {
1657 static constexpr dnn::QuantizedActivationMode kModeId =
1658 dnn::QuantizedActivationMode::k8Bit;
1659};
1660
1661template <>
1662struct Quantization<uint16_t> {
1663 static constexpr dnn::QuantizedActivationMode kModeId =
1664 dnn::QuantizedActivationMode::k16Bit;
1665};
1666
1667template <>
1668struct Quantization<int32_t> {
1669 static constexpr dnn::QuantizedActivationMode kModeId =
1670 dnn::QuantizedActivationMode::k32Bit;
1671};
1672
1673} // namespace stream_executor
1674
1675#endif // TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_STREAM_H_
1676