40
40
import java .net .StandardSocketOptions ;
41
41
import java .nio .ByteBuffer ;
42
42
import java .nio .channels .CompletionHandler ;
43
+ import java .nio .channels .InterruptedByTimeoutException ;
43
44
import java .nio .channels .SelectionKey ;
44
45
import java .nio .channels .Selector ;
45
46
import java .nio .channels .SocketChannel ;
49
50
import java .util .concurrent .ExecutorService ;
50
51
import java .util .concurrent .Future ;
51
52
import java .util .concurrent .TimeUnit ;
53
+ import java .util .concurrent .atomic .AtomicReference ;
52
54
53
55
import static com .mongodb .assertions .Assertions .assertTrue ;
54
56
import static com .mongodb .assertions .Assertions .isTrue ;
@@ -97,21 +99,40 @@ public void close() {
97
99
group .shutdown ();
98
100
}
99
101
102
+ /**
103
+ * Monitors `OP_CONNECT` events for socket connections.
104
+ */
100
105
private static class SelectorMonitor implements Closeable {
101
106
102
- private static final class Pair {
107
+ static final class SocketRegistration {
103
108
private final SocketChannel socketChannel ;
104
- private final Runnable attachment ;
109
+ private final AtomicReference < Runnable > afterConnectAction ;
105
110
106
- private Pair (final SocketChannel socketChannel , final Runnable attachment ) {
111
+ SocketRegistration (final SocketChannel socketChannel , final Runnable afterConnectAction ) {
107
112
this .socketChannel = socketChannel ;
108
- this .attachment = attachment ;
113
+ this .afterConnectAction = new AtomicReference <>(afterConnectAction );
114
+ }
115
+
116
+ boolean tryCancelPendingConnection () {
117
+ return tryTakeAction () != null ;
118
+ }
119
+
120
+ void runAfterConnectActionIfNotCanceled () {
121
+ Runnable afterConnectActionToExecute = tryTakeAction ();
122
+ if (afterConnectActionToExecute != null ) {
123
+ afterConnectActionToExecute .run ();
124
+ }
125
+ }
126
+
127
+ @ Nullable
128
+ private Runnable tryTakeAction () {
129
+ return afterConnectAction .getAndSet (null );
109
130
}
110
131
}
111
132
112
133
private final Selector selector ;
113
134
private volatile boolean isClosed ;
114
- private final ConcurrentLinkedDeque <Pair > pendingRegistrations = new ConcurrentLinkedDeque <>();
135
+ private final ConcurrentLinkedDeque <SocketRegistration > pendingRegistrations = new ConcurrentLinkedDeque <>();
115
136
116
137
SelectorMonitor () {
117
138
try {
@@ -127,17 +148,14 @@ void start() {
127
148
while (!isClosed ) {
128
149
try {
129
150
selector .select ();
130
-
131
151
for (SelectionKey selectionKey : selector .selectedKeys ()) {
132
152
selectionKey .cancel ();
133
- Runnable runnable = (Runnable ) selectionKey .attachment ();
134
- runnable .run ();
153
+ ((SocketRegistration ) selectionKey .attachment ()).runAfterConnectActionIfNotCanceled ();
135
154
}
136
155
137
- for (Iterator <Pair > iter = pendingRegistrations .iterator (); iter .hasNext ();) {
138
- Pair pendingRegistration = iter .next ();
139
- pendingRegistration .socketChannel .register (selector , SelectionKey .OP_CONNECT ,
140
- pendingRegistration .attachment );
156
+ for (Iterator <SocketRegistration > iter = pendingRegistrations .iterator (); iter .hasNext ();) {
157
+ SocketRegistration pendingRegistration = iter .next ();
158
+ pendingRegistration .socketChannel .register (selector , SelectionKey .OP_CONNECT , pendingRegistration );
141
159
iter .remove ();
142
160
}
143
161
} catch (Exception e ) {
@@ -156,8 +174,8 @@ void start() {
156
174
selectorThread .start ();
157
175
}
158
176
159
- void register (final SocketChannel channel , final Runnable attachment ) {
160
- pendingRegistrations .add (new Pair ( channel , attachment ) );
177
+ void register (final SocketRegistration registration ) {
178
+ pendingRegistrations .add (registration );
161
179
selector .wakeup ();
162
180
}
163
181
@@ -200,44 +218,79 @@ public void openAsync(final OperationContext operationContext, final AsyncComple
200
218
if (getSettings ().getSendBufferSize () > 0 ) {
201
219
socketChannel .setOption (StandardSocketOptions .SO_SNDBUF , getSettings ().getSendBufferSize ());
202
220
}
203
-
221
+ //getConnectTimeoutMs MUST be called before connection attempt, as it might throw MongoOperationTimeout exception.
222
+ int connectTimeoutMs = operationContext .getTimeoutContext ().getConnectTimeoutMs ();
204
223
socketChannel .connect (getSocketAddresses (getServerAddress (), inetAddressResolver ).get (0 ));
224
+ SelectorMonitor .SocketRegistration socketRegistration = new SelectorMonitor .SocketRegistration (
225
+ socketChannel , () -> initializeTslChannel (handler , socketChannel ));
205
226
206
- selectorMonitor .register (socketChannel , () -> {
207
- try {
208
- if (!socketChannel .finishConnect ()) {
209
- throw new MongoSocketOpenException ("Failed to finish connect" , getServerAddress ());
210
- }
227
+ if (connectTimeoutMs > 0 ) {
228
+ scheduleTimeoutInterruption (handler , socketRegistration , connectTimeoutMs );
229
+ }
230
+ selectorMonitor .register (socketRegistration );
231
+ } catch (IOException e ) {
232
+ handler .failed (new MongoSocketOpenException ("Exception opening socket" , getServerAddress (), e ));
233
+ } catch (Throwable t ) {
234
+ handler .failed (t );
235
+ }
236
+ }
211
237
212
- SSLEngine sslEngine = getSslContext ().createSSLEngine (getServerAddress ().getHost (),
213
- getServerAddress ().getPort ());
214
- sslEngine .setUseClientMode (true );
238
+ private void scheduleTimeoutInterruption (final AsyncCompletionHandler <Void > handler ,
239
+ final SelectorMonitor .SocketRegistration socketRegistration ,
240
+ final int connectTimeoutMs ) {
241
+ group .getTimeoutExecutor ().schedule (() -> {
242
+ if (socketRegistration .tryCancelPendingConnection ()) {
243
+ closeAndTimeout (handler , socketRegistration .socketChannel );
244
+ }
245
+ }, connectTimeoutMs , TimeUnit .MILLISECONDS );
246
+ }
215
247
216
- SSLParameters sslParameters = sslEngine .getSSLParameters ();
217
- enableSni (getServerAddress ().getHost (), sslParameters );
248
+ private void closeAndTimeout (final AsyncCompletionHandler <Void > handler , final SocketChannel socketChannel ) {
249
+ // We check if this stream was closed before timeout exception.
250
+ boolean streamClosed = isClosed ();
251
+ InterruptedByTimeoutException timeoutException = new InterruptedByTimeoutException ();
252
+ try {
253
+ socketChannel .close ();
254
+ } catch (Exception e ) {
255
+ timeoutException .addSuppressed (e );
256
+ }
218
257
219
- if (!sslSettings .isInvalidHostNameAllowed ()) {
220
- enableHostNameVerification (sslParameters );
221
- }
222
- sslEngine .setSSLParameters (sslParameters );
258
+ if (streamClosed ) {
259
+ handler .completed (null );
260
+ } else {
261
+ handler .failed (new MongoSocketOpenException ("Exception opening socket" , getAddress (), timeoutException ));
262
+ }
263
+ }
223
264
224
- BufferAllocator bufferAllocator = new BufferProviderAllocator ();
265
+ private void initializeTslChannel (final AsyncCompletionHandler <Void > handler , final SocketChannel socketChannel ) {
266
+ try {
267
+ if (!socketChannel .finishConnect ()) {
268
+ throw new MongoSocketOpenException ("Failed to finish connect" , getServerAddress ());
269
+ }
225
270
226
- TlsChannel tlsChannel = ClientTlsChannel .newBuilder (socketChannel , sslEngine )
227
- .withEncryptedBufferAllocator (bufferAllocator )
228
- .withPlainBufferAllocator (bufferAllocator )
229
- .build ();
271
+ SSLEngine sslEngine = getSslContext ().createSSLEngine (getServerAddress ().getHost (),
272
+ getServerAddress ().getPort ());
273
+ sslEngine .setUseClientMode (true );
230
274
231
- // build asynchronous channel, based in the TLS channel and associated with the global group.
232
- setChannel ( new AsynchronousTlsChannelAdapter ( new AsynchronousTlsChannel ( group , tlsChannel , socketChannel )) );
275
+ SSLParameters sslParameters = sslEngine . getSSLParameters ();
276
+ enableSni ( getServerAddress (). getHost (), sslParameters );
233
277
234
- handler .completed (null );
235
- } catch (IOException e ) {
236
- handler .failed (new MongoSocketOpenException ("Exception opening socket" , getServerAddress (), e ));
237
- } catch (Throwable t ) {
238
- handler .failed (t );
239
- }
240
- });
278
+ if (!sslSettings .isInvalidHostNameAllowed ()) {
279
+ enableHostNameVerification (sslParameters );
280
+ }
281
+ sslEngine .setSSLParameters (sslParameters );
282
+
283
+ BufferAllocator bufferAllocator = new BufferProviderAllocator ();
284
+
285
+ TlsChannel tlsChannel = ClientTlsChannel .newBuilder (socketChannel , sslEngine )
286
+ .withEncryptedBufferAllocator (bufferAllocator )
287
+ .withPlainBufferAllocator (bufferAllocator )
288
+ .build ();
289
+
290
+ // build asynchronous channel, based in the TLS channel and associated with the global group.
291
+ setChannel (new AsynchronousTlsChannelAdapter (new AsynchronousTlsChannel (group , tlsChannel , socketChannel )));
292
+
293
+ handler .completed (null );
241
294
} catch (IOException e ) {
242
295
handler .failed (new MongoSocketOpenException ("Exception opening socket" , getServerAddress (), e ));
243
296
} catch (Throwable t ) {
0 commit comments