Commit aeb3a618 authored by Maxime Perrotin's avatar Maxime Perrotin
Browse files

Merge branch 'llvm' of https://github.com/dbrabera/opengeode

parents ea3aad83 9b2eea26
......@@ -18,8 +18,9 @@
"""
import logging
from singledispatch import singledispatch
from llvm import core, passes, ee
from llvm import core, ee
import ogAST
import Helper
......@@ -29,33 +30,323 @@ LOG = logging.getLogger(__name__)
__all__ = ['generate']
# LLVM Global variable - Initialized when the generator is invoked
LLVM = {
# The LLVM module, which holds all the IR code.
'module': None,
# Dictionary that keeps track of which values are defined in the current
# scope and what their LLVM representation is.
'named_values': {},
# The function optimization passes manager.
'pass_manager': None,
# The LLVM execution engine.
'executor': None
}
class Context():
def __init__(self, process):
self.name = str(process.processName)
self.module = core.Module.new(self.name)
self.target_data = ee.TargetData.new(self.module.data_layout)
self.dataview = process.dataview
self.procedures = process.procedures
self.scope = Scope(self)
self.global_scope = self.scope
self.states = {}
self.enums = {}
self.structs = {}
self.unions = {}
self.strings = {}
self.funcs = {}
self.lltypes = {}
self.basic_asn1types = {}
# Initialize built-in types
self.i1 = core.Type.int(1)
self.i8 = core.Type.int(8)
self.i32 = core.Type.int(32)
self.i64 = core.Type.int(64)
self.void = core.Type.void()
self.double = core.Type.double()
self.i1_ptr = core.Type.pointer(self.i1)
self.i8_ptr = core.Type.pointer(self.i8)
self.i32_ptr = core.Type.pointer(self.i32)
self.i64_ptr = core.Type.pointer(self.i64)
self.double_ptr = core.Type.pointer(self.double)
# Initialize common constants
self.zero = core.Constant.int(self.i32, 0)
self.one = core.Constant.int(self.i32, 1)
# Intialize built-in functions
ty = core.Type.function(self.void, [self.i8_ptr], True)
self.funcs['printf'] = self.module.add_function(ty, 'printf')
self.funcs['memcpy'] = core.Function.intrinsic(
self.module,
core.INTR_MEMCPY,
[self.i8_ptr, self.i8_ptr, self.i64]
)
self.funcs['powi'] = core.Function.intrinsic(
self.module,
core.INTR_POWI,
[self.double]
)
self.funcs['fabs'] = core.Function.intrinsic(
self.module,
core.INTR_FABS,
[self.double]
)
def open_scope(self):
''' Open a scope '''
self.scope = Scope(self, self.scope)
def close_scope(self):
''' Close the current scope '''
self.scope = self.scope.parent
def basic_asn1type_of(self, asn1ty):
''' Return the ASN.1 basic type of a type '''
if asn1ty.kind != 'ReferenceType':
return asn1ty
asn1ty_name = asn1ty.ReferencedTypeName.lower()
# return the basic type if its cached
if asn1ty_name in self.basic_asn1types:
return self.basic_asn1types[asn1ty_name]
basic_asn1ty = asn1ty
while basic_asn1ty.kind == 'ReferenceType':
for typename in self.dataview.viewkeys():
if typename.lower() == basic_asn1ty.ReferencedTypeName.lower():
basic_asn1ty = self.dataview[typename].type
break
# cache the basic type
self.basic_asn1types[asn1ty_name] = basic_asn1ty
return basic_asn1ty
def lltype_of(self, asn1ty):
''' Return the LL type of a ASN.1 type '''
try:
name = asn1ty.ReferencedTypeName.replace('-', '_')
except AttributeError:
name = None
if name and name in self.lltypes:
return self.lltypes[name]
basic_asn1ty = self.basic_asn1type_of(asn1ty)
if basic_asn1ty.kind == 'IntegerType':
llty = self.i64
elif basic_asn1ty.kind == 'Integer32Type':
llty = self.i32
elif basic_asn1ty.kind == 'BooleanType':
llty = self.i1
elif basic_asn1ty.kind == 'RealType':
llty = self.double
elif basic_asn1ty.kind == 'SequenceOfType':
llty = self._lltype_of_sequenceof(name, basic_asn1ty)
elif basic_asn1ty.kind == 'SequenceType':
llty = self._lltype_of_sequence(name, basic_asn1ty)
elif basic_asn1ty.kind == 'EnumeratedType':
llty = self.i32
elif basic_asn1ty.kind == 'ChoiceType':
llty = self._lltype_of_choice(name, basic_asn1ty)
elif basic_asn1ty.kind == 'OctetStringType':
llty = self._lltype_of_octetstring(name, basic_asn1ty)
elif basic_asn1ty.kind in ('StringType', 'StandardStringType'):
llty = self.i8_ptr
else:
raise CompileError('Unknown basic ASN.1 type "%s"' % basic_asn1ty.kind)
if name:
self.lltypes[name] = llty
return llty
def _lltype_of_sequenceof(self, name, asn1ty):
''' Return the LL type of a SequenceOf ASN.1 type '''
min_size = int(asn1ty.Min)
max_size = int(asn1ty.Max)
is_variable_size = min_size != max_size
elem_llty = self.lltype_of(asn1ty.type)
array_llty = core.Type.array(elem_llty, max_size)
if is_variable_size:
struct = self.decl_struct(['nCount', 'arr'], [self.i32, array_llty], name)
else:
struct = self.decl_struct(['arr'], [array_llty], name)
struct_ptr = core.Type.pointer(struct.llty)
self.decl_func("asn1Scc%s_Equal" % name, self.i1, [struct_ptr, struct_ptr])
return struct.llty
def _lltype_of_sequence(self, name, asn1ty):
''' Return the LL type of a Sequence ASN.1 type '''
field_names = []
field_lltys = []
for field_name in Helper.sorted_fields(asn1ty):
field_names.append(field_name.replace('-', '_'))
field_lltys.append(self.lltype_of(asn1ty.Children[field_name].type))
struct = self.decl_struct(field_names, field_lltys, name)
struct_ptr = core.Type.pointer(struct.llty)
self.decl_func("asn1Scc%s_Equal" % name, self.i1, [struct_ptr, struct_ptr])
return struct.llty
def _lltype_of_choice(self, name, asn1ty):
''' Return the equivalent LL type of a Choice ASN.1 type '''
field_names = []
field_lltys = []
for idx, field_name in enumerate(Helper.sorted_fields(asn1ty)):
# enum values used in choice determinant/present
self.enums[field_name.replace('-', '_')] = core.Constant.int(self.i32, idx)
field_names.append(field_name.replace('-', '_'))
field_lltys.append(self.lltype_of(asn1ty.Children[field_name].type))
union = self.decl_union(field_names, field_lltys, name)
union_ptr = core.Type.pointer(union.llty)
self.decl_func("asn1Scc%s_Equal" % name, self.i1, [union_ptr, union_ptr])
return union.llty
def _lltype_of_octetstring(self, name, asn1ty):
''' Return the equivalent LL type of a OctetString ASN.1 type '''
min_size = int(asn1ty.Min)
max_size = int(asn1ty.Max)
is_variable_size = min_size != max_size
array_llty = core.Type.array(self.i8, max_size)
if is_variable_size:
struct = self.decl_struct(['nCount', 'arr'], [self.i32, array_llty], name)
else:
struct = self.decl_struct(['arr'], [array_llty], name)
struct_ptr = core.Type.pointer(struct.llty)
self.decl_func("asn1Scc%s_Equal" % name, self.i1, [struct_ptr, struct_ptr])
return struct.llty
def string_ptr(self, str):
''' Returns a pointer to a global string with the given value '''
if str in self.strings:
return self.strings[str].gep([self.zero, self.zero])
str_val = core.Constant.stringz(str)
var_name = '.str%s' % len(self.strings)
var_ptr = self.module.add_global_variable(str_val.type, var_name)
var_ptr.initializer = str_val
self.strings[str] = var_ptr
return var_ptr.gep([self.zero, self.zero])
def decl_func(self, name, return_llty, param_lltys, extern=False):
''' Declare a function '''
func_llty = core.Type.function(return_llty, param_lltys)
func_name = ("%s_RI_%s" % (self.name, name)) if extern else name
func = core.Function.new(self.module, func_llty, func_name)
self.funcs[name.lower()] = func
return func
def decl_struct(self, field_names, field_lltys, name=None):
''' Declare a struct '''
name = name if name else "struct.%s" % len(self.structs)
name = name.replace('-', '_')
struct = StructType(name, field_names, field_lltys)
self.structs[name] = struct
return struct
def resolve_struct(self, name):
''' Return the struct associated to a name '''
return self.structs[name.replace('-', '_')]
def decl_union(self, field_names, field_lltys, name=None):
name = name if name else "union.%s" % len(self.structs)
name = name.replace('-', '_')
union = UnionType(name, field_names, field_lltys, self)
self.unions[name] = union
return union
def resolve_union(self, name):
''' Return the union associated to a name '''
return self.unions[name.replace('-', '_')]
class StructType():
def __init__(self, name, field_names, field_lltys):
self.name = name
self.field_names = field_names
self.llty = core.Type.struct(field_lltys, self.name)
def idx(self, field_name):
return self.field_names.index(field_name)
class UnionType():
def __init__(self, name, field_names, field_lltys, ctx):
self.name = name
self.field_names = field_names
self.field_lltys = field_lltys
# Unions are represented a struct with a field indicating the index of its type
# and a byte array with the size of the biggest type in the union
self.size = max([ctx.target_data.size(ty) for ty in field_lltys])
self.llty = core.Type.struct([ctx.i32, core.Type.array(ctx.i8, self.size)], name)
def kind(self, name):
idx = self.field_names.index(name)
return (idx, self.field_lltys[idx])
class Scope:
def __init__(self, ctx, parent=None):
self.ctx = ctx
self.vars = {}
self.labels = {}
self.parent = parent
def define(self, name, var):
self.vars[name.lower()] = var
def resolve(self, name):
var = self.vars.get(name.lower())
if var:
return var
if self.parent:
return self.parent.resolve(name)
else:
raise NameError("name '%s' is not defined" % name)
def label(self, name):
name = name.lower()
label_block = self.labels.get(name)
if not label_block:
func = self.ctx.builder.basic_block.function
label_block = func.append_basic_block('label:%s' % name)
self.labels[name] = label_block
return label_block
class CompileError(Exception):
pass
@singledispatch
def generate(ast):
''' Generate the code for an item of the AST '''
raise TypeError('[Backend] Unsupported AST construct')
def generate(ast, ctx=None):
''' Generate the IR for an AST node '''
raise CompileError('Unsupported AST construct "%s"' % ast.__class__.__name__)
# Processing of the AST
# Processing of the AST
@generate.register(ogAST.Process)
def _process(process):
''' Generate LLVM IR code (incomplete) '''
process_name = process.processName
LOG.info('Generating LLVM IR code for process ' + str(process_name))
def _process(process, ctx=None):
''' Generate the IR for a process '''
process_name = str(process.processName)
LOG.info('Generating LLVM IR code for process ' + process_name)
ctx = Context(process)
# In case model has nested states, flatten everything
Helper.flatten(process)
......@@ -64,241 +355,1565 @@ def _process(process):
# generate the lookup tables for the state machine runtime
mapping = Helper.map_input_state(process)
# Initialise LLVM global structure
LLVM['module'] = core.Module.new(str(process_name))
LLVM['pass_manager'] = passes.FunctionPassManager.new(LLVM['module'])
LLVM['executor'] = ee.ExecutionEngine.new(LLVM['module'])
# Set up the optimizer pipeline.
# Start with registering info about how the
# target lays out data structures.
# LLVM['pass_manager'].add(LLVM['executor'].target_data)
# # Do simple "peephole" optimizations and bit-twiddling optzns.
# LLVM['pass_manager'].add(passes.PASS_INSTRUCTION_COMBINING)
# # Reassociate expressions.
# LLVM['pass_manager'].add(passes.PASS_REASSOCIATE)
# # Eliminate Common SubExpressions.
# LLVM['pass_manager'].add(passes.PASS_GVN)
# # Simplify the control flow graph (deleting unreachable blocks, etc).
# LLVM['pass_manager'].add(passes.PASS_CFG_SIMPLIFICATION)
# LLVM['pass_manager'].initialize()
# Create the runTransition function
run_funct_name = 'run_transition'
run_funct_type = core.Type.function(core.Type.void(), [
core.Type.int()])
run_funct = core.Function.new(
LLVM['module'], run_funct_type, run_funct_name)
# Generate the code of the start transition:
# Clear scope
LLVM['named_values'].clear()
# Create the function name and type
funct_name = str(process_name) + '_startup'
funct_type = core.Type.function(core.Type.void(), [])
# Create a function object
function = core.Function.new(LLVM['module'], funct_type, funct_name)
# Create a new basic block to start insertion into.
block = function.append_basic_block('entry')
builder = core.Builder.new(block)
# Add the body of the function
builder.call(run_funct, (core.Constant.int(
core.Type.int(), 0),))
# Add terminator (mandatory)
builder.ret_void()
# Validate the generated code, checking for consistency.
function.verify()
# Optimize the function (not yet).
# LLVM['pass_manager'].run(function)
print function
def write_statement(param, newline):
''' Generate the code for the special "write" operator '''
pass
# Initialize states
for name, val in process.mapping.viewitems():
if not name.endswith('START'):
cons_val = core.Constant.int(ctx.i32, len(ctx.states))
ctx.states[name.lower()] = cons_val
elif name != 'START':
cons_val = core.Constant.int(ctx.i32, val)
ctx.states[name.lower()] = cons_val
# Generate state var
state_cons = ctx.module.add_global_variable(ctx.i32, '.state')
state_cons.initializer = core.Constant.int(ctx.i32, -1)
ctx.scope.define('.state', state_cons)
# Generare process-level vars
for name, (asn1ty, expr) in process.variables.viewitems():
var_llty = ctx.lltype_of(asn1ty)
global_var = ctx.module.add_global_variable(var_llty, str(name))
global_var.initializer = core.Constant.null(var_llty)
ctx.scope.define(str(name).lower(), global_var)
# Declare set/reset timer functions
for timer in process.timers:
# TODO: Should be uint?
ctx.decl_func("set_%s" % str(timer), ctx.void, [ctx.i64_ptr], True)
ctx.decl_func("reset_%s" % str(timer), ctx.void, [], True)
# Declare output signal functions
for signal in process.output_signals:
if 'type' in signal:
param_lltys = [core.Type.pointer(ctx.lltype_of(signal['type']))]
else:
param_lltys = []
ctx.decl_func(str(signal['name']), ctx.void, param_lltys, True)
# Declare external procedures functions
for proc in [proc for proc in process.procedures if proc.external]:
param_lltys = [core.Type.pointer(ctx.lltype_of(p['type'])) for p in proc.fpar]
ctx.decl_func(str(proc.inputString), ctx.void, param_lltys, True)
# Generate internal procedures
for proc in process.content.inner_procedures:
generate(proc, ctx)
# Generate process functions
generate_runtr_func(process, ctx)
generate_startup_func(process, ctx)
# Generate input signals
for signal in process.input_signals:
generate_input_signal(signal, mapping[signal['name']], ctx)
# Generate timer signal
for timer in process.timers:
generate_input_signal({'name': timer.lower()}, mapping[timer], ctx)
ctx.module.verify()
with open(ctx.name + '.ll', 'w') as ll_file:
ll_file.write(str(ctx.module))
def generate_runtr_func(process, ctx):
''' Generate the IR for the run_transition function '''
func = ctx.decl_func('run_transition', ctx.void, [ctx.i32])
ctx.open_scope()
entry_block = func.append_basic_block('runtr:entry')
cond_block = func.append_basic_block('runtr:cond')
body_block = func.append_basic_block('runtr:body')
exit_block = func.append_basic_block('runtr:exit')
ctx.builder = core.Builder.new(entry_block)
# entry
id_ptr = ctx.builder.alloca(ctx.i32, None, 'id')
ctx.scope.define('id', id_ptr)
ctx.builder.store(func.args[0], id_ptr)
ctx.builder.branch(cond_block)
# cond
ctx.builder.position_at_end(cond_block)
no_tr_cons = core.Constant.int(ctx.i32, -1)
id_val = ctx.builder.load(id_ptr)
cond_val = ctx.builder.icmp(core.ICMP_NE, id_val, no_tr_cons, 'cond')
ctx.builder.cbranch(cond_val, body_block, exit_block)
# body
ctx.builder.position_at_end(body_block)
switch = ctx.builder.switch(id_val, exit_block)
# transitions
for idx, tr in enumerate(process.transitions):
tr_block = func.append_basic_block('runtr:tr%d' % idx)
const = core.Constant.int(ctx.i32, idx)
switch.add_case(const, tr_block)
ctx.builder.position_at_end(tr_block)
generate(tr, ctx)
if not ctx.builder.basic_block.terminator:
ctx.builder.branch(cond_block)
# exit
ctx.builder.position_at_end(exit_block)
ctx.builder.ret_void()
Helper.inner_labels_to_floating(process)
for label in process.content.floating_labels:
generate(label, ctx)
next_tr_label_block = ctx.scope.label('next_transition')
ctx.builder.position_at_end(next_tr_label_block)
ctx.builder.branch(cond_block)
ctx.close_scope()
func.verify()
return func
def generate_startup_func(process, ctx):
''' Generate the IR for the startup function '''
func = ctx.decl_func(ctx.name + '_startup', ctx.void, [])
ctx.open_scope()
entry_block = func.append_basic_block('startup:entry')
ctx.builder = core.Builder.new(entry_block)
# Initialize process level variables
for name, (ty, expr) in process.variables.viewitems():
if expr:
global_var = ctx.scope.resolve(str(name))
sdl_assign(global_var, expression(expr, ctx), ctx)
sdl_call('run_transition', [core.Constant.int(ctx.i32, 0)], ctx)
ctx.builder.ret_void()
ctx.close_scope()
func.verify()
return func
def generate_input_signal(signal, inputs, ctx):
''' Generate the IR for an input signal '''
func_name = ctx.name + "_" + str(signal['name'])
param_lltys = []
if 'type' in signal:
param_lltys.append(core.Type.pointer(ctx.lltype_of(signal['type'])))
func = ctx.decl_func(func_name, ctx.void, param_lltys)
ctx.open_scope()
entry_block = func.append_basic_block('input:entry')
exit_block = func.append_basic_block('input:exit')
ctx.builder = core.Builder.new(entry_block)
g_state_val = ctx.builder.load(ctx.global_scope.resolve('.state'))
switch = ctx.builder.switch(g_state_val, exit_block)
for state_name, state_id in ctx.states.iteritems():
if state_name.endswith('start'):
continue
state_block = func.append_basic_block('input:state_%s' % str(state_name))
switch.add_case(state_id, state_block)
ctx.builder.position_at_end(state_block)
# TODO: Nested states
input = inputs.get(state_name)
if input:
for var_name in input.parameters:
var_ptr = ctx.scope.resolve(str(var_name))
if is_struct_ptr(var_ptr) or is_array_ptr(var_ptr):
sdl_assign(var_ptr, func.args[0], ctx)
else:
sdl_assign(var_ptr, ctx.builder.load(func.args[0]), ctx)
if input.transition:
id_val = core.Constant.int(ctx.i32, input.transition_id)
sdl_call('run_transition', [id_val], ctx)
ctx.builder.ret_void()
ctx.builder.position_at_end(exit_block)
ctx.builder.ret_void()
ctx.close_scope()
func.verify()
@generate.register(ogAST.Output)