Skip to content

Commit 537f124

Browse files
authored
[mlir][LLVM] Add fastmath flags support to fpext/fptrunc ops. (#192185)
Add fastmath attributes to llvm fpext/fptrunc ops, FastmathFlagsInterface op interface support.
1 parent 57d2a2c commit 537f124

3 files changed

Lines changed: 38 additions & 2 deletions

File tree

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -588,6 +588,26 @@ class LLVM_CastOpWithOverflowFlag<string mnemonic, string instName, Type type,
588588
}];
589589
}
590590

591+
class LLVM_CastOpWithFastMathFlag<string mnemonic, string instName, Type type,
592+
Type resultType, list<Trait> traits = []> :
593+
LLVM_Op<mnemonic, !listconcat([Pure], [DeclareOpInterfaceMethods<FastmathFlagsInterface>], traits)>,
594+
LLVM_Builder<"$res = builder.Create" # instName # "($arg, $_resultType);"> {
595+
let arguments = (
596+
ins type:$arg,
597+
DefaultValuedAttr<LLVM_FastmathFlagsAttr, "{}">:$fastmathFlags);
598+
let results = (outs resultType:$res);
599+
let builders = [LLVM_OneResultOpBuilder];
600+
let assemblyFormat = "$arg (`fastmath` `` $fastmathFlags^)? "
601+
"attr-dict `:` type($arg) `to` type($res)";
602+
string llvmInstName = instName;
603+
string mlirBuilder = [{
604+
auto op = $_qualCppClassName::create($_builder,
605+
$_location, $_resultType, $arg);
606+
moduleImport.setFastmathFlagsAttr(inst, op);
607+
$res = op;
608+
}];
609+
}
610+
591611
class LLVM_DereferenceableCastOp<string mnemonic, string instName, Type type,
592612
Type resultType, list<Trait> traits = []> :
593613
LLVM_Op<mnemonic, !listconcat([Pure], [DeclareOpInterfaceMethods<DereferenceableOpInterface>], traits)> {
@@ -699,10 +719,10 @@ def LLVM_FPToSIOp : LLVM_CastOp<"fptosi", "FPToSI",
699719
def LLVM_FPToUIOp : LLVM_CastOp<"fptoui", "FPToUI",
700720
LLVM_ScalarOrVectorOf<LLVM_AnyFloat>,
701721
LLVM_ScalarOrVectorOf<AnySignlessInteger>>;
702-
def LLVM_FPExtOp : LLVM_CastOp<"fpext", "FPExt",
722+
def LLVM_FPExtOp : LLVM_CastOpWithFastMathFlag<"fpext", "FPExt",
703723
LLVM_ScalarOrVectorOf<LLVM_AnyFloat>,
704724
LLVM_ScalarOrVectorOf<LLVM_AnyFloat>>;
705-
def LLVM_FPTruncOp : LLVM_CastOp<"fptrunc", "FPTrunc",
725+
def LLVM_FPTruncOp : LLVM_CastOpWithFastMathFlag<"fptrunc", "FPTrunc",
706726
LLVM_ScalarOrVectorOf<LLVM_AnyFloat>,
707727
LLVM_ScalarOrVectorOf<LLVM_AnyFloat>>;
708728

mlir/test/Target/LLVMIR/Import/fastmath.ll

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,17 @@ define void @fastmath_inst(float %arg1, float %arg2, i1 %arg3) {
1919

2020
; // -----
2121

22+
; CHECK-LABEL: @fastmath_cast
23+
define void @fastmath_cast(float %arg1) {
24+
; CHECK: llvm.fpext %{{.*}} fastmath<nnan> : f32 to f64
25+
%1 = fpext nnan float %arg1 to double
26+
; CHECK: llvm.fptrunc %{{.*}} fastmath<fast> : f32 to f16
27+
%2 = fptrunc fast float %arg1 to half
28+
ret void
29+
}
30+
31+
; // -----
32+
2233
; CHECK-LABEL: @fastmath_fcmp
2334
define void @fastmath_fcmp(float %arg1, float %arg2) {
2435
; CHECK: llvm.fcmp "oge" %{{.*}}, %{{.*}} {fastmathFlags = #llvm.fastmath<nsz>} : f32

mlir/test/Target/LLVMIR/llvmir.mlir

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2208,6 +2208,11 @@ llvm.func @fastmathFlags(%arg0: f32, %arg1 : vector<2xf32>) {
22082208
%25 = llvm.mlir.constant(true) : i1
22092209
// CHECK: select contract i1
22102210
%26 = llvm.select %25, %arg0, %20 {fastmathFlags = #llvm.fastmath<contract>} : i1, f32
2211+
2212+
// CHECK: {{.*}} = fpext nnan float {{.*}} to double
2213+
// CHECK: {{.*}} = fptrunc fast float {{.*}} to half
2214+
%27 = llvm.fpext %arg0 fastmath<nnan> : f32 to f64
2215+
%28 = llvm.fptrunc %arg0 fastmath<fast> : f32 to f16
22112216
llvm.return
22122217
}
22132218

0 commit comments

Comments
 (0)