@@ -1123,6 +1123,8 @@ addReductionDecl(mlir::Location currentLocation,
1123
1123
Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
1124
1124
if (const Fortran::semantics::Symbol * symbol{name->symbol }) {
1125
1125
mlir::Value symVal = converter.getSymbolAddress (*symbol);
1126
+ if (auto declOp = symVal.getDefiningOp <hlfir::DeclareOp>())
1127
+ symVal = declOp.getBase ();
1126
1128
mlir::Type redType =
1127
1129
symVal.getType ().cast <fir::ReferenceType>().getEleTy ();
1128
1130
reductionVars.push_back (symVal);
@@ -1160,6 +1162,8 @@ addReductionDecl(mlir::Location currentLocation,
1160
1162
Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
1161
1163
if (const Fortran::semantics::Symbol * symbol{name->symbol }) {
1162
1164
mlir::Value symVal = converter.getSymbolAddress (*symbol);
1165
+ if (auto declOp = symVal.getDefiningOp <hlfir::DeclareOp>())
1166
+ symVal = declOp.getBase ();
1163
1167
mlir::Type redType =
1164
1168
symVal.getType ().cast <fir::ReferenceType>().getEleTy ();
1165
1169
reductionVars.push_back (symVal);
@@ -3746,6 +3750,8 @@ void Fortran::lower::genOpenMPReduction(
3746
3750
Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
3747
3751
if (const Fortran::semantics::Symbol * symbol{name->symbol }) {
3748
3752
mlir::Value reductionVal = converter.getSymbolAddress (*symbol);
3753
+ if (auto declOp = reductionVal.getDefiningOp <hlfir::DeclareOp>())
3754
+ reductionVal = declOp.getBase ();
3749
3755
mlir::Type reductionType =
3750
3756
reductionVal.getType ().cast <fir::ReferenceType>().getEleTy ();
3751
3757
if (!reductionType.isa <fir::LogicalType>()) {
@@ -3789,6 +3795,9 @@ void Fortran::lower::genOpenMPReduction(
3789
3795
ompObject)}) {
3790
3796
if (const Fortran::semantics::Symbol * symbol{name->symbol }) {
3791
3797
mlir::Value reductionVal = converter.getSymbolAddress (*symbol);
3798
+ if (auto declOp =
3799
+ reductionVal.getDefiningOp <hlfir::DeclareOp>())
3800
+ reductionVal = declOp.getBase ();
3792
3801
for (const mlir::OpOperand &reductionValUse :
3793
3802
reductionVal.getUses ()) {
3794
3803
if (auto loadOp = mlir::dyn_cast<fir::LoadOp>(
@@ -3844,6 +3853,13 @@ mlir::Operation *Fortran::lower::findReductionChain(mlir::Value loadVal,
3844
3853
return reductionOp;
3845
3854
}
3846
3855
}
3856
+ if (auto assign =
3857
+ mlir::dyn_cast<hlfir::AssignOp>(reductionOperand.getOwner ())) {
3858
+ if (assign.getLhs () == *reductionVal) {
3859
+ assign.erase ();
3860
+ return reductionOp;
3861
+ }
3862
+ }
3847
3863
}
3848
3864
}
3849
3865
}
@@ -3899,6 +3915,11 @@ void Fortran::lower::removeStoreOp(mlir::Operation *reductionOp,
3899
3915
if (storeOp.getMemref () == symVal)
3900
3916
storeOp.erase ();
3901
3917
}
3918
+ if (auto assignOp =
3919
+ mlir::dyn_cast<hlfir::AssignOp>(convertReductionUse)) {
3920
+ if (assignOp.getLhs () == symVal)
3921
+ assignOp.erase ();
3922
+ }
3902
3923
}
3903
3924
}
3904
3925
}
0 commit comments