aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp')
-rw-r--r--mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp29
1 files changed, 28 insertions, 1 deletions
diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
index 0196a21f4a69..2a6e7f281860 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
@@ -663,6 +663,17 @@ public:
ConversionPatternRewriter &rewriter) const override;
};
+/// Converts std.xor to SPIR-V operations if the type of source is i1 or vector
+/// of i1.
+class BoolXOrOpPattern final : public OpConversionPattern<XOrOp> {
+public:
+ using OpConversionPattern<XOrOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(XOrOp xorOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
} // namespace
//===----------------------------------------------------------------------===//
@@ -1250,6 +1261,22 @@ XOrOpPattern::matchAndRewrite(XOrOp xorOp, ArrayRef<Value> operands,
return success();
}
+LogicalResult
+BoolXOrOpPattern::matchAndRewrite(XOrOp xorOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
+ assert(operands.size() == 2);
+
+ if (!isBoolScalarOrVector(operands.front().getType()))
+ return failure();
+
+ auto dstType = getTypeConverter()->convertType(xorOp.getType());
+ if (!dstType)
+ return failure();
+ rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(xorOp, dstType,
+ operands);
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// Pattern population
//===----------------------------------------------------------------------===//
@@ -1293,7 +1320,7 @@ void populateStandardToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
UnaryAndBinaryOpPattern<UnsignedDivIOp, spirv::UDivOp>,
UnaryAndBinaryOpPattern<UnsignedRemIOp, spirv::UModOp>,
UnaryAndBinaryOpPattern<UnsignedShiftRightOp, spirv::ShiftRightLogicalOp>,
- SignedRemIOpPattern, XOrOpPattern,
+ SignedRemIOpPattern, XOrOpPattern, BoolXOrOpPattern,
// Comparison patterns
BoolCmpIOpPattern, CmpFOpPattern, CmpFOpNanNonePattern, CmpIOpPattern,