diff --git a/bindings/python/unicorn/unicorn.py b/bindings/python/unicorn/unicorn.py index 58c6ee06..19df82e0 100644 --- a/bindings/python/unicorn/unicorn.py +++ b/bindings/python/unicorn/unicorn.py @@ -3,6 +3,7 @@ import ctypes import ctypes.util import distutils.sysconfig +from functools import wraps import pkg_resources import inspect import os.path @@ -307,6 +308,27 @@ def reg_write(reg_write_func, arch, reg_id, value): return +def _catch_hook_exception(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + """Catches exceptions raised in hook functions. + + If an exception is raised, it is saved to the Uc object and a call to stop + emulation is issued. + """ + try: + return func(self, *args, **kwargs) + except Exception as e: + # If multiple hooks raise exceptions, just use the first one + if self._hook_exception is None: + self._hook_exception = e + + self.emu_stop() + + return wrapper + + + class uc_x86_mmr(ctypes.Structure): """Memory-Management Register for instructions IDTR, GDTR, LDTR, TR.""" _fields_ = [ @@ -410,6 +432,7 @@ class Uc(object): self._ctype_cbs = [] self._callback_count = 0 self._cleanup.register(self) + self._hook_exception = None # The exception raised in a hook @staticmethod def release_handle(uch): @@ -427,6 +450,9 @@ class Uc(object): if status != uc.UC_ERR_OK: raise UcError(status) + if self._hook_exception is not None: + raise self._hook_exception + # stop emulation def emu_stop(self): status = _uc.uc_emu_stop(self._uch) @@ -522,41 +548,49 @@ class Uc(object): raise UcError(status) return result.value + @_catch_hook_exception def _hookcode_cb(self, handle, address, size, user_data): # call user's callback with self object (cb, data) = self._callbacks[user_data] cb(self, address, size, data) + @_catch_hook_exception def _hook_mem_invalid_cb(self, handle, access, address, size, value, user_data): # call user's callback with self object (cb, data) = self._callbacks[user_data] return cb(self, access, address, size, value, data) + @_catch_hook_exception def _hook_mem_access_cb(self, handle, access, address, size, value, user_data): # call user's callback with self object (cb, data) = self._callbacks[user_data] cb(self, access, address, size, value, data) + @_catch_hook_exception def _hook_intr_cb(self, handle, intno, user_data): # call user's callback with self object (cb, data) = self._callbacks[user_data] cb(self, intno, data) + @_catch_hook_exception def _hook_insn_invalid_cb(self, handle, user_data): # call user's callback with self object (cb, data) = self._callbacks[user_data] return cb(self, data) + @_catch_hook_exception def _hook_insn_in_cb(self, handle, port, size, user_data): # call user's callback with self object (cb, data) = self._callbacks[user_data] return cb(self, port, size, data) + @_catch_hook_exception def _hook_insn_out_cb(self, handle, port, size, value, user_data): # call user's callback with self object (cb, data) = self._callbacks[user_data] cb(self, port, size, value, data) + @_catch_hook_exception def _hook_insn_syscall_cb(self, handle, user_data): # call user's callback with self object (cb, data) = self._callbacks[user_data] diff --git a/tests/regress/hook_raises_exception.py b/tests/regress/hook_raises_exception.py new file mode 100644 index 00000000..0c7b5bc1 --- /dev/null +++ b/tests/regress/hook_raises_exception.py @@ -0,0 +1,39 @@ +import regress +from unicorn import Uc, UC_ARCH_X86, UC_MODE_64, UC_HOOK_CODE + +CODE = b"\x90" * 3 +CODE_ADDR = 0x1000 + + +class HookCounter(object): + """Counts number of hook calls.""" + + def __init__(self): + self.hook_calls = 0 + + def bad_code_hook(self, uc, address, size, data): + self.hook_calls += 1 + raise ValueError("Something went wrong") + + def good_code_hook(self, uc, address, size, data): + self.hook_calls += 1 + + +class TestExceptionInHook(regress.RegressTest): + + def test_exception_in_hook(self): + uc = Uc(UC_ARCH_X86, UC_MODE_64) + uc.mem_map(CODE_ADDR, 0x1000) + uc.mem_write(CODE_ADDR, CODE) + + counter = HookCounter() + uc.hook_add(UC_HOOK_CODE, counter.good_code_hook, begin=CODE_ADDR, end=CODE_ADDR + len(CODE)) + uc.hook_add(UC_HOOK_CODE, counter.bad_code_hook, begin=CODE_ADDR, end=CODE_ADDR + len(CODE)) + + self.assertRaises(ValueError, uc.emu_start, CODE_ADDR, CODE_ADDR + len(CODE)) + # Make sure hooks calls finish before raising (hook_calls == 2) + self.assertEqual(counter.hook_calls, 2) + + +if __name__ == "__main__": + regress.main()