Skip to content

Commit c68665f

Browse files
authored
Add TransactionListener interface (#71)
Add TransactionListener interface to support notification and conditional action based on {before, after} {commit, rollback} events. ### Use-Case My present need is to dispatch events conditional on the commit of the current transaction. Using a post-commit hook allows for sending a message (e.g. on a AWS SQS queue) only after the transaction commits. This prevents race conditions where the message is picked up for processing before changes have been effected in the database and are visible to other parties. (Concretely, sending a "User Created" event only after the user record has been persisted.) Of course this is possible without post-commit hooks, however hooks provide a modular abstraction that decouples the code (e.g. send message) from the specific commit point. Transactional event listeners are a common pattern, and can be seen in other libraries such as the Spring Framework (see https://docs.spring.io/spring-framework/reference/data-access/transaction/event.html) ### Other Use-Cases Hooks provide modularity to support decoupling in several other cases. Before-Commit Hook - Perform final validation checks before committing changes - Update aggregate or summary tables based on transactional changes - Trigger notifications or events based on the impending commit After Commit Hook - Clean up temporary resources used during the transaction - Update caches or search indexes with newly committed data - Trigger asynchronous processes based on successful commits *(already elaborated above)* Before Rollback Hook - Log detailed information about the state leading to rollback - Notify relevant components/systems about the impending rollback - Eagerly release resources used during the transaction After Rollback Hook - Reset application state or clear caches affected by the rolled-back transaction - Log or report on the reasons for rollback for analysis - Trigger compensating actions or notifications due to the failed transaction
1 parent b6f864f commit c68665f

File tree

4 files changed

+245
-19
lines changed

4 files changed

+245
-19
lines changed

scalasql/core/src/DbApi.scala

Lines changed: 71 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
package scalasql.core
22

3+
import DbClient.notifyListeners
4+
35
import geny.Generator
46

57
import java.sql.{PreparedStatement, Statement}
@@ -136,6 +138,50 @@ object DbApi {
136138
flattened.renderSql(castParams)
137139
}
138140

141+
/**
142+
* A listener that can be added to a [[DbApi.Txn]] to be notified of commit and rollback events.
143+
*
144+
* The default implementations of these methods do nothing, but you can override them to
145+
* implement your own behavior.
146+
*/
147+
trait TransactionListener {
148+
149+
/**
150+
* Called when a new transaction is started.
151+
*/
152+
def begin(): Unit = ()
153+
154+
/**
155+
* Called before the transaction is committed.
156+
*
157+
* If this method throws an exception, the transaction will be rolled back and the exception
158+
* will be propagated.
159+
*/
160+
def beforeCommit(): Unit = ()
161+
162+
/**
163+
* Called after the transaction is committed.
164+
*
165+
* If this method throws an exception, it will be propagated.
166+
*/
167+
def afterCommit(): Unit = ()
168+
169+
/**
170+
* Called before the transaction is rolled back.
171+
*
172+
* If this method throws an exception, the transaction will be rolled back and the exception
173+
* will be propagated to the caller of rollback().
174+
*/
175+
def beforeRollback(): Unit = ()
176+
177+
/**
178+
* Called after the transaction is rolled back.
179+
*
180+
* If this method throws an exception, it will be propagated to the caller of rollback().
181+
*/
182+
def afterRollback(): Unit = ()
183+
}
184+
139185
/**
140186
* An interface to a SQL database *transaction*, allowing you to run queries,
141187
* create savepoints, or roll back the transaction.
@@ -151,9 +197,11 @@ object DbApi {
151197
def savepoint[T](block: DbApi.Savepoint => T): T
152198

153199
/**
154-
* Tolls back any active Savepoints and then rolls back this Transaction
200+
* Rolls back any active Savepoints and then rolls back this Transaction
155201
*/
156202
def rollback(): Unit
203+
204+
def addTransactionListener(listener: TransactionListener): Unit
157205
}
158206

159207
/**
@@ -187,9 +235,19 @@ object DbApi {
187235
connection: java.sql.Connection,
188236
config: Config,
189237
dialect: DialectConfig,
190-
autoCommit: Boolean,
191-
rollBack0: () => Unit
238+
defaultListeners: Iterable[TransactionListener],
239+
autoCommit: Boolean
192240
) extends DbApi.Txn {
241+
242+
val listeners =
243+
collection.mutable.ArrayDeque.empty[TransactionListener].addAll(defaultListeners)
244+
245+
override def addTransactionListener(listener: TransactionListener): Unit = {
246+
if (autoCommit)
247+
throw new IllegalStateException("Cannot add listener to auto-commit transaction")
248+
listeners.append(listener)
249+
}
250+
193251
def run[Q, R](query: Q, fetchSize: Int = -1, queryTimeoutSeconds: Int = -1)(
194252
implicit qr: Queryable[Q, R],
195253
fileName: sourcecode.FileName,
@@ -218,6 +276,7 @@ object DbApi {
218276
res.toVector.asInstanceOf[R]
219277
}
220278
}
279+
221280
}
222281

223282
def stream[Q, R](query: Q, fetchSize: Int = -1, queryTimeoutSeconds: Int = -1)(
@@ -229,8 +288,8 @@ object DbApi {
229288
streamFlattened0(
230289
r => {
231290
qr.asInstanceOf[Queryable[Q, R]].construct(query, r) match {
232-
case s: Seq[R] => s.head
233-
case r: R => r
291+
case s: Seq[R] @unchecked => s.head
292+
case r: R @unchecked => r
234293
}
235294
},
236295
flattened,
@@ -545,8 +604,13 @@ object DbApi {
545604
}
546605

547606
def rollback() = {
548-
savepointStack.clear()
549-
rollBack0()
607+
try {
608+
notifyListeners(listeners)(_.beforeRollback())
609+
} finally {
610+
savepointStack.clear()
611+
connection.rollback()
612+
notifyListeners(listeners)(_.afterRollback())
613+
}
550614
}
551615

552616
private def cast[T](t: Any): T = t.asInstanceOf[T]

scalasql/core/src/DbClient.scala

Lines changed: 64 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,33 @@ trait DbClient {
3535

3636
object DbClient {
3737

38+
/**
39+
* Calls the given function for each listener, collecting any exceptions and throwing them
40+
* as a single exception if any are thrown.
41+
*/
42+
private[core] def notifyListeners(listeners: Iterable[DbApi.TransactionListener])(
43+
f: DbApi.TransactionListener => Unit
44+
): Unit = {
45+
if (listeners.isEmpty) return
46+
47+
var exception: Throwable = null
48+
listeners.foreach { listener =>
49+
try {
50+
f(listener)
51+
} catch {
52+
case e: Throwable =>
53+
if (exception == null) exception = e
54+
else exception.addSuppressed(e)
55+
}
56+
}
57+
if (exception != null) throw exception
58+
}
59+
3860
class Connection(
3961
connection: java.sql.Connection,
40-
config: Config = new Config {}
62+
config: Config = new Config {},
63+
/** Listeners that are added to all transactions created by this connection */
64+
listeners: Seq[DbApi.TransactionListener] = Seq.empty
4165
)(implicit dialect: DialectConfig)
4266
extends DbClient {
4367

@@ -49,28 +73,57 @@ object DbClient {
4973

5074
def transaction[T](block: DbApi.Txn => T): T = {
5175
connection.setAutoCommit(false)
52-
val txn =
53-
new DbApi.Impl(connection, config, dialect, false, () => connection.rollback())
54-
try block(txn)
55-
catch {
76+
val txn = new DbApi.Impl(connection, config, dialect, listeners, autoCommit = false)
77+
var rolledBack = false
78+
try {
79+
notifyListeners(txn.listeners)(_.begin())
80+
val result = block(txn)
81+
notifyListeners(txn.listeners)(_.beforeCommit())
82+
result
83+
} catch {
5684
case e: Throwable =>
57-
connection.rollback()
85+
rolledBack = true
86+
try {
87+
notifyListeners(txn.listeners)(_.beforeRollback())
88+
} catch {
89+
case e2: Throwable => e.addSuppressed(e2)
90+
} finally {
91+
connection.rollback()
92+
try {
93+
notifyListeners(txn.listeners)(_.afterRollback())
94+
} catch {
95+
case e3: Throwable => e.addSuppressed(e3)
96+
}
97+
}
5898
throw e
59-
} finally connection.setAutoCommit(true)
99+
} finally {
100+
// this commits uncommitted operations, if any
101+
connection.setAutoCommit(true)
102+
if (!rolledBack) {
103+
notifyListeners(txn.listeners)(_.afterCommit())
104+
}
105+
}
60106
}
61107

62108
def getAutoCommitClientConnection: DbApi = {
63109
connection.setAutoCommit(true)
64-
new DbApi.Impl(connection, config, dialect, autoCommit = true, () => ())
110+
new DbApi.Impl(connection, config, dialect, listeners, autoCommit = true)
65111
}
66112
}
67113

68114
class DataSource(
69115
dataSource: javax.sql.DataSource,
70-
config: Config = new Config {}
116+
config: Config = new Config {},
117+
/** Listeners that are added to all transactions created through the [[DataSource]] */
118+
listeners: Seq[DbApi.TransactionListener] = Seq.empty
71119
)(implicit dialect: DialectConfig)
72120
extends DbClient {
73121

122+
/** Returns a new [[DataSource]] with the given listener added */
123+
def withTransactionListener(listener: DbApi.TransactionListener): DbClient = {
124+
new DataSource(dataSource, config, listeners :+ listener)
125+
}
126+
74127
def renderSql[Q, R](query: Q, castParams: Boolean = false)(
75128
implicit qr: Queryable[Q, R]
76129
): String = {
@@ -79,7 +132,7 @@ object DbClient {
79132

80133
private def withConnection[T](f: DbClient.Connection => T): T = {
81134
val connection = dataSource.getConnection
82-
try f(new DbClient.Connection(connection, config))
135+
try f(new DbClient.Connection(connection, config, listeners))
83136
finally connection.close()
84137
}
85138

@@ -88,7 +141,7 @@ object DbClient {
88141
def getAutoCommitClientConnection: DbApi = {
89142
val connection = dataSource.getConnection
90143
connection.setAutoCommit(true)
91-
new DbApi.Impl(connection, config, dialect, autoCommit = true, () => ())
144+
new DbApi.Impl(connection, config, dialect, defaultListeners = Seq.empty, autoCommit = true)
92145
}
93146
}
94147
}

scalasql/test/src/api/TransactionTests.scala

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package scalasql.api
22

33
import scalasql.Purchase
44
import scalasql.utils.{ScalaSqlSuite, SqliteSuite}
5+
import scalasql.DbApi
56
import sourcecode.Text
67
import utest._
78

@@ -12,6 +13,42 @@ trait TransactionTests extends ScalaSqlSuite {
1213
override def utestBeforeEach(path: Seq[String]): Unit = checker.reset()
1314
class FooException extends Exception
1415

16+
class ListenerException(message: String) extends Exception(message)
17+
18+
class StubTransactionListener(
19+
throwOnBeforeCommit: Boolean = false,
20+
throwOnAfterCommit: Boolean = false,
21+
throwOnBeforeRollback: Boolean = false,
22+
throwOnAfterRollback: Boolean = false
23+
) extends DbApi.TransactionListener {
24+
var beginCalled = false
25+
var beforeCommitCalled = false
26+
var afterCommitCalled = false
27+
var beforeRollbackCalled = false
28+
var afterRollbackCalled = false
29+
30+
override def begin(): Unit = {
31+
beginCalled = true
32+
}
33+
34+
override def beforeCommit(): Unit = {
35+
beforeCommitCalled = true
36+
if (throwOnBeforeCommit) throw new ListenerException("beforeCommit")
37+
}
38+
override def afterCommit(): Unit = {
39+
afterCommitCalled = true
40+
if (throwOnAfterCommit) throw new ListenerException("afterCommit")
41+
}
42+
override def beforeRollback(): Unit = {
43+
beforeRollbackCalled = true
44+
if (throwOnBeforeRollback) throw new ListenerException("beforeRollback")
45+
}
46+
override def afterRollback(): Unit = {
47+
afterRollbackCalled = true
48+
if (throwOnAfterRollback) throw new ListenerException("afterRollback")
49+
}
50+
}
51+
1552
def tests = Tests {
1653
test("simple") {
1754
test("commit") - checker.recorded(
@@ -537,5 +574,77 @@ trait TransactionTests extends ScalaSqlSuite {
537574
}
538575
}
539576
}
577+
578+
test("listener") {
579+
test("beforeCommit and afterCommit are called under normal circumstances") {
580+
val listener = new StubTransactionListener()
581+
dbClient.withTransactionListener(listener).transaction { _ =>
582+
// do nothing
583+
}
584+
listener.beginCalled ==> true
585+
listener.beforeCommitCalled ==> true
586+
listener.afterCommitCalled ==> true
587+
listener.beforeRollbackCalled ==> false
588+
listener.afterRollbackCalled ==> false
589+
}
590+
591+
test("if beforeCommit causes an exception, {before,after}Rollback are called") {
592+
val listener = new StubTransactionListener(throwOnBeforeCommit = true)
593+
val e = intercept[ListenerException] {
594+
dbClient.transaction { implicit txn =>
595+
txn.addTransactionListener(listener)
596+
}
597+
}
598+
e.getMessage ==> "beforeCommit"
599+
listener.beforeCommitCalled ==> true
600+
listener.afterCommitCalled ==> false
601+
listener.beforeRollbackCalled ==> true
602+
listener.afterRollbackCalled ==> true
603+
}
604+
605+
test("if afterCommit causes an exception, the exception is propagated") {
606+
val listener = new StubTransactionListener(throwOnAfterCommit = true)
607+
val e = intercept[ListenerException] {
608+
dbClient.transaction { implicit txn =>
609+
txn.addTransactionListener(listener)
610+
}
611+
}
612+
e.getMessage ==> "afterCommit"
613+
listener.beforeCommitCalled ==> true
614+
listener.afterCommitCalled ==> true
615+
listener.beforeRollbackCalled ==> false
616+
listener.afterRollbackCalled ==> false
617+
}
618+
619+
test("if beforeRollback causes an exception, afterRollback is still called") {
620+
val listener = new StubTransactionListener(throwOnBeforeRollback = true)
621+
val e = intercept[FooException] {
622+
dbClient.transaction { implicit txn =>
623+
txn.addTransactionListener(listener)
624+
throw new FooException()
625+
}
626+
}
627+
e.getSuppressed.head.getMessage ==> "beforeRollback"
628+
listener.beforeCommitCalled ==> false
629+
listener.afterCommitCalled ==> false
630+
listener.beforeRollbackCalled ==> true
631+
listener.afterRollbackCalled ==> true
632+
}
633+
634+
test("if afterRollback causes an exception, the exception is propagated") {
635+
val listener = new StubTransactionListener(throwOnAfterRollback = true)
636+
val e = intercept[FooException] {
637+
dbClient.transaction { implicit txn =>
638+
txn.addTransactionListener(listener)
639+
throw new FooException()
640+
}
641+
}
642+
e.getSuppressed.head.getMessage ==> "afterRollback"
643+
listener.beforeCommitCalled ==> false
644+
listener.afterCommitCalled ==> false
645+
listener.beforeRollbackCalled ==> true
646+
listener.afterRollbackCalled ==> true
647+
}
648+
}
540649
}
541650
}

scalasql/test/src/utils/TestChecker.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import scalasql.query.SubqueryRef
66
import scalasql.{DbClient, Queryable, Expr, UtestFramework}
77

88
class TestChecker(
9-
val dbClient: DbClient,
9+
val dbClient: DbClient.DataSource,
1010
testSchemaFileName: String,
1111
testDataFileName: String,
1212
suiteName: String,

0 commit comments

Comments
 (0)