فهرست منبع

multithreaded decryption using producer/consumer pattern

Sebastian Stenzel 9 سال پیش
والد
کامیت
cdcc1626ce

+ 2 - 2
main/core/src/test/java/org/cryptomator/webdav/jackrabbit/RangeRequestTest.java

@@ -70,9 +70,9 @@ public class RangeRequestTest {
 		final HttpClient client = new HttpClient();
 
 		// prepare 64MiB test data:
-		final byte[] plaintextData = new byte[6777216 * Integer.BYTES];
+		final byte[] plaintextData = new byte[16777216 * Integer.BYTES];
 		final ByteBuffer bbIn = ByteBuffer.wrap(plaintextData);
-		for (int i = 0; i < 6777216; i++) {
+		for (int i = 0; i < 16777216; i++) {
 			bbIn.putInt(i);
 		}
 		final InputStream plaintextDataInputStream = new ByteArrayInputStream(plaintextData);

+ 82 - 88
main/crypto-aes/src/main/java/org/cryptomator/crypto/aes256/Aes256Cryptor.java

@@ -22,10 +22,14 @@ import java.security.SecureRandom;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.List;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.CompletionService;
 import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorCompletionService;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
-import java.util.concurrent.Future;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicLong;
 import java.util.concurrent.locks.Condition;
 import java.util.concurrent.locks.Lock;
@@ -439,104 +443,94 @@ public class Aes256Cryptor implements Cryptor, AesCryptographicConfiguration {
 		final SecretKey fileKey = new SecretKeySpec(fileKeyBytes, AES_KEY_ALGORITHM);
 		System.err.println("init DEC: " + (System.nanoTime() - t0) / 1000 / 1000.0 + "ms");
 
-		final int numWorkers = 1;
-		final ExecutorService executorService = Executors.newFixedThreadPool(numWorkers);
+		// prepare some crypto workers:
+		final int numWorkers = Runtime.getRuntime().availableProcessors();
 		final Lock lock = new ReentrantLock();
-		final Condition blockCondition = lock.newCondition();
-
-		// reading ciphered input and MACs interleaved:
-		final AtomicLong bytesDecrypted = new AtomicLong();
-		final InputStream in = new SeekableByteChannelInputStream(encryptedFile);
-		final byte[] buffer = new byte[(CONTENT_MAC_BLOCK + 32) * numWorkers];
-		int n = 0;
-		final AtomicLong blockNum = new AtomicLong();
-		final AtomicLong numBlocksWritten = new AtomicLong();
-		while ((n = IOUtils.read(in, buffer)) > 0 && bytesDecrypted.get() < fileSize) {
-			t0 = System.nanoTime();
-
-			final int finalN = n;
-			final List<Future<Boolean>> tasks = new ArrayList<>();
-			for (int i = 0; i < numWorkers; i++) {
-				final Future<Boolean> task = executorService.submit(() -> {
-					final long myBlockNum = blockNum.getAndIncrement();
-					final int myBufferOffset = (int) ((myBlockNum % numWorkers) * (CONTENT_MAC_BLOCK + 32));
-					final int myN = Math.min(finalN - myBufferOffset, CONTENT_MAC_BLOCK + 32);
-
-					if (myN <= 0) {
-						// EOF
-					} else if (myN < 32) {
-						throw new DecryptFailedException("Invalid file content, missing MAC.");
-					}
-
-					// check MAC of current block:
-					if (authenticate) {
-						final Mac contentMac = this.hmacSha256(hMacMasterKey);
-						contentMac.update(iv);
-						contentMac.update(longToByteArray(myBlockNum));
-						contentMac.update(buffer, myBufferOffset, myN - 32);
-						final byte[] calculatedMac = contentMac.doFinal();
-						final byte[] storedMac = new byte[32];
-						System.arraycopy(buffer, myBufferOffset + myN - 32, storedMac, 0, 32);
-						if (!MessageDigest.isEqual(calculatedMac, storedMac)) {
-							throw new MacAuthenticationFailedException("Content MAC authentication failed.");
-						}
-					}
+		final Condition blockDone = lock.newCondition();
+		final AtomicLong currentBlock = new AtomicLong();
+		final BlockingQueue<Block> inputQueue = new LinkedBlockingQueue<>(numWorkers);
+		final LengthLimitingOutputStream paddingRemovingOutputStream = new LengthLimitingOutputStream(plaintextFile, fileSize);
+		final List<DecryptWorker> workers = new ArrayList<>();
+		final ExecutorService executorService = Executors.newFixedThreadPool(numWorkers);
+		final CompletionService<Void> completionService = new ExecutorCompletionService<>(executorService);
+		for (int i = 0; i < numWorkers; i++) {
+			final DecryptWorker worker = new DecryptWorker(lock, blockDone, currentBlock, inputQueue, authenticate, paddingRemovingOutputStream) {
+				private final Mac mac = hmacSha256(hMacMasterKey);
 
-					// decrypt block:
+				@Override
+				protected byte[] decrypt(Block block) {
 					final ByteBuffer nonceAndCounterBuf = ByteBuffer.allocate(AES_BLOCK_LENGTH);
 					nonceAndCounterBuf.put(nonce);
-					nonceAndCounterBuf.putLong(myBlockNum * CONTENT_MAC_BLOCK / AES_BLOCK_LENGTH);
+					nonceAndCounterBuf.putLong(block.blockNumber * CONTENT_MAC_BLOCK / AES_BLOCK_LENGTH);
 					final byte[] nonceAndCounter = nonceAndCounterBuf.array();
-					final Cipher cipher = this.aesCtrCipher(fileKey, nonceAndCounter, Cipher.DECRYPT_MODE);
-					final byte[] plaintext = cipher.update(buffer, myBufferOffset, myN - 32);
-
-					// wait for our turn to write the plaintext:
-					lock.lock();
-					try {
-						while (numBlocksWritten.get() != myBlockNum) {
-							blockCondition.await();
-						}
-						final int plaintextLengthWithoutPadding = (int) Math.min(plaintext.length, fileSize - bytesDecrypted.get()); // plaintext.length is known to be a 32 bit int
-						plaintextFile.write(plaintext, 0, plaintextLengthWithoutPadding);
-						bytesDecrypted.addAndGet(plaintextLengthWithoutPadding);
-						numBlocksWritten.incrementAndGet();
-						blockCondition.signalAll();
-					} catch (InterruptedException e) {
-
-					} finally {
-						lock.unlock();
-					}
-					return true;
-				});
-				tasks.add(task);
-			}
+					final Cipher cipher = aesCtrCipher(fileKey, nonceAndCounter, Cipher.DECRYPT_MODE);
+					return cipher.update(block.buffer, 0, block.numBytes - mac.getMacLength());
+				}
 
-			for (Future<Boolean> task : tasks) {
-				try {
-					task.get();
-				} catch (InterruptedException e) {
-					// TODO Auto-generated catch block
-				} catch (ExecutionException e) {
-					final Throwable cause = e.getCause();
-					if (cause instanceof DecryptFailedException) {
-						throw (DecryptFailedException) cause;
-					} else if (cause instanceof MacAuthenticationFailedException) {
-						throw (MacAuthenticationFailedException) cause;
-					} else if (cause instanceof IOException) {
-						throw (IOException) cause;
-					} else {
-						// TODO loggingpower
+				@Override
+				protected void checkMac(Block block) throws MacAuthenticationFailedException {
+					mac.update(iv);
+					mac.update(longToByteArray(block.blockNumber));
+					mac.update(block.buffer, 0, block.numBytes - mac.getMacLength());
+					final byte[] calculatedMac = mac.doFinal();
+					final byte[] storedMac = Arrays.copyOfRange(block.buffer, block.numBytes - mac.getMacLength(), block.numBytes);
+					if (!MessageDigest.isEqual(calculatedMac, storedMac)) {
+						throw new MacAuthenticationFailedException("Content MAC authentication failed.");
 					}
 				}
-			}
+			};
+			workers.add(worker);
+			completionService.submit(worker);
+		}
 
-			// System.err.println("dec " + numWorkers + " blocks: " + (System.nanoTime() - t0) / 1000 / 1000.0 + "ms");
+		System.err.println("initialization of decrypt workers: " + (System.nanoTime() - t0) / 1000 / 1000.0 + "ms");
+		t0 = System.nanoTime();
+
+		// reading ciphered input and MACs interleaved:
+		final InputStream in = new SeekableByteChannelInputStream(encryptedFile);
+		final byte[] buffer = new byte[CONTENT_MAC_BLOCK + 32];
+		int n = 0;
+		int blockNumber = 0;
+		try {
+			// read as many blocks from file as possible, but wait if queue is full:
+			while ((n = IOUtils.read(in, buffer)) > 0) {
+				final boolean consumedInTime = inputQueue.offer(new Block(n, Arrays.copyOf(buffer, n), blockNumber++), 1, TimeUnit.SECONDS);
+				if (!consumedInTime) {
+					// interrupt read loop and make room for some poisons:
+					inputQueue.clear();
+					break;
+				}
+			}
+			// each worker has to swallow some poison:
+			for (int i = 0; i < numWorkers; i++) {
+				inputQueue.put(CryptoWorker.POISON);
+			}
+		} catch (InterruptedException e) {
+			// TODO
+			e.printStackTrace();
 		}
-		destroyQuietly(fileKey);
 
-		System.err.println("cleanup: " + (System.nanoTime() - t0) / 1000 / 1000.0 + "ms");
+		// wait for decryption workers to finish:
+		try {
+			for (int i = 0; i < numWorkers; i++) {
+				completionService.take().get();
+			}
+		} catch (ExecutionException e) {
+			final Throwable cause = e.getCause();
+			if (cause instanceof IOException) {
+				throw (IOException) cause;
+			}
+		} catch (InterruptedException e) {
+			// TODO
+			e.printStackTrace();
+		} finally {
+			// shutdown either after normal decryption or if ANY worker threw an exception:
+			executorService.shutdownNow();
+		}
 
-		return bytesDecrypted.get();
+		destroyQuietly(fileKey);
+		System.err.println("decrypted " + paddingRemovingOutputStream.getBytesWritten() + " bytes in: " + (System.nanoTime() - t0) / 1000 / 1000.0 + "ms");
+		return paddingRemovingOutputStream.getBytesWritten();
 	}
 
 	@Override
@@ -674,7 +668,7 @@ public class Aes256Cryptor implements Cryptor, AesCryptographicConfiguration {
 
 		// add random length padding to obfuscate file length:
 		final byte[] randomPadding = this.randomData(AES_BLOCK_LENGTH);
-		final LengthObfuscationInputStream in = new LengthObfuscationInputStream(plaintextFile, randomPadding);
+		final LengthObfuscatingInputStream in = new LengthObfuscatingInputStream(plaintextFile, randomPadding);
 
 		// content encryption:
 		final SecretKey fileKey = new SecretKeySpec(fileKeyBytes, AES_KEY_ALGORITHM);

+ 15 - 0
main/crypto-aes/src/main/java/org/cryptomator/crypto/aes256/Block.java

@@ -0,0 +1,15 @@
+package org.cryptomator.crypto.aes256;
+
+class Block {
+
+	final int numBytes;
+	final byte[] buffer;
+	final long blockNumber;
+
+	Block(int numBytes, byte[] buffer, long blockNumber) {
+		this.numBytes = numBytes;
+		this.buffer = buffer;
+		this.blockNumber = blockNumber;
+	}
+
+}

+ 63 - 0
main/crypto-aes/src/main/java/org/cryptomator/crypto/aes256/CryptoWorker.java

@@ -0,0 +1,63 @@
+package org.cryptomator.crypto.aes256;
+
+import java.io.IOException;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.Callable;
+import java.util.concurrent.atomic.AtomicLong;
+import java.util.concurrent.locks.Condition;
+import java.util.concurrent.locks.Lock;
+
+import org.cryptomator.crypto.exceptions.CryptingException;
+
+abstract class CryptoWorker implements Callable<Void> {
+
+	static final Block POISON = new Block(0, new byte[0], -1L);
+
+	final Lock lock;
+	final Condition blockDone;
+	final AtomicLong currentBlock;
+	final BlockingQueue<Block> queue;
+
+	public CryptoWorker(Lock lock, Condition blockDone, AtomicLong currentBlock, BlockingQueue<Block> queue) {
+		this.lock = lock;
+		this.blockDone = blockDone;
+		this.currentBlock = currentBlock;
+		this.queue = queue;
+	}
+
+	@Override
+	public final Void call() throws IOException {
+		try {
+			while (!Thread.currentThread().isInterrupted()) {
+				final Block block = queue.take();
+				if (block == POISON) {
+					// put poison back in for other threads:
+					break;
+				}
+				final byte[] processedBytes = this.process(block);
+				lock.lock();
+				try {
+					while (currentBlock.get() != block.blockNumber) {
+						blockDone.await();
+					}
+					assert currentBlock.get() == block.blockNumber;
+					// yay, its my turn!
+					this.write(processedBytes);
+					// signal worker working on next block:
+					currentBlock.set(block.blockNumber + 1);
+					blockDone.signalAll();
+				} finally {
+					lock.unlock();
+				}
+			}
+		} catch (InterruptedException e) {
+			Thread.currentThread().interrupt();
+		}
+		return null;
+	}
+
+	protected abstract byte[] process(Block block) throws CryptingException;
+
+	protected abstract void write(byte[] processedBytes) throws IOException;
+
+}

+ 49 - 0
main/crypto-aes/src/main/java/org/cryptomator/crypto/aes256/DecryptWorker.java

@@ -0,0 +1,49 @@
+package org.cryptomator.crypto.aes256;
+
+import java.io.IOException;
+import java.io.OutputStream;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.atomic.AtomicLong;
+import java.util.concurrent.locks.Condition;
+import java.util.concurrent.locks.Lock;
+
+import org.cryptomator.crypto.exceptions.CryptingException;
+import org.cryptomator.crypto.exceptions.DecryptFailedException;
+import org.cryptomator.crypto.exceptions.MacAuthenticationFailedException;
+
+abstract class DecryptWorker extends CryptoWorker implements AesCryptographicConfiguration {
+
+	private final boolean shouldAuthenticate;
+	private final OutputStream out;
+
+	public DecryptWorker(Lock lock, Condition blockDone, AtomicLong currentBlock, BlockingQueue<Block> queue, boolean shouldAuthenticate, OutputStream out) {
+		super(lock, blockDone, currentBlock, queue);
+		this.shouldAuthenticate = shouldAuthenticate;
+		this.out = out;
+	}
+
+	@Override
+	protected byte[] process(Block block) throws CryptingException {
+		if (block.numBytes < 32) {
+			throw new DecryptFailedException("Invalid file content, missing MAC.");
+		}
+
+		// check MAC of current block:
+		if (shouldAuthenticate) {
+			checkMac(block);
+		}
+
+		// decrypt block:
+		return decrypt(block);
+	}
+
+	@Override
+	protected void write(byte[] processedBytes) throws IOException {
+		out.write(processedBytes);
+	}
+
+	protected abstract void checkMac(Block block) throws MacAuthenticationFailedException;
+
+	protected abstract byte[] decrypt(Block block);
+
+}

+ 40 - 0
main/crypto-aes/src/main/java/org/cryptomator/crypto/aes256/LengthLimitingOutputStream.java

@@ -0,0 +1,40 @@
+package org.cryptomator.crypto.aes256;
+
+import java.io.FilterOutputStream;
+import java.io.IOException;
+import java.io.OutputStream;
+
+public class LengthLimitingOutputStream extends FilterOutputStream {
+
+	private final long limit;
+	private volatile long bytesWritten;
+
+	public LengthLimitingOutputStream(OutputStream out, long limit) {
+		super(out);
+		this.limit = limit;
+		this.bytesWritten = 0;
+	}
+
+	@Override
+	public void write(int b) throws IOException {
+		if (bytesWritten < limit) {
+			out.write(b);
+			bytesWritten++;
+		}
+	}
+
+	@Override
+	public void write(byte[] b, int off, int len) throws IOException {
+		final long bytesAvailable = limit - bytesWritten;
+		final int adjustedLen = (int) Math.min(len, bytesAvailable);
+		if (adjustedLen > 0) {
+			out.write(b, off, adjustedLen);
+			bytesWritten += adjustedLen;
+		}
+	}
+
+	public long getBytesWritten() {
+		return bytesWritten;
+	}
+
+}

+ 2 - 2
main/crypto-aes/src/main/java/org/cryptomator/crypto/aes256/LengthObfuscationInputStream.java

@@ -9,14 +9,14 @@ import org.apache.commons.io.IOUtils;
 /**
  * Not thread-safe!
  */
-public class LengthObfuscationInputStream extends FilterInputStream {
+public class LengthObfuscatingInputStream extends FilterInputStream {
 
 	private final byte[] padding;
 	private int paddingLength = -1;
 	private long inputBytesRead = 0;
 	private int paddingBytesRead = 0;
 
-	LengthObfuscationInputStream(InputStream in, byte[] padding) {
+	LengthObfuscatingInputStream(InputStream in, byte[] padding) {
 		super(in);
 		this.padding = padding;
 	}