AccessionRepositoryCustomImpl.java

/*
 * Copyright 2017 Global Crop Diversity Trust
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 **/

package org.genesys.server.persistence;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.UUID;
import java.util.stream.Collectors;

import javax.persistence.EntityManager;
import javax.persistence.PersistenceContext;
import javax.persistence.criteria.CriteriaBuilder;
import javax.persistence.criteria.CriteriaQuery;
import javax.persistence.criteria.Join;
import javax.persistence.criteria.Path;
import javax.persistence.criteria.Predicate;
import javax.persistence.criteria.Root;

import org.apache.commons.lang3.time.StopWatch;
import org.genesys.server.exception.InvalidApiUsageException;
import org.genesys.server.model.genesys.Accession;
import org.genesys.server.model.genesys.AccessionData;
import org.genesys.server.model.genesys.AccessionHistoric;
import org.genesys.server.model.genesys.AccessionId;
import org.genesys.server.model.genesys.QAccession;
import org.genesys.server.model.genesys.QTaxonomy2;
import org.genesys.server.model.genesys.Taxonomy2;
import org.genesys.server.model.impl.AccessionIdentifier3;
import org.genesys.server.model.impl.FaoInstitute;
import org.genesys.server.service.filter.AccessionFilter;
import org.hsqldb.lib.HashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.dao.IncorrectResultSizeDataAccessException;
import org.springframework.data.domain.Pageable;
import org.springframework.data.jpa.repository.support.Querydsl;
import org.springframework.stereotype.Repository;
import org.springframework.transaction.annotation.Transactional;

import com.querydsl.core.types.dsl.PathBuilder;
import com.querydsl.core.types.dsl.PathBuilderFactory;
import com.querydsl.jpa.JPAExpressions;
import com.querydsl.jpa.JPQLQuery;
import com.querydsl.jpa.impl.JPAQuery;
import com.querydsl.jpa.impl.JPAQueryFactory;

@Repository
@Transactional(readOnly = true)
public class AccessionRepositoryCustomImpl implements AccessionRepositoryCustom, InitializingBean {
	public static final Logger LOG = LoggerFactory.getLogger(AccessionRepositoryCustomImpl.class);
	private static final Predicate[] EMPTY_PREDICATE_ARRAY = new Predicate[] {};

	@Autowired
	private JPAQueryFactory jpaQueryFactory;
	
	@PersistenceContext
	private EntityManager em;

	private CriteriaBuilder criteriaBuilder;

	PathBuilder<Accession> pathBuilder = new PathBuilderFactory().create(Accession.class);
	
	@Override
	public void afterPropertiesSet() throws Exception {
		this.criteriaBuilder = em.getCriteriaBuilder();
	}

	@Override
	public List<Accession> find(boolean useGenus, List<? extends AccessionIdentifier3> forUpdate) {
		if (forUpdate == null || forUpdate.isEmpty()) {
			return Collections.emptyList();
		}

		CriteriaQuery<Long> cq = criteriaBuilder.createQuery(Long.class);
		Root<Accession> root = cq.from(Accession.class);
		cq.distinct(true);
		cq.select(root.get("id"));

		List<Accession> res = new ArrayList<>(forUpdate.size());
		StopWatch stopWatch = StopWatch.createStarted();
		final int chunkSize = 1000;
		for (int fromIndex = 0; fromIndex < forUpdate.size(); fromIndex += chunkSize) {
			stopWatch.reset();
			stopWatch.start();
			List<? extends AccessionIdentifier3> sublist = forUpdate.subList(fromIndex, Math.min(forUpdate.size(), fromIndex + chunkSize))
					// Ignore NULL entries
					.stream().filter(entry -> entry != null).collect(Collectors.toList());

			var uniqueUuids = sublist.stream().map(AccessionIdentifier3::getUuid).filter(uuid -> uuid != null).collect(Collectors.toSet());
			var uniqueDois = sublist.stream().map(AccessionIdentifier3::getDoi).filter(doi -> doi != null).collect(Collectors.toSet());
			var uniqueDataProviderIds = sublist.stream().map(AccessionIdentifier3::getDataProviderId).filter(dpId -> dpId != null).collect(Collectors.toSet());

			LOG.debug("Ready to match {} DOIs and {} overall to Accession in {}ms", uniqueDois.size(), sublist.size(), stopWatch.getTime());

			// Set if accession IDs for this batch is calculated by excuting up to three queries
			Set<Long> accessionIds = new HashSet<>(chunkSize);

			if (uniqueUuids.size() > 0) {
				var uuidQ = criteriaBuilder.createQuery(Long.class);
				var accessionId = uuidQ.from(AccessionId.class);
				uuidQ.select(accessionId.get("id"));
				uuidQ.where(accessionId.get("uuid").in(uniqueUuids));
				accessionIds.addAll(em.createQuery(uuidQ).getResultList());
				LOG.trace("*** {} uuids={}", uniqueUuids.size(), uniqueUuids);
			}

			var accessionQ = criteriaBuilder.createQuery(Long.class);
			var accession = accessionQ.from(Accession.class);
			accessionQ.select(accession.get("id"));

			if (uniqueDois.size() > 0) {
				accessionQ.where(accession.get("doi").in(uniqueDois));
				accessionIds.addAll(em.createQuery(accessionQ).getResultList());
				LOG.trace("*** {} dois={}", uniqueDois.size(), uniqueDois);
			}

			if (uniqueDataProviderIds.size() > 0) {
				accessionQ.where(accession.get("dataProviderId").in(uniqueDataProviderIds));
				accessionIds.addAll(em.createQuery(accessionQ).getResultList());
				LOG.trace("*** {} dataProviderIds={}", uniqueDataProviderIds.size(), uniqueDataProviderIds);
			}

			{
				Path<String> theInstCode = accession.get("instituteCode");
				Path<String> theAcceNumb = accession.get("accessionNumber");
				Path<String> theGenus = accession.get("genus");
	
				// A lot of .. (instCode=? and acceNumb=? and genus=?)
				List<Predicate> a3r = new ArrayList<Predicate>();
				for (AccessionIdentifier3 ah : sublist) {
					if (useGenus) {
						a3r.add(criteriaBuilder.and(criteriaBuilder.equal(theInstCode, ah.getHoldingInstitute()), criteriaBuilder.equal(theAcceNumb, ah.getAccessionNumber()),
							criteriaBuilder.equal(theGenus, ah.getGenus())));
					} else {
						a3r.add(criteriaBuilder.and(criteriaBuilder.equal(theInstCode, ah.getHoldingInstitute()), criteriaBuilder.equal(theAcceNumb, ah
							.getAccessionNumber())));
					}
				}
				accessionQ.where(criteriaBuilder.or(a3r.toArray(EMPTY_PREDICATE_ARRAY)));

				accessionIds.addAll(em.createQuery(accessionQ).getResultList());
			}

			LOG.info("Matched {} things to {} accession IDs in {}ms", sublist.size(), accessionIds.size(), stopWatch.getTime());
			List<Accession> accessions = jpaQueryFactory.selectFrom(QAccession.accession).where(QAccession.accession.id.in(accessionIds))
				// This order by makes sure that we will first get a match by doi, then dataProviderID, etc.
				.orderBy(QAccession.accession.doi.desc(), QAccession.accession.dataProviderId.desc(), QAccession.accession.instituteCode.desc(), QAccession.accession.genus.desc(), QAccession.accession.accessionNumber.desc())
				.fetch()
			;

			res.addAll(accessions);
			// accessions.forEach(a -> {
			// em.detach(a.getAccessionId());
			// em.detach(a);
			// });
			LOG.info("Now have Accession size={} in {}ms", res.size(), stopWatch.getTime());
			
		}
		
		if (LOG.isDebugEnabled())
			LOG.trace("*** Loaded accessions {} of {}", res.size(), forUpdate.size());

		LOG.info("Done matching! {} refs to {} Accessions", forUpdate.size(), res.size());
		return res;
	}

	@Override
	public Accession findOne(FaoInstitute institute, String doi, String acceNumb, String genus) {
		CriteriaQuery<Accession> cq = criteriaBuilder.createQuery(Accession.class);
		Root<Accession> root = cq.from(Accession.class);
		cq.distinct(true);
		cq.select(root);
		// root.fetch("stoRage", JoinType.LEFT);
		Join<Object, Object> tax = root.join("taxonomy");

		List<Predicate> restrictions = new ArrayList<Predicate>();

		if (doi != null) {
			restrictions.add(criteriaBuilder.equal(root.get("doi"), doi));
		} else if (institute != null && institute.hasUniqueAcceNumbs()) {
			restrictions.add(criteriaBuilder.equal(root.get("accessionNumber"), acceNumb));
		} else {
			restrictions.add(criteriaBuilder.and(criteriaBuilder.equal(root.get("accessionNumber"), acceNumb), criteriaBuilder.equal(tax.get("genus"), genus)));
		}

		cq.where(criteriaBuilder.and(criteriaBuilder.equal(root.get("institute"), institute), criteriaBuilder.or(restrictions.toArray(new Predicate[] {}))));

		List<Accession> result = em.createQuery(cq).getResultList();

		if (result.isEmpty()) {
			return null;
		}

		if (result.size() > 1) {
			throw new IncorrectResultSizeDataAccessException(1, result.size());
		}

		return result.get(0);
	}

	@Override
	public List<AccessionData> findActiveAndHistoric(Collection<UUID> accessionUuids) {
		if (accessionUuids.isEmpty()) {
			return Collections.emptyList();
		}

		List<AccessionData> activeAndHistoric = new ArrayList<>(accessionUuids.size());

		CriteriaQuery<Accession> cq = criteriaBuilder.createQuery(Accession.class);
		Root<Accession> active = cq.from(Accession.class);
		cq.distinct(true);
		cq.select(active);
		cq.where(active.get("accessionId").get("uuid").in(accessionUuids));

		activeAndHistoric.addAll(em.createQuery(cq).getResultList());

		CriteriaQuery<AccessionHistoric> cqhist = criteriaBuilder.createQuery(AccessionHistoric.class);
		Root<AccessionHistoric> historic = cqhist.from(AccessionHistoric.class);
		cqhist.distinct(true);
		cqhist.select(historic);
		cqhist.where(historic.get("accessionId").get("uuid").in(accessionUuids));

		activeAndHistoric.addAll(em.createQuery(cqhist).getResultList());

		return activeAndHistoric;
	}
	
	@Override
	public List<Accession> find(Collection<UUID> accessionUuids) {
		if (accessionUuids.isEmpty()) {
			return Collections.emptyList();
		}
		JPAQuery<Accession> query = jpaQueryFactory.selectFrom(QAccession.accession).distinct().where(QAccession.accession.accessionId().uuid.in(accessionUuids));
		return query.fetch();
	}
	
	@Override
	public long count(AccessionFilter filter) {
		Querydsl querydsl = new Querydsl(em, pathBuilder);
		QAccession qAccession = QAccession.accession;
		JPQLQuery<Object> query = querydsl.createQuery(qAccession);
		
		filter.buildJpaQuery(query, qAccession);
		
		return query.fetchCount();
	}

	@Override
	public List<Accession> findAll(AccessionFilter filter, Pageable page) {
		if (filter.isFulltextQuery()) {
			throw new InvalidApiUsageException("Use elasticService with full-text queries!");
		}

		Querydsl querydsl = new Querydsl(em, pathBuilder);
		QAccession qAccession = QAccession.accession;
		JPQLQuery<Object> query = querydsl.createQuery(qAccession);
		
		filter.buildJpaQuery(query, qAccession);
		
		querydsl.applyPagination(page, query);
		querydsl.applySorting(page.getSort(), query);
		
		return query.select(qAccession).fetch();
	}

	@Override
	public List<UUID> getUUIDs(AccessionFilter filter) {
		if (filter.isFulltextQuery()) {
			throw new InvalidApiUsageException("Use elasticService with full-text queries!");
		}

		Querydsl querydsl = new Querydsl(em, pathBuilder);
		QAccession qAccession = QAccession.accession;
		JPQLQuery<Object> query = querydsl.createQuery(qAccession);

//		filter.buildJpaQuery(query, qAccession);

		return query.select(qAccession.accessionId().uuid).where(filter.buildPredicate()).orderBy(qAccession.seqNo.asc(), qAccession.id.asc()).fetch();
	}

	@Override
	public List<Taxonomy2> listSpecies(AccessionFilter filter) {
		JPQLQuery<Taxonomy2> subquery = JPAExpressions.select(QAccession.accession.taxonomy()).from(QAccession.accession).distinct();
		if (filter != null) {
			filter.buildJpaQuery(subquery, QAccession.accession);
		}

		return jpaQueryFactory.selectFrom(QTaxonomy2.taxonomy2).where(QTaxonomy2.taxonomy2.in(subquery))
			// order
			.orderBy(QTaxonomy2.taxonomy2.genus.asc(), QTaxonomy2.taxonomy2.species.asc(), QTaxonomy2.taxonomy2.subtaxa.asc()).fetch();
	}
}