Skip to content

Commit 2a74f30

Browse files
[CIR] Add coroutine cleanup handling and update co_return semantics (#189281)
This PR adds cleanup handling for coroutine frame destruction. The cleanup is emitted as a conditional that checks the result of the `coro.free` builtin, which is used to determine whether the coroutine frame was heap-allocated, if the returned pointer is null, no destruction is performed. Additionally, this PR changes how co_return is represented: previously, it was lowered directly into a branch to the block containing the final suspend logic, but now a new `cir.coro.body` operation is introduced to represent the user-written coroutine body. Inside this region, `cir.co_return` operations mark exits from the coroutine body and represent structured control flow that transfers execution to the final suspend point. The lowering of this structured control flow into explicit branches is deferred to a future PR in the FlattenCFG pass.
1 parent 0bdcf4e commit 2a74f30

15 files changed

Lines changed: 760 additions & 227 deletions

File tree

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 67 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -954,9 +954,9 @@ def CIR_ConditionOp : CIR_Op<"condition", [
954954
//===----------------------------------------------------------------------===//
955955

956956
defvar CIR_YieldableScopes = [
957-
"ArrayCtor", "ArrayDtor", "AwaitOp", "CaseOp", "CleanupScopeOp", "DoWhileOp",
958-
"ForOp", "GlobalOp", "IfOp", "ScopeOp", "SwitchOp", "TernaryOp", "TryOp",
959-
"WhileOp"
957+
"ArrayCtor", "ArrayDtor", "AwaitOp", "CaseOp", "CleanupScopeOp", "CoroBodyOp",
958+
"DoWhileOp", "ForOp", "GlobalOp", "IfOp", "ScopeOp", "SwitchOp", "TernaryOp",
959+
"TryOp", "WhileOp"
960960
];
961961

962962
def CIR_YieldOp : CIR_Op<"yield", [
@@ -4156,6 +4156,70 @@ def CIR_AwaitOp : CIR_Op<"await",[
41564156

41574157
let hasVerifier = 1;
41584158
}
4159+
//===----------------------------------------------------------------------===//
4160+
// CoroBody
4161+
//===----------------------------------------------------------------------===//
4162+
def CIR_CoroBodyOp : CIR_Op<"coro.body", [
4163+
DeclareOpInterfaceMethods<RegionBranchOpInterface, ["getSuccessorInputs"]>,
4164+
RecursivelySpeculatable, AutomaticAllocationScope, NoRegionArguments,
4165+
RecursiveMemoryEffects
4166+
]> {
4167+
let summary = "Region containing the user-authored coroutine body";
4168+
let description = [{
4169+
The `cir.coro.body` operation models the region where the user-authored
4170+
coroutine code is emitted.
4171+
4172+
This operation serves as a structural boundary separating the coroutine
4173+
setup and teardown logic (e.g. initial suspend, final suspend, and cleanup)
4174+
from the user-provided statements inside the coroutine.
4175+
4176+
The body region contains the code corresponding to the original function
4177+
body, including `co_await` and `co_return` expressions. In particular,
4178+
`cir.co_return` operations inside this region mark coroutine exit points
4179+
and introduce structured control flow that transfers execution to the
4180+
final suspend point of the coroutine.
4181+
}];
4182+
4183+
let regions = (region AnyRegion:$body);
4184+
4185+
let skipDefaultBuilders = 1;
4186+
4187+
let builders = [
4188+
OpBuilder<(ins "BuilderCallbackRef":$bodyBuilder)>
4189+
];
4190+
4191+
let assemblyFormat = [{
4192+
$body attr-dict
4193+
}];
4194+
4195+
let hasLLVMLowering = false;
4196+
let hasVerifier = 1;
4197+
}
4198+
4199+
//===----------------------------------------------------------------------===//
4200+
// CoReturnOp
4201+
//===----------------------------------------------------------------------===//
4202+
4203+
def CIR_CoReturnOp : CIR_Op<"co_return", [
4204+
ReturnLike, Pure, Terminator
4205+
]> {
4206+
let summary = "Coroutine return operation";
4207+
let description = [{
4208+
The `cir.co_return` operation models a coroutine return point inside a
4209+
`cir.coro.body` region.
4210+
This operation is expected to appear only within a `cir.coro.body` region,
4211+
but it may be nested within other operations or regions inside that body.
4212+
}];
4213+
4214+
let assemblyFormat = [{
4215+
attr-dict
4216+
}];
4217+
4218+
let hasVerifier = 1;
4219+
4220+
let hasLLVMLowering = false;
4221+
}
4222+
41594223

41604224
//===----------------------------------------------------------------------===//
41614225
// CopyOp

clang/lib/CIR/CodeGen/CIRGenBuiltin.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1271,6 +1271,7 @@ RValue CIRGenFunction::emitBuiltinExpr(const GlobalDecl &gd, unsigned builtinID,
12711271
return emitCoroutineFrame();
12721272
}
12731273
case Builtin::BI__builtin_coro_free:
1274+
return RValue::get(emitCoroFreeBuiltin(e).getResult());
12741275
case Builtin::BI__builtin_coro_size: {
12751276
GlobalDecl gd{fd};
12761277
mlir::Type ty = cgm.getTypes().getFunctionType(

clang/lib/CIR/CodeGen/CIRGenCoroutine.cpp

Lines changed: 125 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -33,17 +33,16 @@ struct clang::CIRGen::CGCoroData {
3333
// Stores the result of __builtin_coro_begin call.
3434
mlir::Value coroBegin = nullptr;
3535

36-
// Stores the insertion point for final suspend, this happens after the
37-
// promise call (return_xxx promise member) but before a cir.br to the return
38-
// block.
39-
mlir::Operation *finalSuspendInsPoint;
40-
4136
// How many co_return statements are in the coroutine. Used to decide whether
4237
// we need to add co_return; equivalent at the end of the user authored body.
4338
unsigned coreturnCount = 0;
4439

4540
// The promise type's 'unhandled_exception' handler, if it defines one.
4641
Stmt *exceptionHandler = nullptr;
42+
43+
// Stores the last emitted coro.free for the deallocate expressions, we use it
44+
// to wrap dealloc code with if(auto mem = coro.free) dealloc(mem).
45+
cir::CallOp lastCoroFree = nullptr;
4746
};
4847

4948
// Defining these here allows to keep CGCoroData private to this file.
@@ -110,6 +109,63 @@ struct ParamReferenceReplacerRAII {
110109
};
111110
} // namespace
112111

112+
namespace {
113+
// Make sure to call coro.delete on scope exit.
114+
struct CallCoroDelete final : public EHScopeStack::Cleanup {
115+
Stmt *deallocate;
116+
117+
// Emit "if (coro.free(CoroId, CoroBegin)) Deallocate;"
118+
119+
// Note: That deallocation will be emitted twice: once for a normal exit and
120+
// once for exceptional exit. This usage is safe because Deallocate does not
121+
// contain any declarations. The SubStmtBuilder::makeNewAndDeleteExpr()
122+
// builds a single call to a deallocation function which is safe to emit
123+
// multiple times.
124+
void emit(CIRGenFunction &cgf, Flags) override {
125+
// Remember the current point, as we are going to emit deallocation code
126+
// first to get to coro.free instruction that is an argument to a delete
127+
// call.
128+
129+
if (cgf.emitStmt(deallocate, /*useCurrentScope=*/true).failed()) {
130+
cgf.cgm.error(deallocate->getBeginLoc(),
131+
"failed to emit coroutine deallocation expression");
132+
return;
133+
}
134+
135+
CIRGenBuilderTy &builder = cgf.getBuilder();
136+
cir::CallOp coroFree = cgf.curCoro.data->lastCoroFree;
137+
138+
if (!coroFree) {
139+
cgf.cgm.error(deallocate->getBeginLoc(),
140+
"Deallocation expression does not refer to coro.free");
141+
return;
142+
}
143+
144+
builder.setInsertionPointAfter(coroFree);
145+
mlir::Value isPtrNotNull = builder.createPtrIsNotNull(coroFree.getResult());
146+
147+
llvm::SmallVector<mlir::Operation *> opsToMove;
148+
mlir::Block *block = builder.getInsertionBlock();
149+
mlir::Block::iterator it(isPtrNotNull.getDefiningOp());
150+
151+
for (++it; it != block->end(); ++it)
152+
opsToMove.push_back(&*it);
153+
154+
auto ifOp =
155+
cir::IfOp::create(builder, cgf.getLoc(deallocate->getSourceRange()),
156+
isPtrNotNull, /*withElseRegion*/ false,
157+
[&](mlir::OpBuilder &builder, mlir::Location loc) {
158+
cir::YieldOp::create(builder, loc);
159+
});
160+
161+
mlir::Operation *yieldOp = ifOp.getThenRegion().back().getTerminator();
162+
for (auto *op : opsToMove)
163+
op->moveBefore(yieldOp);
164+
}
165+
explicit CallCoroDelete(Stmt *deallocStmt) : deallocate(deallocStmt) {}
166+
};
167+
} // namespace
168+
113169
RValue CIRGenFunction::emitCoroutineFrame() {
114170
if (curCoro.data && curCoro.data->coroBegin) {
115171
return RValue::get(curCoro.data->coroBegin);
@@ -235,6 +291,28 @@ cir::CallOp CIRGenFunction::emitCoroEndBuiltinCall(mlir::Location loc,
235291
loc, fnOp, mlir::ValueRange{nullPtr, builder.getBool(false, loc)});
236292
}
237293

294+
cir::CallOp CIRGenFunction::emitCoroFreeBuiltin(const CallExpr *e) {
295+
mlir::Operation *builtin = cgm.getGlobalValue(cgm.builtinCoroFree);
296+
mlir::Location loc = getLoc(e->getBeginLoc());
297+
cir::FuncOp fnOp;
298+
if (!builtin) {
299+
fnOp = cgm.createCIRBuiltinFunction(
300+
loc, cgm.builtinCoroFree,
301+
cir::FuncType::get({uInt32Ty, voidPtrTy}, voidPtrTy),
302+
/*fd=*/nullptr);
303+
assert(fnOp && "should always succeed");
304+
} else {
305+
fnOp = cast<cir::FuncOp>(builtin);
306+
}
307+
cir::CallOp coroFree =
308+
builder.createCallOp(loc, fnOp,
309+
mlir::ValueRange{curCoro.data->coroId.getResult(),
310+
curCoro.data->coroBegin});
311+
312+
curCoro.data->lastCoroFree = coroFree;
313+
return coroFree;
314+
}
315+
238316
mlir::LogicalResult
239317
CIRGenFunction::emitCoroutineBody(const CoroutineBodyStmt &s) {
240318
mlir::Location openCurlyLoc = getLoc(s.getBeginLoc());
@@ -280,6 +358,8 @@ CIRGenFunction::emitCoroutineBody(const CoroutineBodyStmt &s) {
280358
{
281359
assert(!cir::MissingFeatures::generateDebugInfo());
282360
ParamReferenceReplacerRAII paramReplacer(localDeclMap);
361+
RunCleanupsScope resumeScope(*this);
362+
ehStack.pushCleanup<CallCoroDelete>(NormalAndEHCleanup, s.getDeallocate());
283363
// Create mapping between parameters and copy-params for coroutine
284364
// function.
285365
llvm::ArrayRef<const Stmt *> paramMoves = s.getParamMoves();
@@ -326,11 +406,31 @@ CIRGenFunction::emitCoroutineBody(const CoroutineBodyStmt &s) {
326406

327407
curCoro.data->currentAwaitKind = cir::AwaitKind::User;
328408

329-
// FIXME(cir): wrap emitBodyAndFallthrough with try/catch bits.
330-
if (s.getExceptionHandler())
331-
assert(!cir::MissingFeatures::coroutineExceptions());
332-
if (emitBodyAndFallthrough(*this, s, s.getBody(), curLexScope).failed())
333-
return mlir::failure();
409+
mlir::OpBuilder::InsertPoint userBody;
410+
auto coroBodyOp =
411+
cir::CoroBodyOp::create(builder, openCurlyLoc, /*scopeBuilder=*/
412+
[&](mlir::OpBuilder &b, mlir::Location loc) {
413+
userBody = b.saveInsertionPoint();
414+
});
415+
{
416+
mlir::OpBuilder::InsertionGuard guard(builder);
417+
builder.restoreInsertionPoint(userBody);
418+
// FIXME(cir): wrap emitBodyAndFallthrough with try/catch bits.
419+
if (s.getExceptionHandler()) {
420+
assert(!cir::MissingFeatures::coroutineExceptions());
421+
cgm.errorNYI("exceptions in coroutines are not yet supported in CIR");
422+
}
423+
if (emitBodyAndFallthrough(*this, s, s.getBody(), curLexScope).failed()) {
424+
return mlir::failure();
425+
}
426+
}
427+
428+
mlir::Block &coroBodyBlock = coroBodyOp.getBody().back();
429+
if (!coroBodyBlock.mightHaveTerminator()) {
430+
mlir::OpBuilder::InsertionGuard guard(builder);
431+
builder.setInsertionPointToEnd(&coroBodyBlock);
432+
cir::YieldOp::create(builder, openCurlyLoc);
433+
}
334434

335435
// Note that LLVM checks CanFallthrough by looking into the availability
336436
// of the insert block which is kinda brittle and unintuitive, seems to be
@@ -346,13 +446,26 @@ CIRGenFunction::emitCoroutineBody(const CoroutineBodyStmt &s) {
346446
curCoro.data->currentAwaitKind = cir::AwaitKind::Final;
347447
{
348448
mlir::OpBuilder::InsertionGuard guard(builder);
349-
builder.setInsertionPoint(curCoro.data->finalSuspendInsPoint);
350449
if (emitStmt(s.getFinalSuspendStmt(), /*useCurrentScope=*/true)
351450
.failed())
352451
return mlir::failure();
353452
}
354453
}
355454
}
455+
456+
emitCoroEndBuiltinCall(
457+
openCurlyLoc, builder.getNullPtr(builder.getVoidPtrTy(), openCurlyLoc));
458+
if (auto *ret = cast_or_null<ReturnStmt>(s.getReturnStmt())) {
459+
// Since we already emitted the return value above, so we shouldn't
460+
// emit it again here.
461+
Expr *previousRetValue = ret->getRetValue();
462+
ret->setRetValue(nullptr);
463+
if (emitStmt(ret, /*useCurrentScope=*/true).failed())
464+
return mlir::failure();
465+
// Set the return value back. The code generator, as the AST **Consumer**,
466+
// shouldn't change the AST.
467+
ret->setRetValue(previousRetValue);
468+
}
356469
return mlir::success();
357470
}
358471

@@ -538,13 +651,7 @@ mlir::LogicalResult CIRGenFunction::emitCoreturnStmt(CoreturnStmt const &s) {
538651
// it. The actual return instruction is only inserted during current
539652
// scope cleanup handling.
540653
mlir::Location loc = getLoc(s.getSourceRange());
541-
mlir::Block *retBlock = curLexScope->getOrCreateRetBlock(*this, loc);
542-
curCoro.data->finalSuspendInsPoint =
543-
cir::BrOp::create(builder, loc, retBlock);
544-
545-
// Insert the new block to continue codegen after branch to ret block,
546-
// this will likely be an empty block.
547-
builder.createBlock(builder.getBlock()->getParent());
654+
cir::CoReturnOp::create(builder, loc);
548655

549656
return mlir::success();
550657
}

clang/lib/CIR/CodeGen/CIRGenFunction.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -318,10 +318,6 @@ cir::ReturnOp CIRGenFunction::LexicalScope::emitReturn(mlir::Location loc) {
318318
auto fn = dyn_cast<cir::FuncOp>(cgf.curFn);
319319
assert(fn && "emitReturn from non-function");
320320

321-
// If we are on a coroutine, add the coro_end builtin call.
322-
if (fn.getCoroutine())
323-
cgf.emitCoroEndBuiltinCall(loc,
324-
builder.getNullPtr(builder.getVoidPtrTy(), loc));
325321
if (!fn.getFunctionType().hasVoidReturn()) {
326322
// Load the value from `__retval` and return it via the `cir.return` op.
327323
auto value = cir::LoadOp::create(

clang/lib/CIR/CodeGen/CIRGenFunction.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1677,6 +1677,8 @@ class CIRGenFunction : public CIRGenTypeCache {
16771677
cir::CallOp emitCoroAllocBuiltinCall(mlir::Location loc);
16781678
cir::CallOp emitCoroBeginBuiltinCall(mlir::Location loc,
16791679
mlir::Value coroframeAddr);
1680+
1681+
cir::CallOp emitCoroFreeBuiltin(const CallExpr *e);
16801682
RValue emitCoroutineFrame();
16811683

16821684
void emitDestroy(Address addr, QualType type, Destroyer *destroyer);

clang/lib/CIR/CodeGen/CIRGenModule.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -705,6 +705,7 @@ class CIRGenModule : public CIRGenTypeCache {
705705
static constexpr const char *builtinCoroAlloc = "__builtin_coro_alloc";
706706
static constexpr const char *builtinCoroBegin = "__builtin_coro_begin";
707707
static constexpr const char *builtinCoroEnd = "__builtin_coro_end";
708+
static constexpr const char *builtinCoroFree = "__builtin_coro_free";
708709

709710
/// Given a builtin id for a function like "__builtin_fabsf", return a
710711
/// Function* for "fabsf".

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2538,15 +2538,24 @@ mlir::LogicalResult cir::FuncOp::verify() {
25382538

25392539
if (!isDeclaration() && getCoroutine()) {
25402540
bool foundAwait = false;
2541+
int coroBodyCount = 0;
25412542
this->walk([&](Operation *op) {
25422543
if (auto await = dyn_cast<AwaitOp>(op)) {
25432544
foundAwait = true;
2544-
return;
2545+
} else if (isa<CoroBodyOp>(op)) {
2546+
coroBodyCount++;
2547+
if (coroBodyCount > 1) {
2548+
return mlir::WalkResult::interrupt();
2549+
}
25452550
}
2551+
return mlir::WalkResult::advance();
25462552
});
25472553
if (!foundAwait)
25482554
return emitOpError()
25492555
<< "coroutine body must use at least one cir.await op";
2556+
if (coroBodyCount != 1)
2557+
return emitOpError()
2558+
<< "coroutine function must have exactly one cir.body op";
25502559
}
25512560

25522561
llvm::SmallSet<llvm::StringRef, 16> labels;
@@ -2959,6 +2968,48 @@ LogicalResult cir::AwaitOp::verify() {
29592968
return success();
29602969
}
29612970

2971+
LogicalResult cir::CoReturnOp::verify() {
2972+
if (!getOperation()->getParentOfType<CoroBodyOp>())
2973+
return emitOpError("must be inside a cir.coro.body");
2974+
return success();
2975+
}
2976+
2977+
//===----------------------------------------------------------------------===//
2978+
// CoroBody
2979+
//===----------------------------------------------------------------------===//
2980+
2981+
void cir::CoroBodyOp::getSuccessorRegions(
2982+
mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
2983+
if (!point.isParent()) {
2984+
regions.push_back(RegionSuccessor::parent());
2985+
return;
2986+
}
2987+
2988+
regions.push_back(RegionSuccessor(&getBody()));
2989+
}
2990+
2991+
mlir::ValueRange
2992+
cir::CoroBodyOp::getSuccessorInputs(RegionSuccessor successor) {
2993+
return ValueRange();
2994+
}
2995+
2996+
LogicalResult cir::CoroBodyOp::verify() {
2997+
if (!getOperation()->getParentOfType<FuncOp>().getCoroutine())
2998+
return emitOpError("enclosing function must be a coroutine");
2999+
return success();
3000+
}
3001+
3002+
void cir::CoroBodyOp::build(OpBuilder &builder, OperationState &result,
3003+
BuilderCallbackRef bodyBuilder) {
3004+
assert(bodyBuilder &&
3005+
"the builder callback for 'CoroBodyOp' must be present");
3006+
OpBuilder::InsertionGuard guard(builder);
3007+
3008+
Region *bodyRegion = result.addRegion();
3009+
builder.createBlock(bodyRegion);
3010+
bodyBuilder(builder, result.location);
3011+
}
3012+
29623013
//===----------------------------------------------------------------------===//
29633014
// CopyOp Definitions
29643015
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)