Commit 664beb99 authored by dbarbera's avatar dbarbera
Browse files

Refactor global state into global state object

parent f82830f4
......@@ -19,7 +19,7 @@
import logging
from singledispatch import singledispatch
from llvm import core, passes, ee
from llvm import core
import ogAST
import Helper
......@@ -29,29 +29,41 @@ LOG = logging.getLogger(__name__)
__all__ = ['generate']
# LLVM Global variable - Initialized when the generator is invoked
LLVM = {
# The LLVM module, which holds all the IR code.
'module': None,
# Dictionary that keeps track of which values are defined in the current
# scope and what their LLVM representation is.
'named_values': {},
# Dictionary that keeps track of the defined states and its integer
# constant representation
'states': {},
# The builder used for the current function generation.
'builder': None,
# The function optimization passes manager.
'pass_manager': None,
# The LLVM execution engine.
'executor': None,
# ASN.1 data view
'dataview': None,
# Generated types
'types': {},
# Global strings
'strings': {}
}
# Global state
g = None
class GlobalState():
def __init__(self, process):
self.module = core.Module.new(str(process.processName))
self.dataview = process.dataview
self.scope = {}
self.states = {}
self.types = {}
self.strings = {}
# Initialize built-in types
self.i1 = core.Type.int(1)
self.i8 = core.Type.int(8)
self.i32 = core.Type.int(32)
self.i64 = core.Type.int(64)
self.void = core.Type.void()
self.double = core.Type.double()
self.i1_ptr = core.Type.pointer(self.i1)
self.i8_ptr = core.Type.pointer(self.i8)
self.i32_ptr = core.Type.pointer(self.i32)
self.i64_ptr = core.Type.pointer(self.i64)
self.double_ptr = core.Type.pointer(self.double)
# Intialize built-in functions
ty = core.Type.function(self.void, [core.Type.pointer(self.i8)], True)
self.printf = self.module.add_function(ty, 'printf')
self.memcpy = core.Function.intrinsic(
self.module, core.INTR_MEMCPY,
[self.i8_ptr, self.i8_ptr, self.i64]
)
@singledispatch
......@@ -61,39 +73,13 @@ def generate(ast):
# Processing of the AST
@generate.register(ogAST.Process)
def _process(process):
''' Generate LLVM IR code '''
LOG.info('Generating LLVM IR code for process ' + str(process.processName))
process_name = str(process.processName)
LOG.info('Generating LLVM IR code for process ' + str(process_name))
# Initialize LLVM global structure
LLVM['module'] = core.Module.new(process_name)
LLVM['pass_manager'] = passes.FunctionPassManager.new(LLVM['module'])
LLVM['executor'] = ee.ExecutionEngine.new(LLVM['module'])
LLVM['dataview'] = process.dataview
# Set up the optimizer pipeline.
# Start with registering info about how the
# target lays out data structures.
# LLVM['pass_manager'].add(LLVM['executor'].target_data)
# # Do simple "peephole" optimizations and bit-twiddling optzns.
# LLVM['pass_manager'].add(passes.PASS_INSTRUCTION_COMBINING)
# # Reassociate expressions.
# LLVM['pass_manager'].add(passes.PASS_REASSOCIATE)
# # Eliminate Common SubExpressions.
# LLVM['pass_manager'].add(passes.PASS_GVN)
# # Simplify the control flow graph (deleting unreachable blocks, etc).
# LLVM['pass_manager'].add(passes.PASS_CFG_SIMPLIFICATION)
# LLVM['pass_manager'].initialize()
# Initialize built-in functions
printf_type = core.Type.function(
core.Type.void(),
[core.Type.pointer(core.Type.int(8))], True
)
LLVM['module'].add_function(printf_type, 'printf')
global g
g = GlobalState(process)
# In case model has nested states, flatten everything
Helper.flatten(process)
......@@ -105,47 +91,47 @@ def _process(process):
# Initialize states enum
for name in process.mapping.iterkeys():
if not name.endswith('START'):
cons = core.Constant.int(core.Type.int(), len(LLVM['states']))
LLVM['states'][name] = cons
cons = core.Constant.int(g.i32, len(g.states))
g.states[name] = cons
# Generate state var
LLVM['module'].add_global_variable(core.Type.int(), 'state')
g.module.add_global_variable(g.i32, 'state')
# Initialize output signals
for signal in process.output_signals:
param_tys = [core.Type.pointer(_generate_type(signal['type']))]
func_ty = core.Type.function(core.Type.void(), param_tys)
core.Function.new(LLVM['module'], func_ty, str(signal['name']))
func_ty = core.Type.function(g.void, param_tys)
core.Function.new(g.module, func_ty, str(signal['name']))
# Initialize external procedures
for proc in [proc for proc in process.procedures if proc.external]:
param_tys = [core.Type.pointer(_generate_type(p['type'])) for p in proc.fpar]
func_ty = core.Type.function(core.Type.void(), param_tys)
core.Function.new(LLVM['module'], func_ty, str(proc.inputString))
func_ty = core.Type.function(g.void, param_tys)
core.Function.new(g.module, func_ty, str(proc.inputString))
# Generare process-level vars
for var_name, (var_asn1_type, def_value) in process.variables.viewitems():
var_type = _generate_type(var_asn1_type)
LLVM['module'].add_global_variable(var_type, str(var_name).lower())
g.module.add_global_variable(var_type, str(var_name).lower())
if def_value:
raise NotImplementedError
# Generate process functions
runtr_func = _generate_runtr_func(process)
_generate_startup_func(process, process_name, runtr_func)
_generate_startup_func(process, str(process.processName), runtr_func)
# Generate input signals
for signal in process.input_signals:
_generate_input_signal(signal, mapping[signal['name']])
print LLVM['module']
print g.module
def _generate_runtr_func(process):
''' Generate code for the run_transition function '''
func_name = 'run_transition'
func_type = core.Type.function(core.Type.void(), [core.Type.int()])
func = core.Function.new(LLVM['module'], func_type, func_name)
func_type = core.Type.function(g.void, [g.i32])
func = core.Function.new(g.module, func_type, func_name)
entry_block = func.append_basic_block('entry')
cond_block = func.append_basic_block('cond')
......@@ -153,18 +139,18 @@ def _generate_runtr_func(process):
exit_block = func.append_basic_block('exit')
builder = core.Builder.new(entry_block)
LLVM['builder'] = builder
g.builder = builder
# entry
id_ptr = builder.alloca(core.Type.int(), None, 'id')
LLVM['named_values']['id'] = id_ptr
id_ptr = builder.alloca(g.i32, None, 'id')
g.scope['id'] = id_ptr
builder.store(func.args[0], id_ptr)
builder.branch(cond_block)
# cond
builder.position_at_end(cond_block)
id_ptr = func.args[0]
no_tr_cons = core.Constant.int(core.Type.int(), -1)
no_tr_cons = core.Constant.int(g.i32, -1)
cond_val = builder.icmp(core.ICMP_NE, id_ptr, no_tr_cons, 'cond')
builder.cbranch(cond_val, body_block, exit_block)
......@@ -175,7 +161,7 @@ def _generate_runtr_func(process):
# transitions
for idx, tr in enumerate(process.transitions):
tr_block = func.append_basic_block('tr%d' % idx)
const = core.Constant.int(core.Type.int(), idx)
const = core.Constant.int(g.i32, idx)
switch.add_case(const, tr_block)
builder.position_at_end(tr_block)
generate(tr)
......@@ -186,19 +172,19 @@ def _generate_runtr_func(process):
builder.ret_void()
func.verify()
LLVM['named_values'].clear()
g.scope.clear()
return func
def _generate_startup_func(process, process_name, runtr_func):
''' Generate code for the startup function '''
func_name = process_name + '_startup'
func_type = core.Type.function(core.Type.void(), [])
func = core.Function.new(LLVM['module'], func_type, func_name)
func_type = core.Type.function(g.void, [])
func = core.Function.new(g.module, func_type, func_name)
entry_block = func.append_basic_block('entry')
builder = core.Builder.new(entry_block)
LLVM['builder'] = builder
g.builder = builder
# entry
builder.call(runtr_func, [core.Constant.int(core.Type.int(), 0)])
......@@ -214,39 +200,38 @@ def _generate_input_signal(signal, inputs):
param_tys = []
if 'type' in signal:
param_tys.append(core.Type.pointer(_generate_type(signal['type'])))
func_type = core.Type.function(core.Type.void(), param_tys)
func = core.Function.new(LLVM['module'], func_type, func_name)
func_type = core.Type.function(g.void, param_tys)
func = core.Function.new(g.module, func_type, func_name)
entry_block = func.append_basic_block('entry')
exit_block = func.append_basic_block('exit')
builder = core.Builder.new(entry_block)
LLVM['builder'] = builder
g.builder = core.Builder.new(entry_block)
runtr_func = LLVM['module'].get_function_named('run_transition')
runtr_func = g.module.get_function_named('run_transition')
g_state_val = builder.load(LLVM['module'].get_global_variable_named('state'))
switch = builder.switch(g_state_val, exit_block)
g_state_val = g.builder.load(g.module.get_global_variable_named('state'))
switch = g.builder.switch(g_state_val, exit_block)
for state_name, state_id in LLVM['states'].iteritems():
for state_name, state_id in g.states.iteritems():
state_block = func.append_basic_block('state_%s' % str(state_name))
switch.add_case(state_id, state_block)
builder.position_at_end(state_block)
g.builder.position_at_end(state_block)
# TODO: Nested states
input = inputs.get(state_name)
if input:
for var_name in input.parameters:
var_val = LLVM['module'].get_global_variable_named(str(var_name).lower())
var_val = g.module.get_global_variable_named(str(var_name).lower())
_generate_assign(var_val, func.args[0])
if input.transition:
id_val = core.Constant.int(core.Type.int(), input.transition_id)
builder.call(runtr_func, [id_val])
id_val = core.Constant.int(g.i32, input.transition_id)
g.builder.call(runtr_func, [id_val])
builder.ret_void()
g.builder.ret_void()
builder.position_at_end(exit_block)
builder.ret_void()
g.builder.position_at_end(exit_block)
g.builder.ret_void()
func.verify()
......@@ -269,32 +254,31 @@ def _call_external_function(output):
_generate_set_timer(out['params'])
continue
func = LLVM['module'].get_function_named(str(name))
LLVM['builder'].call(func, [expression(p) for p in out.get('params', [])])
func = g.module.get_function_named(str(name))
g.builder.call(func, [expression(p) for p in out.get('params', [])])
def _generate_write(params):
''' Generate the code for the write operator '''
zero = core.Constant.int(core.Type.int(), 0)
zero = core.Constant.int(g.i32, 0)
for param in params:
basic_ty = find_basic_type(param.exprType)
expr_val = expression(param)
printf_func = LLVM['module'].get_function_named('printf')
if basic_ty.kind == 'IntegerType':
fmt_val = _get_string_cons('%d')
fmt_ptr = LLVM['builder'].gep(fmt_val, [zero, zero])
LLVM['builder'].call(printf_func, [fmt_ptr, expr_val])
fmt_ptr = g.builder.gep(fmt_val, [zero, zero])
g.builder.call(g.printf, [fmt_ptr, expr_val])
elif basic_ty.kind == 'RealType':
fmt_val = _get_string_cons('%.14E')
fmt_ptr = LLVM['builder'].gep(fmt_val, [zero, zero])
LLVM['builder'].call(printf_func, [fmt_ptr, expr_val])
fmt_ptr = g.builder.gep(fmt_val, [zero, zero])
g.builder.call(g.printf, [fmt_ptr, expr_val])
elif basic_ty.kind == 'BooleanType':
true_str_val = _get_string_cons('true')
true_str_ptr = LLVM['builder'].gep(true_str_val, [zero, zero])
true_str_ptr = g.builder.gep(true_str_val, [zero, zero])
false_str_val = _get_string_cons('false')
false_str_ptr = LLVM['builder'].gep(false_str_val, [zero, zero])
str_ptr = LLVM['builder'].select(expr_val, true_str_ptr, false_str_ptr)
LLVM['builder'].call(printf_func, [str_ptr])
false_str_ptr = g.builder.gep(false_str_val, [zero, zero])
str_ptr = g.builder.select(expr_val, true_str_ptr, false_str_ptr)
g.builder.call(g.printf, [str_ptr])
else:
raise NotImplementedError
......@@ -302,9 +286,11 @@ def _generate_write(params):
def _generate_writeln(params):
''' Generate the code for the writeln operator '''
_generate_write(params)
zero = core.Constant.int(g.i32, 0)
str_cons = _get_string_cons('\n')
str_ptr = LLVM['builder'].gep(str_cons, [zero, zero])
LLVM['builder'].call(printf_func, [str_ptr])
str_ptr = g.builder.gep(str_cons, [zero, zero])
g.builder.call(g.printf, [str_ptr])
def _generate_reset_timer(params):
......@@ -347,7 +333,7 @@ def expression(expr):
@expression.register(ogAST.PrimVariable)
def _primary_variable(prim):
''' Generate the code for a single variable reference '''
return LLVM['module'].get_global_variable_named(str(prim.value[0]).lower())
return g.module.get_global_variable_named(str(prim.value[0]).lower())
@expression.register(ogAST.PrimPath)
......@@ -370,77 +356,76 @@ def _prim_path(primary_id):
@expression.register(ogAST.ExprRem)
def _basic(expr):
''' Generate the code for an arithmetic of relational expression '''
builder = LLVM['builder']
lefttmp = expression(expr.left)
righttmp = expression(expr.right)
# load the value of the expression if it is a pointer
if lefttmp.type.kind == core.TYPE_POINTER:
lefttmp = builder.load(lefttmp, 'lefttmp')
lefttmp = g.builder.load(lefttmp, 'lefttmp')
if righttmp.type.kind == core.TYPE_POINTER:
righttmp = builder.load(righttmp, 'lefttmp')
righttmp = g.builder.load(righttmp, 'lefttmp')
if lefttmp.type.kind != righttmp.type.kind:
raise NotImplementedError
if lefttmp.type.kind == core.TYPE_INTEGER:
if expr.operand == '+':
return builder.add(lefttmp, righttmp, 'addtmp')
return g.builder.add(lefttmp, righttmp, 'addtmp')
elif expr.operand == '-':
return builder.sub(lefttmp, righttmp, 'subtmp')
return g.builder.sub(lefttmp, righttmp, 'subtmp')
elif expr.operand == '*':
return builder.mul(lefttmp, righttmp, 'multmp')
return g.builder.mul(lefttmp, righttmp, 'multmp')
elif expr.operand == '/':
return builder.sdiv(lefttmp, righttmp, 'divtmp')
return g.builder.sdiv(lefttmp, righttmp, 'divtmp')
elif expr.operand == 'mod':
# l mod r == (((l rem r) + r) rem r)
remtmp = builder.srem(lefttmp, righttmp)
addtmp = builder.add(remtmp, righttmp)
return builder.srem(addtmp, righttmp, 'modtmp')
remtmp = g.builder.srem(lefttmp, righttmp)
addtmp = g.builder.add(remtmp, righttmp)
return g.builder.srem(addtmp, righttmp, 'modtmp')
elif expr.operand == 'rem':
return builder.srem(lefttmp, righttmp, 'remtmp')
return g.builder.srem(lefttmp, righttmp, 'remtmp')
elif expr.operand == '<':
return builder.icmp(core.ICMP_SLT, lefttmp, righttmp, 'lttmp')
return g.builder.icmp(core.ICMP_SLT, lefttmp, righttmp, 'lttmp')
elif expr.operand == '<=':
return builder.icmp(core.ICMP_SLE, lefttmp, righttmp, 'letmp')
return g.builder.icmp(core.ICMP_SLE, lefttmp, righttmp, 'letmp')
elif expr.operand == '=':
return builder.icmp(core.ICMP_EQ, lefttmp, righttmp, 'eqtmp')
return g.builder.icmp(core.ICMP_EQ, lefttmp, righttmp, 'eqtmp')
elif expr.operand == '/=':
return builder.icmp(core.ICMP_NE, lefttmp, righttmp, 'netmp')
return g.builder.icmp(core.ICMP_NE, lefttmp, righttmp, 'netmp')
elif expr.operand == '>=':
return builder.icmp(core.ICMP_SGE, lefttmp, righttmp, 'getmp')
return g.builder.icmp(core.ICMP_SGE, lefttmp, righttmp, 'getmp')
elif expr.operand == '>':
return builder.icmp(core.ICMP_SGT, lefttmp, righttmp, 'gttmp')
return g.builder.icmp(core.ICMP_SGT, lefttmp, righttmp, 'gttmp')
else:
raise NotImplementedError
elif lefttmp.type.kind == core.TYPE_DOUBLE:
if expr.operand == '+':
return builder.fadd(lefttmp, righttmp, 'addtmp')
return g.builder.fadd(lefttmp, righttmp, 'addtmp')
elif expr.operand == '-':
return builder.fsub(lefttmp, righttmp, 'subtmp')
return g.builder.fsub(lefttmp, righttmp, 'subtmp')
elif expr.operand == '*':
return builder.fmul(lefttmp, righttmp, 'multmp')
return g.builder.fmul(lefttmp, righttmp, 'multmp')
elif expr.operand == '/':
return builder.fdiv(lefttmp, righttmp, 'divtmp')
return g.builder.fdiv(lefttmp, righttmp, 'divtmp')
elif expr.operand == 'mod':
# l mod r == (((l rem r) + r) rem r)
remtmp = builder.frem(lefttmp, righttmp)
addtmp = builder.fadd(remtmp, righttmp)
return builder.frem(addtmp, righttmp, 'modtmp')
remtmp = g.builder.frem(lefttmp, righttmp)
addtmp = g.builder.fadd(remtmp, righttmp)
return g.builder.frem(addtmp, righttmp, 'modtmp')
elif expr.operand == 'rem':
return builder.frem(lefttmp, righttmp, 'remtmp')
return g.builder.frem(lefttmp, righttmp, 'remtmp')
elif expr.operand == '<':
return builder.icmp(core.FCMP_OLT, lefttmp, righttmp, 'lttmp')
return g.builder.icmp(core.FCMP_OLT, lefttmp, righttmp, 'lttmp')
elif expr.operand == '<=':
return builder.icmp(core.FCMP_OLE, lefttmp, righttmp, 'letmp')
return g.builder.icmp(core.FCMP_OLE, lefttmp, righttmp, 'letmp')
elif expr.operand == '=':
return builder.icmp(core.FCMP_OEQ, lefttmp, righttmp, 'eqtmp')
return g.builder.icmp(core.FCMP_OEQ, lefttmp, righttmp, 'eqtmp')
elif expr.operand == '/=':
return builder.icmp(core.FCMP_ONE, lefttmp, righttmp, 'netmp')
return g.builder.icmp(core.FCMP_ONE, lefttmp, righttmp, 'netmp')
elif expr.operand == '>=':
return builder.icmp(core.FCMP_OGE, lefttmp, righttmp, 'getmp')
return g.builder.icmp(core.FCMP_OGE, lefttmp, righttmp, 'getmp')
elif expr.operand == '>':
return builder.icmp(core.FCMP_OGT, lefttmp, righttmp, 'gttmp')
return g.builder.icmp(core.FCMP_OGT, lefttmp, righttmp, 'gttmp')
else:
raise NotImplementedError
else:
......@@ -460,20 +445,17 @@ def _generate_assign(left, right):
''' Generate code for an assign from two LLVM values'''
# This is extracted as an standalone function because is used by
# multiple generation rules
builder = LLVM['builder']
if left.type.kind == core.TYPE_POINTER and left.type.pointee.kind == core.TYPE_STRUCT:
memcpy = _get_memcpy_intrinsic()
size = core.Constant.int(core.Type.int(64), 2)
align = core.Constant.int(core.Type.int(32), 1)
volatile = core.Constant.int(core.Type.int(1), 0)
size = core.Constant.int(g.i64, 2)
align = core.Constant.int(g.i32, 1)
volatile = core.Constant.int(g.i1, 0)
right_ptr = builder.bitcast(right, core.Type.pointer(core.Type.int(8)))
left_ptr = builder.bitcast(left, core.Type.pointer(core.Type.int(8)))
right_ptr = g.builder.bitcast(right, g.i8_ptr)
left_ptr = g.builder.bitcast(left, g.i8_ptr)
builder.call(memcpy, [left_ptr, right_ptr, size, align, volatile])
g.builder.call(g.memcpy, [left_ptr, right_ptr, size, align, volatile])
else:
builder.store(right, left)
g.builder.store(right, left)
@expression.register(ogAST.ExprOr)
......@@ -481,8 +463,6 @@ def _generate_assign(left, right):
@expression.register(ogAST.ExprXor)
def _logical(expr):
''' Generate the code for a logical expression '''
builder = LLVM['builder']
lefttmp = expression(expr.left)
righttmp = expression(expr.right)
......@@ -492,16 +472,16 @@ def _logical(expr):
# load the value of the expression if it is a pointer
if lefttmp.type.kind == core.TYPE_POINTER:
lefttmp = builder.load(lefttmp, 'lefttmp')
lefttmp = g.builder.load(lefttmp, 'lefttmp')
if righttmp.type.kind == core.TYPE_POINTER:
righttmp = builder.load(righttmp, 'lefttmp')
righttmp = g.builder.load(righttmp, 'lefttmp')
if expr.operand == '&&':
return builder.and_(lefttmp, righttmp, 'ortmp')
return g.builder.and_(lefttmp, righttmp, 'ortmp')
elif expr.operand == '||':
return builder.or_(lefttmp, righttmp, 'ortmp')
return g.builder.or_(lefttmp, righttmp, 'ortmp')
else:
return builder.xor(lefttmp, righttmp, 'xortmp')
return g.builder.xor(lefttmp, righttmp, 'xortmp')
@expression.register(ogAST.ExprAppend)
......@@ -531,22 +511,22 @@ def _choice_determinant(primary):
@expression.register(ogAST.PrimInteger)
def _integer(primary):
''' Generate code for a raw integer value '''
return core.Constant.int(core.Type.int(), primary.value[0])
return core.Constant.int(g.i32, primary.value[0])
@expression.register(ogAST.PrimReal)
def _real(primary):
''' Generate code for a raw real value '''
return core.Constant.real(core.Type.double(), primary.value[0])
return core.Constant.real(g.double, primary.value[0])
@expression.register(ogAST.PrimBoolean)
def _boolean(primary):
''' Generate code for a raw boolean value '''
if primary.value[0].lower() == 'true':
return core.Constant.int(core.Type.int(1), 1)
return core.Constant.int(g.i1, 1)
else:
return core.Constant.int(core.Type.int(1), 0)
return core.Constant.int(g.i1, 0)
@expression.register(ogAST.PrimEmptyString)
......@@ -588,18 +568,16 @@ def _sequence(seq):
@expression.register(ogAST.PrimSequenceOf)
def _sequence_of(seqof):
''' Generate the code for an ASN.1 SEQUENCE OF '''
builder = LLVM['builder']
ty = _generate_type(seqof.exprType)
struct_ptr = builder.alloca(ty)
zero_cons = core.Constant.int(core.Type.int(), 0)
array_ptr = builder.gep(struct_ptr, [zero_cons, zero_cons])
struct_ptr = g.builder.alloca(ty)
zero_cons = core.Constant.int(g.i32, 0)
array_ptr = g.builder.gep(struct_ptr, [zero_cons, zero_cons])
for idx, expr in enumerate(seqof.value):
idx_cons = core.Constant.int(core.Type.int(), idx)
idx_cons = core.Constant.int(g.i32, idx)
expr_val = expression(expr)
pos_ptr = builder.gep(array_ptr, [zero_cons, idx_cons])
builder.store(expr_val, pos_ptr)
pos_ptr = g.builder.gep(array_ptr, [zero_cons, idx_cons])
g.builder.store(expr_val, pos_ptr)
return struct_ptr
......@@ -613,19 +591,18 @@ def _choiceitem(choice):
@generate.register(ogAST.Decision)
def _decision(dec):
''' Generate the code for a decision '''
builder = LLVM['builder']
func = builder.basic_block.function
func = g.builder.basic_block.function
ans_cond_blocks = [func.append_basic_block('ans_cond') for ans in dec.answers]
end_block = func.append_basic_block('end')
builder.branch(ans_cond_blocks[0])
g.builder.branch(ans_cond_blocks[0])
for idx, ans in enumerate(dec.answers):
ans_cond_block = ans_cond_blocks[idx]
if ans.transition:
ans_tr_block = func.append_basic_block('ans_tr')
builder.position_at_end(ans_cond_block)
g.builder.position_at_end(ans_cond_block)
if ans.kind == 'constant':
next_block = ans_cond_blocks[idx+1] if idx < len(ans_cond_blocks) else end_block
......@@ -635,23 +612,23 @@ def _decision(dec):
<