Skip to content

Commit 6b9b10d

Browse files
Bernd Warmuthwarber
Bernd Warmuth
authored andcommitted
chore: simplify ChennelMonitor methods
Signed-off-by: Bernd Warmuth <[email protected]>
1 parent d2efff5 commit 6b9b10d

File tree

3 files changed

+124
-39
lines changed

3 files changed

+124
-39
lines changed

providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/common/ChannelMonitor.java

Lines changed: 33 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import io.grpc.ConnectivityState;
55
import io.grpc.ManagedChannel;
66
import java.util.concurrent.CountDownLatch;
7+
import java.util.concurrent.Executors;
8+
import java.util.concurrent.ScheduledFuture;
79
import java.util.concurrent.TimeUnit;
810
import lombok.extern.slf4j.Slf4j;
911

@@ -32,65 +34,58 @@ public static void monitorChannelState(
3234
ConnectivityState currentState = channel.getState(true);
3335
log.info("Channel state changed to: {}", currentState);
3436
if (currentState == ConnectivityState.READY) {
35-
onConnectionReady.run();
37+
if (onConnectionReady != null) {
38+
onConnectionReady.run();
39+
} else {
40+
log.debug("onConnectionReady is null");
41+
}
3642
} else if (currentState == ConnectivityState.TRANSIENT_FAILURE
3743
|| currentState == ConnectivityState.SHUTDOWN) {
38-
onConnectionLost.run();
44+
if (onConnectionLost != null) {
45+
onConnectionLost.run();
46+
} else {
47+
log.debug("onConnectionLost is null");
48+
}
3949
}
4050
// Re-register the state monitor to watch for the next state transition.
4151
monitorChannelState(currentState, channel, onConnectionReady, onConnectionLost);
4252
});
4353
}
4454

4555
/**
46-
* Waits for the channel to reach a desired state within a specified timeout period.
56+
* Waits for the channel to reach the desired connectivity state within the specified timeout.
4757
*
48-
* @param channel the ManagedChannel to monitor.
49-
* @param desiredState the ConnectivityState to wait for.
50-
* @param connectCallback callback invoked when the desired state is reached.
51-
* @param timeout the maximum amount of time to wait.
52-
* @param unit the time unit of the timeout.
53-
* @throws InterruptedException if the current thread is interrupted while waiting.
58+
* @param desiredState the desired {@link ConnectivityState} to wait for
59+
* @param channel the {@link ManagedChannel} to monitor
60+
* @param connectCallback the {@link Runnable} to execute when the desired state is reached
61+
* @param timeout the maximum time to wait
62+
* @param unit the time unit of the timeout argument
63+
* @throws InterruptedException if the current thread is interrupted while waiting
64+
* @throws GeneralError if the desired state is not reached within the timeout
5465
*/
5566
public static void waitForDesiredState(
56-
ManagedChannel channel,
5767
ConnectivityState desiredState,
58-
Runnable connectCallback,
59-
long timeout,
60-
TimeUnit unit)
61-
throws InterruptedException {
62-
waitForDesiredState(channel, desiredState, connectCallback, new CountDownLatch(1), timeout, unit);
63-
}
64-
65-
private static void waitForDesiredState(
6668
ManagedChannel channel,
67-
ConnectivityState desiredState,
6869
Runnable connectCallback,
69-
CountDownLatch latch,
7070
long timeout,
7171
TimeUnit unit)
7272
throws InterruptedException {
73-
channel.notifyWhenStateChanged(ConnectivityState.SHUTDOWN, () -> {
74-
try {
75-
ConnectivityState state = channel.getState(true);
76-
log.debug("Channel state changed to: {}", state);
73+
CountDownLatch latch = new CountDownLatch(1);
7774

78-
if (state == desiredState) {
79-
connectCallback.run();
80-
latch.countDown();
81-
return;
82-
}
83-
waitForDesiredState(channel, desiredState, connectCallback, latch, timeout, unit);
84-
} catch (InterruptedException e) {
85-
Thread.currentThread().interrupt();
86-
log.error("Thread interrupted while waiting for desired state", e);
87-
} catch (Exception e) {
88-
log.error("Error occurred while waiting for desired state", e);
75+
Runnable waitForStateTask = () -> {
76+
ConnectivityState currentState = channel.getState(true);
77+
if (currentState == desiredState) {
78+
connectCallback.run();
79+
latch.countDown();
8980
}
90-
});
81+
};
82+
83+
ScheduledFuture<?> scheduledFuture = Executors.newSingleThreadScheduledExecutor()
84+
.scheduleWithFixedDelay(waitForStateTask, 0, 100, TimeUnit.MILLISECONDS);
9185

92-
// Await the latch or timeout for the state change
93-
if (!latch.await(timeout, unit)) {
86+
boolean success = latch.await(timeout, unit);
87+
scheduledFuture.cancel(true);
88+
if (!success) {
9489
throw new GeneralError(String.format(
9590
"Deadline exceeded. Condition did not complete within the %d " + "deadline", timeout));
9691
}

providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/common/GrpcConnector.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ public GrpcConnector(
137137
public void initialize() throws Exception {
138138
log.info("Initializing GRPC connection...");
139139
ChannelMonitor.waitForDesiredState(
140-
channel, ConnectivityState.READY, this::onInitialConnect, deadline, TimeUnit.MILLISECONDS);
140+
ConnectivityState.READY, channel, this::onInitialConnect, deadline, TimeUnit.MILLISECONDS);
141141
ChannelMonitor.monitorChannelState(ConnectivityState.READY, channel, this::onReady, this::onConnectionLost);
142142
}
143143

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
package dev.openfeature.contrib.providers.flagd.resolver.common;
2+
3+
import static dev.openfeature.contrib.providers.flagd.resolver.common.ChannelMonitor.monitorChannelState;
4+
import static dev.openfeature.contrib.providers.flagd.resolver.common.ChannelMonitor.waitForDesiredState;
5+
import static org.junit.Assert.assertThrows;
6+
import static org.mockito.ArgumentMatchers.anyBoolean;
7+
import static org.mockito.ArgumentMatchers.eq;
8+
import static org.mockito.Mockito.doNothing;
9+
import static org.mockito.Mockito.mock;
10+
import static org.mockito.Mockito.never;
11+
import static org.mockito.Mockito.times;
12+
import static org.mockito.Mockito.verify;
13+
import static org.mockito.Mockito.when;
14+
15+
import dev.openfeature.sdk.exceptions.GeneralError;
16+
import io.grpc.ConnectivityState;
17+
import io.grpc.ManagedChannel;
18+
import java.util.concurrent.TimeUnit;
19+
import org.junit.jupiter.api.Test;
20+
import org.junit.jupiter.params.ParameterizedTest;
21+
import org.junit.jupiter.params.provider.EnumSource;
22+
import org.mockito.ArgumentCaptor;
23+
import org.mockito.Mockito;
24+
25+
class ChannelMonitorTest {
26+
@Test
27+
void testWaitForDesiredState() throws InterruptedException {
28+
ManagedChannel channel = mock(ManagedChannel.class);
29+
Runnable connectCallback = mock(Runnable.class);
30+
31+
// Set up the desired state
32+
ConnectivityState desiredState = ConnectivityState.READY;
33+
when(channel.getState(anyBoolean())).thenReturn(desiredState);
34+
35+
// Call the method
36+
waitForDesiredState(desiredState, channel, connectCallback, 1, TimeUnit.SECONDS);
37+
38+
// Verify that the callback was run
39+
verify(connectCallback, times(1)).run();
40+
}
41+
42+
@Test
43+
void testWaitForDesiredStateTimeout() {
44+
ManagedChannel channel = Mockito.mock(ManagedChannel.class);
45+
Runnable connectCallback = mock(Runnable.class);
46+
47+
// Set up the desired state
48+
ConnectivityState desiredState = ConnectivityState.READY;
49+
when(channel.getState(anyBoolean())).thenReturn(ConnectivityState.IDLE);
50+
51+
// Call the method and expect a timeout
52+
assertThrows(GeneralError.class, () -> {
53+
waitForDesiredState(desiredState, channel, connectCallback, 1, TimeUnit.SECONDS);
54+
});
55+
}
56+
57+
@ParameterizedTest
58+
@EnumSource(ConnectivityState.class)
59+
void testMonitorChannelState(ConnectivityState state) {
60+
ManagedChannel channel = Mockito.mock(ManagedChannel.class);
61+
Runnable onConnectionReady = mock(Runnable.class);
62+
Runnable onConnectionLost = mock(Runnable.class);
63+
64+
// Set up the expected state
65+
ConnectivityState expectedState = ConnectivityState.IDLE;
66+
when(channel.getState(anyBoolean())).thenReturn(state);
67+
68+
// Capture the callback
69+
ArgumentCaptor<Runnable> callbackCaptor = ArgumentCaptor.forClass(Runnable.class);
70+
doNothing().when(channel).notifyWhenStateChanged(eq(expectedState), callbackCaptor.capture());
71+
72+
// Call the method
73+
monitorChannelState(expectedState, channel, onConnectionReady, onConnectionLost);
74+
75+
// Simulate state change
76+
callbackCaptor.getValue().run();
77+
78+
// Verify the callbacks based on the state
79+
if (state == ConnectivityState.READY) {
80+
verify(onConnectionReady, times(1)).run();
81+
verify(onConnectionLost, never()).run();
82+
} else if (state == ConnectivityState.TRANSIENT_FAILURE || state == ConnectivityState.SHUTDOWN) {
83+
verify(onConnectionReady, never()).run();
84+
verify(onConnectionLost, times(1)).run();
85+
} else {
86+
verify(onConnectionReady, never()).run();
87+
verify(onConnectionLost, never()).run();
88+
}
89+
}
90+
}

0 commit comments

Comments
 (0)