For a while I’ve wanted to work on a typed spreadsheet application. This weekend I started working on an interpreter for it using David Beazley’s PLY. So far, this is able to store data and type information in cells in a data store, and perform operations using numbers or references to cells. It also supports limited type checking.
#!/usr/bin/python3 # interpreter.py import ply.lex as lex import ply.yacc as yacc class DictStoragePlugin: def __init__(self): self.store = {} def set(self, key, value): self.store[key] = value def get(self, key): return self.store[key] def keys(self): return self.store.keys() class TerminalOutputPlugin: def __init__(self): pass def show(self, store): keys = store.keys() print() for key in keys: print("{} <- {}".format(key, store.get(key))) class KeyValueStore: def __init__(self, storage_plugin, output_plugin): self.storage = storage_plugin self.display = output_plugin def set(self, key, value): self.storage.set(key, value) def get(self, key): return self.storage.get(key) def view(self): self.display.show(self.storage) class Cell: pass class Currency(Cell): def __init__(self, currency): super().__init__() currencies = [ "USD", "EURO", "INR", "CNY", ] if currency not in currencies: raise RuntimeError("Unknown currency {}".format(currency)) self.currency = currency self.data = None def __repr__(self): if self.data is None: return "({}, {})".format(self.data, self.currency) else: return "({:.2f}, {})".format(self.data, self.currency) class Distance(Cell): def __init__(self, unit): super().__init__() self.unit = unit self.data = None def __repr__(self): return "({}, {})".format(self.data, self.unit) class Area(Cell): def __init__(self, unit): super().__init__() self.unit = unit self.data = None def __repr__(self): return "({}, {})".format(self.data, self.unit) class Text(Cell): def __init__(self, data): super().__init__() self.data = data def __repr__(self): return "({}, TEXT)".format(self.data) class Scalar(Cell): def __init__(self, data): super().__init__() self.data = data def __repr__(self): return "({}, SCALAR)".format(self.data) class Lexer: tokens = ( 'ADDRESS', 'STRING', 'NAME', 'INT', 'FLOAT', 'PLUS', 'MINUS', 'TIMES', 'DIVIDE', 'ASSIGN', 'LPAREN', 'RPAREN', 'COMMA', 'RANGE', ) t_ADDRESS = r'@\(\d,\d\)' # t_STRING = r'"([^"]|\\")*"' t_NAME = r'[\w_][\w\d_]*' t_PLUS = r'\+' t_MINUS = r'-' t_TIMES = r'\*' t_DIVIDE = r'\/' t_ASSIGN = r'=' t_LPAREN = r'\(' t_RPAREN = r'\)' t_COMMA = r',' t_RANGE = r':' def t_FLOAT(self, t): r'\d+\.\d?' t.value = float(t.value) return t def t_INT(self, t): r'\d+' t.value = int(t.value) return t def t_STRING(self, t): r'"([^"]|\\")*"' t.value = t.value[1:-1] return t t_ignore = " \t" def t_newline(self, t): r'\n+' t.lexer.lineno += t.value.count("\n") def t_error(self, t): print("Illegal character '%s'" % t.value[0]) t.lexer.skip(1) def __init__(self): self.lexer = lex.lex(module=self) class Parser: @staticmethod def _parse_cell_address(x) -> tuple: if len(x) < 2: raise RuntimeError if x[:2] != '@(' and x[-1] != ')': raise RuntimeError x = x[2:-1] row, col = map(int, x.split(',')) return row, col @staticmethod def _build_cell_address(x) -> str: return "@({},{})".format(x[0], x[1]) tokens = Lexer.tokens precedence = ( ('left', 'PLUS', 'MINUS'), ('left', 'TIMES', 'DIVIDE'), ('right', 'UMINUS'), ) def p_statement_expr(self, p): '''statement : expression''' p[0] = p[1] def p_expression_string(self, p): '''expression : STRING''' p[0] = p[1] def p_expression_group(self, p): '''expression : LPAREN expression RPAREN''' p[0] = p[2] def p_expression_address(self, p): '''expression : ADDRESS''' """Given an address, resolve to the data at that address """ p[0] = self.cells.get(p[1]).data def p_statement_cell_assignment(self, p): '''statement : ADDRESS ASSIGN expression''' """Assign data to a cell at an address If a cell is a Currency, Distance, or Area cell, update it's data. If the cell is not typed, infer its type as Text or Scalar. """ address = p[1] data = p[3] try: cell = self.cells.get(address) class_name = cell.__class__.__name__ if class_name in ["Currency", "Distance", "Area"]: cell.data = data except: if type(data) == str: self.cells.set(address, Text(data)) elif type(data) == int or type(data) == float: self.cells.set(address, Scalar(data)) else: raise RuntimeError("Unable to resolve type.") def p_statement_single_cellfunc_call_with_argument(self, p): '''statement : NAME LPAREN ADDRESS COMMA STRING RPAREN''' """Cell function operate on cells. This is used to set the type of a cell for special types. """ func_name = p[1] address = p[3] argument = p[5] func = self.cell_funcs[func_name] self.cells.set(address, func(argument)) def p_statement_single_cellfunc_call_without_argument(self, p): '''statement : NAME LPAREN ADDRESS RPAREN''' """This is used for the INFO cell function.cell_function.cell_function. """ func_name = p[1] address = p[3] func = self.cell_funcs[func_name] p[0] = func(self.cells.get(address)) def p_statement_range_cellfunc_call(self, p): '''statement : NAME LPAREN ADDRESS RANGE ADDRESS COMMA STRING RPAREN''' """Perform an operation on each cell in column or row. """ func_name = p[1] func = self.cell_funcs[func_name] beg = self._parse_cell_address(p[3]) end = self._parse_cell_address(p[5]) argument = p[7] # assume same row if beg[0] == end[0]: assert beg[1] <= end[1] for i in range(beg[1], end[1]+1): address = self._build_cell_address((beg[0], i)) func(argument) # assume same column elif beg[1] == end[1]: assert beg[0] <= end[0] for i in range(beg[0], end[0]+1): address = self._build_cell_address((i, beg[1])) func(argument) # barf else: raise RuntimeError def p_statement_range_datafunc_call(self, p): '''expression : NAME LPAREN ADDRESS RANGE ADDRESS RPAREN''' """Collect data from cells and perform some reducing operation. """ func_name = p[1] func = self.data_funcs[func_name] beg = self._parse_cell_address(p[3]) end = self._parse_cell_address(p[5]) initial = self.cells.get(p[3]) initial_type = initial.__class__.__name__ arr = [] # assume same row if beg[0] == end[0]: assert beg[1] <= end[1] for i in range(beg[1], end[1]+1): addr = self._build_cell_address((beg[0], i)) cell = self.cells.get(addr) if cell.__class__.__name__ != initial_type: raise RuntimeError("Inconsistent type.") if initial_type == 'Currency': if cell.currency != initial.currency: raise RuntimeError("Inconsistent currency.") arr.append(cell.data) p[0] = func(arr) # assume same column elif beg[1] == end[1]: assert beg[0] <= end[0] for i in range(beg[0], end[0]+1): addr = self._build_cell_address((i, beg[1])) cell = self.cells.get(addr) if cell.__class__.__name__ != initial_type: raise RuntimeError("Inconsistent type.") if initial_type == 'Currency': if cell.currency != initial.currency: raise RuntimeError("Inconsistent currency.") arr.append(cell.data) p[0] = func(arr) # barf else: raise RuntimeError def p_expression_binop(self, p): '''expression : expression PLUS expression | expression MINUS expression | expression TIMES expression | expression DIVIDE expression''' if p[2] == '+': p[0] = p[1] + p[3] elif p[2] == '-': p[0] = p[1] - p[3] elif p[2] == '*': p[0] = p[1] * p[3] elif p[2] == '/': p[0] = p[1] / p[3] def p_expression_uminus(self, p): '''expression : MINUS expression %prec UMINUS''' p[0] = -p[2] def p_expression_number(self, p): '''expression : number''' p[0] = p[1] def p_number_int(self, p): '''number : INT''' p[0] = p[1] def p_number_float(self, p): '''number : FLOAT''' p[0] = p[1] def p_error(self, p): if p: print("Syntax error at '%s'" % p.value) else: print("Syntax error at EOF") def parse(self, p): return self.parser.parse(p) def __init__(self, storage_plugin, output_plugin): self.lexer = Lexer() self.parser = yacc.yacc(module=self) self.cells = KeyValueStore(storage_plugin, output_plugin) self.data_funcs = { 'SUM': lambda x: sum(x), 'AVG': lambda x: sum(x)/len(x), 'MAX': lambda x: max(x), 'MIN': lambda x: min(x), } self.cell_funcs = { 'SET_TYPE_CURRENCY': lambda x: Currency(x), 'SET_TYPE_DISTANCE': lambda x: Distance(x), 'SET_TYPE_AREA': lambda x: Area(x), 'INFO': lambda x: x.__repr__(), }
And some limited testing.
#!/usr/bin/python3 # test_interpreter.py import interpreter import pytest arithmetic_data = [ ("1", 1), ("1+1", 2), ("1-1", 0), ("-1+1", 0), ("1.0", 1.0), ("1.0+1.0", 2.0), ("1.0-1.0", 0.0), ("-1.0+1.0", 0.0), ("1+1.0", 2.0), ("SUM(@(0,0):@(0,1))", 3), ("AVG(@(0,0):@(0,1))", 1.5), ("SUM(@(0,0):@(0,1))+1", 4), ("MAX(@(0,0):@(0,2))", 3), ("MIN(@(0,0):@(0,2))", 1), ] assignment_data = [ ("@(0,0) = 1", "@(0,0)", 1), ("@(0,1) = 2", "@(0,0) + @(0,1)", 3), ("@(0,2) = @(0,0) + @(0,1)", "@(0,2)", 3), ('@(1,0) = "foo"', "@(1,0)", "foo") ] cell_types = [ ('SET_TYPE_CURRENCY(@(0,0), "USD")', "INFO(@(0,0))", '(None, USD)'), ('SET_TYPE_CURRENCY(@(1,0), "INR")', "INFO(@(1,0))", '(None, INR)'), ('@(0,0) = 3', "INFO(@(0,0))", '(3.00, USD)'), ('@(1,0) = 5', "INFO(@(1,0))", '(5.00, INR)'), ] class TestStorage: def test_store(self): store = interpreter.KeyValueStore(interpreter.DictStoragePlugin(), interpreter.TerminalOutputPlugin()) address = "@(0,0)" store.set(address, 1) assert store.get(address) == 1 class Testinterpreter: def setup_class(cls): cls.storage_plugin = interpreter.DictStoragePlugin() cls.output_plugin = interpreter.TerminalOutputPlugin() cls.parser = interpreter.Parser(cls.storage_plugin, cls.output_plugin) @pytest.mark.parametrize("inp_a, inp_b, outp", assignment_data) def test_assignment_expressions(self, inp_a, inp_b, outp): self.parser.parse(inp_a) assert self.parser.parse(inp_b) == outp @pytest.mark.parametrize("inp, outp", arithmetic_data) def test_arithmetic_expressions(self, inp, outp): assert self.parser.parse(inp) == outp @pytest.mark.parametrize("inp_a, inp_b, outp", cell_types) def test_cell_types(self, inp_a, inp_b, outp): self.parser.parse(inp_a) assert self.parser.parse(inp_b) == outp def test_currency_mismatch(self): try: self.parser.parse("SUM(@(0,0):@(1,0))") assert False except: assert True def test_currency_sum(self): self.parser.parse('SET_TYPE_CURRENCY(@(0,0), "INR")') self.parser.parse('@(0,0) = 7') assert self.parser.parse("SUM(@(0,0):@(1,0))") == 12 def test_currency_avg(self): assert self.parser.parse("AVG(@(0,0):@(1,0))") == 6 def test_currency_max(self): assert self.parser.parse("MAX(@(0,0):@(1,0))") == 7 def test_currency_min(self): assert self.parser.parse("MIN(@(0,0):@(1,0))") == 5 def test_view(self): print(self.parser.cells.view())