Skip to content

Commit 4bf160e

Browse files
authored
[clang][Interp] Implement Complex-complex multiplication (#94891)
Share the implementation for floating-point complex-complex multiplication with the current interpreter. This means we need a new opcode for this, but there's no good way around that.
1 parent 954cb5f commit 4bf160e

File tree

6 files changed

+201
-59
lines changed

6 files changed

+201
-59
lines changed

clang/lib/AST/ExprConstShared.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
#ifndef LLVM_CLANG_LIB_AST_EXPRCONSTSHARED_H
1515
#define LLVM_CLANG_LIB_AST_EXPRCONSTSHARED_H
1616

17+
namespace llvm {
18+
class APFloat;
19+
}
1720
namespace clang {
1821
class QualType;
1922
class LangOptions;
@@ -56,4 +59,8 @@ enum class GCCTypeClass {
5659
GCCTypeClass EvaluateBuiltinClassifyType(QualType T,
5760
const LangOptions &LangOpts);
5861

62+
void HandleComplexComplexMul(llvm::APFloat A, llvm::APFloat B, llvm::APFloat C,
63+
llvm::APFloat D, llvm::APFloat &ResR,
64+
llvm::APFloat &ResI);
65+
5966
#endif

clang/lib/AST/ExprConstant.cpp

Lines changed: 57 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -15126,6 +15126,62 @@ bool ComplexExprEvaluator::VisitCastExpr(const CastExpr *E) {
1512615126
llvm_unreachable("unknown cast resulting in complex value");
1512715127
}
1512815128

15129+
void HandleComplexComplexMul(APFloat A, APFloat B, APFloat C, APFloat D,
15130+
APFloat &ResR, APFloat &ResI) {
15131+
// This is an implementation of complex multiplication according to the
15132+
// constraints laid out in C11 Annex G. The implementation uses the
15133+
// following naming scheme:
15134+
// (a + ib) * (c + id)
15135+
15136+
APFloat AC = A * C;
15137+
APFloat BD = B * D;
15138+
APFloat AD = A * D;
15139+
APFloat BC = B * C;
15140+
ResR = AC - BD;
15141+
ResI = AD + BC;
15142+
if (ResR.isNaN() && ResI.isNaN()) {
15143+
bool Recalc = false;
15144+
if (A.isInfinity() || B.isInfinity()) {
15145+
A = APFloat::copySign(APFloat(A.getSemantics(), A.isInfinity() ? 1 : 0),
15146+
A);
15147+
B = APFloat::copySign(APFloat(B.getSemantics(), B.isInfinity() ? 1 : 0),
15148+
B);
15149+
if (C.isNaN())
15150+
C = APFloat::copySign(APFloat(C.getSemantics()), C);
15151+
if (D.isNaN())
15152+
D = APFloat::copySign(APFloat(D.getSemantics()), D);
15153+
Recalc = true;
15154+
}
15155+
if (C.isInfinity() || D.isInfinity()) {
15156+
C = APFloat::copySign(APFloat(C.getSemantics(), C.isInfinity() ? 1 : 0),
15157+
C);
15158+
D = APFloat::copySign(APFloat(D.getSemantics(), D.isInfinity() ? 1 : 0),
15159+
D);
15160+
if (A.isNaN())
15161+
A = APFloat::copySign(APFloat(A.getSemantics()), A);
15162+
if (B.isNaN())
15163+
B = APFloat::copySign(APFloat(B.getSemantics()), B);
15164+
Recalc = true;
15165+
}
15166+
if (!Recalc && (AC.isInfinity() || BD.isInfinity() || AD.isInfinity() ||
15167+
BC.isInfinity())) {
15168+
if (A.isNaN())
15169+
A = APFloat::copySign(APFloat(A.getSemantics()), A);
15170+
if (B.isNaN())
15171+
B = APFloat::copySign(APFloat(B.getSemantics()), B);
15172+
if (C.isNaN())
15173+
C = APFloat::copySign(APFloat(C.getSemantics()), C);
15174+
if (D.isNaN())
15175+
D = APFloat::copySign(APFloat(D.getSemantics()), D);
15176+
Recalc = true;
15177+
}
15178+
if (Recalc) {
15179+
ResR = APFloat::getInf(A.getSemantics()) * (A * C - B * D);
15180+
ResI = APFloat::getInf(A.getSemantics()) * (A * D + B * C);
15181+
}
15182+
}
15183+
}
15184+
1512915185
bool ComplexExprEvaluator::VisitBinaryOperator(const BinaryOperator *E) {
1513015186
if (E->isPtrMemOp() || E->isAssignmentOp() || E->getOpcode() == BO_Comma)
1513115187
return ExprEvaluatorBaseTy::VisitBinaryOperator(E);
@@ -15225,55 +15281,7 @@ bool ComplexExprEvaluator::VisitBinaryOperator(const BinaryOperator *E) {
1522515281
!handleFloatFloatBinOp(Info, E, ResI, BO_Mul, B))
1522615282
return false;
1522715283
} else {
15228-
// In the fully general case, we need to handle NaNs and infinities
15229-
// robustly.
15230-
APFloat AC = A * C;
15231-
APFloat BD = B * D;
15232-
APFloat AD = A * D;
15233-
APFloat BC = B * C;
15234-
ResR = AC - BD;
15235-
ResI = AD + BC;
15236-
if (ResR.isNaN() && ResI.isNaN()) {
15237-
bool Recalc = false;
15238-
if (A.isInfinity() || B.isInfinity()) {
15239-
A = APFloat::copySign(
15240-
APFloat(A.getSemantics(), A.isInfinity() ? 1 : 0), A);
15241-
B = APFloat::copySign(
15242-
APFloat(B.getSemantics(), B.isInfinity() ? 1 : 0), B);
15243-
if (C.isNaN())
15244-
C = APFloat::copySign(APFloat(C.getSemantics()), C);
15245-
if (D.isNaN())
15246-
D = APFloat::copySign(APFloat(D.getSemantics()), D);
15247-
Recalc = true;
15248-
}
15249-
if (C.isInfinity() || D.isInfinity()) {
15250-
C = APFloat::copySign(
15251-
APFloat(C.getSemantics(), C.isInfinity() ? 1 : 0), C);
15252-
D = APFloat::copySign(
15253-
APFloat(D.getSemantics(), D.isInfinity() ? 1 : 0), D);
15254-
if (A.isNaN())
15255-
A = APFloat::copySign(APFloat(A.getSemantics()), A);
15256-
if (B.isNaN())
15257-
B = APFloat::copySign(APFloat(B.getSemantics()), B);
15258-
Recalc = true;
15259-
}
15260-
if (!Recalc && (AC.isInfinity() || BD.isInfinity() ||
15261-
AD.isInfinity() || BC.isInfinity())) {
15262-
if (A.isNaN())
15263-
A = APFloat::copySign(APFloat(A.getSemantics()), A);
15264-
if (B.isNaN())
15265-
B = APFloat::copySign(APFloat(B.getSemantics()), B);
15266-
if (C.isNaN())
15267-
C = APFloat::copySign(APFloat(C.getSemantics()), C);
15268-
if (D.isNaN())
15269-
D = APFloat::copySign(APFloat(D.getSemantics()), D);
15270-
Recalc = true;
15271-
}
15272-
if (Recalc) {
15273-
ResR = APFloat::getInf(A.getSemantics()) * (A * C - B * D);
15274-
ResI = APFloat::getInf(A.getSemantics()) * (A * D + B * C);
15275-
}
15276-
}
15284+
HandleComplexComplexMul(A, B, C, D, ResR, ResI);
1527715285
}
1527815286
} else {
1527915287
ComplexValue LHS = Result;

clang/lib/AST/Interp/ByteCodeExprGen.cpp

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -876,6 +876,22 @@ bool ByteCodeExprGen<Emitter>::VisitComplexBinOp(const BinaryOperator *E) {
876876
if (const auto *AT = RHSType->getAs<AtomicType>())
877877
RHSType = AT->getValueType();
878878

879+
// For ComplexComplex Mul, we have special ops to make their implementation
880+
// easier.
881+
BinaryOperatorKind Op = E->getOpcode();
882+
if (Op == BO_Mul && LHSType->isAnyComplexType() &&
883+
RHSType->isAnyComplexType()) {
884+
assert(classifyPrim(LHSType->getAs<ComplexType>()->getElementType()) ==
885+
classifyPrim(RHSType->getAs<ComplexType>()->getElementType()));
886+
PrimType ElemT =
887+
classifyPrim(LHSType->getAs<ComplexType>()->getElementType());
888+
if (!this->visit(LHS))
889+
return false;
890+
if (!this->visit(RHS))
891+
return false;
892+
return this->emitMulc(ElemT, E);
893+
}
894+
879895
// Evaluate LHS and save value to LHSOffset.
880896
bool LHSIsComplex;
881897
unsigned LHSOffset;
@@ -919,38 +935,37 @@ bool ByteCodeExprGen<Emitter>::VisitComplexBinOp(const BinaryOperator *E) {
919935
// For both LHS and RHS, either load the value from the complex pointer, or
920936
// directly from the local variable. For index 1 (i.e. the imaginary part),
921937
// just load 0 and do the operation anyway.
922-
auto loadComplexValue = [this](bool IsComplex, unsigned ElemIndex,
923-
unsigned Offset, const Expr *E) -> bool {
938+
auto loadComplexValue = [this](bool IsComplex, bool LoadZero,
939+
unsigned ElemIndex, unsigned Offset,
940+
const Expr *E) -> bool {
924941
if (IsComplex) {
925942
if (!this->emitGetLocal(PT_Ptr, Offset, E))
926943
return false;
927944
return this->emitArrayElemPop(classifyComplexElementType(E->getType()),
928945
ElemIndex, E);
929946
}
930-
if (ElemIndex == 0)
947+
if (ElemIndex == 0 || !LoadZero)
931948
return this->emitGetLocal(classifyPrim(E->getType()), Offset, E);
932949
return this->visitZeroInitializer(classifyPrim(E->getType()), E->getType(),
933950
E);
934951
};
935952

936953
// Now we can get pointers to the LHS and RHS from the offsets above.
937-
BinaryOperatorKind Op = E->getOpcode();
938954
for (unsigned ElemIndex = 0; ElemIndex != 2; ++ElemIndex) {
939955
// Result pointer for the store later.
940956
if (!this->DiscardResult) {
941957
if (!this->emitGetLocal(PT_Ptr, ResultOffset, E))
942958
return false;
943959
}
944960

945-
if (!loadComplexValue(LHSIsComplex, ElemIndex, LHSOffset, LHS))
946-
return false;
947-
948-
if (!loadComplexValue(RHSIsComplex, ElemIndex, RHSOffset, RHS))
949-
return false;
950-
951961
// The actual operation.
952962
switch (Op) {
953963
case BO_Add:
964+
if (!loadComplexValue(LHSIsComplex, true, ElemIndex, LHSOffset, LHS))
965+
return false;
966+
967+
if (!loadComplexValue(RHSIsComplex, true, ElemIndex, RHSOffset, RHS))
968+
return false;
954969
if (ResultElemT == PT_Float) {
955970
if (!this->emitAddf(getRoundingMode(E), E))
956971
return false;
@@ -960,6 +975,11 @@ bool ByteCodeExprGen<Emitter>::VisitComplexBinOp(const BinaryOperator *E) {
960975
}
961976
break;
962977
case BO_Sub:
978+
if (!loadComplexValue(LHSIsComplex, true, ElemIndex, LHSOffset, LHS))
979+
return false;
980+
981+
if (!loadComplexValue(RHSIsComplex, true, ElemIndex, RHSOffset, RHS))
982+
return false;
963983
if (ResultElemT == PT_Float) {
964984
if (!this->emitSubf(getRoundingMode(E), E))
965985
return false;
@@ -968,6 +988,21 @@ bool ByteCodeExprGen<Emitter>::VisitComplexBinOp(const BinaryOperator *E) {
968988
return false;
969989
}
970990
break;
991+
case BO_Mul:
992+
if (!loadComplexValue(LHSIsComplex, false, ElemIndex, LHSOffset, LHS))
993+
return false;
994+
995+
if (!loadComplexValue(RHSIsComplex, false, ElemIndex, RHSOffset, RHS))
996+
return false;
997+
998+
if (ResultElemT == PT_Float) {
999+
if (!this->emitMulf(getRoundingMode(E), E))
1000+
return false;
1001+
} else {
1002+
if (!this->emitMul(ResultElemT, E))
1003+
return false;
1004+
}
1005+
break;
9711006

9721007
default:
9731008
return false;

clang/lib/AST/Interp/Interp.h

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#ifndef LLVM_CLANG_AST_INTERP_INTERP_H
1414
#define LLVM_CLANG_AST_INTERP_INTERP_H
1515

16+
#include "../ExprConstShared.h"
1617
#include "Boolean.h"
1718
#include "Floating.h"
1819
#include "Function.h"
@@ -368,6 +369,62 @@ inline bool Mulf(InterpState &S, CodePtr OpPC, llvm::RoundingMode RM) {
368369
S.Stk.push<Floating>(Result);
369370
return CheckFloatResult(S, OpPC, Result, Status);
370371
}
372+
373+
template <PrimType Name, class T = typename PrimConv<Name>::T>
374+
inline bool Mulc(InterpState &S, CodePtr OpPC) {
375+
const Pointer &RHS = S.Stk.pop<Pointer>();
376+
const Pointer &LHS = S.Stk.pop<Pointer>();
377+
const Pointer &Result = S.Stk.peek<Pointer>();
378+
379+
if constexpr (std::is_same_v<T, Floating>) {
380+
APFloat A = LHS.atIndex(0).deref<Floating>().getAPFloat();
381+
APFloat B = LHS.atIndex(1).deref<Floating>().getAPFloat();
382+
APFloat C = RHS.atIndex(0).deref<Floating>().getAPFloat();
383+
APFloat D = RHS.atIndex(1).deref<Floating>().getAPFloat();
384+
385+
APFloat ResR(A.getSemantics());
386+
APFloat ResI(A.getSemantics());
387+
HandleComplexComplexMul(A, B, C, D, ResR, ResI);
388+
389+
// Copy into the result.
390+
Result.atIndex(0).deref<Floating>() = Floating(ResR);
391+
Result.atIndex(0).initialize();
392+
Result.atIndex(1).deref<Floating>() = Floating(ResI);
393+
Result.atIndex(1).initialize();
394+
Result.initialize();
395+
} else {
396+
// Integer element type.
397+
const T &LHSR = LHS.atIndex(0).deref<T>();
398+
const T &LHSI = LHS.atIndex(1).deref<T>();
399+
const T &RHSR = RHS.atIndex(0).deref<T>();
400+
const T &RHSI = RHS.atIndex(1).deref<T>();
401+
unsigned Bits = LHSR.bitWidth();
402+
403+
// real(Result) = (real(LHS) * real(RHS)) - (imag(LHS) * imag(RHS))
404+
T A;
405+
if (T::mul(LHSR, RHSR, Bits, &A))
406+
return false;
407+
T B;
408+
if (T::mul(LHSI, RHSI, Bits, &B))
409+
return false;
410+
if (T::sub(A, B, Bits, &Result.atIndex(0).deref<T>()))
411+
return false;
412+
Result.atIndex(0).initialize();
413+
414+
// imag(Result) = (real(LHS) * imag(RHS)) + (imag(LHS) * real(RHS))
415+
if (T::mul(LHSR, RHSI, Bits, &A))
416+
return false;
417+
if (T::mul(LHSI, RHSR, Bits, &B))
418+
return false;
419+
if (T::add(A, B, Bits, &Result.atIndex(1).deref<T>()))
420+
return false;
421+
Result.atIndex(1).initialize();
422+
Result.initialize();
423+
}
424+
425+
return true;
426+
}
427+
371428
/// 1) Pops the RHS from the stack.
372429
/// 2) Pops the LHS from the stack.
373430
/// 3) Pushes 'LHS & RHS' on the stack

clang/lib/AST/Interp/Opcodes.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -526,6 +526,10 @@ def Sub : AluOpcode;
526526
def Subf : FloatOpcode;
527527
def Mul : AluOpcode;
528528
def Mulf : FloatOpcode;
529+
def Mulc : Opcode {
530+
let Types = [NumberTypeClass];
531+
let HasGroup = 1;
532+
}
529533
def Rem : IntegerOpcode;
530534
def Div : IntegerOpcode;
531535
def Divf : FloatOpcode;

clang/test/AST/Interp/complex.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,37 @@ static_assert(&__imag z1 == &__real z1 + 1, "");
99
static_assert((*(&__imag z1)) == __imag z1, "");
1010
static_assert((*(&__real z1)) == __real z1, "");
1111

12+
13+
static_assert(((1.25 + 0.5j) * (0.25 - 0.75j)) == (0.6875 - 0.8125j), "");
14+
static_assert(((1.25 + 0.5j) * 0.25) == (0.3125 + 0.125j), "");
15+
static_assert((1.25 * (0.25 - 0.75j)) == (0.3125 - 0.9375j), "");
16+
constexpr _Complex float InfC = {1.0, __builtin_inf()};
17+
constexpr _Complex float InfInf = __builtin_inf() + InfC;
18+
static_assert(__real__(InfInf) == __builtin_inf());
19+
static_assert(__imag__(InfInf) == __builtin_inf());
20+
static_assert(__builtin_isnan(__real__(InfInf * InfInf)));
21+
static_assert(__builtin_isinf_sign(__imag__(InfInf * InfInf)) == 1);
22+
23+
static_assert(__builtin_isinf_sign(__real__((__builtin_inf() + 1.0j) * 1.0)) == 1);
24+
static_assert(__builtin_isinf_sign(__imag__((1.0 + InfC) * 1.0)) == 1);
25+
static_assert(__builtin_isinf_sign(__real__(1.0 * (__builtin_inf() + 1.0j))) == 1);
26+
static_assert(__builtin_isinf_sign(__imag__(1.0 * (1.0 + InfC))) == 1);
27+
static_assert(__builtin_isinf_sign(__real__((__builtin_inf() + 1.0j) * (1.0 + 1.0j))) == 1);
28+
static_assert(__builtin_isinf_sign(__real__((1.0 + 1.0j) * (__builtin_inf() + 1.0j))) == 1);
29+
static_assert(__builtin_isinf_sign(__real__((__builtin_inf() + 1.0j) * (__builtin_inf() + 1.0j))) == 1);
30+
static_assert(__builtin_isinf_sign(__real__((1.0 + InfC) * (1.0 + 1.0j))) == -1);
31+
static_assert(__builtin_isinf_sign(__imag__((1.0 + InfC) * (1.0 + 1.0j))) == 1);
32+
static_assert(__builtin_isinf_sign(__real__((1.0 + 1.0j) * (1.0 + InfC))) == -1);
33+
static_assert(__builtin_isinf_sign(__imag__((1.0 + 1.0j) * (1.0 + InfC))) == 1);
34+
static_assert(__builtin_isinf_sign(__real__((1.0 + InfC) * (1.0 + InfC))) == -1);
35+
static_assert(__builtin_isinf_sign(__real__(InfInf * InfInf)) == 0);
36+
37+
constexpr _Complex int IIMA = {1,2};
38+
constexpr _Complex int IIMB = {10,20};
39+
constexpr _Complex int IIMC = IIMA * IIMB;
40+
static_assert(__real(IIMC) == -30, "");
41+
static_assert(__imag(IIMC) == 40, "");
42+
1243
constexpr _Complex int Comma1 = {1, 2};
1344
constexpr _Complex int Comma2 = (0, Comma1);
1445
static_assert(Comma1 == Comma1, "");

0 commit comments

Comments
 (0)