AccessionProcessor.java

/*
 * Copyright 2018 Global Crop Diversity Trust
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.genesys.server.service.worker;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.function.Consumer;
import java.util.stream.Collectors;

import javax.persistence.EntityManager;

import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.time.StopWatch;
import org.genesys.server.model.genesys.Accession;
import org.genesys.server.model.genesys.AccessionCollect;
import org.genesys.server.model.genesys.AccessionData;
import org.genesys.server.model.genesys.QAccession;
import org.genesys.server.model.genesys.QAccessionCollect;
import org.genesys.server.model.genesys.QAccessionId;
import org.genesys.server.persistence.AccessionRepository;
import org.genesys.server.service.AccessionService;
import org.genesys.server.service.AccessionService.IAccessionBatchAction;
import org.genesys.server.service.ElasticsearchService;
import org.genesys.server.service.filter.AccessionFilter;
import org.hibernate.Hibernate;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.core.task.TaskExecutor;
import org.springframework.data.domain.Pageable;
import org.springframework.data.jpa.repository.support.Querydsl;
import org.springframework.stereotype.Component;
import org.springframework.transaction.annotation.Propagation;
import org.springframework.transaction.annotation.Transactional;

import com.querydsl.core.types.EntityPath;
import com.querydsl.core.types.Predicate;
import com.querydsl.core.types.dsl.NumberPath;
import com.querydsl.core.types.dsl.PathBuilder;
import com.querydsl.core.types.dsl.PathBuilderFactory;
import com.querydsl.core.types.dsl.SetPath;
import com.querydsl.core.types.dsl.StringPath;
import com.querydsl.jpa.JPQLQuery;
import com.querydsl.jpa.impl.JPAQueryFactory;

/**
 * Executes actions on filtered accessions.
 */
@Component
public class AccessionProcessor {

	private static final Logger LOG = LoggerFactory.getLogger(AccessionProcessor.class);

	@Autowired
	private TaskExecutor taskExecutor;

	@Autowired
	private AccessionService accessionService;

	@Autowired
	private AccessionRepository accessionRepository;

	@Autowired
	private EntityManager em;

	@Autowired
	private JPAQueryFactory jpaQueryFactory;

	/// Size of database batch scan for IDs
	private final static int batchSize = 1000;

	@Autowired(required = false)
	private ElasticsearchService elasticSearchService;

	@Transactional(readOnly = true, propagation = Propagation.REQUIRES_NEW)
	public void process(AccessionFilter filter, IAccessionBatchAction action) throws Exception {
		if (filter.isFulltextQuery()) {
			processEs(filter, action, null);
		} else {
			process(toQuery(filter, null), action, null);
		}
	}

	@Transactional(readOnly = true, propagation = Propagation.REQUIRES_NEW)
	public void process(AccessionFilter filter, IAccessionBatchAction action, Long maxSize) throws Exception {
		if (filter.isFulltextQuery()) {
			processEs(filter, action, maxSize);
		} else {
			process(toQuery(filter, null), action, maxSize);
		}
	}

	@Transactional(readOnly = true, propagation = Propagation.REQUIRED)
	public void processIds(AccessionFilter filter, Pageable page, Consumer<List<Long>> action) throws Exception {
		if (filter.isFulltextQuery()) {
			elasticSearchService.processById(Accession.class, filter, action::accept, page);
		} else {
			processIdOnly(toQuery(filter, page), action);
		}
	}

	private void processIdOnly(JPQLQuery<Long> accessionIdQuery, Consumer<List<Long>> action) throws Exception {
		StopWatch stopWatch = new StopWatch();
		stopWatch.start();

		var localBatchSize = 10000;

//		List<Long> results = LoggerHelper.withSqlLogging(accessionIdQuery::fetch);
		List<Long> results = accessionIdQuery.fetch();
		var count = results.size();

		for (var pos = 0; pos < count; pos += localBatchSize) {
			stopWatch.split();
			LOG.debug("Reading Accessions. Stopwatch={}s {}+{} of {}. Processing at {} accessions/s",
				stopWatch.getSplitTime() / 1000, pos, localBatchSize, count,
				(double) (localBatchSize) / (stopWatch.getSplitTime() / 1000));
			action.accept(results.subList(pos, Math.min(pos + localBatchSize, count)));
			em.clear();
		}
	}

	private JPQLQuery<Long> toQuery(AccessionFilter filter, Pageable page) throws Exception {
		JPQLQuery<Long> query = jpaQueryFactory.from(QAccession.accession)
			// select id only
			.select(QAccession.accession.id)
			// order by id
			.orderBy(QAccession.accession.id.asc());

		// apply filter
		if (filter != null) {
			filter.buildJpaQuery(query, QAccession.accession);
		}

		if (page != null && !page.isUnpaged()) {
			query.offset(page.getOffset());
			query.limit(page.getPageSize());
		}

		return query;
	}
	
	private void processEs(AccessionFilter filter, IAccessionBatchAction action, Long maxSize) throws Exception {
		elasticSearchService.process(Accession.class, filter, action, maxSize);
	}

	/**
	 * Advanced usage.
	 * 
	 * @param accessionIdQuery query statement that returns ordered accessionIds
	 * (for pagination)
	 * @param action
	 * @param maxSize
	 * @throws Exception
	 */
	@Transactional(readOnly = true, propagation = Propagation.REQUIRES_NEW)
	public void process(JPQLQuery<Long> accessionIdQuery, IAccessionBatchAction action, Long maxSize) throws Exception {
//		long count = accessionIdQuery.fetchCount();
//		LOG.warn("Processing {} accession IDs", count);

		StopWatch stopWatch = new StopWatch();
		stopWatch.start();

		List<Long> results = accessionIdQuery.fetch();
		var count = results.size();

		for (var pos = 0; pos < count; pos += batchSize) {
			stopWatch.split();
			LOG.debug("Reading Accessions. Stopwatch={}s {}+{} of {}. Processing at {} accessions/s", stopWatch.getSplitTime() / 1000, pos, batchSize, count,
				(double) (batchSize) / (stopWatch.getSplitTime() / 1000));
			loadAndProcess(results.subList(pos, Math.min(pos + batchSize, count)), action);
			em.clear();
		}

//		int startPosition = 0;
//		List<Long> results = accessionIdQuery.fetch();
//		do {
//			accessionIdQuery.offset(startPosition);
//
//			// Respect maxSize
//			if (maxSize != null && startPosition + batchSize > maxSize) {
//				// we would be over allowed number of records
//				accessionIdQuery.limit(maxSize - startPosition);
//			} else {
//				accessionIdQuery.limit(batchSize);
//			}
//
//			stopWatch.split();
//			LOG.debug("Reading Accessions. Stopwatch={}s {}+{} of {}. Processing at {} accessions/s", stopWatch.getSplitTime() / 1000, startPosition, batchSize, count,
//				(double) (startPosition + batchSize) / (stopWatch.getSplitTime() / 1000));
//			results = accessionIdQuery.fetch();
//			LOG.debug("Received {} accession IDs: {}", results.size(), results);
//
//			// Offset is updated above.
//			startPosition += results.size();
//
//			loadAndProcess(results, action);
//
//			// Clear anything cached in the entity manager
//			em.clear();
//		} while (results.size() > 0 && ((maxSize == null || startPosition < maxSize) && (startPosition < count)));

		stopWatch.stop();
		LOG.info("Processing Accessions took {}ms", stopWatch.getTime());
	}


	private void loadAndProcess(List<Long> accessionIds, IAccessionBatchAction action) throws Exception {
		if (! CollectionUtils.isEmpty(accessionIds)) {
			List<Accession> accessions = jpaQueryFactory.selectFrom(QAccession.accession).where(QAccession.accession.id.in(accessionIds)).orderBy(QAccession.accession.id.asc()).fetch();
			action.apply(accessions);
		}
	}

	@Transactional(readOnly = true, propagation = Propagation.REQUIRES_NEW)
	public void processMCPD(JPQLQuery<Long> accessionIdQuery, IAccessionBatchAction action, Long maxSize) throws Exception {
		StopWatch stopWatch = new StopWatch();
		stopWatch.start();

		List<Long> results = accessionIdQuery.fetch();
		var count = results.size();

		for (var pos = 0; pos < count; pos += batchSize) {
			stopWatch.split();
			LOG.debug("Reading Accessions. Stopwatch={}s {}+{} of {}. Processing at {} accessions/s", stopWatch.getSplitTime() / 1000, pos, batchSize, count,
				(double) (batchSize) / (stopWatch.getSplitTime() / 1000));
			loadAndProcessForMCPD(results.subList(pos, Math.min(pos + batchSize, count)), action);
			em.clear();
		}

		stopWatch.stop();
		LOG.info("Processing Accessions took {}ms", stopWatch.getTime());
	}

	private void loadAndProcessForMCPD(List<Long> accessionIds, IAccessionBatchAction action) throws Exception {
		if (! CollectionUtils.isEmpty(accessionIds)) {
			List<Accession> accessions = accessionRepository.findAllById(accessionIds);

			var accessionMap = accessions.stream().collect(Collectors.toMap(AccessionData::getId, accession -> accession));

			if (!accessions.isEmpty()) {

				var acceWithCollMap = accessions.stream()
					.filter(accession -> Objects.nonNull(accession.getAccessionId().getColl()))
					.collect(Collectors.toMap(AccessionData::getId, accession -> accession));

				if (!acceWithCollMap.isEmpty()) {
					acceWithCollMap.values().forEach(accession -> {
						var collection = Hibernate.unproxy(accession.getAccessionId().getColl(), AccessionCollect.class);
						collection.setCollCode(new HashSet<>());
						collection.setCollName(new HashSet<>());
						collection.setCollInstAddress(new HashSet<>());
						accession.getAccessionId().setColl(collection);
					});

					fetchAndAddStrings(QAccessionCollect.accessionCollect, QAccessionCollect.accessionCollect.accession().id, QAccessionCollect.accessionCollect.collCode, acceWithCollMap.keySet(), (id, value) -> {
						var target = acceWithCollMap.get(id).getAccessionId().getColl();
						target.getCollCode().add(value);
					});
					fetchAndAddStrings(QAccessionCollect.accessionCollect, QAccessionCollect.accessionCollect.accession().id, QAccessionCollect.accessionCollect.collName, acceWithCollMap.keySet(), (id, value) -> {
						var target = acceWithCollMap.get(id).getAccessionId().getColl();
						target.getCollName().add(value);
					});
					fetchAndAddStrings(QAccessionCollect.accessionCollect, QAccessionCollect.accessionCollect.accession().id, QAccessionCollect.accessionCollect.collInstAddress, acceWithCollMap.keySet(), (id, value) -> {
						var target = acceWithCollMap.get(id).getAccessionId().getColl();
						target.getCollInstAddress().add(value);
					});
				}

				accessionMap.values().forEach(accession -> {
					accession.getAccessionId().setBreederCode(new HashSet<>());
					accession.getAccessionId().setBreederName(new HashSet<>());
					accession.getAccessionId().setDuplSite(new HashSet<>());
					accession.getAccessionId().setStorage(new HashSet<>());
				});
				fetchAndAddStrings(QAccessionId.accessionId, QAccessionId.accessionId.id, QAccessionId.accessionId.breederCode, accessionMap.keySet(), (id, value) -> {
					var target = accessionMap.get(id).getAccessionId();
					target.getBreederCode().add(value);
				});
				fetchAndAddStrings(QAccessionId.accessionId, QAccessionId.accessionId.id, QAccessionId.accessionId.breederName, accessionMap.keySet(), (id, value) -> {
					var target = accessionMap.get(id).getAccessionId();
					target.getBreederName().add(value);
				});
				fetchAndAddStrings(QAccessionId.accessionId, QAccessionId.accessionId.id, QAccessionId.accessionId.duplSite, accessionMap.keySet(), (id, value) -> {
					var target = accessionMap.get(id).getAccessionId();
					target.getDuplSite().add(value);
				});
				fetchAndAddNumbers(QAccessionId.accessionId, QAccessionId.accessionId.id, QAccessionId.accessionId.storage, accessionMap.keySet(), (id, value) -> {
					var target = accessionMap.get(id).getAccessionId();
					target.getStorage().add(value);
				});
			}

			action.apply(accessions);
		}
	}

	private void fetchAndAddStrings(EntityPath<?> entityPath, NumberPath<Long> entityPathId, SetPath<String, StringPath> setPath, Collection<Long> identifiers, IAddStuff<String> addToCollection) {
		var oneToMany = jpaQueryFactory.from(entityPath)
			.select(entityPathId, setPath.any()).where(entityPathId.in(identifiers)).fetch();

		oneToMany.forEach(combo -> addToCollection.accept(combo.get(entityPathId), combo.get(setPath.any())));
	}

	private <T extends Number & Comparable<?>> void fetchAndAddNumbers(EntityPath<?> entityPath, NumberPath<Long> entityPathId, SetPath<T, NumberPath<T>> setPath, Collection<Long> identifiers, IAddStuff<T> addToCollection) {
		var oneToMany = jpaQueryFactory.from(entityPath)
			.select(entityPathId, setPath.any()).where(entityPathId.in(identifiers)).fetch();

		oneToMany.forEach(combo -> addToCollection.accept(combo.get(entityPathId), combo.get(setPath.any())));
	}

	public interface IAddStuff<T> {

		void accept(Long id, T value);
	}

	/**
	 * Apply action on accessions matching the provided filter.
	 *
	 * @param filter the filter
	 * @param action the action
	 */
	@Transactional(readOnly = true, propagation = Propagation.REQUIRES_NEW)
	public void apply(AccessionFilter filter, IAccessionBatchAction action) {

		apply(filter.buildPredicate(), action);
	}

	/**
	 * Apply action on accessions matching the provided filter.
	 *
	 * @param predicate JPA query predicate on Accession
	 * @param action the action
	 */
	@Transactional(readOnly = true, propagation = Propagation.REQUIRES_NEW)
	public void apply(Predicate predicate, IAccessionBatchAction action) {

//		long count = accessionRepository.count(predicate);
//		LOG.info("{} accessions match the query", count);

		PathBuilder<Accession> builder = new PathBuilderFactory().create(Accession.class);
		Querydsl querydsl = new Querydsl(em, builder);
		JPQLQuery<Long> query = querydsl.createQuery(QAccession.accession)
			// select id only
			.select(QAccession.accession.id)
			// order by id
			.orderBy(QAccession.accession.id.asc())
			// apply filter
			.where(predicate);

		apply(query, action);
	}


	/**
	 * Apply action on accessions matching the provided filter.
	 *
	 * @param predicate JPA query predicate on Accession
	 * @param action the action
	 */
	@Transactional(readOnly = true, propagation = Propagation.REQUIRES_NEW)
	public void apply(JPQLQuery<Long> accessionIdQuery, IAccessionBatchAction action) {

//		long count = accessionIdQuery.fetchCount();
//		LOG.info("{} accessions match the query", count);

		StopWatch stopWatch = new StopWatch();
		stopWatch.start();

		List<Long> results = accessionIdQuery.fetch();
		var count = results.size();

		for (var pos = 0; pos < count; pos += batchSize) {
			stopWatch.split();
			LOG.debug("Reading Accessions. Stopwatch={}s {}+{} of {}. Processing at {} accessions/s", stopWatch.getSplitTime() / 1000, pos, batchSize, count,
				(double) (batchSize) / (stopWatch.getSplitTime() / 1000));
			asyncUpdate(results.subList(pos, Math.min(pos + batchSize, count)), action);
			em.clear();
		}

//		int startPosition = 0;
//		accessionIdQuery.offset(startPosition);
//		accessionIdQuery.limit(batchSize);
//		List<Long> results;
//		do {
//			stopWatch.split();
//			LOG.debug("Reading Accessions. Stopwatch={}s {}+{} of {}. Processing at {} accessions/s", stopWatch.getSplitTime() / 1000, startPosition, batchSize, count,
//				(double) (startPosition + batchSize) / (stopWatch.getSplitTime() / 1000));
//			results = accessionIdQuery.fetch();
//			asyncUpdate(results, action);
//
//			// Next page
//			accessionIdQuery.offset(startPosition += results.size());
//
//			// Clear anything cached in the entity manager
//			em.clear();
//		} while (results.size() > 0);

		stopWatch.stop();
		LOG.info("Processing Accessions took {}ms", stopWatch.getTime());
	}

	private void asyncUpdate(List<Long> accessionIds, IAccessionBatchAction action) {
		if (accessionIds.size() == 0) {
			return;
		}

		final ArrayList<Long> copy = new ArrayList<>(accessionIds);

		taskExecutor.execute(() -> {
			LOG.trace("Executing action on {} Accessions.", copy.size());
			accessionService.processAccessions(copy, action);
			LOG.trace("Done executing action on {} accessions.", copy.size());
		});
	}

}