@@ -103,21 +103,6 @@ static fir::GlobalOp globalInitialization(
103
103
return global;
104
104
}
105
105
106
- static mlir::Operation *getCompareFromReductionOp (mlir::Operation *reductionOp,
107
- mlir::Value loadVal) {
108
- for (mlir::Value reductionOperand : reductionOp->getOperands ()) {
109
- if (mlir::Operation *compareOp = reductionOperand.getDefiningOp ()) {
110
- if (compareOp->getOperand (0 ) == loadVal ||
111
- compareOp->getOperand (1 ) == loadVal)
112
- assert ((mlir::isa<mlir::arith::CmpIOp>(compareOp) ||
113
- mlir::isa<mlir::arith::CmpFOp>(compareOp)) &&
114
- " Expected comparison not found in reduction intrinsic" );
115
- return compareOp;
116
- }
117
- }
118
- return nullptr ;
119
- }
120
-
121
106
// Get the extended value for \p val by extracting additional variable
122
107
// information from \p base.
123
108
static fir::ExtendedValue getExtendedValue (fir::ExtendedValue base,
@@ -237,213 +222,6 @@ createAndSetPrivatizedLoopVar(Fortran::lower::AbstractConverter &converter,
237
222
return storeOp;
238
223
}
239
224
240
- static mlir::Operation *
241
- findReductionChain (mlir::Value loadVal, mlir::Value *reductionVal = nullptr ) {
242
- for (mlir::OpOperand &loadOperand : loadVal.getUses ()) {
243
- if (mlir::Operation *reductionOp = loadOperand.getOwner ()) {
244
- if (auto convertOp = mlir::dyn_cast<fir::ConvertOp>(reductionOp)) {
245
- for (mlir::OpOperand &convertOperand : convertOp.getRes ().getUses ()) {
246
- if (mlir::Operation *reductionOp = convertOperand.getOwner ())
247
- return reductionOp;
248
- }
249
- }
250
- for (mlir::OpOperand &reductionOperand : reductionOp->getUses ()) {
251
- if (auto store =
252
- mlir::dyn_cast<fir::StoreOp>(reductionOperand.getOwner ())) {
253
- if (store.getMemref () == *reductionVal) {
254
- store.erase ();
255
- return reductionOp;
256
- }
257
- }
258
- if (auto assign =
259
- mlir::dyn_cast<hlfir::AssignOp>(reductionOperand.getOwner ())) {
260
- if (assign.getLhs () == *reductionVal) {
261
- assign.erase ();
262
- return reductionOp;
263
- }
264
- }
265
- }
266
- }
267
- }
268
- return nullptr ;
269
- }
270
-
271
- // for a logical operator 'op' reduction X = X op Y
272
- // This function returns the operation responsible for converting Y from
273
- // fir.logical<4> to i1
274
- static fir::ConvertOp getConvertFromReductionOp (mlir::Operation *reductionOp,
275
- mlir::Value loadVal) {
276
- for (mlir::Value reductionOperand : reductionOp->getOperands ()) {
277
- if (auto convertOp =
278
- mlir::dyn_cast<fir::ConvertOp>(reductionOperand.getDefiningOp ())) {
279
- if (convertOp.getOperand () == loadVal)
280
- continue ;
281
- return convertOp;
282
- }
283
- }
284
- return nullptr ;
285
- }
286
-
287
- static void updateReduction (mlir::Operation *op,
288
- fir::FirOpBuilder &firOpBuilder,
289
- mlir::Value loadVal, mlir::Value reductionVal,
290
- fir::ConvertOp *convertOp = nullptr ) {
291
- mlir::OpBuilder::InsertPoint insertPtDel = firOpBuilder.saveInsertionPoint ();
292
- firOpBuilder.setInsertionPoint (op);
293
-
294
- mlir::Value reductionOp;
295
- if (convertOp)
296
- reductionOp = convertOp->getOperand ();
297
- else if (op->getOperand (0 ) == loadVal)
298
- reductionOp = op->getOperand (1 );
299
- else
300
- reductionOp = op->getOperand (0 );
301
-
302
- firOpBuilder.create <mlir::omp::ReductionOp>(op->getLoc (), reductionOp,
303
- reductionVal);
304
- firOpBuilder.restoreInsertionPoint (insertPtDel);
305
- }
306
-
307
- static void removeStoreOp (mlir::Operation *reductionOp, mlir::Value symVal) {
308
- for (mlir::Operation *reductionOpUse : reductionOp->getUsers ()) {
309
- if (auto convertReduction =
310
- mlir::dyn_cast<fir::ConvertOp>(reductionOpUse)) {
311
- for (mlir::Operation *convertReductionUse :
312
- convertReduction.getRes ().getUsers ()) {
313
- if (auto storeOp = mlir::dyn_cast<fir::StoreOp>(convertReductionUse)) {
314
- if (storeOp.getMemref () == symVal)
315
- storeOp.erase ();
316
- }
317
- if (auto assignOp =
318
- mlir::dyn_cast<hlfir::AssignOp>(convertReductionUse)) {
319
- if (assignOp.getLhs () == symVal)
320
- assignOp.erase ();
321
- }
322
- }
323
- }
324
- }
325
- }
326
-
327
- // Generate an OpenMP reduction operation.
328
- // TODO: Currently assumes it is either an integer addition/multiplication
329
- // reduction, or a logical and reduction. Generalize this for various reduction
330
- // operation types.
331
- // TODO: Generate the reduction operation during lowering instead of creating
332
- // and removing operations since this is not a robust approach. Also, removing
333
- // ops in the builder (instead of a rewriter) is probably not the best approach.
334
- static void
335
- genOpenMPReduction (Fortran::lower::AbstractConverter &converter,
336
- Fortran::semantics::SemanticsContext &semaCtx,
337
- const Fortran::parser::OmpClauseList &clauseList) {
338
- fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder ();
339
-
340
- List<Clause> clauses{makeClauses (clauseList, semaCtx)};
341
-
342
- for (const Clause &clause : clauses) {
343
- if (const auto &reductionClause =
344
- std::get_if<clause::Reduction>(&clause.u )) {
345
- const auto &redOperatorList{
346
- std::get<clause::Reduction::ReductionIdentifiers>(
347
- reductionClause->t )};
348
- assert (redOperatorList.size () == 1 && " Expecting single operator" );
349
- const auto &redOperator = redOperatorList.front ();
350
- const auto &objects{std::get<ObjectList>(reductionClause->t )};
351
- if (const auto *reductionOp =
352
- std::get_if<clause::DefinedOperator>(&redOperator.u )) {
353
- const auto &intrinsicOp{
354
- std::get<clause::DefinedOperator::IntrinsicOperator>(
355
- reductionOp->u )};
356
-
357
- switch (intrinsicOp) {
358
- case clause::DefinedOperator::IntrinsicOperator::Add:
359
- case clause::DefinedOperator::IntrinsicOperator::Multiply:
360
- case clause::DefinedOperator::IntrinsicOperator::AND:
361
- case clause::DefinedOperator::IntrinsicOperator::EQV:
362
- case clause::DefinedOperator::IntrinsicOperator::OR:
363
- case clause::DefinedOperator::IntrinsicOperator::NEQV:
364
- break ;
365
- default :
366
- continue ;
367
- }
368
- for (const Object &object : objects) {
369
- if (const Fortran::semantics::Symbol *symbol = object.id ()) {
370
- mlir::Value reductionVal = converter.getSymbolAddress (*symbol);
371
- if (auto declOp = reductionVal.getDefiningOp <hlfir::DeclareOp>())
372
- reductionVal = declOp.getBase ();
373
- mlir::Type reductionType =
374
- reductionVal.getType ().cast <fir::ReferenceType>().getEleTy ();
375
- if (!reductionType.isa <fir::LogicalType>()) {
376
- if (!reductionType.isIntOrIndexOrFloat ())
377
- continue ;
378
- }
379
- for (mlir::OpOperand &reductionValUse : reductionVal.getUses ()) {
380
- if (auto loadOp =
381
- mlir::dyn_cast<fir::LoadOp>(reductionValUse.getOwner ())) {
382
- mlir::Value loadVal = loadOp.getRes ();
383
- if (reductionType.isa <fir::LogicalType>()) {
384
- mlir::Operation *reductionOp = findReductionChain (loadVal);
385
- fir::ConvertOp convertOp =
386
- getConvertFromReductionOp (reductionOp, loadVal);
387
- updateReduction (reductionOp, firOpBuilder, loadVal,
388
- reductionVal, &convertOp);
389
- removeStoreOp (reductionOp, reductionVal);
390
- } else if (mlir::Operation *reductionOp =
391
- findReductionChain (loadVal, &reductionVal)) {
392
- updateReduction (reductionOp, firOpBuilder, loadVal,
393
- reductionVal);
394
- }
395
- }
396
- }
397
- }
398
- }
399
- } else if (const auto *reductionIntrinsic =
400
- std::get_if<clause::ProcedureDesignator>(&redOperator.u )) {
401
- if (!ReductionProcessor::supportedIntrinsicProcReduction (
402
- *reductionIntrinsic))
403
- continue ;
404
- ReductionProcessor::ReductionIdentifier redId =
405
- ReductionProcessor::getReductionType (*reductionIntrinsic);
406
- for (const Object &object : objects) {
407
- if (const Fortran::semantics::Symbol *symbol = object.id ()) {
408
- mlir::Value reductionVal = converter.getSymbolAddress (*symbol);
409
- if (auto declOp = reductionVal.getDefiningOp <hlfir::DeclareOp>())
410
- reductionVal = declOp.getBase ();
411
- for (const mlir::OpOperand &reductionValUse :
412
- reductionVal.getUses ()) {
413
- if (auto loadOp =
414
- mlir::dyn_cast<fir::LoadOp>(reductionValUse.getOwner ())) {
415
- mlir::Value loadVal = loadOp.getRes ();
416
- // Max is lowered as a compare -> select.
417
- // Match the pattern here.
418
- mlir::Operation *reductionOp =
419
- findReductionChain (loadVal, &reductionVal);
420
- if (reductionOp == nullptr )
421
- continue ;
422
-
423
- if (redId == ReductionProcessor::ReductionIdentifier::MAX ||
424
- redId == ReductionProcessor::ReductionIdentifier::MIN) {
425
- assert (mlir::isa<mlir::arith::SelectOp>(reductionOp) &&
426
- " Selection Op not found in reduction intrinsic" );
427
- mlir::Operation *compareOp =
428
- getCompareFromReductionOp (reductionOp, loadVal);
429
- updateReduction (compareOp, firOpBuilder, loadVal,
430
- reductionVal);
431
- }
432
- if (redId == ReductionProcessor::ReductionIdentifier::IOR ||
433
- redId == ReductionProcessor::ReductionIdentifier::IEOR ||
434
- redId == ReductionProcessor::ReductionIdentifier::IAND) {
435
- updateReduction (reductionOp, firOpBuilder, loadVal,
436
- reductionVal);
437
- }
438
- }
439
- }
440
- }
441
- }
442
- }
443
- }
444
- }
445
- }
446
-
447
225
struct OpWithBodyGenInfo {
448
226
// / A type for a code-gen callback function. This takes as argument the op for
449
227
// / which the code is being generated and returns the arguments of the op's
@@ -2197,7 +1975,6 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
2197
1975
// 2.9.3.1 SIMD construct
2198
1976
createSimdLoop (converter, semaCtx, eval, ompDirective, loopOpClauseList,
2199
1977
currentLocation);
2200
- genOpenMPReduction (converter, semaCtx, loopOpClauseList);
2201
1978
} else {
2202
1979
createWsloop (converter, semaCtx, eval, ompDirective, loopOpClauseList,
2203
1980
endClauseList, currentLocation);
0 commit comments