Skip to content

Commit 0c2803a

Browse files
cupofcattoddbaert
andauthored
feat(flagd): Add features to customize auth to Sync API server (authorityOverride and clientInterceptors) (#1260)
Signed-off-by: Maks Osowski <[email protected]> Co-authored-by: Todd Baert <[email protected]>
1 parent 31b1ebc commit 0c2803a

File tree

6 files changed

+157
-0
lines changed

6 files changed

+157
-0
lines changed

providers/flagd/README.md

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ Given below are the supported configurations:
110110
| port | FLAGD_PORT | int | 8013 | rpc & in-process |
111111
| targetUri | FLAGD_TARGET_URI | string | null | rpc & in-process |
112112
| tls | FLAGD_TLS | boolean | false | rpc & in-process |
113+
| defaultAuthority | FLAGD_DEFAULT_AUTHORITY | String | null | rpc & in-process |
113114
| socketPath | FLAGD_SOCKET_PATH | String | null | rpc & in-process |
114115
| certPath | FLAGD_SERVER_CERT_PATH | String | null | rpc & in-process |
115116
| deadline | FLAGD_DEADLINE_MS | int | 500 | rpc & in-process & file |
@@ -180,6 +181,50 @@ FlagdProvider flagdProvider = new FlagdProvider(
180181
> There's a [vulnerability](https://security.snyk.io/vuln/SNYK-JAVA-IONETTY-1042268) in [netty](https://github.com/netty/netty), a transitive dependency of the underlying gRPC libraries used in the flagd-provider that fails to correctly validate certificates.
181182
> This will be addressed in netty v5.
182183
184+
### Configuring gRPC credentials and headers
185+
186+
The `clientInterceptors` and `defaultAuthority` are meant for connection of the in-process resolver to a Sync API implementation on a host/port, that might require special credentials or headers.
187+
188+
```java
189+
private static ClientInterceptor createHeaderInterceptor() {
190+
return new ClientInterceptor() {
191+
@Override
192+
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) {
193+
return new ForwardingClientCall.SimpleForwardingClientCall<ReqT, RespT>(next.newCall(method, callOptions)) {
194+
@Override
195+
public void start(Listener<RespT> responseListener, Metadata headers) {
196+
headers.put(Metadata.Key.of("custom-header", Metadata.ASCII_STRING_MARSHALLER), "header-value");
197+
super.start(responseListener, headers);
198+
}
199+
};
200+
}
201+
};
202+
}
203+
204+
private static ClientInterceptor createCallCrednetialsInterceptor(CallCredentials callCredentials) throws IOException {
205+
return new ClientInterceptor() {
206+
@Override
207+
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) {
208+
return next.newCall(method, callOptions.withCallCredentials(callCredentials));
209+
}
210+
};
211+
}
212+
213+
List<ClientInterceptor> clientInterceptors = new ArrayList<ClientInterceptor>(2);
214+
clientInterceptors.add(createHeaderInterceptor());
215+
CallCredentials myCallCredentals = ...;
216+
clientInterceptors.add(createCallCrednetialsInterceptor(myCallCredentials));
217+
218+
FlagdProvider flagdProvider = new FlagdProvider(
219+
FlagdOptions.builder()
220+
.host("example.com/flagdSyncApi")
221+
.port(443)
222+
.tls(true)
223+
.defaultAuthority("authority-host.sync.example.com")
224+
.clientInterceptors(clientInterceptors)
225+
.build());
226+
```
227+
183228
### Caching (RPC only)
184229

185230
> [!NOTE]

providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/Config.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ public final class Config {
2424
static final String HOST_ENV_VAR_NAME = "FLAGD_HOST";
2525
static final String PORT_ENV_VAR_NAME = "FLAGD_PORT";
2626
static final String TLS_ENV_VAR_NAME = "FLAGD_TLS";
27+
static final String DEFAULT_AUTHORITY_ENV_VAR_NAME = "FLAGD_DEFAULT_AUTHORITY";
2728
static final String SOCKET_PATH_ENV_VAR_NAME = "FLAGD_SOCKET_PATH";
2829
static final String SERVER_CERT_PATH_ENV_VAR_NAME = "FLAGD_SERVER_CERT_PATH";
2930
static final String CACHE_ENV_VAR_NAME = "FLAGD_CACHE";

providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/FlagdOptions.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77
import dev.openfeature.sdk.EvaluationContext;
88
import dev.openfeature.sdk.ImmutableContext;
99
import dev.openfeature.sdk.Structure;
10+
import io.grpc.ClientInterceptor;
1011
import io.opentelemetry.api.GlobalOpenTelemetry;
1112
import io.opentelemetry.api.OpenTelemetry;
13+
import java.util.List;
1214
import java.util.function.Function;
1315
import lombok.Builder;
1416
import lombok.Getter;
@@ -164,6 +166,18 @@ public class FlagdOptions {
164166
*/
165167
private OpenTelemetry openTelemetry;
166168

169+
/**
170+
* gRPC client interceptors to be used when creating a gRPC channel.
171+
*/
172+
@Builder.Default
173+
private List<ClientInterceptor> clientInterceptors = null;
174+
175+
/**
176+
* Authority header to be used when creating a gRPC channel.
177+
*/
178+
@Builder.Default
179+
private String defaultAuthority = fallBackToEnvOrDefault(Config.DEFAULT_AUTHORITY_ENV_VAR_NAME, null);
180+
167181
/**
168182
* Builder overwrite in order to customize the "build" method.
169183
*

providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/common/ChannelBuilder.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,12 @@ public static ManagedChannel nettyChannel(final FlagdOptions options) {
6363
final NettyChannelBuilder builder =
6464
NettyChannelBuilder.forTarget(targetUri).keepAliveTime(keepAliveMs, TimeUnit.MILLISECONDS);
6565

66+
if (options.getDefaultAuthority() != null) {
67+
builder.overrideAuthority(options.getDefaultAuthority());
68+
}
69+
if (options.getClientInterceptors() != null) {
70+
builder.intercept(options.getClientInterceptors());
71+
}
6672
if (options.isTls()) {
6773
SslContextBuilder sslContext = GrpcSslContexts.forClient();
6874

providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/FlagdOptionsTest.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@
2020

2121
import dev.openfeature.contrib.providers.flagd.resolver.process.storage.MockConnector;
2222
import dev.openfeature.contrib.providers.flagd.resolver.process.storage.connector.Connector;
23+
import io.grpc.ClientInterceptor;
2324
import io.opentelemetry.api.OpenTelemetry;
25+
import java.util.ArrayList;
26+
import java.util.List;
2427
import java.util.function.Function;
2528
import org.junit.jupiter.api.Nested;
2629
import org.junit.jupiter.api.Test;
@@ -46,12 +49,15 @@ void TestDefaults() {
4649
assertNull(builder.getOfflineFlagSourcePath());
4750
assertEquals(Resolver.RPC, builder.getResolverType());
4851
assertEquals(0, builder.getKeepAlive());
52+
assertNull(builder.getDefaultAuthority());
53+
assertNull(builder.getClientInterceptors());
4954
}
5055

5156
@Test
5257
void TestBuilderOptions() {
5358
OpenTelemetry openTelemetry = Mockito.mock(OpenTelemetry.class);
5459
Connector connector = new MockConnector(null);
60+
List<ClientInterceptor> clientInterceptors = new ArrayList<ClientInterceptor>();
5561

5662
FlagdOptions flagdOptions = FlagdOptions.builder()
5763
.host("https://hosted-flagd")
@@ -66,6 +72,8 @@ void TestBuilderOptions() {
6672
.resolverType(Resolver.IN_PROCESS)
6773
.targetUri("dns:///localhost:8016")
6874
.keepAlive(1000)
75+
.defaultAuthority("test-authority.sync.example.com")
76+
.clientInterceptors(clientInterceptors)
6977
.build();
7078

7179
assertEquals("https://hosted-flagd", flagdOptions.getHost());
@@ -80,6 +88,8 @@ void TestBuilderOptions() {
8088
assertEquals(Resolver.IN_PROCESS, flagdOptions.getResolverType());
8189
assertEquals("dns:///localhost:8016", flagdOptions.getTargetUri());
8290
assertEquals(1000, flagdOptions.getKeepAlive());
91+
assertEquals("test-authority.sync.example.com", flagdOptions.getDefaultAuthority());
92+
assertEquals(clientInterceptors, flagdOptions.getClientInterceptors());
8393
}
8494

8595
@Test

providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/resolver/common/ChannelBuilderTest.java

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import static org.assertj.core.api.Assertions.assertThat;
44
import static org.assertj.core.api.Assertions.assertThatThrownBy;
5+
import static org.mockito.ArgumentMatchers.anyList;
56
import static org.mockito.Mockito.any;
67
import static org.mockito.Mockito.anyLong;
78
import static org.mockito.Mockito.anyString;
@@ -11,6 +12,7 @@
1112
import static org.mockito.Mockito.when;
1213

1314
import dev.openfeature.contrib.providers.flagd.FlagdOptions;
15+
import io.grpc.ClientInterceptor;
1416
import io.grpc.ManagedChannel;
1517
import io.grpc.netty.GrpcSslContexts;
1618
import io.grpc.netty.NettyChannelBuilder;
@@ -20,6 +22,8 @@
2022
import io.netty.channel.unix.DomainSocketAddress;
2123
import io.netty.handler.ssl.SslContextBuilder;
2224
import java.io.File;
25+
import java.util.ArrayList;
26+
import java.util.List;
2327
import java.util.concurrent.TimeUnit;
2428
import javax.net.ssl.SSLKeyException;
2529
import org.junit.jupiter.api.Test;
@@ -113,6 +117,83 @@ void testNettyChannel_withTlsAndCert() {
113117
}
114118
}
115119

120+
@Test
121+
void testNettyChannel_withDefaultAuthority() {
122+
try (MockedStatic<NettyChannelBuilder> nettyMock = mockStatic(NettyChannelBuilder.class)) {
123+
// Mocks
124+
NettyChannelBuilder mockBuilder = mock(NettyChannelBuilder.class);
125+
ManagedChannel mockChannel = mock(ManagedChannel.class);
126+
nettyMock
127+
.when(() -> NettyChannelBuilder.forTarget("localhost:8080"))
128+
.thenReturn(mockBuilder);
129+
130+
when(mockBuilder.keepAliveTime(anyLong(), any(TimeUnit.class))).thenReturn(mockBuilder);
131+
when(mockBuilder.sslContext(any())).thenReturn(mockBuilder);
132+
when(mockBuilder.overrideAuthority(anyString())).thenReturn(mockBuilder);
133+
when(mockBuilder.build()).thenReturn(mockChannel);
134+
135+
// Input options
136+
FlagdOptions options = FlagdOptions.builder()
137+
.host("localhost")
138+
.port(8080)
139+
.keepAlive(5000)
140+
.tls(true)
141+
.defaultAuthority("test-authority.sync.example.com")
142+
.build();
143+
144+
// Call method under test
145+
ManagedChannel channel = ChannelBuilder.nettyChannel(options);
146+
147+
// Assertions
148+
assertThat(channel).isEqualTo(mockChannel);
149+
nettyMock.verify(() -> NettyChannelBuilder.forTarget("localhost:8080"));
150+
verify(mockBuilder).keepAliveTime(5000, TimeUnit.MILLISECONDS);
151+
verify(mockBuilder).sslContext(any());
152+
verify(mockBuilder).overrideAuthority("test-authority.sync.example.com");
153+
verify(mockBuilder).build();
154+
}
155+
}
156+
157+
@Test
158+
void testNettyChannel_withClientInterceptors() {
159+
try (MockedStatic<NettyChannelBuilder> nettyMock = mockStatic(NettyChannelBuilder.class)) {
160+
// Mocks
161+
NettyChannelBuilder mockBuilder = mock(NettyChannelBuilder.class);
162+
ManagedChannel mockChannel = mock(ManagedChannel.class);
163+
nettyMock
164+
.when(() -> NettyChannelBuilder.forTarget("localhost:8080"))
165+
.thenReturn(mockBuilder);
166+
167+
when(mockBuilder.keepAliveTime(anyLong(), any(TimeUnit.class))).thenReturn(mockBuilder);
168+
when(mockBuilder.sslContext(any())).thenReturn(mockBuilder);
169+
when(mockBuilder.intercept(anyList())).thenReturn(mockBuilder);
170+
when(mockBuilder.build()).thenReturn(mockChannel);
171+
172+
List<ClientInterceptor> clientInterceptors = new ArrayList<ClientInterceptor>();
173+
clientInterceptors.add(mock(ClientInterceptor.class));
174+
175+
// Input options
176+
FlagdOptions options = FlagdOptions.builder()
177+
.host("localhost")
178+
.port(8080)
179+
.keepAlive(5000)
180+
.tls(true)
181+
.clientInterceptors(clientInterceptors)
182+
.build();
183+
184+
// Call method under test
185+
ManagedChannel channel = ChannelBuilder.nettyChannel(options);
186+
187+
// Assertions
188+
assertThat(channel).isEqualTo(mockChannel);
189+
nettyMock.verify(() -> NettyChannelBuilder.forTarget("localhost:8080"));
190+
verify(mockBuilder).keepAliveTime(5000, TimeUnit.MILLISECONDS);
191+
verify(mockBuilder).sslContext(any());
192+
verify(mockBuilder).intercept(clientInterceptors);
193+
verify(mockBuilder).build();
194+
}
195+
}
196+
116197
@ParameterizedTest
117198
@ValueSource(strings = {"/incorrect/{uri}/;)"})
118199
void testNettyChannel_withInvalidTargetUri(String uri) {

0 commit comments

Comments
 (0)