#include "common.h"
#include "hook.h"
#include "game.h"
#include "etpro.h"
#include "detours.h"

int hook_vmMain(int command, int a0, int a1, int a2, int a3, int a4, int a5, int a6, int a7, int a8, int a9, int aa, int ab) {

	if (game.ETPro_AC) {
		*(void **) ETPRO_AC2_VM_LOCATION = (void *) ((unsigned int) game.cgame + ETPRO_AC2_VM_ORIG);
	}

	int ret = game.cg_vmMain(command, a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, aa, ab);

	if (game.ETPro_AC) {
		*(void **) ETPRO_AC2_VM_LOCATION = game_vmMain;
	}

	return ret;
}

int hook_syscall(int command, int a0, int a1, int a2, int a3, int a4, int a5, int a6, int a7, int a8, int a9, int aa) {
	return game.syscall(command, a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, aa);
}

static usercmd_t hook_CreateCmd(void) {
	return game_CreateCmd();
}

static void hook_JoyMove(void) {
	game_JoyMove();
}

static void hook_ETPro_AC(void *a, void *b, void *c, int checksum, void *e, char *orig_guid) {
	game.ETPro_AC(a, b, c, ETPRO_AC_CHECKSUM, e, orig_guid);
}

static void* hook_ETPro_AC2(DWORD a) {
	*(void **) ETPRO_AC2_VM_LOCATION = (void *) ((unsigned int) game.cgame + ETPRO_AC2_VM_ORIG);
	void *p = game.ETPro_AC2(a);
	*(void **) ETPRO_AC2_VM_LOCATION = game_vmMain;
	*(int *) ((unsigned int) game.cgame + ETPRO_AC2_SECURITY) = 0;
	return p;
}

static void hook_CG_dllEntry(int(*syscallptr)(int arg, ...)) {
	game.syscall = syscallptr;
	game.cg_dllEntry((int (*)(int, ...)) game_syscall);
}

static HINSTANCE __stdcall hook_LoadLibraryA(LPCSTR lpLibName) {

	HMODULE ret = game.LoadLibraryA(lpLibName);

	if (strstr(lpLibName, "\\cgame_mp_x86.dll")) {
		game.cgame = ret;
	}

	if (strstr(lpLibName, "etpro\\cgame_mp_x86.dll")) {
		game.ETPro_AC = (void (*)(void *, void *, void *, int, void *, char *)) DetourFunction((PBYTE) ret + ETPRO_AC_LOCATION, (PBYTE) hook_ETPro_AC);
		game.ETPro_AC2 = (void* (*)(DWORD)) DetourFunction((PBYTE) ret + ETPRO_AC2_LOCATION, (PBYTE) hook_ETPro_AC2);
	}

	return ret;

}

static FARPROC __stdcall hook_GetProcAddress(HMODULE hModule, LPCSTR lpProcName) {

	FARPROC ret = game.GetProcAddress(hModule, lpProcName);

	if (hModule == game.cgame) {

		if (!strcmp(lpProcName, "vmMain")) {
			game.cg_vmMain = (int(*)(int, ...)) ret;
			ret = (void*) game_vmMain;
		} else if (!strcmp(lpProcName, "dllEntry")) {
			game.cg_dllEntry = (void(*)(int(*)(int, ...))) ret;
			ret = (void*) hook_CG_dllEntry;
		}

	}

	return ret;

}

BOOL __stdcall DllMain(HMODULE module, DWORD reason, LPVOID reserved) {

	if (reason == DLL_PROCESS_ATTACH) {

		game.LoadLibraryA   = (void *) DetourFunction((void *) LoadLibraryA, (void *) hook_LoadLibraryA);
		game.GetProcAddress = (void *) DetourFunction((void *) GetProcAddress, (void *) hook_GetProcAddress);
		game.CL_CreateCmd   = (void *) DetourFunction((PBYTE) ADDR_CREATE_CMD, (void *) hook_CreateCmd);
		game.IN_JoyMove     = (void *) DetourFunction((PBYTE) ADDR_JOY_MOVE, (void *) hook_JoyMove);

		game.Sys_QueEvent   = (void (*)(int, sysEventType_t, int, int, int, void *)) ADDR_QUE_EVENT;
		Com_Printf          = (void (*)(const char *, ...))                          ADDR_COM_PRINTF;
		Cmd_Argv            = (char* (*)(int))                                       ADDR_COM_ARGV;
		Cvar_Get            = (cvar_t *(*)(const char *, const char *, int))         ADDR_CVAR_GET;
		cvar_vars           = (void *)                                               ADDR_CVAR_VARS;
		game.mouseX         = (float *)                                              ADDR_MOUSE_X;
		game.mouseY         = (float *)                                              ADDR_MOUSE_Y;

	}

	return TRUE;

}