OAuthServiceImpl.java

/*
 * Copyright 2019 Global Crop Diversity Trust
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.genesys.blocks.oauth.service;

import java.net.URL;
import java.time.Duration;
import java.time.temporal.ChronoUnit;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;

import javax.persistence.EntityNotFoundException;

import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.collections4.MapUtils;
import org.apache.commons.lang3.RandomStringUtils;
import org.apache.commons.lang3.StringUtils;
import org.genesys.blocks.oauth.model.OAuthClient;
import org.genesys.blocks.oauth.model.OAuthRole;
import org.genesys.blocks.oauth.model.QOAuthClient;
import org.genesys.blocks.oauth.persistence.OAuthClientRepository;
import org.genesys.blocks.security.service.impl.CustomAclServiceImpl;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.cache.Cache;
import org.springframework.cache.CacheManager;
import org.springframework.cache.annotation.CacheEvict;
import org.springframework.cache.annotation.Cacheable;
import org.springframework.cache.annotation.Caching;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.PageRequest;
import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.Sort;
import org.springframework.security.access.prepost.PreAuthorize;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.crypto.password.PasswordEncoder;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.oidc.OidcScopes;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient.Builder;
import org.springframework.security.oauth2.server.authorization.settings.ClientSettings;
import org.springframework.security.oauth2.server.authorization.settings.OAuth2TokenFormat;
import org.springframework.security.oauth2.server.authorization.settings.TokenSettings;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;

import com.querydsl.core.types.Predicate;

import lombok.extern.slf4j.Slf4j;

/**
 * The Class OAuthServiceImpl.
 */
@Service
@Transactional(readOnly = true)
@Slf4j
public class OAuthServiceImpl implements OAuthClientService, InitializingBean {

	/** Supported grant types */
	private static final Set<String> SUPPORTED_GRANT_TYPES = Set.of(
		AuthorizationGrantType.AUTHORIZATION_CODE.getValue(),
		AuthorizationGrantType.CLIENT_CREDENTIALS.getValue(),
		AuthorizationGrantType.REFRESH_TOKEN.getValue()
	);

	/** The site URL. */
	@Value("${base.url:#{null}}")
	private String baseUrl;

	/** The hostname. */
	@Value("${oauth.clientId.suffix:#{null}}")
	private String clientIdSuffix;

	/** The oauth client repository. */
	@Autowired
	private OAuthClientRepository oauthClientRepository;

	/** The password encoder. */
	@Autowired
	public PasswordEncoder passwordEncoder;

	/** The cache manager. */
	@Autowired(required = false)
	private CacheManager cacheManager;

	@Value("${default.oauth.accessToken.validity:259200}") // 3 days = 60 * 60 * 24 * 3
	private int ttlAccessTokenSeconds;

	@Value("${default.oauth.refreshToken.validity:2592000}") // 30 days = 60 * 60 * 24 * 30
	private int ttlRefreshTokenSeconds;

	@Override
	public void afterPropertiesSet() throws Exception {
		if (StringUtils.isEmpty(clientIdSuffix)) {
			if (StringUtils.isNotEmpty(baseUrl)) {
				var siteUrl = new URL(baseUrl);
				clientIdSuffix = siteUrl.getHost();
			} else {
				clientIdSuffix = "localhost";
			}
		}
	}

	@Cacheable(cacheNames = { "oauthclient" }, key = "#clientId", unless = "#result == null")
	public OAuthClient loadClientByClientId(final String clientId)  {
		final OAuthClient client = getClient(clientId);
		if (client == null) {
			return client;
		}

		/*
		 * Order of authorities is extremely important in ACL!
		 * First the dynamic runtime authorities are added, then all the default authorities. 
		 */
		var runtimeAuthorities = new LinkedHashSet<GrantedAuthority>(20);

		// Then add roles from the database
		runtimeAuthorities.addAll(client.getRoles());

		runtimeAuthorities.remove(OAuthRole.EVERYONE); // Add EVERYONE to the tail
		runtimeAuthorities.add(OAuthRole.EVERYONE);

		client.setRuntimeAuthorities(new ArrayList<>(runtimeAuthorities));
		return client;
	}

	private OAuthClient lazyLoad(OAuthClient client) {
		if (client != null) {
			client.getRoles().size();
		}
		return client;
	}

	/*
	 * (non-Javadoc)
	 * @see
	 * org.genesys.blocks.oauth.service.OAuthClientDetailsService#listClientDetails(
	 * )
	 */
	@Override
	public List<OAuthClient> listClientDetails() {
		return oauthClientRepository.findAll(Sort.by("clientId"));
	}

	@Override
	public Page<OAuthClient> listClientDetails(Pageable pageable) {
		return oauthClientRepository.findAll(pageable);
	}

	@Override
	public Page<OAuthClient> listClientDetails(Predicate predicate, Pageable pageable) {
		return oauthClientRepository.findAll(predicate, pageable);
	}

	/*
	 * (non-Javadoc)
	 * @see
	 * org.genesys.blocks.oauth.service.OAuthClientDetailsService#getClient(java.
	 * lang.String)
	 */
	@Override
	public OAuthClient getClient(final String clientId) {
		OAuthClient client = oauthClientRepository.findByClientId(clientId);
		if (client != null)
			client.getRoles().size();
		return client;
	}

	/*
	 * (non-Javadoc)
	 * @see
	 * org.genesys.blocks.oauth.service.OAuthClientDetailsService#removeClient(org.
	 * genesys.blocks.oauth.model.OAuthClient)
	 */
	@Override
	@Transactional
	@Caching(evict = {
		@CacheEvict(cacheNames = { "oauthclient", "oauthclient.registered" }, key = "#client.clientId", condition = "#client != null"),
		@CacheEvict(cacheNames = { "oauthclient.registered.byid", "oauthclient.byid.active" }, key = "T(java.lang.String).valueOf(#client.id)", condition = "#client != null and #client.id != null") 
	})
	public OAuthClient removeClient(final OAuthClient client) {
		oauthClientRepository.delete(client);
		return client;
	}

	/*
	 * (non-Javadoc)
	 * @see
	 * org.genesys.blocks.oauth.service.OAuthClientDetailsService#addClient(org.
	 * genesys.blocks.oauth.model.OAuthClient)
	 */
	@Override
	@Transactional
	public OAuthClient addClient(OAuthClient client) {
		final String clientId = RandomStringUtils.randomAlphanumeric(5).toLowerCase() + "." + RandomStringUtils.randomAlphanumeric(20).toLowerCase() + "@" + clientIdSuffix;
		final String clientSecret = RandomStringUtils.randomAlphanumeric(32);

		final OAuthClient newClient = new OAuthClient();
		newClient.apply(client);
		newClient.setClientId(clientId);
		newClient.setClientSecret(passwordEncoder.encode(clientSecret));
		// Remove any unsupported grants 
		newClient.getAuthorizedGrantTypes().removeIf((grant) -> !SUPPORTED_GRANT_TYPES.contains(grant));
			// Remove role EVERYONE, it's assigned automatically
		newClient.getRoles().remove(OAuthRole.EVERYONE);
		return lazyLoad(oauthClientRepository.save(newClient));
	}

	/*
	 * (non-Javadoc)
	 * @see
	 * org.genesys.blocks.oauth.service.OAuthClientDetailsService#updateClient(long,
	 * int, org.genesys.blocks.oauth.model.OAuthClient)
	 */
	@Override
	@Transactional
	@Caching(evict = {
		@CacheEvict(cacheNames = { "oauthclient", "oauthclient.registered" }, key = "#updates.clientId", condition = "#updates != null"),
		@CacheEvict(cacheNames = { "oauthclient.registered.byid", "oauthclient.byid.active" }, key = "T(java.lang.String).valueOf(#id)", condition = "#updates != null")
	})
	public OAuthClient updateClient(final long id, final int version, final OAuthClient updates) {
		OAuthClient client = oauthClientRepository.findByIdAndVersion(id, version);
		client.apply(updates);
		// Remove any unsupported grants 
		client.getAuthorizedGrantTypes().removeIf((grant) -> !SUPPORTED_GRANT_TYPES.contains(grant));
		// Remove role EVERYONE, it's assigned automatically
		client.getRoles().remove(OAuthRole.EVERYONE);
		return lazyLoad(oauthClientRepository.save(client));
	}

	@Override
	@Transactional
	@Caching(evict = {
		@CacheEvict(cacheNames = { "oauthclient", "oauthclient.registered" }, key = "#sourceId", condition = "#sourceId != null && #targetId != null"),
		@CacheEvict(cacheNames = { "oauthclient.registered.byid", "oauthclient.byid.active" }, key = "T(java.lang.String).valueOf(#result.id)", condition = "#result != null and #result.id != null")
	})
	public OAuthClient updateClientId(String sourceId, String targetId) {
		OAuthClient client = getClient(sourceId);
		client.setClientId(targetId);

		if (cacheManager!=null) {
			// We need to clear sid names cache manually. Duplicate @CacheEvict annotations are not allowed.
			final Cache sidNamesCache = cacheManager.getCache(CustomAclServiceImpl.CACHE_SID_NAMES);
			if (sidNamesCache != null) {
				sidNamesCache.evict(sourceId);
				sidNamesCache.evict(client.getId());
			}
		}

		return lazyLoad(oauthClientRepository.save(client));
	}

	@Override
	public List<OAuthClient> autocompleteClients(final String term, int limit) {
		if (StringUtils.isBlank(term) || term.length() < 1)
			return Collections.emptyList();

		log.debug("Autocomplete for={}", term);

		Predicate predicate = QOAuthClient.oAuthClient.title.startsWithIgnoreCase(term)
			// clientId
			.or(QOAuthClient.oAuthClient.clientId.startsWithIgnoreCase(term))
			// description contains
			.or(QOAuthClient.oAuthClient.description.contains(term));

		return oauthClientRepository.findAll(predicate, PageRequest.of(0, Math.min(100, limit), Sort.by("title"))).getContent();
	}

	@Override
	@Transactional
	@PreAuthorize("hasRole('ADMINISTRATOR') or hasPermission(#oauthClient, 'ADMINISTRATION')")
	public final String resetSecret(OAuthClient oauthClient) {
		return setSecret(oauthClient, null);
	}

	@Override
	@Transactional
	@PreAuthorize("hasRole('ADMINISTRATOR') or hasPermission(#oauthClient, 'ADMINISTRATION')")
	public final String setSecret(OAuthClient oauthClient, String clientSecret) {
		assert oauthClient != null;
		assert oauthClient.getId() != null;

		oauthClient = oauthClientRepository.findById(oauthClient.getId()).orElseThrow(() -> new EntityNotFoundException("Record not found."));

		String oldHash = oauthClient.getClientSecret();
		String newHash = null;
		do {
			if (StringUtils.isBlank(clientSecret)) {
				clientSecret = RandomStringUtils.randomAlphanumeric(32);
			}
			newHash = passwordEncoder.encode(clientSecret);
		} while (oldHash != null && oldHash.equals(newHash));

		oauthClient.setClientSecret(newHash);
		oauthClientRepository.save(oauthClient);
		return clientSecret;
	}

	@Override
	@Transactional
	@PreAuthorize("hasRole('ADMINISTRATOR') or hasPermission(#oauthClient, 'ADMINISTRATION')")
	public final OAuthClient removeSecret(OAuthClient oauthClient) {
		assert oauthClient != null;
		assert oauthClient.getId() != null;

		oauthClient = oauthClientRepository.findById(oauthClient.getId()).orElseThrow(() -> new EntityNotFoundException("Record not found."));
		if (oauthClient.getAuthorizedGrantTypes().contains("client_credentials")) {
			throw new RuntimeException("OAuth Client with client_credentials grant must have a secret");
		}
		oauthClient.setClientSecret(null);
		oauthClient = oauthClientRepository.save(oauthClient);
		return lazyLoad(oauthClient);
	}

	@Override
	public boolean isOriginRegistered(String origin) {
		AtomicBoolean found = new AtomicBoolean(false);

		oauthClientRepository.findAll(QOAuthClient.oAuthClient.origins.contains(origin)).forEach(client -> {
			if (client.getAllowedOrigins().contains(origin)) {
				found.set(true);
			}
		});

		return found.get();
	}

	@Override
	public void save(RegisteredClient registeredClient) {
		log.warn("Saving client: {}", registeredClient);
		throw new RuntimeException("Not implemented");
	}

	@Override
	@Cacheable(cacheNames = "oauthclient.registered.byid", key = "#registrationId", unless = "#result == null")
	public RegisteredClient findById(String registrationId) {
		log.warn("Loading OAuth registered client by registrationId {}", registrationId);
		var client = oauthClientRepository.findById(Long.valueOf(registrationId)).orElseThrow(() -> new EntityNotFoundException("No such client."));
		if (client == null || !client.isActive()) {
			return null;
		}
		return convertToRegisteredClient(client);
	}

	@Override
	@Cacheable(cacheNames = "oauthclient.registered", key = "#clientId", unless = "#result == null")
	public RegisteredClient findByClientId(String clientId) {
		log.warn("Loading OAuth registered client by clientId {}", clientId);
		var client = loadClientByClientId(clientId);
		if (client == null || !client.isActive()) {
			return null;
		}
		return convertToRegisteredClient(client);
	}
	
	@Override
	@Cacheable(cacheNames = "oauthclient.byid.active", key = "#registrationId", unless = "#registrationId == null")
	public boolean isClientActive(String registrationId) {
		log.warn("Loading OAuth registered client by registrationId {}", registrationId);
		var client = oauthClientRepository.findById(Long.valueOf(registrationId)).orElse(null);
		return client != null && client.isActive();
	}

	private RegisteredClient convertToRegisteredClient(OAuthClient client) {
		Builder registeredClient = RegisteredClient.withId(String.valueOf(client.getId()));
		registeredClient
			.clientId(client.getClientId())
			.clientSecret(client.getClientSecret())
			.clientIdIssuedAt(client.getCreatedDate())
			.clientName(client.getTitle())
			;

		if (StringUtils.isBlank(client.getClientSecret())) {
			registeredClient.clientAuthenticationMethod(ClientAuthenticationMethod.NONE);
		} else {
			registeredClient.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST);
			registeredClient.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_BASIC);
			registeredClient.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_JWT);
		}

		// Grant types
		if (CollectionUtils.isNotEmpty(client.getAuthorizedGrantTypes())) {
			client.getAuthorizedGrantTypes().stream().map(AuthorizationGrantType::new).forEach(registeredClient::authorizationGrantType);
		}

		// Redirect URIs
		if (CollectionUtils.isNotEmpty(client.getRegisteredRedirectUri())) {
			client.getRegisteredRedirectUri().forEach(registeredClient::redirectUri);
		}

		// Scopes
		registeredClient
			.scope(OidcScopes.OPENID)
			.scope("profile")
			.scope("email")
			;
		// Apply scopes
		if (CollectionUtils.isNotEmpty(client.getScope())) {
			client.getScope().forEach(registeredClient::scope);
		}

		var token = TokenSettings.builder();
		token
			.accessTokenFormat(OAuth2TokenFormat.SELF_CONTAINED)
//			.accessTokenFormat(OAuth2TokenFormat.REFERENCE) // Spring only supports Opaque or JWTs
			.accessTokenTimeToLive(Duration.of(
				 // 3 days
				Optional.ofNullable(client.getAccessTokenValidity()).orElse(ttlAccessTokenSeconds).longValue(), ChronoUnit.SECONDS))
			.refreshTokenTimeToLive(
				Duration.of(
					// 30 days
					Optional.ofNullable(client.getRefreshTokenValidity()).orElse(ttlRefreshTokenSeconds).longValue(), ChronoUnit.SECONDS)) 
			.reuseRefreshTokens(true);

		registeredClient.tokenSettings(token.build());

		// Settings
		var settings = ClientSettings.builder();
		
		settings.requireAuthorizationConsent(false);
		settings.setting("randomSetting", "Is here"); // Random
		// Copy additional settings
		if (MapUtils.isNotEmpty(client.getAdditionalInformation())) {
			client.getAdditionalInformation().entrySet().forEach(setting -> settings.setting(setting.getKey(), setting.getValue()));
		}

		registeredClient
			.clientSettings(settings.build())
			.build();

		return registeredClient.build();
	}

}