diff --git a/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp b/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp index 5b4ec9c5cf16..e012ff154f4d 100644 --- a/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp +++ b/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp @@ -802,6 +802,33 @@ class CIRUnaryOpLowering : public mlir::OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; + template + mlir::Operation * + replaceImmediateOp(cir::UnaryOp op, mlir::Type type, mlir::Value input, + int64_t n, + mlir::ConversionPatternRewriter &rewriter) const { + if (type.isFloat()) { + auto imm = mlir::arith::ConstantOp::create( + rewriter, op.getLoc(), + mlir::FloatAttr::get(type, static_cast(n))); + if constexpr (rev) + return rewriter.replaceOpWithNewOp(op, type, imm, input); + else + return rewriter.replaceOpWithNewOp(op, type, input, imm); + } + if (type.isInteger()) { + auto imm = mlir::arith::ConstantOp::create( + rewriter, op.getLoc(), mlir::IntegerAttr::get(type, n)); + if constexpr (rev) + return rewriter.replaceOpWithNewOp(op, type, imm, input); + else + return rewriter.replaceOpWithNewOp(op, type, input, imm); + } + op->emitError("Unsupported type: ") << type << " at " << op->getLoc(); + llvm_unreachable("CIRUnaryOpLowering met unsupported type"); + return nullptr; + } + mlir::LogicalResult matchAndRewrite(cir::UnaryOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { @@ -810,15 +837,13 @@ class CIRUnaryOpLowering : public mlir::OpConversionPattern { switch (op.getKind()) { case cir::UnaryOpKind::Inc: { - auto One = mlir::arith::ConstantOp::create( - rewriter, op.getLoc(), mlir::IntegerAttr::get(type, 1)); - rewriter.replaceOpWithNewOp(op, type, input, One); + replaceImmediateOp( + op, type, input, 1, rewriter); break; } case cir::UnaryOpKind::Dec: { - auto One = mlir::arith::ConstantOp::create( - rewriter, op.getLoc(), mlir::IntegerAttr::get(type, 1)); - rewriter.replaceOpWithNewOp(op, type, input, One); + replaceImmediateOp( + op, type, input, -1, rewriter); break; } case cir::UnaryOpKind::Plus: { @@ -826,20 +851,17 @@ class CIRUnaryOpLowering : public mlir::OpConversionPattern { break; } case cir::UnaryOpKind::Minus: { - auto Zero = mlir::arith::ConstantOp::create( - rewriter, op.getLoc(), mlir::IntegerAttr::get(type, 0)); - rewriter.replaceOpWithNewOp(op, type, Zero, input); + replaceImmediateOp( + op, type, input, 0, rewriter); break; } case cir::UnaryOpKind::Not: { - auto MinusOne = mlir::arith::ConstantOp::create( + auto o = mlir::arith::ConstantOp::create( rewriter, op.getLoc(), mlir::IntegerAttr::get(type, -1)); - rewriter.replaceOpWithNewOp(op, type, MinusOne, - input); + rewriter.replaceOpWithNewOp(op, type, o, input); break; } } - return mlir::LogicalResult::success(); } }; diff --git a/clang/test/CIR/Lowering/ThroughMLIR/if.c b/clang/test/CIR/Lowering/ThroughMLIR/if.c index a7d95d3a104a..142846d6a631 100644 --- a/clang/test/CIR/Lowering/ThroughMLIR/if.c +++ b/clang/test/CIR/Lowering/ThroughMLIR/if.c @@ -29,8 +29,8 @@ void foo() { //CHECK: memref.store %[[SEVEN]], %[[alloca_0]][] : memref //CHECK: } else { //CHECK: %[[SIX:.+]] = memref.load %[[alloca_0]][] : memref -//CHECK: %[[C1_I32:.+]] = arith.constant 1 : i32 -//CHECK: %[[SEVEN:.+]] = arith.subi %[[SIX]], %[[C1_I32]] : i32 +//CHECK: %[[C1_I32:.+]] = arith.constant -1 : i32 +//CHECK: %[[SEVEN:.+]] = arith.addi %[[SIX]], %[[C1_I32]] : i32 //CHECK: memref.store %[[SEVEN]], %[[alloca_0]][] : memref //CHECK: } //CHECK: } @@ -106,8 +106,8 @@ void foo3() { //CHECK: memref.store %[[THIRTEEN]], %[[alloca_0]][] : memref //CHECK: } else { //CHECK: %[[TWELVE:.+]] = memref.load %[[alloca_0]][] : memref -//CHECK: %[[C1_I32_5:.+]] = arith.constant 1 : i32 -//CHECK: %[[THIRTEEN:.+]] = arith.subi %[[TWELVE]], %[[C1_I32_5]] : i32 +//CHECK: %[[C1_I32_5:.+]] = arith.constant -1 : i32 +//CHECK: %[[THIRTEEN:.+]] = arith.addi %[[TWELVE]], %[[C1_I32_5]] : i32 //CHECK: memref.store %[[THIRTEEN]], %[[alloca_0]][] : memref //CHECK: } //CHECK: } diff --git a/clang/test/CIR/Lowering/ThroughMLIR/unary-inc-dec.cir b/clang/test/CIR/Lowering/ThroughMLIR/unary-inc-dec.cir index 1db339fe34fc..35bcc3a42962 100644 --- a/clang/test/CIR/Lowering/ThroughMLIR/unary-inc-dec.cir +++ b/clang/test/CIR/Lowering/ThroughMLIR/unary-inc-dec.cir @@ -1,5 +1,5 @@ -// RUN: cir-opt %s -cir-to-mlir -o - | FileCheck %s -check-prefix=MLIR -// RUN: cir-opt %s -cir-to-mlir -cir-mlir-to-llvm -o - | mlir-translate -mlir-to-llvmir | FileCheck %s -check-prefix=LLVM +// RUN: cir-opt %s --cir-to-mlir -o - | FileCheck %s -check-prefix=MLIR +// RUN: cir-opt %s --cir-to-mlir -cir-mlir-to-llvm -o - | mlir-translate -mlir-to-llvmir | FileCheck %s -check-prefix=LLVM !s32i = !cir.int module { @@ -17,14 +17,32 @@ module { %5 = cir.load %1 : !cir.ptr, !s32i %6 = cir.unary(dec, %5) : !s32i, !s32i cir.store %6, %1 : !s32i, !cir.ptr + + // test float + %7 = cir.alloca !s32i, !cir.ptr, ["b", init] {alignment = 4 : i64} cir.return } -} // MLIR: = arith.constant 1 // MLIR: = arith.addi -// MLIR: = arith.constant 1 -// MLIR: = arith.subi +// MLIR: = arith.constant -1 +// MLIR: = arith.addi // LLVM: = add i32 %[[#]], 1 -// LLVM: = sub i32 %[[#]], 1 +// LLVM: = add i32 %[[#]], -1 + + + cir.func @floatingPoints(%arg0: !cir.double) { + %0 = cir.alloca !cir.double, !cir.ptr, ["X", init] {alignment = 8 : i64} + cir.store %arg0, %0 : !cir.double, !cir.ptr + %1 = cir.load %0 : !cir.ptr, !cir.double + %2 = cir.unary(inc, %1) : !cir.double, !cir.double + %3 = cir.load %0 : !cir.ptr, !cir.double + %4 = cir.unary(dec, %3) : !cir.double, !cir.double + cir.return + } +// MLIR: = arith.constant 1.0 +// MLIR: = arith.addf +// MLIR: = arith.constant -1.0 +// MLIR: = arith.addf +} diff --git a/clang/test/CIR/Lowering/ThroughMLIR/unary-plus-minus.cir b/clang/test/CIR/Lowering/ThroughMLIR/unary-plus-minus.cir index ecb7e7ef6734..a6101358e5a4 100644 --- a/clang/test/CIR/Lowering/ThroughMLIR/unary-plus-minus.cir +++ b/clang/test/CIR/Lowering/ThroughMLIR/unary-plus-minus.cir @@ -19,6 +19,16 @@ module { cir.store %6, %1 : !s32i, !cir.ptr cir.return } + + cir.func @floatingPoints(%arg0: !cir.double) { + %0 = cir.alloca !cir.double, !cir.ptr, ["X", init] {alignment = 8 : i64} + cir.store %arg0, %0 : !cir.double, !cir.ptr + %1 = cir.load %0 : !cir.ptr, !cir.double + %2 = cir.unary(plus, %1) : !cir.double, !cir.double + %3 = cir.load %0 : !cir.ptr, !cir.double + %4 = cir.unary(minus, %3) : !cir.double, !cir.double + cir.return + } } // MLIR: %[[#INPUT_PLUS:]] = memref.load