ClamAVScanner.java

/*
 * Copyright 2018 Global Crop Diversity Trust
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.genesys.filerepository.service.impl;

import java.io.ByteArrayInputStream;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.util.Collection;
import java.util.Map.Entry;

import org.genesys.filerepository.service.VirusFoundException;
import org.genesys.filerepository.service.VirusScanner;
import org.springframework.beans.factory.InitializingBean;

import lombok.extern.slf4j.Slf4j;
import xyz.capybara.clamav.ClamavClient;
import xyz.capybara.clamav.ClamavException;
import xyz.capybara.clamav.Platform;
import xyz.capybara.clamav.commands.scan.result.ScanResult;

/**
 * Scan bytes using ClamAV.
 */
@Slf4j
public class ClamAVScanner implements VirusScanner, InitializingBean {

	/** The clam av host. */
	private String clamAvHost;

	/** The clam av port. */
	private int clamAvPort;

	/** The client. */
	private volatile ClamavClient client;

	/**
	 * Sets the clam av host.
	 *
	 * @param clamAvHost the new clam av host
	 */
	public void setClamAvHost(final String clamAvHost) {
		this.clamAvHost = clamAvHost;
	}

	/**
	 * Sets the clam av port.
	 *
	 * @param clamAvPort the new clam av port
	 */
	public void setClamAvPort(final int clamAvPort) {
		this.clamAvPort = clamAvPort;
	}

	/*
	 * (non-Javadoc)
	 * @see org.springframework.beans.factory.InitializingBean#afterPropertiesSet()
	 */
	@Override
	public void afterPropertiesSet() throws Exception {
		try {
			reconnect();
		} catch (final ClamavException e) {
			log.warn("Could not connect to clamd. Will retry on file upload...");
		}
	}

	/**
	 * Reconnect.
	 *
	 * @throws ClamavException the clamav exception
	 */
	private void reconnect() throws ClamavException {
		log.info("Connecting to clamd at {}:{}", clamAvHost, clamAvPort);
		final ClamavClient c = new ClamavClient(clamAvHost, clamAvPort, Platform.UNIX);
		c.ping();
		log.info("Connected to clamd at {}:{} version={}", clamAvHost, clamAvPort, c.version());
		synchronized (this) {
			this.client = c;
		}
	}

	/*
	 * (non-Javadoc)
	 * @see org.genesys.filerepository.service.VirusScanner#scan(byte[])
	 */
	@Override
	public void scan(final byte[] bytes) throws VirusFoundException, ClamavException, IOException {
		synchronized (this) {
			if (client == null) {
				try {
					reconnect();
				} catch (final ClamavException e) {
					log.error(e.getMessage());
					throw e;
				}
			}
		}

		try (InputStream is = new ByteArrayInputStream(bytes)) {
			final ScanResult scanResult = client.scan(is);
			if (scanResult instanceof ScanResult.OK) {
				// OK
				log.debug("Data scanned and found virus-free");
			} else if (scanResult instanceof ScanResult.VirusFound) {
				final StringBuffer sb = new StringBuffer();
				for (final Entry<String, Collection<String>> entry : ((ScanResult.VirusFound) scanResult).getFoundViruses().entrySet()) {
					log.error("In file={}: Found virus {}", entry.getKey(), entry.getValue());
					sb.append(entry.getValue()).append("; ");
				}
				throw new VirusFoundException(sb.toString());
			}
		} catch (final ClamavException e) {
			log.warn("Error scanning: {}", e.getMessage(), e);
		}
	}

	/*
	 * (non-Javadoc)
	 * @see org.genesys.filerepository.service.VirusScanner#scan(InputStream)
	 */
	@Override
	public void scan(final InputStream inputStream) throws VirusFoundException {
		synchronized (this) {
			if (client == null) {
				try {
					reconnect();
				} catch (final ClamavException e) {
					log.error(e.getMessage());
					throw e;
				}
			}
		}

		try {
			final ScanResult scanResult = client.scan(inputStream);
			if (scanResult instanceof ScanResult.OK) {
				// OK
				log.debug("Data scanned and found virus-free");
			} else if (scanResult instanceof ScanResult.VirusFound) {
				final StringBuilder sb = new StringBuilder();
				for (final Entry<String, Collection<String>> entry : ((ScanResult.VirusFound) scanResult).getFoundViruses().entrySet()) {
					log.error("In file={}: Found virus {}", entry.getKey(), entry.getValue());
					sb.append(entry.getValue()).append("; ");
				}
				throw new VirusFoundException(sb.toString());
			}
		} catch (final ClamavException e) {
			log.warn("Error scanning: {}", e.getMessage(), e);
			throw new RuntimeException("Error scanning: " + e.getMessage(), e);
		}
	}

	/*
	 * (non-Javadoc)
	 * @see org.genesys.filerepository.service.VirusScanner#scan(InputStream)
	 */
	@Override
	public void scan(final File inputFile) throws VirusFoundException {
		if (inputFile == null || !inputFile.exists()) {
			log.debug("Null or non-existing File provided to scanner. Skipping.");
			return;
		}

		synchronized (this) {
			if (client == null) {
				try {
					reconnect();
				} catch (final ClamavException e) {
					log.error(e.getMessage());
					throw e;
				}
			}
		}

		try {
			final ScanResult scanResult = client.scan(inputFile.toPath());
			if (scanResult instanceof ScanResult.OK) {
				// OK
				log.debug("Data scanned and found virus-free");
			} else if (scanResult instanceof ScanResult.VirusFound) {
				final StringBuilder sb = new StringBuilder();
				for (final Entry<String, Collection<String>> entry : ((ScanResult.VirusFound) scanResult).getFoundViruses().entrySet()) {
					log.error("In file={}: Found virus {}", entry.getKey(), entry.getValue());
					sb.append(entry.getValue()).append("; ");
				}
				throw new VirusFoundException(sb.toString());
			}
		} catch (final ClamavException e) {
			log.warn("Error scanning: {}", e.getMessage(), e);
			throw new RuntimeException("Error scanning: " + e.getMessage(), e);
		}
	}
}