Skip to content

feat(flagd): Add features to customize auth to Sync API server (authorityOverride and clientInterceptors) #1260

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions providers/flagd/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ Given below are the supported configurations:
| port | FLAGD_PORT | int | 8013 | rpc & in-process |
| targetUri | FLAGD_TARGET_URI | string | null | rpc & in-process |
| tls | FLAGD_TLS | boolean | false | rpc & in-process |
| defaultAuthority | FLAGD_DEFAULT_AUTHORITY | String | null | rpc & in-process |
| socketPath | FLAGD_SOCKET_PATH | String | null | rpc & in-process |
| certPath | FLAGD_SERVER_CERT_PATH | String | null | rpc & in-process |
| deadline | FLAGD_DEADLINE_MS | int | 500 | rpc & in-process & file |
Expand Down Expand Up @@ -180,6 +181,50 @@ FlagdProvider flagdProvider = new FlagdProvider(
> There's a [vulnerability](https://security.snyk.io/vuln/SNYK-JAVA-IONETTY-1042268) in [netty](https://github.com/netty/netty), a transitive dependency of the underlying gRPC libraries used in the flagd-provider that fails to correctly validate certificates.
> This will be addressed in netty v5.

### Configuring gRPC credentials and headers

The `clientInterceptors` and `defaultAuthority` are meant for connection of the in-process resolver to a Sync API implementation on a host/port, that might require special credentials or headers.

```java
private static ClientInterceptor createHeaderInterceptor() {
return new ClientInterceptor() {
@Override
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) {
return new ForwardingClientCall.SimpleForwardingClientCall<ReqT, RespT>(next.newCall(method, callOptions)) {
@Override
public void start(Listener<RespT> responseListener, Metadata headers) {
headers.put(Metadata.Key.of("custom-header", Metadata.ASCII_STRING_MARSHALLER), "header-value");
super.start(responseListener, headers);
}
};
}
};
}

private static ClientInterceptor createCallCrednetialsInterceptor(CallCredentials callCredentials) throws IOException {
return new ClientInterceptor() {
@Override
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) {
return next.newCall(method, callOptions.withCallCredentials(callCredentials));
}
};
}

List<ClientInterceptor> clientInterceptors = new ArrayList<ClientInterceptor>(2);
clientInterceptors.add(createHeaderInterceptor());
CallCredentials myCallCredentals = ...;
clientInterceptors.add(createCallCrednetialsInterceptor(myCallCredentials));

FlagdProvider flagdProvider = new FlagdProvider(
FlagdOptions.builder()
.host("example.com/flagdSyncApi")
.port(443)
.tls(true)
.defaultAuthority("authority-host.sync.example.com")
.clientInterceptors(clientInterceptors)
.build());
```

### Caching (RPC only)

> [!NOTE]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ public final class Config {
static final String HOST_ENV_VAR_NAME = "FLAGD_HOST";
static final String PORT_ENV_VAR_NAME = "FLAGD_PORT";
static final String TLS_ENV_VAR_NAME = "FLAGD_TLS";
static final String DEFAULT_AUTHORITY_ENV_VAR_NAME = "FLAGD_DEFAULT_AUTHORITY";
static final String SOCKET_PATH_ENV_VAR_NAME = "FLAGD_SOCKET_PATH";
static final String SERVER_CERT_PATH_ENV_VAR_NAME = "FLAGD_SERVER_CERT_PATH";
static final String CACHE_ENV_VAR_NAME = "FLAGD_CACHE";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
import dev.openfeature.sdk.EvaluationContext;
import dev.openfeature.sdk.ImmutableContext;
import dev.openfeature.sdk.Structure;
import io.grpc.ClientInterceptor;
import io.opentelemetry.api.GlobalOpenTelemetry;
import io.opentelemetry.api.OpenTelemetry;
import java.util.List;
import java.util.function.Function;
import lombok.Builder;
import lombok.Getter;
Expand Down Expand Up @@ -164,6 +166,18 @@ public class FlagdOptions {
*/
private OpenTelemetry openTelemetry;

/**
* gRPC client interceptors to be used when creating a gRPC channel.
*/
@Builder.Default
private List<ClientInterceptor> clientInterceptors = null;

/**
* Authority header to be used when creating a gRPC channel.
*/
@Builder.Default
private String defaultAuthority = fallBackToEnvOrDefault(Config.DEFAULT_AUTHORITY_ENV_VAR_NAME, null);

/**
* Builder overwrite in order to customize the "build" method.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,12 @@ public static ManagedChannel nettyChannel(final FlagdOptions options) {
final NettyChannelBuilder builder =
NettyChannelBuilder.forTarget(targetUri).keepAliveTime(keepAliveMs, TimeUnit.MILLISECONDS);

if (options.getDefaultAuthority() != null) {
builder.overrideAuthority(options.getDefaultAuthority());
}
if (options.getClientInterceptors() != null) {
builder.intercept(options.getClientInterceptors());
}
if (options.isTls()) {
SslContextBuilder sslContext = GrpcSslContexts.forClient();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@

import dev.openfeature.contrib.providers.flagd.resolver.process.storage.MockConnector;
import dev.openfeature.contrib.providers.flagd.resolver.process.storage.connector.Connector;
import io.grpc.ClientInterceptor;
import io.opentelemetry.api.OpenTelemetry;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Function;
import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.Test;
Expand All @@ -46,12 +49,15 @@ void TestDefaults() {
assertNull(builder.getOfflineFlagSourcePath());
assertEquals(Resolver.RPC, builder.getResolverType());
assertEquals(0, builder.getKeepAlive());
assertNull(builder.getDefaultAuthority());
assertNull(builder.getClientInterceptors());
}

@Test
void TestBuilderOptions() {
OpenTelemetry openTelemetry = Mockito.mock(OpenTelemetry.class);
Connector connector = new MockConnector(null);
List<ClientInterceptor> clientInterceptors = new ArrayList<ClientInterceptor>();

FlagdOptions flagdOptions = FlagdOptions.builder()
.host("https://hosted-flagd")
Expand All @@ -66,6 +72,8 @@ void TestBuilderOptions() {
.resolverType(Resolver.IN_PROCESS)
.targetUri("dns:///localhost:8016")
.keepAlive(1000)
.defaultAuthority("test-authority.sync.example.com")
.clientInterceptors(clientInterceptors)
.build();

assertEquals("https://hosted-flagd", flagdOptions.getHost());
Expand All @@ -80,6 +88,8 @@ void TestBuilderOptions() {
assertEquals(Resolver.IN_PROCESS, flagdOptions.getResolverType());
assertEquals("dns:///localhost:8016", flagdOptions.getTargetUri());
assertEquals(1000, flagdOptions.getKeepAlive());
assertEquals("test-authority.sync.example.com", flagdOptions.getDefaultAuthority());
assertEquals(clientInterceptors, flagdOptions.getClientInterceptors());
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.anyList;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.anyLong;
import static org.mockito.Mockito.anyString;
Expand All @@ -11,6 +12,7 @@
import static org.mockito.Mockito.when;

import dev.openfeature.contrib.providers.flagd.FlagdOptions;
import io.grpc.ClientInterceptor;
import io.grpc.ManagedChannel;
import io.grpc.netty.GrpcSslContexts;
import io.grpc.netty.NettyChannelBuilder;
Expand All @@ -20,6 +22,8 @@
import io.netty.channel.unix.DomainSocketAddress;
import io.netty.handler.ssl.SslContextBuilder;
import java.io.File;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.TimeUnit;
import javax.net.ssl.SSLKeyException;
import org.junit.jupiter.api.Test;
Expand Down Expand Up @@ -113,6 +117,83 @@ void testNettyChannel_withTlsAndCert() {
}
}

@Test
void testNettyChannel_withDefaultAuthority() {
try (MockedStatic<NettyChannelBuilder> nettyMock = mockStatic(NettyChannelBuilder.class)) {
// Mocks
NettyChannelBuilder mockBuilder = mock(NettyChannelBuilder.class);
ManagedChannel mockChannel = mock(ManagedChannel.class);
nettyMock
.when(() -> NettyChannelBuilder.forTarget("localhost:8080"))
.thenReturn(mockBuilder);

when(mockBuilder.keepAliveTime(anyLong(), any(TimeUnit.class))).thenReturn(mockBuilder);
when(mockBuilder.sslContext(any())).thenReturn(mockBuilder);
when(mockBuilder.overrideAuthority(anyString())).thenReturn(mockBuilder);
when(mockBuilder.build()).thenReturn(mockChannel);

// Input options
FlagdOptions options = FlagdOptions.builder()
.host("localhost")
.port(8080)
.keepAlive(5000)
.tls(true)
.defaultAuthority("test-authority.sync.example.com")
.build();

// Call method under test
ManagedChannel channel = ChannelBuilder.nettyChannel(options);

// Assertions
assertThat(channel).isEqualTo(mockChannel);
nettyMock.verify(() -> NettyChannelBuilder.forTarget("localhost:8080"));
verify(mockBuilder).keepAliveTime(5000, TimeUnit.MILLISECONDS);
verify(mockBuilder).sslContext(any());
verify(mockBuilder).overrideAuthority("test-authority.sync.example.com");
verify(mockBuilder).build();
}
}

@Test
void testNettyChannel_withClientInterceptors() {
try (MockedStatic<NettyChannelBuilder> nettyMock = mockStatic(NettyChannelBuilder.class)) {
// Mocks
NettyChannelBuilder mockBuilder = mock(NettyChannelBuilder.class);
ManagedChannel mockChannel = mock(ManagedChannel.class);
nettyMock
.when(() -> NettyChannelBuilder.forTarget("localhost:8080"))
.thenReturn(mockBuilder);

when(mockBuilder.keepAliveTime(anyLong(), any(TimeUnit.class))).thenReturn(mockBuilder);
when(mockBuilder.sslContext(any())).thenReturn(mockBuilder);
when(mockBuilder.intercept(anyList())).thenReturn(mockBuilder);
when(mockBuilder.build()).thenReturn(mockChannel);

List<ClientInterceptor> clientInterceptors = new ArrayList<ClientInterceptor>();
clientInterceptors.add(mock(ClientInterceptor.class));

// Input options
FlagdOptions options = FlagdOptions.builder()
.host("localhost")
.port(8080)
.keepAlive(5000)
.tls(true)
.clientInterceptors(clientInterceptors)
.build();

// Call method under test
ManagedChannel channel = ChannelBuilder.nettyChannel(options);

// Assertions
assertThat(channel).isEqualTo(mockChannel);
nettyMock.verify(() -> NettyChannelBuilder.forTarget("localhost:8080"));
verify(mockBuilder).keepAliveTime(5000, TimeUnit.MILLISECONDS);
verify(mockBuilder).sslContext(any());
verify(mockBuilder).intercept(clientInterceptors);
verify(mockBuilder).build();
}
}

@ParameterizedTest
@ValueSource(strings = {"/incorrect/{uri}/;)"})
void testNettyChannel_withInvalidTargetUri(String uri) {
Expand Down
Loading