Commit a7f94aec authored by dbarbera's avatar dbarbera
Browse files

Removed context from global scope

parent 3a64ba58
......@@ -30,10 +30,6 @@ LOG = logging.getLogger(__name__)
__all__ = ['generate']
# Global context
ctx = None
class Context():
def __init__(self, process):
self.name = str(process.processName)
......@@ -41,7 +37,7 @@ class Context():
self.target_data = ee.TargetData.new(self.module.data_layout)
self.dataview = process.dataview
self.scope = Scope()
self.scope = Scope(self)
self.global_scope = self.scope
self.states = {}
self.enums = {}
......@@ -92,7 +88,7 @@ class Context():
def open_scope(self):
''' Open a scope '''
self.scope = Scope(self.scope)
self.scope = Scope(self, self.scope)
def close_scope(self):
''' Close the current scope '''
......@@ -108,28 +104,28 @@ class Context():
if name and name in self.lltypes:
return self.lltypes[name]
basic_asn1ty = find_basic_type(asn1ty)
basic_asn1ty = find_basic_type(asn1ty, self)
if basic_asn1ty.kind == 'IntegerType':
llty = ctx.i64
llty = self.i64
elif basic_asn1ty.kind == 'Integer32Type':
llty = ctx.i32
llty = self.i32
elif basic_asn1ty.kind == 'BooleanType':
llty = ctx.i1
llty = self.i1
elif basic_asn1ty.kind == 'RealType':
llty = ctx.double
llty = self.double
elif basic_asn1ty.kind == 'SequenceOfType':
llty = self._type_of_sequenceof(name, basic_asn1ty)
elif basic_asn1ty.kind == 'SequenceType':
llty = self._type_of_sequence(name, basic_asn1ty)
elif basic_asn1ty.kind == 'EnumeratedType':
llty = ctx.i32
llty = self.i32
elif basic_asn1ty.kind == 'ChoiceType':
llty = self._type_of_choice(name, basic_asn1ty)
elif basic_asn1ty.kind == 'OctetStringType':
llty = self._type_of_octetstring(name, basic_asn1ty)
elif basic_asn1ty.kind in ('StringType', 'StandardStringType'):
llty = ctx.i8_ptr
llty = self.i8_ptr
else:
raise NotImplementedError
......@@ -148,7 +144,7 @@ class Context():
array_ty = core.Type.array(elem_ty, max_size)
if is_variable_size:
struct = self.decl_struct(['nCount', 'arr'], [ctx.i32, array_ty], name)
struct = self.decl_struct(['nCount', 'arr'], [self.i32, array_ty], name)
else:
struct = self.decl_struct(['arr'], [array_ty], name)
......@@ -198,10 +194,10 @@ class Context():
max_size = int(octetstring_ty.Max)
is_variable_size = min_size != max_size
array_ty = core.Type.array(ctx.i8, max_size)
array_ty = core.Type.array(self.i8, max_size)
if is_variable_size:
struct = self.decl_struct(['nCount', 'arr'], [ctx.i32, array_ty], name)
struct = self.decl_struct(['nCount', 'arr'], [self.i32, array_ty], name)
else:
struct = self.decl_struct(['arr'], [array_ty], name)
......@@ -245,7 +241,7 @@ class Context():
def decl_union(self, field_names, field_types, name=None):
name = name if name else "union.%s" % len(self.structs)
name = name.replace('-', '_')
union = UnionType(name, field_names, field_types)
union = UnionType(name, field_names, field_types, self)
self.unions[name] = union
return union
......@@ -265,7 +261,7 @@ class StructType():
class UnionType():
def __init__(self, name, field_names, field_types):
def __init__(self, name, field_names, field_types, ctx):
self.name = name
self.field_names = field_names
self.field_types = field_types
......@@ -280,7 +276,8 @@ class UnionType():
class Scope:
def __init__(self, parent=None):
def __init__(self, ctx, parent=None):
self.ctx = ctx
self.vars = {}
self.labels = {}
self.parent = parent
......@@ -301,7 +298,7 @@ class Scope:
name = name.lower()
label_block = self.labels.get(name)
if not label_block:
func = ctx.builder.basic_block.function
func = self.ctx.builder.basic_block.function
label_block = func.append_basic_block('label:%s' % name)
self.labels[name] = label_block
return label_block
......@@ -312,19 +309,18 @@ class CompileError(Exception):
@singledispatch
def generate(ast):
def generate(ast, ctx=None):
''' Generate the IR for an AST node '''
raise CompileError('Unsupported AST construct "%s"' % ast.__class__.__name__)
# Processing of the AST
@generate.register(ogAST.Process)
def _process(process):
def _process(process, ctx=None):
''' Generate the IR for a process '''
process_name = str(process.processName)
LOG.info('Generating LLVM IR code for process ' + process_name)
global ctx
ctx = Context(process)
# In case model has nested states, flatten everything
......@@ -376,19 +372,19 @@ def _process(process):
# Generate internal procedures
for proc in process.content.inner_procedures:
generate(proc)
generate(proc, ctx)
# Generate process functions
generate_runtr_func(process)
generate_startup_func(process)
generate_runtr_func(process, ctx)
generate_startup_func(process, ctx)
# Generate input signals
for signal in process.input_signals:
generate_input_signal(signal, mapping[signal['name']])
generate_input_signal(signal, mapping[signal['name']], ctx)
# Generate timer signal
for timer in process.timers:
generate_input_signal({'name': timer.lower()}, mapping[timer])
generate_input_signal({'name': timer.lower()}, mapping[timer], ctx)
ctx.module.verify()
......@@ -396,7 +392,7 @@ def _process(process):
ll_file.write(str(ctx.module))
def generate_runtr_func(process):
def generate_runtr_func(process, ctx):
''' Generate the IR for the run_transition function '''
func = ctx.decl_func('run_transition', ctx.void, [ctx.i32])
......@@ -432,7 +428,7 @@ def generate_runtr_func(process):
const = core.Constant.int(ctx.i32, idx)
switch.add_case(const, tr_block)
ctx.builder.position_at_end(tr_block)
generate(tr)
generate(tr, ctx)
if not ctx.builder.basic_block.terminator:
ctx.builder.branch(cond_block)
......@@ -442,7 +438,7 @@ def generate_runtr_func(process):
Helper.inner_labels_to_floating(process)
for label in process.content.floating_labels:
generate(label)
generate(label, ctx)
# TODO: Use defined cond_block instead?
next_tr_label_block = ctx.scope.label('next_transition')
......@@ -455,7 +451,7 @@ def generate_runtr_func(process):
return func
def generate_startup_func(process):
def generate_startup_func(process, ctx):
''' Generate the IR for the startup function '''
func = ctx.decl_func(ctx.name + '_startup', ctx.void, [])
......@@ -468,7 +464,7 @@ def generate_startup_func(process):
for name, (ty, expr) in process.variables.viewitems():
if expr:
global_var = ctx.scope.resolve(str(name))
generate_assign(global_var, expression(expr))
generate_assign(global_var, expression(expr, ctx), ctx)
ctx.builder.call(ctx.funcs['run_transition'], [core.Constant.int(ctx.i32, 0)])
ctx.builder.ret_void()
......@@ -479,7 +475,7 @@ def generate_startup_func(process):
return func
def generate_input_signal(signal, inputs):
def generate_input_signal(signal, inputs, ctx):
''' Generate the IR for an input signal '''
func_name = ctx.name + "_" + str(signal['name'])
param_tys = []
......@@ -511,9 +507,9 @@ def generate_input_signal(signal, inputs):
for var_name in input.parameters:
var_ptr = ctx.scope.resolve(str(var_name))
if is_struct_ptr(var_ptr) or is_array_ptr(var_ptr):
generate_assign(var_ptr, func.args[0])
generate_assign(var_ptr, func.args[0], ctx)
else:
generate_assign(var_ptr, ctx.builder.load(func.args[0]))
generate_assign(var_ptr, ctx.builder.load(func.args[0]), ctx)
if input.transition:
id_val = core.Constant.int(ctx.i32, input.transition_id)
ctx.builder.call(ctx.funcs['run_transition'], [id_val])
......@@ -530,29 +526,29 @@ def generate_input_signal(signal, inputs):
@generate.register(ogAST.Output)
@generate.register(ogAST.ProcedureCall)
def _call_external_function(output):
def _call_external_function(output, ctx):
''' Generate the IR for an output or procedure call '''
for out in output.output:
name = out['outputName'].lower()
if name == 'write':
generate_write(out['params'])
generate_write(out['params'], ctx)
continue
elif name == 'writeln':
generate_writeln(out['params'])
generate_writeln(out['params'], ctx)
continue
elif name == 'reset_timer':
generate_reset_timer(out['params'])
generate_reset_timer(out['params'], ctx)
continue
elif name == 'set_timer':
generate_set_timer(out['params'])
generate_set_timer(out['params'], ctx)
continue
func = ctx.funcs[str(name).lower()]
params = []
for p in out.get('params', []):
p_val = expression(p)
p_val = expression(p, ctx)
# Pass by reference
if p_val.type.kind != core.TYPE_POINTER:
p_var = ctx.builder.alloca(p_val.type, None)
......@@ -564,11 +560,11 @@ def _call_external_function(output):
ctx.builder.call(func, params)
def generate_write(params):
def generate_write(params, ctx):
''' Generate the IR for the write operator '''
for param in params:
basic_ty = find_basic_type(param.exprType)
expr_val = expression(param)
basic_ty = find_basic_type(param.exprType, ctx)
expr_val = expression(param, ctx)
if basic_ty.kind in ['IntegerType', 'Integer32Type']:
fmt_str_ptr = ctx.string_ptr('% d')
......@@ -597,15 +593,15 @@ def generate_write(params):
raise NotImplementedError
def generate_writeln(params):
def generate_writeln(params, ctx):
''' Generate the IR for the writeln operator '''
generate_write(params)
generate_write(params, ctx)
str_ptr = ctx.string_ptr('\n')
ctx.builder.call(ctx.funcs['printf'], [str_ptr])
def generate_reset_timer(params):
def generate_reset_timer(params, ctx):
''' Generate the IR for the reset timer operator '''
timer_id = params[0]
reset_func_name = 'reset_%s' % timer_id.value[0]
......@@ -614,13 +610,13 @@ def generate_reset_timer(params):
ctx.builder.call(reset_func, [])
def generate_set_timer(params):
def generate_set_timer(params, ctx):
''' Generate the IR for the set timer operator '''
timer_expr, timer_id = params
set_func_name = 'set_%s' % timer_id.value[0]
set_func = ctx.funcs[set_func_name.lower()]
expr_val = expression(timer_expr)
expr_val = expression(timer_expr, ctx)
tmp_ptr = ctx.builder.alloca(expr_val.type)
ctx.builder.store(expr_val, tmp_ptr)
......@@ -629,29 +625,29 @@ def generate_set_timer(params):
@generate.register(ogAST.TaskAssign)
def _task_assign(task):
def _task_assign(task, ctx):
''' Generate the IR for a list of assignments '''
for expr in task.elems:
expression(expr)
expression(expr, ctx)
@generate.register(ogAST.TaskInformalText)
def _task_informal_text(task):
def _task_informal_text(task, ctx):
''' Generate comments for informal text '''
pass
@generate.register(ogAST.TaskForLoop)
def _task_forloop(task):
def _task_forloop(task, ctx):
''' Generate the IRfor a for loop '''
for loop in task.elems:
if loop['range']:
generate_for_range(loop)
generate_for_range(loop, ctx)
else:
generate_for_iterable(loop)
generate_for_iterable(loop, ctx)
def generate_for_range(loop):
def generate_for_range(loop, ctx):
''' Generate the IR for a for x in range loop '''
func = ctx.builder.basic_block.function
cond_block = func.append_basic_block('for:cond')
......@@ -665,12 +661,12 @@ def generate_for_range(loop):
ctx.scope.define(str(loop['var']), loop_var)
if loop['range']['start']:
start_val = expression(loop['range']['start'])
start_val = expression(loop['range']['start'], ctx)
ctx.builder.store(start_val, loop_var)
else:
ctx.builder.store(core.Constant.int(ctx.i64, 0), loop_var)
stop_val = expression(loop['range']['stop'])
stop_val = expression(loop['range']['stop'], ctx)
ctx.builder.branch(cond_block)
ctx.builder.position_at_end(cond_block)
......@@ -679,7 +675,7 @@ def generate_for_range(loop):
ctx.builder.cbranch(cond_val, body_block, end_block)
ctx.builder.position_at_end(body_block)
generate(loop['transition'])
generate(loop['transition'], ctx)
ctx.builder.branch(inc_block)
ctx.builder.position_at_end(inc_block)
......@@ -694,9 +690,9 @@ def generate_for_range(loop):
ctx.close_scope()
def generate_for_iterable(loop):
def generate_for_iterable(loop, ctx):
''' Generate the IR for a for x in iterable loop '''
seqof_asn1ty = find_basic_type(loop['list'].exprType)
seqof_asn1ty = find_basic_type(loop['list'].exprType, ctx)
is_variable_size = seqof_asn1ty.Min != seqof_asn1ty.Max
func = ctx.builder.basic_block.function
......@@ -714,7 +710,7 @@ def generate_for_iterable(loop):
idx_ptr = ctx.builder.alloca(ctx.i32)
ctx.builder.store(core.Constant.int(ctx.i32, 0), idx_ptr)
seqof_struct_ptr = expression(loop['list'])
seqof_struct_ptr = expression(loop['list'], ctx)
if is_variable_size:
# In variable size SequenceOfs the array values are in the second field
......@@ -739,14 +735,16 @@ def generate_for_iterable(loop):
ctx.builder.position_at_end(load_block)
idx_var = ctx.builder.load(idx_ptr)
if element_typ.kind == core.TYPE_STRUCT:
generate_assign(var_ptr, ctx.builder.gep(array_ptr, [ctx.zero, idx_var]))
elem_ptr = ctx.builder.gep(array_ptr, [ctx.zero, idx_var])
generate_assign(var_ptr, elem_ptr, ctx)
else:
generate_assign(var_ptr, ctx.builder.load(ctx.builder.gep(array_ptr, [ctx.zero, idx_var])))
elem_val = ctx.builder.load(ctx.builder.gep(array_ptr, [ctx.zero, idx_var]))
generate_assign(var_ptr, elem_val, ctx)
ctx.builder.branch(body_block)
# body block
ctx.builder.position_at_end(body_block)
generate(loop['transition'])
generate(loop['transition'], ctx)
ctx.builder.branch(cond_block)
# cond block
......@@ -762,21 +760,21 @@ def generate_for_iterable(loop):
@singledispatch
def reference(prim):
def reference(prim, ctx):
''' Generate the IR for a reference '''
raise CompileError('Unsupported reference "%s"' % prim.__class__.__name__)
@reference.register(ogAST.PrimVariable)
def _prim_var_reference(prim):
def _prim_var_reference(prim, ctx):
''' Generate the IR for a variable reference '''
return ctx.scope.resolve(str(prim.value[0]))
@reference.register(ogAST.PrimSelector)
def _prim_selector_reference(prim):
def _prim_selector_reference(prim, ctx):
''' Generate the IR for a field selector referece '''
receiver_ptr = reference(prim.value[0])
receiver_ptr = reference(prim.value[0], ctx)
field_name = prim.value[1]
if receiver_ptr.type.pointee.name in ctx.structs:
......@@ -792,10 +790,10 @@ def _prim_selector_reference(prim):
@reference.register(ogAST.PrimIndex)
def _prim_index_reference(prim):
def _prim_index_reference(prim, ctx):
''' Generate the IR for an index reference '''
receiver_ptr = reference(prim.value[0])
idx_val = expression(prim.value[1]['index'][0])
receiver_ptr = reference(prim.value[0], ctx)
idx_val = expression(prim.value[1]['index'][0], ctx)
array_ptr = ctx.builder.gep(receiver_ptr, [ctx.zero, ctx.zero])
......@@ -809,7 +807,7 @@ def _prim_index_reference(prim):
@singledispatch
def expression(expr):
def expression(expr, ctx):
''' Generate the IR for an expression node '''
raise CompileError('Unsupported expression "%s"' % expr.__class__.__name__)
......@@ -820,12 +818,12 @@ def expression(expr):
@expression.register(ogAST.ExprDiv)
@expression.register(ogAST.ExprMod)
@expression.register(ogAST.ExprRem)
def _expr_arith(expr):
def _expr_arith(expr, ctx):
''' Generate the IR for an arithmetic expression '''
left_val = expression(expr.left)
right_val = expression(expr.right)
left_val = expression(expr.left, ctx)
right_val = expression(expr.right, ctx)
expr_bty = find_basic_type(expr.exprType)
expr_bty = find_basic_type(expr.exprType, ctx)
if expr_bty.kind in ('IntegerType', 'Integer32Type'):
if expr.operand == '+':
......@@ -868,12 +866,12 @@ def _expr_arith(expr):
@expression.register(ogAST.ExprGe)
@expression.register(ogAST.ExprLt)
@expression.register(ogAST.ExprLe)
def _expr_rel(expr):
def _expr_rel(expr, ctx):
''' Generate the IR for a relational expression '''
left_val = expression(expr.left)
right_val = expression(expr.right)
left_val = expression(expr.left, ctx)
right_val = expression(expr.right, ctx)
operands_bty = find_basic_type(expr.left.exprType)
operands_bty = find_basic_type(expr.left.exprType, ctx)
if operands_bty.kind in ('IntegerType', 'Integer32Type'):
if expr.operand == '<':
......@@ -946,9 +944,9 @@ def _expr_rel(expr):
@expression.register(ogAST.ExprNeg)
def _expr_neg(expr):
def _expr_neg(expr, ctx):
''' Generate the IR for a negative expression '''
expr_val = expression(expr.expr)
expr_val = expression(expr.expr, ctx)
if expr_val.type.kind == core.TYPE_INTEGER:
zero_val = core.Constant.int(ctx.i64, 0)
return ctx.builder.sub(zero_val, expr_val)
......@@ -958,12 +956,12 @@ def _expr_neg(expr):
@expression.register(ogAST.ExprAssign)
def _expr_assign(expr):
def _expr_assign(expr, ctx):
''' Generate the IR for an assign expression '''
generate_assign(reference(expr.left), expression(expr.right))
generate_assign(reference(expr.left, ctx), expression(expr.right, ctx), ctx)
def generate_assign(left, right):
def generate_assign(left, right, ctx):
''' Generate the IR for an assign from two LLVM values '''
# This is extracted as an standalone function because is used by
# multiple generation rules
......@@ -983,9 +981,9 @@ def generate_assign(left, right):
@expression.register(ogAST.ExprOr)
@expression.register(ogAST.ExprAnd)
@expression.register(ogAST.ExprXor)
def _expr_logic(expr):
def _expr_logic(expr, ctx):
''' Generate the IR for a logic expression '''
bty = find_basic_type(expr.exprType)
bty = find_basic_type(expr.exprType, ctx)
if expr.shortcircuit:
if bty.kind != 'BooleanType':
......@@ -998,7 +996,7 @@ def _expr_logic(expr):
end_block = func.append_basic_block('%s:end' % expr.operand)
res_ptr = ctx.builder.alloca(ctx.i1)
left_val = expression(expr.left)
left_val = expression(expr.left, ctx)
ctx.builder.store(left_val, res_ptr)
if expr.operand == 'and':
......@@ -1009,7 +1007,7 @@ def _expr_logic(expr):
raise CompileError('Unknown shortcircuit operator "%s"' % expr.operand)
ctx.builder.position_at_end(right_block)
right_val = expression(expr.right)
right_val = expression(expr.right, ctx)
ctx.builder.store(right_val, res_ptr)
ctx.builder.branch(end_block)
......@@ -1017,8 +1015,8 @@ def _expr_logic(expr):
return ctx.builder.load(res_ptr)
elif bty.kind == 'BooleanType':
left_val = expression(expr.left)
right_val = expression(expr.right)
left_val = expression(expr.left, ctx)
right_val = expression(expr.right, ctx)
if expr.operand == 'and':
return ctx.builder.and_(left_val, right_val)
elif expr.operand == 'or':
......@@ -1034,8 +1032,8 @@ def _expr_logic(expr):
next_block = func.append_basic_block('%s:next' % expr.operand)
end_block = func.append_basic_block('%s:end' % expr.operand)
left_ptr = expression(expr.left)
right_ptr = expression(expr.right)
left_ptr = expression(expr.left, ctx)
right_ptr = expression(expr.right, ctx)
res_ptr = ctx.builder.alloca(left_ptr.type.pointee)
array_ty = res_ptr.type.pointee.elements[0]
......@@ -1085,12 +1083,12 @@ def _expr_logic(expr):
@expression.register(ogAST.ExprNot)
def _expr_not(expr):
def _expr_not(expr, ctx):
''' Generate the IR for a not expression '''
bty = find_basic_type(expr.exprType)
bty = find_basic_type(expr.exprType, ctx)
if bty.kind == 'BooleanType':
return ctx.builder.not_(expression(expr.expr))
return ctx.builder.not_(expression(expr.expr, ctx))
elif bty.kind == 'SequenceOfType' and bty.Min == bty.Max:
func = ctx.builder.basic_block.function
......@@ -1102,7 +1100,7 @@ def _expr_not(expr):
idx_ptr = ctx.builder.alloca(ctx.i32)
ctx.builder.store(core.Constant.int(ctx.i32, 0), idx_ptr)
struct_ptr = expression(expr.expr)
struct_ptr = expression(expr.expr, ctx)
res_struct_ptr = ctx.builder.alloca(struct_ptr.type.pointee)
array_ty = struct_ptr.type.pointee.elements[0]
......@@ -1136,9 +1134,9 @@ def _expr_not(expr):
@expression.register(ogAST.ExprAppend)
def _expr_append(expr):
def _expr_append(expr, ctx):
''' Generate the IR for a append expression '''
bty = find_basic_type(expr.exprType)
bty = find_basic_type(expr.exprType, ctx)
if bty.kind in ('SequenceOfType', 'OctetStringType'):
res_ty = ctx.type_of(expr.exprType)
......@@ -1149,12 +1147,12 @@ def _expr_append(expr):
res_len_ptr = ctx.builder.gep(res_ptr, [ctx.zero, ctx.zero])
res_arr_ptr = ctx.builder.gep(res_ptr, [ctx.zero, ctx.one])
left_ptr = expression(expr.left)
left_ptr = expression(expr.left, ctx)
left_len_ptr = ctx.builder.gep(left_ptr, [ctx.zero, ctx.zero])
left_arr_ptr = ctx.builder.gep(left_ptr, [ctx.zero, ctx.one])
left_len_val = ctx.builder.load(left_len_ptr)
right_ptr = expression(expr.right)
right_ptr = expression(expr.right, ctx)
right_len_ptr = ctx.builder.gep(right_ptr, [ctx.zero, ctx.zero])
right_arr_ptr = ctx.builder.gep(right_ptr, [ctx.zero, ctx.one])
right_len_val = ctx.builder.load(right_len_ptr)
......@@ -1187,7 +1185,7 @@ def _expr_append(expr):