diff options
-rw-r--r-- | mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp | 29 | ||||
-rw-r--r-- | mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir | 4 |
2 files changed, 32 insertions, 1 deletions
diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp index 0196a21f4a69..2a6e7f281860 100644 --- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp @@ -663,6 +663,17 @@ public: ConversionPatternRewriter &rewriter) const override; }; +/// Converts std.xor to SPIR-V operations if the type of source is i1 or vector +/// of i1. +class BoolXOrOpPattern final : public OpConversionPattern<XOrOp> { +public: + using OpConversionPattern<XOrOp>::OpConversionPattern; + + LogicalResult + matchAndRewrite(XOrOp xorOp, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override; +}; + } // namespace //===----------------------------------------------------------------------===// @@ -1250,6 +1261,22 @@ XOrOpPattern::matchAndRewrite(XOrOp xorOp, ArrayRef<Value> operands, return success(); } +LogicalResult +BoolXOrOpPattern::matchAndRewrite(XOrOp xorOp, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const { + assert(operands.size() == 2); + + if (!isBoolScalarOrVector(operands.front().getType())) + return failure(); + + auto dstType = getTypeConverter()->convertType(xorOp.getType()); + if (!dstType) + return failure(); + rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(xorOp, dstType, + operands); + return success(); +} + //===----------------------------------------------------------------------===// // Pattern population //===----------------------------------------------------------------------===// @@ -1293,7 +1320,7 @@ void populateStandardToSPIRVPatterns(SPIRVTypeConverter &typeConverter, UnaryAndBinaryOpPattern<UnsignedDivIOp, spirv::UDivOp>, UnaryAndBinaryOpPattern<UnsignedRemIOp, spirv::UModOp>, UnaryAndBinaryOpPattern<UnsignedShiftRightOp, spirv::ShiftRightLogicalOp>, - SignedRemIOpPattern, XOrOpPattern, + SignedRemIOpPattern, XOrOpPattern, BoolXOrOpPattern, // Comparison patterns BoolCmpIOpPattern, CmpFOpPattern, CmpFOpNanNonePattern, CmpIOpPattern, diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir index 0148a0731dc9..fe769482c787 100644 --- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir @@ -224,6 +224,8 @@ func @logical_scalar(%arg0 : i1, %arg1 : i1) { %0 = and %arg0, %arg1 : i1 // CHECK: spv.LogicalOr %1 = or %arg0, %arg1 : i1 + // CHECK: spv.LogicalNotEqual + %2 = xor %arg0, %arg1 : i1 return } @@ -233,6 +235,8 @@ func @logical_vector(%arg0 : vector<4xi1>, %arg1 : vector<4xi1>) { %0 = and %arg0, %arg1 : vector<4xi1> // CHECK: spv.LogicalOr %1 = or %arg0, %arg1 : vector<4xi1> + // CHECK: spv.LogicalNotEqual + %2 = xor %arg0, %arg1 : vector<4xi1> return } |