AuditTrailInterceptor.java

/*
 * Copyright 2018 Global Crop Diversity Trust
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.genesys.blocks.auditlog.component;

import java.io.Serializable;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.text.Format;
import java.time.Instant;
import java.time.LocalDate;
import java.time.ZoneOffset;
import java.time.format.DateTimeFormatter;
import java.util.ArrayList;
import java.util.Calendar;
import java.util.Collection;
import java.util.Collections;
import java.util.Date;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.Stack;
import java.util.UUID;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import javax.persistence.EntityManager;
import javax.persistence.PersistenceContext;
import javax.persistence.Temporal;
import javax.persistence.TemporalType;

import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.time.FastDateFormat;
import org.genesys.blocks.auditlog.annotations.Audited;
import org.genesys.blocks.auditlog.annotations.HideAuditValue;
import org.genesys.blocks.auditlog.annotations.NotAudited;
import org.genesys.blocks.auditlog.model.AuditAction;
import org.genesys.blocks.auditlog.model.AuditLog;
import org.genesys.blocks.auditlog.model.TransactionAuditLog;
import org.genesys.blocks.auditlog.service.AuditTrailService;
import org.genesys.blocks.model.BasicModel;
import org.genesys.blocks.model.EntityId;
import org.hibernate.CallbackException;
import org.hibernate.EmptyInterceptor;
import org.hibernate.Transaction;
import org.hibernate.collection.spi.PersistentCollection;
import org.hibernate.resource.transaction.spi.TransactionStatus;
import org.hibernate.type.Type;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.core.ResolvableType;
import org.springframework.stereotype.Component;
import org.springframework.util.ReflectionUtils;

import lombok.extern.slf4j.Slf4j;

/**
 * Record changed data using {@link AuditLog} entries.
 *
 * @author Matija Obreza
 */
@Component
@Slf4j
public class AuditTrailInterceptor extends EmptyInterceptor implements InitializingBean {

	/** The Constant serialVersionUID. */
	private static final long serialVersionUID = 1881637304461659508L;

	/** The Constant DEFAULT_IGNORED_PROPERTIES. */
	private static final Set<String> DEFAULT_IGNORED_PROPERTIES = Stream.of("serialVersionUID", "id", "createdDate", "lastModifiedDate", "version", "lastModifiedBy")
		.collect(Collectors.toSet());

	/** The ignored properties. */
	private Set<String> ignoredProperties = new HashSet<>(DEFAULT_IGNORED_PROPERTIES);

	/** The ignored properties of audited entities. */
	private final Map<Class<?>, Set<String>> ignoredClassFields;

	/** The secured properties of audited entities. */
	private final Map<Class<?>, Set<String>> securedClassFields;

	/** The audited classes. */
	private Set<Class<?>> auditedClasses = new HashSet<>();

	/** The included classes. */
	// Two caches
	private final Set<Class<?>> ignoredClasses, includedClasses;

	/** The audit trail service. */
	@Autowired
	private transient AuditTrailService auditTrailService;

	/** The entity manager. */
	@PersistenceContext
	private transient EntityManager entityManager;

	/** The date format. */
	private final static String dateFormat = "dd-MMM-yyyy";

	/** The time format. */
	private final static String timeFormat = "HH:mm:ss";

	/** The date time format. */
	private final static String dateTimeFormat = "dd-MMM-yyyy HH:mm:ss";

	/** The date formatter. */
	private final static Format dateFormatter = FastDateFormat.getInstance(dateFormat);

	/** The date time formatter. */
	private final static Format dateTimeFormatter = FastDateFormat.getInstance(dateTimeFormat);

	/** The time formatter. */
	private final static Format timeFormatter = FastDateFormat.getInstance(timeFormat);

	/** Place to store audit logs before storing them to db */
	private static final ThreadLocal<Stack<Set<TransactionAuditLog>>> auditLogStack = new ThreadLocal<Stack<Set<TransactionAuditLog>>>() {
		@Override
		protected Stack<Set<TransactionAuditLog>> initialValue() {
			return new Stack<Set<TransactionAuditLog>>();
		};
	};

	/**
	 * Instantiates a new audit trail interceptor.
	 */
	public AuditTrailInterceptor() {
		log.info("Enabling {}", getClass().getName());
		// make synchronized local caches
		ignoredClasses = Collections.synchronizedSet(new HashSet<>());
		includedClasses = Collections.synchronizedSet(new HashSet<>());
		ignoredClassFields = Collections.synchronizedMap(new HashMap<>());
		securedClassFields = Collections.synchronizedMap(new HashMap<>());
	}

	/*
	 * (non-Javadoc)
	 * @see org.springframework.beans.factory.InitializingBean#afterPropertiesSet()
	 */
	@Override
	public void afterPropertiesSet() throws Exception {
		assert ignoredProperties != null;
		assert auditTrailService != null;

		// Make them unmodifiable
		ignoredProperties = Collections.unmodifiableSet(ignoredProperties);
		auditedClasses = Collections.unmodifiableSet(auditedClasses);
	}

	/**
	 * Explicitly set the list of classes that should be audited. Note that any
	 * class with {@link Audited} annotation will be included, even if not on this
	 * list.
	 *
	 * @param auditedClasses entity classes to audit
	 * @see Audited
	 */
	public void setAuditedClasses(final Set<Class<?>> auditedClasses) {
		this.auditedClasses = auditedClasses;
	}

	/**
	 * Gets the audited classes.
	 *
	 * @return the audited classes
	 */
	public Set<Class<?>> getAuditedClasses() {
		return auditedClasses;
	}

	/**
	 * Set the list of properties to ignore on all entities (e.g. "password").
	 * Defaults to {@link #DEFAULT_IGNORED_PROPERTIES}. Note that you can explicitly
	 * exclude fields by annotating them with <code>@NotAudited</code> annotation
	 * (see {@link NotAudited}).
	 *
	 * @param ignoredProperties entity property names to exclude from audit trail
	 * @see NotAudited
	 */
	public void setIgnoredProperties(final Set<String> ignoredProperties) {
		this.ignoredProperties = ignoredProperties;
	}

	/**
	 * Gets the ignored properties.
	 *
	 * @return the ignored properties
	 */
	public Set<String> getIgnoredProperties() {
		return ignoredProperties;
	}

	/**
	 * Sets the audit trail service.
	 *
	 * @param auditTrailService the new audit trail service
	 */
	public void setAuditTrailService(final AuditTrailService auditTrailService) {
		this.auditTrailService = auditTrailService;
	}

	/**
	 * Gets the audit trail service.
	 *
	 * @return the audit trail service
	 */
	public AuditTrailService getAuditTrailService() {
		return auditTrailService;
	}

	/*
	 * (non-Javadoc)
	 * @see org.hibernate.EmptyInterceptor#onFlushDirty(java.lang.Object,
	 * java.io.Serializable, java.lang.Object[], java.lang.Object[],
	 * java.lang.String[], org.hibernate.type.Type[])
	 */
	/* We add more stuff to the transaction if that fails we're still good! */
	@Override
	public boolean onFlushDirty(final Object entity, final Serializable id, final Object[] currentState, final Object[] previousState, final String[] propertyNames,
			final Type[] types) {
		final Class<?> entityClass = entity.getClass();
		log.trace("Inspecting Entity.class={} id={}", entityClass, id);

		if (!isAudited(entityClass)) {
			return false;
		}

		final Set<String> entityIgnoredFields = ignoredClassFields.get(entityClass);

		// Identify changed values
		for (int i = 0; i < previousState.length; i++) {
			final String propertyName = propertyNames[i];
			final Object prev = previousState[i];
			final Object curr = currentState[i];

			if (ignoredProperties.contains(propertyName) || (entityIgnoredFields != null && entityIgnoredFields.contains(propertyName))) {
				log.trace("{} property in {} is not audited.", propertyName, entityClass.getSimpleName());
				continue;
			}

			if (((prev != null) && !prev.equals(curr)) || ((prev == null) && (curr != null))) {
				log.trace("prop={} prev={} curr={} type={}", propertyName, prev, curr, types[i].getReturnedClass());

				if (isPrimitiveType(types[i].getReturnedClass())) {
					final String currentValue = formatValue(curr, types[i], entityClass, propertyName);
					final String previousValue = formatValue(prev, types[i], entityClass, propertyName);
					// Notice cast to Long here!
					recordChange(entity, (Long) id, propertyName, previousValue, currentValue, null, prev, curr);

				} else if (isEntity(types[i].getReturnedClass())) {
					final EntityId prevEntity = (EntityId) prev, currEntity = (EntityId) curr;
					final String previousValue = prevEntity == null ? null : prevEntity.getId().toString();
					final String currentValue = currEntity == null ? null : currEntity.getId().toString();

					if (!StringUtils.equals(previousValue, currentValue)) {
						// Notice cast to Long here!
						recordChange(entity, (Long) id, propertyName, previousValue, currentValue, types[i].getReturnedClass(), prev, curr);
					}
				} else {
					log.trace("Entity.{} {} is not a primitive. Ignoring value={}", propertyName, prev == null ? null : prev.getClass(), prev);
					// TODO Capture in audit log
				}
			}
		}
		return false;
	}

	/**
	 * Format value.
	 *
	 * @param someValue the some value
	 * @param type the type
	 * @param entityClass the entity class
	 * @param propertyName the property name
	 * @return the string
	 */
	private String formatValue(final Object someValue, final Type type, final Class<?> entityClass, final String propertyName) {
		if (someValue == null) {
			return null;
		}

		// Check if field should be masked
		final Set<String> securedFields = securedClassFields.get(entityClass);
		if (securedFields != null && securedFields.contains(propertyName)) {
			return AuditLog.FIELD_VALUE_NOT_AUDITED;
		}

		final Class<?> returnedClass = type.getReturnedClass();

		if (Instant.class.equals(returnedClass)) {
			return DateTimeFormatter.ISO_DATE_TIME.withZone(ZoneOffset.systemDefault()).format((Instant)someValue);
		} else if (LocalDate.class.equals(returnedClass)) {
			return DateTimeFormatter.ISO_DATE.format((LocalDate)someValue);
		} else if (Date.class.equals(returnedClass) || Calendar.class.equals(returnedClass)) {
			TemporalType temporalType = TemporalType.TIMESTAMP;
			try {
				final Field field = entityClass.getDeclaredField(propertyName);
				if ((field != null) && field.isAnnotationPresent(Temporal.class)) {
					final Temporal ta = field.getAnnotation(Temporal.class);
					temporalType = ta.value();
				}
			} catch (NoSuchFieldException | SecurityException e) {
				log.trace("Could not access field {}#{}", entityClass, propertyName);
			}

			switch (temporalType) {
				case TIMESTAMP:
					return dateTimeFormatter.format(someValue);
				case DATE:
					return dateFormatter.format(someValue);
				case TIME:
					return timeFormatter.format(someValue);
			}
		}

		return someValue.toString();
	}

	/**
	 * Checks if is entity.
	 *
	 * @param clazz the clazz
	 * @return true, if is entity
	 */
	private boolean isEntity(final Class<?> clazz) {
		if (EntityId.class.isAssignableFrom(clazz)) {
			return true;
		}
		log.trace("{} is not an EntityId", clazz.getName());
		return false;
	}

	/*
	 * (non-Javadoc)
	 * @see org.hibernate.EmptyInterceptor#onDelete(java.lang.Object,
	 * java.io.Serializable, java.lang.Object[], java.lang.String[],
	 * org.hibernate.type.Type[])
	 */
	@Override
	public void onDelete(final Object entity, final Serializable id, final Object[] states, final String[] propertyNames, final Type[] types) {
		final Class<?> entityClass = entity.getClass();
		log.trace("Inspecting Entity.class={} id={}", entityClass, id);

		if (!isAudited(entityClass)) {
			log.trace("{} is not audited", entityClass);
			return;
		}

		final Set<String> entityIgnoredFields = ignoredClassFields.get(entityClass);

		for (int i = 0; i < states.length; i++) {
			final String propertyName = propertyNames[i];
			final Object state = states[i];

			if (ignoredProperties.contains(propertyName) || (entityIgnoredFields != null && entityIgnoredFields.contains(propertyName))) {
				continue;
			}

			if (state != null) {
				log.trace("Deleted prop={} state={} type={}", propertyName, state, types[i].getReturnedClass());

				if (isPrimitiveType(types[i].getReturnedClass())) {
					// Notice cast to Long here!
					recordDelete(entity, (Long) id, propertyName, state.toString(), null, null);
				} else if (isEntity(types[i].getReturnedClass())) {
					final EntityId prevEntity = (EntityId) state;
					final String previousValue = prevEntity.getId().toString();
					// Notice cast to Long here!
					recordDelete(entity, (Long) id, propertyName, previousValue, types[i].getReturnedClass(), state);
				} else {
					log.trace("Entity.{} {} is not a primitive. Ignoring value={}", propertyName, state.getClass(), state);
					// TODO Capture in audit log
					// PersistentBag
				}
			}
		}
	}

	/*
	 * (non-Javadoc)
	 * @see org.hibernate.EmptyInterceptor#onCollectionRecreate(java.lang.Object,
	 * java.io.Serializable)
	 */
	@Override
	public void onCollectionRecreate(final Object collection, final Serializable key) throws CallbackException {
		log.trace("Collection recreated: key={} coll={}", key, collection);
	}

	/*
	 * (non-Javadoc)
	 * @see org.hibernate.EmptyInterceptor#onCollectionRemove(java.lang.Object,
	 * java.io.Serializable)
	 */
	@Override
	public void onCollectionRemove(final Object collection, final Serializable key) throws CallbackException {
		final PersistentCollection pc = (PersistentCollection) collection;
		if (!isAudited(pc.getOwner().getClass())) {
			log.trace("Class {} is not audited", pc.getOwner().getClass());
			return;
		}

		final Class<? extends Object> ownerClass = pc.getOwner().getClass();
		final String propertyName = pc.getRole().substring(pc.getRole().lastIndexOf('.') + 1);
		final Class<?> propertyType = findPropertyType(ownerClass, propertyName);
		log.trace("Property class: {}.{}={}", ownerClass.getName(), propertyName, propertyType);

		Collection<Object> deleted = new HashSet<>();

		if (pc.getValue() == null) {
			log.trace("onCollectionRemove is empty, no change key={}", key);
			return;
		}

		if (pc.getValue() instanceof Collection<?>) {
			deleted.addAll((Collection<?>) pc.getValue());
		}

		if (deleted.isEmpty()) {
			log.trace("onCollectionRemove is empty, no change key={}", key);
			return;
		}

		Serializable snap = pc.getStoredSnapshot();
		log.trace("Collection remove: key={} coll={} snap={}", key, collection, snap);

		// If remaining is EntityId, convert to ID only
		Class<?> referencedEntity = null;
		if (EntityId.class.isAssignableFrom(propertyType)) {
			log.trace("{} is EntityId, converting values.", propertyType.getName());
			referencedEntity = propertyType;
			deleted = convertEntityId(deleted);
		}

		log.trace("prev={} curr=null", deleted);
		recordDelete(pc.getOwner(), (Long) key, propertyName, deleted.toString(), referencedEntity, deleted);
	}

	/*
	 * (non-Javadoc)
	 * @see org.hibernate.EmptyInterceptor#onCollectionUpdate(java.lang.Object,
	 * java.io.Serializable)
	 */
	@Override
	public void onCollectionUpdate(final Object collection, final Serializable key) throws CallbackException {
		final PersistentCollection pc = (PersistentCollection) collection;
		if (!isAudited(pc.getOwner().getClass())) {
			log.trace("Class {} is not audited", pc.getOwner().getClass());
			return;
		}
		// if (pc.empty()) {
		// log.trace("onCollectionUpdate is empty, no change", key);
		// return;
		// }
		log.trace("Collection update: key={} coll={}", key, collection);

		final Class<? extends Object> ownerClass = pc.getOwner().getClass();
		log.trace("ownerClass={} role={}", ownerClass.getName(), pc.getRole());
		final String propertyName = pc.getRole().substring(pc.getRole().lastIndexOf('.') + 1);
		final Class<?> propertyType = findPropertyType(ownerClass, propertyName);
		log.trace("Property class: {}.{}={}", ownerClass.getName(), propertyName, propertyType);
		
		
		Collection<?> remaining = null;
		if (pc.getValue() instanceof Collection<?>) {
			remaining = (Collection<?>) pc.getValue();
		} else {
			log.trace("Can't handle pc={} val={}", pc.getValue().getClass(), pc);
		}

		Collection<?> previous = null;
		final Serializable snap = pc.getStoredSnapshot();
		if (snap instanceof Map<?, ?>) {
			// log.trace("Snap keys: {}", ((Map) snap).keySet());
			// log.trace("Snap vals: {}", ((Map) snap).values());
			final Map<?, ?> snapMap = (Map<?, ?>) snap;
			previous = snapMap.keySet();
		} else if (snap instanceof Collection<?>) {
			final Collection<?> snapList = (Collection<?>) snap;
			previous = snapList;
		} else if (snap != null) {
			log.trace("Can't handle snap={} val={}", snap.getClass(), snap);
		}

		// If remaining is EntityId, convert to ID only
		Class<?> referencedEntity = null;
		if (EntityId.class.isAssignableFrom(propertyType)) {
			log.trace("{} is EntityId, converting values.", propertyType.getName());
			referencedEntity = propertyType;
			previous = convertEntityId(previous);
			remaining = convertEntityId(remaining);
		}

		log.trace("prev={} curr={}", previous, remaining);
		try {
			recordChange(pc.getOwner(), (Long) key, propertyName, collectionToStringSorted(previous), collectionToStringSorted(remaining), referencedEntity, previous, remaining);
		} catch (ClassCastException e) {
			if (previous != null) log.error("Previous {}: {}", previous.getClass(), previous);
			if (remaining != null) log.error("Remaining {}: {}", remaining.getClass(), remaining);
			log.error("Could not serialize property {}#{}: {}", pc.getOwner().getClass(), propertyName, e.getMessage());
		}
	}


	/**
	 * Produce a string representation of the collection. Set elements are sorted,
	 * other collection types return `#toString`.
	 *
	 * @param collection the collection
	 * @return the string
	 */
	private String collectionToStringSorted(Collection<?> collection) {
		if (collection == null || collection.isEmpty()) {
			return null;
		}
		if (collection instanceof Set<?>) {
			log.trace("Converting to sorted list {} -> {}", collection, collection.stream().sorted().map(Object::toString).collect(Collectors.joining(", ", "[", "]")));
			return collection.stream().sorted().map(Object::toString).collect(Collectors.joining(", ", "[", "]"));
		} else {
			log.trace("Not sorting {}: {}", collection.getClass(), collection);
		}
		return collection.toString();
	}

	/**
	 * Find property type.
	 *
	 * @param class1 the class 1
	 * @param propertyName the property name
	 * @return the class
	 */
	private Class<?> findPropertyType(final Class<? extends Object> class1, final String propertyName) {
		log.trace("Finding property type for {}.{}", class1.getName(), propertyName);

		// Field
		final Field field = ReflectionUtils.findField(class1, propertyName);
		if (field != null) {
			log.trace("Found field: {}\n\ttype={}\n\tgeneric={}\n\tgenericTN={}", field, field.getType(), field.getGenericType(), field.getGenericType().getTypeName());
			final ResolvableType t = ResolvableType.forField(field, class1);
			if (t.hasGenerics()) {
				log.trace("\tResoved={} returning={}", t, t.resolveGeneric(0));
				return t.resolveGeneric(0);
			} else {
				log.trace("Returning class itself={}", t.getRawClass());
				return t.getRawClass();
			}
		}

		// Getter
		try {
			final Method method = class1.getMethod("get" + StringUtils.capitalize(propertyName));
			if (method != null) {
				log.trace("Didn't find field, found the method: {}", method.getReturnType().getTypeParameters()[0]);
			}
		} catch (SecurityException | NoSuchMethodException e) {
			log.debug("Could not access getter: {}", e.getMessage());
		}
		return null;
	}

	/**
	 * Convert entity id.
	 *
	 * @param previous the previous
	 * @return the collection
	 */
	private Collection<Object> convertEntityId(final Collection<?> previous) {
		if (previous == null || previous.isEmpty()) return List.of();

		final List<Object> converted = new ArrayList<>();
		for (final Object p : previous) {
			if (p instanceof EntityId) {
				converted.add(((EntityId) p).getId());
			} else {
				converted.add(p);
			}
		}
		converted.sort((a, b) -> {
			if (a instanceof Long && b instanceof Long) {
				return Long.compare((Long) a, (Long) b);
			} else if (a == null || b == null){
				return 0;
			} else {
				return Integer.compare(a.hashCode(), b.hashCode());
			}
		});
		return converted;
	}

	/**
	 * Checks if class is a "primitive".
	 *
	 * @param class1 the property class
	 * @return true, if is primitive
	 */
	boolean isPrimitiveType(final Class<?> class1) {
		if (class1.isPrimitive()) {
			return true;
		} else if (class1.isEnum()) {
			return true;
		} else if (class1.isArray()) {
			return false;
		} else if (Number.class.isAssignableFrom(class1)) {
			return true;
		} else if (String.class.equals(class1)) {
			return true;
		} else if (Character.class.equals(class1)) {
			return true;
		} else if (Boolean.class.equals(class1)) {
			return true;
		} else if (Instant.class.equals(class1)) {
			return true;
		} else if (LocalDate.class.equals(class1)) {
			return true;
		} else if (UUID.class.equals(class1)) {
			return true;
		}
		log.trace("Class {} is not a primitive.", class1);
		return false;
	}

	/**
	 * Record change.
	 *
	 * @param entity the entity
	 * @param id the id
	 * @param propertyName the property name
	 * @param previousState the previous state
	 * @param currentState the current state
	 * @param referencedEntity the referenced entity
	 * @param curr previousEntit
	 * @param prev currentEntity
	 */
	private void recordChange(final Object entity, final Long id, final String propertyName, final String previousState, final String currentState,
			final Class<?> referencedEntity, Object prev, Object curr) {

		if (StringUtils.equals(previousState, currentState) && !StringUtils.equals(currentState, AuditLog.FIELD_VALUE_NOT_AUDITED)) {
			log.trace("No state change {}.{} {}=={}", entity.getClass(), id, previousState, currentState);
			return;
		}

		TransactionAuditLog change = auditTrailService.auditLogEntry(AuditAction.UPDATE, entity, id, propertyName, previousState, currentState, referencedEntity, prev, curr);
		if (auditLogStack.get().peek().remove(change)) {
			log.trace("Replacing existing changelog {}", change);
		} else {
			log.trace("Adding new changelog {}", change);
		}
		auditLogStack.get().peek().add(change);
	}

	/**
	 * Record delete.
	 *
	 * @param entity the entity
	 * @param id the id
	 * @param propertyName the property name
	 * @param state the state
	 * @param referencedEntity the referenced entity
	 */
	private void recordDelete(final Object entity, final Long id, final String propertyName, final String state, final Class<?> referencedEntity, final Object prev) {
		String stateToLog = state;

		// Check fields masked with @HideAuditValue
		if (stateToLog != null) {
			final Set<String> securedFields = securedClassFields.get(entity.getClass());
			stateToLog = securedFields != null && securedFields.contains(propertyName) ? AuditLog.FIELD_VALUE_NOT_AUDITED : state;
		}

		TransactionAuditLog delete = auditTrailService.auditLogEntry(AuditAction.DELETE, entity, id, propertyName, stateToLog, null, referencedEntity, prev, null);
		if (auditLogStack.get().peek().remove(delete)) {
			log.trace("Replacing exising changelog {}", delete);
		} else {
			log.trace("Adding new changelog {}", delete);
		}
		auditLogStack.get().peek().add(delete);
	}

	/**
	 * Determines if the entity class should be audited.
	 *
	 * @param entityClass entity class
	 * @return <code>true</code> when changes should be audited
	 */
	boolean isAudited(final Class<?> entityClass) {
		// Cache lookup
		if (includedClasses.contains(entityClass)) {
			return true;
		}
		// Cache lookup
		if (ignoredClasses.contains(entityClass)) {
			return false;
		}

		final NotAudited notAuditedAnnotation = entityClass.getAnnotation(NotAudited.class);
		if (notAuditedAnnotation != null) {
			log.trace("{} is excluded from auditing", entityClass);
			// Register ignored class
			ignoredClasses.add(entityClass);
			return false;
		}

		final Audited auditedAnnotation = entityClass.getAnnotation(Audited.class);
		if (auditedAnnotation != null) {
			log.trace("{} is annotated for auditing", entityClass);
			// Register included class
			includedClasses.add(entityClass);

			ReflectionUtils.doWithFields(entityClass, field -> {
				if (field.getAnnotation(NotAudited.class) != null) {
					Set<String> ignoredEntityFields = ignoredClassFields.computeIfAbsent(entityClass, (x) -> new HashSet<>());
					log.trace("{} property of {} class is excluded from auditing", field.getName(), entityClass);
					ignoredEntityFields.add(field.getName());
				}
				if (field.getAnnotation(HideAuditValue.class) != null) {
					Set<String> securedFields = securedClassFields.computeIfAbsent(entityClass, (x) -> new HashSet<>());
					log.trace("Previous and a new value of {} property of {} class is excluded from persisting", field.getName(), entityClass);
					securedFields.add(field.getName());
				}
			});
			return true;
		}

		for (final Class<?> auditedClass : auditedClasses) {
			if (auditedClass.isAssignableFrom(entityClass)) {
				log.trace("{} is audited because it is an instance of {}", entityClass, auditedClass);
				// Register included class
				includedClasses.add(entityClass);
				return true;
			}
		}

		log.trace("{} is not audited", entityClass);
		// Register ignored entity class
		ignoredClasses.add(entityClass);
		return false;
	}

	// @Override
	// public boolean onSave(Object entity, Serializable id, Object[] state,
	// String[] propertyNames, Type[] types) {
	// log.trace("onSave entity={} id={} props={}", entity, id,
	// Arrays.toString(propertyNames));
	// return super.onSave(entity, id, state, propertyNames, types);
	// }

	@Override
	public void afterTransactionBegin(final Transaction tx) {
		// Push new auditLogs to the stack
		var transactionLogs = auditLogStack.get();
		transactionLogs.push(new HashSet<>());

		log.trace("Starting transaction level={}", transactionLogs.size());

		// tx.registerSynchronization(new Synchronization() {
		//
		// @Override
		// public void beforeCompletion() {
		// log.trace("callback beforeCompletion");
		// }
		//
		// @Override
		// public void afterCompletion(int status) {
		// log.trace("callback afterCompletion status={}", status);
		// }
		// });
		super.afterTransactionBegin(tx);
	}

	// Called before a transaction is committed (but not before rollback).
	@Override
	public void beforeTransactionCompletion(final Transaction tx) {

		var transactionLogs = auditLogStack.get();
		final long level = transactionLogs.size();
		Set<TransactionAuditLog> currentAuditLogs = transactionLogs.pop(); // pop

		log.trace("beforeTransactionCompletion transaction level={} auditlogs={}", level, currentAuditLogs.size());

		// log.trace("Before transaction completion status={} tx={}",
		// tx.getLocalStatus(), tx);

		if (currentAuditLogs.size() > 0) {
			log.trace("We have {} auditlogs", currentAuditLogs.size());
			currentAuditLogs.stream().forEach(auditLog -> {
				log.debug("Audit log to save: {}", auditLog);
			});
			if (TransactionStatus.ROLLED_BACK == tx.getStatus()) {
				log.warn("Transaction was rolled back. Audit logs likely won't be persisted");
			}
			this.auditTrailService.addAuditLogs(currentAuditLogs);
			currentAuditLogs.clear(); // not required anymore
		}

		super.beforeTransactionCompletion(tx);
	}

	// @SuppressWarnings({ "rawtypes", "unchecked" })
	// @Override
	// public void postFlush(final Iterator entities) {
	// // log.trace("postFlush {}", entities);
	// entities.forEachRemaining(entity -> {
	// log.trace("postFlush {}", entity);
	// });
	//
	// super.postFlush(entities);
	// }

	// Called after a transaction is committed or rolled back.
	@Override
	public void afterTransactionCompletion(Transaction tx) {

		final long level = auditLogStack.get().size();
		log.trace("afterTransactionCompletion transaction level={}", level);
		if (TransactionStatus.COMMITTED == tx.getStatus()) {
			log.trace("Transaction was committed, level={}", level);
		} else if (TransactionStatus.ROLLED_BACK == tx.getStatus()) {
			log.trace("Transaction was rolled back, level={}", level);
		}

		super.afterTransactionCompletion(tx);
	}

	@Override
	public Boolean isTransient(final Object entity) {
		if (entity instanceof BasicModel) {
			return ((BasicModel) entity).isNew();
		}
		// TODO Use Spring field access methods?
		try {
			return tryMethod(entity, "getVersion");
		} catch (final NoSuchMethodException e) {
			try {
				return tryMethod(entity, "getId");
			} catch (NoSuchMethodException | IllegalAccessException | InvocationTargetException e1) {
				throw new RuntimeException(e.getMessage() + " on " + entity.getClass() + " e=" + entity, e);
			}
		} catch (final Throwable e) {
			throw new RuntimeException(e.getMessage() + " on " + entity.getClass() + " e=" + entity, e);
		}
	}

	/**
	 * Try method.
	 *
	 * @param entity the entity
	 * @param methodName the method name
	 * @return the boolean
	 * @throws NoSuchMethodException the no such method exception
	 * @throws IllegalAccessException the illegal access exception
	 * @throws InvocationTargetException the invocation target exception
	 */
	public boolean tryMethod(final Object entity, final String methodName) throws NoSuchMethodException, IllegalAccessException, InvocationTargetException {
		final Method getter = entity.getClass().getMethod(methodName);
		final Object result = getter.invoke(entity);
		if (result == null) {
			log.trace("{} is transient, has {} == null", entity, methodName);
			return true;
		} else if (result instanceof Number) {
			final Number r = (Number) result;
			if (r.longValue() < 0) {
				log.trace("{} is transient, has {} = {} < 0", entity, methodName, result);
				return true;
			} else {
				// log.trace(entity + " is not transient, has " + methodName + " = " + result +
				// " >= 0");
				return false;
			}
		} else {
			// log.trace(entity + " is not transient, has " + methodName + " = " + result);
			return false;
		}
	}
}