1/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/ir/ops.h"
17
18#include <algorithm>
19#include <cstdint>
20#include <list>
21#include <memory>
22#include <string>
23#include <utility>
24
25#include "llvm/ADT/ArrayRef.h"
26#include "llvm/ADT/STLExtras.h"
27#include "llvm/ADT/TypeSwitch.h"
28#include "llvm/Support/ErrorHandling.h"
29#include "llvm/Support/SMLoc.h"
30#include "llvm/Support/raw_ostream.h"
31#include "mlir/IR/Builders.h" // from @llvm-project
32#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
33#include "mlir/IR/BuiltinOps.h" // from @llvm-project
34#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
35#include "mlir/IR/Diagnostics.h" // from @llvm-project
36#include "mlir/IR/Dialect.h" // from @llvm-project
37#include "mlir/IR/DialectImplementation.h" // from @llvm-project
38#include "mlir/IR/FunctionImplementation.h" // from @llvm-project
39#include "mlir/IR/FunctionInterfaces.h" // from @llvm-project
40#include "mlir/IR/MLIRContext.h" // from @llvm-project
41#include "mlir/IR/OpImplementation.h" // from @llvm-project
42#include "mlir/IR/OperationSupport.h" // from @llvm-project
43#include "mlir/IR/SymbolTable.h" // from @llvm-project
44#include "mlir/IR/TypeRange.h" // from @llvm-project
45#include "mlir/IR/TypeUtilities.h" // from @llvm-project
46#include "mlir/IR/Value.h" // from @llvm-project
47#include "mlir/Interfaces/ControlFlowInterfaces.h" // from @llvm-project
48#include "mlir/Support/LLVM.h" // from @llvm-project
49#include "mlir/Support/LogicalResult.h" // from @llvm-project
50#include "tensorflow/core/ir/dialect.h"
51#include "tensorflow/core/ir/interfaces.h"
52#include "tensorflow/core/ir/types/dialect.h"
53#include "tensorflow/core/ir/utility.h"
54
55// Generated definitions.
56#include "tensorflow/core/ir/dialect.cc.inc"
57
58namespace mlir {
59namespace tfg {
60
61//===----------------------------------------------------------------------===//
62// TFGraph dialect.
63//===----------------------------------------------------------------------===//
64
65// Name operation results with the operation name, except control outputs which
66// are named "ctl". MLIR will automatically use a numerical suffix to unique.
67static void GenericGetAsmResultNames(Operation *op,
68 OpAsmSetValueNameFn set_name_fn) {
69 // We only name the results when there are results to name, an op like `print`
70 // which does not have results will just use the `ctl` name for the control
71 // output.
72 if (op->getNumResults() > 1 && !op->getResult(0).getType().isa<ControlType>())
73 set_name_fn(op->getResult(0), op->getName().stripDialect());
74 for (Value result : op->getResults()) {
75 if (result.getType().isa<ControlType>()) {
76 set_name_fn(op->getResult(op->getNumResults() - 1), "ctl");
77 break;
78 }
79 }
80}
81
82// TFGraph support for interacting with the AsmPrinter.
83// Gives prettier names to SSA values.
84struct TFGraphOpAsmInterface
85 : public OpAsmOpInterface::FallbackModel<TFGraphOpAsmInterface> {
86 static bool classof(Operation *op) { return true; }
87
88 void getAsmResultNames(Operation *op, OpAsmSetValueNameFn set_name_fn) const {
89 GenericGetAsmResultNames(op, set_name_fn);
90 }
91 void getAsmBlockArgumentNames(Operation *op, Region &region,
92 OpAsmSetValueNameFn setNameFn) const {}
93 void getAsmBlockNames(Operation *op,
94 mlir::OpAsmSetBlockNameFn setNameFn) const {}
95};
96
97// Dialect construction: there is one instance per context and it registers its
98// operations, types, and interfaces here.
99void TFGraphDialect::initialize() {
100 getContext()->getOrLoadDialect<tf_type::TFTypeDialect>();
101 addOperations<
102#define GET_OP_LIST
103#include "tensorflow/core/ir/ops.cc.inc"
104 >();
105 addAttributes<
106#define GET_ATTRDEF_LIST
107#include "tensorflow/core/ir/attributes.cc.inc"
108 >();
109
110 // Support unknown operations because not all TensorFlow operations are
111 // registered.
112 allowUnknownOperations();
113
114 // Create the fallback OpAsmOpInterface instance.
115 fallbackOpAsmInterface_ = new TFGraphOpAsmInterface;
116
117 // Register the memory effects interface adaptor.
118 addInterfaces<StatefulMemoryEffectInterface>();
119
120 // Initialized the cached operation names.
121#define GET_OP_NAME_DEFS
122#include "tensorflow/core/ir/tf_op_names.inc"
123
124 // Caching some often used context-owned informations for fast-access.
125 name_key_ = StringAttr::get(getContext(), getNameAttrKey());
126 device_key_ = StringAttr::get(getContext(), getDeviceAttrKey());
127 assigned_device_key_ =
128 StringAttr::get(getContext(), getAssignedDeviceAttrKey());
129 fulltype_key_ = StringAttr::get(getContext(), getFullTypeAttrKey());
130 lifted_graph_func_name_ =
131 StringAttr::get(getContext(), getLiftedGraphFuncNameKey());
132 tfg_name_key_ = StringAttr::get(getContext(), getTfgNameAttrKey());
133 tfg_description_key_ =
134 StringAttr::get(getContext(), getTfgDescriptionAttrKey());
135 tfg_is_ref_key_ = StringAttr::get(getContext(), getTfgIsRefAttrKey());
136 tfg_handle_data_key_ =
137 StringAttr::get(getContext(), getTfgHandleDataAttrKey());
138 tfg_full_type_key_ = StringAttr::get(getContext(), getTfgFullTypeAttrKey());
139
140 control_ty_ = ControlType::get(getContext());
141}
142
143// Provides a hook for op interface.
144void *TFGraphDialect::getRegisteredInterfaceForOp(TypeID interface,
145 OperationName opName) {
146 if (interface == TypeID::get<OpAsmOpInterface>()) {
147 return fallbackOpAsmInterface_;
148 }
149
150 // Intrinsic operations explicitly implement intefaces.
151 if (opName.hasTrait<OpTrait::IntrinsicOperation>()) {
152 return nullptr;
153 }
154
155 if (interface == TypeID::get<TensorFlowRegistryInterface>()) {
156 if (auto *instance =
157 getRegisteredInterface<TensorFlowRegistryInterfaceBase>()) {
158 // Important: cast to (Concept *) to shift the pointer off the vtable.
159 return static_cast<TensorFlowRegistryInterfaceBase::Concept *>(
160 const_cast<TensorFlowRegistryInterfaceBase *>(instance));
161 }
162 } else if (interface == TypeID::get<MemoryEffectOpInterface>()) {
163 auto *instance = getRegisteredInterface<StatefulMemoryEffectInterface>();
164 assert(instance && "expected the memory interface to be registered");
165 return static_cast<StatefulMemoryEffectInterface::Concept *>(
166 const_cast<StatefulMemoryEffectInterface *>(instance));
167 }
168
169 return nullptr;
170}
171
172TFGraphDialect::~TFGraphDialect() { delete fallbackOpAsmInterface_; }
173
174// The name of certain optional attributes.
175static std::array<StringRef, 3> keyword_attrs{
176 "_mlir_device", "_mlir_assigned_device", "_mlir_name"};
177
178static void PrintKeywordAttributes(Operation *op, OpAsmPrinter &printer,
179 ArrayRef<StringRef> elided_attrs = {}) {
180 // Handles the optional "device" and "name" attribute.
181 for (StringRef keyword : keyword_attrs) {
182 if (StringAttr value_attr = op->getAttrOfType<StringAttr>(keyword)) {
183 assert(!value_attr.getValue().empty());
184 printer << " " << keyword.drop_front(/*len(_mlir_)*/ 6) << "(\""
185 << value_attr.getValue() << "\")";
186 }
187 }
188
189 // Print attributes (other than name and device).
190 SmallVector<StringRef> attrs_to_elide = llvm::to_vector(elided_attrs);
191 llvm::append_range(attrs_to_elide, keyword_attrs);
192 printer.printOptionalAttrDict(op->getAttrs(), attrs_to_elide);
193}
194
195// Print an operation that belongs to this dialect, if unregistered.
196// The general syntax is:
197// tfg.OpName(%input1, %input2, %input3) [%control_dep1, %control_dep2]
198// name("<node_name>") device("<device>") { attribute-dict } :
199// (input types) -> (result_types)
200void TFGraphDialect::printCustomTfOp(Operation *op,
201 OpAsmPrinter &printer) const {
202 ControlType control_ty = getControlType();
203
204 // Check that all control dependencies are after the regular values,
205 // otherwise print the generic form. We don't expect this to happen but
206 // we're defensive in the printer since this may happen in "hard-to-debug"
207 // issues.
208 {
209 bool has_control_dep = false;
210 for (Value operand : op->getOperands()) {
211 if (operand.getType() == control_ty) {
212 has_control_dep = true;
213 continue;
214 }
215 if (has_control_dep) {
216 printer.printGenericOp(op);
217 return;
218 }
219 }
220 has_control_dep = false;
221 for (Value result : op->getResults()) {
222 if (result.getType() == control_ty) {
223 has_control_dep = true;
224 continue;
225 }
226 if (has_control_dep) {
227 printer.printGenericOp(op);
228 return;
229 }
230 }
231 }
232
233 // Print the inputs (other than the control dependencies), if any.
234 TFOp tfg_op(op);
235 OperandRange data = tfg_op.getNonControlOperands();
236 if (!data.empty()) printer << '(' << data << ')';
237 // Print the control dependencies (if any).
238 OperandRange ctls = tfg_op.getControlOperands();
239 if (!ctls.empty()) printer << " [" << ctls << ']';
240
241 // Print the keyword attributes and optional attribute dictionary.
242 PrintKeywordAttributes(op, printer);
243
244 // Print the type, but omit control dependencies.
245 // If there is a single control return, just print the list of input types,
246 // otherwise print the complete type in a "function-style" way: (operands)
247 // -> (results).
248 ResultRange results = tfg_op.getNonControlResults();
249 if (results.empty()) {
250 if (!data.empty()) printer << " : " << data.getTypes();
251 } else {
252 printer << " : (" << data.getTypes() << ") -> (" << results.getTypes()
253 << ")";
254 }
255}
256
257// Print a custom TFG op.
258static void PrintCustomTfOp(Operation *op, OpAsmPrinter &printer) {
259 cast<TFGraphDialect>(op->getDialect())->printCustomTfOp(op, printer);
260}
261
262llvm::unique_function<void(Operation *, OpAsmPrinter &)>
263TFGraphDialect::getOperationPrinter(Operation *op) const {
264 return [this](Operation *op, OpAsmPrinter &printer) {
265 this->printCustomTfOp(op, printer);
266 };
267}
268
269// Try to parse optional keyword attributes and prefix them with `_mlir_`, of
270// `device`, `assigned_device`, and `name`.
271static ParseResult ParseKeywordAttributes(OpAsmParser &parser,
272 OperationState &result) {
273 for (const char *keyword : {"device", "assigned_device", "name"}) {
274 if (succeeded(parser.parseOptionalKeyword(keyword))) {
275 StringAttr value;
276 if (parser.parseLParen() ||
277 parser.parseAttribute<StringAttr>(
278 value, NoneType::get(parser.getContext())) ||
279 parser.parseRParen())
280 return failure();
281 result.addAttribute((Twine("_mlir_") + keyword).str(), value);
282 }
283 }
284 return parser.parseOptionalAttrDict(result.attributes);
285}
286
287// Parse an operation that belongs to this dialect, if unregistered.
288// The general syntax is:
289// tfg.OpName(%input1, %input2, %input3) [%control_dep1, %control_dep2]
290// name("<node_name>") device("<device>") { attribute-dict } :
291// (input types) -> (result_types)
292static ParseResult ParseCustomTfOp(OpAsmParser &parser,
293 OperationState &result) {
294 MLIRContext *context = parser.getBuilder().getContext();
295 // Parse optional argument list
296 SmallVector<OpAsmParser::UnresolvedOperand, 4> op_infos;
297 if (parser.parseOperandList(op_infos, AsmParser::Delimiter::OptionalParen))
298 return failure();
299 unsigned numNonControlOperands = op_infos.size();
300 // Optional control list, in between brackets.
301 if (parser.parseOperandList(op_infos, AsmParser::Delimiter::OptionalSquare))
302 return failure();
303
304 // Parse the optional keyword attributes and optional attribute dictionary.
305 if (ParseKeywordAttributes(parser, result)) return failure();
306
307 // Parse the functional type.
308 SmallVector<Type> arg_types;
309 arg_types.reserve(op_infos.size());
310 llvm::SMLoc loc = parser.getCurrentLocation();
311 Type control_type = ControlType::get(context);
312 if (failed(parser.parseOptionalColonTypeList(arg_types))) return failure();
313 if (arg_types.size() == 1 && arg_types.front().isa<FunctionType>()) {
314 auto funcType = arg_types.front().cast<FunctionType>();
315 if (funcType.getNumInputs() != numNonControlOperands)
316 return parser.emitError(loc)
317 << "got " << numNonControlOperands
318 << " non-control operands, but the type defines "
319 << funcType.getNumInputs() << " input types";
320 arg_types.clear();
321 arg_types.append(funcType.getInputs().begin(), funcType.getInputs().end());
322 result.types.append(funcType.getResults().begin(),
323 funcType.getResults().end());
324 }
325
326 // The control input are elided from the type list, add them here.
327 arg_types.resize(op_infos.size(), control_type);
328 if (!arg_types.empty())
329 if (parser.resolveOperands(op_infos, arg_types, loc, result.operands))
330 return failure();
331 if (result.name.getStringRef() != "tfg.return")
332 result.types.push_back(control_type);
333 return success();
334}
335
336Optional<Dialect::ParseOpHook> TFGraphDialect::getParseOperationHook(
337 StringRef opName) const {
338 return ParseOpHook(ParseCustomTfOp);
339}
340
341static bool VerifyGenericTFGOperation(Operation &op) {
342 TFGraphDialect *dialect = dyn_cast<TFGraphDialect>(op.getDialect());
343 if (!dialect) return true;
344 ControlType control_ty = dialect->getControlType();
345
346 // verifies that control operands (or results) are always after regular
347 // inputs (or results).
348 auto check_ctl_at_end = [&](TypeRange types, StringRef input_or_output) {
349 int has_control_dep = -1;
350 for (auto &indexed_operand : llvm::enumerate(types)) {
351 if (indexed_operand.value() == control_ty) {
352 has_control_dep = indexed_operand.index();
353 continue;
354 }
355 if (has_control_dep != -1) {
356 op.emitOpError() << "found non-control " << input_or_output
357 << " in position #" << indexed_operand.index()
358 << " after control " << input_or_output
359 << " in position #" << has_control_dep;
360 return false;
361 }
362 }
363 return true;
364 };
365 if (!check_ctl_at_end(op.getOperandTypes(), "input")) return false;
366 if (!check_ctl_at_end(op.getResultTypes(), "result")) return false;
367
368 // Certain attributes are supposed to be inserted with non-empty value.
369 for (StringRef keyword : keyword_attrs) {
370 if (StringAttr value_attr = op.getAttrOfType<StringAttr>(keyword)) {
371 if (value_attr.getValue().empty()) {
372 op.emitOpError() << keyword
373 << " has empty value. Only insert this attribute when "
374 "it has a value";
375 }
376 }
377 }
378
379 return true;
380}
381
382//===----------------------------------------------------------------------===//
383// Graph Operation
384//===----------------------------------------------------------------------===//
385
386LogicalResult GraphOp::verify() {
387 GraphOp op = *this;
388 // Check all ops in the body.
389 if (!all_of(*op.getBody(), VerifyGenericTFGOperation)) return failure();
390
391 return success();
392}
393//===----------------------------------------------------------------------===//
394// Func Operation
395//===----------------------------------------------------------------------===//
396
397bool GraphFuncOp::isMarkedForCompilation() {
398 auto is_enabled = [this](StringRef attr_name) -> bool {
399 Attribute attr = (*this)->getAttr(attr_name);
400 if (!attr) return false;
401 if (auto bool_attr = attr.dyn_cast<BoolAttr>()) return bool_attr.getValue();
402 if (auto str_attr = attr.dyn_cast<StringAttr>())
403 return !str_attr.getValue().empty();
404 return false;
405 };
406 return is_enabled("_xla_compile_id") || is_enabled("_tpu_replicate") ||
407 is_enabled("_XlaMustCompile");
408}
409
410// Hook for OpTrait::FunctionLike, called after verifying that the 'type'
411// attribute is present and checks if it holds a function type. Ensures
412// getType, getNumFuncArguments, and getNumFuncResults can be called safely
413LogicalResult GraphFuncOp::verifyType() {
414 auto type = getFunctionTypeAttr().getValue();
415 if (!type.isa<FunctionType>())
416 return emitOpError("requires '" + getTypeAttrName() +
417 "' attribute of function type");
418 return success();
419}
420
421// Hook for OpTrait::FunctionLike, called after verifying the function
422// type and the presence of the (potentially empty) function body.
423LogicalResult GraphFuncOp::verifyBody() {
424 FunctionType type = getFunctionType();
425 Block *body = SingleBlock::getBody();
426 // Check that the body is terminated with a tfg.return.
427 if (getRegion().empty() || body->empty())
428 return emitOpError() << "expects a non empty body";
429
430 if (body->getNumArguments() != type.getNumInputs())
431 return emitOpError() << "function type indicated " << type.getNumInputs()
432 << " args but block has " << body->getNumArguments();
433
434 for (auto &arg_types :
435 llvm::enumerate(llvm::zip(type.getInputs(), body->getArgumentTypes()))) {
436 Type signature_arg = std::get<0>(arg_types.value());
437 Type block_arg = std::get<1>(arg_types.value());
438 if (signature_arg != block_arg)
439 return emitOpError() << "type mismatch for arg #" << arg_types.index()
440 << ", signature defines " << signature_arg
441 << " block arg is " << block_arg;
442 }
443
444 if (!isa<ReturnOp>(body->back()))
445 return emitOpError()
446 << "expects body to be terminated with a tfg.return, but got: "
447 << body->back().getName().getStringRef();
448
449 ReturnOp return_op = cast<ReturnOp>(body->getTerminator());
450
451 if (type.getNumResults() > return_op->getNumOperands())
452 return emitOpError() << "expects " << type.getNumResults()
453 << " returned values but tfg.return has "
454 << return_op->getNumOperands() << " operands";
455 for (auto &indexed_type : llvm::enumerate(type.getResults())) {
456 Type expected_type = indexed_type.value();
457 int res_num = indexed_type.index();
458 Type actual_type = return_op->getOperand(res_num).getType();
459 if (!tf_type::AreCastCompatible({expected_type, actual_type})) {
460 return emitOpError() << "type mismatch for returned value #" << res_num
461 << ", expected " << expected_type << " but got "
462 << actual_type;
463 }
464 }
465 Type control_type = getDialect()->getControlType();
466 for (auto &indexed_type : llvm::enumerate(llvm::drop_begin(
467 return_op->getOperandTypes(), type.getNumResults()))) {
468 Type actual_type = indexed_type.value();
469 if (actual_type != control_type) {
470 return emitOpError() << "returned value #" << indexed_type.index()
471 << " overflow the expected " << type.getNumResults()
472 << " returned value for function " << getName()
473 << ", expected a ControlType but got "
474 << actual_type;
475 }
476 }
477
478 // Check all ops in the body.
479 if (!all_of(*SingleBlock::getBody(), VerifyGenericTFGOperation))
480 return failure();
481
482 return success();
483}
484
485LogicalResult GraphFuncOp::canonicalize(GraphFuncOp func_op,
486 PatternRewriter &rewriter) {
487 // Prune function body: the body is a graph where feeds/fetches a materialized
488 // with function arguments and returned values. As such any operation not
489 // reachable from the "fetches" can be pruned. The return statement also has
490 // control input so that side-effecting operations without results (print for
491 // example) aren't pruned.
492 bool changed = true;
493 while (changed) {
494 changed = false;
495 for (Operation &op : llvm::make_early_inc_range(
496 llvm::reverse(*func_op.SingleBlock::getBody()))) {
497 if (isa<ReturnOp>(op)) continue;
498 if (op.getUses().empty()) {
499 rewriter.eraseOp(&op);
500 changed = true;
501 }
502 }
503 }
504 return failure();
505}
506
507LogicalResult GraphFuncOp::verify() {
508 GraphFuncOp func_op = *this;
509 if (func_op.getNumArguments() % 2)
510 return func_op.emitOpError() << "expects an even number of arguments";
511 ArrayAttr args_attrs = func_op.getAllArgAttrs();
512 if (args_attrs && args_attrs.size() != func_op.getNumArguments())
513 return func_op.emitOpError()
514 << "expects argument attributes for each argument ("
515 << args_attrs.size() << " vs " << func_op.getNumArguments() << ")";
516 ArrayAttr res_attrs = func_op.getAllResultAttrs();
517 if (res_attrs && res_attrs.size() != func_op.getNumResults())
518 return func_op.emitOpError()
519 << "expects results attributes for each result (" << res_attrs.size()
520 << " vs " << func_op.getNumResults() << ")";
521 return success();
522}
523
524ParseResult GraphFuncOp::parse(OpAsmParser &parser, OperationState &result) {
525 SmallVector<OpAsmParser::UnresolvedOperand> entry_args;
526 SmallVector<Attribute> arg_attrs;
527 SmallVector<Attribute> result_attrs;
528 SmallVector<Type> arg_types;
529 SmallVector<Type> result_types;
530 auto &builder = parser.getBuilder();
531 MLIRContext *context = builder.getContext();
532
533 // Parse visibility.
534 StringRef visibility;
535 if (!parser.parseOptionalKeyword(&visibility,
536 {"public", "private", "nested"})) {
537 StringAttr visibility_attr = parser.getBuilder().getStringAttr(visibility);
538 result.attributes.push_back(parser.getBuilder().getNamedAttr(
539 SymbolTable::getVisibilityAttrName(), visibility_attr));
540 }
541
542 if (succeeded(parser.parseOptionalKeyword("generic")))
543 result.addAttribute("generic", builder.getUnitAttr());
544
545 // Parse the name as a symbol.
546 StringAttr name_attr;
547 if (parser.parseSymbolName(name_attr, SymbolTable::getSymbolAttrName(),
548 result.attributes))
549 return failure();
550
551 // Parse the function signature.
552 // The difference with usual functions, is that for every single argument
553 // parsed, we create two block arguments: one for the expected value and one
554 // for the control dependency.
555 if (parser.parseLParen()) return failure();
556 Type control_ty = ControlType::get(builder.getContext());
557 std::list<std::string> control_operand_names;
558
559 // Helper to parse a single argument and its attributes.
560 auto parse_argument = [&]() -> ParseResult {
561 // Parse argument name if present.
562 entry_args.emplace_back();
563 arg_types.emplace_back();
564 if (parser.parseOperand(entry_args.back(), /*allowResultNumber=*/false) ||
565 parser.parseColonType(arg_types.back()))
566 return failure();
567
568 // Parse any argument attributes.
569 NamedAttrList attrs;
570 if (parser.parseOptionalAttrDict(attrs)) return failure();
571 arg_attrs.push_back(attrs.getDictionary(context));
572
573 // Define the control input: it's not printed but is added as block
574 // argument. Note the name computed here (suffixed ".ctl") is coupled to the
575 // implementation of:
576 // TFGraphOpAsmInterface::getAsmBlockArgumentNames()
577 // at the top of this file.
578 OpAsmParser::UnresolvedOperand control_operand = entry_args.back();
579 control_operand_names.push_back((control_operand.name + ".ctl").str());
580 control_operand.name = control_operand_names.back();
581 entry_args.push_back(control_operand);
582 arg_types.push_back(control_ty);
583 arg_attrs.push_back(DictionaryAttr::get(context));
584 return success();
585 };
586
587 // Parse the function arguments and their attributes.
588 if (failed(parser.parseOptionalRParen())) {
589 do {
590 if (parse_argument()) return failure();
591 } while (succeeded(parser.parseOptionalComma()));
592 if (parser.parseRParen()) return failure();
593 }
594
595 // Parse the result types and their attributes.
596 if (succeeded(parser.parseOptionalArrow())) {
597 if (failed(parser.parseLParen())) return failure();
598 if (failed(parser.parseOptionalRParen())) {
599 // Parse individual function results.
600 do {
601 result_types.emplace_back();
602 NamedAttrList result_attr;
603 if (parser.parseType(result_types.back()) ||
604 parser.parseOptionalAttrDict(result_attr)) {
605 return failure();
606 }
607 result_attrs.push_back(builder.getDictionaryAttr(result_attr));
608 } while (succeeded(parser.parseOptionalComma()));
609 if (parser.parseRParen()) return failure();
610 }
611 }
612
613 auto type = builder.getFunctionType(arg_types, result_types);
614 result.addAttribute(GraphFuncOp::getTypeAttrName(), TypeAttr::get(type));
615
616 // If function attributes are present, parse them.
617 NamedAttrList parsed_attributes;
618 if (parser.parseOptionalAttrDictWithKeyword(parsed_attributes))
619 return failure();
620 result.attributes.append(parsed_attributes);
621
622 // Add the attributes to the function arguments.
623 assert(arg_attrs.size() == arg_types.size());
624 assert(result_attrs.size() == result_types.size());
625 result.attributes.append(
626 builder.getNamedAttr(FunctionOpInterface::getArgDictAttrName(),
627 builder.getArrayAttr(arg_attrs)));
628 result.attributes.append(
629 builder.getNamedAttr(FunctionOpInterface::getResultDictAttrName(),
630 builder.getArrayAttr(result_attrs)));
631
632 // Parse the function body.
633 auto *body = result.addRegion();
634 llvm::SMLoc loc = parser.getCurrentLocation();
635 SmallVector<OpAsmParser::Argument> args;
636 if (entry_args.size()) {
637 for (auto argAndType : llvm::zip(entry_args, arg_types)) {
638 auto &arg = args.emplace_back();
639 arg.ssaName = std::get<0>(argAndType);
640 arg.type = std::get<1>(argAndType);
641 }
642 }
643
644 if (failed(parser.parseRegion(*body, args, /*enableNameShadowing=*/false)))
645 return failure();
646
647 // Function body was parsed, make sure it's not empty.
648 if (body->empty())
649 return parser.emitError(loc, "expected non-empty function body");
650
651 return success();
652}
653
654void GraphFuncOp::print(OpAsmPrinter &p) {
655 // Print the operation and the function name.
656 Operation *op = *this;
657 p << " ";
658 int argIndentSize = op->getName().getStringRef().size() + 3;
659 StringRef visibility_attr_name = SymbolTable::getVisibilityAttrName();
660 if (auto visibility = op->getAttrOfType<StringAttr>(visibility_attr_name)) {
661 p << visibility.getValue() << ' ';
662 argIndentSize += visibility.getValue().size() + 1;
663 }
664 if (getGeneric()) p << "generic ";
665 auto funcName =
666 op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName())
667 .getValue();
668 p.printSymbolName(funcName);
669 argIndentSize += funcName.size();
670 std::string indent(argIndentSize, ' ');
671 FunctionType fnType = getFunctionType();
672 ArrayRef<Type> arg_types = fnType.getInputs();
673 ArrayRef<Type> result_types = fnType.getResults();
674 assert((arg_types.size() % 2) == 0);
675 // Print operand list with attributes.
676 p << '(';
677 ArrayAttr args_attr = getAllArgAttrs();
678 for (unsigned i = 0, e = arg_types.size(); i < e; i += 2) {
679 // Args come by pair: input+control.
680 p.printOperand(getArgument(i));
681 p << ": ";
682 p.printType(arg_types[i]);
683 if (auto arg_attrs = args_attr[i].dyn_cast<DictionaryAttr>())
684 p.printOptionalAttrDict(arg_attrs.getValue());
685 if (i != e - 2) {
686 p << ", ";
687 p.printNewline();
688 p << indent;
689 }
690 }
691 p << ')';
692
693 // Print result types, if any.
694 if (!result_types.empty()) {
695 p.printNewline();
696 p.getStream() << " -> (";
697 indent = std::string(9, ' ');
698 ArrayAttr results_attr = getAllResultAttrs();
699 for (int i = 0, e = result_types.size(); i < e; ++i) {
700 p.printType(result_types[i]);
701 if (auto result_attrs = results_attr[i].dyn_cast<DictionaryAttr>())
702 p.printOptionalAttrDict(result_attrs.getValue());
703 if (i != e - 1) {
704 p << ", ";
705 p.printNewline();
706 p << indent;
707 }
708 }
709 p << ")";
710 }
711 // Print attributes.
712 if (!op->getAttrs().empty()) {
713 p.printNewline();
714 function_interface_impl::printFunctionAttributes(
715 p, *this, fnType.getNumInputs(), fnType.getNumResults(),
716 {"generic", SymbolTable::getVisibilityAttrName()});
717 }
718 // Print body.
719 p << ' ';
720 p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
721}
722
723GraphFuncOp GraphFuncOp::getCalledFunction(Operation *op,
724 SymbolTable &symbol_table) {
725 // Check if a node does indirect function call via PartitionedCallOp.
726 // TODO(aminim): consider replacing with isa<...> when possible.
727 if (op->getName().getStringRef() == "tfg.PartitionCall" ||
728 op->getName().getStringRef() == "tfg.StatefulPartitionedCall") {
729 auto func_attr = op->getAttrOfType<FuncAttr>("f");
730 if (!func_attr) return {};
731 GraphFuncOp callee = symbol_table.lookup<GraphFuncOp>(
732 func_attr.getName().getLeafReference());
733 if (callee) return callee;
734 }
735 return symbol_table.lookup<GraphFuncOp>(op->getName().stripDialect());
736}
737
738BlockArgument GraphFuncOp::getDataValueOf(BlockArgument ctl) {
739 return ctl.getOwner()->getArgument(ctl.getArgNumber() - 1);
740}
741
742BlockArgument GraphFuncOp::getControlTokenOf(BlockArgument data) {
743 return data.getOwner()->getArgument(data.getArgNumber() + 1);
744}
745
746BlockArgument GraphFuncOp::getDataValue(Region &region, unsigned idx) {
747 return region.getArgument(idx * 2);
748}
749
750// This is naming block arguments for GraphFuncOp, we rely on the arg attributes
751// for computing the names.
752void GraphFuncOp::getAsmBlockArgumentNames(Region &region,
753 OpAsmSetValueNameFn set_name_fn) {
754 ArrayRef<BlockArgument> args = SingleBlock::getBody()->getArguments();
755 ControlType control_ty = ControlType::get(getContext());
756 // Sanity checking: this is verified by the op but this may be called before
757 // the verifier or in some diagnostic/debug context, let's not crash.
758 // We expect the function block operands to come as pair: tensor+control.
759 if (args.size() % 2) return;
760 for (unsigned i = 0, e = args.size(); i < e; i += 2)
761 if (args[i].getType() == control_ty || args[i + 1].getType() != control_ty)
762 return;
763
764 // Name the values based on the `tfg.name` arg attribute retrieved from the
765 // func_op.
766 ArrayAttr args_attr = getAllArgAttrs();
767 if (!args_attr || args_attr.size() != args.size()) return;
768 for (int arg_num = 0, e = args.size(); arg_num < e; arg_num += 2) {
769 DictionaryAttr arg_attrs = args_attr[arg_num].dyn_cast<DictionaryAttr>();
770 if (!arg_attrs) continue;
771 if (auto strAttr = arg_attrs.getAs<StringAttr>("tfg.name")) {
772 set_name_fn(args[arg_num], strAttr.getValue());
773 set_name_fn(args[arg_num + 1], (strAttr.getValue() + ".ctl").str());
774 }
775 }
776}
777
778//===----------------------------------------------------------------------===//
779// ReturnOp
780//===----------------------------------------------------------------------===//
781
782LogicalResult ReturnOp::verify() {
783 ReturnOp op = *this;
784 // If the control result attributes are present, there must be the same number
785 // of entries as control results.
786 if (op.getControlRetAttrs().size() != TFOp(op).getControlOperands().size()) {
787 return op.emitOpError(
788 "expected as many control result attributes as there are control "
789 "operands");
790 }
791 return success();
792}
793
794ParseResult ReturnOp::parse(OpAsmParser &parser, OperationState &result) {
795 // ReturnOp has the same assembly format as generic TFG ops except that the
796 // control result attributes are embedded with the control operands:
797 // [%ctl {tfg.name = "foo"}, %ctl_0 {tfg.name = "bar"}]
798 SmallVector<OpAsmParser::UnresolvedOperand> operands;
799 if (parser.parseOperandList(operands, AsmParser::Delimiter::OptionalParen))
800 return failure();
801
802 SmallVector<Attribute> control_ret_attrs;
803 if (succeeded(parser.parseOptionalLSquare())) {
804 OpAsmParser::UnresolvedOperand operand;
805 do {
806 NamedAttrList attrs;
807 OptionalParseResult parse_result = parser.parseOptionalOperand(operand);
808 if (!parse_result.hasValue()) break;
809 if (failed(parse_result.getValue())) return failure();
810 if (parser.parseOptionalAttrDict(attrs)) return failure();
811 control_ret_attrs.push_back(attrs.getDictionary(result.getContext()));
812 operands.push_back(std::move(operand));
813 } while (succeeded(parser.parseOptionalComma()));
814 if (parser.parseRSquare()) return failure();
815 }
816
817 if (ParseKeywordAttributes(parser, result)) return failure();
818 result.addAttribute(ReturnOp::control_ret_attrsAttrName(result.name),
819 ArrayAttr::get(result.getContext(), control_ret_attrs));
820
821 SmallVector<Type> types;
822 if (parser.parseOptionalColonTypeList(types)) return failure();
823 types.resize(operands.size(), ControlType::get(result.getContext()));
824 if (parser.resolveOperands(operands, types, parser.getCurrentLocation(),
825 result.operands))
826 return failure();
827 return success();
828}
829
830void ReturnOp::print(OpAsmPrinter &printer) {
831 TFOp tfg_op(*this);
832 OperandRange data = tfg_op.getNonControlOperands();
833 if (!data.empty()) printer << '(' << data << ')';
834
835 OperandRange ctls = tfg_op.getControlOperands();
836 if (!ctls.empty()) {
837 printer << " [";
838 llvm::interleave(
839 llvm::zip(ctls, getControlRetAttrs().getAsRange<DictionaryAttr>()),
840 printer,
841 [&](auto it) {
842 printer << std::get<0>(it);
843 if (!std::get<1>(it).empty()) printer << ' ' << std::get<1>(it);
844 },
845 ", ");
846 printer << ']';
847 }
848
849 PrintKeywordAttributes(*this, printer, {"control_ret_attrs"});
850
851 if (!data.empty()) printer << " : " << data.getTypes();
852}
853
854void ReturnOp::build(OpBuilder &odsBuilder, OperationState &odsState,
855 ValueRange operands, ValueRange control_operands) {
856 odsState.addOperands(operands);
857 odsState.addOperands(control_operands);
858 // Populate `control_ret_attrs` with empty dictionaries.
859 odsState.addAttribute(
860 ReturnOp::control_ret_attrsAttrName(odsState.name),
861 odsBuilder.getArrayAttr(SmallVector<Attribute>(
862 control_operands.size(), odsBuilder.getDictionaryAttr({}))));
863}
864
865//===----------------------------------------------------------------------===//
866// Concrete Ops
867//===----------------------------------------------------------------------===//
868
869// The ODS definitions of TFG ops can be autogenerated TODO(jeffniu) as well as
870// parts of their verifiers. These hand-written verifiers focus on verifying the
871// ops' operand and result types with respect to their functions' types, the
872// logic for which is slightly different between operations.
873
874// Verify that all control operands follow non-control operands, and return the
875// subrange of non-control operands.
876static FailureOr<TypeRange> VerifyOperands(Operation *op) {
877 ControlType control_ty =
878 cast<TFGraphDialect>(op->getDialect())->getControlType();
879 Operation::operand_type_iterator it =
880 llvm::find(op->getOperandTypes(), control_ty);
881 if (!std::all_of(it, op->operand_type_end(),
882 [&](Type type) { return type == control_ty; })) {
883 return op->emitOpError(
884 "not all control tokens come after non-control operands");
885 }
886 return {Operation::operand_type_range(op->operand_type_begin(), it)};
887}
888
889// Verify that the last result of an operation is the only control result, and
890// return a subrange of the non-control results.
891static FailureOr<TypeRange> VerifyResults(Operation *op) {
892 ControlType control_ty =
893 cast<TFGraphDialect>(op->getDialect())->getControlType();
894 Operation::result_type_iterator it =
895 llvm::find(op->getResultTypes(), control_ty);
896 if (it == op->result_type_end())
897 return op->emitOpError("does not define a control result");
898 if (it != std::prev(op->result_type_end())) {
899 return op->emitOpError(
900 "must have a control token result as and only as its last result");
901 }
902 return {Operation::result_type_range(op->result_type_begin(), it)};
903}
904
905// Verify that the signature of the function matches the operation's operands
906// and results.
907static LogicalResult VerifySignature(GraphFuncOp func, Operation *op,
908 TypeRange operands, TypeRange results,
909 const Twine &func_name) {
910 auto attach_func = [&](InFlightDiagnostic diag) -> LogicalResult {
911 return diag.attachNote(func.getLoc()).appendOp(*func, OpPrintingFlags())
912 << "\nsee referenced function";
913 };
914
915 ArrayRef<Type> arguments = func.getFunctionType().getInputs();
916 ArrayRef<Type> returns = func.getFunctionType().getResults();
917 if (operands.size() * 2 != arguments.size()) {
918 return attach_func(op->emitOpError(func_name)
919 << " function has " << arguments.size() / 2
920 << " arguments but was provided " << operands.size());
921 }
922 if (results.size() != returns.size()) {
923 return attach_func(op->emitOpError(func_name)
924 << " function has " << returns.size()
925 << " results but expected " << results.size());
926 }
927
928 if (func.getGeneric()) return success();
929
930 for (auto &it : llvm::enumerate(operands)) {
931 Type arg_type = arguments[it.index() * 2];
932 Type op_type = it.value();
933 if (!tf_type::HasCompatibleElementTypes(arg_type, op_type)) {
934 return attach_func(
935 op->emitOpError(func_name)
936 << " function argument #" << it.index() << " type " << arg_type
937 << " is not compatible with corresponding operand type: " << op_type);
938 }
939 }
940 for (auto &it : llvm::enumerate(results)) {
941 Type ret_type = returns[it.index()];
942 Type res_type = it.value();
943 if (!tf_type::HasCompatibleElementTypes(ret_type, res_type)) {
944 return attach_func(
945 op->emitOpError(func_name)
946 << " function result #" << it.index() << " type " << ret_type
947 << " is not compatible with corresponding result type: " << res_type);
948 }
949 }
950 return success();
951}
952
953// This function verifies that the types of `values`, which are either operands
954// or results of `op`, match the types specified in `types`, which is expected
955// to be an array of type attributes.
956static LogicalResult VerifyTypeArray(Operation *op, ValueRange values,
957 ArrayAttr types, StringRef kind) {
958 // Don't verify if the types are not present.
959 if (!types) return success();
960 if (values.size() != types.size()) {
961 return op->emitOpError("has ") << values.size() << " " << kind << "s but "
962 << types.size() << " " << kind << " types";
963 }
964 for (auto it :
965 llvm::zip(llvm::enumerate(values), types.getAsRange<TypeAttr>())) {
966 Type type = std::get<0>(it).value().getType();
967 Type dtype = std::get<1>(it).getValue();
968 if (!tf_type::HasCompatibleElementTypes(type,
969 UnrankedTensorType::get(dtype))) {
970 return op->emitOpError(kind)
971 << " #" << std::get<0>(it).index()
972 << " is incompatible with dtype " << dtype << ", got: " << type;
973 }
974 }
975 return success();
976}
977
978namespace detail {
979// Check if the op type has `T`.
980template <typename OpT>
981using has_T = decltype(std::declval<OpT>().T());
982template <typename OpT>
983using detect_has_T = llvm::is_detected<has_T, OpT>;
984
985// Get the input and output type arrays. If the op has a single type array,
986// use it for both input and output. Otherwise, return separate type arrays.
987template <typename OpT, bool = detect_has_T<OpT>::value>
988struct GetTypeArray {
989 static ArrayAttr getInputTypes(OpT op) { return op.getTinAttr(); }
990 static ArrayAttr getOutputTypes(OpT op) { return op.getToutAttr(); }
991};
992template <typename OpT>
993struct GetTypeArray<OpT, true> {
994 static ArrayAttr getInputTypes(OpT op) { return op.getTAttr(); }
995 static ArrayAttr getOutputTypes(OpT op) { return op.getTAttr(); }
996};
997} // namespace detail
998
999// Verify a functional op's inputs and outputs against its data type arrays. For
1000// loop ops, this also checks that the number of inputs and outputs match. This
1001// is guaranteed to be valid on import but may be violated by a transformation.
1002template <typename OpT>
1003static LogicalResult VerifyTypeArrayAttributes(OpT op) {
1004 using GetTypeArray = typename detail::GetTypeArray<OpT>;
1005 ValueRange args =
1006 SplitDataAndControlValues(op.getArgs(), ControlType::get(op.getContext()))
1007 .first;
1008 return success(
1009 succeeded(VerifyTypeArray(op, args, GetTypeArray::getInputTypes(op),
1010 "argument")) &&
1011 succeeded(VerifyTypeArray(op, op.getOuts(),
1012 GetTypeArray::getOutputTypes(op), "result")));
1013}
1014
1015//===----------------------------------------------------------------------===//
1016// If-Like Ops
1017
1018template <typename IfLikeOp>
1019static LogicalResult VerifyIfLikeOp(IfLikeOp op,
1020 SymbolTableCollection &symbol_table) {
1021 if (failed(op.verifyInvariants())) return failure();
1022 FailureOr<TypeRange> ins = VerifyOperands(op);
1023 if (failed(ins)) return failure();
1024 FailureOr<TypeRange> outs = VerifyResults(op);
1025 if (failed(outs)) return failure();
1026
1027 // The first operand is the condition and is not passed to the functions.
1028 TypeRange func_args = ins->drop_front();
1029
1030 auto then_func = symbol_table.lookupNearestSymbolFrom<GraphFuncOp>(
1031 op, op.getThenBranch().getName());
1032 if (then_func &&
1033 failed(VerifySignature(then_func, op, func_args, *outs, "then")))
1034 return failure();
1035
1036 auto else_func = symbol_table.lookupNearestSymbolFrom<GraphFuncOp>(
1037 op, op.getElseBranch().getName());
1038 if (else_func &&
1039 failed(VerifySignature(else_func, op, func_args, *outs, "else")))
1040 return failure();
1041
1042 return VerifyTypeArrayAttributes(op);
1043}
1044
1045//===----------------------------------------------------------------------===//
1046// Case-Like Ops
1047
1048template <typename CaseLikeOp>
1049static LogicalResult VerifyCaseLikeOp(CaseLikeOp op,
1050 SymbolTableCollection &symbol_table) {
1051 if (failed(op.verifyInvariants())) return failure();
1052 FailureOr<TypeRange> ins = VerifyOperands(op);
1053 if (failed(ins)) return failure();
1054 FailureOr<TypeRange> outs = VerifyResults(op);
1055 if (failed(outs)) return failure();
1056
1057 // The first operand is the branch index and is not passed to the functions.
1058 TypeRange func_args = ins->drop_front();
1059
1060 for (auto &it : llvm::enumerate(op.getBranches())) {
1061 SymbolRefAttr func_name = it.value().template cast<FuncAttr>().getName();
1062 auto func =
1063 symbol_table.lookupNearestSymbolFrom<GraphFuncOp>(op, func_name);
1064 if (func && failed(VerifySignature(func, op, func_args, *outs,
1065 "branch #" + Twine(it.index()))))
1066 return failure();
1067 }
1068
1069 return VerifyTypeArrayAttributes(op);
1070}
1071
1072//===----------------------------------------------------------------------===//
1073// While-Like Ops
1074
1075template <typename WhileLikeOp>
1076static LogicalResult VerifyWhileLikeOp(WhileLikeOp op,
1077 SymbolTableCollection &symbol_table) {
1078 if (failed(op.verifyInvariants())) return failure();
1079 FailureOr<TypeRange> ins = VerifyOperands(op);
1080 if (failed(ins)) return failure();
1081 FailureOr<TypeRange> outs = VerifyResults(op);
1082 if (failed(outs)) return failure();
1083
1084 SymbolRefAttr body_name = op.getBody().getName();
1085
1086 auto cond_func = symbol_table.lookupNearestSymbolFrom<GraphFuncOp>(
1087 op, op.getCond().getName());
1088 auto i1_type = UnrankedTensorType::get(Builder(op.getContext()).getI1Type());
1089 if (cond_func &&
1090 failed(VerifySignature(cond_func, op, *ins, i1_type, "cond")))
1091 return failure();
1092
1093 auto body_func = symbol_table.lookupNearestSymbolFrom<GraphFuncOp>(
1094 op, op.getBody().getName());
1095 if (body_func && failed(VerifySignature(body_func, op, *ins, *outs, "body")))
1096 return failure();
1097
1098 return VerifyTypeArrayAttributes(op);
1099}
1100
1101//===----------------------------------------------------------------------===//
1102// ForOp
1103
1104LogicalResult ForOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1105 if (failed(verifyInvariants())) return failure();
1106 FailureOr<TypeRange> ins = VerifyOperands(*this);
1107 if (failed(ins)) return failure();
1108 FailureOr<TypeRange> outs = VerifyResults(*this);
1109 if (failed(outs)) return failure();
1110
1111 auto body_func = symbolTable.lookupNearestSymbolFrom<GraphFuncOp>(
1112 *this, getBody().getName());
1113 // The first three arguments are the for-loop indices, but the current loop
1114 // index is passed in.
1115 TypeRange func_args = llvm::drop_begin(*ins, /*N=*/2);
1116 if (body_func &&
1117 failed(VerifySignature(body_func, *this, func_args, *outs, "body")))
1118 return failure();
1119
1120 return VerifyTypeArrayAttributes(*this);
1121}
1122
1123//===----------------------------------------------------------------------===//
1124// Region Ops and Terminators
1125//===----------------------------------------------------------------------===//
1126
1127// If a region op has preserved attributes, verify that they match the number of
1128// results and block arguments.
1129static LogicalResult VerifyPreservedAttrs(Operation *op,
1130 ArrayRef<Attribute> preserved_attrs) {
1131 assert(op->getNumRegions() == preserved_attrs.size());
1132 for (auto it : llvm::zip(preserved_attrs, op->getRegions())) {
1133 // Preserved attributes for a particular region may not exist.
1134 auto attrs = std::get<0>(it).dyn_cast_or_null<RegionAttr>();
1135 if (!attrs) continue;
1136 Region &region = std::get<1>(it);
1137
1138 const auto emit_region_error = [&](StringRef msg) {
1139 return op->emitOpError("region #")
1140 << region.getRegionNumber() << " " << msg;
1141 };
1142
1143 unsigned num_args = GetLoopRegionDataArgs(region).size();
1144 if (num_args != attrs.getArgAttrs().size()) {
1145 return emit_region_error("has ")
1146 << num_args << " argument(s) but preserved attributes has "
1147 << attrs.getArgAttrs().size();
1148 }
1149
1150 // All regions are terminated by either a YieldOp or a ConditionOp. In the
1151 // latter case, the function will only have one result.
1152 unsigned num_rets;
1153 Operation *terminator = region.front().getTerminator();
1154 if (isa<ConditionOp>(terminator)) {
1155 num_rets = 1;
1156 } else {
1157 num_rets = cast<RegionBranchTerminatorOpInterface>(terminator)
1158 .getMutableSuccessorOperands(region.getRegionNumber())
1159 .size();
1160 }
1161 if (num_rets != attrs.getResAttrs().size()) {
1162 return emit_region_error("has ")
1163 << num_rets << " result(s) but preserved attributes has "
1164 << attrs.getResAttrs().size();
1165 }
1166 }
1167 return success();
1168}
1169
1170//===----------------------------------------------------------------------===//
1171// YieldOp
1172
1173MutableOperandRange YieldOp::getMutableSuccessorOperands(
1174 Optional<unsigned> index) {
1175 // Get the subrange of non-control operands.
1176 return getArgsMutable();
1177}
1178
1179static bool TerminatedByYield(Block &block) {
1180 return isa<YieldOp>(block.getTerminator());
1181}
1182
1183//===----------------------------------------------------------------------===//
1184// IfLikeRegionOp
1185
1186// Verify an if-like region op.
1187template <typename IfLikeRegionOp>
1188static LogicalResult VerifyIfLikeRegionOp(IfLikeRegionOp op) {
1189 // Verify terminators.
1190 if (!TerminatedByYield(op.then_block()))
1191 return op.emitOpError("then region must be terminated by a 'tfg.yield'");
1192 if (!TerminatedByYield(op.else_block()))
1193 return op.emitOpError("else region must be terminated by a 'tfg.yield'");
1194 return VerifyPreservedAttrs(
1195 op, {op.getThenRegionAttrsAttr(), op.getElseRegionAttrsAttr()});
1196}
1197
1198// Given an potentially null attribute that would represent a constant value,
1199// try to narrow it to a statically known condition.
1200// TODO(jeffniu): Incorporate the other cases of `tf.ToBool`.
1201static Optional<bool> GetStaticallyKnownBranch(Attribute cond_attr) {
1202 // Only handle the case of a scalar tensor of i1.
1203 auto cond = cond_attr.dyn_cast_or_null<ElementsAttr>();
1204 if (cond && cond.getNumElements() == 1 &&
1205 cond.getElementType().isSignlessInteger(1))
1206 return cond.getSplatValue<bool>();
1207 return {};
1208}
1209
1210// Get the successor of the regions of an if-like op.
1211template <typename IfLikeRegionOp>
1212void GetIfLikeRegionOpSuccessorRegions(
1213 IfLikeRegionOp op, Optional<unsigned> index, ArrayRef<Attribute> operands,
1214 SmallVectorImpl<RegionSuccessor> &regions) {
1215 assert(index.has_value() ||
1216 !operands.empty() && "if-like op expected at least 1 operand");
1217 // Both regions branch back to the parent op.
1218 if (index.has_value()) {
1219 // Ignore the control token.
1220 regions.emplace_back(
1221 ResultRange(op->result_begin(), std::prev(op->result_end())));
1222 } else if (auto cond = GetStaticallyKnownBranch(operands[0])) {
1223 // Add only 1 possible successor if the condition is known.
1224 Region &region = *cond ? op.getThenRegion() : op.getElseRegion();
1225 regions.emplace_back(&region, GetLoopRegionDataArgs(region));
1226 } else {
1227 // Unknown successor.
1228 regions.emplace_back(&op.getThenRegion(),
1229 GetLoopRegionDataArgs(op.getThenRegion()));
1230 regions.emplace_back(&op.getElseRegion(),
1231 GetLoopRegionDataArgs(op.getElseRegion()));
1232 }
1233}
1234
1235//===----------------------------------------------------------------------===//
1236// CaseLikeRegionOp
1237
1238// Verify a case-like region op.
1239template <typename CaseLikeRegionOp>
1240static LogicalResult VerifyCaseLikeRegionOp(CaseLikeRegionOp op) {
1241 for (auto &it : llvm::enumerate(op.getBranches())) {
1242 if (!TerminatedByYield(it.value().front())) {
1243 return op.emitOpError("branch region #")
1244 << it.index() << " is not terminated by a 'tfg.yield' op";
1245 }
1246 }
1247
1248 if (op.getBranchAttrs() &&
1249 op.getBranches().size() != op.getBranchAttrs()->size()) {
1250 return op.emitOpError("has ")
1251 << op.getBranches().size() << " regions but "
1252 << op.getBranchAttrs()->size() << " branch function attributes";
1253 }
1254 if (auto region_attrs = op.getRegionAttrsAttr()) {
1255 if (region_attrs.size() != op.getNumRegions()) {
1256 return op.emitOpError("expected ")
1257 << op.getNumRegions() << " region attribute(s) but got "
1258 << region_attrs.size();
1259 }
1260 if (failed(VerifyPreservedAttrs(op, region_attrs.getValue())))
1261 return failure();
1262 }
1263 return success();
1264}
1265
1266// Given a potentially null attribute that would represent a constant value,
1267// try to narrow it to a statically known branch index.
1268static Optional<unsigned> GetStaticallyKnownCaseBranch(Attribute branch_attr) {
1269 auto branch = branch_attr.dyn_cast_or_null<ElementsAttr>();
1270 if (branch && branch.getNumElements() == 1 &&
1271 branch.getElementType().isSignlessInteger(32))
1272 return branch.getSplatValue<unsigned>();
1273 return {};
1274}
1275
1276// Get the successor of the regions of a case-like op.
1277template <typename CaseLikeRegionOp>
1278void GetCaseLikeRegionOpSuccessorRegions(
1279 CaseLikeRegionOp op, Optional<unsigned> index, ArrayRef<Attribute> operands,
1280 SmallVectorImpl<RegionSuccessor> &regions) {
1281 assert(index.has_value() ||
1282 !operands.empty() && "case-like op expected at least 1 operand");
1283 // All branch regions branch back to the parent op.
1284 if (index.has_value()) {
1285 // Ignore the control token.
1286 regions.emplace_back(
1287 ResultRange(op->result_begin(), std::prev(op->result_end())));
1288 } else if (auto branch_index = GetStaticallyKnownCaseBranch(operands[0])) {
1289 // Add only 1 possible successor if the condition is known.
1290 Region &region = op.getBranches()[*branch_index];
1291 regions.emplace_back(&region, GetLoopRegionDataArgs(region));
1292 } else {
1293 // Unknown successor. Add all of them.
1294 for (Region &branch : op.getBranches())
1295 regions.emplace_back(&branch, GetLoopRegionDataArgs(branch));
1296 }
1297}
1298
1299//===----------------------------------------------------------------------===//
1300// ConditionOp
1301
1302MutableOperandRange ConditionOp::getMutableSuccessorOperands(
1303 Optional<unsigned> index) {
1304 // Get the subrange of non-control operands that are forwarded to the
1305 // successor region.
1306 return getArgsMutable();
1307}
1308
1309//===----------------------------------------------------------------------===//
1310// WhileLikeRegionOp
1311
1312// Verify that the loop regions of a region-based loop op have N control tokens
1313// immediately following N data values in their entry block arguments.
1314// `RegionBranchOpInterface` will verify the number of arguments and their
1315// types.
1316static LogicalResult VerifyLoopRegionArgs(Operation *op, Region &region) {
1317 const auto arg_error = [&](BlockArgument arg) {
1318 return op->emitOpError("region #")
1319 << region.getRegionNumber() << " argument #" << arg.getArgNumber()
1320 << " ";
1321 };
1322
1323 // The interface trait verifies the number of data and control arguments. If
1324 // the first half of the arguments are not control tokens, then we know for
1325 // sure that the second half is only control tokens.
1326 for (BlockArgument data : GetLoopRegionDataArgs(region))
1327 if (data.getType().isa<ControlType>())
1328 return arg_error(data) << "should not be a control token";
1329 return success();
1330}
1331
1332// Verify a while-like region op.
1333template <typename WhileLikeRegionOp>
1334static LogicalResult VerifyWhileLikeRegionOp(WhileLikeRegionOp op) {
1335 // Verify terminators.
1336 if (!isa<ConditionOp>(op.cond_block().getTerminator())) {
1337 return op.emitOpError(
1338 "condition region must be terminated by a 'tfg.condition' op");
1339 }
1340 if (!TerminatedByYield(op.body_block()))
1341 op.emitOpError("body region must be terminated by a 'tfg.yield' op");
1342
1343 if (failed(VerifyLoopRegionArgs(op, op.getCondRegion())) ||
1344 failed(VerifyLoopRegionArgs(op, op.getBodyRegion())))
1345 return failure();
1346 if (failed(VerifyPreservedAttrs(
1347 op, {op.getCondRegionAttrsAttr(), op.getBodyRegionAttrsAttr()})))
1348 return failure();
1349
1350 return success();
1351}
1352
1353template <typename WhileLikeRegionOp>
1354static void GetWhileLikeRegionOpSuccessorRegions(
1355 WhileLikeRegionOp op, Optional<unsigned> index,
1356 ArrayRef<Attribute> operands, SmallVectorImpl<RegionSuccessor> &regions) {
1357 // The parent op and the body region always branch to the condion region.
1358 if (!index || *index == 1) {
1359 regions.emplace_back(&op.getCondRegion(),
1360 GetLoopRegionDataArgs(op.getCondRegion()));
1361 return;
1362 }
1363 assert(*index == 0 && "invalid region index");
1364 // The condition regions branches to the loop body or back to the parent.
1365 // Try to narrow the condition value to a constant.
1366 auto condition =
1367 cast<ConditionOp>(op.getCondRegion().front().getTerminator());
1368 Attribute cond_attr;
1369 matchPattern(condition.getCond(), m_Constant(&cond_attr));
1370 Optional<bool> cond = GetStaticallyKnownBranch(cond_attr);
1371 if (!cond || *cond) {
1372 regions.emplace_back(&op.getBodyRegion(),
1373 GetLoopRegionDataArgs(op.getBodyRegion()));
1374 }
1375 if (!cond || !*cond) {
1376 // Drop the control token.
1377 regions.emplace_back(op.getResults().drop_back());
1378 }
1379}
1380
1381//===----------------------------------------------------------------------===//
1382// ForRegionOp
1383
1384LogicalResult ForRegionOp::verify() {
1385 if (!TerminatedByYield(body_block())) {
1386 return emitOpError("body region must be terminated by a 'tfg.yield' op");
1387 }
1388
1389 Block::BlockArgListType args = body_block().getArguments();
1390 if (args.empty()) {
1391 return emitOpError(
1392 "expected the body block to have at least have the loop index as an "
1393 "argument");
1394 }
1395 auto index = args.front().getType().dyn_cast<TensorType>();
1396 if (!index || !index.getElementType().isSignlessInteger(32)) {
1397 return emitOpError(
1398 "expected first body block argument to be an i32 tensor");
1399 }
1400
1401 if (failed(VerifyLoopRegionArgs(*this, getBodyRegion()))) return failure();
1402 return VerifyPreservedAttrs(*this, {getRegionAttrsAttr()});
1403}
1404
1405OperandRange ForRegionOp::getSuccessorEntryOperands(Optional<unsigned> index) {
1406 return getInit();
1407}
1408
1409void ForRegionOp::getSuccessorRegions(
1410 Optional<unsigned> index, ArrayRef<Attribute> operands,
1411 SmallVectorImpl<RegionSuccessor> &regions) {
1412 // Both the parent op and the body region branch to the body. Ignore the loop
1413 // index block argument, as it is not modified by the loop body itself.
1414 regions.emplace_back(&getBodyRegion(),
1415 GetLoopRegionDataArgs(getBodyRegion()).drop_front());
1416 if (!index) return;
1417 // The body might branch back to the parent. Drop the control token.
1418 regions.emplace_back((*this)->getResults().drop_back());
1419}
1420
1421BlockArgument ForRegionOp::getDataValueOf(BlockArgument ctl) {
1422 return GetLoopRegionDataOf(ctl);
1423}
1424BlockArgument ForRegionOp::getControlTokenOf(BlockArgument data) {
1425 return GetLoopRegionControlOf(data);
1426}
1427BlockArgument ForRegionOp::getDataValue(Region &region, unsigned idx) {
1428 return GetLoopRegionDataArgs(region)[idx];
1429}
1430BlockArgument ForRegionOp::getControlToken(Region &region, unsigned idx) {
1431 return GetLoopRegionControlTokens(region)[idx];
1432}
1433
1434//===----------------------------------------------------------------------===//
1435// Function Table
1436//===----------------------------------------------------------------------===//
1437
1438FunctionTable::FunctionTable(ModuleOp module) {
1439 // Collect function names (to be used for disambiguating legacy call
1440 // behavior).
1441 for (auto &op : module.getOps()) {
1442 if (auto func = dyn_cast<GraphFuncOp>(op)) functions.insert(func.getName());
1443 }
1444}
1445
1446bool FunctionTable::MayBeCall(Operation *op) const {
1447 if (IsLegacyCall(op)) return true;
1448 // The operation might be a call if it references a symbol.
1449 bool references_symbol = false;
1450 op->getAttrDictionary().walkSubAttrs(
1451 [&](Attribute attr) { references_symbol |= attr.isa<SymbolRefAttr>(); });
1452 return references_symbol;
1453}
1454
1455bool FunctionTable::IsLegacyCall(Operation *op) const {
1456 // If the operation name refers to a function in the module, then it is
1457 // guaranteed to be a legacy call. Otherwise, it is not.
1458 return functions.count(op->getName().stripDialect());
1459}
1460
1461} // namespace tfg
1462} // namespace mlir
1463
1464//===----------------------------------------------------------------------===//
1465// ODS Definitions
1466//===----------------------------------------------------------------------===//
1467
1468#define GET_OP_CLASSES
1469#include "tensorflow/core/ir/ops.cc.inc"
1470#define GET_ATTRDEF_CLASSES
1471#include "tensorflow/core/ir/attributes.cc.inc"
1472