cc: add support for structs

- now we keep track of expression type
- added '|' and '|=' operators
- implemented '->' operator
- minor cleanups
This commit is contained in:
Paul Mathieu 2021-03-13 15:42:39 -08:00
parent 3b56750a73
commit a3a67105eb
2 changed files with 178 additions and 30 deletions

View File

@ -64,7 +64,7 @@ initializer_list: "{" [init_list_field ("," init_list_field)* ","? ] "}"
| prec2_expr ">>=" prec14_expr | prec2_expr ">>=" prec14_expr
| prec2_expr "&=" prec14_expr | prec2_expr "&=" prec14_expr
| prec2_expr "^=" prec14_expr | prec2_expr "^=" prec14_expr
| prec2_expr "|=" prec14_expr | prec2_expr "|=" prec14_expr -> or_ass
?prec13_expr: prec12_expr ?prec13_expr: prec12_expr
| prec12_expr "?" prec13_expr ":" prec13_expr | prec12_expr "?" prec13_expr ":" prec13_expr
@ -74,7 +74,7 @@ initializer_list: "{" [init_list_field ("," init_list_field)* ","? ] "}"
?prec11_expr: prec10_expr ?prec11_expr: prec10_expr
| prec11_expr "&&" prec10_expr | prec11_expr "&&" prec10_expr
?prec10_expr: prec9_expr ?prec10_expr: prec9_expr
| prec10_expr "|" prec9_expr | prec10_expr "|" prec9_expr -> _or
?prec9_expr: prec8_expr ?prec9_expr: prec8_expr
| prec9_expr "^" prec8_expr | prec9_expr "^" prec8_expr
?prec8_expr: prec7_expr ?prec8_expr: prec7_expr
@ -127,7 +127,7 @@ litteral: SIGNED_NUMBER | ESCAPED_STRING | HEX_LITTERAL | CHARACTER
field: IDENTIFIER field: IDENTIFIER
identifier: IDENTIFIER identifier: IDENTIFIER
?symbol: IDENTIFIER ?symbol: IDENTIFIER
?type: type_qualifier* IDENTIFIER type: type_qualifier* IDENTIFIER
| struct_type | struct_type
| type "*" -> pointer | type "*" -> pointer
?array_size: INT ?array_size: INT

View File

@ -1,6 +1,7 @@
import argparse import argparse
import collections import collections
import contextlib import contextlib
from dataclasses import dataclass
import importlib import importlib
import io import io
import re import re
@ -19,13 +20,13 @@ class Scope:
def __init__(self): def __init__(self):
self.symbols = {} self.symbols = {}
self.parent = None self.parent = None
self.structs = {}
class Variable: class Variable:
def __init__(self, type, name, volatile=False, addr_reg=None): def __init__(self, type, name, addr_reg=None):
self.type = type self.type = type
self.name = name self.name = name
self.volatile = volatile
self.addr_reg = addr_reg self.addr_reg = addr_reg
def __repr__(self): def __repr__(self):
@ -33,18 +34,13 @@ class Variable:
@classmethod @classmethod
def from_def(cls, tree): def from_def(cls, tree):
volatile = False
type = tree.children[0] type = tree.children[0]
if isinstance(type, lark.Tree):
for c in type.children:
if c == "volatile":
volatile = True
name = tree.children[1] name = tree.children[1]
return cls(type, name, volatile=volatile) return cls(type, name)
@classmethod @classmethod
def from_dereference(cls, reg): def from_dereference(cls, reg, tree):
return cls('deref', 'deref', addr_reg=reg) return cls(tree.type.pointed, 'deref', addr_reg=reg)
# - all registers pointing to unwritten stuff can be dropped as soon as we're # - all registers pointing to unwritten stuff can be dropped as soon as we're
@ -213,6 +209,10 @@ class FunctionSpec:
params = ', '.join(self.param_types) params = ', '.join(self.param_types)
return f'<Function: {self.return_type} {self.name}({params})>' return f'<Function: {self.return_type} {self.name}({params})>'
@property
def type(self):
return FunType(ret=self.return_type, params=self.param_types)
class Function: class Function:
def __init__(self, fun_prot): def __init__(self, fun_prot):
@ -261,6 +261,49 @@ class Function:
f'{self.spec.name}:'] + indented f'{self.spec.name}:'] + indented
class Struct:
def __init__(self, struct_def):
self.name = struct_def.children[0]
self.fields = [Variable(*x.children) for x in struct_def.children[1].children]
def offset(self, field):
return 2 * [x.name for x in self.fields].index(field)
def field_type(self, field):
for f in self.fields:
if f.name == field:
return f.type
return None
@dataclass
class CType:
volatile: bool = False
const: bool = False
@dataclass
class StructType(CType):
struct: str = ''
@dataclass
class PointerType(CType):
pointed: CType = None
@dataclass
class FunType(CType):
ret: CType = None
params: [CType] = None
@dataclass
class PodType(CType):
pod: str = 'int'
signed = True
@dataclass
class VoidType(CType):
pass
class CcTransform(lark.visitors.Transformer): class CcTransform(lark.visitors.Transformer):
def _binary_op(litt): def _binary_op(litt):
@lark.v_args(tree=True) @lark.v_args(tree=True)
@ -272,8 +315,37 @@ class CcTransform(lark.visitors.Transformer):
return tree return tree
return _f return _f
def field(self, children):
(field,) = children
return field
def struct_type(self, children):
(name,) = children
return StructType(struct=name)
def pointer(self, children):
volat = 'volatile' in children
pointed = children[-1]
return PointerType(volatile=volat, pointed=pointed)
volatile = lambda *_: 'volatile'
const = lambda *_: 'const'
def type(self, children):
volat = 'volatile' in children
const = 'const' in children
typ = children[-1]
if isinstance(typ, str):
if typ == 'void':
return VoidType()
return PodType(volatile=volat, const=const, pod=typ)
else:
return typ
def array_item(self, children): def array_item(self, children):
# transform blarg[foo] into *(blarg + foo) because reasons # transform blarg[foo] into *(blarg + foo) because reasons
# TODO: bring this over to the main parser, we need to know the type
# of blarg in order to compute the actual offset
addop = lark.Tree('add', children) addop = lark.Tree('add', children)
return lark.Tree('dereference', [addop]) return lark.Tree('dereference', [addop])
@ -335,7 +407,7 @@ MulOp = make_cpu_bin_op('mul')
# no div # no div
# no mod either # no mod either
AndOp = make_cpu_bin_op('and') AndOp = make_cpu_bin_op('and')
orOp = make_cpu_bin_op('or') OrOp = make_cpu_bin_op('or')
XorOp = make_cpu_bin_op('xor') XorOp = make_cpu_bin_op('xor')
class ShlOp(BinOp): class ShlOp(BinOp):
@ -634,6 +706,23 @@ class Delayed(AsmOp):
return [] return []
def litt_type(val):
if isinstance(val, str):
return PointerType(pointed=PodType(pod='char', const=True))
elif isinstance(val, float):
return PodType(pod='float', const=True)
else:
return PodType(pod='int', const=True)
def return_type(type):
return type.ret
def deref_type(type):
return type.pointed
class CcInterp(lark.visitors.Interpreter): class CcInterp(lark.visitors.Interpreter):
def __init__(self): def __init__(self):
self.global_scope = Scope() self.global_scope = Scope()
@ -659,6 +748,14 @@ class CcInterp(lark.visitors.Interpreter):
self.next_marker += 1 self.next_marker += 1
return mark return mark
@contextlib.contextmanager
def _nosynth(self):
old = self._synth
self._synth = lambda *_: None
yield
self._synth = old
def _synth(self, op): def _synth(self, op):
with self.cur_fun.regs.borrow(op.scratch_need) as scratches: with self.cur_fun.regs.borrow(op.scratch_need) as scratches:
self._log(f'{op.__class__.__name__}') self._log(f'{op.__class__.__name__}')
@ -671,7 +768,7 @@ class CcInterp(lark.visitors.Interpreter):
reg = self._get_reg() reg = self._get_reg()
return SetAddr(self.cur_fun, [reg, ident]) return SetAddr(self.cur_fun, [reg, ident])
else: else:
if s.volatile: if s.type.volatile:
self._log(f'loading volatile {s}') self._log(f'loading volatile {s}')
return Load(self.cur_fun, [reg, s]) return Load(self.cur_fun, [reg, s])
reg, created = self.cur_fun.regs.load(s) reg, created = self.cur_fun.regs.load(s)
@ -695,6 +792,8 @@ class CcInterp(lark.visitors.Interpreter):
tree.op = Delayed(delayed_load) tree.op = Delayed(delayed_load)
tree.var = self._lookup_symbol(tree.children[0]) tree.var = self._lookup_symbol(tree.children[0])
assert tree.var is not None, f'undefined identifier: {tree.children[0]}'
tree.type = tree.var.type
def litteral(self, tree): def litteral(self, tree):
imm = tree.children[0] imm = tree.children[0]
@ -702,19 +801,14 @@ class CcInterp(lark.visitors.Interpreter):
assert self.cur_fun is not None assert self.cur_fun is not None
tree.op = Set16Imm(self.cur_fun, [reg, imm]) tree.op = Set16Imm(self.cur_fun, [reg, imm])
self._synth(tree.op) self._synth(tree.op)
tree.type = litt_type(tree.children[0])
def _assign(self, left, right): def _assign(self, left, right):
# need to make sure there's a variable and either:
# - to store it, or
# - to mark a register for it
if left.volatile:
self._synth(Store(self.cur_fun, [right, left]))
self.cur_fun.stored(left)
return
self._log(f'assigning {left} = {right}') self._log(f'assigning {left} = {right}')
# cleanup = lambda reg: self._synth(Store(self.cur_fun, [reg, left]))
self.cur_fun.regs.assign(left, right) self.cur_fun.regs.assign(left, right)
self._synth(Store(self.cur_fun, [right, left])) op = Store(self.cur_fun, [right, left])
self._synth(op)
return op.out
def assignment(self, tree): def assignment(self, tree):
self.visit_children(tree) self.visit_children(tree)
@ -722,7 +816,7 @@ class CcInterp(lark.visitors.Interpreter):
# left hand side is an lvalue, retrieve it # left hand side is an lvalue, retrieve it
var = tree.children[0].var var = tree.children[0].var
self._assign(var, val) tree.op = self._assign(var, val)
def global_var(self, tree): def global_var(self, tree):
self.visit_children(tree) self.visit_children(tree)
@ -763,7 +857,12 @@ class CcInterp(lark.visitors.Interpreter):
self.cur_fun.fun_calls += 1 self.cur_fun.fun_calls += 1
self.visit_children(tree) self.visit_children(tree)
fn_reg = tree.children[0].op.out fn_reg = tree.children[0].op.out
param_regs = [c.op.out for c in tree.children[1:]] if len(tree.children) == 1:
param_regs = []
elif tree.children[1].data == 'comma_expr':
param_regs = [c.op.out for c in tree.children[1].children]
else:
param_regs = [tree.children[1].op.out]
self.cur_fun.regs.flush_all() self.cur_fun.regs.flush_all()
for i in range(len(param_regs)): for i in range(len(param_regs)):
self.cur_fun.regs.take(f'r{i}') self.cur_fun.regs.take(f'r{i}')
@ -773,6 +872,7 @@ class CcInterp(lark.visitors.Interpreter):
self._synth(tree.op) self._synth(tree.op)
self.cur_fun.regs.reset() self.cur_fun.regs.reset()
self.cur_fun.regs.take('r0') self.cur_fun.regs.take('r0')
tree.type = return_type(tree.children[0].type)
def statement(self, tree): def statement(self, tree):
self.visit_children(tree) self.visit_children(tree)
@ -840,10 +940,9 @@ class CcInterp(lark.visitors.Interpreter):
self.cur_fun.regs.give(operand) self.cur_fun.regs.give(operand)
return _f return _f
def dereference(self, tree): def _deref(self, tree):
self.visit_children(tree)
reg = tree.children[0].op.out reg = tree.children[0].op.out
var = Variable.from_dereference(reg) var = Variable.from_dereference(reg, tree.children[0])
self._log(f'making var {var} from derefing reg {reg}') self._log(f'making var {var} from derefing reg {reg}')
tree.var = var tree.var = var
def delayed_load(): def delayed_load():
@ -852,6 +951,25 @@ class CcInterp(lark.visitors.Interpreter):
self._synth(op) self._synth(op)
return op.out return op.out
tree.op = Delayed(delayed_load) tree.op = Delayed(delayed_load)
tree.type = var.type
def dereference(self, tree):
self.visit_children(tree)
self._deref(tree)
def pointer_access(self, tree):
with self._nosynth():
self.visit_children(tree)
ch_strct, ch_field = tree.children
type = ch_strct.type.pointed
strct = self.global_scope.structs[type.struct]
offs = strct.offset(ch_field)
load_offs = lark.Tree('litteral', [offs])
addop = lark.Tree('add', [tree.children[0], load_offs])
self.visit(addop)
addop.type = PointerType(pointed=strct.field_type(ch_field))
tree.children = [addop]
self._deref(tree)
def post_increment(self, tree): def post_increment(self, tree):
self.visit_children(tree) self.visit_children(tree)
@ -860,6 +978,7 @@ class CcInterp(lark.visitors.Interpreter):
reg = tree.op.out reg = tree.op.out
self.cur_fun.deferred_ops.append(Incr(self.cur_fun, [reg, reg])) self.cur_fun.deferred_ops.append(Incr(self.cur_fun, [reg, reg]))
self.cur_fun.deferred_ops.append(Store(self.cur_fun, [reg, var])) self.cur_fun.deferred_ops.append(Store(self.cur_fun, [reg, var]))
tree.type = var.type
def post_decrement(self, tree): def post_decrement(self, tree):
self.visit_children(tree) self.visit_children(tree)
@ -868,6 +987,7 @@ class CcInterp(lark.visitors.Interpreter):
reg = tree.op.out reg = tree.op.out
self.cur_fun.deferred_ops.append(Decr(self.cur_fun, [reg, reg])) self.cur_fun.deferred_ops.append(Decr(self.cur_fun, [reg, reg]))
self.cur_fun.deferred_ops.append(Store(self.cur_fun, [reg, var])) self.cur_fun.deferred_ops.append(Store(self.cur_fun, [reg, var]))
tree.type = var.type
pre_increment = _unary_op(Incr) pre_increment = _unary_op(Incr)
bool_not = _unary_op(BoolNot) bool_not = _unary_op(BoolNot)
@ -881,6 +1001,7 @@ class CcInterp(lark.visitors.Interpreter):
self._synth(tree.op) self._synth(tree.op)
self.cur_fun.regs.give(left) self.cur_fun.regs.give(left)
self.cur_fun.regs.give(right) self.cur_fun.regs.give(right)
tree.type = tree.children[0].type # because uhm reasons
return _f return _f
def _combo(uop, bop): def _combo(uop, bop):
@ -895,12 +1016,33 @@ class CcInterp(lark.visitors.Interpreter):
sub = _binary_op(SubOp) sub = _binary_op(SubOp)
mul = _binary_op(MulOp) mul = _binary_op(MulOp)
_and = _binary_op(AndOp) _and = _binary_op(AndOp)
_or = _binary_op(OrOp)
# ... # ...
gt = _binary_op(GtOp) gt = _binary_op(GtOp)
lt = _binary_op(LtOp) lt = _binary_op(LtOp)
neq = _binary_op(NeqOp) neq = _binary_op(NeqOp)
eq = _combo(bool_not, 'neq') eq = _combo(bool_not, 'neq')
def _assign(self, left, right):
self._log(f'assigning {left} = {right}')
self.cur_fun.regs.assign(left, right)
op = Store(self.cur_fun, [right, left])
self._synth(op)
return op.out
def _assign_op(op):
def f(self, tree):
opnode = lark.Tree(op, tree.children)
self.visit(opnode)
val = opnode.op.out
# left hand side is an lvalue, retrieve it
var = tree.children[0].var
tree.op = self._assign(var, val)
return f
or_ass = _assign_op('_or')
def _forward_op(self, tree): def _forward_op(self, tree):
self.visit_children(tree) self.visit_children(tree)
tree.op = tree.children[0].op tree.op = tree.children[0].op
@ -908,6 +1050,7 @@ class CcInterp(lark.visitors.Interpreter):
def cast(self, tree): def cast(self, tree):
self.visit_children(tree) self.visit_children(tree)
tree.op = tree.children[1].op tree.op = tree.children[1].op
tree.type = tree.children[0]
def _log(self, line): def _log(self, line):
self.cur_fun.log(line) self.cur_fun.log(line)
@ -957,7 +1100,7 @@ class CcInterp(lark.visitors.Interpreter):
if fun.spec.name == 'main': if fun.spec.name == 'main':
self._synth(Set16Imm(fun, ['r0', 0])) self._synth(Set16Imm(fun, ['r0', 0]))
self._synth(ReturnReg(fun, ['r0'])) self._synth(ReturnReg(fun, ['r0']))
elif fun.spec.return_type == 'void': elif fun.spec.return_type == VoidType():
self._synth(ReturnReg(fun, ['r0'])) self._synth(ReturnReg(fun, ['r0']))
else: else:
assert fun.ret is not None assert fun.ret is not None
@ -979,6 +1122,11 @@ class CcInterp(lark.visitors.Interpreter):
tree.op = ReturnReg(self.cur_fun, [expr_reg]) tree.op = ReturnReg(self.cur_fun, [expr_reg])
self._synth(tree.op) self._synth(tree.op)
def struct_def(self, tree):
self.visit_children(tree)
strct = Struct(tree)
self.cur_scope.structs[strct.name] = strct
preamble = [f'_start:', preamble = [f'_start:',
f'xor r0, r0, r0', f'xor r0, r0, r0',
f'xor r1, r1, r1', f'xor r1, r1, r1',