JpaOAuth2AuthorizationService.java

/*
 * Copyright 2022 Global Crop Diversity Trust
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.genesys.spring.security.service;

import static org.apache.commons.lang3.StringUtils.defaultIfBlank;

import java.io.IOException;
import java.security.Principal;
import java.time.Instant;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Optional;
import java.util.function.Consumer;

import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.databind.DeserializationContext;
import com.fasterxml.jackson.databind.JsonDeserializer;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.module.SimpleModule;
import org.genesys.blocks.oauth.model.Authorization;
import org.genesys.blocks.oauth.model.OAuthClient;
import org.genesys.blocks.oauth.model.OAuthRole;
import org.genesys.blocks.oauth.model.QAuthorization;
import org.genesys.blocks.oauth.persistence.AuthorizationRepository;
import org.genesys.blocks.oauth.service.OAuthClientService;
import org.genesys.server.model.UserRole;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.cache.annotation.CacheConfig;
import org.springframework.cache.annotation.CacheEvict;
import org.springframework.cache.annotation.Cacheable;
import org.springframework.dao.DataRetrievalFailureException;
import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.authentication.AuthenticationCredentialsNotFoundException;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.jackson2.SecurityJackson2Modules;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.OAuth2Token;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.oidc.OidcIdToken;
import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationCode;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.OAuth2TokenType;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
import org.springframework.security.oauth2.server.authorization.jackson2.OAuth2AuthorizationServerJackson2Module;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.MapperFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializationFeature;
import com.fasterxml.jackson.databind.json.JsonMapper;
import com.fasterxml.jackson.databind.util.StdDateFormat;
import com.fasterxml.jackson.datatype.hibernate5.Hibernate5Module;
import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule;

@Service
@CacheConfig(cacheNames = { "oauthtoken" })
public class JpaOAuth2AuthorizationService implements OAuth2AuthorizationService {

	private static final Logger LOG = LoggerFactory.getLogger(JpaOAuth2AuthorizationService.class);

	private final AuthorizationRepository authorizationRepository;
	private final RegisteredClientRepository registeredClientRepository;
	private final ObjectMapper objectMapper;
	private final JwtTokenIdExtractor jwtTokenIdExtractor;

	public JpaOAuth2AuthorizationService(AuthorizationRepository authorizationRepository, RegisteredClientRepository registeredClientRepository, JwtTokenIdExtractor jwtTokenIdExtractor, OAuthClientService oAuthClientService) {
		Assert.notNull(authorizationRepository, "authorizationRepository cannot be null");
		Assert.notNull(registeredClientRepository, "registeredClientRepository cannot be null");
		this.authorizationRepository = authorizationRepository;
		this.registeredClientRepository = registeredClientRepository;
		this.jwtTokenIdExtractor = jwtTokenIdExtractor;

		this.objectMapper = objectMapper();
		ClassLoader classLoader = JpaOAuth2AuthorizationService.class.getClassLoader();
		var securityModules = SecurityJackson2Modules.getModules(classLoader);
		this.objectMapper.registerModules(securityModules);
		this.objectMapper.registerModule(new OAuth2AuthorizationServerJackson2Module());
		this.objectMapper.addMixIn(UserRole.class, UserRole.class);
		this.objectMapper.addMixIn(Long.class, Long.class);
		this.objectMapper.addMixIn(CustomPrincipal.class, CustomPrincipal.class);
		this.objectMapper.addMixIn(OAuthRole.class, OAuthRole.class);

		SimpleModule module = new SimpleModule();
		module.addDeserializer(OAuthClient.class, new OAuthClientDeserializer(oAuthClientService));
		objectMapper.registerModule(module);
	}

	private ObjectMapper objectMapper() {

		final Hibernate5Module hibernateModule = new Hibernate5Module();
		hibernateModule.enable(Hibernate5Module.Feature.REPLACE_PERSISTENT_COLLECTIONS);
		hibernateModule.disable(Hibernate5Module.Feature.FORCE_LAZY_LOADING);
		hibernateModule.disable(Hibernate5Module.Feature.SERIALIZE_IDENTIFIER_FOR_LAZY_NOT_LOADED_OBJECTS);

		// JSR310 java.time
		var javaTimeModule = new JavaTimeModule();

		final ObjectMapper mapper = JsonMapper.builder()
//			.addModule(new CoreJackson2Module())
			.addModule(hibernateModule)
			.addModule(javaTimeModule)
			.disable(SerializationFeature.EAGER_SERIALIZER_FETCH)
			.enable(MapperFeature.DEFAULT_VIEW_INCLUSION)
			.enable(DeserializationFeature.ACCEPT_EMPTY_STRING_AS_NULL_OBJECT)
			.enable(DeserializationFeature.ACCEPT_SINGLE_VALUE_AS_ARRAY)
			.disable(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES)
			// .serializationInclusion(JsonInclude.Include.NON_EMPTY) // Include all
			.disable(SerializationFeature.WRITE_DATES_AS_TIMESTAMPS)
			.defaultDateFormat(new StdDateFormat().withColonInTimeZone(true))
			.build();

		return mapper;
	}

	private static class OAuthClientDeserializer extends JsonDeserializer<OAuthClient> {

		private final OAuthClientService oAuthClientService;

		public OAuthClientDeserializer(OAuthClientService oAuthClientService) {
			this.oAuthClientService = oAuthClientService;
		}

		@Override
		public OAuthClient deserialize(JsonParser parser, DeserializationContext ctxt) throws IOException {
			JsonNode rootNode = parser.getCodec().readTree(parser);
			var idNode = rootNode.get("clientId");
			if (idNode.isTextual()) {
				var clientId = idNode.asText();
				return oAuthClientService.loadClientByClientId(clientId);
			}
			return null;
		}
	}

	@Override
	@CacheEvict(allEntries = true)
	@Transactional
	public void save(OAuth2Authorization authorization) {
		Assert.notNull(authorization, "authorization cannot be null");
		var entityForSave = toEntity(authorization);
		var existedEntity = authorizationRepository.findById(authorization.getId()).orElse(null);
		if (existedEntity != null) {
			entityForSave.setAccessTokenId(defaultIfBlank(existedEntity.getAccessTokenId(), entityForSave.getAccessTokenId()));
			entityForSave.setRefreshTokenId(defaultIfBlank(existedEntity.getRefreshTokenId(), entityForSave.getRefreshTokenId()));
			entityForSave.setOidcIdTokenId(defaultIfBlank(existedEntity.getOidcIdTokenId(), entityForSave.getOidcIdTokenId()));
		}
		this.authorizationRepository.save(entityForSave);
	}

	@Override
	@CacheEvict(allEntries = true)
	@Transactional
	public void remove(OAuth2Authorization authorization) {
		Assert.notNull(authorization, "authorization cannot be null");
		this.authorizationRepository.deleteById(authorization.getId());
	}

	@Override
	@Cacheable(unless = "#result == null")
	public OAuth2Authorization findById(String id) {
		Assert.hasText(id, "id cannot be empty");
		return this.authorizationRepository.findById(id).map(this::toObject).orElse(null);
	}

	@Override
	@Cacheable(unless = "#result == null")
	public OAuth2Authorization findByToken(String token, OAuth2TokenType tokenType) {
		Assert.hasText(token, "token cannot be empty");

		var authorization = QAuthorization.authorization;
		Optional<Authorization> result;
		if (tokenType == null) {
			result = authorizationRepository.findOne(authorization.authorizationCodeValue.eq(token));
			if (result.isEmpty()) {
				result = findByTokenId(token);
			}
		} else if (OAuth2ParameterNames.CODE.equals(tokenType.getValue())) {
			result = authorizationRepository.findOne(authorization.authorizationCodeValue.eq(token));
		} else {
			result = findByTokenId(token);
		}

		return result.map(this::toObject).orElse(null);
	}

	private Optional<Authorization> findByTokenId(String token) {
		String tokenId = jwtTokenIdExtractor.getJwtTokenId(token);
		if (tokenId == null) {
			return Optional.empty();
		}
		var authorization = QAuthorization.authorization;
		return authorizationRepository.findOne(authorization.accessTokenId.eq(tokenId).or(authorization.refreshTokenId.eq(tokenId)).or(authorization.oidcIdTokenId.eq(tokenId)));
	}

	private OAuth2Authorization toObject(Authorization entity) {
		RegisteredClient registeredClient = this.registeredClientRepository.findById(String.valueOf(entity.getRegisteredClientId()));
		if (registeredClient == null) {
			throw new DataRetrievalFailureException(
				"The RegisteredClient with id '" + entity.getRegisteredClientId() + "' was not found in the RegisteredClientRepository.");
		}

		try {
			OAuth2Authorization.Builder builder = OAuth2Authorization.withRegisteredClient(registeredClient)
				.id(entity.getId())
				.principalName(entity.getPrincipalName())
				.authorizationGrantType(resolveAuthorizationGrantType(entity.getAuthorizationGrantType()))
				.attributes(attributes -> attributes.putAll(parseMap(entity.getAttributes())));
			if (entity.getState() != null) {
				builder.attribute(OAuth2ParameterNames.STATE, entity.getState());
			}
	
			if (entity.getAuthorizationCodeValue() != null) {
				OAuth2AuthorizationCode authorizationCode = new OAuth2AuthorizationCode(
					entity.getAuthorizationCodeValue(),
					entity.getAuthorizationCodeIssuedAt(),
					entity.getAuthorizationCodeExpiresAt());
				builder.token(authorizationCode, metadata -> metadata.putAll(parseMap(entity.getAuthorizationCodeMetadata())));
			}
	
			if (entity.getAccessTokenId() != null) {
				OAuth2AccessToken accessToken = new OAuth2AccessToken(
					OAuth2AccessToken.TokenType.BEARER,
					entity.getAccessTokenId(),
					entity.getAccessTokenIssuedAt(),
					entity.getAccessTokenExpiresAt(),
					StringUtils.commaDelimitedListToSet(entity.getAccessTokenScopes()));
				builder.token(accessToken, metadata -> metadata.putAll(parseMap(entity.getAccessTokenMetadata())));
			}
	
			if (entity.getRefreshTokenId() != null) {
				OAuth2RefreshToken refreshToken = new OAuth2RefreshToken(
					entity.getRefreshTokenId(),
					entity.getRefreshTokenIssuedAt(),
					entity.getRefreshTokenExpiresAt());
				builder.token(refreshToken, metadata -> metadata.putAll(parseMap(entity.getRefreshTokenMetadata())));
			}
	
			if (entity.getOidcIdTokenId() != null) {
				OidcIdToken idToken = new OidcIdToken(
					entity.getOidcIdTokenId(),
					entity.getOidcIdTokenIssuedAt(),
					entity.getOidcIdTokenExpiresAt(),
					parseMap(entity.getOidcIdTokenClaims()));
				builder.token(idToken, metadata -> metadata.putAll(parseMap(entity.getOidcIdTokenMetadata())));
			}
	
			var oauth2Auth = builder.build();
			var principal = oauth2Auth.getAttribute(Principal.class.getName());
			if (principal instanceof CustomPrincipal) {
				var authentication = (CustomPrincipal) principal;
				var authPrincipal = authentication.getPrincipal();
				if (authPrincipal instanceof DefaultOidcUser) {
					var oidcUser = (DefaultOidcUser) authPrincipal;
					authPrincipal = new DefaultOidcUser(new HashSet<>(authentication.getAuthorities()), oidcUser.getIdToken(), oidcUser.getUserInfo(), oidcUser.getClaimAsString("userNameAttribute"));
					authentication.setPrincipal(authPrincipal);
				}
			}
			return oauth2Auth;
		} catch (IllegalArgumentException e) {
			LOG.warn("Could not read data: {}", e.getMessage());
			throw new AuthenticationCredentialsNotFoundException("Cannot parse authorization");
		}
	}

	private Authorization toEntity(OAuth2Authorization authorization) {
		Authorization entity = new Authorization();
		entity.setId(authorization.getId());
		entity.setRegisteredClientId(Long.valueOf(authorization.getRegisteredClientId()));
		entity.setPrincipalName(authorization.getPrincipalName());
		entity.setAuthorizationGrantType(authorization.getAuthorizationGrantType().getValue());

		var attributes = authorization.getAttributes();
		Map<String, Object> customAttributes = new HashMap<>(attributes);
		var principal = customAttributes.get(Principal.class.getName());
		if (principal instanceof Authentication) {
			var authentication = (Authentication) principal;
			var authPrincipal = authentication.getPrincipal();
			if (authPrincipal instanceof DefaultOidcUser) {
				var oidcUser = (DefaultOidcUser) authPrincipal;
				authPrincipal = new DefaultOidcUser(null, oidcUser.getIdToken(), oidcUser.getUserInfo(), oidcUser.getClaimAsString("userNameAttribute"));
			}
			customAttributes.put(Principal.class.getName(), new CustomPrincipal(authPrincipal, new HashSet<>(authentication.getAuthorities()), null));
		}
		
		entity.setAttributes(writeMap(customAttributes));
		entity.setState(authorization.getAttribute(OAuth2ParameterNames.STATE));

		OAuth2Authorization.Token<OAuth2AuthorizationCode> authorizationCode = authorization.getToken(OAuth2AuthorizationCode.class);
		if (authorizationCode != null) {
			setTokenValues(
				authorizationCode,
				entity::setAuthorizationCodeIssuedAt,
				entity::setAuthorizationCodeExpiresAt,
				entity::setAuthorizationCodeMetadata
			);
			entity.setAuthorizationCodeValue(authorizationCode.getToken().getTokenValue());
		}

		OAuth2Authorization.Token<OAuth2AccessToken> accessToken = authorization.getToken(OAuth2AccessToken.class);
		if (accessToken != null) {
			entity.setAccessTokenId(jwtTokenIdExtractor.getJwtTokenId(accessToken.getToken().getTokenValue()));
			if (accessToken.getToken().getTokenValue() != null) {
				setTokenValues(
					accessToken,
					entity::setAccessTokenIssuedAt,
					entity::setAccessTokenExpiresAt,
					entity::setAccessTokenMetadata
				);
			if (accessToken.getToken().getScopes() != null) {
				entity.setAccessTokenScopes(StringUtils.collectionToDelimitedString(accessToken.getToken().getScopes(), ","));
			}
		}
		}

		OAuth2Authorization.Token<OAuth2RefreshToken> refreshToken = authorization.getToken(OAuth2RefreshToken.class);
		if (refreshToken != null) {
			entity.setRefreshTokenId(jwtTokenIdExtractor.getJwtTokenId(refreshToken.getToken().getTokenValue()));
			if (refreshToken.getToken().getTokenValue() != null) {
				setTokenValues(
					refreshToken,
					entity::setRefreshTokenIssuedAt,
					entity::setRefreshTokenExpiresAt,
					entity::setRefreshTokenMetadata
				);
			}
		}

		OAuth2Authorization.Token<OidcIdToken> oidcIdToken = authorization.getToken(OidcIdToken.class);
		if (oidcIdToken != null) {
			entity.setOidcIdTokenId(jwtTokenIdExtractor.getJwtTokenId(oidcIdToken.getToken().getTokenValue()));
			if (oidcIdToken.getToken().getTokenValue() != null) {
				setTokenValues(
					oidcIdToken,
					entity::setOidcIdTokenIssuedAt,
					entity::setOidcIdTokenExpiresAt,
					entity::setOidcIdTokenMetadata
				);
				entity.setOidcIdTokenClaims(writeMap(oidcIdToken.getClaims()));
			}
		}

		return entity;
	}

	public static class CustomPrincipal extends AbstractAuthenticationToken {
		private static final long serialVersionUID = -4183504039635400160L;

		@JsonIgnoreProperties({ "roles" })
		private Object principal;
		private Object credentials;

		public CustomPrincipal(Object sid, Collection<? extends GrantedAuthority> authorities, Object credentials) {
			super(authorities);
			this.principal = sid;
			this.credentials = credentials;
			this.setAuthenticated(true);
		}

		public CustomPrincipal() {
			super(null);
			this.setAuthenticated(true);
		}

		@Override
		public Object getCredentials() {
			return credentials;
		}

		@Override
		public Object getPrincipal() {
			return principal;
		}

		public void setPrincipal(Object principal) {
			this.principal = principal;
		}
	}

	private void setTokenValues(
		OAuth2Authorization.Token<?> token,
		Consumer<Instant> issuedAtConsumer,
		Consumer<Instant> expiresAtConsumer,
		Consumer<String> metadataConsumer) {
		if (token != null) {
			OAuth2Token oAuth2Token = token.getToken();
			issuedAtConsumer.accept(oAuth2Token.getIssuedAt());
			expiresAtConsumer.accept(oAuth2Token.getExpiresAt());
			metadataConsumer.accept(writeMap(token.getMetadata()));
		}
	}

	private Map<String, Object> parseMap(String data) throws IllegalArgumentException {
		try {
			return this.objectMapper.readValue(data, new TypeReference<>() {
			});
		} catch (Exception ex) {
			LOG.error("Exception in deserializing map: {}", ex.getMessage());
			throw new IllegalArgumentException(ex.getMessage(), ex);
		}
	}

	private String writeMap(Map<String, Object> metadata) {
		try {
			return this.objectMapper.writeValueAsString(metadata);
		} catch (Exception ex) {
			LOG.error("Exception in serializing map", ex);
			throw new IllegalArgumentException(ex.getMessage(), ex);
		}
	}

	private static AuthorizationGrantType resolveAuthorizationGrantType(String authorizationGrantType) {
		if (AuthorizationGrantType.AUTHORIZATION_CODE.getValue().equals(authorizationGrantType)) {
			return AuthorizationGrantType.AUTHORIZATION_CODE;
		} else if (AuthorizationGrantType.CLIENT_CREDENTIALS.getValue().equals(authorizationGrantType)) {
			return AuthorizationGrantType.CLIENT_CREDENTIALS;
		} else if (AuthorizationGrantType.REFRESH_TOKEN.getValue().equals(authorizationGrantType)) {
			return AuthorizationGrantType.REFRESH_TOKEN;
		}
		return new AuthorizationGrantType(authorizationGrantType);
	}
}