Add a concrete executor for BIL

This commit is contained in:
Tim Becker 2015-02-20 11:08:29 -05:00
parent 02bca65965
commit 176b4697ff
2 changed files with 727 additions and 0 deletions

View File

@ -0,0 +1,397 @@
import abc
class BitVector(object):
__metaclass__ = abc.ABCMeta
""" Properties """
@abc.abstractproperty
def get_bits(self, low, high):
return
def get_high_bits(self, num):
size = self.get_size()
assert num <= size, "Cannot get {} bits from {}-bit bitvector".format(num, size)
return self.get_bits(size-num, size-1)
def get_low_bits(self, num):
size = self.get_size()
assert num <= size, "Cannot get {} bits from {}-bit bitvector".format(num, size)
return self.get_bits(0, num-1)
@abc.abstractproperty
def get_size(self):
return
@abc.abstractmethod
def signed(self):
return
""" Operations """
@abc.abstractmethod
def concat(self, other):
return
""" Arithmetic operations """
@abc.abstractmethod
def add(self, other):
return
@abc.abstractmethod
def sub(self, other):
return
@abc.abstractmethod
def mul(self, other):
return
@abc.abstractmethod
def div(self, other):
return
@abc.abstractmethod
def mod(self, other):
return
@abc.abstractmethod
def neg(self):
return
""" Bitwise operations """
@abc.abstractmethod
def band(self, other):
return
@abc.abstractmethod
def xor(self, other):
return
@abc.abstractmethod
def bor(self, other):
return
@abc.abstractmethod
def bnot(self):
return
@abc.abstractmethod
def lshift(self, shift):
return
@abc.abstractmethod
def rshift(self, shift):
return
@abc.abstractmethod
def arshift(self, shift):
return
""" Comparison operations """
@abc.abstractmethod
def eq(self, other):
return
@abc.abstractmethod
def neq(self, other):
return
@abc.abstractmethod
def lt(self, other):
return
@abc.abstractmethod
def le(self, other):
return
@abc.abstractmethod
def gt(self, other):
return
@abc.abstractmethod
def ge(self, other):
return
@abc.abstractmethod
def slt(self, other):
return
@abc.abstractmethod
def sle(self, other):
return
""" Overloading """
def reverse(op):
return lambda a,b : op(a.__class__(a.get_size(), b), a)
def __add__(self, other):
return self.add(other)
__radd__ = __add__
def __sub__(self, other):
return self.sub(other)
__rsub__ = reverse(__sub__)
def __mul__(self, other):
return self.mul(other)
__rmul__ = __mul__
def __div__(self, other):
return self.div(other)
__rdiv__ = reverse(__div__)
def __mod__(self, other):
return self.mod(other)
__rmod__ = reverse(__mod__)
def __neg__(self):
return self.neg()
def __and__(self, other):
return self.band(other)
__rand__ = __and__
def __xor__(self, other):
return self.xor(other)
__rxor__ = __xor__
def __or__(self, other):
return self.bor(other)
__ror__ = __or__
def __invert__(self):
return self.bnot()
def __lshift__(self, other):
return self.lshift(other)
__rlshift__ = reverse(__lshift__)
def __rshift__(self, other):
return self.rshift(other)
__rrshift__ = reverse(__rshift__)
def __eq__(self, other):
return self.eq(other)
def __ne__(self, other):
return self.neq(other)
def __lt__(self, other):
return self.lt(other)
def __le__(self, other):
return self.le(other)
def __gt__(self, other):
return self.gt(other)
def __ge__(self, other):
return self.ge(other)
class ConcreteBitVector(BitVector):
def __init__(self, size, value=0):
super(ConcreteBitVector, self).__init__()
self.size = size
self.value = value
bitmask = (1 << self.size) - 1
self.value &= bitmask
def get_bits(self, low, high):
length = high - low + 1
bitmask = (1 << length) - 1
value = self >> low
return ConcreteBitVector(length, int(value & bitmask))
def get_size(self):
return self.size
def signed(self):
mask = (1 << (self.get_size() - 1))
if self.value & mask:
return -(2**self.get_size() - self.value)
else:
return self.value
""" Operations """
def concat(self, other):
return ConcreteBitVector(self.size+other.size, (self.value << self.size) | other.value)
""" Arithmetic operations """
def add(self, other):
if isinstance(other, ConcreteBitVector):
size = max(self.size, other.size)
value = self.value + other.value
else:
size = self.size
value = self.value + other
return ConcreteBitVector(size, value)
def sub(self, other):
if isinstance(other, ConcreteBitVector):
size = max(self.size, other.size)
value = self.value - other.value
else:
size = self.size
value = self.value - other
return ConcreteBitVector(size, value)
def mul(self, other):
if isinstance(other, ConcreteBitVector):
size = max(self.size, other.size)
value = self.value * other.value
else:
size = self.size
value = self.value * other
return ConcreteBitVector(size, value)
def div(self, other):
if isinstance(other, ConcreteBitVector):
size = max(self.size, other.size)
value = self.value / other.value
else:
size = self.size
value = self.value / other
return ConcreteBitVector(size, value)
def mod(self, other):
if isinstance(other, ConcreteBitVector):
size = max(self.size, other.size)
value = self.value % other.value
else:
size = self.size
value = self.value % other
return ConcreteBitVector(size, value)
def neg(self, other):
return self.bnot().add(1)
""" Bitwise operations """
def band(self, other):
if isinstance(other, ConcreteBitVector):
size = max(self.size, other.size)
value = self.value & other.value
else:
size = self.size
value = self.value & other
return ConcreteBitVector(size, value)
def xor(self, other):
if isinstance(other, ConcreteBitVector):
size = max(self.size, other.size)
value = self.value ^ other.value
else:
size = self.size
value = self.value ^ other
return ConcreteBitVector(size, value)
def bor(self, other):
if isinstance(other, ConcreteBitVector):
size = max(self.size, other.size)
value = self.value | other.value
else:
size = self.size
value = self.value | other
return ConcreteBitVector(size, value)
def bnot(self):
bitmask = (1 << self.size) - 1
value = self.value ^ bitmask
size = self.size
return ConcreteBitVector(size, value)
def lshift(self, other):
size = self.size
if isinstance(other, ConcreteBitVector):
value = self.value << other.value
else:
value = self.value << other
return ConcreteBitVector(size, value)
def rshift(self, other):
size = self.size
if isinstance(other, ConcreteBitVector):
value = self.value >> other.value
else:
value = self.value >> other
return ConcreteBitVector(size, value)
def arshift(self, other):
if isinstance(other, ConcreteBitVector):
otherval = other.value
else:
otherval = other
highbit = (self.value & (1 << (self.size - 1))) >> (self.size - 1)
mask = ((1 << otherval) - 1) * highbit #111.. if hb == 1, 000.. otherwise
mask <<= (self.size - otherval) # shift to high bits of bitvector
newval = self.value >> otherval
newval |= mask # set the high bits correctly
return ConcreteBitVector(self.size, newval)
""" Comparison operations """
def eq(self, other):
if isinstance(other, ConcreteBitVector):
return self.value == other.value
else:
return self.value == other
def neq(self, other):
if isinstance(other, ConcreteBitVector):
return self.value != other.value
else:
return self.value != other
def lt(self, other):
if isinstance(other, ConcreteBitVector):
return self.value < other.value
else:
return self.value < other
def le(self, other):
if isinstance(other, ConcreteBitVector):
return self.value <= other.value
else:
return self.value <= other
def gt(self, other):
if isinstance(other, ConcreteBitVector):
return self.value > other.value
else:
return self.value > other
def ge(self, other):
if isinstance(other, ConcreteBitVector):
return self.value >= other.value
else:
return self.value >= other
def slt(self, other):
if isinstance(other, ConcreteBitVector):
return self.signed() < other.signed()
else:
return self.signed() < other
def sle(self, other):
if isinstance(other, ConcreteBitVector):
return self.signed() <= other.signed()
else:
return self.signed() <= other
def __str__(self):
return str(self.value)
def __repr__(self):
return self.__str__()
def __int__(self):
return self.value

View File

@ -0,0 +1,330 @@
#!/usr/bin/env python2.7
from bap import bil
from bap import adt
from functools import partial
from model import BapInsn
from bitvector import ConcreteBitVector
import collections
class Memory(dict):
def __init__(self, fetch_mem, initial=None):
self.fetch_mem = fetch_mem
if initial is not None:
self.update(initial)
def get_mem(self, addr, size, little_endian=True):
addr = int(addr)
result = "".join([self[address] for address in range(addr, addr+size)])
if little_endian: result = result[::-1]
return result
def set_mem(self, addr, size, val, little_endian=True):
for i in range(size):
addr = int(addr)
shift = i if little_endian else size-i-1
byteval = (val >> shift*8) & 0xff
self[addr+i] = chr(byteval)
def __getitem__(self, addr):
if addr not in self:
self[addr] = self.fetch_mem(addr, 1)
return self.get(addr)
class State:
def __init__(self, variables, get_mem, initial_mem=None):
self.variables = variables
self.memory = Memory(get_mem, initial_mem)
def get_mem(self, addr, size, little_endian=True):
return self.memory.get_mem(addr, size, little_endian)
def set_mem(self, addr, size, val, little_endian=True):
return self.memory.set_mem(addr, size, val, little_endian)
def __getitem__(self, name):
if isinstance(name, str):
return self.variables[name]
elif isinstance(name, int):
return self.memory(name)
elif isinstance(name, ConcreteBitVector):
return self.memory(int(name))
def __setitem__(self, name, val):
if isinstance(name, str):
self.variables[name] = val
else:
self.memory[int(name)] = val
def __str__(self):
return str(self.variables)
class VariableException(Exception):
pass
class MemoryException(Exception):
pass
class ConcreteExecutor(adt.Visitor):
def __init__(self, state, pc):
self.state = state
self.pc = pc
self.jumped = False
def visit_Load(self, op):
addr = self.run(op.idx)
mem = self.state.get_mem(addr, op.size / 8, isinstance(op.endian, bil.LittleEndian))
if len(mem) == 0:
raise MemoryException(addr)
return ConcreteBitVector(op.size, int(mem.encode('hex'), 16))
def visit_Store(self, op):
addr = self.run(op.idx)
val = self.run(op.value)
self.state.set_mem(addr, op.size / 8, val, isinstance(op.endian, bil.LittleEndian))
return op.mem
def visit_Var(self, op):
try:
return self.state[op.name]
except KeyError as e:
raise VariableException(op.name)
def visit_Int(self, op):
return ConcreteBitVector(op.size, op.value)
def visit_Let(self, op):
variables = self.state.variables
tmp = variables.get(op.var.name, None)
variables[op.var.name] = self.run(op.value)
result = self.run(op.expr)
if tmp is None:
variables.pop(op.var.name, None)
else:
variables[op.var.name] = tmp
return result
def visit_Unknown(self, op):
return ConcreteBitVector(1,0)
def visit_Ite(self, op):
return self.run(op.true) if self.run(op.cond) else self.run(op.false)
def visit_Extract(self, op):
return self.run(op.expr).get_bits(op.low_bit, op.high_bit)
def visit_Concat(self, op):
return self.run(op.lhs).concat(self.run(op.rhs))
def visit_Move(self, op):
if isinstance(op.var.type, bil.Imm):
self.state[op.var.name] = ConcreteBitVector(op.var.type.size, int(self.run(op.expr)))
else:
self.run(op.expr) # no need to store Mems
def visit_Jmp(self, op):
self.jumped = True
self.state[self.pc] = self.run(op.arg)
def visit_While(self, op):
while self.run(op.cond) == 1:
adt.visit(self, op.stmts)
def visit_If(self, op):
if self.run(op.cond) == 1:
adt.visit(self, op.true)
else:
adt.visit(self, op.false)
def visit_PLUS(self, op):
return self.run(op.lhs) + self.run(op.rhs)
def visit_MINUS(self, op):
return self.run(op.lhs) - self.run(op.rhs)
def visit_TIMES(self, op):
return self.run(op.lhs) * self.run(op.rhs)
def visit_DIVIDE(self, op):
return self.run(op.lhs) / self.run(op.rhs)
def visit_SDIVIDE(self, op):
return self.run(op.lhs) / self.run(op.rhs)
def visit_MOD(self, op):
return self.run(op.lhs) % self.run(op.rhs)
def visit_SMOD(self, op):
return self.run(op.lhs) % self.run(op.rhs)
def visit_LSHIFT(self, op):
return self.run(op.lhs) << self.run(op.rhs)
def visit_RSHIFT(self, op):
return self.run(op.lhs) >> self.run(op.rhs)
def visit_ARSHIFT(self, op):
return self.run(op.lhs).arshift(self.run(op.rhs))
def visit_AND(self, op):
return self.run(op.lhs) & self.run(op.rhs)
def visit_OR(self, op):
return self.run(op.lhs) | self.run(op.rhs)
def visit_XOR(self, op):
return self.run(op.lhs) ^ self.run(op.rhs)
def visit_EQ(self, op):
return ConcreteBitVector(1, 1 if self.run(op.lhs) == self.run(op.rhs) else 0)
def visit_NEQ(self, op):
return ConcreteBitVector(1, 1 if self.run(op.lhs) != self.run(op.rhs) else 0)
def visit_LT(self, op):
return ConcreteBitVector(1, 1 if self.run(op.lhs) < self.run(op.rhs) else 0)
def visit_LE(self, op):
return ConcreteBitVector(1, 1 if self.run(op.lhs) <= self.run(op.rhs) else 0)
def visit_SLT(self, op):
return ConcreteBitVector(1, 1 if self.run(op.lhs).slt(self.run(op.rhs)) else 0)
def visit_SLE(self, op):
return ConcreteBitVector(1, 1 if self.run(op.lhs).sle(self.run(op.rhs)) else 0)
def visit_NEG(self, op):
return -self.run(op.arg)
def visit_NOT(self, op):
return ~self.run(op.arg)
def visit_UNSIGNED(self, op):
return ConcreteBitVector(op.size, int(self.run(op.expr)))
def visit_SIGNED(self, op):
return ConcreteBitVector(op.size, int(self.run(op.expr)))
def visit_HIGH(self, op):
return self.run(op.expr).get_high_bits(op.size)
def visit_LOW(self, op):
return self.run(op.expr).get_low_bits(op.size)
class Issue:
def __init__(self, clnum, insn, message):
self.clnum = clnum
self.insn = insn
self.message = message
class Warning(Issue):
pass
class Error(Issue):
pass
def validate_bil(program, flow):
r"""
Runs the concrete executor, validating the the results are consistent with the trace.
Returns a tuple of (Errors, Warnings)
Currently only supports ARM, x86, and x86-64
"""
trace = program.traces[0]
libraries = [(m[3],m[1]) for m in trace.mapped]
registers = program.tregs[0]
regsize = 8 * program.tregs[1]
arch = program.tregs[-1]
if arch == "arm":
cpu_flags = ["ZF", "CF", "NF", "VF"]
PC = "PC"
elif arch == "i386":
cpu_flags = ["CF", "PF", "AF", "ZF", "SF", "OF", "DF"]
PC = "EIP"
elif arch == "x86-64":
cpu_flags = ["CF", "PF", "AF", "ZF", "SF", "OF", "DF"]
PC = "RIP"
else:
print "Architecture not supported"
return [],[]
errors = []
warnings = []
def new_state_for_clnum(clnum, include_flags=True):
flags = cpu_flags if include_flags else []
flagvalues = [0 for f in flags]
varnames = registers + flags
initial_regs = trace.db.fetch_registers(clnum)
varvals = initial_regs + flagvalues
varvals = map(lambda x: ConcreteBitVector(regsize, x), varvals)
initial_vars = dict(zip(varnames, varvals))
initial_mem_get = partial(trace.fetch_raw_memory, clnum)
return State(initial_vars, initial_mem_get)
state = new_state_for_clnum(0)
for (addr,data,clnum,ins) in flow:
instr = program.static[addr]['instruction']
if not isinstance(instr, BapInsn):
errors.append(Error(clnum, instr, "Could not make BAP instruction for %s" % str(instr)))
state = new_state_for_clnum(clnum)
else:
bil_instrs = instr.insn.bil
if bil_instrs is None:
errors.append(Error(clnum, instr, "No BIL for instruction %s" % str(instr)))
state = new_state_for_clnum(clnum)
else:
# this is bad.. fix this
if arch == "arm":
state[PC] += 8 #Qira PC is wrong
executor = ConcreteExecutor(state, PC)
try:
adt.visit(executor, bil_instrs)
except VariableException as e:
errors.append(Error(clnum, instr, "No BIL variable %s!" % str(e.args[0])))
except MemoryException as e:
errors.append(Error(clnum, instr, "Used invalid address %x." % e.args[0]))
if not executor.jumped:
if arch == "arm":
state[PC] -= 4
elif arch == "i386" or arch == "x86-64":
state[PC] += instr.size()
validate = True
PC_val = state[PC]
if PC_val > 0xf0000000 or any([PC_val >= base and PC_val <= base+size for (base,size) in libraries]):
# we are jumping into a library that we can't trace.. reset the state and continue
warnings.append(Warning(clnum, instr, "Jumping into library. Cannot trace this"))
state = new_state_for_clnum(clnum)
continue
error = False
correct_regs = new_state_for_clnum(clnum, include_flags=False).variables
for reg, correct in correct_regs.iteritems():
if state[reg] != correct:
error = True
errors.append(Error(clnum, instr, "%s was incorrect! (%x != %x)." % (reg, state[reg] , correct)))
state[reg] = correct
for (addr, val) in state.memory.items():
realval = trace.fetch_raw_memory(clnum, addr, 1)
if len(realval) == 0 or len(val) == 0:
errors.append(Error(clnum, instr, "Used invalid address %x." % addr))
# this is unfixable, reset state
state = new_state_for_clnum(clnum)
elif val != realval:
error = True
errors.append(Error(clnum, instr, "Value at address %x is wrong! (%x != %x)." % (addr, ord(val), ord(realval))))
state[addr] = realval
return (errors, warnings)