Skip to content

Commit

Permalink
[Sema][CodeGen] Support __builtin_<op>_overflow with __intcap
Browse files Browse the repository at this point in the history
Morello LLVM has downstream support for this, but it's both incomplete
(see https://git.morello-project.org/morello/llvm-project/-/issues/80)
and incorrect with regards to provenance (in that it takes a naive
type-based approach rather than considering the cheri_no_provenance
attribute, meaning it differs from the binary operators in provenance
semantics). This is a from-scratch implementation that aims to not have
the same shortcomings.
  • Loading branch information
jrtc27 committed Aug 2, 2024
1 parent 4f37a7e commit def11ee
Show file tree
Hide file tree
Showing 4 changed files with 5,663 additions and 13 deletions.
130 changes: 117 additions & 13 deletions clang/lib/CodeGen/CGBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include "llvm/ADT/StringExtras.h"
#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/InlineAsm.h"
#include "llvm/IR/IntrinsicInst.h"
Expand Down Expand Up @@ -698,9 +699,7 @@ static WidthAndSignedness
getIntegerWidthAndSignedness(const clang::ASTContext &context,
const clang::QualType Type) {
assert(Type->isIntegerType() && "Given type is not an integer.");
unsigned Width = Type->isBooleanType() ? 1
: Type->isBitIntType() ? context.getIntWidth(Type)
: context.getTypeInfo(Type).Width;
unsigned Width = context.getIntWidth(Type);
bool Signed = Type->isSignedIntegerType();
return {Width, Signed};
}
Expand Down Expand Up @@ -1925,14 +1924,40 @@ static RValue EmitCheckedUnsignedMultiplySignedResult(
CodeGenFunction &CGF, const clang::Expr *Op1, WidthAndSignedness Op1Info,
const clang::Expr *Op2, WidthAndSignedness Op2Info,
const clang::Expr *ResultArg, QualType ResultQTy,
WidthAndSignedness ResultInfo) {
WidthAndSignedness ResultInfo, SourceLocation Loc) {
assert(isSpecialUnsignedMultiplySignedResult(
Builtin::BI__builtin_mul_overflow, Op1Info, Op2Info, ResultInfo) &&
"Cannot specialize this multiply");

clang::QualType Op1QTy = Op1->getType();
clang::QualType Op2QTy = Op2->getType();
bool Op1IsCap = Op1QTy->isCHERICapabilityType(CGF.getContext());
bool Op2IsCap = Op2QTy->isCHERICapabilityType(CGF.getContext());
bool ResultIsCap = ResultQTy->isCHERICapabilityType(CGF.getContext());

llvm::Value *V1 = CGF.EmitScalarExpr(Op1);
llvm::Value *V2 = CGF.EmitScalarExpr(Op2);

llvm::Value *ProvenanceCap = nullptr;
if (ResultIsCap) {
bool Op1NoProvenance =
!Op1IsCap || Op1QTy->hasAttr(attr::CHERINoProvenance);
bool Op2NoProvenance =
!Op2IsCap || Op2QTy->hasAttr(attr::CHERINoProvenance);
if (Op1NoProvenance && Op2NoProvenance)
ProvenanceCap = llvm::ConstantPointerNull::get(CGF.Int8CheriCapTy);
else if (Op1NoProvenance)
ProvenanceCap = V2;
else
ProvenanceCap = V1;
}

if (Op1IsCap)
V1 = CGF.getCapabilityIntegerValue(V1);

if (Op2IsCap)
V2 = CGF.getCapabilityIntegerValue(V2);

llvm::Value *HasOverflow;
llvm::Value *Result = EmitOverflowIntrinsic(
CGF, llvm::Intrinsic::umul_with_overflow, V1, V2, HasOverflow);
Expand All @@ -1946,6 +1971,9 @@ static RValue EmitCheckedUnsignedMultiplySignedResult(
llvm::Value *IntMaxOverflow = CGF.Builder.CreateICmpUGT(Result, IntMaxValue);
HasOverflow = CGF.Builder.CreateOr(HasOverflow, IntMaxOverflow);

if (ResultIsCap)
Result = CGF.setCapabilityIntegerValue(ProvenanceCap, Result, Loc);

bool isVolatile =
ResultArg->getType()->getPointeeType().isVolatileQualified();
Address ResultPtr = CGF.EmitPointerWithAlignment(ResultArg);
Expand All @@ -1971,18 +1999,47 @@ EmitCheckedMixedSignMultiply(CodeGenFunction &CGF, const clang::Expr *Op1,
WidthAndSignedness Op1Info, const clang::Expr *Op2,
WidthAndSignedness Op2Info,
const clang::Expr *ResultArg, QualType ResultQTy,
WidthAndSignedness ResultInfo) {
WidthAndSignedness ResultInfo,
SourceLocation Loc) {
assert(isSpecialMixedSignMultiply(Builtin::BI__builtin_mul_overflow, Op1Info,
Op2Info, ResultInfo) &&
"Not a mixed-sign multipliction we can specialize");

QualType Op1QTy = Op1->getType();
QualType Op2QTy = Op2->getType();
bool Op1IsCap = Op1QTy->isCHERICapabilityType(CGF.getContext());
bool Op2IsCap = Op2QTy->isCHERICapabilityType(CGF.getContext());
bool ResultIsCap = ResultQTy->isCHERICapabilityType(CGF.getContext());

// Emit the signed and unsigned operands.
const clang::Expr *SignedOp = Op1Info.Signed ? Op1 : Op2;
const clang::Expr *UnsignedOp = Op1Info.Signed ? Op2 : Op1;
llvm::Value *Signed = CGF.EmitScalarExpr(SignedOp);
llvm::Value *Unsigned = CGF.EmitScalarExpr(UnsignedOp);
unsigned SignedOpWidth = Op1Info.Signed ? Op1Info.Width : Op2Info.Width;
unsigned UnsignedOpWidth = Op1Info.Signed ? Op2Info.Width : Op1Info.Width;
bool SignedIsCap = Op1Info.Signed ? Op1IsCap : Op2IsCap;
bool UnsignedIsCap = Op1Info.Signed ? Op2IsCap : Op1IsCap;

llvm::Value *ProvenanceCap = nullptr;
if (ResultIsCap) {
bool Op1NoProvenance =
!Op1IsCap || Op1QTy->hasAttr(attr::CHERINoProvenance);
bool Op2NoProvenance =
!Op2IsCap || Op2QTy->hasAttr(attr::CHERINoProvenance);
if (Op1NoProvenance && Op2NoProvenance)
ProvenanceCap = llvm::ConstantPointerNull::get(CGF.Int8CheriCapTy);
else if (Op1NoProvenance)
ProvenanceCap = Op1Info.Signed ? Unsigned : Signed;
else
ProvenanceCap = Op1Info.Signed ? Signed : Unsigned;
}

if (SignedIsCap)
Signed = CGF.getCapabilityIntegerValue(Signed);

if (UnsignedIsCap)
Unsigned = CGF.getCapabilityIntegerValue(Unsigned);

// One of the operands may be smaller than the other. If so, [s|z]ext it.
if (SignedOpWidth < UnsignedOpWidth)
Expand All @@ -1993,7 +2050,9 @@ EmitCheckedMixedSignMultiply(CodeGenFunction &CGF, const clang::Expr *Op1,
llvm::Type *OpTy = Signed->getType();
llvm::Value *Zero = llvm::Constant::getNullValue(OpTy);
Address ResultPtr = CGF.EmitPointerWithAlignment(ResultArg);
llvm::Type *ResTy = ResultPtr.getElementType();
llvm::Type *ResTy = ResultIsCap ? llvm::IntegerType::get(CGF.getLLVMContext(),
ResultInfo.Width)
: ResultPtr.getElementType();
unsigned OpWidth = std::max(Op1Info.Width, Op2Info.Width);

// Take the absolute value of the signed operand.
Expand Down Expand Up @@ -2032,8 +2091,7 @@ EmitCheckedMixedSignMultiply(CodeGenFunction &CGF, const clang::Expr *Op1,
IsNegative, CGF.Builder.CreateIsNotNull(UnsignedResult));
Overflow = CGF.Builder.CreateOr(UnsignedOverflow, Underflow);
if (ResultInfo.Width < OpWidth) {
auto IntMax =
llvm::APInt::getMaxValue(ResultInfo.Width).zext(OpWidth);
auto IntMax = llvm::APInt::getMaxValue(ResultInfo.Width).zext(OpWidth);
llvm::Value *TruncOverflow = CGF.Builder.CreateICmpUGT(
UnsignedResult, llvm::ConstantInt::get(OpTy, IntMax));
Overflow = CGF.Builder.CreateOr(Overflow, TruncOverflow);
Expand All @@ -2047,6 +2105,9 @@ EmitCheckedMixedSignMultiply(CodeGenFunction &CGF, const clang::Expr *Op1,
}
assert(Overflow && Result && "Missing overflow or result");

if (ResultIsCap)
Result = CGF.setCapabilityIntegerValue(ProvenanceCap, Result, Loc);

bool isVolatile =
ResultArg->getType()->getPointeeType().isVolatileQualified();
CGF.Builder.CreateStore(CGF.EmitToMemory(Result, ResultQTy), ResultPtr,
Expand Down Expand Up @@ -4493,13 +4554,18 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
const clang::Expr *RightArg = E->getArg(1);
const clang::Expr *ResultArg = E->getArg(2);

clang::QualType LeftQTy = LeftArg->getType();
clang::QualType RightQTy = RightArg->getType();
clang::QualType ResultQTy =
ResultArg->getType()->castAs<PointerType>()->getPointeeType();

bool LeftIsCap = LeftQTy->isCHERICapabilityType(CGM.getContext());
bool RightIsCap = RightQTy->isCHERICapabilityType(CGM.getContext());
bool ResultIsCap = ResultQTy->isCHERICapabilityType(CGM.getContext());
WidthAndSignedness LeftInfo =
getIntegerWidthAndSignedness(CGM.getContext(), LeftArg->getType());
getIntegerWidthAndSignedness(CGM.getContext(), LeftQTy);
WidthAndSignedness RightInfo =
getIntegerWidthAndSignedness(CGM.getContext(), RightArg->getType());
getIntegerWidthAndSignedness(CGM.getContext(), RightQTy);
WidthAndSignedness ResultInfo =
getIntegerWidthAndSignedness(CGM.getContext(), ResultQTy);

Expand All @@ -4508,37 +4574,44 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
if (isSpecialMixedSignMultiply(BuiltinID, LeftInfo, RightInfo, ResultInfo))
return EmitCheckedMixedSignMultiply(*this, LeftArg, LeftInfo, RightArg,
RightInfo, ResultArg, ResultQTy,
ResultInfo);
ResultInfo, E->getExprLoc());

if (isSpecialUnsignedMultiplySignedResult(BuiltinID, LeftInfo, RightInfo,
ResultInfo))
return EmitCheckedUnsignedMultiplySignedResult(
*this, LeftArg, LeftInfo, RightArg, RightInfo, ResultArg, ResultQTy,
ResultInfo);
ResultInfo, E->getExprLoc());

WidthAndSignedness EncompassingInfo =
EncompassingIntegerType({LeftInfo, RightInfo, ResultInfo});

llvm::Type *EncompassingLLVMTy =
llvm::IntegerType::get(CGM.getLLVMContext(), EncompassingInfo.Width);

llvm::Type *ResultLLVMTy = CGM.getTypes().ConvertType(ResultQTy);
llvm::Type *ResultLLVMTy =
ResultIsCap
? llvm::IntegerType::get(CGM.getLLVMContext(), ResultInfo.Width)
: CGM.getTypes().ConvertType(ResultQTy);

llvm::Intrinsic::ID IntrinsicId;
bool Commutative;
switch (BuiltinID) {
default:
llvm_unreachable("Unknown overflow builtin id.");
case Builtin::BI__builtin_add_overflow:
Commutative = true;
IntrinsicId = EncompassingInfo.Signed
? llvm::Intrinsic::sadd_with_overflow
: llvm::Intrinsic::uadd_with_overflow;
break;
case Builtin::BI__builtin_sub_overflow:
Commutative = false;
IntrinsicId = EncompassingInfo.Signed
? llvm::Intrinsic::ssub_with_overflow
: llvm::Intrinsic::usub_with_overflow;
break;
case Builtin::BI__builtin_mul_overflow:
Commutative = true;
IntrinsicId = EncompassingInfo.Signed
? llvm::Intrinsic::smul_with_overflow
: llvm::Intrinsic::umul_with_overflow;
Expand All @@ -4549,6 +4622,33 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
llvm::Value *Right = EmitScalarExpr(RightArg);
Address ResultPtr = EmitPointerWithAlignment(ResultArg);

llvm::Value *ProvenanceCap = nullptr;
if (ResultIsCap) {
if (!Commutative) {
if (LeftIsCap)
ProvenanceCap = Left;
else
ProvenanceCap = llvm::ConstantPointerNull::get(Int8CheriCapTy);
} else {
bool LeftNoProvenance =
!LeftIsCap || LeftQTy->hasAttr(attr::CHERINoProvenance);
bool RightNoProvenance =
!RightIsCap || RightQTy->hasAttr(attr::CHERINoProvenance);
if (LeftNoProvenance && RightNoProvenance)
ProvenanceCap = llvm::ConstantPointerNull::get(Int8CheriCapTy);
else if (LeftNoProvenance)
ProvenanceCap = Right;
else
ProvenanceCap = Left;
}
}

if (LeftIsCap)
Left = getCapabilityIntegerValue(Left);

if (RightIsCap)
Right = getCapabilityIntegerValue(Right);

// Extend each operand to the encompassing type.
Left = Builder.CreateIntCast(Left, EncompassingLLVMTy, LeftInfo.Signed);
Right = Builder.CreateIntCast(Right, EncompassingLLVMTy, RightInfo.Signed);
Expand All @@ -4573,6 +4673,10 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
Result = ResultTrunc;
}

if (ResultIsCap)
Result =
setCapabilityIntegerValue(ProvenanceCap, Result, E->getExprLoc());

// Finally, store the result using the pointer.
bool isVolatile =
ResultArg->getType()->getPointeeType().isVolatileQualified();
Expand Down
14 changes: 14 additions & 0 deletions clang/lib/Sema/SemaChecking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@
#include "clang/AST/TypeLoc.h"
#include "clang/AST/UnresolvedSet.h"
#include "clang/Basic/AddressSpaces.h"
#include "clang/Basic/Builtins.h"
#include "clang/Basic/CharInfo.h"
#include "clang/Basic/Diagnostic.h"
#include "clang/Basic/DiagnosticFrontend.h"
#include "clang/Basic/IdentifierTable.h"
#include "clang/Basic/LLVM.h"
#include "clang/Basic/LangOptions.h"
Expand Down Expand Up @@ -439,6 +441,18 @@ static bool SemaBuiltinOverflow(Sema &S, CallExpr *TheCall,
}
}

// ScalarExprEmitter::EmitSub's diagnostics aren't included here since
// they're generally unhelpful, grouped under pedantic warnings, and would be
// confusing without also taking into the type of the result.
if (BuiltinID != Builtin::BI__builtin_sub_overflow) {
assert((BuiltinID == Builtin::BI__builtin_add_overflow ||
BuiltinID == Builtin::BI__builtin_mul_overflow) &&
"Unexpected overflow builtin");

S.DiagnoseAmbiguousProvenance(TheCall->getArg(0), TheCall->getArg(1),
TheCall->getExprLoc(), false);
}

return false;
}

Expand Down
Loading

0 comments on commit def11ee

Please sign in to comment.