JpaOAuth2AuthorizationService.java
/*
* Copyright 2022 Global Crop Diversity Trust
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.genesys.spring.security.service;
import static org.apache.commons.lang3.StringUtils.defaultIfBlank;
import java.io.Serializable;
import java.security.Principal;
import java.time.Instant;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Optional;
import java.util.function.Consumer;
import org.genesys.blocks.oauth.model.Authorization;
import org.genesys.blocks.oauth.model.OAuthRole;
import org.genesys.blocks.oauth.model.QAuthorization;
import org.genesys.blocks.oauth.persistence.AuthorizationRepository;
import org.genesys.server.model.UserRole;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.cache.annotation.CacheConfig;
import org.springframework.cache.annotation.CacheEvict;
import org.springframework.cache.annotation.Cacheable;
import org.springframework.dao.DataRetrievalFailureException;
import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.authentication.AuthenticationCredentialsNotFoundException;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.jackson2.SecurityJackson2Modules;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.OAuth2Token;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.oidc.OidcIdToken;
import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationCode;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.OAuth2TokenType;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
import org.springframework.security.oauth2.server.authorization.jackson2.OAuth2AuthorizationServerJackson2Module;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.MapperFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializationFeature;
import com.fasterxml.jackson.databind.json.JsonMapper;
import com.fasterxml.jackson.databind.util.StdDateFormat;
import com.fasterxml.jackson.datatype.hibernate5.Hibernate5Module;
import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule;
@Service
@CacheConfig(cacheNames = { "oauthtoken" })
public class JpaOAuth2AuthorizationService implements OAuth2AuthorizationService {
private static final Logger LOG = LoggerFactory.getLogger(JpaOAuth2AuthorizationService.class);
private final AuthorizationRepository authorizationRepository;
private final RegisteredClientRepository registeredClientRepository;
private final ObjectMapper objectMapper;
private final JwtTokenIdExtractor jwtTokenIdExtractor;
public JpaOAuth2AuthorizationService(AuthorizationRepository authorizationRepository, RegisteredClientRepository registeredClientRepository, JwtTokenIdExtractor jwtTokenIdExtractor) {
Assert.notNull(authorizationRepository, "authorizationRepository cannot be null");
Assert.notNull(registeredClientRepository, "registeredClientRepository cannot be null");
this.authorizationRepository = authorizationRepository;
this.registeredClientRepository = registeredClientRepository;
this.jwtTokenIdExtractor = jwtTokenIdExtractor;
this.objectMapper = objectMapper();
ClassLoader classLoader = JpaOAuth2AuthorizationService.class.getClassLoader();
var securityModules = SecurityJackson2Modules.getModules(classLoader);
this.objectMapper.registerModules(securityModules);
this.objectMapper.registerModule(new OAuth2AuthorizationServerJackson2Module());
this.objectMapper.addMixIn(UserRole.class, UserRole.class);
this.objectMapper.addMixIn(Long.class, Long.class);
this.objectMapper.addMixIn(CustomPrincipal.class, CustomPrincipal.class);
this.objectMapper.addMixIn(OAuthRole.class, OAuthRole.class);
}
private ObjectMapper objectMapper() {
final Hibernate5Module hibernateModule = new Hibernate5Module();
hibernateModule.enable(Hibernate5Module.Feature.REPLACE_PERSISTENT_COLLECTIONS);
hibernateModule.disable(Hibernate5Module.Feature.FORCE_LAZY_LOADING);
hibernateModule.disable(Hibernate5Module.Feature.SERIALIZE_IDENTIFIER_FOR_LAZY_NOT_LOADED_OBJECTS);
// JSR310 java.time
var javaTimeModule = new JavaTimeModule();
final ObjectMapper mapper = JsonMapper.builder()
// .addModule(new CoreJackson2Module())
.addModule(hibernateModule)
.addModule(javaTimeModule)
.disable(SerializationFeature.EAGER_SERIALIZER_FETCH)
.enable(MapperFeature.DEFAULT_VIEW_INCLUSION)
.enable(DeserializationFeature.ACCEPT_EMPTY_STRING_AS_NULL_OBJECT)
.enable(DeserializationFeature.ACCEPT_SINGLE_VALUE_AS_ARRAY)
.disable(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES)
// .serializationInclusion(JsonInclude.Include.NON_EMPTY) // Include all
.disable(SerializationFeature.WRITE_DATES_AS_TIMESTAMPS)
.defaultDateFormat(new StdDateFormat().withColonInTimeZone(true))
.build();
return mapper;
}
@Override
@CacheEvict(allEntries = true)
@Transactional
public void save(OAuth2Authorization authorization) {
Assert.notNull(authorization, "authorization cannot be null");
var entityForSave = toEntity(authorization);
var existedEntity = authorizationRepository.findById(authorization.getId()).orElse(null);
if (existedEntity != null) {
entityForSave.setAccessTokenId(defaultIfBlank(existedEntity.getAccessTokenId(), entityForSave.getAccessTokenId()));
entityForSave.setRefreshTokenId(defaultIfBlank(existedEntity.getRefreshTokenId(), entityForSave.getRefreshTokenId()));
entityForSave.setOidcIdTokenId(defaultIfBlank(existedEntity.getOidcIdTokenId(), entityForSave.getOidcIdTokenId()));
}
this.authorizationRepository.save(entityForSave);
}
@Override
@CacheEvict(allEntries = true)
@Transactional
public void remove(OAuth2Authorization authorization) {
Assert.notNull(authorization, "authorization cannot be null");
this.authorizationRepository.deleteById(authorization.getId());
}
@Override
@Cacheable(unless = "#result == null")
public OAuth2Authorization findById(String id) {
Assert.hasText(id, "id cannot be empty");
return this.authorizationRepository.findById(id).map(this::toObject).orElse(null);
}
@Override
@Cacheable(unless = "#result == null")
public OAuth2Authorization findByToken(String token, OAuth2TokenType tokenType) {
Assert.hasText(token, "token cannot be empty");
var authorization = QAuthorization.authorization;
Optional<Authorization> result;
if (tokenType == null) {
result = authorizationRepository.findOne(authorization.authorizationCodeValue.eq(token));
if (result.isEmpty()) {
result = findByTokenId(token);
}
} else if (OAuth2ParameterNames.CODE.equals(tokenType.getValue())) {
result = authorizationRepository.findOne(authorization.authorizationCodeValue.eq(token));
} else {
result = findByTokenId(token);
}
return result.map(this::toObject).orElse(null);
}
private Optional<Authorization> findByTokenId(String token) {
String tokenId = jwtTokenIdExtractor.getJwtTokenId(token);
if (tokenId == null) {
return Optional.empty();
}
var authorization = QAuthorization.authorization;
return authorizationRepository.findOne(authorization.accessTokenId.eq(tokenId).or(authorization.refreshTokenId.eq(tokenId)).or(authorization.oidcIdTokenId.eq(tokenId)));
}
private OAuth2Authorization toObject(Authorization entity) {
RegisteredClient registeredClient = this.registeredClientRepository.findById(String.valueOf(entity.getRegisteredClientId()));
if (registeredClient == null) {
throw new DataRetrievalFailureException(
"The RegisteredClient with id '" + entity.getRegisteredClientId() + "' was not found in the RegisteredClientRepository.");
}
try {
OAuth2Authorization.Builder builder = OAuth2Authorization.withRegisteredClient(registeredClient)
.id(entity.getId())
.principalName(entity.getPrincipalName())
.authorizationGrantType(resolveAuthorizationGrantType(entity.getAuthorizationGrantType()))
.attributes(attributes -> attributes.putAll(parseMap(entity.getAttributes())));
if (entity.getState() != null) {
builder.attribute(OAuth2ParameterNames.STATE, entity.getState());
}
if (entity.getAuthorizationCodeValue() != null) {
OAuth2AuthorizationCode authorizationCode = new OAuth2AuthorizationCode(
entity.getAuthorizationCodeValue(),
entity.getAuthorizationCodeIssuedAt(),
entity.getAuthorizationCodeExpiresAt());
builder.token(authorizationCode, metadata -> metadata.putAll(parseMap(entity.getAuthorizationCodeMetadata())));
}
if (entity.getAccessTokenId() != null) {
OAuth2AccessToken accessToken = new OAuth2AccessToken(
OAuth2AccessToken.TokenType.BEARER,
entity.getAccessTokenId(),
entity.getAccessTokenIssuedAt(),
entity.getAccessTokenExpiresAt(),
StringUtils.commaDelimitedListToSet(entity.getAccessTokenScopes()));
builder.token(accessToken, metadata -> metadata.putAll(parseMap(entity.getAccessTokenMetadata())));
}
if (entity.getRefreshTokenId() != null) {
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken(
entity.getRefreshTokenId(),
entity.getRefreshTokenIssuedAt(),
entity.getRefreshTokenExpiresAt());
builder.token(refreshToken, metadata -> metadata.putAll(parseMap(entity.getRefreshTokenMetadata())));
}
if (entity.getOidcIdTokenId() != null) {
OidcIdToken idToken = new OidcIdToken(
entity.getOidcIdTokenId(),
entity.getOidcIdTokenIssuedAt(),
entity.getOidcIdTokenExpiresAt(),
parseMap(entity.getOidcIdTokenClaims()));
builder.token(idToken, metadata -> metadata.putAll(parseMap(entity.getOidcIdTokenMetadata())));
}
var oauth2Auth = builder.build();
var principal = oauth2Auth.getAttribute(Principal.class.getName());
if (principal instanceof CustomPrincipal) {
var authentication = (CustomPrincipal) principal;
var authPrincipal = authentication.getPrincipal();
if (authPrincipal instanceof DefaultOidcUser) {
var oidcUser = (DefaultOidcUser) authPrincipal;
authPrincipal = new DefaultOidcUser(new HashSet<>(authentication.getAuthorities()), oidcUser.getIdToken(), oidcUser.getUserInfo(), oidcUser.getClaimAsString("userNameAttribute"));
authentication.setPrincipal(authPrincipal);
}
}
return oauth2Auth;
} catch (IllegalArgumentException e) {
LOG.warn("Could not read data: {}", e.getMessage());
throw new AuthenticationCredentialsNotFoundException("Cannot parse authorization");
}
}
private Authorization toEntity(OAuth2Authorization authorization) {
Authorization entity = new Authorization();
entity.setId(authorization.getId());
entity.setRegisteredClientId(Long.valueOf(authorization.getRegisteredClientId()));
entity.setPrincipalName(authorization.getPrincipalName());
entity.setAuthorizationGrantType(authorization.getAuthorizationGrantType().getValue());
var attributes = authorization.getAttributes();
Map<String, Object> customAttributes = new HashMap<>(attributes);
var principal = customAttributes.get(Principal.class.getName());
if (principal instanceof Authentication) {
var authentication = (Authentication) principal;
var authPrincipal = authentication.getPrincipal();
if (authPrincipal instanceof DefaultOidcUser) {
var oidcUser = (DefaultOidcUser) authPrincipal;
authPrincipal = new DefaultOidcUser(null, oidcUser.getIdToken(), oidcUser.getUserInfo(), oidcUser.getClaimAsString("userNameAttribute"));
}
customAttributes.put(Principal.class.getName(), new CustomPrincipal(authPrincipal, new HashSet<>(authentication.getAuthorities()), null));
}
entity.setAttributes(writeMap(customAttributes));
entity.setState(authorization.getAttribute(OAuth2ParameterNames.STATE));
OAuth2Authorization.Token<OAuth2AuthorizationCode> authorizationCode = authorization.getToken(OAuth2AuthorizationCode.class);
if (authorizationCode != null) {
setTokenValues(
authorizationCode,
entity::setAuthorizationCodeIssuedAt,
entity::setAuthorizationCodeExpiresAt,
entity::setAuthorizationCodeMetadata
);
entity.setAuthorizationCodeValue(authorizationCode.getToken().getTokenValue());
}
OAuth2Authorization.Token<OAuth2AccessToken> accessToken = authorization.getToken(OAuth2AccessToken.class);
if (accessToken != null) {
entity.setAccessTokenId(jwtTokenIdExtractor.getJwtTokenId(accessToken.getToken().getTokenValue()));
if (accessToken.getToken().getTokenValue() != null) {
setTokenValues(
accessToken,
entity::setAccessTokenIssuedAt,
entity::setAccessTokenExpiresAt,
entity::setAccessTokenMetadata
);
if (accessToken.getToken().getScopes() != null) {
entity.setAccessTokenScopes(StringUtils.collectionToDelimitedString(accessToken.getToken().getScopes(), ","));
}
}
}
OAuth2Authorization.Token<OAuth2RefreshToken> refreshToken = authorization.getToken(OAuth2RefreshToken.class);
if (refreshToken != null) {
entity.setRefreshTokenId(jwtTokenIdExtractor.getJwtTokenId(refreshToken.getToken().getTokenValue()));
if (refreshToken.getToken().getTokenValue() != null) {
setTokenValues(
refreshToken,
entity::setRefreshTokenIssuedAt,
entity::setRefreshTokenExpiresAt,
entity::setRefreshTokenMetadata
);
}
}
OAuth2Authorization.Token<OidcIdToken> oidcIdToken = authorization.getToken(OidcIdToken.class);
if (oidcIdToken != null) {
entity.setOidcIdTokenId(jwtTokenIdExtractor.getJwtTokenId(oidcIdToken.getToken().getTokenValue()));
if (oidcIdToken.getToken().getTokenValue() != null) {
setTokenValues(
oidcIdToken,
entity::setOidcIdTokenIssuedAt,
entity::setOidcIdTokenExpiresAt,
entity::setOidcIdTokenMetadata
);
entity.setOidcIdTokenClaims(writeMap(oidcIdToken.getClaims()));
}
}
return entity;
}
public static class CustomPrincipal extends AbstractAuthenticationToken {
private static final long serialVersionUID = -4183504039635400160L;
@JsonIgnoreProperties({ "roles" })
private Object principal;
private Object credentials;
public CustomPrincipal(Object sid, Collection<? extends GrantedAuthority> authorities, Object credentials) {
super(authorities);
this.principal = sid;
this.credentials = credentials;
this.setAuthenticated(true);
}
public CustomPrincipal() {
super(null);
this.setAuthenticated(true);
}
@Override
public Object getCredentials() {
return credentials;
}
@Override
public Object getPrincipal() {
return principal;
}
public void setPrincipal(Object principal) {
this.principal = principal;
}
}
private void setTokenValues(
OAuth2Authorization.Token<?> token,
Consumer<Instant> issuedAtConsumer,
Consumer<Instant> expiresAtConsumer,
Consumer<String> metadataConsumer) {
if (token != null) {
OAuth2Token oAuth2Token = token.getToken();
issuedAtConsumer.accept(oAuth2Token.getIssuedAt());
expiresAtConsumer.accept(oAuth2Token.getExpiresAt());
metadataConsumer.accept(writeMap(token.getMetadata()));
}
}
private Map<String, Object> parseMap(String data) throws IllegalArgumentException {
try {
return this.objectMapper.readValue(data, new TypeReference<>() {
});
} catch (Exception ex) {
LOG.error("Exception in deserializing map: {}", ex.getMessage());
throw new IllegalArgumentException(ex.getMessage(), ex);
}
}
private String writeMap(Map<String, Object> metadata) {
try {
return this.objectMapper.writeValueAsString(metadata);
} catch (Exception ex) {
LOG.error("Exception in serializing map", ex);
throw new IllegalArgumentException(ex.getMessage(), ex);
}
}
private static AuthorizationGrantType resolveAuthorizationGrantType(String authorizationGrantType) {
if (AuthorizationGrantType.AUTHORIZATION_CODE.getValue().equals(authorizationGrantType)) {
return AuthorizationGrantType.AUTHORIZATION_CODE;
} else if (AuthorizationGrantType.CLIENT_CREDENTIALS.getValue().equals(authorizationGrantType)) {
return AuthorizationGrantType.CLIENT_CREDENTIALS;
} else if (AuthorizationGrantType.REFRESH_TOKEN.getValue().equals(authorizationGrantType)) {
return AuthorizationGrantType.REFRESH_TOKEN;
}
return new AuthorizationGrantType(authorizationGrantType);
}
}