Commit 1a45a200 authored by dbarbera's avatar dbarbera
Browse files

Refactor find_basic_type function into ctx.basic_type_of

parent a7f94aec
......@@ -94,6 +94,17 @@ class Context():
''' Close the current scope '''
self.scope = self.scope.parent
def basic_type_of(self, asn1ty):
''' Return the ASN.1 basic type of a type '''
basic_type = asn1ty
while basic_type.kind == 'ReferenceType':
# Find type with proper case in the data view
for typename in self.dataview.viewkeys():
if typename.lower() == basic_type.ReferencedTypeName.lower():
basic_type = self.dataview[typename].type
break
return basic_type
def type_of(self, asn1ty):
''' Return the LL type of a ASN.1 type '''
try:
......@@ -104,7 +115,7 @@ class Context():
if name and name in self.lltypes:
return self.lltypes[name]
basic_asn1ty = find_basic_type(asn1ty, self)
basic_asn1ty = self.basic_type_of(asn1ty)
if basic_asn1ty.kind == 'IntegerType':
llty = self.i64
......@@ -563,7 +574,7 @@ def _call_external_function(output, ctx):
def generate_write(params, ctx):
''' Generate the IR for the write operator '''
for param in params:
basic_ty = find_basic_type(param.exprType, ctx)
basic_ty = ctx.basic_type_of(param.exprType)
expr_val = expression(param, ctx)
if basic_ty.kind in ['IntegerType', 'Integer32Type']:
......@@ -692,7 +703,7 @@ def generate_for_range(loop, ctx):
def generate_for_iterable(loop, ctx):
''' Generate the IR for a for x in iterable loop '''
seqof_asn1ty = find_basic_type(loop['list'].exprType, ctx)
seqof_asn1ty = ctx.basic_type_of(loop['list'].exprType)
is_variable_size = seqof_asn1ty.Min != seqof_asn1ty.Max
func = ctx.builder.basic_block.function
......@@ -823,7 +834,7 @@ def _expr_arith(expr, ctx):
left_val = expression(expr.left, ctx)
right_val = expression(expr.right, ctx)
expr_bty = find_basic_type(expr.exprType, ctx)
expr_bty = ctx.basic_type_of(expr.exprType)
if expr_bty.kind in ('IntegerType', 'Integer32Type'):
if expr.operand == '+':
......@@ -871,7 +882,7 @@ def _expr_rel(expr, ctx):
left_val = expression(expr.left, ctx)
right_val = expression(expr.right, ctx)
operands_bty = find_basic_type(expr.left.exprType, ctx)
operands_bty = ctx.basic_type_of(expr.left.exprType)
if operands_bty.kind in ('IntegerType', 'Integer32Type'):
if expr.operand == '<':
......@@ -983,7 +994,7 @@ def generate_assign(left, right, ctx):
@expression.register(ogAST.ExprXor)
def _expr_logic(expr, ctx):
''' Generate the IR for a logic expression '''
bty = find_basic_type(expr.exprType, ctx)
bty = ctx.basic_type_of(expr.exprType)
if expr.shortcircuit:
if bty.kind != 'BooleanType':
......@@ -1085,7 +1096,7 @@ def _expr_logic(expr, ctx):
@expression.register(ogAST.ExprNot)
def _expr_not(expr, ctx):
''' Generate the IR for a not expression '''
bty = find_basic_type(expr.exprType, ctx)
bty = ctx.basic_type_of(expr.exprType)
if bty.kind == 'BooleanType':
return ctx.builder.not_(expression(expr.expr, ctx))
......@@ -1136,7 +1147,7 @@ def _expr_not(expr, ctx):
@expression.register(ogAST.ExprAppend)
def _expr_append(expr, ctx):
''' Generate the IR for a append expression '''
bty = find_basic_type(expr.exprType, ctx)
bty = ctx.basic_type_of(expr.exprType)
if bty.kind in ('SequenceOfType', 'OctetStringType'):
res_ty = ctx.type_of(expr.exprType)
......@@ -1193,7 +1204,7 @@ def _expr_in(expr, ctx):
check_block = func.append_basic_block('in:check')
end_block = func.append_basic_block('in:end')
seq_asn1_ty = find_basic_type(expr.left.exprType, ctx)
seq_asn1_ty = ctx.basic_type_of(expr.left.exprType)
is_variable_size = seq_asn1_ty.Min != seq_asn1_ty.Max
......@@ -1264,7 +1275,7 @@ def _prim_index(prim, ctx):
@expression.register(ogAST.PrimSubstring)
def _prim_substring(prim, ctx):
''' Generate the IR for a substring expression '''
bty = find_basic_type(prim.exprType, ctx)
bty = ctx.basic_type_of(prim.exprType)
if bty.Min == bty.Max:
raise NotImplementedError
......@@ -1326,7 +1337,7 @@ def generate_length(params, ctx):
''' Generate the IR for the length operator '''
seq_ptr = reference(params[0], ctx)
bty = find_basic_type(params[0].exprType, ctx)
bty = ctx.basic_type_of(params[0].exprType)
if bty.Min != bty.Max:
len_ptr = ctx.builder.gep(seq_ptr, [ctx.zero, ctx.zero])
return ctx.builder.zext(ctx.builder.load(len_ptr), ctx.i64)
......@@ -1389,7 +1400,7 @@ def generate_num(params, ctx):
def _prim_enumerated_value(prim, ctx):
''' Generate the IR for an enumerated value '''
enumerant = prim.value[0].replace('_', '-')
basic_ty = find_basic_type(prim.exprType, ctx)
basic_ty = ctx.basic_type_of(prim.exprType)
return core.Constant.int(ctx.i32, basic_ty.EnumValues[enumerant].IntValue)
......@@ -1434,7 +1445,7 @@ def _prim_empty_string(prim, ctx):
@expression.register(ogAST.PrimStringLiteral)
def _prim_string_literal(prim, ctx):
''' Generate the IR for a string'''
bty = find_basic_type(prim.exprType, ctx)
bty = ctx.basic_type_of(prim.exprType)
str_len = len(str(prim.value[1:-1]))
str_ptr = ctx.string_ptr(str(prim.value[1:-1]))
......@@ -1535,7 +1546,7 @@ def _prim_sequence(prim, ctx):
@expression.register(ogAST.PrimSequenceOf)
def _prim_sequence_of(prim, ctx):
''' Generate the IR for an ASN.1 SEQUENCE OF '''
basic_ty = find_basic_type(prim.exprType, ctx)
basic_ty = ctx.basic_type_of(prim.exprType)
ty = ctx.type_of(prim.exprType)
struct_ptr = ctx.builder.alloca(ty)
......@@ -1792,16 +1803,3 @@ def is_struct_ptr(val):
def is_array_ptr(val):
return val.type.kind == core.TYPE_POINTER and val.type.pointee.kind == core.TYPE_ARRAY
# TODO: Refactor this into the helper module
def find_basic_type(a_type, ctx):
''' Return the ASN.1 basic type of a_type '''
basic_type = a_type
while basic_type.kind == 'ReferenceType':
# Find type with proper case in the data view
for typename in ctx.dataview.viewkeys():
if typename.lower() == basic_type.ReferencedTypeName.lower():
basic_type = ctx.dataview[typename].type
break
return basic_type
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment