ElasticQueryBuilder.java

package org.genesys.server.component.elastic;

import static com.google.common.collect.Lists.*;
import static org.elasticsearch.index.query.QueryBuilders.*;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.stream.Collectors;
import org.apache.lucene.search.join.ScoreMode;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.MatchPhrasePrefixQueryBuilder;
import org.elasticsearch.index.query.MatchPhraseQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.index.query.RangeQueryBuilder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.data.elasticsearch.annotations.Field;
import org.springframework.data.elasticsearch.annotations.FieldType;
import org.springframework.util.ReflectionUtils;

import com.google.common.collect.ImmutableList;
import com.querydsl.core.QueryMetadata;
import com.querydsl.core.types.Constant;
import com.querydsl.core.types.Expression;
import com.querydsl.core.types.FactoryExpression;
import com.querydsl.core.types.Operation;
import com.querydsl.core.types.OperationImpl;
import com.querydsl.core.types.Operator;
import com.querydsl.core.types.Ops;
import com.querydsl.core.types.ParamExpression;
import com.querydsl.core.types.Path;
import com.querydsl.core.types.PathImpl;
import com.querydsl.core.types.PathMetadata;
import com.querydsl.core.types.PathType;
import com.querydsl.core.types.Predicate;
import com.querydsl.core.types.SubQueryExpression;
import com.querydsl.core.types.TemplateExpression;
import com.querydsl.core.types.Visitor;
import com.querydsl.core.types.dsl.NumberPath;

/**
 * Converter from a Querydsl predicate to Elasticsearch query.
 * 
 * @author Matija Obreza
 * @author Maxym Borodenko
 */
public class ElasticQueryBuilder implements Visitor<Void, Class<?>> {
	private static Logger LOG = LoggerFactory.getLogger(ElasticQueryBuilder.class);

	/** This caches the results of {@link #isNested(Class, Path)}. */
	private static Map<Class<?>, Map<Path<?>, Boolean>> NESTED_CACHE = new HashMap<>();

	private final List<QueryBuilder> mustClauses = new ArrayList<>();
	private final List<QueryBuilder> mustNotClauses = new ArrayList<>();

	private final Map<String, RangeQueryBuilder> ranges = new HashMap<>();
	
	public QueryBuilder getQuery() {
		BoolQueryBuilder root = QueryBuilders.boolQuery();
		mustClauses.forEach(must -> root.filter(must));
		mustNotClauses.forEach(mustNot -> root.mustNot(mustNot));
		return root;
	}

	private String customizedPath(String path) {
		String root = null;
		// Just remove the entity name from the path -- hopefully that's fine
		if (path.contains(".")) {
			int firstDot = path.indexOf('.');
			root = path.substring(0, firstDot);
			path = path.substring(firstDot + 1);
		}

		// FIXME We should try to build this from annotations on startup. This is very
		// hacky.
		if ("accession".equals(root)) {
			if (path.equals("accessionId.lists.uuid")) {
				// Example of @JsonIdentityReference(alwaysAsId = true)
				return "lists";
			} else if (path.equals("datasets.uuid")) {
				// Example of @JsonIdentityReference(alwaysAsId = true)
				return "datasets";
			} else if (path.equals("subsets.uuid")) {
				// Example of @JsonIdentityReference(alwaysAsId = true)
				return "subsets";
			} else if (path.equals("diversityTrees.uuid")) {
				// Example of @JsonIdentityReference(alwaysAsId = true)
				return "diversityTrees";
			} else if (path.startsWith("accessionId.")) {
				// Example of @JsonUnwrapped
				return path.replace("accessionId.", "");
			} else if (path.equals("institute.networks.slug")) {
				// Example of @JsonIdentityReference(alwaysAsId = true)
				return "institute.networks";
			}
		}
		
		return path;
	}

	@Override
	public Void visit(Constant<?> c, Class<?> context) {
		LOG.debug("+Constant: {}", c.getConstant());
		return null;
	}

	@Override
	public Void visit(FactoryExpression<?> expr, Class<?> context) {
		LOG.debug("+FactoryExpression: {}", expr.getArgs());
		return null;
	}

	@Override
	public Void visit(Operation<?> expr, Class<?> context) {
		LOG.debug("+Operation context={}: {} {} {}", context, expr.getType(), expr.getOperator(), expr.getArgs());
		visitOperation(context, expr.getType(), expr.getOperator(), expr.getArgs());
		return null;
	}

	private void visitOperation(Class<?> context, Class<?> type, Operator operator, List<Expression<?>> args) {
		if (operator == Ops.AND) {
			handleAnd(context, args);
		} else if (operator == Ops.OR) {
			handleOr(context, args);
		} else if (operator == Ops.EQ || operator == Ops.IN) {
			if (Path.class.isAssignableFrom(args.get(0).getClass())) {
				LOG.debug("EQUALS: {}", args);
				for (Expression<?> expr : args) {
					printExpression("EQUALS.. ", expr);
				}
				Path<?> a0 = (Path<?>) args.get(0);
				Expression<?> a1 = args.get(1);
				handleEquals(context, a0, a1);
			} else {
				Path<?> path = (Path<?>)((OperationImpl<?>) args.get(0)).getArg(0);
				PathMetadata pmd = path.getMetadata();
				mustNotClauses.add(existsQuery(customizedPath(getParentPath(pmd.getParent()) + "." + pmd.getName())));
			}
		} else if (operator == Ops.LOE || operator == Ops.GOE || operator == Ops.BETWEEN || operator == Ops.LT || operator == Ops.GT) {
			if (Path.class.isAssignableFrom(args.get(0).getClass())) {
				LOG.debug("Range: {}", args);
				for (Expression<?> expr : args) {
					printExpression("LOE.. ", expr);
				}
				Path<?> a0 = (Path<?>) args.get(0);
				handleRange(operator, a0, args.get(1), args.size() > 2 ? args.get(2) : null);
			} else {
				Path<?> path = (Path<?>)((OperationImpl<?>) args.get(0)).getArg(0);
				PathMetadata pmd = path.getMetadata();
				mustClauses.add(existsQuery(customizedPath(getParentPath(pmd.getParent()) + "." + pmd.getName())));
			}
		} else if (operator == Ops.STRING_CONTAINS || operator == Ops.STARTS_WITH) {
			LOG.debug("{}: {}", operator, args);
			for (Expression<?> expr : args) {
				printExpression(operator + ".. ", expr);
			}
			Path<?> a0 = (Path<?>) args.get(0);
			Expression<?> a1 = args.get(1);

			handleLike(operator, a0, a1);
		} else if (operator == Ops.NOT) {
			LOG.debug("{}: {}", operator, args);
			for (Expression<?> expr : args) {
				printExpression(operator + ".. ", expr);
			}
			Expression<?> notExp = args.get(0);
			handleNot(context, notExp);
		} else if (operator == Ops.IS_NOT_NULL) {
			LOG.debug("{}: {}", operator, args);
			for (Expression<?> expr : args) {
				printExpression(operator + ".. ", expr);
			}
			Path<?> path = (Path<?>) args.get(0);
			PathMetadata pmd = path.getMetadata();
			if (isNested(context, path)) {
				mustClauses.add(nestedQuery(
					customizedPath(getParentPath(pmd.getParent())), // + "." + pmd.getName()),
					existsQuery(customizedPath(getParentPath(pmd.getParent()) + "." + pmd.getName())),
					ScoreMode.Avg
				));
			} else {
				mustClauses.add(existsQuery(customizedPath(getParentPath(pmd.getParent()) + "." + pmd.getName())));
			}
		} else if (operator == Ops.IS_NULL) {
			LOG.debug("{}: {}", operator, args);
			for (Expression<?> expr : args) {
				printExpression(operator + ".. ", expr);
			}
			Path<?> path = (Path<?>) args.get(0);
			PathMetadata pmd = path.getMetadata();
			if (isNested(context, path)) {
				mustNotClauses.add(nestedQuery(
					customizedPath(getParentPath(pmd.getParent())), // + "." + pmd.getName()),
					existsQuery(customizedPath(getParentPath(pmd.getParent()) + "." + pmd.getName())),
					ScoreMode.Avg
				));
			} else {
				mustNotClauses.add(existsQuery(customizedPath(getParentPath(pmd.getParent()) + "." + pmd.getName())));
			}
		} else if (operator == Ops.COL_IS_EMPTY) {
			LOG.debug("{}: {}", operator, args);
			for (Expression<?> expr : args) {
				printExpression(operator + ".. ", expr);
			}
			Path<?> path = (Path<?>) args.get(0);
			PathMetadata pmd = path.getMetadata();
			
			LOG.debug("Path in context={} {} {} for {} {} NULL/EMPTY", context, pmd.getPathType(), pmd.getName(), pmd.getParent(), pmd.getParent().getMetadata().getPathType());

			if (isNested(context, path)) {
				mustNotClauses.add(nestedQuery(
					customizedPath(getParentPath(pmd.getParent()) + "." + pmd.getName()),
					existsQuery(customizedPath(getParentPath(pmd.getParent()) + "." + pmd.getName())),
					ScoreMode.Avg
				));
			} else {
				mustNotClauses.add(existsQuery(customizedPath(getParentPath(pmd.getParent()) + "." + pmd.getName())));
			}
		} else {
			LOG.error("Op {}: {}", operator, args);
		}
		// Expression<?> a0 = args.get(0);
		// Expression<?> a1 = args.get(1);
		// printExpression("a1: " + type.getName() + " " + operator, a1);
	}

	/**
	 * Figure out if the path is pointing to a nested entity or not
	 * @param context the context
	 * @param path the path
	 * @return true if nested
	 */
	private boolean isNested(Class<?> context, Path<?> path) {
		if (context == null) {
			LOG.info("Context not provided, assuming not-nested");
			return false;
		}

		return NESTED_CACHE.computeIfAbsent(context, (k) -> new HashMap<>()).computeIfAbsent(path, (p) -> {
			var pmd = p.getMetadata();
			var pathParts = (getParentPath(pmd.getParent()) + "." + pmd.getName()).split("\\.");
			LOG.debug("Checking if {}/{} is nested or not", context.getSimpleName(), pathParts);
	
			// Convention: pathParts[0] equals Clazz.simpleName.toLowerCase()
			// assert(pathParts[0].equals(context.getSimpleName().toLowerCase()));
			Class<?> prop = context; // Root
			for (var i = 1; i < pathParts.length; i++) {
				LOG.debug("  Checking {}.{}", prop, pathParts[i]);
				var propField = ReflectionUtils.findField(prop, pathParts[i]);
				if (propField == null) {
					LOG.warn("No field {} in {}", pathParts[i], prop);
					return false;
				}
	
				LOG.debug("   Found {}.{} = {}", prop, pathParts[i], propField);
	
				var fieldAnnotation = propField.getAnnotation(Field.class);
				if (fieldAnnotation != null && fieldAnnotation.type() == FieldType.Nested) {
					return true;
				} else if (fieldAnnotation != null && fieldAnnotation.type() == FieldType.Object) {
					prop = propField.getType(); // Go deeper
					continue;
				} else if (fieldAnnotation != null && fieldAnnotation.type() == FieldType.Keyword) {
					return false;
				}
	
				// var propType = propField.getType();
				// if (Collection.class.isAssignableFrom(propType)) {
				// 	ParameterizedType paramType = (ParameterizedType) propField.getGenericType();
				// 	Class<?> paramClass = (Class<?>) paramType.getActualTypeArguments()[0];
				// 	LOG.debug("    is a Collection of {}", paramClass);
				// 	if (EmptyModel.class.isAssignableFrom(paramClass)) {
				// 		var jsonIdentityReference = propField.getAnnotation(JsonIdentityReference.class);
				// 		if (jsonIdentityReference != null && jsonIdentityReference.alwaysAsId()) {
				// 			LOG.debug("     is a Collection of IDs only");
				// 			return false;
				// 		} else {
				// 			LOG.debug("     is a Collection of entities");
				// 			return true;
				// 		}
				// 	}
	
				// } else if (EmptyModel.class.isAssignableFrom(propType)) {
				// 	LOG.debug("    is an Object of {}", propType);
				// 	var jsonIdentityReference = propField.getAnnotation(JsonIdentityReference.class);
				// 	if (jsonIdentityReference != null && jsonIdentityReference.alwaysAsId()) {
				// 		LOG.debug("     is an ID only");
				// 		return false;
				// 	} else {
				// 		LOG.debug("     is an entity");
				// 		return true;
				// 	}
	
				else {
					// Next prop?
					prop = propField.getType();
				}
			}
			return false;
		});
	}

	public int size() {
		return this.mustClauses.size() + this.mustNotClauses.size();
	}

	private void handleAnd(Class<?> context, List<Expression<?>> args) {
		LOG.debug("AND expr: {}", args);
		ElasticQueryBuilder andBuilder = new ElasticQueryBuilder();
		for (Expression<?> a : args) {
			a.accept(andBuilder, context);
		}
		if (andBuilder.size() == 1 && andBuilder.mustClauses.size() > 0) {
			mustClauses.addAll(andBuilder.mustClauses);
		} else if (andBuilder.size() == 1 && andBuilder.mustNotClauses.size() > 0) {
			mustNotClauses.addAll(andBuilder.mustNotClauses);
		}  else {
			mustClauses.add(andBuilder.getQuery());
		}
	}

	private void handleOr(Class<?> context, List<Expression<?>> args) {
		LOG.debug("OR expr: {}", args);
		ElasticQueryBuilder orBuilder = new ElasticQueryBuilder();
		for (Expression<?> a : args) {
			a.accept(orBuilder, context);
		}

		BoolQueryBuilder orQuery = boolQuery();
		orQuery.minimumShouldMatch(1);
		orBuilder.mustClauses.forEach(should -> orQuery.should(should));
		mustClauses.add(orQuery);
	}

	private void handleNot(Class<?> context, Expression<?> notExp) {
		LOG.debug("NOT expr: {}", notExp);
		ElasticQueryBuilder notBuilder = new ElasticQueryBuilder();
		notExp.accept(notBuilder, context);

		notBuilder.mustClauses.forEach(mustNot -> mustNotClauses.add(mustNot));
		notBuilder.mustNotClauses.forEach(must -> mustClauses.add(must));
	}

	private void handleLike(Operator operator, Path<?> path, Expression<?> val) {
		PathMetadata pmd = path.getMetadata();
		// SimpleQueryStringBuilder qsq = simpleQueryStringQuery( +":" + toValue(val));
		if (operator == Ops.STARTS_WITH) {
			MatchPhrasePrefixQueryBuilder matchPrefixQuery = matchPhrasePrefixQuery(customizedPath(getParentPath(pmd.getParent()) + "." + pmd.getName()), toValue(val));
			mustClauses.add(matchPrefixQuery);
		} else if (operator == Ops.STRING_CONTAINS) {
			MatchPhraseQueryBuilder matchPrefixQuery = matchPhraseQuery(customizedPath(getParentPath(pmd.getParent()) + "." + pmd.getName()), toValue(val));
			mustClauses.add(matchPrefixQuery);
		} else {
			throw new RuntimeException("Unsupported ES handleLike operator: " + operator);
		}
	}

	private void handleRange(Operator operator, Path<?> path, Expression<?> val1, Expression<?> val2) {
		PathMetadata pmd = path.getMetadata();
		RangeQueryBuilder rq;

		if (ranges.get(path.toString()) != null) {
			rq = ranges.get(path.toString());
		} else {
			rq = rangeQuery(customizedPath(getParentPath(pmd.getParent()) + "." + pmd.getName()));
			ranges.put(path.toString(), rq);
		}

		if (operator == Ops.LOE) {
			rq.lte(toValue(val1));
		} else if (operator == Ops.LT) {
			rq.lt(toValue(val1));
		} else if (operator == Ops.GOE) {
			rq.gte(toValue(val1));
		} else if (operator == Ops.GT) {
			rq.gt(toValue(val1));
		} else if (operator == Ops.BETWEEN) {
			rq.gte(toValue(val1));
			rq.lte(toValue(val2));
		} else if (operator == Ops.LOE) {
			rq.lte(toValue(val1));
		}
		mustClauses.add(rq);
	}

	private void handleEquals(Class<?> context, Path<?> path, Expression<?> value) {
		PathMetadata pmd = path.getMetadata();
		if (pmd.getPathType() == PathType.COLLECTION_ANY) {
			LOG.debug("Path ANY for {}={}", pmd.getParent(), value);
			mustClauses.add(termsQuery(customizedPath(getParentPath(pmd.getParent())), toValues(value)));
		} else {
			LOG.debug("Path for {} {}={}", pmd.getParent(), pmd.getParent().getMetadata().getPathType(), value);
			if (isNested(context, path)) {
				mustClauses.add(nestedQuery(
					customizedPath(getParentPath(pmd.getParent())), // + "." + pmd.getName()),
					termsQuery(customizedPath(getParentPath(pmd.getParent()) + "." + pmd.getName()), toValues(value)),
					ScoreMode.Avg
				));
			} else {
				mustClauses.add(termsQuery(customizedPath(getParentPath(pmd.getParent()) + "." + pmd.getName()), toValues(value)));
			}
		}
	}

	/** Get a clean path (without "any(parent)") */
	private static String getParentPath(Path<?> path) {
		String pathValue = path.toString();

		if (pathValue.startsWith("any")) {
			return getParentPath(path.getMetadata().getParent());
		}

		return pathValue;
	}

	private static Object toValue(Expression<?> value) {
		if (value instanceof Constant<?>) {
			Constant<?> cons = (Constant<?>) value;
			Object obj = cons.getConstant();
			return convertValue(obj);
		}

		throw new RuntimeException("Unhandled value " + value);
	}

	private static Object convertValue(Object obj) {
		if (obj == null) {
			return null;
		}
		Class<? extends Object> objClass = obj.getClass();
		LOG.debug("toValue of {}: c={}", obj, objClass);
		if (objClass.isEnum()) {
			return obj.toString();
		}
		if (UUID.class.isAssignableFrom(objClass)) {
			return obj.toString();
		}
		return obj;
	}

	private static Collection<?> toValues(Expression<?> value) {
		if (value instanceof Constant<?>) {
			Constant<?> cons = (Constant<?>) value;
			Object obj = cons.getConstant();
			LOG.debug("toValues of {}", obj);

			if (obj instanceof Collection<?>) {
				Collection<?> c = (Collection<?>) obj;
				return c.stream().map((forString) -> convertValue(forString)).collect(Collectors.toList());
			} else {
				return newArrayList(convertValue(obj));
			}
		}

		throw new RuntimeException("Unhandled value " + value);
	}

	private static void printExpression(String prefix, Expression<?> expr) {
		if (expr instanceof NumberPath<?>) {
			NumberPath<?> path = (NumberPath<?>) expr;
			LOG.debug("{}: NumberPath {} {}", prefix, path.getRoot(), path.getType());

		} else if (expr instanceof Path<?>) {
			PathImpl<?> path = (PathImpl<?>) expr;
			PathMetadata pmd = path.getMetadata();
			if (pmd.getPathType() == PathType.COLLECTION_ANY) {
				LOG.debug("{}: {} {} parent={}", prefix, pmd.getPathType(), pmd.getElement(), pmd.getParent());
			} else {
				LOG.debug("{}: {} {}/{} parent={} {}", prefix, pmd.getPathType(), pmd.getName(), pmd.getElement(), pmd.getParent(), pmd.getParent().getMetadata().getPathType());
			}
		
		} else if (expr instanceof Constant<?>) {
			Constant<?> cons = (Constant<?>) expr;
			LOG.debug("{}: Constant {} {}", prefix, cons.getConstant(), cons.getType());
		
		} else if (expr instanceof Predicate) {
			Predicate pred = (Predicate) expr;
			LOG.debug("{}: should visit Predicate {}", prefix, pred);
		
		} else {
			LOG.debug("{}: {} {}", prefix, expr.getClass(), expr.getType());
		}
	}

	@Override
	public Void visit(ParamExpression<?> param, Class<?> context) {
		LOG.debug("+ParamExpression: {} {} {}", param.getType(), param.isAnon(), param.getName());
		return null;
	}

	@Override
	public Void visit(Path<?> path, Class<?> context) {
		final PathType pathType = path.getMetadata().getPathType();
		final Object element = path.getMetadata().getElement();
		List<Object> args;
		if (path.getMetadata().getParent() != null) {
			args = ImmutableList.of(path.getMetadata().getParent(), element);
		} else {
			args = ImmutableList.of(element);
		}
		LOG.debug("+Path: {} {} {}", pathType, pathType.name(), args);
		return null;
	}

	@Override
	public Void visit(SubQueryExpression<?> query, Class<?> context) {
		QueryMetadata qm = query.getMetadata();
		LOG.debug("+SubQueryExpression: ", qm);
		return null;
	}

	@Override
	public Void visit(TemplateExpression<?> expr, Class<?> context) {
		LOG.debug("+TemplateExpr: ", expr.getTemplate(), expr.getArgs());
		return null;
	}

}