/*
   Copyright (c) 2015 Techmeology

   Permission is hereby granted, free of charge, to any person obtaining a copy
   of this software and associated documentation files (the "Software"), to deal
   in the Software without restriction, including without limitation the rights
   to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
   copies of the Software, and to permit persons to whom the Software is
   furnished to do so, subject to the following conditions:

   The above copyright notice and this permission notice shall be included in
   all copies or substantial portions of the Software.

   THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
   IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
   FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
   AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
   LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
   OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
   SOFTWARE.
*/

/* Example build: gcc -Wall -pedantic -O2 -s -o xorimg xorimg.c
   
   xorimg is a program for interleaving multiple disk images (or other similar
   data) together in a manner that is efficient for compressing with an external
   program. The inputs are read in blocks. Each block is processed as follows:
    - The first input's block is written verbatum
    - The length of an input's block is written as a uint32 immediately before
      its data.
    - Each successive input's block is xor'd with the previous still-open
      input's block, and the result is written. The idea using xor is to set to
      zero any runs of bytes that are the same between the successive images. If
      the images are similar, then the result would be long runs of zero bytes,
      which should be easily compressed.
    - Each block is zero padded if the block is less than the chunk size.
    - Zero padded sections are implicit and not written.
   The beginning of the output contains a 16 byte header as follows:
    - char[8] magic: {'x', 'o', 'r', 0, 'i', 'm', 'g', 0}
    - uint32 block_size: Size of each block; must be non-zero
    - uint32 input_count: Number of inputs
   
   Resource usage:
   This program opens a file handle for every input and output given to it.
   Unless an extremely large block size is specified (see below), this will
   probably be the limiting factor for the scaling of this program. It also
   allocates two buffers big enough to hold a complete block on the heap. In
   principle, we could get away with only one such buffer, but this would add
   complexity to the program and is unimportant unless the block size is huge.
 */

#include <stdarg.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

enum /* return code */
{
	OK,
	USAGE_ERROR,
	SYSCALL_ERROR,
	INPUT_ERROR
};

static const char magic[] = "xor\0img";

/* **** System interaction and error handling **** */
/* We don't fclose() or free() handles and memory that are still open at program
   termination, because exit() or return from main should do that for us. Where
   we do fclose(), we don't bother checking the return code, as there is little
   benefit in doing so. */

static void syscall_error (const char *fn, const char *format, ...)
{
	va_list as; va_start(as, format); vfprintf(stderr, format, as); va_end(as);
	fputs(": ", stderr); perror(fn);
	exit(SYSCALL_ERROR);
}

static void input_error (const char *format, ...)
{
	va_list as;
	fputs("Input error: ", stderr);
	va_start(as, format); vfprintf(stderr, format, as); va_end(as);
	fputs("\n", stderr);
	exit(INPUT_ERROR);
}

static void *alloc (uint32_t count)
{
	void *mem = malloc(count);
	if (!mem) syscall_error("malloc",
		"Could not allocate %lu bytes", (unsigned long)count);
	memset(mem, 0, count);
	return mem;
}

typedef enum
{
	open_mode_read,
	open_mode_write
} open_mode_t;
static FILE *open (const char *path, open_mode_t mode)
{
	FILE *f = 0;
	
	/* Handle the mode argument */
	FILE *std_file = 0;
	const char *str_mode = 0;
	switch (mode){
		case open_mode_read:
			std_file = stdin;
			str_mode = "rb";
			break;
		case open_mode_write:
			std_file = stdout;
			str_mode = "wb";
			break;
	}
	
	/* Handle stdin/stdout */
	if (!strcmp(path, "-")) return std_file;
	
	/* Handle non-standard in/out */
	f = fopen(path, str_mode);
	if (!f) syscall_error("fopen", "Could not open \"%s\"", path);
	return f;
}

static uint32_t read_bytes (FILE *f, void *buffer, uint32_t count, int complete)
{
	size_t n = fread(buffer, 1, count, f);
	if (ferror(f)) syscall_error("fread",
		"Could not read %lu bytes", (unsigned long)count);
	if (complete && n != count) input_error("Unexpected end of file");
	memset((uint8_t *)buffer + n, 0, count - n);
	return n;
}

static void write_bytes (FILE *f, const void *buffer, uint32_t count)
{
	fwrite(buffer, 1, count, f);
	if (ferror(f)) syscall_error("fwrite",
		"Could not write %lu bytes", (unsigned long)count);
}

static uint32_t read_uint32 (FILE *f)
{
	uint8_t buffer[4] = {0, 0, 0, 0};
	read_bytes(f, buffer, sizeof(buffer), 1);
	return (uint32_t)buffer[0] | (uint32_t)buffer[1] << 8 |
		(uint32_t)buffer[2] << 16 | (uint32_t)buffer[3] << 24;
}

static void write_uint32 (FILE *f, uint32_t n)
{
	uint8_t buffer[4] = {0, 0, 0, 0};
	buffer[0] = n; buffer[1] = n >> 8; buffer[2] = n >> 16; buffer[3] = n >> 24;
	write_bytes(f, buffer, sizeof(buffer));
}

/* **** Program logic **** */
static const char *exec_name = "";
static void usage (const char *msg)
{
	fprintf (stderr,
		"Error: %s\n"
		"Usage:\n"
		"\t%s c BLOCK_SIZE OUT [IN]...\n%s"
		"\t%s d IN [OUT]...\n%s"
		"\t%s i IN\n%s",
		msg,
		exec_name, /* c */
			"\t\tCompress INs to OUT in BLOCK_SIZE bytes chunks\n"
			"\t\tAll INs and OUT must be unique\n"
			"\t\tBLOCK_SIZE must be non-zero\n"
			"\t\tBLOCK_SIZE must be less than 4 GiB (higher bits are ignored)\n"
			"\t\tIf in doubt, a good value for BLOCK_SIZE is 4096\n",
		exec_name, /* d */
			"\t\tIN must be a file of the format produced by this program\n"
			"\t\tDecompress IN to OUTs\n"
			"\t\tAll IN and OUTS must be unique\n"
			"\t\tThere should be an OUT for each image in IN\n",
		exec_name, /* i */
			"\t\tDisplay information about IN without file lengths\n"
			"\t\tIN must be a file of the format produced by this program\n"
	);
	exit(USAGE_ERROR);
}

static void differ_strings (const char *extra,
	const char *const strings[], int count
){
	int i = 0, j = 0;
	for (i = 0; i < count; i++){
		if (!strcmp(strings[i], extra)) usage("Duplicated I/O path");
		for (j = i + 1; j < count; j++)
			if (!strcmp(strings[i], strings[j])) usage("Duplicated path");
	}
}

static void swap_ptr(uint8_t **a, uint8_t **b)
	{ uint8_t *tmp = *a; *a = *b; *b = tmp; }
static void xor (uint32_t count, uint8_t *result, uint8_t *a, uint8_t *b)
	{ int i = 0; for (i = 0; i < count; i++) result[i] = a[i] ^ b[i]; }

static void compress (uint32_t block_size, const char *output,
	const char *const inputs[], int input_count
){
	int i = 0, inputs_open = input_count;
	FILE *foutput = 0;
	FILE **finputs = 0;
	uint8_t *input_buffer = 0;
	uint8_t *previous_buffer = 0;
	
	/* Open the files */
	foutput = open(output, open_mode_write);
	finputs = alloc(input_count * sizeof(FILE *));
	for (i = 0; i < input_count; i++)
		finputs[i] = open(inputs[i], open_mode_read);
	
	/* Allocate the buffers */
	input_buffer = alloc(block_size);
	previous_buffer = alloc(block_size);
	
	/* Write the header */
	write_bytes(foutput, magic, sizeof(magic));
	write_uint32(foutput, block_size);
	write_uint32(foutput, input_count);
	
	/* Write the blocks */
	while (inputs_open){
		for (i = 0; i < input_count; i++){
			uint32_t input_size = 0;
			uint8_t *output_buffer = 0;
			if (!finputs[i]) continue; /* This input has already EOF'd */
			
			/* Read the current input, and write the block header. The call to
			   read_bytes() will do the implicit zero memset for us. */
			input_size = read_bytes(finputs[i], input_buffer, block_size, 0);
			write_uint32(foutput, input_size);
			
			/* Perform xor if necessary, and set the output buffer */
			if (output_buffer){ /* Xor with previous */
				output_buffer = previous_buffer; /* Set the output buffer */
				xor(input_size, output_buffer, previous_buffer, input_buffer);
			} else output_buffer = input_buffer; /* Verbatum */
			
			/* Write the block, and prepare for xoring the next input wiht this
			   one. output_buffer != 0 marks that there is a previous block. */
			write_bytes(foutput, output_buffer, input_size);
			swap_ptr(&input_buffer, &previous_buffer);
			
			/* Close an input if we reached its EOF */
			if (input_size != block_size){
				fclose(finputs[i]);
				finputs[i] = 0;
				inputs_open--;
			}
		}
	}
}

static void decompress (const char *input,
	const char * const *outputs, /* If null, then display info instead */
	int output_count
){
	int i = 0, outputs_open = output_count;
	uint32_t block_size = 0, expected_output_count = 0;
	FILE *finput = 0;
	FILE **foutputs = 0;
	uint8_t *input_buffer = 0;
	uint8_t *second_buffer = 0;
	char test_magic[sizeof(magic)]; memset(test_magic, 0, sizeof(test_magic));
	
	/* Read the header */
	finput = open(input, open_mode_read);
	read_bytes(finput, test_magic, sizeof(test_magic), 1);
	if (memcmp(magic, test_magic, sizeof(magic)))
		input_error("\"%s\" is an invalid xorimg input", input);
	block_size = read_uint32(finput);
	expected_output_count = read_uint32(finput);
	if (outputs && expected_output_count != output_count)
		usage("Wrong output count"); /* Validate if we're not in info mode */
	
	/* If we're in info mode, print the information, and return */
	if (!outputs){
		printf("Summary of \"%s\":\n\tBlock size: %lu\n\tFile count: %lu\n",
			input,
			(unsigned long)block_size, (unsigned long)expected_output_count);
		return;
	}
	
	/* Open the files */
	foutputs = alloc(output_count * sizeof(FILE *));
	for (i = 0; i < output_count; i++)
		foutputs[i] = open(outputs[i], open_mode_write);
	
	/* Allocate the buffers */
	input_buffer = alloc(block_size);
	second_buffer = alloc(block_size);
	
	/* Read the blocks */
	while (outputs_open){
		uint8_t *output_buffer = 0;
		for (i = 0; i < output_count; i++){
			uint32_t input_size = 0;
			if (!foutputs[i]) continue; /* This input has already EOF'd */
			
			/* Read the current input. */
			input_size = read_uint32(finput);
			if (input_size > block_size) input_error("Input block size %lu is "
				"greater than file block size %lu",
					(unsigned long)input_size, (unsigned long)block_size);
			read_bytes(finput, input_buffer, input_size, 1);
			
			/* Perform xor if necessary, and set the output buffer */
			if (output_buffer) /* Xor with previous */
				xor(input_size, output_buffer, output_buffer, input_buffer);
			else { /* Swap: Output and input buffers must always differ. */
				output_buffer = input_buffer; /* Verbatum */
				swap_ptr(&input_buffer, &second_buffer);
			}
			
			/* Write the block, and prepare for xoring the next input with this
			   one. output_buffer != 0 marks that there is a previous block.
			   Since we're reading the actual input  block size, rather than
			   trying a full block size, we need to do the implicit zeros with
			   memset here. */
			memset(output_buffer + input_size, 0, block_size - input_size);
			write_bytes(foutputs[i], output_buffer, input_size);
			
			/* Close an input if we reached its EOF */
			if (input_size != block_size){
				fclose(foutputs[i]);
				foutputs[i] = 0;
				outputs_open--;
			}
		}
	}
}

int main (int argc, const char *const argv[])
{
	long int block_size = 0;
	if (argc) exec_name = argv[0]; /* This is used by usage() */
	
	/* Validate the arguments */
	if (argc < 3) usage("Insufficient arguments");
	if (strlen(argv[1]) != 1) usage("Invalid command");
	
	/* Perform mode specific operations */
	switch (argv[1][0]){
		case 'c': /* Compression */
			/* Read/validate compress specific arguments */
			if (argc < 4) usage("Insufficient arguments for compression");
			block_size = atol(argv[2]); if (!block_size) usage("0 block size");
			differ_strings(argv[3], argv + 4, argc - 4);
			
			/* Actually do the compression */
			compress(block_size, argv[3], argv + 4, argc - 4);
			break;
		case 'd': /* Decompression */
			differ_strings(argv[2], argv + 3, argc - 3); /* Validate */
			decompress(argv[2], argv + 3, argc - 3);
			break;
		case 'i':
			if (argc != 3) usage("Incorrect arguments for info");
			decompress(argv[2], 0, 0);
			break;
		default: usage("Unknown command"); /* Invalid mode */
	}
	
	/* If exit() hasn't been called, we were successful */
	return OK;
}
 
