Skip to content

Decouple SAML 2.0 Single Logout from the authenticated principal's type #11338

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
import org.springframework.security.config.annotation.web.configurers.LogoutConfigurer;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticatedPrincipal;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationInfo;
import org.springframework.security.saml2.provider.service.authentication.logout.OpenSaml4LogoutRequestValidator;
import org.springframework.security.saml2.provider.service.authentication.logout.OpenSaml4LogoutResponseValidator;
import org.springframework.security.saml2.provider.service.authentication.logout.OpenSaml5LogoutRequestValidator;
Expand Down Expand Up @@ -531,10 +531,7 @@ private static class Saml2RequestMatcher implements RequestMatcher {
@Override
public boolean matches(HttpServletRequest request) {
Authentication authentication = this.securityContextHolderStrategy.getContext().getAuthentication();
if (authentication == null) {
return false;
}
return authentication.getPrincipal() instanceof Saml2AuthenticatedPrincipal;
return Saml2AuthenticationInfo.fromAuthentication(authentication) != null;
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticatedPrincipal;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationInfo;
import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver;
import org.springframework.security.saml2.provider.service.web.authentication.logout.Saml2LogoutRequestFilter;
import org.springframework.security.saml2.provider.service.web.authentication.logout.Saml2LogoutResponseFilter;
Expand Down Expand Up @@ -236,10 +236,7 @@ public static class Saml2RequestMatcher implements RequestMatcher {
@Override
public boolean matches(HttpServletRequest request) {
Authentication authentication = this.securityContextHolderStrategy.getContext().getAuthentication();
if (authentication == null) {
return false;
}
return authentication.getPrincipal() instanceof Saml2AuthenticatedPrincipal;
return Saml2AuthenticationInfo.fromAuthentication(authentication) != null;
}

public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
* @author Clement Stoquart
* @since 5.2.2
*/
public interface Saml2AuthenticatedPrincipal extends AuthenticatedPrincipal {
public interface Saml2AuthenticatedPrincipal extends AuthenticatedPrincipal, Saml2AuthenticationInfo {

/**
* Get the first value of Saml2 token attribute by name
Expand Down Expand Up @@ -72,10 +72,17 @@ default Map<String, List<Object>> getAttributes() {
* @return the {@link RelyingPartyRegistration} identifier
* @since 5.6
*/
@Override
default String getRelyingPartyRegistrationId() {
return null;
}

@Override
default String getNameId() {
return getName();
}

@Override
default List<String> getSessionIndexes() {
return Collections.emptyList();
}
Expand Down
Copy link
Contributor

@jzheaux jzheaux Mar 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not totally clear on why this new interface is needed. Maybe the code instead should check for authentication as well as authentication.getPrincipal being of type Saml2AuthenticatedPrincipal. The code could change to:

if (authentication instanceof Saml2AuthenticatedPrincipal principal) {
    return principal.getRelyingPartyRegistrationId();
}
if (authentication.getPrincipal() instanceof Saml2AuthenticatedPrincipal principal) {
  return principal.getRelyingPartyRegistrationId();
}

It seems like this would achieve what you are wanting without adding a new interface to support. This is nice since there is such a big overlap between Saml2AuthenticatedPrincipal and Saml2AuthenticationInfo's methods.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because an authentication token is not a principal, while both a token and a principal can be sources of SAML information.

You don't want to have Saml2Authentication implements Saml2AuthenticatedPrincipal do you?

Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.security.saml2.provider.service.authentication;

import java.util.List;

import org.opensaml.saml.saml2.core.NameID;
import org.opensaml.saml.saml2.core.SessionIndex;

import org.springframework.security.core.Authentication;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;

/**
* Additional SAML 2.0 authentication information
*
* <p>
* SAML 2.0 Single Logout requires that the {@link Authentication#getPrincipal()
* authenticated principal} or the {@link Authentication} itself implements this
* interface.
*
* @author Christian Schuster
*/
public interface Saml2AuthenticationInfo {

/**
* Get the {@link RelyingPartyRegistration} identifier
* @return the {@link RelyingPartyRegistration} identifier
*/
String getRelyingPartyRegistrationId();

/**
* Get the {@link NameID} value of the authenticated principal
* @return the {@link NameID} value of the authenticated principal
*/
String getNameId();

/**
* Get the {@link SessionIndex} values of the authenticated principal
* @return the {@link SessionIndex} values of the authenticated principal
*/
List<String> getSessionIndexes();

/**
* Try to obtain a {@link Saml2AuthenticationInfo} instance from an
* {@link Authentication}
*
* <p>
* The result is either the {@link Authentication#getPrincipal() authenticated
* principal}, the {@link Authentication} itself, or {@code null}.
*
* <p>
* Returning {@code null} indicates that the given {@link Authentication} does not
* represent a SAML 2.0 authentication.
* @param authentication the {@link Authentication}
* @return the {@link Saml2AuthenticationInfo} or {@code null} if unavailable
*/
static Saml2AuthenticationInfo fromAuthentication(Authentication authentication) {
if (authentication == null) {
return null;
}
Object principal = authentication.getPrincipal();
if (principal instanceof Saml2AuthenticationInfo) {
return (Saml2AuthenticationInfo) principal;
}
if (authentication instanceof Saml2AuthenticationInfo) {
return (Saml2AuthenticationInfo) authentication;
}
return null;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
import org.springframework.security.core.Authentication;
import org.springframework.security.saml2.core.OpenSamlInitializationService;
import org.springframework.security.saml2.core.Saml2ParameterNames;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticatedPrincipal;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationInfo;
import org.springframework.security.saml2.provider.service.authentication.logout.Saml2LogoutRequest;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
Expand Down Expand Up @@ -147,16 +147,19 @@ public Saml2LogoutRequest resolve(HttpServletRequest request, Authentication aut
issuer.setValue(entityId);
logoutRequest.setIssuer(issuer);
NameID nameId = this.nameIdBuilder.buildObject();
nameId.setValue(authentication.getName());
logoutRequest.setNameID(nameId);
if (authentication.getPrincipal() instanceof Saml2AuthenticatedPrincipal) {
Saml2AuthenticatedPrincipal principal = (Saml2AuthenticatedPrincipal) authentication.getPrincipal();
for (String index : principal.getSessionIndexes()) {
Saml2AuthenticationInfo info = Saml2AuthenticationInfo.fromAuthentication(authentication);
if (info != null) {
nameId.setValue(info.getNameId());
for (String index : info.getSessionIndexes()) {
SessionIndex sessionIndex = this.sessionIndexBuilder.buildObject();
sessionIndex.setValue(index);
logoutRequest.getSessionIndexes().add(sessionIndex);
}
}
else {
nameId.setValue(authentication.getName());
}
logoutRequest.setIssueInstant(Instant.now(this.clock));
this.parametersConsumer
.accept(new LogoutRequestParameters(request, registration, authentication, logoutRequest));
Expand Down Expand Up @@ -191,12 +194,9 @@ private String getRegistrationId(Authentication authentication) {
if (this.logger.isTraceEnabled()) {
this.logger.trace("Attempting to resolve registrationId from " + authentication);
}
if (authentication == null) {
return null;
}
Object principal = authentication.getPrincipal();
if (principal instanceof Saml2AuthenticatedPrincipal) {
return ((Saml2AuthenticatedPrincipal) principal).getRelyingPartyRegistrationId();
Saml2AuthenticationInfo info = Saml2AuthenticationInfo.fromAuthentication(authentication);
if (info != null) {
return info.getRelyingPartyRegistrationId();
}
return null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
import org.springframework.security.saml2.core.OpenSamlInitializationService;
import org.springframework.security.saml2.core.Saml2Error;
import org.springframework.security.saml2.core.Saml2ParameterNames;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticatedPrincipal;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationInfo;
import org.springframework.security.saml2.provider.service.authentication.logout.Saml2LogoutRequest;
import org.springframework.security.saml2.provider.service.authentication.logout.Saml2LogoutRequestValidatorParameters;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
Expand Down Expand Up @@ -130,11 +130,9 @@ private String getRegistrationId(RequestMatcher.MatchResult result, Authenticati
if (registrationId != null) {
return registrationId;
}
if (authentication == null) {
return null;
}
if (authentication.getPrincipal() instanceof Saml2AuthenticatedPrincipal principal) {
return principal.getRelyingPartyRegistrationId();
Saml2AuthenticationInfo info = Saml2AuthenticationInfo.fromAuthentication(authentication);
if (info != null) {
return info.getRelyingPartyRegistrationId();
}
return null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
import org.springframework.security.core.Authentication;
import org.springframework.security.saml2.core.OpenSamlInitializationService;
import org.springframework.security.saml2.core.Saml2ParameterNames;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticatedPrincipal;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationInfo;
import org.springframework.security.saml2.provider.service.authentication.logout.Saml2LogoutResponse;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
Expand Down Expand Up @@ -204,12 +204,9 @@ private String getRegistrationId(Authentication authentication) {
if (this.logger.isTraceEnabled()) {
this.logger.trace("Attempting to resolve registrationId from " + authentication);
}
if (authentication == null) {
return null;
}
Object principal = authentication.getPrincipal();
if (principal instanceof Saml2AuthenticatedPrincipal) {
return ((Saml2AuthenticatedPrincipal) principal).getRelyingPartyRegistrationId();
Saml2AuthenticationInfo info = Saml2AuthenticationInfo.fromAuthentication(authentication);
if (info != null) {
return info.getRelyingPartyRegistrationId();
}
return null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.security.saml2.core.Saml2Error;
import org.springframework.security.saml2.core.Saml2ParameterNames;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticatedPrincipal;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationInfo;
import org.springframework.security.saml2.provider.service.authentication.logout.Saml2LogoutRequest;
import org.springframework.security.saml2.provider.service.authentication.logout.Saml2LogoutRequestValidator;
import org.springframework.security.saml2.provider.service.authentication.logout.Saml2LogoutRequestValidatorParameters;
Expand Down Expand Up @@ -303,11 +303,9 @@ private String getRegistrationId(RequestMatcher.MatchResult result, Authenticati
if (registrationId != null) {
return registrationId;
}
if (authentication == null) {
return null;
}
if (authentication.getPrincipal() instanceof Saml2AuthenticatedPrincipal principal) {
return principal.getRelyingPartyRegistrationId();
Saml2AuthenticationInfo info = Saml2AuthenticationInfo.fromAuthentication(authentication);
if (info != null) {
return info.getRelyingPartyRegistrationId();
}
return null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
import org.springframework.security.saml2.core.OpenSamlInitializationService;
import org.springframework.security.saml2.core.Saml2Error;
import org.springframework.security.saml2.core.Saml2ParameterNames;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticatedPrincipal;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationInfo;
import org.springframework.security.saml2.provider.service.authentication.logout.Saml2LogoutRequest;
import org.springframework.security.saml2.provider.service.authentication.logout.Saml2LogoutRequestValidatorParameters;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
Expand Down Expand Up @@ -144,11 +144,9 @@ private String getRegistrationId(RequestMatcher.MatchResult result, Authenticati
if (registrationId != null) {
return registrationId;
}
if (authentication == null) {
return null;
}
if (authentication.getPrincipal() instanceof Saml2AuthenticatedPrincipal principal) {
return principal.getRelyingPartyRegistrationId();
Saml2AuthenticationInfo info = Saml2AuthenticationInfo.fromAuthentication(authentication);
if (info != null) {
return info.getRelyingPartyRegistrationId();
}
return null;
}
Expand Down