S3StorageServiceImpl.java

/*
 * Copyright 2018 Global Crop Diversity Trust
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.genesys.filerepository.service.impl;

import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.net.URLEncoder;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.security.InvalidKeyException;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Date;
import java.util.List;
import java.util.Locale;
import java.util.TimeZone;
import java.util.function.Consumer;
import java.util.stream.Collectors;

import javax.crypto.Mac;
import javax.crypto.spec.SecretKeySpec;
import javax.xml.bind.DatatypeConverter;

import org.apache.commons.lang3.StringUtils;
import org.genesys.filerepository.InvalidRepositoryPathException;
import org.genesys.filerepository.service.BytesStorageService;
import org.genesys.filerepository.service.s3.ListBucketResult;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.core.io.FileSystemResource;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpRequest;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;
import org.springframework.http.client.ClientHttpRequestInterceptor;
import org.springframework.http.client.ClientHttpResponse;
import org.springframework.http.converter.xml.MappingJackson2XmlHttpMessageConverter;
import org.springframework.stereotype.Service;
import org.springframework.web.client.HttpClientErrorException;
import org.springframework.web.client.RestTemplate;

import com.fasterxml.jackson.module.jaxb.JaxbAnnotationModule;

import lombok.extern.slf4j.Slf4j;

/**
 * Amazon S3 storage implementation.
 */
@Service("S3Storage")
@Slf4j
public class S3StorageServiceImpl implements BytesStorageService, InitializingBean {

	private static final Charset CHARSET_UTF8 = StandardCharsets.UTF_8;

	private static final String HTTP_AUTHORIZATION = "Authorization";
	private static final String LINE_SEPARATOR = "\n";
	
	private static final String AMZ_CONTENT_SHA256 = "X-Amz-Content-SHA256";
	private static final String AMZ_DATE = "X-Amz-Date";

	/** Algorithm for AWS V4 */
	private static final String AWS_SIGN_ALG = "HmacSHA256";

	/** The Constant HEADER_DATE_FORMAT. */
	private static final ThreadLocal<SimpleDateFormat> HEADER_DATE_FORMAT = new ThreadLocal<SimpleDateFormat>() {
		@Override
		protected SimpleDateFormat initialValue() {
			var sdf = new SimpleDateFormat("yyyyMMdd'T'HHmmss'Z'", Locale.US);
			sdf.setTimeZone(TimeZone.getTimeZone("UTC"));
			return sdf;
		}
	};

	/** The Constant YYYYMMDD date format. */
	private static final ThreadLocal<SimpleDateFormat> YYYYMMDD = new ThreadLocal<SimpleDateFormat>() {
		@Override
		protected SimpleDateFormat initialValue() {
			var sdf = new SimpleDateFormat("yyyyMMdd");
			sdf.setTimeZone(TimeZone.getTimeZone("UTC"));
			return sdf;
		}
	};

	/** The rest template. */
	private final RestTemplate restTemplate = initializeRestTemplate();

	/** The access key. */
	@Value("${s3.accessKey}")
	private String accessKey;

	/** The secret key. */
	@Value("${s3.secretKey}")
	private String secretKey;

	/** The bucket. */
	@Value("${s3.bucket}")
	private String bucket;

	/** The region. */
	@Value("${s3.region}")
	private String region;

	/** The prefix. */
	@Value("${s3.prefix}")
	private String prefix;
	
	// We use this handle the prefix
	private Path awsBasePath;

	@Override
	public void afterPropertiesSet() throws Exception {
		this.awsBasePath = Paths.get(StringUtils.defaultIfBlank(this.prefix, "/"));

		log.warn("S3 region={} bucket={} prefix={} dummy={}", region, bucket, prefix, getAwsUrl(Paths.get("/dummy", "filename.txt")));
	}

	/*
	 * (non-Javadoc)
	 * @see org.genesys.filerepository.service.BytesStorageService#upsert
	 * (java.lang.String, java.lang.String, byte[])
	 */
	@Override
	public void upsert(final Path bytesFile, final byte[] data) throws IOException {
		final Path normalPath = bytesFile.normalize().toAbsolutePath();

		if (data == null) {
			throw new IOException("File bytes are null");
		}

		log.debug("Putting to path={} len={}", bytesFile, data.length);

		final String url = getAwsUrl(normalPath);
		try {
			restTemplate.put(url, data);
		} catch (final HttpClientErrorException e) {
			log.error("Upserting file failed with error\n{}", e.getResponseBodyAsString());
			throw e;
		}
	}

	/** {@inheritDoc} */
	@Override
	public void upsert(Path bytesFile, File fileWithData) throws IOException {
		if (fileWithData == null || !fileWithData.exists()) {
			throw new IOException("File is null or does not exist.");
		}

		final Path normalPath = bytesFile.normalize().toAbsolutePath();
		log.debug("Putting to path={} len={}", bytesFile, fileWithData.length());

		final String url = getAwsUrl(normalPath);
		try {
			ResponseEntity<String> response = restTemplate.exchange(url, HttpMethod.PUT, new HttpEntity<>(new FileSystemResource(fileWithData)), String.class);
			log.info("Upload status code: {}", response.getStatusCode());
			log.debug("Upload response: {}", response.getBody());
		} catch (final HttpClientErrorException e) {
			log.error("Upserting file failed with error\n{}", e.getResponseBodyAsString());
			throw e;
		}
	}

	/*
	 * (non-Javadoc)
	 * @see org.genesys.filerepository.service.BytesStorageService#remove
	 * (java.lang.String, java.lang.String)
	 */
	@Override
	public void remove(final Path bytesFile) throws IOException {
		final Path normalPath = bytesFile.normalize().toAbsolutePath();

		final String url = getAwsUrl(normalPath);

		log.debug("Deleting from path={} url={}", normalPath, url);

		try {
			restTemplate.delete(url);
		} catch (final HttpClientErrorException e) {
			log.error("Deleting file failed with error\n{}", e.getResponseBodyAsString());
			throw e;
		}
	}

	/*
	 * (non-Javadoc)
	 * @see org.genesys.filerepository.service.BytesStorageService#get(java
	 * .lang.String, java.lang.String)
	 */
	@Override
	public byte[] get(final Path bytesFile) throws IOException {
		final Path normalPath = bytesFile.normalize().toAbsolutePath();
	
		log.debug("Getting bytes path={} filename={}", normalPath.getParent(), normalPath.getFileName());
		final String url = getAwsUrl(normalPath);

		try {
			return restTemplate.getForObject(url, byte[].class);
		} catch (final HttpClientErrorException e) {
			log.error("Getting bytes failed with {} {} error\n{}", e.getStatusCode(), e.getStatusText(), e.getResponseBodyAsString());
			if (e.getStatusCode() == HttpStatus.NOT_FOUND) {
				return null; // Match behavior of FilesystemStorageServiceImpl
			}
			throw e;
		}
	}

	@Override
	public void get(Path bytesFile, Consumer<InputStream> consumerOfStream) throws IOException {
		final Path normalPath = bytesFile.normalize().toAbsolutePath();

		if (log.isDebugEnabled()) {
			log.debug("Getting bytes path={} filename={}", normalPath.getParent(), normalPath.getFileName());
		}
		final String url = getAwsUrl(normalPath);

		try {
			restTemplate.execute(url, HttpMethod.GET, null, (clientHttpResponse) -> {
				try (InputStream inputStream = clientHttpResponse.getBody()) {
					consumerOfStream.accept(inputStream);
				}
				return null;
			});
		} catch (final HttpClientErrorException e) {
			log.error("Getting bytes failed with error\n{}", e.getResponseBodyAsString());
			throw e;
		}
	}

	/**
	 * Returns URL for S3 resource.
	 *
	 * @param path the normalized absolute path
	 * @param filename the filename
	 * @return the url
	 */
	private String getAwsUrl(final Path bytesFile) {
		final String url = String.format("https://%s%s", getHost(), getAwsPath(bytesFile));
		log.trace("getUrl path={} result={}", bytesFile, url);
		return url;
	}

	/**
	 * Gets the path. Must end with "/" if not blank.
	 *
	 * @param path the path
	 * @return the path
	 */
	private String getAwsPath(final Path path) {
		return Paths.get(awsBasePath.toString(), path.toString()).normalize().toAbsolutePath().toString();
	}

	/**
	 * Get the hostname part of the S3 resource URL.
	 *
	 * @return the host
	 */
	private String getHost() {
		return String.format("%s.s3-%s.amazonaws.com", bucket, region);
	}

	/**
	 * Returns string to sign as specified at
	 * http://docs.aws.amazon.com/AmazonS3/latest/dev/RESTAuthentication.html#
	 * ConstructingTheAuthenticationHeader
	 *
	 * @param request
	 * @param body
	 *
	 *
	 * @throws NoSuchAlgorithmException
	 */
	private String buildCanonicalRequest(final HttpRequest request, final byte[] body) throws NoSuchAlgorithmException {
		final StringBuilder sb = new StringBuilder();

		// Content hash
		final byte[] contentSha256 = hashSha256(body == null ? "".getBytes(CHARSET_UTF8) : body);

		// Add header
		request.getHeaders().set(AMZ_CONTENT_SHA256, printHex(contentSha256));

		// <HTTPMethod>\n
		// <CanonicalURI>\n
		// <CanonicalQueryString>\n
		// <CanonicalHeaders>\n
		// <SignedHeaders>\n
		// <HashedPayload>

		// HTTP-Verb
		sb.append(request.getMethod()).append(LINE_SEPARATOR);

		// CanonicalURI
		sb.append(request.getURI().getPath()).append(LINE_SEPARATOR);

		// CanonicalQueryString
		sb.append(buildQueryString(StringUtils.defaultIfBlank(request.getURI().getQuery(), ""))).append(LINE_SEPARATOR);

		// sorted headers, lowercase
		request.getHeaders().keySet().stream().map(String::toLowerCase).sorted()
			// remove blanks
			.filter(headerName -> !request.getHeaders().getValuesAsList(headerName).isEmpty())
			// print values, but how do we print multiples??
			.forEach(headerName -> {
				sb.append(headerName).append(':').append(request.getHeaders().get(headerName).get(0)).append(LINE_SEPARATOR);
			});
		sb.append(LINE_SEPARATOR);

		// signed headers
		sb.append(request.getHeaders().keySet().stream().map(String::toLowerCase).sorted().collect(Collectors.joining(";")));
		sb.append(LINE_SEPARATOR);

		// HashedPayload is the hexadecimal value of the SHA256 hash of the request
		// payload.
		sb.append(printHex(contentSha256));

		log.trace("canonicalRequest\n{}", sb);
		return sb.toString();
	}

	/**
	 * Sorted by query parameter name.
	 *
	 * @param query the S3 query string
	 * @return a sorted, normalized list of query parameters
	 * as US-ASCII
	 */
	public static String buildQueryString(String query) {
		log.trace("Encoding query string: {}", query);
		return Arrays.stream(query.split("&"))
			// split
			.map(part -> part.split("=", 2))
			// encode parts
			.map(part -> URLEncoder.encode(part[0], StandardCharsets.US_ASCII) + (part.length == 1 ? "" : "=" + URLEncoder.encode(part[1], StandardCharsets.US_ASCII)))
			// must be sorted
			.sorted()
			// debug
			.peek(part -> log.trace("Querystring part: {}", part))
			// merge
			.reduce("", (res, part) -> {
				if (res.length() == 0) {
					return part;
				} else {
					// Do not &amp; the ampersands!
					return res + '&' + part;
				}
			});
	}

	/**
	 * Hash sha 256.
	 *
	 * @param bytes the bytes
	 * @return the byte[]
	 * @throws NoSuchAlgorithmException the no such algorithm exception
	 */
	public static byte[] hashSha256(final byte[] bytes) throws NoSuchAlgorithmException {
		final MessageDigest digest = MessageDigest.getInstance("SHA-256");
		return digest.digest(bytes);
	}

	/**
	 * Prints the hex.
	 *
	 * @param bytes the bytes
	 * @return the string
	 */
	public static String printHex(final byte[] bytes) {
		return DatatypeConverter.printHexBinary(bytes).toLowerCase();
	}

	private static String buildStringToSign(final String canonicalRequest, final Date date, final String region, final String awsService) throws NoSuchAlgorithmException {
		final StringBuilder sb = new StringBuilder();

		// "AWS4-HMAC-SHA256" + LINE_SEPARATOR +
		// timeStampISO8601Format + LINE_SEPARATOR +
		// <Scope> + LINE_SEPARATOR +
		// Hex(SHA256Hash(<CanonicalRequest>))

		sb.append("AWS4-HMAC-SHA256\n");
		sb.append(HEADER_DATE_FORMAT.get().format(date)).append(LINE_SEPARATOR);

		// 20130606/us-east-1/s3/aws4_request
		sb.append(YYYYMMDD.get().format(date)).append('/').append(region).append('/').append(awsService).append("/aws4_request").append(LINE_SEPARATOR);

		// Hex(SHA256Hash(<CanonicalRequest>))
		sb.append(printHex(hashSha256(canonicalRequest.getBytes(CHARSET_UTF8))));

		log.trace("stringToSign\n{}", sb);
		return sb.toString();
	}

	private static byte[] calculateSigningKey(final String secretKey, final String date, final String region, final String service) throws InvalidKeyException,
			NoSuchAlgorithmException {
		log.trace("sign date={} region={} service={}", date, region, service);
		return
		// SigningKey = HMAC-SHA256(<DateRegionServiceKey>, "aws4_request")
		hmacSha256(
			// DateRegionServiceKey = HMAC-SHA256(<DateRegionKey>, "<aws-service>")
			hmacSha256(
				// DateRegionKey = HMAC-SHA256(<DateKey>, "<aws-region>")
				hmacSha256(
					// DateKey = HMAC-SHA256("AWS4"+"<SecretAccessKey>", "<YYYYMMDD>")
					hmacSha256(("AWS4" + secretKey).getBytes(CHARSET_UTF8), date), region), service), "aws4_request");
	}

	private static byte[] hmacSha256(final byte[] key, final String data) throws InvalidKeyException, NoSuchAlgorithmException {
		return hmacSha256(key, data.getBytes(CHARSET_UTF8));
	}

	private static byte[] hmacSha256(final byte[] key, final byte[] data) throws NoSuchAlgorithmException, InvalidKeyException {
		final Mac mac = Mac.getInstance(AWS_SIGN_ALG);
		mac.init(new SecretKeySpec(key, AWS_SIGN_ALG));
		return mac.doFinal(data);
	}

	/**
	 * Returns AWS authorization HTTP Header.
	 *
	 * http://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-auth-using-authorization-header.html
	 * http://docs.aws.amazon.com/AmazonS3/latest/API/sig-v4-header-based-auth.html
	 *
	 * @param finalSignature the signature
	 * @param date
	 * @return the authorization header
	 */
	private String getAuthorizationHeader(final byte[] finalSignature, final HttpRequest request, final Date date) {

		final StringBuilder sb = new StringBuilder("AWS4-HMAC-SHA256").append(" Credential=")
			// credential
			.append(getAWSCredential(date))
			// signed headers
			.append(",SignedHeaders=");

		// signed headers
		sb.append(request.getHeaders().keySet().stream().map(String::toLowerCase).sorted().collect(Collectors.joining(";")));

		// request signature
		sb.append(",Signature=").append(printHex(finalSignature));

		log.trace("authorizationHeader=\n{}", sb);
		return sb.toString();
	}

	private String getAWSCredential(final Date date) {
		return String.format("%s/%s/%s/%s/aws4_request", accessKey, (YYYYMMDD.get().format(date)), (region), ("s3"));
	}

	/**
	 * Initializes RestTemplate with the interceptor that signs the HTTP requests to
	 * AWS using V4 signature method.
	 *
	 * @return the rest template
	 */
	private RestTemplate initializeRestTemplate() {
		final RestTemplate restTemplate = new RestTemplate();

		// create module
		JaxbAnnotationModule jaxbAnnotationModule = new JaxbAnnotationModule();

		restTemplate.getMessageConverters().stream().filter(converter -> {
			return converter instanceof MappingJackson2XmlHttpMessageConverter;
		}).forEach(converter -> ((MappingJackson2XmlHttpMessageConverter) converter).getObjectMapper().registerModule(jaxbAnnotationModule));

		final List<ClientHttpRequestInterceptor> interceptors = new ArrayList<>();
		interceptors.add((request, body, execution) -> {

			final Date date = new Date();
			request.getHeaders().set("Host", getHost());
			// This avoids date formatting problems
			request.getHeaders().add(AMZ_DATE, HEADER_DATE_FORMAT.get().format(date));
			// DELETE has no Content-length
			if (request.getMethod() != HttpMethod.POST && request.getMethod() != HttpMethod.PUT) {
				request.getHeaders().remove(HttpHeaders.CONTENT_LENGTH);
			}

			try {
				final String canonicalRequest = buildCanonicalRequest(request, body);
				final String stringToSign = buildStringToSign(canonicalRequest, date, region, "s3");
				final byte[] signingKey = calculateSigningKey(secretKey, YYYYMMDD.get().format(date), region, "s3");
				final byte[] finalSignature = hmacSha256(signingKey, stringToSign);

				request.getHeaders().set(HTTP_AUTHORIZATION, getAuthorizationHeader(finalSignature, request, date));
			} catch (NoSuchAlgorithmException | InvalidKeyException e) {
				log.error("Could not sign AWS request.", e);
			}

			final ClientHttpResponse response = execution.execute(request, body);

			if (response.getStatusCode() != HttpStatus.OK) {
				log.trace("S3 HTTP {} {} status={} {}", request.getMethod(), request.getURI(), response.getRawStatusCode(), response.getStatusText());
			}

			return response;
		});
		restTemplate.setInterceptors(interceptors);

		return restTemplate;
	}

	/**
	 * http://docs.aws.amazon.com/AmazonS3/latest/API/RESTObjectHEAD.html
	 *
	 * @param bytesFile the bytes file
	 * @return true, if successful
	 * @throws IOException when other stuff is bad
	 * @throws InvalidRepositoryPathException when path or filename are weird
	 */
	@Override
	public boolean exists(final Path bytesFile) throws IOException, InvalidRepositoryPathException {
		final Path normalPath = bytesFile.normalize().toAbsolutePath();

		try {
			if (log.isTraceEnabled()) {
				log.trace("Fetching HEAD for url={}", getAwsUrl(normalPath));
			}
			final HttpHeaders headers = restTemplate.headForHeaders(getAwsUrl(normalPath));
			if (log.isDebugEnabled()) {
				headers.forEach((header, values) -> {
					log.debug("{}: {}", header, values);
				});
			}
			return true;

		} catch (final HttpClientErrorException e) {
			if (e.getStatusCode() != HttpStatus.NOT_FOUND) {
				log.error("Testing for file failed with error\n{}", e.getResponseBodyAsString());
				throw e;
			}
		} catch (final Throwable e) {
			log.warn("Catch this thing!", e);
			throw e;
		}
		return false;
	}

	/**
	 * List bucket contents as per
	 * http://docs.aws.amazon.com/AmazonS3/latest/API/RESTBucketGET.html
	 *
	 * @param path the repository path
	 * @return list of filenames at specified path
	 * @throws InvalidRepositoryPathException when path is messed up
	 */
	@Override
	public List<String> listFiles(final Path path) throws InvalidRepositoryPathException {

		PathValidator.checkValidPath(path);

		final String s3prefix = getAwsPath(path).substring(1);
		log.debug("Listing S3 bucket for host={} path={} prefix={}", getHost(), path, s3prefix);

		try {
			final ListBucketResult listBucketResult = restTemplate.getForObject("https://" + getHost() + "/?list-type=2&delimiter=/&prefix={path}/", ListBucketResult.class, s3prefix);

			if (log.isDebugEnabled()) {
				log.debug("Bucket name={} maxKeys={} delimiter={} prefix={}", listBucketResult.getName(), listBucketResult.getMaxKeys(), listBucketResult.getDelimiter(), listBucketResult
					.getPrefix());

				if (listBucketResult.getCommonPrefixes() != null) {
					listBucketResult.getCommonPrefixes().forEach(commonPrefix -> {
						log.debug("Subprefix={}", commonPrefix.getPrefix());
					});
				}

				if (listBucketResult.getContents() != null) {
					listBucketResult.getContents().forEach(content -> {
						log.debug("Object prefix={} len={} filename={}", content.getKey(), content.getSize(), content.getKey().substring(s3prefix.length()));
					});
				}
			}

			if (listBucketResult == null || listBucketResult.getContents() == null) {
				return Collections.emptyList();
			} else {
				return listBucketResult.getContents().stream().map(content -> content.getKey().substring(s3prefix.length())).collect(Collectors.toList());
			}
		} catch (HttpClientErrorException e) {
			log.error("Error listing files at path={}\n{}", path, e.getResponseBodyAsString());
			throw e;
		}
	}

}