Commit eb216185 authored by dbarbera's avatar dbarbera
Browse files

Refactor get_string_constant function into string_ptr context method

parent 078ca0b3
...@@ -178,6 +178,18 @@ class Context(): ...@@ -178,6 +178,18 @@ class Context():
struct = decl_struct(['nCount', 'arr'], [ctx.i32, arr_ty], name) struct = decl_struct(['nCount', 'arr'], [ctx.i32, arr_ty], name)
return struct.ty return struct.ty
def string_ptr(self, str):
''' Returns a pointer to a global string with the given value '''
if str in self.strings:
return self.strings[str].gep([self.zero, self.zero])
str_val = core.Constant.stringz(str)
var_name = '.str%s' % len(self.strings)
var_ptr = self.module.add_global_variable(str_val.type, var_name)
var_ptr.initializer = str_val
self.strings[str] = var_ptr
return var_ptr.gep([self.zero, self.zero])
class StructType(): class StructType():
def __init__(self, name, field_names, field_types): def __init__(self, name, field_names, field_types):
...@@ -488,25 +500,20 @@ def generate_write(params): ...@@ -488,25 +500,20 @@ def generate_write(params):
expr_val = expression(param) expr_val = expression(param)
if basic_ty.kind in ['IntegerType', 'Integer32Type']: if basic_ty.kind in ['IntegerType', 'Integer32Type']:
fmt_val = get_string_cons('% d') fmt_str_ptr = ctx.string_ptr('% d')
fmt_ptr = ctx.builder.gep(fmt_val, [ctx.zero, ctx.zero]) ctx.builder.call(ctx.funcs['printf'], [fmt_str_ptr, expr_val])
ctx.builder.call(ctx.funcs['printf'], [fmt_ptr, expr_val])
elif basic_ty.kind == 'RealType': elif basic_ty.kind == 'RealType':
fmt_val = get_string_cons('% .14E') fmt_str_ptr = ctx.string_ptr('% .14E')
fmt_ptr = ctx.builder.gep(fmt_val, [ctx.zero, ctx.zero]) ctx.builder.call(ctx.funcs['printf'], [fmt_str_ptr, expr_val])
ctx.builder.call(ctx.funcs['printf'], [fmt_ptr, expr_val])
elif basic_ty.kind == 'BooleanType': elif basic_ty.kind == 'BooleanType':
true_str_val = get_string_cons('TRUE') true_str_ptr = ctx.string_ptr('TRUE')
true_str_ptr = ctx.builder.gep(true_str_val, [ctx.zero, ctx.zero]) false_str_ptr = ctx.string_ptr('FALSE')
false_str_val = get_string_cons('FALSE')
false_str_ptr = ctx.builder.gep(false_str_val, [ctx.zero, ctx.zero])
str_ptr = ctx.builder.select(expr_val, true_str_ptr, false_str_ptr) str_ptr = ctx.builder.select(expr_val, true_str_ptr, false_str_ptr)
ctx.builder.call(ctx.funcs['printf'], [str_ptr]) ctx.builder.call(ctx.funcs['printf'], [str_ptr])
elif basic_ty.kind in ['StringType', 'OctetStringType']: elif basic_ty.kind in ['StringType', 'OctetStringType']:
fmt_val = get_string_cons('%s') fmt_str_ptr = ctx.string_ptr('%s')
fmt_ptr = ctx.builder.gep(fmt_val, [ctx.zero, ctx.zero])
arr_ptr = ctx.builder.gep(expr_val, [ctx.zero, ctx.one]) arr_ptr = ctx.builder.gep(expr_val, [ctx.zero, ctx.one])
ctx.builder.call(ctx.funcs['printf'], [fmt_ptr, arr_ptr]) ctx.builder.call(ctx.funcs['printf'], [fmt_str_ptr, arr_ptr])
else: else:
raise NotImplementedError raise NotImplementedError
...@@ -515,9 +522,7 @@ def generate_writeln(params): ...@@ -515,9 +522,7 @@ def generate_writeln(params):
''' Generate the code for the writeln operator ''' ''' Generate the code for the writeln operator '''
generate_write(params) generate_write(params)
zero = core.Constant.int(ctx.i32, 0) str_ptr = ctx.string_ptr('\n')
str_cons = get_string_cons('\n')
str_ptr = ctx.builder.gep(str_cons, [zero, zero])
ctx.builder.call(ctx.funcs['printf'], [str_ptr]) ctx.builder.call(ctx.funcs['printf'], [str_ptr])
...@@ -1065,8 +1070,7 @@ def _empty_string(primary): ...@@ -1065,8 +1070,7 @@ def _empty_string(primary):
@expression.register(ogAST.PrimStringLiteral) @expression.register(ogAST.PrimStringLiteral)
def _string_literal(primary): def _string_literal(primary):
''' Generate code for a string (Octet String) ''' ''' Generate code for a string (Octet String) '''
str_val = get_string_cons(str(primary.value[1:-1])) str_ptr = ctx.string_ptr(str(primary.value[1:-1]))
str_ptr = ctx.builder.gep(str_val, [ctx.zero, ctx.zero])
# Allocate anonymous OctetString struct # Allocate anonymous OctetString struct
str_len = len(str(primary.value[1:-1])) + 1 str_len = len(str(primary.value[1:-1])) + 1
...@@ -1386,19 +1390,6 @@ def _inner_procedure(proc): ...@@ -1386,19 +1390,6 @@ def _inner_procedure(proc):
func.verify() func.verify()
def get_string_cons(str):
''' Returns a reference to a global string constant with the given value '''
if str in ctx.strings:
return ctx.strings[str]
str_val = core.Constant.stringz(str)
gvar_name = '.str%s' % len(ctx.strings)
gvar_val = ctx.module.add_global_variable(str_val.type, gvar_name)
gvar_val.initializer = str_val
ctx.strings[str] = gvar_val
return gvar_val
def decl_func(name, return_ty, param_tys, extern=False): def decl_func(name, return_ty, param_tys, extern=False):
''' Declare a function ''' ''' Declare a function '''
func_ty = core.Type.function(return_ty, param_tys) func_ty = core.Type.function(return_ty, param_tys)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment