ApiTokenAuthenticationFilter.java

/*
 * Copyright 2023 Global Crop Diversity Trust
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.genesys.blocks.tokenauth.spring;

import java.io.IOException;
import java.util.Optional;

import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import org.apache.commons.lang3.StringUtils;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.security.authentication.AuthenticationCredentialsNotFoundException;
import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.authentication.InternalAuthenticationServiceException;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.web.filter.GenericFilterBean;

import lombok.extern.slf4j.Slf4j;

/**
 * Authentication filter for API tokens extracts the token from request
 * "Authorization" header prefixed with "API-Token" and passes it to the authentication manager.
 */
@Slf4j
public class ApiTokenAuthenticationFilter extends GenericFilterBean {

	private static final String AUTHORIZATION_HEADER = "Authorization";
	public static final String AUTHORIZATION_TYPE = "API-Token";

	private RequestMatcher requiresAuthenticationRequestMatcher;

	private AuthenticationManager authenticationManager;

	private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder.getContextHolderStrategy();

	public ApiTokenAuthenticationFilter(final RequestMatcher requiresAuth, final AuthenticationManager authenticationManager) {
		this.requiresAuthenticationRequestMatcher = requiresAuth;
		this.authenticationManager = authenticationManager;
	}

	@Override
	public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {
		doFilter((HttpServletRequest) request, (HttpServletResponse) response, chain);
	}

	private void doFilter(HttpServletRequest request, HttpServletResponse response, FilterChain chain) throws IOException, ServletException {
		if (!requiresAuthentication(request, response)) {
			chain.doFilter(request, response);
			return;
		}
		try {
			Authentication authenticationResult = attemptAuthentication(request, response);
			if (authenticationResult == null) {
				// return immediately as subclass has indicated that it hasn't completed
				return;
			}
			// Authentication success
			successfulAuthentication(request, response, chain, authenticationResult);
		} catch (InternalAuthenticationServiceException failed) {
			log.error("An internal error occurred while trying to authenticate the user.", failed);
			unsuccessfulAuthentication(request, response, failed);
		} catch (AuthenticationException ex) {
			// Authentication failed
			unsuccessfulAuthentication(request, response, ex);
		}
	}

	protected void successfulAuthentication(HttpServletRequest request, HttpServletResponse response, FilterChain chain,
			Authentication authResult) throws IOException, ServletException {
		SecurityContext context = this.securityContextHolderStrategy.createEmptyContext();
		context.setAuthentication(authResult);
		this.securityContextHolderStrategy.setContext(context);
		try {
			chain.doFilter(request, response);
		} finally {
			// Note: SecurityContext is cleared in SecurityContextPeristenceFilter
		}
	}

	protected void unsuccessfulAuthentication(HttpServletRequest request, HttpServletResponse response, AuthenticationException failed) {
		this.securityContextHolderStrategy.clearContext();
		log.trace("Failed to process authentication request", failed);
		log.trace("Cleared SecurityContextHolder");
		log.trace("Handling authentication failure");

		response.addHeader(HttpHeaders.WWW_AUTHENTICATE, "Invalid API-Token");
		response.setStatus(HttpStatus.UNAUTHORIZED.value());
	}


	protected boolean requiresAuthentication(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) {
		if (this.requiresAuthenticationRequestMatcher.matches(httpServletRequest)) {
			Optional<String> tokenParam = Optional.ofNullable(httpServletRequest.getHeader(AUTHORIZATION_HEADER));

			if (tokenParam.isPresent()) {
				return StringUtils.startsWithIgnoreCase(tokenParam.get(), AUTHORIZATION_TYPE);
			}
		}
		return false;
	}

	public Authentication attemptAuthentication(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) throws AuthenticationException, IOException, ServletException {

		String token = httpServletRequest.getHeader(AUTHORIZATION_HEADER);
		log.trace("Have {} header: {}", AUTHORIZATION_HEADER, token);

		token = StringUtils.removeStartIgnoreCase(token, AUTHORIZATION_TYPE).trim();
		if (token.length() == 0) {
			throw new AuthenticationCredentialsNotFoundException("Invalid API token"); // Bail fast
		}
		log.debug("Received {} token: {}", AUTHORIZATION_TYPE, token);
		// Authenticate by token
		return this.authenticationManager.authenticate(new ApiTokenAuthenticationToken(token));
	}
}