Commit 8336bd84 authored by dbarbera's avatar dbarbera
Browse files

Refactor generate_type function into type_of context method

parent bdbf7d1f
......@@ -48,6 +48,7 @@ class Context():
self.unions = {}
self.strings = {}
self.funcs = {}
self.lltypes = {}
# Initialize built-in types
self.i1 = core.Type.int(1)
......@@ -94,6 +95,87 @@ class Context():
def close_scope(self):
self.scope = self.scope.parent
def type_of(self, asn1ty):
''' Return the LL type of a ASN.1 type '''
try:
name = asn1ty.ReferencedTypeName.replace('-', '_')
except AttributeError:
name = None
if name and name in self.lltypes:
return self.lltypes[name]
basic_asn1ty = find_basic_type(asn1ty)
if basic_asn1ty.kind in ['IntegerType', 'Integer32Type']:
llty = ctx.i32
elif basic_asn1ty.kind == 'BooleanType':
llty = ctx.i1
elif basic_asn1ty.kind == 'RealType':
llty = ctx.double
elif basic_asn1ty.kind == 'SequenceOfType':
llty = self._type_of_sequenceof(name, basic_asn1ty)
elif basic_asn1ty.kind == 'SequenceType':
llty = self._type_of_sequence(name, basic_asn1ty)
elif basic_asn1ty.kind == 'EnumeratedType':
llty = ctx.i32
elif basic_asn1ty.kind == 'ChoiceType':
llty = self._type_of_choice(name, basic_asn1ty)
elif basic_asn1ty.kind == 'OctetStringType':
llty = self._type_of_octetstring(name, basic_asn1ty)
else:
raise NotImplementedError
if name:
self.lltypes[name] = llty
return llty
def _type_of_sequenceof(self, name, sequenceof_ty):
''' Return the LL type of a SequenceOf ASN.1 type '''
min_size = int(sequenceof_ty.Min)
max_size = int(sequenceof_ty.Max)
is_variable_size = min_size != max_size
elem_ty = self.type_of(sequenceof_ty.type)
array_ty = core.Type.array(elem_ty, max_size)
if is_variable_size:
struct = decl_struct(['nCount', 'arr'], [ctx.i32, array_ty], name)
else:
struct = decl_struct(['arr'], [array_ty], name)
return struct.ty
def _type_of_sequence(self, name, sequence_ty):
''' Return the LL type of a Sequence ASN.1 type '''
field_names = []
field_types = []
for field_name in Helper.sorted_fields(sequence_ty):
field_names.append(field_name.replace('-', '_'))
field_types.append(self.type_of(sequence_ty.Children[field_name].type))
struct = decl_struct(field_names, field_types, name)
return struct.ty
def _type_of_choice(self, name, choice_ty):
''' Return the equivalent LL type of a Choice ASN.1 type '''
field_names = []
field_types = []
for name, t in choice_ty.Children.viewitems():
field_names.append(name)
field_types.append(self.type_of(t.type))
union = decl_union(field_names, field_types, name)
return union.ty
def _type_of_octetstring(self, name, octetstring_ty):
''' Return the equivalent LL type of a OcterString ASN.1 type '''
max_size = int(octetstring_ty.Max)
arr_ty = core.Type.array(ctx.i8, max_size)
struct = decl_struct(['nCount', 'arr'], [ctx.i32, arr_ty], name)
return struct.ty
class StructType():
def __init__(self, name, field_names, field_types):
......@@ -187,7 +269,7 @@ def _process(process):
# Generare process-level vars
for name, (ty, expr) in process.variables.viewitems():
var_ty = generate_type(ty)
var_ty = ctx.type_of(ty)
global_var = ctx.module.add_global_variable(var_ty, str(name))
global_var.initializer = core.Constant.null(var_ty)
ctx.scope.define(str(name).lower(), global_var)
......@@ -201,14 +283,14 @@ def _process(process):
# Declare output signal functions
for signal in process.output_signals:
if 'type' in signal:
param_tys = [core.Type.pointer(generate_type(signal['type']))]
param_tys = [core.Type.pointer(ctx.type_of(signal['type']))]
else:
param_tys = []
decl_func(str(signal['name']), ctx.void, param_tys, True)
# Declare external procedures functions
for proc in [proc for proc in process.procedures if proc.external]:
param_tys = [core.Type.pointer(generate_type(p['type'])) for p in proc.fpar]
param_tys = [core.Type.pointer(ctx.type_of(p['type'])) for p in proc.fpar]
decl_func(str(proc.inputString), ctx.void, param_tys, True)
# Generate internal procedures
......@@ -317,7 +399,7 @@ def generate_input_signal(signal, inputs):
func_name = ctx.name + "_" + str(signal['name'])
param_tys = []
if 'type' in signal:
param_tys.append(core.Type.pointer(generate_type(signal['type'])))
param_tys.append(core.Type.pointer(ctx.type_of(signal['type'])))
func = decl_func(func_name, ctx.void, param_tys)
......@@ -1031,7 +1113,7 @@ def _if_then_else(ifthen):
else_block = func.append_basic_block('ternary:else')
end_block = func.append_basic_block('')
res_ptr = ctx.builder.alloca(generate_type(ifthen.exprType))
res_ptr = ctx.builder.alloca(ctx.type_of(ifthen.exprType))
cond_val = expression(ifthen.value['if'])
ctx.builder.cbranch(cond_val, if_block, else_block)
......@@ -1066,7 +1148,7 @@ def _sequence(seq):
def _sequence_of(seqof):
''' Generate the code for an ASN.1 SEQUENCE OF '''
basic_ty = find_basic_type(seqof.exprType)
ty = generate_type(seqof.exprType)
ty = ctx.type_of(seqof.exprType)
struct_ptr = ctx.builder.alloca(ty)
is_variable_size = basic_ty.Min != basic_ty.Max
......@@ -1266,7 +1348,7 @@ def _floating_label(label):
@generate.register(ogAST.Procedure)
def _inner_procedure(proc):
''' Generate the code for a procedure '''
param_tys = [core.Type.pointer(generate_type(p['type'])) for p in proc.fpar]
param_tys = [core.Type.pointer(ctx.type_of(p['type'])) for p in proc.fpar]
func = decl_func(str(proc.inputString), ctx.void, param_tys)
if proc.external:
......@@ -1281,7 +1363,7 @@ def _inner_procedure(proc):
ctx.builder = core.Builder.new(entry_block)
for name, (ty, expr) in proc.variables.viewitems():
var_ty = generate_type(ty)
var_ty = ctx.type_of(ty)
var_ptr = ctx.builder.alloca(var_ty)
ctx.scope.define(name, var_ptr)
if expr:
......@@ -1302,95 +1384,6 @@ def _inner_procedure(proc):
func.verify()
def generate_type(ty):
''' Generate the equivalent LLVM type of a ASN.1 type '''
basic_ty = find_basic_type(ty)
try:
name = ty.ReferencedTypeName.replace('-', '_')
except AttributeError:
name = None
if basic_ty.kind in ['IntegerType', 'Integer32Type']:
return ctx.i32
elif basic_ty.kind == 'BooleanType':
return ctx.i1
elif basic_ty.kind == 'RealType':
return ctx.double
elif basic_ty.kind == 'SequenceOfType':
return generate_sequenceof_type(name, basic_ty)
elif basic_ty.kind == 'SequenceType':
return generate_sequence_type(name, basic_ty)
elif basic_ty.kind == 'EnumeratedType':
return ctx.i32
elif basic_ty.kind == 'ChoiceType':
return generate_choice_type(name, basic_ty)
elif basic_ty.kind == 'OctetStringType':
return generate_octetstring_type(name, basic_ty)
else:
raise NotImplementedError
def generate_sequenceof_type(name, sequenceof_ty):
''' Generate the equivalent LLVM type of a SequenceOf type '''
if name and name in ctx.structs:
return ctx.structs[name].ty
min_size = int(sequenceof_ty.Min)
max_size = int(sequenceof_ty.Max)
is_variable_size = min_size != max_size
elem_ty = generate_type(sequenceof_ty.type)
array_ty = core.Type.array(elem_ty, max_size)
if is_variable_size:
struct = decl_struct(['nCount', 'arr'], [ctx.i32, array_ty], name)
else:
struct = decl_struct(['arr'], [array_ty], name)
return struct.ty
def generate_sequence_type(name, sequence_ty):
''' Generate the equivalent LLVM type of a Sequence type '''
if name in ctx.structs:
return ctx.structs[name].ty
field_names = []
field_types = []
for field_name in Helper.sorted_fields(sequence_ty):
field_names.append(field_name.replace('-', '_'))
field_types.append(generate_type(sequence_ty.Children[field_name].type))
struct = decl_struct(field_names, field_types, name)
return struct.ty
def generate_choice_type(name, choice_ty):
''' Generate the equivalent LLVM type of a Choice type '''
if name in ctx.unions:
return ctx.unions[name].ty
field_names = []
field_types = []
for name, t in choice_ty.Children.viewitems():
field_names.append(name)
field_types.append(generate_type(t.type))
union = decl_union(field_names, field_types, name)
return union.ty
def generate_octetstring_type(name, octetstring_ty):
''' Generate the equivalent LLVM type of a OcterString type '''
if name in ctx.structs:
return ctx.structs[name].ty
max_size = int(octetstring_ty.Max)
arr_ty = core.Type.array(ctx.i8, max_size)
struct = decl_struct(['nCount', 'arr'], [ctx.i32, arr_ty], name)
return struct.ty
def get_string_cons(str):
''' Returns a reference to a global string constant with the given value '''
if str in ctx.strings:
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment