#define _GNU_SOURCE

#include <link.h>
#include <stdio.h>
#include <string.h>
#include <errno.h>
#include <stdlib.h>

#define LIB_NAME                       "pbss_rce_fix"
#define PBSV_LIB                       "pbsv.so"
#define PBSS_PNG_RETURN_ADDR  (void *) 0x11e696
#define PBSS_HTM1_RETURN_ADDR (void *) 0x0db819
#define PBSS_HTM2_RETURN_ADDR (void *) 0x11e7a8
#define PBSS_HTM3_RETURN_ADDR (void *) 0x0ba03a
#define MAX_OSPATH                     256
#define CVAR_INIT                      16
#define VER_MEM_CHECK                  0x8d1cec83
#define FUNC_COM_PRINTF                0x806C450
#define FUNC_CVAR_GET                  0x8071E40
#define DISABLE_ALTOGETHER             1

#if !DISABLE_ALTOGETHER

typedef struct cvar_s {
	char          *name;
	char          *string;
	char          *resetString;
	char          *latchedString;
	int           flags;
	int           modified;
	int           modificationCount;
	float         value;
	int           integer;
	struct cvar_s *next;
	struct cvar_s *hashNext;
} cvar_t;

cvar_t* (*Cvar_Get)(const char *var_name, const char *var_value, int flags);

#endif

FILE*   (*orig_fopen)(const char *filename, const char* mode);
void    (*Com_Printf)(const char *msg, ...);

static int dl_iterate_phdr_callback(struct dl_phdr_info *info, size_t size, void *data) {

	for (int i = 0; i < info->dlpi_phnum; i++) {

		if (
			   (unsigned int) data > (info->dlpi_addr + info->dlpi_phdr[i].p_vaddr)
			&& (unsigned int) data < (info->dlpi_addr + info->dlpi_phdr[i].p_vaddr + info->dlpi_phdr[i].p_memsz)
			&& !strcmp(PBSV_LIB, basename(info->dlpi_name))
			&& (data - info->dlpi_addr + info->dlpi_phdr[i].p_vaddr == PBSS_PNG_RETURN_ADDR  ||
			    data - info->dlpi_addr + info->dlpi_phdr[i].p_vaddr == PBSS_HTM1_RETURN_ADDR ||
			    data - info->dlpi_addr + info->dlpi_phdr[i].p_vaddr == PBSS_HTM2_RETURN_ADDR ||
			    data - info->dlpi_addr + info->dlpi_phdr[i].p_vaddr == PBSS_HTM3_RETURN_ADDR)) {
			return 1;
		}

	}

	return 0;

}

static int check_ss_name(const char *filename) {

#if DISABLE_ALTOGETHER
	return 0;
#else

	unsigned int len = strlen(filename);

	if (len < 4 || len > MAX_OSPATH) {
		return 0;
	}

	if (   strcasecmp(filename + len - 4, ".png") != 0
	    && strcasecmp(filename + len - 4, ".htm") != 0) {
		return 0;
	}

	cvar_t* fs_homepath = Cvar_Get("fs_homepath", "", CVAR_INIT);
	unsigned int hpLen  = strlen(fs_homepath->string);

	for (int i = 0; i < len; i++) {

		if (i < hpLen) {

			if (filename[i] != fs_homepath->string[i]) {
				return 0;
			}

			continue;

		}

		if (i == hpLen) {

			if (strncmp("/pb/svss/pb", filename + i, 11) != 0) {
				return 0;
			}

			i += 10;
			continue;

		}

		if (i == len - 4) {
			break;
		}

		if (filename[i] < '0' || filename[i] > '9') {
			return 0;
		}

	}

	return 1;

#endif

}

FILE* fopen(const char* filename, const char* mode) {

	if (dl_iterate_phdr(dl_iterate_phdr_callback, __builtin_return_address(0)) && !check_ss_name(filename)) {

#if !DISABLE_ALTOGETHER
		Com_Printf("%s: intercepted possibly malicious fopen(%s)\n", LIB_NAME, filename);
#endif

		errno = EACCES;
		return NULL;

	}

	return orig_fopen(filename, mode);

}

void __attribute__ ((constructor)) init(void) {

	if (*(int *) FUNC_COM_PRINTF != VER_MEM_CHECK) {
		fprintf(stderr, "%s: memory check failed - incompatible etded binary\n", LIB_NAME);
		exit(1);
	}

	void *libc_handle = dlopen("libc.so.6", RTLD_LAZY);
	orig_fopen        = dlsym(libc_handle,"fopen");

	Com_Printf  = (void *) FUNC_COM_PRINTF;

#if !DISABLE_ALTOGETHER
	Cvar_Get    = (void *) FUNC_CVAR_GET;
#endif

}