LocaleURLFilter.java

/*
 * Copyright 2018 Global Crop Diversity Trust
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.genesys.server.servlet.filter;

import java.io.IOException;
import java.util.Enumeration;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Set;
import java.util.Locale.FilteringMode;
import java.util.regex.Matcher;

import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import org.apache.commons.lang3.StringUtils;
import org.springframework.web.filter.GenericFilterBean;

import lombok.extern.slf4j.Slf4j;

/**
 * Handles the very important locale selection by URL
 */
@Slf4j
public class LocaleURLFilter extends GenericFilterBean {

	private static final LocaleURLMatcher localeUrlMatcher = new LocaleURLMatcher();
	public static final String REQUEST_LOCALE_ATTR = LocaleURLFilter.class.getName() + ".LOCALE";
	private static final String REQUEST_INTERNAL_URL = LocaleURLFilter.class.getName() + ".INTERNALURL";
	private static final String REQUEST_LOCALE_LANGUAGE = LocaleURLFilter.class.getName() + ".LANGUAGE";

	private Set<Locale> supportedLocales = Set.of();
	private Locale defaultLocale;

	public LocaleURLFilter() {
		addRequiredProperty("excludePaths");
		addRequiredProperty("allowedLocales");
		addRequiredProperty("defaultLocale");
	}
	
	public void setExcludePaths(String excludePaths) {
		if (StringUtils.isNotBlank(excludePaths)) {
			final String[] ex = excludePaths.split("\\s*,\\s*");
			for (final String e : ex) {
				log.info("Excluding path: {}", e);
			}
			localeUrlMatcher.setExcludedPaths(ex);
		}
	}
	
	public void setAllowedLocales(String allowedLocales) {
		if (StringUtils.isNotBlank(allowedLocales)) {
			Set<Locale> loc = new HashSet<>();
			final String[] ex = allowedLocales.split("\\s*,\\s*");
			for (final String l : ex) {
				log.info("Allowed locale: {}", l);
				loc.add(Locale.forLanguageTag(l));
			}
			this.supportedLocales = loc;
			log.warn("Supported locales: {}", loc);
		}
	}
	
	public void setDefaultLocale(String defaultLocale) {
		log.info("Default locale: {}", defaultLocale);
		if (defaultLocale != null) {
			this.defaultLocale = Locale.forLanguageTag(defaultLocale);
		} else {
			this.defaultLocale = Locale.getDefault();
		}
		log.info("Using default locale: {}", this.defaultLocale);
	}
	
	@Override
	public void destroy() {
		log.info("Destroying LocaleURLFilter");
	}

	@Override
	public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
		final HttpServletRequest httpRequest = (HttpServletRequest) servletRequest;
		final HttpServletResponse httpResponse = (HttpServletResponse) servletResponse;
		final String url = httpRequest.getRequestURI().substring(httpRequest.getContextPath().length());

		if (localeUrlMatcher.isExcludedPath(url)) {
			log.debug("Excluded " + url);
			filterChain.doFilter(servletRequest, servletResponse);
			return;
		}

		log.debug("Incoming URL: {}", url);
		{
			final Enumeration<String> attrNames = httpRequest.getAttributeNames();
			while (attrNames.hasMoreElements()) {
				final String attrName = attrNames.nextElement();
				log.trace("Request attr {} = {}", attrName, httpRequest.getAttribute(attrName));
			}
		}
		
		String existingUrlLanguage = (String) httpRequest.getAttribute(REQUEST_LOCALE_LANGUAGE);
		if (existingUrlLanguage != null) {
			final LocaleWrappedServletResponse localeResponse = new LocaleWrappedServletResponse(httpResponse, localeUrlMatcher, existingUrlLanguage, defaultLocale.toLanguageTag());
			log.debug("Found REQUEST_LOCALE_LANGUAGE {} in request", existingUrlLanguage);
			filterChain.doFilter(servletRequest, localeResponse);
			return;
		}

		final Matcher matcher = localeUrlMatcher.matcher(url);
		if (matcher.matches()) {
			final String urlLanguage = matcher.group(1);
			final String remainingUrl = matcher.group(2);

			Locale urlLocale = Locale.forLanguageTag(urlLanguage);

			if (this.supportedLocales != null) {
				var range1 = new Locale.LanguageRange(urlLanguage, Locale.LanguageRange.MAX_WEIGHT);
				var matchingLocale = Locale.lookup(List.of(range1), supportedLocales);
				if (matchingLocale != null) {
					log.info("Matching locale:{}", matchingLocale);
					urlLocale = matchingLocale;
				} else {
					log.warn("Locale {} not allowed. Temporary redirect to default locale.", urlLanguage);
					httpResponse.sendRedirect(getInternalUrl(remainingUrl, httpRequest.getQueryString()));
					return;
				}
			}

			log.info("Using locale:{}", urlLocale);

			if (urlLocale.equals(this.defaultLocale)) {
				final String defaultLocaleUrl = getInternalUrl(remainingUrl, httpRequest.getQueryString());
				log.info("Default locale requested, permanent-redirect to {}", defaultLocaleUrl);

				httpResponse.reset();
				httpResponse.setStatus(HttpServletResponse.SC_MOVED_PERMANENTLY);
				httpResponse.setHeader("Location", defaultLocaleUrl);
				return;
			}

			httpRequest.setAttribute(REQUEST_LOCALE_ATTR, urlLocale);
			httpRequest.setAttribute(REQUEST_LOCALE_LANGUAGE, urlLanguage);
			httpRequest.setAttribute(REQUEST_INTERNAL_URL, getInternalUrl(remainingUrl, httpRequest.getQueryString()));

			if (log.isTraceEnabled()) {
				log.debug("URL matches! lang={} remaining={}", urlLanguage, remainingUrl);
				log.debug("Country: {} Lang: {} locale={}", urlLocale.getCountry(), urlLocale.getLanguage(), urlLocale);

				final Enumeration<String> attrNames = httpRequest.getAttributeNames();
				while (attrNames.hasMoreElements()) {
					final String attrName = attrNames.nextElement();
					log.debug("Request attr {} = {}", attrName, httpRequest.getAttribute(attrName));
				}

				log.debug("Proxying request to remaining URL {}", remainingUrl);
			}

			final LocaleWrappedServletResponse localeResponse = new LocaleWrappedServletResponse(httpResponse, localeUrlMatcher, urlLanguage, defaultLocale.toLanguageTag());
			final LocaleWrappedServletRequest localeRequest = new LocaleWrappedServletRequest(httpRequest, url, remainingUrl);

			filterChain.doFilter(localeRequest, localeResponse);
		} else {
			log.debug("No match on url {} setting {}", url, getInternalUrl(url, httpRequest.getQueryString()));
			httpRequest.setAttribute(REQUEST_INTERNAL_URL, getInternalUrl(url, httpRequest.getQueryString()));
			httpRequest.setAttribute(REQUEST_LOCALE_LANGUAGE, "");
			final LocaleWrappedServletResponse localeResponse = new LocaleWrappedServletResponse(httpResponse, localeUrlMatcher, null, defaultLocale.toLanguageTag());
			filterChain.doFilter(servletRequest, localeResponse);
		}
	}

	private String getInternalUrl(final String url, final String queryString) {
		if (StringUtils.isBlank(queryString))
			return url;
		else
			return url + "?" + queryString;
	}

}