Writing an Interpreter in Python with PLY

For a while I’ve wanted to work on a typed spreadsheet application. This weekend I started working on an interpreter for it using David Beazley’s PLY. So far, this is able to store data and type information in cells in a data store, and perform operations using numbers or references to cells. It also supports limited type checking.

#!/usr/bin/python3
# interpreter.py
import ply.lex as lex
import ply.yacc as yacc

class DictStoragePlugin:
 
    def __init__(self):
        self.store = {}
 
    def set(self, key, value):
        self.store[key] = value
 
    def get(self, key):
        return self.store[key]

    def keys(self):
        return self.store.keys()


class TerminalOutputPlugin:

    def __init__(self):
        pass

    def show(self, store):
        keys = store.keys()
        print()
        for key in keys:
            print("{} <- {}".format(key, store.get(key)))


class KeyValueStore:
 
    def __init__(self, storage_plugin, output_plugin):
        self.storage = storage_plugin
        self.display = output_plugin

    def set(self, key, value):
        self.storage.set(key, value)

    def get(self, key):
        return self.storage.get(key)

    def view(self):
        self.display.show(self.storage)


class Cell:
    pass


class Currency(Cell):

    def __init__(self, currency):
        super().__init__()
        currencies = [
            "USD",
            "EURO",
            "INR",
            "CNY",
        ]
        if currency not in currencies:
            raise RuntimeError("Unknown currency {}".format(currency))
        self.currency = currency
        self.data = None


    def __repr__(self):
        if self.data is None:
            return "({}, {})".format(self.data, self.currency)
        else:
            return "({:.2f}, {})".format(self.data, self.currency)


class Distance(Cell):
    def __init__(self, unit):
        super().__init__()
        self.unit = unit
        self.data = None

    def __repr__(self):
        return "({}, {})".format(self.data, self.unit)


class Area(Cell):
    def __init__(self, unit):
        super().__init__()
        self.unit = unit
        self.data = None

    def __repr__(self):
        return "({}, {})".format(self.data, self.unit)


class Text(Cell):

    def __init__(self, data):
        super().__init__()
        self.data = data

    def __repr__(self):
        return "({}, TEXT)".format(self.data)


class Scalar(Cell):

    def __init__(self, data):
        super().__init__()
        self.data = data

    def __repr__(self):
        return "({}, SCALAR)".format(self.data)


class Lexer:

    tokens = (
        'ADDRESS',
        'STRING',
        'NAME',
        'INT',
        'FLOAT',
        'PLUS',
        'MINUS',
        'TIMES',
        'DIVIDE',
        'ASSIGN',
        'LPAREN',
        'RPAREN',
        'COMMA',
        'RANGE',
    )

    t_ADDRESS = r'@\(\d,\d\)'
    # t_STRING = r'"([^"]|\\")*"'
    t_NAME = r'[\w_][\w\d_]*'
    t_PLUS = r'\+'
    t_MINUS = r'-'
    t_TIMES = r'\*'
    t_DIVIDE = r'\/'
    t_ASSIGN = r'='
    t_LPAREN = r'\('
    t_RPAREN = r'\)'
    t_COMMA = r','
    t_RANGE = r':'

    def t_FLOAT(self, t):
        r'\d+\.\d?'
        t.value = float(t.value)
        return t

    def t_INT(self, t):
        r'\d+'
        t.value = int(t.value)
        return t

    def t_STRING(self, t):
        r'"([^"]|\\")*"'
        t.value = t.value[1:-1]
        return t

    t_ignore = " \t"

    def t_newline(self, t):
        r'\n+'
        t.lexer.lineno += t.value.count("\n")

    def t_error(self, t):
        print("Illegal character '%s'" % t.value[0])
        t.lexer.skip(1)

    def __init__(self):
        self.lexer = lex.lex(module=self)


class Parser:

    @staticmethod
    def _parse_cell_address(x) -> tuple:
        if len(x) < 2:
            raise RuntimeError
        if x[:2] != '@(' and x[-1] != ')':
            raise RuntimeError
        x = x[2:-1]
        row, col = map(int, x.split(','))
        return row, col

    @staticmethod
    def _build_cell_address(x) -> str:
        return "@({},{})".format(x[0], x[1])

    tokens = Lexer.tokens

    precedence = (
        ('left', 'PLUS', 'MINUS'),
        ('left', 'TIMES', 'DIVIDE'),
        ('right', 'UMINUS'),
    )

    def p_statement_expr(self, p):
        '''statement : expression'''
        p[0] = p[1]

    def p_expression_string(self, p):
        '''expression : STRING'''
        p[0] = p[1]

    def p_expression_group(self, p):
        '''expression : LPAREN expression RPAREN'''
        p[0] = p[2]

    def p_expression_address(self, p):
        '''expression : ADDRESS'''
        """Given an address, resolve to the data at that address
        """
        p[0] = self.cells.get(p[1]).data

    def p_statement_cell_assignment(self, p):
        '''statement : ADDRESS ASSIGN expression'''
        """Assign data to a cell at an address

        If a cell is a Currency, Distance, or Area cell, update it's data.
        If the cell is not typed, infer its type as Text or Scalar.
        """
        address = p[1]
        data = p[3]
        try:
            cell = self.cells.get(address)
            class_name = cell.__class__.__name__
            if class_name in ["Currency", "Distance", "Area"]:
                cell.data = data
        except:
            if type(data) == str:
                self.cells.set(address, Text(data))
            elif type(data) == int or type(data) == float:
                self.cells.set(address, Scalar(data))
            else:
                raise RuntimeError("Unable to resolve type.")

    def p_statement_single_cellfunc_call_with_argument(self, p):
        '''statement : NAME LPAREN ADDRESS COMMA STRING RPAREN'''
        """Cell function operate on cells.

        This is used to set the type of a cell for special types.
        """
        func_name = p[1]
        address = p[3]
        argument = p[5]
        func = self.cell_funcs[func_name]
        self.cells.set(address, func(argument))

    def p_statement_single_cellfunc_call_without_argument(self, p):
        '''statement : NAME LPAREN ADDRESS RPAREN'''
        """This is used for the INFO cell function.cell_function.cell_function.
        """
        func_name = p[1]
        address = p[3]
        func = self.cell_funcs[func_name]
        p[0] = func(self.cells.get(address))

    def p_statement_range_cellfunc_call(self, p):
        '''statement : NAME LPAREN ADDRESS RANGE ADDRESS COMMA STRING RPAREN'''
        """Perform an operation on each cell in column or row.
        """
        func_name = p[1]
        func = self.cell_funcs[func_name]
        beg = self._parse_cell_address(p[3])
        end = self._parse_cell_address(p[5])
        argument = p[7]
        # assume same row
        if beg[0] == end[0]:
            assert beg[1] <= end[1]
            for i in range(beg[1], end[1]+1):
                address = self._build_cell_address((beg[0], i))
                func(argument)
        # assume same column
        elif beg[1] == end[1]:
            assert beg[0] <= end[0]
            for i in range(beg[0], end[0]+1):
                address = self._build_cell_address((i, beg[1]))
                func(argument)
        # barf
        else:
            raise RuntimeError

    def p_statement_range_datafunc_call(self, p):
        '''expression : NAME LPAREN ADDRESS RANGE ADDRESS RPAREN'''
        """Collect data from cells and perform some reducing operation.
        """
        func_name = p[1]
        func = self.data_funcs[func_name]
        beg = self._parse_cell_address(p[3])
        end = self._parse_cell_address(p[5])
        initial = self.cells.get(p[3])
        initial_type = initial.__class__.__name__
        arr = []
        # assume same row
        if beg[0] == end[0]:
            assert beg[1] <= end[1]
            for i in range(beg[1], end[1]+1):
                addr = self._build_cell_address((beg[0], i))
                cell = self.cells.get(addr)
                if cell.__class__.__name__ != initial_type:
                    raise RuntimeError("Inconsistent type.")
                if initial_type == 'Currency':
                    if cell.currency != initial.currency:
                        raise RuntimeError("Inconsistent currency.")
                arr.append(cell.data)
            p[0] = func(arr)
        # assume same column
        elif beg[1] == end[1]:
            assert beg[0] <= end[0]
            for i in range(beg[0], end[0]+1):
                addr = self._build_cell_address((i, beg[1]))
                cell = self.cells.get(addr)
                if cell.__class__.__name__ != initial_type:
                    raise RuntimeError("Inconsistent type.")
                if initial_type == 'Currency':
                    if cell.currency != initial.currency:
                        raise RuntimeError("Inconsistent currency.")
                arr.append(cell.data)
            p[0] = func(arr)
        # barf
        else:
            raise RuntimeError

    def p_expression_binop(self, p):
        '''expression : expression PLUS expression
                      | expression MINUS expression
                      | expression TIMES expression
                      | expression DIVIDE expression'''
        if p[2] == '+':
            p[0] = p[1] + p[3]
        elif p[2] == '-':
            p[0] = p[1] - p[3]
        elif p[2] == '*':
            p[0] = p[1] * p[3]
        elif p[2] == '/':
            p[0] = p[1] / p[3]

    def p_expression_uminus(self, p):
        '''expression : MINUS expression %prec UMINUS'''
        p[0] = -p[2]

    def p_expression_number(self, p):
        '''expression : number'''
        p[0] = p[1]

    def p_number_int(self, p):
        '''number : INT'''
        p[0] = p[1]

    def p_number_float(self, p):
        '''number : FLOAT'''
        p[0] = p[1]

    def p_error(self, p):
        if p:
            print("Syntax error at '%s'" % p.value)
        else:
            print("Syntax error at EOF")

    def parse(self, p):
        return self.parser.parse(p)

    def __init__(self, storage_plugin, output_plugin):
        self.lexer = Lexer()
        self.parser = yacc.yacc(module=self)
        self.cells = KeyValueStore(storage_plugin, output_plugin)
        self.data_funcs = {
            'SUM': lambda x: sum(x),
            'AVG': lambda x: sum(x)/len(x),
            'MAX': lambda x: max(x),
            'MIN': lambda x: min(x),
        }
        self.cell_funcs = {
            'SET_TYPE_CURRENCY': lambda x: Currency(x),
            'SET_TYPE_DISTANCE': lambda x: Distance(x),
            'SET_TYPE_AREA': lambda x: Area(x),
            'INFO': lambda x: x.__repr__(),
        }

And some limited testing.

#!/usr/bin/python3
# test_interpreter.py
import interpreter
import pytest

arithmetic_data = [
    ("1", 1),
    ("1+1", 2),
    ("1-1", 0),
    ("-1+1", 0),
    ("1.0", 1.0),
    ("1.0+1.0", 2.0),
    ("1.0-1.0", 0.0),
    ("-1.0+1.0", 0.0),
    ("1+1.0", 2.0),
    ("SUM(@(0,0):@(0,1))", 3),
    ("AVG(@(0,0):@(0,1))", 1.5),
    ("SUM(@(0,0):@(0,1))+1", 4),
    ("MAX(@(0,0):@(0,2))", 3),
    ("MIN(@(0,0):@(0,2))", 1),
]

assignment_data = [
    ("@(0,0) = 1", "@(0,0)", 1),
    ("@(0,1) = 2", "@(0,0) + @(0,1)", 3),
    ("@(0,2) = @(0,0) + @(0,1)", "@(0,2)", 3),
    ('@(1,0) = "foo"', "@(1,0)", "foo")
]

cell_types = [
    ('SET_TYPE_CURRENCY(@(0,0), "USD")', "INFO(@(0,0))", '(None, USD)'),
    ('SET_TYPE_CURRENCY(@(1,0), "INR")', "INFO(@(1,0))", '(None, INR)'),
    ('@(0,0) = 3', "INFO(@(0,0))", '(3.00, USD)'),
    ('@(1,0) = 5', "INFO(@(1,0))", '(5.00, INR)'),
]

class TestStorage:

    def test_store(self):
        store = interpreter.KeyValueStore(interpreter.DictStoragePlugin(), interpreter.TerminalOutputPlugin())
        address = "@(0,0)"
        store.set(address, 1)
        assert store.get(address) == 1

class Testinterpreter:

    def setup_class(cls):
        cls.storage_plugin = interpreter.DictStoragePlugin()
        cls.output_plugin = interpreter.TerminalOutputPlugin()
        cls.parser = interpreter.Parser(cls.storage_plugin, cls.output_plugin)

    @pytest.mark.parametrize("inp_a, inp_b, outp", assignment_data)
    def test_assignment_expressions(self, inp_a, inp_b, outp):
        self.parser.parse(inp_a)
        assert self.parser.parse(inp_b) == outp

    @pytest.mark.parametrize("inp, outp", arithmetic_data)
    def test_arithmetic_expressions(self, inp, outp):
        assert self.parser.parse(inp) == outp

    @pytest.mark.parametrize("inp_a, inp_b, outp", cell_types)
    def test_cell_types(self, inp_a, inp_b, outp):
        self.parser.parse(inp_a)
        assert self.parser.parse(inp_b) == outp

    def test_currency_mismatch(self):
        try:
            self.parser.parse("SUM(@(0,0):@(1,0))")
            assert False
        except:
            assert True

    def test_currency_sum(self):
        self.parser.parse('SET_TYPE_CURRENCY(@(0,0), "INR")')
        self.parser.parse('@(0,0) = 7')
        assert self.parser.parse("SUM(@(0,0):@(1,0))") == 12

    def test_currency_avg(self):
        assert self.parser.parse("AVG(@(0,0):@(1,0))") == 6

    def test_currency_max(self):
        assert self.parser.parse("MAX(@(0,0):@(1,0))") == 7

    def test_currency_min(self):
        assert self.parser.parse("MIN(@(0,0):@(1,0))") == 5

    def test_view(self):
        print(self.parser.cells.view())