PK!"{^^pytch/__init__.py__version__ = "0.1" ISSUE_TRACKER_URL = "https://github.com/arxanas/pytch/issues/new/choose" PK!zHAApytch/__main__.pyimport sys from typing import Sequence, TextIO import click from .lexer import lex from .parser import dump_syntax_tree, parse from .repl import compile_file, interact, print_errors, run_file from .utils import FileInfo @click.group() def cli() -> None: pass @cli.command("compile") @click.argument("source_files", type=click.File(), nargs=-1) @click.option("--dump-tree", is_flag=True) def compile(source_files: Sequence[TextIO], dump_tree: bool) -> None: for source_file in source_files: file_info = FileInfo(file_path=source_file.name, source_code=source_file.read()) if dump_tree: errors = [] lexation = lex(file_info=file_info) errors.extend(lexation.errors) parsation = parse(file_info=file_info, tokens=lexation.tokens) errors.extend(parsation.errors) print_errors(errors) (offset, lines) = dump_syntax_tree( file_info.source_code, ast_node=parsation.green_cst ) sys.stdout.write("".join(line + "\n" for line in lines)) else: (compiled_output, errors) = compile_file(file_info=file_info) print_errors(errors) if compiled_output is not None and source_file is sys.stdin: sys.stdout.write(compiled_output) @cli.command("run") @click.argument("source_file", type=click.File()) def run(source_file: TextIO) -> None: run_file( file_info=FileInfo(file_path=source_file.name, source_code=source_file.read()) ) @cli.command("repl") def repl() -> None: interact() PK!pytch/binder.py"""Binds name references to variable declarations in the AST. In a valid program, every `IdentifierExpr` refers to at least one `VariablePattern` somewhere. (In a pattern-match, there may be more than one source `VariablePattern`.) """ from typing import Dict, List, Mapping, Optional, Tuple import attr import distance from .errors import Error, ErrorCode, Note, Severity from .redcst import IdentifierExpr, LetExpr, Node, Pattern, SyntaxTree, VariablePattern from .utils import FileInfo, Range GLOBAL_SCOPE: Mapping[str, List[VariablePattern]] = { "map": [], "filter": [], "print": [], "True": [], "False": [], "None": [], } @attr.s(auto_attribs=True, frozen=True) class Bindation: bindings: Mapping[IdentifierExpr, List[VariablePattern]] errors: List[Error] def get(self, node: IdentifierExpr) -> Optional[List[VariablePattern]]: return self.bindings.get(node) def get_names_bound_for_let_expr_value( n_let_expr: LetExpr, ) -> Mapping[str, List[VariablePattern]]: """Get the names bound in a let-expression's value. That is, for function let expressions, get the names of the parameters that should be bound inside the function's definition. For example: let foo(bar, baz) = bar + baz # bar and baz should be bound here... foo(1, 2) # ...but not here. TODO: Additionally, if the function is marked as `rec`, bind the function name itself inside the function body. """ n_parameter_list = n_let_expr.n_parameter_list if n_parameter_list is None: return {} parameters = n_parameter_list.parameters if parameters is None: return {} bindings: Dict[str, List[VariablePattern]] = {} for parameter in parameters: n_pattern = parameter.n_pattern if n_pattern is not None: # TODO: warn about overlapping name-bindings. bindings.update(get_names_bound_by_pattern(n_pattern)) return bindings def get_names_bound_for_let_expr_body( n_let_expr: LetExpr, ) -> Mapping[str, List[VariablePattern]]: if n_let_expr.n_pattern is None: return {} return get_names_bound_by_pattern(n_let_expr.n_pattern) def get_names_bound_by_pattern( n_pattern: Pattern, ) -> Mapping[str, List[VariablePattern]]: if isinstance(n_pattern, VariablePattern): t_identifier = n_pattern.origin.t_identifier if t_identifier is None: return {} name = t_identifier.text return {name: [n_pattern]} else: assert False, f"Unhandled pattern type: {n_pattern.__class__.__name__}" def bind( file_info: FileInfo, syntax_tree: SyntaxTree, global_scope: Mapping[str, List[VariablePattern]], ) -> Bindation: def get_binding_referred_to_by_name( node: Node, name: str, names_in_scope: Mapping[str, List[VariablePattern]] ) -> Tuple[Optional[List[VariablePattern]], List[Error]]: binding = names_in_scope.get(name) if binding is not None: return (binding, []) suggestions = [ candidate for candidate in names_in_scope if distance.levenshtein(name, candidate) <= 2 ] notes = [] for suggestion in suggestions: suggestion_nodes = names_in_scope.get(suggestion) range: Optional[Range] if suggestion_nodes: range = file_info.get_range_from_offset_range( suggestion_nodes[0].offset_range ) location = ", defined here" else: range = None location = " (a builtin)" notes.append( Note( file_info=file_info, message=f"Did you mean '{suggestion}'{location}?", range=range, ) ) errors = [ Error( file_info=file_info, code=ErrorCode.UNBOUND_NAME, severity=Severity.ERROR, message=( f"I couldn't find a binding " + f"in the current scope with the name '{name}'." ), notes=notes, range=file_info.get_range_from_offset_range(node.offset_range), ) ] return (None, errors) def bind_node( node: Node, names_in_scope: Mapping[str, List[VariablePattern]] ) -> Tuple[Mapping[IdentifierExpr, List[VariablePattern]], List[Error]]: bindings = {} errors = [] if isinstance(node, IdentifierExpr): node_identifier = node.t_identifier if node_identifier is not None: name = node_identifier.text ( identifier_binding, identifier_errors, ) = get_binding_referred_to_by_name( node=node, name=name, names_in_scope=names_in_scope ) if identifier_binding is not None: bindings[node] = identifier_binding errors.extend(identifier_errors) if isinstance(node, LetExpr): if node.n_value is not None: value_names_in_scope = { **names_in_scope, **get_names_bound_for_let_expr_value(node), } (value_bindings, value_errors) = bind_node( node=node.n_value, names_in_scope=value_names_in_scope ) bindings.update(value_bindings) errors.extend(value_errors) if node.n_body is not None: body_names_in_scope = { **names_in_scope, **get_names_bound_for_let_expr_body(node), } (body_bindings, body_errors) = bind_node( node=node.n_body, names_in_scope=body_names_in_scope ) bindings.update(body_bindings) errors.extend(body_errors) else: for child in node.children: if isinstance(child, Node): (child_bindings, child_errors) = bind_node( node=child, names_in_scope=names_in_scope ) bindings.update(child_bindings) errors.extend(child_errors) return (bindings, errors) (bindings, errors) = bind_node(node=syntax_tree, names_in_scope=global_scope) return Bindation(bindings=bindings, errors=errors) PK!9v'D'Dpytch/codegen/__init__.pyimport keyword from typing import Dict, List, Optional, Set, Tuple import attr from .py3ast import ( PyArgument, PyAssignmentStmt, PyBinaryExpr, PyExpr, PyExprStmt, PyFunctionCallExpr, PyFunctionStmt, PyIdentifierExpr, PyIfStmt, PyLiteralExpr, PyParameter, PyStmtList, PyUnavailableExpr, ) from ..binder import Bindation from ..errors import Error from ..lexer import TokenKind from ..redcst import ( BinaryExpr, Expr, FunctionCallExpr, IdentifierExpr, IfExpr, IntLiteralExpr, LetExpr, Pattern, SyntaxTree, VariablePattern, ) from ..typesystem import Typeation @attr.s(auto_attribs=True, frozen=True) class Scope: pytch_bindings: Dict[VariablePattern, str] python_bindings: Set[str] def update(self, **kwargs) -> "Scope": return attr.evolve(self, **kwargs) @staticmethod def empty() -> "Scope": return Scope(pytch_bindings={}, python_bindings=set()) @attr.s(auto_attribs=True, frozen=True) class Env: """Environment for codegen with the Python 3 backend. We keep track of the emitted variable bindings here. We may need to emit extra variables that don't exist in the source code as temporaries, and we need to account for the differences in scoping between Pytch and Python (for example, function bindings aren't recursive by default in Pytch, but are in Python). """ bindation: Bindation scopes: List[Scope] def _update(self, **kwargs) -> "Env": return attr.evolve(self, **kwargs) def push_scope(self) -> "Env": return self._update(scopes=self.scopes + [Scope.empty()]) def pop_scope(self) -> "Env": assert self.scopes return self._update(scopes=self.scopes[:-1]) def add_binding( self, variable_pattern: VariablePattern, preferred_name: str ) -> Tuple["Env", str]: """Add a binding for a variable that exists in the source code. `preferred_name` is used as the preferred Python variable name, but a non-colliding name will be generated if there is already such a name in the current Python scope. """ python_name = self._get_name(preferred_name) current_pytch_bindings = dict(self.scopes[-1].pytch_bindings) current_pytch_bindings[variable_pattern] = python_name current_scope = self.scopes[-1].update(pytch_bindings=current_pytch_bindings) return (self._update(scopes=self.scopes[:-1] + [current_scope]), python_name) def make_temporary(self, preferred_name: str) -> Tuple["Env", str]: python_name = self._get_name(preferred_name) current_python_bindings = set(self.scopes[-1].python_bindings) assert python_name not in current_python_bindings current_python_bindings.add(python_name) current_scope = self.scopes[-1].update(python_bindings=current_python_bindings) return (self._update(scopes=self.scopes[:-1] + [current_scope]), python_name) def lookup_binding(self, variable_pattern: VariablePattern) -> Optional[str]: for scope in reversed(self.scopes): if variable_pattern in scope.pytch_bindings: return scope.pytch_bindings[variable_pattern] return None def _get_name(self, preferred_name: str) -> str: for suggested_name in self._suggest_names(preferred_name): if ( not keyword.iskeyword(suggested_name) and suggested_name not in self.scopes[-1].python_bindings and suggested_name not in self.scopes[-1].pytch_bindings.values() ): return suggested_name assert False, "`suggest_names` should loop forever" def _suggest_names(self, preferred_name: str): yield preferred_name i = 2 while True: yield preferred_name + str(i) i += 1 @attr.s(auto_attribs=True, frozen=True) class Codegenation: statements: PyStmtList errors: List[Error] def get_compiled_output(self) -> str: compiled_output_lines = [] for statement in self.statements: compiled_output_lines.extend(statement.compile()) return "".join(line + "\n" for line in compiled_output_lines) def compile_expr( env: Env, expr: Expr ) -> Tuple[ Env, # A Python expression that evaluates to its corresponding Pytch expression. PyExpr, # Any setup code that needs to be run in order to evaluate the Python # expression (since not everything is an expression in Python). For example, # # def helper(x): # foo() # return x + 1 # # could later be used in the expression # # map(helper, some_list) # # In this case, the expression `helper` would be the `PyExpr` above, and # the definition of the helper function would be the `PyStmtList`. PyStmtList, ]: if isinstance(expr, LetExpr): return compile_let_expr(env, expr) elif isinstance(expr, IfExpr): return compile_if_expr(env, expr) elif isinstance(expr, FunctionCallExpr): return compile_function_call_expr(env, expr) elif isinstance(expr, BinaryExpr): return compile_binary_expr(env, expr) elif isinstance(expr, IdentifierExpr): return compile_identifier_expr(env, expr) elif isinstance(expr, IntLiteralExpr): return compile_int_literal_expr(env, expr) else: assert False, f"Unhandled expr type {expr.__class__.__name__}" PY_EXPR_NO_TARGET = PyUnavailableExpr("should have been directly stored already") def compile_expr_target( env: Env, expr: Expr, target: PyIdentifierExpr, preferred_name: str ) -> Tuple[Env, PyStmtList]: """Like `compile_expr`, but store the result in the given target. This cleans up the generated code by avoiding temporary stores that make it hard to read. For example, this code: ``` let foo = if True then 1 else 2 print(foo) ``` May compile into this, if we don't elide intermediate stores: ``` if True: _tmp_if = 1 else: _tmp_if = 2 foo = _tmp_if print(foo) ``` But we can write this more succinctly by noting that the result of the `if`-expression should be directly assigned to `foo`: ``` if True: foo = 1 else: foo = 2 print(foo) ``` """ if isinstance(expr, LetExpr): (env, _py_expr, statements) = compile_let_expr( env, let_expr=expr, target=target ) return (env, statements) elif isinstance(expr, IfExpr): (env, _py_expr, statements) = compile_if_expr(env, if_expr=expr, target=target) return (env, statements) elif isinstance(expr, IntLiteralExpr): (env, _py_expr, statements) = compile_int_literal_expr(env, expr, target=target) return (env, statements) else: (env, py_expr, statements) = compile_expr(env, expr) statements = statements + [PyAssignmentStmt(lhs=target, rhs=py_expr)] return (env, statements) def compile_let_expr( env: Env, let_expr: LetExpr, target: PyIdentifierExpr = None ) -> Tuple[Env, PyExpr, PyStmtList]: n_pattern = let_expr.n_pattern n_value = let_expr.n_value py_binding_statements: PyStmtList if n_pattern is not None and n_value is not None: n_parameter_list = None if let_expr.n_parameter_list is not None: n_parameter_list = let_expr.n_parameter_list.parameters if n_parameter_list is None: if target is not None: (env, py_binding_statements) = compile_expr_target( env, n_value, target=target, preferred_name="_tmp_let" ) else: (env, py_binding_statements) = compile_assign_to_pattern( env, expr=n_value, pattern=n_pattern ) else: assert isinstance( n_pattern, VariablePattern ), f"Bad pattern type {n_pattern.__class__.__name__} for function" t_identifier = n_pattern.t_identifier if t_identifier is None: return (env, PyUnavailableExpr("missing let-binding function name"), []) function_name = t_identifier.text env = env.push_scope() py_parameters = [] for n_parameter in n_parameter_list: n_parameter_pattern = n_parameter.n_pattern if n_parameter_pattern is None: continue assert isinstance(n_parameter_pattern, VariablePattern), ( f"Unhandled pattern type " + f"{n_parameter_pattern.__class__.__name__}" ) t_pattern_identifier = n_parameter_pattern.t_identifier if t_pattern_identifier is None: continue parameter_name = t_pattern_identifier.text (env, parameter_name) = env.add_binding( variable_pattern=n_parameter_pattern, preferred_name=parameter_name ) py_parameters.append(PyParameter(name=parameter_name)) ( env, py_function_body_return_expr, py_function_body_statements, ) = compile_expr(env, n_value) env = env.pop_scope() (env, actual_function_name) = env.add_binding( n_pattern, preferred_name=function_name ) py_binding_statements = [ PyFunctionStmt( name=actual_function_name, parameters=py_parameters, body_statements=py_function_body_statements, return_expr=py_function_body_return_expr, ) ] if let_expr.n_body is not None: (env, body_expr, body_statements) = compile_expr(env, let_expr.n_body) else: body_expr = PyUnavailableExpr("missing let-expr body") body_statements = [] return (env, body_expr, py_binding_statements + body_statements) def compile_if_expr( env: Env, if_expr: IfExpr, target: PyIdentifierExpr = None ) -> Tuple[Env, PyExpr, PyStmtList]: n_if_expr = if_expr.n_if_expr n_then_expr = if_expr.n_then_expr n_else_expr = if_expr.n_else_expr if n_if_expr is None: return (env, PyUnavailableExpr("missing if condition"), []) (env, py_if_expr, py_if_statements) = compile_expr(env, n_if_expr) # Check `n_then_expr` here to avoid making a temporary and not using it. if target is None and n_then_expr is not None: (env, target_name) = env.make_temporary("_tmp_if") target = PyIdentifierExpr(name=target_name) # Compile the `then`-clause. if n_then_expr is None: return (env, PyUnavailableExpr("missing then expression"), []) if n_else_expr is not None: assert target is not None (env, py_then_statements) = compile_expr_target( env, n_then_expr, target=target, preferred_name="_tmp_if" ) else: # Avoid storing the result of the `then`-clause into anything if there is no corresponding `else`-clause. This makes code like this: # # if True # then print(1) # # produce code like this: # # if True: # print(1) # # instead of code like this: # # if True: # _tmp_if = print(1) # else: # _tmp_if = None # _tmp_if (env, py_body_expr, py_then_statements) = compile_expr(env, n_then_expr) py_then_statements = py_then_statements + [PyExprStmt(expr=py_body_expr)] target = None py_else_statements: Optional[PyStmtList] = None if n_else_expr is not None: assert target is not None (env, py_else_statements) = compile_expr_target( env, n_else_expr, target=target, preferred_name="_tmp_if" ) statements = py_if_statements + [ PyIfStmt( if_expr=py_if_expr, then_statements=py_then_statements, else_statements=py_else_statements, ) ] if isinstance(target, PyIdentifierExpr): return (env, target, statements) else: return (env, PY_EXPR_NO_TARGET, statements) def compile_assign_to_pattern( env: Env, expr: Expr, pattern: Pattern ) -> Tuple[Env, PyStmtList]: if isinstance(pattern, VariablePattern): t_identifier = pattern.t_identifier if t_identifier is None: return ( env, [ PyExprStmt( expr=PyUnavailableExpr( "missing identifier for variable pattern" ) ) ], ) preferred_name = t_identifier.text (env, name) = env.add_binding(pattern, preferred_name=preferred_name) target = PyIdentifierExpr(name=name) return compile_expr_target( env, expr=expr, target=target, preferred_name=preferred_name ) else: assert False, f"unimplemented pattern: {pattern.__class__.__name__}" def compile_function_call_expr( env: Env, function_call_expr: FunctionCallExpr ) -> Tuple[Env, PyExpr, PyStmtList]: n_callee = function_call_expr.n_callee if n_callee is not None: (env, py_callee_expr, py_receiver_statements) = compile_expr(env, n_callee) else: return (env, PyUnavailableExpr("missing function callee"), []) n_argument_list = function_call_expr.n_argument_list if n_argument_list is None or n_argument_list.arguments is None: return (env, PyUnavailableExpr("missing function argument list"), []) py_arguments = [] py_argument_list_statements: PyStmtList = [] for argument in n_argument_list.arguments: if argument.n_expr is None: return (env, PyUnavailableExpr("missing argument"), []) (env, py_argument_expr, py_argument_statements) = compile_expr( env, argument.n_expr ) py_arguments.append(PyArgument(value=py_argument_expr)) py_argument_list_statements.extend(py_argument_statements) py_function_call_expr = PyFunctionCallExpr( callee=py_callee_expr, arguments=py_arguments ) return ( env, py_function_call_expr, py_receiver_statements + py_argument_list_statements, ) def compile_binary_expr( env: Env, binary_expr: BinaryExpr ) -> Tuple[Env, PyExpr, PyStmtList]: n_lhs = binary_expr.n_lhs if n_lhs is None: return (env, PyUnavailableExpr("missing lhs"), []) t_operator = binary_expr.t_operator if t_operator is None: return (env, PyUnavailableExpr("missing operator"), []) n_rhs = binary_expr.n_rhs if n_rhs is None: return (env, PyUnavailableExpr("missing rhs"), []) (env, py_lhs_expr, lhs_statements) = compile_expr(env, expr=n_lhs) (env, py_rhs_expr, rhs_statements) = compile_expr(env, expr=n_rhs) if t_operator.kind == TokenKind.DUMMY_SEMICOLON: statements = lhs_statements + [PyExprStmt(expr=py_lhs_expr)] + rhs_statements return (env, py_rhs_expr, statements) else: assert not t_operator.is_dummy return ( env, PyBinaryExpr(lhs=py_lhs_expr, operator=t_operator.text, rhs=py_rhs_expr), lhs_statements + rhs_statements, ) def compile_identifier_expr( env: Env, identifier_expr: IdentifierExpr ) -> Tuple[Env, PyExpr, PyStmtList]: sources = env.bindation.get(identifier_expr) if not sources: t_identifier = identifier_expr.t_identifier if t_identifier is not None: return (env, PyIdentifierExpr(name=t_identifier.text), []) else: return (env, PyUnavailableExpr(f"unknown identifier"), []) python_identifiers = [] for source in sources: python_identifier = env.lookup_binding(source) assert python_identifier is not None python_identifiers.append(python_identifier) assert all( python_identifier == python_identifiers[0] for python_identifier in python_identifiers ) return (env, PyIdentifierExpr(name=python_identifiers[0]), []) def compile_int_literal_expr( env: Env, int_literal_expr: IntLiteralExpr, target: PyIdentifierExpr = None ) -> Tuple[Env, PyExpr, PyStmtList]: t_int_literal = int_literal_expr.t_int_literal if t_int_literal is None: return (env, PyUnavailableExpr("missing int literal"), []) value = t_int_literal.text py_expr = PyLiteralExpr(value=str(value)) if target is None: return (env, py_expr, []) else: statements: PyStmtList = [PyAssignmentStmt(lhs=target, rhs=py_expr)] return (env, PY_EXPR_NO_TARGET, statements) def codegen( syntax_tree: SyntaxTree, bindation: Bindation, typeation: Typeation ) -> Codegenation: env = Env(bindation=bindation, scopes=[Scope.empty()]) if syntax_tree.n_expr is None: return Codegenation(statements=[], errors=[]) (env, expr, statements) = compile_expr(env, syntax_tree.n_expr) return Codegenation(statements=statements + [PyExprStmt(expr=expr)], errors=[]) PK!Qpytch/codegen/py3ast.pyfrom typing import List, Optional import attr from pytch import ISSUE_TRACKER_URL CompiledOutput = List[str] class PyExpr: def compile(self) -> str: raise NotImplementedError( f"`PyExpr.compile` not implemented by {self.__class__.__name__}" ) @attr.s(auto_attribs=True, frozen=True) class PyUnavailableExpr(PyExpr): """Indicates a value deriving from malformed source code.""" reason: str def compile(self) -> str: return ( f'""' ) @attr.s(auto_attribs=True, frozen=True) class PyIdentifierExpr(PyExpr): name: str def compile(self) -> str: return self.name @attr.s(auto_attribs=True, frozen=True) class PyLiteralExpr(PyExpr): value: str def compile(self) -> str: return self.value @attr.s(auto_attribs=True, frozen=True) class PyArgument: value: PyExpr def compile(self) -> str: return self.value.compile() @attr.s(auto_attribs=True, frozen=True) class PyFunctionCallExpr(PyExpr): callee: PyExpr arguments: List[PyArgument] def compile(self) -> str: compiled_arguments = [] for argument in self.arguments: compiled_arguments.append(argument.compile()) compiled_arguments_str = ", ".join(compiled_arguments) return f"{self.callee.compile()}({compiled_arguments_str})" @attr.s(auto_attribs=True, frozen=True) class PyBinaryExpr(PyExpr): lhs: PyExpr operator: str rhs: PyExpr def compile(self) -> str: return f"{self.lhs.compile()} {self.operator} {self.rhs.compile()}" class PyStmt: def compile(self) -> CompiledOutput: raise NotImplementedError( f"`PyStmt.compile` not implemented by {self.__class__.__name__}" ) PyStmtList = List[PyStmt] @attr.s(auto_attribs=True, frozen=True) class PyIndentedStmt: statement: PyStmt def compile(self) -> CompiledOutput: return [" " + line for line in self.statement.compile()] @attr.s(auto_attribs=True, frozen=True) class PyAssignmentStmt(PyStmt): lhs: PyIdentifierExpr rhs: PyExpr def compile(self) -> CompiledOutput: return [f"{self.lhs.compile()} = {self.rhs.compile()}"] @attr.s(auto_attribs=True, frozen=True) class PyReturnStmt(PyStmt): expr: PyExpr def compile(self) -> CompiledOutput: return [f"return {self.expr.compile()}"] @attr.s(auto_attribs=True, frozen=True) class PyIfStmt(PyStmt): if_expr: PyExpr # noqa: E701 then_statements: PyStmtList else_statements: Optional[PyStmtList] # noqa: E701 def compile(self) -> CompiledOutput: if_statements = [f"if {self.if_expr.compile()}:"] for statement in self.then_statements: if_statements.extend(PyIndentedStmt(statement=statement).compile()) else_statements = [] if self.else_statements is not None: assert self.else_statements else_statements.append("else:") for statement in self.else_statements: else_statements.extend(PyIndentedStmt(statement=statement).compile()) return if_statements + else_statements @attr.s(auto_attribs=True, frozen=True) class PyParameter: name: str def compile(self) -> str: return self.name @attr.s(auto_attribs=True, frozen=True) class PyFunctionStmt(PyStmt): name: str parameters: List[PyParameter] body_statements: PyStmtList return_expr: PyExpr def compile(self) -> CompiledOutput: parameters = ", ".join(parameter.compile() for parameter in self.parameters) body_statements = [] for statement in self.body_statements: body_statements.extend(PyIndentedStmt(statement=statement).compile()) return_statement = PyIndentedStmt(statement=PyReturnStmt(expr=self.return_expr)) body_statements.extend(return_statement.compile()) return [f"def {self.name}({parameters}):", *body_statements] @attr.s(auto_attribs=True, frozen=True) class PyExprStmt(PyStmt): expr: PyExpr def compile(self) -> CompiledOutput: if isinstance(self.expr, PyUnavailableExpr): return [] return [f"{self.expr.compile()}"] PK!X# pytch/containers.pyfrom typing import ( AbstractSet, Callable, Iterable, Iterator, Mapping, Optional, overload, Sequence, Tuple, TypeVar, Union, ) import pyrsistent as p from pyrsistent import pmap, pset, pvector Tk = TypeVar("Tk") Tv = TypeVar("Tv") Tv_in = TypeVar("Tv_in") Tv_out = TypeVar("Tv_out") class PSet(AbstractSet[Tk]): def __init__(self, iterable: Iterable[Tk] = None) -> None: self._container: p.PSet[Tk] = pset(iterable or []) # TODO: tighten up `__contains__` to only accept `Tk`. def __contains__(self, key: object) -> bool: return key in self._container def __iter__(self) -> Iterator[Tk]: return iter(self._container) def __len__(self) -> int: return len(self._container) def __repr__(self) -> str: elements = ", ".join(repr(element) for element in self._container) return f"PSet([{elements}])" def add(self, key: Tk) -> "PSet[Tk]": return PSet(self._container.add(key)) class PVector(Sequence[Tv]): def __init__(self, iterable: Iterable[Tv] = None) -> None: self._container: p.PVector = pvector(iterable or []) @overload def __getitem__(self, item: int) -> Tv: pass @overload # noqa: F811 def __getitem__(self, item: slice) -> Sequence[Tv]: pass def __getitem__( # noqa: F811 self, index: Union[int, slice] ) -> Union[Tv, Sequence[Tv]]: return self._container[index] def __len__(self) -> int: return len(self._container) def __repr__(self) -> str: elements = ", ".join(repr(element) for element in self._container) return f"PVector([{elements}])" def append(self, element: Tv) -> "PVector[Tv]": return PVector(self._container.append(element)) def map(self, f: Callable[[Tv_in], Tv_out]) -> "PVector[Tv_out]": return PVector(self._container.transform(None, f)) class PMap(Mapping[Tk, Tv]): def __init__(self, mapping: Mapping[Tk, Tv] = None) -> None: self._container: p.PMap[Tk, Tv] = pmap(mapping or {}) @classmethod def of_entries(cls, iterable: Iterable[Tuple[Tk, Tv]] = None) -> "PMap[Tk, Tv]": mapping = dict(iterable or []) return cls(mapping) def __getitem__(self, index: Tk) -> Tv: return self._container[index] def __iter__(self) -> Iterator[Tk]: return iter(self._container) def __len__(self) -> int: return len(self._container) def __repr__(self) -> str: elements = ", ".join(f"{k!r}: {v!r}" for k, v in self._container.items()) return f"PMap({{{elements}}})" def set(self, key: Tk, value: Tv) -> "PMap[Tk, Tv]": return PMap(self._container.set(key, value)) def update(self, bindings: Mapping[Tk, Tv]) -> "PMap[Tk, Tv]": return PMap(self._container.update(bindings)) def find(iterable: Iterable[Tv], pred: Callable[[Tv], bool]) -> Optional[Tv]: for i in iterable: if pred(i): return i return None def take_while(iterable: Iterable[Tv], pred: Callable[[Tv], bool]) -> Iterator[Tv]: for i in iterable: if pred(i): yield i else: return PK!Gpytch/cstquery.pyfrom typing import Iterable, Type, TypeVar, Union from .redcst import Node, SyntaxTree, Token T_node = TypeVar("T_node", bound=Node) class Query: def __init__(self, syntax_tree: SyntaxTree) -> None: self._syntax_tree = syntax_tree def find_instances(self, node_type: Type[T_node]) -> Iterable[T_node]: for node in self._walk_all(): if isinstance(node, node_type): yield node def _walk(self, node: Union[Node, Token]) -> Iterable[Union[Node, Token]]: yield node if isinstance(node, Node): for child in node.children: if child is not None: yield from self._walk(child) def _walk_all(self) -> Iterable[Union[Node, Token]]: if self._syntax_tree.n_expr is not None: yield from self._walk(self._syntax_tree.n_expr) if self._syntax_tree.t_eof is not None: yield from self._walk(self._syntax_tree.t_eof) PK!3wbwbpytch/errors.py"""Error types and pretty-printing. ## Grammar * Use the first-person: the compiler should be a tool that works with you, rather than against you. For example, see the Elm error messages: http://elm-lang.org/blog/compiler-errors-for-humans * Don't be terse. Specify subjects and pronouns explicitly. BAD: Missing ')'. GOOD: I was expecting a ')' here. * Use the past progressive rather than the simple past tense. We don't use the present tense to indicate that the compilation already happened -- it's not still in the process of happening. We don't use the simple past tense simply not to remind the user of terser compilers which use it. BAD: Expected X. GOOD: I was expecting an X. ## Word choice * Use articles where possible. BAD: I was expecting ')' here. GOOD: I was expecting a ')' here. * Prefer the term "binding" over the term "variable". ## Typography * Use single-quotes (') instead of backticks (`) or fancy Unicode quotes. BAD: I was expecting a `)` here. GOOD: I was expecting a ')' here. """ import collections from enum import Enum import itertools import re from typing import ( Callable, cast, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, TypeVar, Union, ) import attr import click from typing_extensions import Protocol from .utils import FileInfo, Range T = TypeVar("T") class ErrorCode(Enum): INVALID_TOKEN = 1000 UNEXPECTED_TOKEN = 1001 EXPECTED_EXPRESSION = 1010 EXPECTED_LPAREN = 1011 EXPECTED_RPAREN = 1012 EXPECTED_PATTERN = 1013 EXPECTED_EQUALS = 1014 EXPECTED_DUMMY_IN = 1015 EXPECTED_LET_EXPRESSION = 1016 EXPECTED_COMMA = 1017 EXPECTED_END_OF_ARGUMENT_LIST = 1018 EXPECTED_END_OF_PARAMETER_LIST = 1019 UNBOUND_NAME = 2000 INCOMPATIBLE_TYPES = 3000 EXPECTED_VOID = 3001 CANNOT_BIND_TO_VOID = 3002 PARSED_LENGTH_MISMATCH = 9000 NOT_A_REAL_ERROR = 9001 """Not a real error code, just for testing purposes.""" SHOULD_END_WITH_EOF = 9002 LET_IN_MISMATCH = 9003 IF_ENDIF_MISMATCH = 9004 @attr.s(auto_attribs=True, frozen=True) class Glyphs: """The set of glyphs to be used when printing out error messages.""" make_colored: Callable[[str, str], str] make_bold: Callable[[str], str] make_inverted: Callable[[str], str] box_vertical: str box_horizontal: str box_upper_left: str box_upper_right: str box_lower_left: str box_lower_right: str box_continuation_left: str box_continuation_right: str underline_start_character: str underline_character: str underline_end_character: str underline_point_character: str vertical_colon: str @attr.s(auto_attribs=True, frozen=True) class OutputEnv: glyphs: Glyphs max_width: int class Diagnostic(Protocol): @property def file_info(self) -> FileInfo: ... @property def color(self) -> str: ... @property def preamble_message(self) -> str: ... @property def message(self) -> str: ... @property def range(self) -> Optional[Range]: ... @attr.s(auto_attribs=True, frozen=True) class _DiagnosticContext: file_info: FileInfo line_ranges: Optional[List[Tuple[int, int]]] @attr.s(auto_attribs=True, frozen=True) class Note: color = "blue" preamble_message = "Note" file_info: FileInfo message: str range: Optional[Range] = attr.ib(default=None) class Severity(Enum): ERROR = "error" WARNING = "warning" @attr.s(auto_attribs=True, frozen=True) class Error: file_info: FileInfo code: ErrorCode severity: Severity message: str notes: List[Note] range: Optional[Range] = attr.ib(default=None) @property def color(self) -> str: if self.severity == Severity.ERROR: return "red" elif self.severity == Severity.WARNING: return "yellow" else: assert False, f"Unhandled severity: {self.severity}" @property def preamble_message(self) -> str: return self.severity.value.title() def get_full_diagnostic_message(diagnostic: Diagnostic,) -> str: return f"{diagnostic.preamble_message}: {diagnostic.message}" def get_glyphs(ascii: bool) -> Glyphs: if ascii: return Glyphs( make_colored=lambda text, color: text, make_bold=lambda text: text, make_inverted=lambda text: text, box_vertical="|", box_horizontal="-", box_upper_left="+", box_upper_right="+", box_lower_left="+", box_lower_right="+", box_continuation_left="+", box_continuation_right="+", underline_start_character="^", underline_character="~", underline_end_character="~", underline_point_character="^", vertical_colon=":", ) else: return Glyphs( make_colored=lambda text, color: click.style(text, fg=color, bold=True), make_bold=lambda text: click.style(text, bold=True), make_inverted=lambda text: click.style(text, reverse=True), box_vertical="│", box_horizontal="─", box_upper_left="┌", box_upper_right="┐", box_lower_left="└", box_lower_right="┘", box_continuation_left="├", box_continuation_right="┤", underline_start_character="┕", underline_character="━", underline_end_character="┙", underline_point_character="↑", vertical_colon=":", # TODO: use Unicode vertical colon ) def get_output_env(ascii: bool) -> OutputEnv: glyphs = get_glyphs(ascii=ascii) if ascii: max_width = 79 else: (terminal_width, _terminal_height) = click.get_terminal_size() max_width = terminal_width - 1 return OutputEnv(glyphs=glyphs, max_width=max_width) @attr.s(auto_attribs=True, frozen=True) class _MessageLine: text: str color: Optional[str] is_wrappable: bool def wrap(self, max_width: int) -> List[str]: match = re.match(r"\s*\S+", self.text) if not match: return [self.text] prefix = match.group() text = self.text[len(prefix) :] wrapped_text = click.wrap_text(text, width=max_width, initial_indent=prefix) if wrapped_text: return wrapped_text.splitlines() else: return [prefix] def get_wrapped_width(self, max_width: int) -> int: return max(map(len, self.wrap(max_width))) @attr.s(auto_attribs=True, frozen=True) class Segment: """A box-enclosed segment of the error display. For example, the message ``` +------------------+ | Error: something | +------------------+ ``` or the code fragment ``` +------------+ 1 | let foo = | 2 | bar(baz) | +------------+ ``` constitute "segments". """ output_env: OutputEnv header: Optional[str] gutter_lines: Optional[List[str]] = attr.ib() message_lines: List[_MessageLine] @gutter_lines.validator def check(self, attribute, value) -> None: if self.gutter_lines is not None: assert len(self.gutter_lines) == len(self.message_lines) is_context_continuation: bool = attr.ib(default=False) """"Whether this segment is a vertical-colon-delimited continuation of the previous segment.""" @property def gutter_width(self) -> int: if not self.gutter_lines: return 0 num_padding_characters = len(" ") max_gutter_line_length = max(len(line) for line in self.gutter_lines) return num_padding_characters + max_gutter_line_length def get_box_width(self, gutter_width: int) -> int: num_box_characters = len("||") num_padding_characters = len(" ") max_message_line_length = max( line.get_wrapped_width( self.output_env.max_width - num_box_characters - num_padding_characters - gutter_width ) for line in self.message_lines ) if self.header is not None: max_message_line_length = max(max_message_line_length, len(self.header)) return max_message_line_length + num_box_characters + num_padding_characters def render_lines( self, is_first: bool, is_last: bool, gutter_width: int, box_width: int ) -> List[str]: if self.gutter_lines is None: gutter_lines = [""] * len(self.message_lines) else: gutter_lines = self.gutter_lines empty_gutter = " " * gutter_width lines = [] glyphs = self.output_env.glyphs # if self._is_context_continuation: # assert not is_first, ( # "The first context should not be a continuation context, " # + "since it has no previous context to continue." # ) # top_line = # else: # if self._is_context_continuation: # is_first = False top_line = "" if is_first: top_line += glyphs.box_upper_left else: top_line += glyphs.box_continuation_left top_line = top_line.ljust(box_width - 1, glyphs.box_horizontal) if is_first: top_line += glyphs.box_upper_right else: top_line += glyphs.box_continuation_right lines.append(empty_gutter + top_line) if self.header: header_line = (" " + self.header).ljust(box_width - 2) header_line = ( glyphs.box_vertical + glyphs.make_bold(header_line) + glyphs.box_vertical ) lines.append(empty_gutter + header_line) num_box_characters = len("||") num_padding_characters = len(" ") padding = num_box_characters + num_padding_characters for gutter_line, message_line in zip(gutter_lines, self.message_lines): if message_line.is_wrappable: wrapped_message_lines = message_line.wrap(box_width - padding) else: wrapped_message_lines = [message_line.text] for wrapped_message_line in wrapped_message_lines: gutter = gutter_line.rjust(gutter_width - num_padding_characters) gutter = " " + gutter + " " line_length = len(wrapped_message_line) if message_line.color is not None: wrapped_message_line = glyphs.make_colored( wrapped_message_line, message_line.color ) message = ( glyphs.box_vertical + " " + wrapped_message_line + (" " * (box_width - line_length - padding)) + " " + glyphs.box_vertical ) lines.append(gutter + message) if is_last: footer = "" footer += glyphs.box_lower_left footer = footer.ljust(box_width - 1, glyphs.box_horizontal) footer += glyphs.box_lower_right lines.append(empty_gutter + footer) return lines def get_error_lines(error: Error, ascii: bool = False) -> List[str]: output_env = get_output_env(ascii=ascii) glyphs = output_env.glyphs output_lines = [] if error.range is not None: line = str(error.range.start.line + 1) character = str(error.range.start.character + 1) output_lines.append( glyphs.make_bold(f"{error.code.name}[{error.code.value}]") + f" in {glyphs.make_bold(error.file_info.file_path)}, " + f"line {glyphs.make_bold(line)}, " + f"character {glyphs.make_bold(character)}:" ) else: output_lines.append( glyphs.make_bold(f"{error.code.name}[{error.code.value}] ") + f"in {glyphs.make_bold(error.file_info.file_path)}:" ) output_lines.append( glyphs.make_colored( click.wrap_text( text=f"{error.preamble_message}: {error.message}", width=output_env.max_width, ), error.color, ) ) segments = get_error_segments(output_env=output_env, error=error) if segments: gutter_width = max(segment.gutter_width for segment in segments) box_width = max(segment.get_box_width(gutter_width) for segment in segments) for i, segment in enumerate(segments): is_first = i == 0 is_last = i == len(segments) - 1 output_lines.extend( segment.render_lines( is_first=is_first, is_last=is_last, gutter_width=gutter_width, box_width=box_width, ) ) return output_lines def get_error_segments(output_env: OutputEnv, error: Error): diagnostics: List[Diagnostic] = [error] diagnostics.extend(error.notes) diagnostic_contexts = [ get_context(file_info=diagnostic.file_info, range=diagnostic.range) for diagnostic in diagnostics ] def key( context: _DiagnosticContext, ) -> List[Tuple[Union[int, float], Union[int, float]]]: if context.line_ranges is not None: return cast( List[Tuple[Union[int, float], Union[int, float]]], context.line_ranges ) else: return [(float("inf"), float("inf"))] sorted_diagnostic_contexts = sorted(diagnostic_contexts, key=key) partitioned_diagnostic_contexts = itertools.groupby( sorted_diagnostic_contexts, lambda context: context.file_info.file_path ) segments: List[Segment] = [] for _file_path, contexts in partitioned_diagnostic_contexts: for context in _merge_contexts(list(contexts)): context_segments = get_context_segments( output_env=output_env, context=context, diagnostics=diagnostics ) if context_segments is not None: segments.extend(context_segments) segments.extend( get_segments_without_ranges(output_env=output_env, diagnostics=diagnostics) ) return segments def get_context(file_info: FileInfo, range: Optional[Range]) -> _DiagnosticContext: """Get the diagnostic context including the line before and after the given range. """ if range is None: return _DiagnosticContext(file_info=file_info, line_ranges=None) start_line_index = max(0, range.start.line - 1) # The range is exclusive, but `range.end.line` is inclusive, so add 1. Then # add 1 again because we want to include the line after `range.end.line`, # if there is one. end_line_index = min(len(file_info.lines), range.end.line + 1 + 1) return _DiagnosticContext( file_info=file_info, line_ranges=[(start_line_index, end_line_index)] ) def _merge_contexts(contexts: List[_DiagnosticContext],) -> List[_DiagnosticContext]: """Combine adjacent contexts with ranges into a list of contexts sorted by range. For example, convert the list of contexts with ranges [(2, 4), (1, 3), None, (6, 8)] into the result [(1, 4), (6, 8), None] """ file_info = contexts[0].file_info assert all(context.file_info == file_info for context in contexts) contexts_with_ranges = [ context for context in contexts if context.line_ranges is not None ] contexts_without_ranges = [ context for context in contexts if context.line_ranges is None ] line_ranges = sorted( line_range for context in contexts_with_ranges for line_range in context.line_ranges # type: ignore ) merged_line_ranges: List[Tuple[int, int]] = [] for line_range in line_ranges: if not merged_line_ranges: merged_line_ranges.append(line_range) elif line_range[0] <= merged_line_ranges[-1][1]: current_line_range = merged_line_ranges.pop() merged_line_ranges.append((current_line_range[0], line_range[1])) else: merged_line_ranges.append(line_range) merged_diagnostic_context = _DiagnosticContext( file_info=file_info, line_ranges=merged_line_ranges ) return [merged_diagnostic_context] + contexts_without_ranges def _ranges_overlap( lhs: Optional[Tuple[int, int]], rhs: Optional[Tuple[int, int]] ) -> bool: if lhs is None or rhs is None: return False assert lhs[0] <= lhs[1] assert rhs[0] <= rhs[1] return not (lhs[1] < rhs[0] or rhs[1] < lhs[0]) def _group_by_pred(seq: Iterable[T], pred: Callable[[T, T], bool]) -> Iterable[List[T]]: current_group: List[T] = [] for i in seq: if current_group and not pred(current_group[-1], i): yield current_group current_group = [] current_group.append(i) if current_group: yield current_group def get_context_segments( output_env: OutputEnv, context: _DiagnosticContext, diagnostics: List[Diagnostic] ) -> Optional[List[Segment]]: diagnostics = [ diagnostic for diagnostic in diagnostics if diagnostic.file_info == context.file_info ] diagnostic_lines_to_insert = _get_diagnostic_lines_to_insert( output_env=output_env, context=context, diagnostics=diagnostics ) line_ranges = context.line_ranges if line_ranges is None: return None segments = [] is_first = True for line_range in line_ranges: gutter_lines = [] message_lines = [] (start_line, end_line) = line_range lines = context.file_info.lines[start_line:end_line] for line_num, line in enumerate(lines, start_line): # 1-index the line number for display. gutter_lines.append(str(line_num + 1)) message_lines.append( _MessageLine( text=line, color=None, # Code segment -- print this verbatim, do not wrap. is_wrappable=False, ) ) diagnostic_lines = diagnostic_lines_to_insert.get(line_num, []) for diagnostic_line in diagnostic_lines: gutter_lines.append("") message_lines.append(diagnostic_line) if not is_first: header = None else: header = context.file_info.file_path segments.append( Segment( output_env=output_env, header=header, gutter_lines=gutter_lines, message_lines=message_lines, is_context_continuation=(not is_first), ) ) is_first = False return segments def get_segments_without_ranges( output_env: OutputEnv, diagnostics: List[Diagnostic] ) -> List[Segment]: segments = [] for diagnostic in diagnostics: if diagnostic.range is None: segments.append( Segment( output_env=output_env, header=None, gutter_lines=[""], message_lines=[ _MessageLine( text=get_full_diagnostic_message(diagnostic), color=diagnostic.color, # Diagnostic message, wrap this if necessary. is_wrappable=True, ) ], ) ) return segments def _get_diagnostic_lines_to_insert( output_env: OutputEnv, context: _DiagnosticContext, diagnostics: Sequence[Diagnostic], ) -> Mapping[int, Sequence[_MessageLine]]: result: Dict[int, List[_MessageLine]] = collections.defaultdict(list) if context.line_ranges is None: return result for line_range in context.line_ranges: context_lines = context.file_info.lines[line_range[0] : line_range[1]] for diagnostic in diagnostics: diagnostic_range = diagnostic.range if diagnostic_range is None: continue underlined_lines = underline_lines( output_env=output_env, start_line_index=line_range[0], context_lines=context_lines, underline_range=diagnostic_range, underline_color=diagnostic.color, ) if underlined_lines: last_line = underlined_lines.pop().text last_line += " " + get_full_diagnostic_message(diagnostic) underlined_lines.append( _MessageLine( text=last_line, color=diagnostic.color, # Diagnostic message, wrap this if necessary. is_wrappable=True, ) ) for line_num, line in enumerate( underlined_lines, diagnostic_range.start.line ): result[line_num].append(line) return result def underline_lines( output_env: OutputEnv, start_line_index: int, context_lines: List[str], underline_range: Range, underline_color: str, ) -> List[_MessageLine]: start_position = underline_range.start end_position = underline_range.end message_lines = [] for line_num, line in enumerate(context_lines, start=start_line_index): underline_start: Optional[int] = None has_underline_start = False if line_num == start_position.line: underline_start = start_position.character has_underline_start = True elif start_position.line <= line_num <= end_position.line: non_whitespace_characters = [ i for i, c in enumerate(line) if not c.isspace() ] if non_whitespace_characters: underline_start = non_whitespace_characters[0] underline_end: Optional[int] = None has_underline_end = False if line_num == end_position.line: underline_end = end_position.character has_underline_end = True elif start_position.line <= line_num <= end_position.line: if underline_start is not None: underline_end = len(line) if underline_start is not None and underline_end is not None: underline_line = " " * underline_start underline_width = underline_end - underline_start if underline_width == 0: # In the event that we have a zero-length range, we want to # render it as an underline of width one. This could happen if # we're flagging the EOF token, for example. underline_width = 1 assert underline_width > 0, ( f"The index of the end of the underline ({underline_end}) on " f"line #{line_num} was before the index of the first " + f"non-whitespace character ({underline_start}) on this line. " + f"It's unclear how this should be rendered. This may be a " + f"bug in the caller, or it's possible that the rendering " + f"logic should be changed to handle this case." ) glyphs = output_env.glyphs if underline_width == 1: if has_underline_start and has_underline_end: underline = glyphs.underline_point_character elif has_underline_start: underline = glyphs.underline_start_character elif has_underline_end: underline = glyphs.underline_end_character else: underline = glyphs.underline_character else: underline = glyphs.underline_character * ( underline_end - underline_start - 2 ) if has_underline_start: underline = glyphs.underline_start_character + underline else: underline = glyphs.underline_character + underline if has_underline_end: underline = underline + glyphs.underline_end_character else: underline = underline + glyphs.underline_character underline_line += underline message_lines.append( _MessageLine( text=underline_line, color=underline_color, # Underline, do not wrap (although it should never require # wrapping, since the code should not be wrapped). is_wrappable=False, ) ) return message_lines PK! ZnYY pytch/fuzz.pyimport os import sys import afl from pytch.lexer import lex from pytch.parser import parse from pytch.utils import FileInfo def check_for_buggy_parse(file_info: FileInfo) -> None: lexation = lex(file_info=file_info) parsation = parse(file_info=file_info, tokens=lexation.tokens) if parsation.is_buggy: raise ValueError("found buggy parse") def main() -> None: afl.init() with open(sys.argv[1]) as f: # afl-fuzz will often generate invalid Unicode and count that as a # crash. See # https://barro.github.io/2018/01/taking-a-look-at-python-afl/ try: file_info = FileInfo(file_path="", source_code=f.read()) except UnicodeDecodeError: pass else: check_for_buggy_parse(file_info) os._exit(0) if __name__ == "__main__": main() PK!((pytch/greencst.py"""NOTE: This file auto-generated from ast.txt. Run `bin/generate_syntax_trees.sh` to re-generate. Do not edit! """ from typing import List, Optional, Sequence, Union from .lexer import Token class Node: def __init__(self, children: Sequence[Union["Node", Optional["Token"]]]) -> None: self._children = children @property def children(self) -> Sequence[Union["Node", Optional["Token"]]]: return self._children @property def leading_text(self) -> str: first_child = self.first_present_child if first_child is None: return "" else: return first_child.leading_text @property def text(self) -> str: if len(self._children) == 0: return "" elif len(self._children) == 1: child = self._children[0] if child is None: return "" else: return child.text else: text = "" [first, *middle, last] = self._children if first is not None: text += first.text + first.trailing_text for child in middle: if child is not None: text += child.full_text if last is not None: text += last.leading_text + last.text return text @property def trailing_text(self) -> str: last_child = self.last_present_child if last_child is None: return "" else: return last_child.trailing_text @property def full_text(self) -> str: return "".join(child.full_text for child in self._children if child is not None) @property def first_present_child(self) -> Optional[Union["Node", "Token"]]: for child in self.children: if child is None: continue if isinstance(child, Token): if not child.is_dummy: return child else: child_first_present_child = child.first_present_child if child_first_present_child is not None: return child_first_present_child return None @property def last_present_child(self) -> Optional[Union["Node", "Token"]]: for child in reversed(self.children): if child is None: continue if isinstance(child, Token): if not child.is_dummy: return child else: child_last_present_child = child.last_present_child if child_last_present_child is not None: return child_last_present_child return None @property def leading_width(self) -> int: child = self.first_present_child if child is None: return 0 return child.leading_width @property def trailing_width(self) -> int: child = self.last_present_child if child is None: return 0 return child.trailing_width @property def width(self) -> int: if not self.children: return 0 return self.full_width - self.leading_width - self.trailing_width @property def full_width(self) -> int: return sum( child.full_width if child is not None else 0 for child in self.children ) class Expr(Node): pass class SyntaxTree(Node): def __init__(self, n_expr: Optional[Expr], t_eof: Optional[Token]) -> None: super().__init__([n_expr, t_eof]) self._n_expr = n_expr self._t_eof = t_eof @property def n_expr(self) -> Optional[Expr]: return self._n_expr @property def t_eof(self) -> Optional[Token]: return self._t_eof class Pattern(Node): pass class VariablePattern(Pattern): def __init__(self, t_identifier: Optional[Token]) -> None: super().__init__([t_identifier]) self._t_identifier = t_identifier @property def t_identifier(self) -> Optional[Token]: return self._t_identifier class Parameter(Node): def __init__(self, n_pattern: Optional[Pattern], t_comma: Optional[Token]) -> None: super().__init__([n_pattern, t_comma]) self._n_pattern = n_pattern self._t_comma = t_comma @property def n_pattern(self) -> Optional[Pattern]: return self._n_pattern @property def t_comma(self) -> Optional[Token]: return self._t_comma class ParameterList(Node): def __init__( self, t_lparen: Optional[Token], parameters: Optional[List[Parameter]], t_rparen: Optional[Token], ) -> None: super().__init__( [t_lparen, *(parameters if parameters is not None else []), t_rparen] ) self._t_lparen = t_lparen self._parameters = parameters self._t_rparen = t_rparen @property def t_lparen(self) -> Optional[Token]: return self._t_lparen @property def parameters(self) -> Optional[List[Parameter]]: return self._parameters @property def t_rparen(self) -> Optional[Token]: return self._t_rparen class LetExpr(Expr): def __init__( self, t_let: Optional[Token], n_pattern: Optional[Pattern], n_parameter_list: Optional[ParameterList], t_equals: Optional[Token], n_value: Optional[Expr], t_in: Optional[Token], n_body: Optional[Expr], ) -> None: super().__init__( [t_let, n_pattern, n_parameter_list, t_equals, n_value, t_in, n_body] ) self._t_let = t_let self._n_pattern = n_pattern self._n_parameter_list = n_parameter_list self._t_equals = t_equals self._n_value = n_value self._t_in = t_in self._n_body = n_body @property def t_let(self) -> Optional[Token]: return self._t_let @property def n_pattern(self) -> Optional[Pattern]: return self._n_pattern @property def n_parameter_list(self) -> Optional[ParameterList]: return self._n_parameter_list @property def t_equals(self) -> Optional[Token]: return self._t_equals @property def n_value(self) -> Optional[Expr]: return self._n_value @property def t_in(self) -> Optional[Token]: return self._t_in @property def n_body(self) -> Optional[Expr]: return self._n_body class IfExpr(Expr): def __init__( self, t_if: Optional[Token], n_if_expr: Optional[Expr], t_then: Optional[Token], n_then_expr: Optional[Expr], t_else: Optional[Token], n_else_expr: Optional[Expr], t_endif: Optional[Token], ) -> None: super().__init__( [t_if, n_if_expr, t_then, n_then_expr, t_else, n_else_expr, t_endif] ) self._t_if = t_if self._n_if_expr = n_if_expr self._t_then = t_then self._n_then_expr = n_then_expr self._t_else = t_else self._n_else_expr = n_else_expr self._t_endif = t_endif @property def t_if(self) -> Optional[Token]: return self._t_if @property def n_if_expr(self) -> Optional[Expr]: return self._n_if_expr @property def t_then(self) -> Optional[Token]: return self._t_then @property def n_then_expr(self) -> Optional[Expr]: return self._n_then_expr @property def t_else(self) -> Optional[Token]: return self._t_else @property def n_else_expr(self) -> Optional[Expr]: return self._n_else_expr @property def t_endif(self) -> Optional[Token]: return self._t_endif class IdentifierExpr(Expr): def __init__(self, t_identifier: Optional[Token]) -> None: super().__init__([t_identifier]) self._t_identifier = t_identifier @property def t_identifier(self) -> Optional[Token]: return self._t_identifier class IntLiteralExpr(Expr): def __init__(self, t_int_literal: Optional[Token]) -> None: super().__init__([t_int_literal]) self._t_int_literal = t_int_literal @property def t_int_literal(self) -> Optional[Token]: return self._t_int_literal class BinaryExpr(Expr): def __init__( self, n_lhs: Optional[Expr], t_operator: Optional[Token], n_rhs: Optional[Expr] ) -> None: super().__init__([n_lhs, t_operator, n_rhs]) self._n_lhs = n_lhs self._t_operator = t_operator self._n_rhs = n_rhs @property def n_lhs(self) -> Optional[Expr]: return self._n_lhs @property def t_operator(self) -> Optional[Token]: return self._t_operator @property def n_rhs(self) -> Optional[Expr]: return self._n_rhs class Argument(Node): def __init__(self, n_expr: Optional[Expr], t_comma: Optional[Token]) -> None: super().__init__([n_expr, t_comma]) self._n_expr = n_expr self._t_comma = t_comma @property def n_expr(self) -> Optional[Expr]: return self._n_expr @property def t_comma(self) -> Optional[Token]: return self._t_comma class ArgumentList(Node): def __init__( self, t_lparen: Optional[Token], arguments: Optional[List[Argument]], t_rparen: Optional[Token], ) -> None: super().__init__( [t_lparen, *(arguments if arguments is not None else []), t_rparen] ) self._t_lparen = t_lparen self._arguments = arguments self._t_rparen = t_rparen @property def t_lparen(self) -> Optional[Token]: return self._t_lparen @property def arguments(self) -> Optional[List[Argument]]: return self._arguments @property def t_rparen(self) -> Optional[Token]: return self._t_rparen class FunctionCallExpr(Expr): def __init__( self, n_callee: Optional[Expr], n_argument_list: Optional[ArgumentList] ) -> None: super().__init__([n_callee, n_argument_list]) self._n_callee = n_callee self._n_argument_list = n_argument_list @property def n_callee(self) -> Optional[Expr]: return self._n_callee @property def n_argument_list(self) -> Optional[ArgumentList]: return self._n_argument_list PK!,CuSSpytch/lexer.py"""Lexes the source code into a series of tokens. The token design is roughly based on the tokens in Roslyn. # Trivia Each token has associated "trivia". Trivia accounts for whitespace in the source code, so that we can apply modifications to the original source code (such as autoformatting or refactorings) without losing data. There are two types of trivia: leading and trailing. The leading trivia comes before the token, and the trailing trivia comes after. Consider the following code (the newline is explicitly written out): let foo = 1\n This has four tokens: let: leading: [] trailing: [] foo: leading: [" "] trailing: [] =: leading: [" "] trailing: [] 1: leading: [" "] trailing: ["\n"] The basic rule is that if possible, a trivium is allocated to the leading trivia of the next token rather than the trailing trivia of the previous token. Whitespace and comments are the types of trivia. # Token fields Tokens contain their kind, text, and trivia. They don't contain their position: this allows us to potentially do incremental reparsing, since we can modify tokens directly without having to adjust the positions of all the following tokens. # Kinds The "kind" of a token indicates what kind of token it was. For example, each keyword and symbol has its own kind, as well as things like identifiers and strings. """ from enum import Enum import re from typing import Iterable, Iterator, List, Mapping, Optional, Pattern, Tuple import attr from .errors import Error, ErrorCode, Severity from .utils import FileInfo, OffsetRange class TriviumKind(Enum): WHITESPACE = "whitespace" NEWLINE = "newline" COMMENT = "comment" ERROR = "error" class TokenKind(Enum): IDENTIFIER = "identifier" LET = "'let'" COMMA = "','" INT_LITERAL = "integer literal" EQUALS = "'='" LPAREN = "'('" RPAREN = "')'" IF = "'if'" THEN = "'then'" ELSE = "'else'" PLUS = "'+'" MINUS = "'-'" OR = "'or'" AND = "'and'" ERROR = "error" """Any invalid token.""" EOF = "the end of the file" """This token is a zero-width token denoting the end of the file. It's inserted by the pre-parser, so we can always expect there to be an EOF token in the token stream. """ # Dummy tokens; inserted by the pre-parser. DUMMY_IN = "the end of a 'let' binding" DUMMY_SEMICOLON = "the end of a statement" DUMMY_ENDIF = "the end of an 'if' expression" class Associativity(Enum): LEFT = "left" RIGHT = "right" BINARY_OPERATORS: Mapping[ TokenKind, Tuple[int, Associativity] # Precedence: higher binds more tightly. ] = { TokenKind.PLUS: (4, Associativity.LEFT), TokenKind.MINUS: (4, Associativity.LEFT), TokenKind.AND: (3, Associativity.LEFT), TokenKind.OR: (2, Associativity.LEFT), TokenKind.DUMMY_SEMICOLON: (1, Associativity.RIGHT), } BINARY_OPERATOR_PRECEDENCES = set( precedence for precedence, associativity in BINARY_OPERATORS.values() ) assert all(precedence > 0 for precedence in BINARY_OPERATOR_PRECEDENCES) assert BINARY_OPERATOR_PRECEDENCES == set( range(min(BINARY_OPERATOR_PRECEDENCES), max(BINARY_OPERATOR_PRECEDENCES) + 1) ) BINARY_OPERATOR_KINDS = list(BINARY_OPERATORS.keys()) @attr.s(auto_attribs=True, frozen=True) class Trivium: kind: TriviumKind text: str @property def width(self) -> int: return len(self.text) @attr.s(auto_attribs=True, frozen=True) class Token: kind: TokenKind text: str leading_trivia: List[Trivium] trailing_trivia: List[Trivium] def update(self, **kwargs) -> "Token": return attr.evolve(self, **kwargs) @property def is_dummy(self): return self.kind == TokenKind.EOF or self.kind.name.lower().startswith("dummy") @property def full_text(self) -> str: return self.leading_text + self.text + self.trailing_text @property def width(self) -> int: return len(self.text) @property def full_width(self) -> int: """The width of the token, including leading and trailing trivia.""" return self.leading_width + self.width + self.trailing_width @property def leading_width(self) -> int: return sum(trivium.width for trivium in self.leading_trivia) @property def leading_text(self) -> str: return "".join(trivium.text for trivium in self.leading_trivia) @property def trailing_width(self) -> int: return sum(trivium.width for trivium in self.trailing_trivia) @property def trailing_text(self) -> str: return "".join(trivium.text for trivium in self.trailing_trivia) @property def is_followed_by_newline(self) -> bool: return any( trivium.kind == TriviumKind.NEWLINE for trivium in self.trailing_trivia ) @attr.s(auto_attribs=True, frozen=True) class State: file_info: FileInfo offset: int def update(self, **kwargs): return attr.evolve(self, **kwargs) def advance_offset(self, offset_delta: int) -> "State": assert offset_delta >= 0 return self.update(offset=self.offset + offset_delta) @attr.s(auto_attribs=True, frozen=True) class Lexation: tokens: List[Token] errors: List[Error] @property def full_width(self) -> int: return sum(token.full_width for token in self.tokens) WHITESPACE_RE = re.compile(r"[ \t]+") NEWLINE_RE = re.compile(r"\n") IDENTIFIER_RE = re.compile(r"[a-zA-Z_][a-zA-Z0-9_]*") INT_LITERAL_RE = re.compile(r"[0-9]+") EQUALS_RE = re.compile(r"=") LET_RE = re.compile(r"let") COMMA_RE = re.compile(r",") LPAREN_RE = re.compile(r"\(") RPAREN_RE = re.compile(r"\)") IF_RE = re.compile(r"if") THEN_RE = re.compile(r"then") ELSE_RE = re.compile(r"else") PLUS_RE = re.compile(r"\+") MINUS_RE = re.compile(r"-") OR_RE = re.compile(r"or") AND_RE = re.compile(r"and") UNKNOWN_TOKEN_RE = re.compile(r"[^ \n\t\ra-zA-Z0-9]+") class Lexer: def lex(self, file_info: FileInfo) -> Lexation: state = State(file_info=file_info, offset=0) errors = [] tokens = [] while True: last_offset = state.offset (state, token) = self.lex_token(state) tokens.append(token) if token.kind == TokenKind.ERROR: errors.append( Error( file_info=file_info, code=ErrorCode.INVALID_TOKEN, severity=Severity.ERROR, message=f"Invalid token '{token.text}'.", notes=[], range=file_info.get_range_from_offset_range( OffsetRange( start=state.offset - token.trailing_width - token.width, end=state.offset - token.trailing_width, ) ), ) ) if token.kind == TokenKind.EOF: break assert state.offset >= last_offset, "No progress made in lexing" return Lexation(tokens=tokens, errors=errors) def lex_leading_trivia(self, state: State) -> Tuple[State, List[Trivium]]: leading_trivia = self.lex_next_trivia_by_patterns( state, {TriviumKind.WHITESPACE: WHITESPACE_RE} ) state = state.advance_offset(sum(trivium.width for trivium in leading_trivia)) return (state, leading_trivia) def lex_trailing_trivia(self, state: State) -> Tuple[State, List[Trivium]]: trailing_trivia = self.lex_next_trivia_by_patterns( state, {TriviumKind.WHITESPACE: WHITESPACE_RE, TriviumKind.NEWLINE: NEWLINE_RE}, ) newline_indices = [ i for (i, trivium) in enumerate(trailing_trivia) if trivium.kind == TriviumKind.NEWLINE ] if newline_indices: last_newline_index = newline_indices[-1] + 1 else: last_newline_index = 0 # Avoid consuming whitespace or other trivia, after the last # newline. We'll consume that as the leading trivia of the next # token. trailing_trivia = trailing_trivia[:last_newline_index] state = state.advance_offset(sum(trivium.width for trivium in trailing_trivia)) return (state, trailing_trivia) def lex_next_trivia_by_patterns( self, state: State, trivia_patterns: Mapping[TriviumKind, Pattern] ) -> List[Trivium]: trivia: List[Trivium] = [] offset = state.offset while True: matches = [ (trivium_kind, regex.match(state.file_info.source_code, pos=offset)) for trivium_kind, regex in trivia_patterns.items() ] filtered_matches = [ (trivium_kind, match) for trivium_kind, match in matches if match is not None ] if not filtered_matches: return trivia assert ( len(filtered_matches) == 1 ), "More than one possible type of trivia found" trivium_kind, match = filtered_matches[0] trivium = Trivium(kind=trivium_kind, text=match.group()) trivia.append(trivium) offset += trivium.width def lex_token(self, state: State) -> Tuple[State, Token]: (state, leading_trivia) = self.lex_leading_trivia(state) token_info = None if token_info is None: (maybe_state, token_info) = self.lex_next_token_by_patterns( state, { TokenKind.INT_LITERAL: INT_LITERAL_RE, TokenKind.EQUALS: EQUALS_RE, TokenKind.LET: LET_RE, TokenKind.COMMA: COMMA_RE, TokenKind.LPAREN: LPAREN_RE, TokenKind.RPAREN: RPAREN_RE, TokenKind.IF: IF_RE, TokenKind.THEN: THEN_RE, TokenKind.ELSE: ELSE_RE, TokenKind.PLUS: PLUS_RE, TokenKind.MINUS: MINUS_RE, TokenKind.OR: OR_RE, TokenKind.AND: AND_RE, TokenKind.IDENTIFIER: IDENTIFIER_RE, }, ) if token_info is not None: state = maybe_state if token_info is None: (maybe_state, token_info) = self.lex_next_token_by_patterns( state, {TokenKind.ERROR: UNKNOWN_TOKEN_RE} ) if token_info is not None: state = maybe_state if token_info is None: # We can't find any match at all? Then there must be only # trivia remaining in the stream, so just produce the EOF # token. token_info = (TokenKind.EOF, "") (state, trailing_trivia) = self.lex_trailing_trivia(state) (token_kind, token_text) = token_info return ( state, Token( kind=token_kind, text=token_text, leading_trivia=leading_trivia, trailing_trivia=trailing_trivia, ), ) def lex_next_token_by_patterns( self, state: State, token_patterns: Mapping[TokenKind, Pattern] ) -> Tuple[State, Optional[Tuple[TokenKind, str]]]: matches = [ (token_kind, regex.match(state.file_info.source_code, pos=state.offset)) for token_kind, regex in token_patterns.items() ] filtered_matches = [ (token_kind, match) for token_kind, match in matches if match is not None ] if not filtered_matches: return (state, None) (kind, match) = max(filtered_matches, key=lambda x: len(x[1].group())) token_text = match.group() state = state.advance_offset(len(token_text)) return (state, (kind, token_text)) def with_indentation_levels(tokens: Iterable[Token],) -> Iterator[Tuple[int, Token]]: indentation_level = 0 is_first_token_on_line = True for token in tokens: if is_first_token_on_line: indentation_level = token.leading_width is_first_token_on_line = False if token.is_followed_by_newline: is_first_token_on_line = True yield (indentation_level, token) def make_dummy_token(kind: TokenKind) -> Token: token = Token(kind=kind, text="", leading_trivia=[], trailing_trivia=[]) assert token.is_dummy return token def preparse(tokens: Iterable[Token]) -> Iterator[Token]: """Insert dummy tokens for lightweight constructs into the token stream. This technique is based off of the "pre-parsing" step as outlined in the F# 4.0 spec, section 15: Lightweight Syntax: http://fsharp.org/specs/language-spec/4.0/FSharpSpec-4.0-latest.pdf The pre-parser inserts dummy tokens into the token stream where we would expect the token to go in the non-lightweight token stream. For example, it might convert this: let foo = 1 foo into this: let foo = 1 $in foo We do the same thing, although with significantly fewer restrictions on the source code's indentation. """ stack: List[Tuple[int, int, Token]] = [] def unwind( indentation_level: int, unwind_statements: bool, kind: TokenKind = None, kind_indentation_level: int = None, ) -> Iterator[Token]: while stack: (top_indentation_level, top_line, top_token) = stack[-1] stack.pop() # If we're unwinding to a specific token kind, only stop once we've # reached that token kind. if kind is not None and top_token.kind == kind: if ( kind_indentation_level is None or top_indentation_level <= kind_indentation_level ): return can_be_followed_by_new_statement = True if top_token.kind == TokenKind.LET: # If we see something of the form # # ``` # let foo = bar # baz # ``` # # then no matter what, we will treat the following `baz` as the # `let` body, not a new statement. can_be_followed_by_new_statement = False yield make_dummy_token(TokenKind.DUMMY_IN) elif ( top_token.kind == TokenKind.IF or top_token.kind == TokenKind.THEN or top_token.kind == TokenKind.ELSE ): yield make_dummy_token(TokenKind.DUMMY_ENDIF) if ( unwind_statements and can_be_followed_by_new_statement and indentation_level == top_indentation_level ): yield make_dummy_token(TokenKind.DUMMY_SEMICOLON) if kind is None and top_indentation_level <= indentation_level: return is_first_token = True current_line = 0 eof_token = None previous_line = None previous_token = None for indentation_level, token in with_indentation_levels(tokens): if token.kind == TokenKind.EOF: eof_token = token break if stack: (previous_indentation_level, _, _) = stack[-1] else: previous_indentation_level = 0 maybe_expr_continuation = True maybe_new_statement = False if previous_line is not None: assert previous_line <= current_line if current_line > previous_line: maybe_new_statement = True if indentation_level <= previous_indentation_level: maybe_expr_continuation = False is_part_of_binary_expr = token.kind in BINARY_OPERATOR_KINDS or ( previous_token is not None and previous_token.kind in BINARY_OPERATOR_KINDS ) has_comma = token.kind == TokenKind.COMMA or ( previous_token is not None and previous_token.kind == TokenKind.COMMA ) if token.kind == TokenKind.LPAREN: # Pass `0` as the indentation level to reset the indentation level # in the stack until we've exited the parenthesized tokens. stack.append((0, current_line, token)) elif token.kind == TokenKind.RPAREN: yield from unwind( indentation_level, unwind_statements=False, kind=TokenKind.LPAREN ) elif token.kind == TokenKind.LET: if not maybe_expr_continuation: yield from unwind(indentation_level, unwind_statements=False) stack.append((indentation_level, current_line, token)) elif token.kind == TokenKind.IF: if not maybe_expr_continuation: yield from unwind(indentation_level, unwind_statements=True) stack.append((indentation_level, current_line, token)) elif token.kind == TokenKind.THEN: yield from unwind( indentation_level, unwind_statements=False, kind=TokenKind.IF ) stack.append((indentation_level, current_line, token)) elif token.kind == TokenKind.ELSE: yield from unwind( indentation_level, unwind_statements=False, kind=TokenKind.THEN, kind_indentation_level=indentation_level, ) stack.append((indentation_level, current_line, token)) elif maybe_new_statement and not is_part_of_binary_expr and not has_comma: if indentation_level <= previous_indentation_level: yield from unwind(indentation_level, unwind_statements=True) stack.append((indentation_level, current_line, token)) elif is_first_token: stack.append((indentation_level, current_line, token)) yield token is_first_token = False previous_line = current_line previous_token = token current_line += sum( len(trivium.text) for trivium in token.trailing_trivia if trivium.kind == TriviumKind.NEWLINE ) yield from unwind(indentation_level=-1, unwind_statements=False) assert eof_token is not None yield eof_token def lex(file_info: FileInfo) -> Lexation: lexer = Lexer() lexation = lexer.lex(file_info=file_info) tokens = list(preparse(lexation.tokens)) errors = lexation.errors source_code_length = len(file_info.source_code) tokens_length = sum(token.full_width for token in lexation.tokens) if source_code_length != tokens_length: errors.append( Error( file_info=file_info, code=ErrorCode.PARSED_LENGTH_MISMATCH, severity=Severity.WARNING, message=( f"Mismatch between source code length ({source_code_length}) " + f"and total length of parsed tokens ({tokens_length}). " + f"The parse tree for this file is probably incorrect. " + f"This is a bug. Please report it!" ), notes=[], ) ) num_lets = 0 num_ins = 0 for token in tokens: if token.kind == TokenKind.LET: num_lets += 1 elif token.kind == TokenKind.DUMMY_IN: num_ins += 1 if num_lets != num_ins: errors.append( Error( file_info=file_info, code=ErrorCode.LET_IN_MISMATCH, severity=Severity.WARNING, message=( f"Mismatch between the number of 'let' bindings ({num_lets}) " + f"and the number of inferred ends " + f"of these 'let' bindings ({num_ins}). " + f"The parse tree for this file is probably incorrect. " + f"This is a bug. Please report it!" ), notes=[], ) ) num_ifs = 0 num_endifs = 0 for token in tokens: if token.kind == TokenKind.IF: num_ifs += 1 elif token.kind == TokenKind.DUMMY_ENDIF: num_endifs += 1 if num_ifs != num_endifs: errors.append( Error( file_info=file_info, code=ErrorCode.IF_ENDIF_MISMATCH, severity=Severity.WARNING, message=( f"Mismatch between the number of 'if' expressions ({num_ifs}) " + f"and the number of inferred ends " + f"of these 'if' expressions ({num_endifs}). " + f"The parse tree for this file is probably incorrect. " + f"This is a bug. Please report it!" ), notes=[], ) ) return Lexation(tokens=tokens, errors=errors) PK!%dAApytch/parser.py"""Parses a series of tokens into a concrete syntax tree (CST). The concrete syntax tree is not quite an abstract syntax tree: the tokens contained therein are enough to reconstitute the source code. The non-meaningful parts of the program are contained within "trivia" nodes. See the lexer for more information. The *green* CST is considered to be immutable and must not be modified. The *red* CST is based off of the green syntax tree. It is also immutable, but its nodes are generated lazily (since they contain `parent` pointers and therefore reference cycles). """ from typing import Iterator, List, Optional, Tuple, Union import attr from .errors import Error, ErrorCode, Note, Severity from .greencst import ( Argument, ArgumentList, BinaryExpr, Expr, FunctionCallExpr, IdentifierExpr, IfExpr, IntLiteralExpr, LetExpr, Node, Parameter, ParameterList, Pattern, SyntaxTree, VariablePattern, ) from .lexer import ( Associativity, BINARY_OPERATOR_KINDS, BINARY_OPERATORS, Token, TokenKind, Trivium, TriviumKind, ) from .utils import FileInfo, OffsetRange, Range def walk_tokens(node: Node) -> Iterator[Token]: for child in node.children: if child is None: continue if isinstance(child, Token): yield child elif isinstance(child, Node): yield from walk_tokens(child) else: assert False, f"Unexpected node child type: {child!r}" @attr.s(auto_attribs=True, frozen=True) class Parsation: green_cst: SyntaxTree errors: List[Error] @property def is_buggy(self) -> bool: """Return whether the parse tree violates any known invariants.""" assert ErrorCode.PARSED_LENGTH_MISMATCH.value == 9000 return any( error.code.value >= ErrorCode.PARSED_LENGTH_MISMATCH.value for error in self.errors ) class ParseException(Exception): def __init__(self, error: Error) -> None: self.error = error @attr.s(auto_attribs=True, frozen=True) class State: file_info: FileInfo tokens: List[Token] = attr.ib() """The list of tokens that make up the file.""" @tokens.validator def check(self, attribute, value) -> None: assert len(self.tokens) > 0, "Expected at least one token (the EOF token)." assert ( self.tokens[-1].kind == TokenKind.EOF ), "Token stream must end with an EOF token." token_index: int """The index into the token list indicating where we currently are in the process of parsing.""" offset: int """The offset into the source file. Must be kept in sync with `token_index`.""" errors: List[Error] """A list of errors that have occurred during parsing so far.""" is_recovering: bool """Whether or not we are in the process of recovering from a parser error. While recovering, we'll consume tokens blindly until we find a token of a kind that we're expecting (a synchronization token), and resume parsing from there.""" error_tokens: List[Token] """A list of tokens that have been consumed during error recovery.""" sync_token_kinds: List[List[TokenKind]] """A stack of collections of tokens. Some callers will push a set of tokens into this stack. This set indicates tokens that can be synchronized to. If a function deeper in the stack encounters an error, then parsing will synchronize to the next token that appears somewhere in this stack, and unwind to the its caller.""" # assert token_index < len(tokens) @property def end_of_file_offset_range(self) -> OffsetRange: last_offset = len(self.file_info.source_code) last_non_empty_token = next( (token for token in reversed(self.tokens) if token.full_width > 0), None ) if last_non_empty_token is None: start = 0 end = 0 else: first_trailing_newline_index = 0 for trivium in last_non_empty_token.trailing_trivia: if trivium.kind == TriviumKind.NEWLINE: break first_trailing_newline_index += 1 trailing_trivia_up_to_newline = last_non_empty_token.trailing_trivia[ : first_trailing_newline_index + 1 ] trailing_trivia_up_to_newline_length = sum( trivium.width for trivium in trailing_trivia_up_to_newline ) start = last_offset - trailing_trivia_up_to_newline_length end = start return OffsetRange(start=start, end=end) def get_current_token(self) -> Token: assert 0 <= self.token_index < len(self.tokens) token = self.tokens[self.token_index] error_trivia = [ Trivium(kind=TriviumKind.ERROR, text=error_token.full_text) for error_token in self.error_tokens ] return token.update(leading_trivia=[*error_trivia, *token.leading_trivia]) @property def current_token_offset_range(self) -> OffsetRange: current_token = self.tokens[self.token_index] # We usually don't want to point to a dummy token, so rewind until # we find a non-dummy token. token_index = self.token_index offset = self.offset did_rewind = False while token_index > 0 and current_token.is_dummy: did_rewind = True token_index -= 1 current_token = self.tokens[token_index] offset -= current_token.full_width start = offset + current_token.leading_width end = start + current_token.width if did_rewind: # If we rewound, point to the location immediately after the # token we rewound to, rather than that token itself. start = end return OffsetRange(start=start, end=end) @property def current_token_kind(self) -> TokenKind: return self.tokens[self.token_index].kind @property def current_token_range(self) -> Range: return self.file_info.get_range_from_offset_range( self.current_token_offset_range ) @property def next_token(self) -> Token: assert ( self.tokens[self.token_index].kind != TokenKind.EOF ), "Tried to look at the token after the EOF token" return self.tokens[self.token_index + 1] def update(self, **kwargs) -> "State": return attr.evolve(self, **kwargs) def add_error(self, error: Error) -> "State": return self.update(errors=self.errors + [error]) def assert_(self, condition: bool, code: ErrorCode, message: str) -> "State": if not condition: return self.add_error( Error( file_info=self.file_info, code=code, severity=Severity.WARNING, message=f"Assertion failure -- please report this! {message}", notes=[], ) ) return self def start_recovery(self) -> "State": assert ( not self.is_recovering ), "Tried to start parser error recovery while already recovering" return self.update(is_recovering=True) def finish_recovery(self) -> "State": assert ( self.is_recovering ), "Tried to finish parser error recovery while not recovering" return self.update(is_recovering=False) def push_sync_token_kinds(self, token_kinds: List[TokenKind]) -> "State": return self.update(sync_token_kinds=self.sync_token_kinds + [token_kinds]) def pop_sync_token_kinds(self) -> "State": assert self.sync_token_kinds return self.update(sync_token_kinds=self.sync_token_kinds[:-1]) def consume_token(self, token: Token) -> "State": assert ( self.get_current_token().kind != TokenKind.EOF ), "Tried to consume the EOF token." # We may have added leading error tokens as trivia, but we don't want # to double-count their width, since they've already been consumed. full_width_without_errors = ( token.width + sum( trivium.width for trivium in token.leading_trivia if trivium.kind != TriviumKind.ERROR ) + sum( trivium.width for trivium in token.trailing_trivia if trivium.kind != TriviumKind.ERROR ) ) return self.update( token_index=self.token_index + 1, offset=self.offset + full_width_without_errors, error_tokens=[], ) def consume_error_token(self, token: Token) -> "State": # Make sure not to use `self.current_token`, since that would duplicate # the error tokens. assert 0 <= self.token_index < len(self.tokens) token = self.tokens[self.token_index] assert ( token.kind != TokenKind.EOF ), "Tried to consume the EOF token as an error token." return self.update( token_index=self.token_index + 1, offset=self.offset + token.full_width, error_tokens=self.error_tokens + [token], ) class UnhandledParserException(Exception): def __init__(self, state: State) -> None: self._state = state def __str__(self) -> str: file_contents = "" for i, token in enumerate(self._state.tokens): if i == self._state.token_index: file_contents += "" file_contents += token.full_text error_messages = "\n".join( f"{error.code.name}[{error.code.value}]: {error.message}" for error in self._state.errors ) return f"""All tokens: {self._state.tokens} Parser location: {file_contents} There are {len(self._state.tokens)} tokens total, and we are currently at token #{self._state.token_index}, which is: {self._state.get_current_token()}. Errors so far: {error_messages or ""} Original exception: {self.__cause__.__class__.__name__}: {self.__cause__} """ class Parser: def parse(self, file_info: FileInfo, tokens: List[Token]) -> Parsation: state = State( file_info=file_info, tokens=tokens, token_index=0, offset=0, errors=[], is_recovering=False, error_tokens=[], sync_token_kinds=[[TokenKind.EOF]], ) # File with only whitespace. if state.get_current_token().kind == TokenKind.EOF: syntax_tree = SyntaxTree(n_expr=None, t_eof=state.get_current_token()) return Parsation(green_cst=syntax_tree, errors=state.errors) try: (state, n_expr) = self.parse_expr(state, allow_naked_lets=True) (state, t_eof) = self.expect_token(state, [TokenKind.EOF]) syntax_tree = SyntaxTree(n_expr=n_expr, t_eof=t_eof) source_code_length = len(file_info.source_code) tokens_length = sum(token.full_width for token in walk_tokens(syntax_tree)) state = state.assert_( source_code_length == tokens_length, code=ErrorCode.PARSED_LENGTH_MISMATCH, message=( f"Mismatch between source code length " + f"({source_code_length}) " + f"and total length of parsed tokens " + f"({tokens_length}). " + f"The parse tree for this file is probably incorrect." ), ) return Parsation(green_cst=syntax_tree, errors=state.errors) except UnhandledParserException: raise except Exception as e: raise UnhandledParserException(state) from e def parse_let_expr( self, state: State, allow_naked_lets=False ) -> Tuple[State, Optional[LetExpr]]: t_let_range = state.current_token_range (state, t_let) = self.expect_token(state, [TokenKind.LET]) if not t_let: return (state, None) state = state.push_sync_token_kinds([TokenKind.DUMMY_IN]) let_note = Note( file_info=state.file_info, message="This is the beginning of the let-binding.", range=t_let_range, ) notes = [let_note] (state, n_let_expr) = self.parse_let_expr_binding( state, allow_naked_lets=allow_naked_lets, t_let=t_let, notes=notes ) (state, t_in) = self.expect_token(state, [TokenKind.DUMMY_IN], notes=notes) if allow_naked_lets and state.get_current_token().kind == TokenKind.EOF: n_body = None else: (state, n_body) = self.parse_expr(state, allow_naked_lets=allow_naked_lets) n_pattern = n_let_expr.n_pattern if n_let_expr is not None else None n_parameter_list = ( n_let_expr.n_parameter_list if n_let_expr is not None else None ) t_equals = n_let_expr.t_equals if n_let_expr is not None else None n_value = n_let_expr.n_value if n_let_expr is not None else None state = state.pop_sync_token_kinds() return ( state, LetExpr( t_let=t_let, n_pattern=n_pattern, n_parameter_list=n_parameter_list, t_equals=t_equals, n_value=n_value, t_in=t_in, n_body=n_body, ), ) def parse_let_expr_binding( self, state: State, allow_naked_lets: bool, t_let: Token, notes: List[Note] ) -> Tuple[State, Optional[LetExpr]]: n_pattern: Optional[Pattern] n_parameter_list: Optional[ParameterList] = None if state.get_current_token().kind == TokenKind.EQUALS: # If the token is an equals sign, assume that the name is missing # (e.g. during editing, the user is renaming the variable), but # that the rest of the let-binding is present. n_pattern = None state = state.add_error( Error( file_info=state.file_info, code=ErrorCode.EXPECTED_PATTERN, severity=Severity.ERROR, message="I was expecting a pattern after 'let'.", notes=notes, range=state.current_token_range, ) ) elif ( state.get_current_token().kind == TokenKind.IDENTIFIER and state.next_token.kind == TokenKind.LPAREN ): # Assume it's a function definition. (state, t_identifier) = self.expect_token(state, [TokenKind.IDENTIFIER]) n_pattern = VariablePattern(t_identifier=t_identifier) (state, n_parameter_list) = self.parse_parameter_list(state) else: (state, n_pattern) = self.parse_pattern( state, error=Error( file_info=state.file_info, code=ErrorCode.EXPECTED_PATTERN, severity=Severity.ERROR, message="I was expecting a pattern after 'let'.", notes=notes, range=state.current_token_range, ), ) (state, t_equals) = self.expect_token(state, [TokenKind.EQUALS], notes=notes) (state, n_value) = self.parse_expr(state, allow_naked_lets=False) return ( state, LetExpr( t_let=t_let, n_pattern=n_pattern, n_parameter_list=n_parameter_list, t_equals=t_equals, n_value=n_value, t_in=None, # Parsed by caller. n_body=None, # Parsed by caller. ), ) def parse_if_expr(self, state: State) -> Tuple[State, Optional[IfExpr]]: (state, t_if) = self.expect_token(state, [TokenKind.IF]) if not t_if: return (state, None) state = state.push_sync_token_kinds([TokenKind.DUMMY_ENDIF]) (state, n_if_expr) = self.parse_expr(state) (state, t_then) = self.expect_token(state, [TokenKind.THEN]) (state, n_then_expr) = self.parse_expr(state) if state.current_token_kind == TokenKind.ELSE: (state, t_else) = self.expect_token(state, [TokenKind.ELSE]) (state, n_else_expr) = self.parse_expr(state) else: t_else = None n_else_expr = None (state, t_endif) = self.expect_token(state, [TokenKind.DUMMY_ENDIF]) state = state.pop_sync_token_kinds() return ( state, IfExpr( t_if=t_if, n_if_expr=n_if_expr, t_then=t_then, n_then_expr=n_then_expr, t_else=t_else, n_else_expr=n_else_expr, t_endif=t_endif, ), ) def parse_pattern( self, state: State, error: Error = None ) -> Tuple[State, Optional[Pattern]]: (state, t_identifier) = self.expect_token( state, [TokenKind.IDENTIFIER], error=error ) if t_identifier: return (state, VariablePattern(t_identifier=t_identifier)) else: return (state, None) def parse_expr( self, state: State, min_precedence: int = 0, # Set when we allow let-bindings without associated expressions. For # example, this at the top-level: # # # Non-naked let; has the expression `let bar = 2` # let foo = # # Non-naked let; has the expression `bar` # let bar = 2 # bar # # # Naked let: no expression for this let-binding. # let bar = 2 allow_naked_lets: bool = False, ) -> Tuple[State, Optional[Expr]]: """Parse an expression, even if that parse involves left-recursion. This parses the expression using precedence-climbing to account for operator precedence and associativity. See this excellent article: https://eli.thegreenplace.net/2012/08/02/parsing-expressions-by-precedence-climbing """ (state, n_expr) = self.parse_non_binary_expr( state, allow_naked_lets=allow_naked_lets ) if n_expr is None: return (state, None) while state.current_token_kind in BINARY_OPERATORS: (precedence, associativity) = BINARY_OPERATORS[state.current_token_kind] if precedence < min_precedence: break if associativity is Associativity.LEFT: next_min_precedence = precedence + 1 elif associativity is Associativity.RIGHT: next_min_precedence = precedence else: assert False, "Invalid associativity" (state, t_operator) = self.expect_token(state, BINARY_OPERATOR_KINDS) assert ( t_operator is not None ), "Should have been checked by the while-loop condition" (state, n_rhs) = self.parse_expr( state, min_precedence=next_min_precedence, allow_naked_lets=allow_naked_lets, ) n_expr = BinaryExpr(n_lhs=n_expr, t_operator=t_operator, n_rhs=n_rhs) return (state, n_expr) def parse_non_binary_expr( self, state: State, allow_naked_lets: bool ) -> Tuple[State, Optional[Expr]]: (state, n_expr) = self.parse_atom(state, allow_naked_lets=allow_naked_lets) while n_expr is not None: token = state.get_current_token() if token.kind == TokenKind.EOF: break elif token.kind == TokenKind.LPAREN: (state, n_expr) = self.parse_function_call( state, current_token=token, n_callee=n_expr ) else: break return (state, n_expr) def skip_past(self, state: State, kind: TokenKind) -> State: while state.current_token_kind != kind: state = state.consume_error_token(state.get_current_token()) state = state.consume_error_token(state.get_current_token()) return state def add_error_and_recover(self, state: State, error: Error) -> State: if state.is_recovering: return state state = state.start_recovery() sync_token_kinds = set( token_kind for sync_token_kinds in state.sync_token_kinds for token_kind in sync_token_kinds ) state = state.add_error(error) while state.current_token_kind != TokenKind.EOF: current_token = state.get_current_token() if current_token.kind == TokenKind.LET: # 'let' is *always* paired with a dummy 'in', thanks to the # pre-parser, so make sure to synchronize past that 'in'. # Otherwise we end up with too many 'in's for our 'let's state = self.skip_past(state, TokenKind.DUMMY_IN) continue if current_token.kind in sync_token_kinds: return state state = state.consume_error_token(state.get_current_token()) return state def parse_atom( self, state: State, allow_naked_lets: bool = False ) -> Tuple[State, Optional[Expr]]: token = state.get_current_token() if token.kind == TokenKind.IDENTIFIER: return self.parse_identifier_expr(state) elif token.kind == TokenKind.INT_LITERAL: return self.parse_int_literal(state) elif token.kind == TokenKind.LET: return self.parse_let_expr(state, allow_naked_lets=allow_naked_lets) elif token.kind == TokenKind.IF: return self.parse_if_expr(state) else: state = self.add_error_and_recover( state, Error( file_info=state.file_info, severity=Severity.ERROR, code=ErrorCode.EXPECTED_EXPRESSION, message=( "I was expecting an expression, but instead got " + self.describe_token(state.get_current_token()) + "." ), range=state.current_token_range, notes=[], ), ) return (state, None) raise UnhandledParserException(state) from ValueError( f"tried to parse expression of unsupported token kind {token.kind}" ) def parse_function_call( self, state: State, current_token: Token, n_callee: Expr ) -> Tuple[State, Optional[FunctionCallExpr]]: (state, n_argument_list) = self.parse_argument_list(state) return ( state, FunctionCallExpr(n_callee=n_callee, n_argument_list=n_argument_list), ) def parse_argument_list(self, state: State) -> Tuple[State, Optional[ArgumentList]]: t_lparen_range = state.current_token_range (state, t_lparen) = self.expect_token(state, [TokenKind.LPAREN]) if t_lparen is None: state = self.add_error_and_recover( state, Error( file_info=state.file_info, code=ErrorCode.EXPECTED_LPAREN, severity=Severity.ERROR, message=( "I was expecting a '(' to indicate the start of a " + "function argument list, but instead got " + self.describe_token(state.get_current_token()) + "." ), notes=[], range=state.current_token_range, ), ) return (state, None) state = state.push_sync_token_kinds([TokenKind.RPAREN]) arguments: List[Argument] = [] while state.current_token_kind not in [TokenKind.RPAREN, TokenKind.EOF]: (state, n_argument) = self.parse_argument(state) if n_argument is None: break arguments.append(n_argument) if n_argument.t_comma is None: break state = state.pop_sync_token_kinds() (state, t_rparen) = self.expect_token( state, [TokenKind.RPAREN], error=Error( file_info=state.file_info, code=ErrorCode.EXPECTED_RPAREN, severity=Severity.ERROR, message=( "I was expecting a ')' to indicate the end of this " + "function argument list, but instead got " + self.describe_token(state.get_current_token()) + "." ), notes=[ Note( file_info=state.file_info, message="The beginning of the argument list is here.", range=t_lparen_range, ) ], range=state.current_token_range, ), ) return ( state, ArgumentList(t_lparen=t_lparen, arguments=arguments, t_rparen=t_rparen), ) def parse_argument(self, state: State) -> Tuple[State, Optional[Argument]]: argument_start_offset = state.offset (state, n_expr) = self.parse_expr(state) if n_expr is None: return (state, None) token = state.get_current_token() if token.kind == TokenKind.RPAREN: return (state, Argument(n_expr=n_expr, t_comma=None)) if token.kind == TokenKind.COMMA: (state, t_comma) = self.expect_token(state, [TokenKind.COMMA]) return (state, Argument(n_expr=n_expr, t_comma=t_comma)) argument_end_offset = ( argument_start_offset + n_expr.leading_width + n_expr.width ) # The end offset is exclusive, so when the position is used as the # start offset, it's one character after the argument (where you # would expect the comma to go). argument_position = state.file_info.get_position_for_offset(argument_end_offset) expected_comma_range = Range(start=argument_position, end=argument_position) error = Error( file_info=state.file_info, code=ErrorCode.EXPECTED_END_OF_ARGUMENT_LIST, severity=Severity.ERROR, message=("I was expecting a ',' or ')' after the previous argument."), notes=[], range=expected_comma_range, ) (state, t_comma) = self.expect_token(state, [TokenKind.COMMA], error=error) return (state, Argument(n_expr=n_expr, t_comma=t_comma)) def parse_parameter_list( self, state: State ) -> Tuple[State, Optional[ParameterList]]: t_lparen_range = state.current_token_range (state, t_lparen) = self.expect_token(state, [TokenKind.LPAREN]) if t_lparen is None: state = self.add_error_and_recover( state, Error( file_info=state.file_info, code=ErrorCode.EXPECTED_LPAREN, severity=Severity.ERROR, message=( "I was expecting a '(' to indicate the start of a " + "function parameter list, but instead got " + self.describe_token(state.get_current_token()) + "." ), notes=[], range=state.current_token_range, ), ) return (state, None) parameters: List[Parameter] = [] while state.current_token_kind not in [TokenKind.RPAREN, TokenKind.EOF]: (state, n_parameter) = self.parse_parameter(state) if n_parameter is None: break parameters.append(n_parameter) if n_parameter.t_comma is None: break (state, t_rparen) = self.expect_token( state, [TokenKind.RPAREN], error=Error( file_info=state.file_info, code=ErrorCode.EXPECTED_RPAREN, severity=Severity.ERROR, message=( "I was expecting a ')' to indicate the end of this " + "function parameter list, but instead got " + self.describe_token(state.get_current_token()) + "." ), notes=[ Note( file_info=state.file_info, message="The beginning of the parameter list is here.", range=t_lparen_range, ) ], range=state.current_token_range, ), ) return ( state, ParameterList(t_lparen=t_lparen, parameters=parameters, t_rparen=t_rparen), ) def parse_parameter(self, state: State) -> Tuple[State, Optional[Parameter]]: parameter_start_offset = state.offset (state, n_pattern) = self.parse_pattern(state) if n_pattern is None: return (state, None) token = state.get_current_token() if token.kind == TokenKind.RPAREN: return (state, Parameter(n_pattern=n_pattern, t_comma=None)) if token.kind == TokenKind.COMMA: (state, t_comma) = self.expect_token(state, [TokenKind.COMMA]) return (state, Parameter(n_pattern=n_pattern, t_comma=t_comma)) parameter_end_offset = ( parameter_start_offset + n_pattern.leading_width + n_pattern.width ) # The end offset is exclusive, so when the position is used as the # start offset, it's one character after the argument (where you # would expect the comma to go). parameter_position = state.file_info.get_position_for_offset( parameter_end_offset ) expected_comma_range = Range(start=parameter_position, end=parameter_position) error = Error( file_info=state.file_info, code=ErrorCode.EXPECTED_END_OF_PARAMETER_LIST, severity=Severity.ERROR, message=("I was expecting a ',' or ')' after the previous parameter."), notes=[], range=expected_comma_range, ) (state, t_comma) = self.expect_token(state, [TokenKind.COMMA], error=error) return (state, Parameter(n_pattern=n_pattern, t_comma=t_comma)) def parse_identifier_expr( self, state: State ) -> Tuple[State, Optional[IdentifierExpr]]: (state, t_identifier) = self.expect_token(state, [TokenKind.IDENTIFIER]) if t_identifier is None: return (state, None) return (state, IdentifierExpr(t_identifier=t_identifier)) def parse_int_literal(self, state: State) -> Tuple[State, Optional[IntLiteralExpr]]: (state, t_int_literal) = self.expect_token(state, [TokenKind.INT_LITERAL]) if t_int_literal is None: return (state, None) return (state, IntLiteralExpr(t_int_literal=t_int_literal)) def expect_token( self, state: State, possible_token_kinds: List[TokenKind], *, notes: List[Note] = None, error: Error = None, ) -> Tuple[State, Optional[Token]]: token = state.get_current_token() if token.kind in possible_token_kinds: if state.is_recovering: state = state.finish_recovery() if token.kind != TokenKind.EOF: state = state.consume_token(token) return (state, token) if state.is_recovering: return (state, None) assert len(possible_token_kinds) > 0 if len(possible_token_kinds) == 1: possible_tokens_str = self.describe_token_kind(possible_token_kinds[0]) elif len(possible_token_kinds) == 2: possible_tokens_str = " or ".join( [ self.describe_token_kind(possible_token_kinds[0]), possible_token_kinds[1].value, ] ) else: possible_tokens_str = ", ".join( token.value for token in possible_token_kinds[:-1] ) possible_tokens_str += ", or " + possible_token_kinds[-1].value if error is None: message = ( f"I was expecting {possible_tokens_str}, " + f"but instead got {self.describe_token(token)}." ) error = Error( file_info=state.file_info, code=ErrorCode.UNEXPECTED_TOKEN, severity=Severity.ERROR, message=message, notes=[], range=state.current_token_range, ) state = self.add_error_and_recover(state, error) token = state.get_current_token() if token.kind in possible_token_kinds: # We recovered to a token that the caller happens to be able to # handle, so return it directly. if token.kind != TokenKind.EOF: state = state.consume_token(token) state = state.finish_recovery() return (state, token) return (state, None) def describe_token(self, token: Token) -> str: if token.kind == TokenKind.ERROR: return f"the invalid token '{token.text}'" return self.describe_token_kind(token.kind) def describe_token_kind(self, token_kind: TokenKind) -> str: if token_kind.value.startswith("the "): return token_kind.value vowels = ["a", "e", "i", "o", "u"] if any(token_kind.value.strip("'").startswith(vowel) for vowel in vowels): return f"an {token_kind.value}" else: return f"a {token_kind.value}" def parse(file_info: FileInfo, tokens: List[Token]) -> Parsation: parser = Parser() return parser.parse(file_info=file_info, tokens=tokens) def dump_syntax_tree( source_code: str, ast_node: Union[Node, Token, None], offset: int = 0 ) -> Tuple[int, List[str]]: if ast_node is None: return (offset, [""]) elif isinstance(ast_node, Token): token = ast_node lines = [] for trivium in token.leading_trivia: offset += trivium.width lines.append(f"Leading {trivium.text!r}") offset += token.width if token.is_dummy: lines.append(f"Token {token.kind.name} {token.text!r}") else: lines.append(f"Token {token.text!r}") for trivium in token.trailing_trivia: offset += trivium.width lines.append(f"Trailing {trivium.text!r}") return (offset, lines) else: lines = [f"{ast_node.__class__.__name__}"] for child in ast_node.children: (offset, rendered_child) = dump_syntax_tree(source_code, child, offset) lines.extend(f" {subline}" for subline in rendered_child) return (offset, lines) PK!Y6O\O\pytch/redcst.py"""NOTE: This file auto-generated from ast.txt. Run `bin/generate_syntax_trees.sh` to re-generate. Do not edit! """ from typing import List, Optional, Sequence, Union import pytch.greencst as greencst from .lexer import Token from .utils import OffsetRange class Node: def __init__(self, parent: Optional["Node"]) -> None: self._parent = parent @property def parent(self) -> Optional["Node"]: return self._parent @property def text(self) -> str: raise NotImplementedError( f"class {self.__class__.__name__} should implement `text`" ) @property def full_text(self) -> str: raise NotImplementedError( f"class {self.__class__.__name__} should implement `full_text`" ) @property def children(self) -> Sequence[Union["Node", Optional["Token"]]]: raise NotImplementedError( f"class {self.__class__.__name__} should implement `children`" ) @property def full_width(self) -> int: raise NotImplementedError( f"class {self.__class__.__name__} should implement `full_width`" ) @property def offset_range(self) -> OffsetRange: raise NotImplementedError( f"class {self.__class__.__name__} should implement `offset_range`" ) class Expr(Node): pass class SyntaxTree(Node): def __init__( self, parent: Optional[Node], origin: greencst.SyntaxTree, offset: int ) -> None: super().__init__(parent) self.origin = origin self.offset = offset self._n_expr: Optional[Expr] = None @property def n_expr(self) -> Optional[Expr]: if self.origin.n_expr is None: return None if self._n_expr is not None: return self._n_expr offset = self.offset result = GREEN_TO_RED_NODE_MAP[self.origin.n_expr.__class__]( parent=self, origin=self.origin.n_expr, offset=offset ) self._n_expr = result return result @property def t_eof(self) -> Optional[Token]: return self.origin.t_eof @property def text(self) -> str: return self.origin.text @property def full_text(self) -> str: return self.origin.full_text @property def full_width(self) -> int: return self.origin.full_width @property def offset_range(self) -> OffsetRange: start = self.offset + self.origin.leading_width return OffsetRange(start=start, end=start + self.origin.width) @property def children(self) -> List[Optional[Union[Token, Node]]]: return [self.n_expr, self.t_eof] class Pattern(Node): pass class VariablePattern(Pattern): def __init__( self, parent: Optional[Node], origin: greencst.VariablePattern, offset: int ) -> None: super().__init__(parent) self.origin = origin self.offset = offset @property def t_identifier(self) -> Optional[Token]: return self.origin.t_identifier @property def text(self) -> str: return self.origin.text @property def full_text(self) -> str: return self.origin.full_text @property def full_width(self) -> int: return self.origin.full_width @property def offset_range(self) -> OffsetRange: start = self.offset + self.origin.leading_width return OffsetRange(start=start, end=start + self.origin.width) @property def children(self) -> List[Optional[Union[Token, Node]]]: return [self.t_identifier] class Parameter(Node): def __init__( self, parent: Optional[Node], origin: greencst.Parameter, offset: int ) -> None: super().__init__(parent) self.origin = origin self.offset = offset self._n_pattern: Optional[Pattern] = None @property def n_pattern(self) -> Optional[Pattern]: if self.origin.n_pattern is None: return None if self._n_pattern is not None: return self._n_pattern offset = self.offset result = GREEN_TO_RED_NODE_MAP[self.origin.n_pattern.__class__]( parent=self, origin=self.origin.n_pattern, offset=offset ) self._n_pattern = result return result @property def t_comma(self) -> Optional[Token]: return self.origin.t_comma @property def text(self) -> str: return self.origin.text @property def full_text(self) -> str: return self.origin.full_text @property def full_width(self) -> int: return self.origin.full_width @property def offset_range(self) -> OffsetRange: start = self.offset + self.origin.leading_width return OffsetRange(start=start, end=start + self.origin.width) @property def children(self) -> List[Optional[Union[Token, Node]]]: return [self.n_pattern, self.t_comma] class ParameterList(Node): def __init__( self, parent: Optional[Node], origin: greencst.ParameterList, offset: int ) -> None: super().__init__(parent) self.origin = origin self.offset = offset self._parameters: Optional[List[Parameter]] = None @property def t_lparen(self) -> Optional[Token]: return self.origin.t_lparen @property def parameters(self) -> Optional[List[Parameter]]: if self.origin.parameters is None: return None if self._parameters is not None: return self._parameters offset = self.offset + ( self.t_lparen.full_width if self.t_lparen is not None else 0 ) result = [] for child in self.origin.parameters: result.append(Parameter(parent=self, origin=child, offset=offset)) offset += child.full_width self._parameters = result return result @property def t_rparen(self) -> Optional[Token]: return self.origin.t_rparen @property def text(self) -> str: return self.origin.text @property def full_text(self) -> str: return self.origin.full_text @property def full_width(self) -> int: return self.origin.full_width @property def offset_range(self) -> OffsetRange: start = self.offset + self.origin.leading_width return OffsetRange(start=start, end=start + self.origin.width) @property def children(self) -> List[Optional[Union[Token, Node]]]: return [ self.t_lparen, *(self.parameters if self.parameters is not None else []), self.t_rparen, ] class LetExpr(Expr): def __init__( self, parent: Optional[Node], origin: greencst.LetExpr, offset: int ) -> None: super().__init__(parent) self.origin = origin self.offset = offset self._n_pattern: Optional[Pattern] = None self._n_parameter_list: Optional[ParameterList] = None self._n_value: Optional[Expr] = None self._n_body: Optional[Expr] = None @property def t_let(self) -> Optional[Token]: return self.origin.t_let @property def n_pattern(self) -> Optional[Pattern]: if self.origin.n_pattern is None: return None if self._n_pattern is not None: return self._n_pattern offset = self.offset + (self.t_let.full_width if self.t_let is not None else 0) result = GREEN_TO_RED_NODE_MAP[self.origin.n_pattern.__class__]( parent=self, origin=self.origin.n_pattern, offset=offset ) self._n_pattern = result return result @property def n_parameter_list(self) -> Optional[ParameterList]: if self.origin.n_parameter_list is None: return None if self._n_parameter_list is not None: return self._n_parameter_list offset = ( self.offset + (self.t_let.full_width if self.t_let is not None else 0) + (self.n_pattern.full_width if self.n_pattern is not None else 0) ) result = ParameterList( parent=self, origin=self.origin.n_parameter_list, offset=offset ) self._n_parameter_list = result return result @property def t_equals(self) -> Optional[Token]: return self.origin.t_equals @property def n_value(self) -> Optional[Expr]: if self.origin.n_value is None: return None if self._n_value is not None: return self._n_value offset = ( self.offset + (self.t_let.full_width if self.t_let is not None else 0) + (self.n_pattern.full_width if self.n_pattern is not None else 0) + ( self.n_parameter_list.full_width if self.n_parameter_list is not None else 0 ) + (self.t_equals.full_width if self.t_equals is not None else 0) ) result = GREEN_TO_RED_NODE_MAP[self.origin.n_value.__class__]( parent=self, origin=self.origin.n_value, offset=offset ) self._n_value = result return result @property def t_in(self) -> Optional[Token]: return self.origin.t_in @property def n_body(self) -> Optional[Expr]: if self.origin.n_body is None: return None if self._n_body is not None: return self._n_body offset = ( self.offset + (self.t_let.full_width if self.t_let is not None else 0) + (self.n_pattern.full_width if self.n_pattern is not None else 0) + ( self.n_parameter_list.full_width if self.n_parameter_list is not None else 0 ) + (self.t_equals.full_width if self.t_equals is not None else 0) + (self.n_value.full_width if self.n_value is not None else 0) + (self.t_in.full_width if self.t_in is not None else 0) ) result = GREEN_TO_RED_NODE_MAP[self.origin.n_body.__class__]( parent=self, origin=self.origin.n_body, offset=offset ) self._n_body = result return result @property def text(self) -> str: return self.origin.text @property def full_text(self) -> str: return self.origin.full_text @property def full_width(self) -> int: return self.origin.full_width @property def offset_range(self) -> OffsetRange: start = self.offset + self.origin.leading_width return OffsetRange(start=start, end=start + self.origin.width) @property def children(self) -> List[Optional[Union[Token, Node]]]: return [ self.t_let, self.n_pattern, self.n_parameter_list, self.t_equals, self.n_value, self.t_in, self.n_body, ] class IfExpr(Expr): def __init__( self, parent: Optional[Node], origin: greencst.IfExpr, offset: int ) -> None: super().__init__(parent) self.origin = origin self.offset = offset self._n_if_expr: Optional[Expr] = None self._n_then_expr: Optional[Expr] = None self._n_else_expr: Optional[Expr] = None @property def t_if(self) -> Optional[Token]: return self.origin.t_if @property def n_if_expr(self) -> Optional[Expr]: if self.origin.n_if_expr is None: return None if self._n_if_expr is not None: return self._n_if_expr offset = self.offset + (self.t_if.full_width if self.t_if is not None else 0) result = GREEN_TO_RED_NODE_MAP[self.origin.n_if_expr.__class__]( parent=self, origin=self.origin.n_if_expr, offset=offset ) self._n_if_expr = result return result @property def t_then(self) -> Optional[Token]: return self.origin.t_then @property def n_then_expr(self) -> Optional[Expr]: if self.origin.n_then_expr is None: return None if self._n_then_expr is not None: return self._n_then_expr offset = ( self.offset + (self.t_if.full_width if self.t_if is not None else 0) + (self.n_if_expr.full_width if self.n_if_expr is not None else 0) + (self.t_then.full_width if self.t_then is not None else 0) ) result = GREEN_TO_RED_NODE_MAP[self.origin.n_then_expr.__class__]( parent=self, origin=self.origin.n_then_expr, offset=offset ) self._n_then_expr = result return result @property def t_else(self) -> Optional[Token]: return self.origin.t_else @property def n_else_expr(self) -> Optional[Expr]: if self.origin.n_else_expr is None: return None if self._n_else_expr is not None: return self._n_else_expr offset = ( self.offset + (self.t_if.full_width if self.t_if is not None else 0) + (self.n_if_expr.full_width if self.n_if_expr is not None else 0) + (self.t_then.full_width if self.t_then is not None else 0) + (self.n_then_expr.full_width if self.n_then_expr is not None else 0) + (self.t_else.full_width if self.t_else is not None else 0) ) result = GREEN_TO_RED_NODE_MAP[self.origin.n_else_expr.__class__]( parent=self, origin=self.origin.n_else_expr, offset=offset ) self._n_else_expr = result return result @property def t_endif(self) -> Optional[Token]: return self.origin.t_endif @property def text(self) -> str: return self.origin.text @property def full_text(self) -> str: return self.origin.full_text @property def full_width(self) -> int: return self.origin.full_width @property def offset_range(self) -> OffsetRange: start = self.offset + self.origin.leading_width return OffsetRange(start=start, end=start + self.origin.width) @property def children(self) -> List[Optional[Union[Token, Node]]]: return [ self.t_if, self.n_if_expr, self.t_then, self.n_then_expr, self.t_else, self.n_else_expr, self.t_endif, ] class IdentifierExpr(Expr): def __init__( self, parent: Optional[Node], origin: greencst.IdentifierExpr, offset: int ) -> None: super().__init__(parent) self.origin = origin self.offset = offset @property def t_identifier(self) -> Optional[Token]: return self.origin.t_identifier @property def text(self) -> str: return self.origin.text @property def full_text(self) -> str: return self.origin.full_text @property def full_width(self) -> int: return self.origin.full_width @property def offset_range(self) -> OffsetRange: start = self.offset + self.origin.leading_width return OffsetRange(start=start, end=start + self.origin.width) @property def children(self) -> List[Optional[Union[Token, Node]]]: return [self.t_identifier] class IntLiteralExpr(Expr): def __init__( self, parent: Optional[Node], origin: greencst.IntLiteralExpr, offset: int ) -> None: super().__init__(parent) self.origin = origin self.offset = offset @property def t_int_literal(self) -> Optional[Token]: return self.origin.t_int_literal @property def text(self) -> str: return self.origin.text @property def full_text(self) -> str: return self.origin.full_text @property def full_width(self) -> int: return self.origin.full_width @property def offset_range(self) -> OffsetRange: start = self.offset + self.origin.leading_width return OffsetRange(start=start, end=start + self.origin.width) @property def children(self) -> List[Optional[Union[Token, Node]]]: return [self.t_int_literal] class BinaryExpr(Expr): def __init__( self, parent: Optional[Node], origin: greencst.BinaryExpr, offset: int ) -> None: super().__init__(parent) self.origin = origin self.offset = offset self._n_lhs: Optional[Expr] = None self._n_rhs: Optional[Expr] = None @property def n_lhs(self) -> Optional[Expr]: if self.origin.n_lhs is None: return None if self._n_lhs is not None: return self._n_lhs offset = self.offset result = GREEN_TO_RED_NODE_MAP[self.origin.n_lhs.__class__]( parent=self, origin=self.origin.n_lhs, offset=offset ) self._n_lhs = result return result @property def t_operator(self) -> Optional[Token]: return self.origin.t_operator @property def n_rhs(self) -> Optional[Expr]: if self.origin.n_rhs is None: return None if self._n_rhs is not None: return self._n_rhs offset = ( self.offset + (self.n_lhs.full_width if self.n_lhs is not None else 0) + (self.t_operator.full_width if self.t_operator is not None else 0) ) result = GREEN_TO_RED_NODE_MAP[self.origin.n_rhs.__class__]( parent=self, origin=self.origin.n_rhs, offset=offset ) self._n_rhs = result return result @property def text(self) -> str: return self.origin.text @property def full_text(self) -> str: return self.origin.full_text @property def full_width(self) -> int: return self.origin.full_width @property def offset_range(self) -> OffsetRange: start = self.offset + self.origin.leading_width return OffsetRange(start=start, end=start + self.origin.width) @property def children(self) -> List[Optional[Union[Token, Node]]]: return [self.n_lhs, self.t_operator, self.n_rhs] class Argument(Node): def __init__( self, parent: Optional[Node], origin: greencst.Argument, offset: int ) -> None: super().__init__(parent) self.origin = origin self.offset = offset self._n_expr: Optional[Expr] = None @property def n_expr(self) -> Optional[Expr]: if self.origin.n_expr is None: return None if self._n_expr is not None: return self._n_expr offset = self.offset result = GREEN_TO_RED_NODE_MAP[self.origin.n_expr.__class__]( parent=self, origin=self.origin.n_expr, offset=offset ) self._n_expr = result return result @property def t_comma(self) -> Optional[Token]: return self.origin.t_comma @property def text(self) -> str: return self.origin.text @property def full_text(self) -> str: return self.origin.full_text @property def full_width(self) -> int: return self.origin.full_width @property def offset_range(self) -> OffsetRange: start = self.offset + self.origin.leading_width return OffsetRange(start=start, end=start + self.origin.width) @property def children(self) -> List[Optional[Union[Token, Node]]]: return [self.n_expr, self.t_comma] class ArgumentList(Node): def __init__( self, parent: Optional[Node], origin: greencst.ArgumentList, offset: int ) -> None: super().__init__(parent) self.origin = origin self.offset = offset self._arguments: Optional[List[Argument]] = None @property def t_lparen(self) -> Optional[Token]: return self.origin.t_lparen @property def arguments(self) -> Optional[List[Argument]]: if self.origin.arguments is None: return None if self._arguments is not None: return self._arguments offset = self.offset + ( self.t_lparen.full_width if self.t_lparen is not None else 0 ) result = [] for child in self.origin.arguments: result.append(Argument(parent=self, origin=child, offset=offset)) offset += child.full_width self._arguments = result return result @property def t_rparen(self) -> Optional[Token]: return self.origin.t_rparen @property def text(self) -> str: return self.origin.text @property def full_text(self) -> str: return self.origin.full_text @property def full_width(self) -> int: return self.origin.full_width @property def offset_range(self) -> OffsetRange: start = self.offset + self.origin.leading_width return OffsetRange(start=start, end=start + self.origin.width) @property def children(self) -> List[Optional[Union[Token, Node]]]: return [ self.t_lparen, *(self.arguments if self.arguments is not None else []), self.t_rparen, ] class FunctionCallExpr(Expr): def __init__( self, parent: Optional[Node], origin: greencst.FunctionCallExpr, offset: int ) -> None: super().__init__(parent) self.origin = origin self.offset = offset self._n_callee: Optional[Expr] = None self._n_argument_list: Optional[ArgumentList] = None @property def n_callee(self) -> Optional[Expr]: if self.origin.n_callee is None: return None if self._n_callee is not None: return self._n_callee offset = self.offset result = GREEN_TO_RED_NODE_MAP[self.origin.n_callee.__class__]( parent=self, origin=self.origin.n_callee, offset=offset ) self._n_callee = result return result @property def n_argument_list(self) -> Optional[ArgumentList]: if self.origin.n_argument_list is None: return None if self._n_argument_list is not None: return self._n_argument_list offset = self.offset + ( self.n_callee.full_width if self.n_callee is not None else 0 ) result = ArgumentList( parent=self, origin=self.origin.n_argument_list, offset=offset ) self._n_argument_list = result return result @property def text(self) -> str: return self.origin.text @property def full_text(self) -> str: return self.origin.full_text @property def full_width(self) -> int: return self.origin.full_width @property def offset_range(self) -> OffsetRange: start = self.offset + self.origin.leading_width return OffsetRange(start=start, end=start + self.origin.width) @property def children(self) -> List[Optional[Union[Token, Node]]]: return [self.n_callee, self.n_argument_list] GREEN_TO_RED_NODE_MAP = { greencst.Expr: Expr, greencst.SyntaxTree: SyntaxTree, greencst.Pattern: Pattern, greencst.VariablePattern: VariablePattern, greencst.Parameter: Parameter, greencst.ParameterList: ParameterList, greencst.LetExpr: LetExpr, greencst.IfExpr: IfExpr, greencst.IdentifierExpr: IdentifierExpr, greencst.IntLiteralExpr: IntLiteralExpr, greencst.BinaryExpr: BinaryExpr, greencst.Argument: Argument, greencst.ArgumentList: ArgumentList, greencst.FunctionCallExpr: FunctionCallExpr, } PK!B pytch/repl.pyfrom code import InteractiveConsole import re import readline import sys from typing import Any, Dict, List, Optional, Sequence, Tuple from . import __version__ from .binder import bind, GLOBAL_SCOPE as BINDER_GLOBAL_SCOPE from .codegen import codegen from .errors import Error, get_error_lines, Severity from .lexer import lex from .parser import parse from .redcst import SyntaxTree as RedSyntaxTree from .typesystem import typecheck from .typesystem.builtins import GLOBAL_SCOPE as TYPESYSTEM_GLOBAL_SCOPE from .utils import FileInfo NO_MORE_INPUT_REQUIRED = False MORE_INPUT_REQUIRED = True LEADING_WHITESPACE_RE = re.compile(r"^\s*") class PytchRepl(InteractiveConsole): def __init__(self) -> None: super().__init__() self.buffer: List[str] = [] self.locals: Dict[str, Any] = {} self.all_source_code = "" readline.set_completer(lambda text, state: text + "foo") def push(self, line: str) -> bool: readline.insert_text("foo") if line: self.buffer.append(line) match = LEADING_WHITESPACE_RE.match(line) if match is not None: readline.insert_text(match.group()) return MORE_INPUT_REQUIRED source_code = "\n".join(self.buffer) self.resetbuffer() run_file( FileInfo(file_path="", source_code=self.all_source_code + source_code) ) return NO_MORE_INPUT_REQUIRED def interact() -> None: PytchRepl().interact(banner=f"Pytch version {__version__} REPL", exitmsg="") def run_file(file_info: FileInfo) -> None: (compiled_output, errors) = compile_file(file_info=file_info) print_errors(errors) if compiled_output is not None: exec(compiled_output) def compile_file(file_info: FileInfo) -> Tuple[Optional[str], List[Error]]: all_errors = [] lexation = lex(file_info=file_info) all_errors.extend(lexation.errors) parsation = parse(file_info=file_info, tokens=lexation.tokens) all_errors.extend(parsation.errors) if has_fatal_error(all_errors): return (None, all_errors) syntax_tree = RedSyntaxTree(parent=None, origin=parsation.green_cst, offset=0) bindation = bind( file_info=file_info, syntax_tree=syntax_tree, global_scope=BINDER_GLOBAL_SCOPE ) all_errors.extend(bindation.errors) if has_fatal_error(all_errors): return (None, all_errors) typeation = typecheck( file_info=file_info, syntax_tree=syntax_tree, bindation=bindation, global_scope=TYPESYSTEM_GLOBAL_SCOPE, ) all_errors.extend(typeation.errors) if has_fatal_error(all_errors): return (None, all_errors) codegenation = codegen( syntax_tree=syntax_tree, bindation=bindation, typeation=typeation ) all_errors.extend(codegenation.errors) if has_fatal_error(all_errors): return (None, all_errors) return (codegenation.get_compiled_output(), all_errors) def has_fatal_error(errors: Sequence[Error]) -> bool: return any(error.severity == Severity.ERROR for error in errors) def print_errors(errors: List[Error]) -> None: ascii = not sys.stderr.isatty() for error in errors: sys.stderr.write("\n".join(get_error_lines(error, ascii=ascii)) + "\n") PK!L}spytch/syntax_tree.txtExpr(Node) SyntaxTree(Node) n_expr: Optional[Expr] t_eof: Optional[Token] Pattern(Node) VariablePattern(Pattern) t_identifier: Optional[Token] Parameter(Node) n_pattern: Optional[Pattern] t_comma: Optional[Token] ParameterList(Node) t_lparen: Optional[Token] parameters: Optional[List[Parameter]] t_rparen: Optional[Token] LetExpr(Expr) t_let: Optional[Token] n_pattern: Optional[Pattern] # Only present if the `let`-expr is a function declaration. If present, # `n_pattern` is expected to be a `VariablePattern`. n_parameter_list: Optional[ParameterList] t_equals: Optional[Token] n_value: Optional[Expr] t_in: Optional[Token] n_body: Optional[Expr] IfExpr(Expr) t_if: Optional[Token] n_if_expr: Optional[Expr] t_then: Optional[Token] n_then_expr: Optional[Expr] # The `else` case is optional. t_else: Optional[Token] n_else_expr: Optional[Expr] t_endif: Optional[Token] IdentifierExpr(Expr) t_identifier: Optional[Token] IntLiteralExpr(Expr) t_int_literal: Optional[Token] BinaryExpr(Expr) n_lhs: Optional[Expr] t_operator: Optional[Token] n_rhs: Optional[Expr] Argument(Node) n_expr: Optional[Expr] t_comma: Optional[Token] ArgumentList(Node) t_lparen: Optional[Token] arguments: Optional[List[Argument]] t_rparen: Optional[Token] FunctionCallExpr(Expr) n_callee: Optional[Expr] n_argument_list: Optional[ArgumentList] PK! <pytch/typesystem/__init__.py"""Type inference and typechecking. The Pytch type system is a bidirectional typechecking system, based off of the system described in [Dunfield 2013] (see Figures 9-11 for the algorithmic typing rules for the system). A standard Hindley-Milner type system would be difficult to reconcile in the presence of subtyping, which will naturally occur when interfacing with Python code. [Dunfield 2013]: https://www.cl.cam.ac.uk/~nk480/bidir.pdf Terminology used in this module: * `ty`: "type". * `ctx`: "context". * `env`: "environment". This only refers to the usual sort of global configuration that's passed around, rather than a typing environment (Γ), which is called a "context" instead. * `var`: "variable", specifically a type variable of some sort. * The spelling "judgment" is preferred over "judgement". """ from .typecheck import Typeation, typecheck __all__ = ["Typeation", "typecheck"] PK!TS-pytch/typesystem/builtins.pyfrom pytch.containers import PMap, PVector from .reason import BuiltinReason from .types import BaseTy, FunctionTy, Ty, TyVar, UniversalTy ERR_TY = BaseTy(name="", reason=BuiltinReason(name="")) """Error type. Produced when there is a typechecking error, in order to prevent cascading failure messages. """ NONE_TY = BaseTy(name="None", reason=BuiltinReason(name="None")) """None type, corresponding to Python's `None` value.""" OBJECT_TY = BaseTy(name="object", reason=BuiltinReason(name="object")) """Object type. This is the top type since everything is an object.""" VOID_TY = BaseTy(name="", reason=BuiltinReason(name="")) """Void type. Denotes the lack of a value. The Python runtime has no concept of "void": functions which don't `return` anything implicitly return `None`. However, there are some cases where it would be dangerous to allow the user to return `None` implicitly. For example, implicitly assigning `None` to `foo` here was probably not intended: ``` let foo = if cond() then "some value" ``` """ INT_TY = BaseTy(name="int", reason=BuiltinReason(name="int")) """Integer type, corresponding to Python's `int` type.""" top_ty_reason = BuiltinReason(name="") top_ty_var = TyVar(name="top_ty", reason=top_ty_reason) TOP_TY = UniversalTy(quantifier_ty=top_ty_var, ty=top_ty_var, reason=top_ty_reason) """The top type. All types are a subtype of this type, including void. You likely want to use the `object` type instead. """ def _make_print() -> FunctionTy: print_reason = BuiltinReason(name="print") # TODO: this may have to be some kind of `ArgumentTy` instead, so that it # can have its own reason, and so that it can take on a label. domain: PVector[Ty] = PVector([OBJECT_TY]) codomain = NONE_TY return FunctionTy(domain=domain, codomain=codomain, reason=print_reason) # TODO: add the Python builtins to the global scope. GLOBAL_SCOPE: PMap[str, Ty] = PMap({"None": NONE_TY, "print": _make_print()}) PK!pytch/typesystem/judgments.pyimport attr from pytch.redcst import Pattern from .types import ExistentialTyVar, Ty, TyVar class TypingJudgment: """Abstract base class for typing judgments. See Dunfield 2013 Figure 6 for the list of possible judgments. """ pass @attr.s(auto_attribs=True, frozen=True) class DeclareVarJudgment(TypingJudgment): variable: TyVar @attr.s(auto_attribs=True, frozen=True) class PatternHasTyJudgment(TypingJudgment): pattern: Pattern ty: Ty @attr.s(auto_attribs=True, frozen=True) class DeclareExistentialVarJudgment(TypingJudgment): existential_ty_var: ExistentialTyVar @attr.s(auto_attribs=True, frozen=True) class ExistentialVariableHasTyJudgment(TypingJudgment): existential_ty_var: ExistentialTyVar ty: Ty @attr.s(auto_attribs=True, frozen=True) class ExistentialVariableMarkerJudgment(TypingJudgment): """Creates a new 'scope' in which to solve existential type variables.""" existential_ty_var: ExistentialTyVar PK!P N))pytch/typesystem/reason.pyimport attr from .types import ExistentialTyVar, Ty class Reason: def __str__(self) -> str: raise NotImplementedError( f"__str__ not implemented for reason class {self.__class__.__name__}" ) @attr.s(auto_attribs=True, frozen=True) class TodoReason(Reason): todo: str def __str__(self) -> str: return f"reason not implemented for case {self.todo}" @attr.s(auto_attribs=True, frozen=True) class BuiltinReason(Reason): name: str def __str__(self) -> str: return f"{self.name} is a built-in" @attr.s(auto_attribs=True, frozen=True) class InvalidSyntaxReason(Reason): def __str__(self) -> str: return ( "there was invalid syntax, so I assumed " + "that bit of code typechecked and proceeded " + "with checking the rest of the program" ) @attr.s(auto_attribs=True, frozen=True) class EqualTysReason(Reason): lhs: Ty rhs: Ty def __str__(self) -> str: return "because the two types were equal" @attr.s(auto_attribs=True, frozen=True) class InstantiateExistentialReason(Reason): existential_ty_var: ExistentialTyVar to: Ty def __str__(self) -> str: return "because the type was determined to be the other type" @attr.s(auto_attribs=True, frozen=True) class SubtypeOfObjectReason(Reason): def __str__(self) -> str: return "all types are subtypes of object" @attr.s(auto_attribs=True, frozen=True) class SubtypeOfUnboundedGenericReason(Reason): def __str__(self) -> str: return "it was checked to be the subtype of an generic type parameter" @attr.s(auto_attribs=True, frozen=True) class NoneIsSubtypeOfVoidReason(Reason): def __str__(self) -> str: return "None is the only value that can be used where no value is expected" PK!Ĭ]]pytch/typesystem/typecheck.pyfrom typing import List, Optional, Tuple import attr from pytch.binder import Bindation from pytch.containers import find, PMap, PVector, take_while from pytch.errors import Error, ErrorCode, Note, Severity from pytch.lexer import TokenKind from pytch.redcst import ( Argument, BinaryExpr, Expr, FunctionCallExpr, IdentifierExpr, IfExpr, IntLiteralExpr, LetExpr, Node, Parameter, SyntaxTree, VariablePattern, ) from pytch.utils import FileInfo, Range from .builtins import ERR_TY, INT_TY, NONE_TY, OBJECT_TY, TOP_TY, VOID_TY from .judgments import ( DeclareExistentialVarJudgment, DeclareVarJudgment, ExistentialVariableHasTyJudgment, ExistentialVariableMarkerJudgment, PatternHasTyJudgment, TypingJudgment, ) from .reason import ( EqualTysReason, InstantiateExistentialReason, InvalidSyntaxReason, NoneIsSubtypeOfVoidReason, Reason, SubtypeOfObjectReason, SubtypeOfUnboundedGenericReason, TodoReason, ) from .types import BaseTy, ExistentialTyVar, FunctionTy, MonoTy, Ty, TyVar, UniversalTy @attr.s(auto_attribs=True, frozen=True) class Env: file_info: FileInfo bindation: Bindation global_scope: PMap[str, Ty] errors: PVector[Error] def get_range_for_node(self, node: Node) -> Range: """Get the range corresponding to node. Note that for `let`-expressions, we don't want to flag the entire range. Instead, we only want to flag the innermost `let`-expression body. """ while isinstance(node, LetExpr): n_body = node.n_body if n_body is None: break else: node = n_body return self.file_info.get_range_from_offset_range(node.offset_range) def add_error( self, code: ErrorCode, severity: Severity, message: str, notes: List[Note] = None, range: Range = None, ) -> "Env": error = Error( file_info=self.file_info, code=code, severity=severity, message=message, notes=notes if notes is not None else [], range=range, ) return attr.evolve(self, errors=self.errors.append(error)) @attr.s(auto_attribs=True, frozen=True) class TypingContext: judgments: PVector[TypingJudgment] inferred_tys: PMap[Expr, Ty] def add_judgment(self, judgment: TypingJudgment) -> "TypingContext": return attr.evolve(self, judgments=self.judgments.append(judgment)) def ty_to_string(self, ty: Ty) -> str: if isinstance(ty, BaseTy): return ty.name else: raise NotImplementedError(f"ty_to_string not implemented for type: {ty!r}") def take_until_before_judgment(self, judgment: TypingJudgment) -> "TypingContext": judgments = PVector(take_while(self.judgments, lambda x: x != judgment)) assert len(judgments) < len( self.judgments ), f"take_until_before_judgment: expected to find judgment {judgment!r} in context {self.judgments!r}" return attr.evolve(self, judgments=judgments) def apply_as_substitution(self, ty: Ty) -> Ty: """See Dunfield 2013 Figure 7.""" if isinstance(ty, MonoTy): return ty elif isinstance(ty, BaseTy): return ty elif isinstance(ty, ExistentialTyVar): def is_substitution_for_existential_ty_variable( judgment: TypingJudgment ) -> bool: if isinstance(judgment, ExistentialVariableHasTyJudgment): return judgment.existential_ty_var == ty else: return False existential_ty_variable_substitution = find( self.judgments, is_substitution_for_existential_ty_variable ) if existential_ty_variable_substitution is not None: assert isinstance( existential_ty_variable_substitution, ExistentialTyVar ) return existential_ty_variable_substitution.ty else: return ty elif isinstance(ty, FunctionTy): domain = ty.domain.map(self.apply_as_substitution) codomain = self.apply_as_substitution(ty.codomain) return FunctionTy(domain=domain, codomain=codomain, reason=ty.reason) elif isinstance(ty, UniversalTy): return UniversalTy( quantifier_ty=ty.quantifier_ty, ty=self.apply_as_substitution(ty), reason=ty.reason, ) else: assert ( False ), f"Unhandled case for typing context substitution: {ty.__class__.__name__}" def instantiate_existential( self, existential_ty_var: ExistentialTyVar, to: Ty ) -> "TypingContext": def f(x: TypingJudgment) -> TypingJudgment: if isinstance(x, DeclareExistentialVarJudgment): if x.existential_ty_var == existential_ty_var: return ExistentialVariableHasTyJudgment( existential_ty_var=existential_ty_var, ty=to ) return x return attr.evolve(self, judgments=self.judgments.map(f)) def record_infers(self, expr: Expr, ty: Ty) -> "TypingContext": return attr.evolve(self, inferred_tys=self.inferred_tys.set(expr, ty)) def get_infers(self, expr: Expr) -> Optional[Ty]: return self.inferred_tys.get(expr) def add_pattern_ty(self, pattern: VariablePattern, ty: Ty) -> "TypingContext": judgment = PatternHasTyJudgment(pattern=pattern, ty=ty) return attr.evolve(self, judgments=self.judgments.append(judgment)) def get_pattern_ty(self, pattern: VariablePattern) -> Optional[Ty]: for judgment in self.judgments: if isinstance(judgment, PatternHasTyJudgment): if judgment.pattern is pattern: return judgment.ty return None def push_existential_ty_var_marker( self, existential_ty_var: ExistentialTyVar ) -> "TypingContext": return attr.evolve( self, judgments=self.judgments.append( ExistentialVariableMarkerJudgment(existential_ty_var=existential_ty_var) ), ) def pop_existential_ty_var_marker( self, existential_ty_var: ExistentialTyVar ) -> "TypingContext": raise NotImplementedError("pop existential ty var not implemented") @attr.s(auto_attribs=True, frozen=True) class Typeation: ctx: TypingContext errors: List[Error] def tys_equal(lhs: Ty, rhs: Ty) -> bool: return lhs == rhs def do_infer(env: Env, ctx: TypingContext, expr: Expr) -> Tuple[Env, TypingContext, Ty]: if isinstance(expr, IntLiteralExpr): return (env, ctx, INT_TY) elif isinstance(expr, LetExpr): raise ValueError("should not be trying to infer the type of a let-expr (?)") elif isinstance(expr, FunctionCallExpr): # Γ ⊢ e1 ⇒ A ⊣ Θ Θ ⊢ [Θ]A•e2 ⇒⇒ C ⊣ ∆ # -------------------------------------- →E # Γ ⊢ e1 e2 ⇒ C ⊣ ∆ n_callee = expr.n_callee if n_callee is None: raise NotImplementedError("TODO(missing): handle missing callee") (env, ctx, callee_ty) = infer(env, ctx=ctx, expr=n_callee) n_argument_list = expr.n_argument_list if n_argument_list is None: raise NotImplementedError("TODO(missing): handle missing argument list") arguments = n_argument_list.arguments if arguments is None: raise NotImplementedError("TODO(missing): handle missing argument list") callee_ty = ctx.apply_as_substitution(callee_ty) return function_application_infer( env, ctx=ctx, ty=callee_ty, arguments=PVector(arguments) ) elif isinstance(expr, IdentifierExpr): target = env.bindation.get(expr) if target is None: raise NotImplementedError("TODO: handle absent type for identifier") if len(target) == 0: # Binding exists, but there is no definition. It must be a global # binding. result_ty = env.global_scope.get(expr.text) if result_ty is not None: return (env, ctx, result_ty) raise NotImplementedError( f"TODO: handle absent global variable type for {expr.text}" ) elif len(target) == 1: pattern = target[0] result_ty = ctx.get_pattern_ty(pattern) if result_ty is None: raise NotImplementedError("TODO: handle absent type for expr") return (env, ctx, result_ty) else: raise NotImplementedError( "TODO: handle multiple possible source definitions" ) elif isinstance(expr, BinaryExpr): t_operator = expr.t_operator if t_operator is None: raise NotImplementedError( "TODO(missing): handle missing operator in binary expression" ) if t_operator.kind == TokenKind.DUMMY_SEMICOLON: n_lhs = expr.n_lhs if n_lhs is not None: (env, ctx, _reason) = check(env, ctx, expr=n_lhs, ty=TOP_TY) n_rhs = expr.n_rhs if n_rhs is not None: return infer(env, ctx, expr=n_rhs) else: return (env, ctx, ERR_TY) elif t_operator.kind == TokenKind.PLUS: # TODO: do something more sophisticated. n_lhs = expr.n_lhs if n_lhs is not None: (env, ctx, checks) = check(env, ctx, expr=n_lhs, ty=INT_TY) if not checks: raise NotImplementedError("TODO: handle + on non-int LHS operand") n_rhs = expr.n_rhs if n_rhs is not None: (env, ctx, checks) = check(env, ctx, expr=n_rhs, ty=INT_TY) if not checks: raise NotImplementedError("TODO: handle + on non-int RHS operand") return (env, ctx, INT_TY) else: raise NotImplementedError( f"`infer` not yet implemented for binary expression operator kind {t_operator.kind}" ) elif isinstance(expr, IfExpr): n_then_expr = expr.n_then_expr if n_then_expr is None: return (env, ctx, ERR_TY) # TODO: should we actually mint a new existential type variable, then # infer it against the two clauses? We could also mint two existential # type variables, and then produce the union of them. n_else_expr = expr.n_else_expr if n_else_expr is None: result_ty = VOID_TY else: (env, ctx, result_ty) = infer(env, ctx, n_else_expr) (env, ctx, _reason) = check(env, ctx, n_then_expr, result_ty) return (env, ctx, result_ty) else: raise NotImplementedError( f"TODO: `infer` not yet implemented for expression type: {expr.__class__.__name__}" ) def infer(env: Env, ctx: TypingContext, expr: Expr) -> Tuple[Env, TypingContext, Ty]: (env, ctx, ty) = do_infer(env, ctx, expr) ctx = ctx.record_infers(expr, ty) return (env, ctx, ty) def infer_function_definition( env: Env, ctx: TypingContext, expr: LetExpr ) -> Tuple[Env, TypingContext, Ty]: def error(ctx: TypingContext): function_ty = ERR_TY n_pattern = expr.n_pattern assert isinstance( n_pattern, VariablePattern ), "Function let-exprs should be VariablePatterns" ctx = ctx.add_pattern_ty(pattern=n_pattern, ty=function_ty) return (env, ctx, function_ty) n_parameter_list = expr.n_parameter_list if n_parameter_list is None: return error(ctx) parameter_list = n_parameter_list.parameters if parameter_list is None: return error(ctx) parameters: PVector[Optional[Parameter]] = PVector() for n_parameter in parameter_list: if n_parameter is None: return error(ctx) parameters = parameters.append(n_parameter) n_value = expr.n_value if n_value is None: return error(ctx) return infer_lambda(env, ctx, parameters=parameters, body=n_value) def function_application_infer( env: Env, ctx: TypingContext, ty: Ty, arguments: PVector[Argument] ) -> Tuple[Env, TypingContext, Ty]: """The function-application relation ⇒⇒, discussed in Dunfield 2013.""" if isinstance(ty, FunctionTy): if len(arguments) != len(ty.domain): raise NotImplementedError("TODO: handle argument number mismatch") # TODO: use `izip_longest` here instead (not known to Mypy?) for argument, argument_ty in zip(arguments, ty.domain): n_expr = argument.n_expr if n_expr is None: raise NotImplementedError("TODO(missing): handle missing argument") (env, ctx, _reason) = check(env, ctx, expr=n_expr, ty=argument_ty) return (env, ctx, ty.codomain) elif isinstance(ty, ExistentialTyVar): raise NotImplementedError() else: assert ( False ), f"Unexpected function_application_infer type: {ty.__class__.__name__}" def check_lambda(): """Check the type of a lambda or function definition. The typing rule is Γ, x:A ⊢ e ⇐ B ⊣ ∆, x:A, Θ -------------------------- →I Γ ⊢ λx.e ⇐ A→B ⊣ ∆ """ raise NotImplementedError("check_lambda") def infer_lambda( env: Env, ctx: TypingContext, parameters: PVector[Optional[Parameter]], body: Expr ) -> Tuple[Env, TypingContext, Ty]: """Infer the type of a lambda or function definition. The typing rule is Γ, â, bˆ, x:â ⊢ e ⇐ bˆ ⊣ ∆, x:â, Θ ----------------------------------- →I⇒ Γ ⊢ λx.e ⇒ â→bˆ ⊣ ∆ which must be generalized here to handle multiple parameters. """ until_judgment = None parameter_tys = [] for i, parameter in enumerate(parameters): if parameter is None: continue n_pattern = parameter.n_pattern if n_pattern is None: continue if not isinstance(n_pattern, VariablePattern): raise NotImplementedError( "TODO: patterns other than VariablePattern not supported" ) parameter_ty = ExistentialTyVar( name=f"param_{i}", reason=TodoReason(todo="parameter ty") ) parameter_tys.append(parameter_ty) judgment = DeclareExistentialVarJudgment(existential_ty_var=parameter_ty) ctx = ctx.add_judgment(judgment) ctx = ctx.add_pattern_ty(pattern=n_pattern, ty=parameter_ty) if until_judgment is None: until_judgment = judgment return_ty = ExistentialTyVar(name="return", reason=TodoReason(todo="return ty")) return_judgment = DeclareExistentialVarJudgment(existential_ty_var=return_ty) if until_judgment is None: until_judgment = return_judgment ctx = ctx.add_judgment(return_judgment) env, ctx, checks = check(env, ctx=ctx, expr=body, ty=return_ty) ctx = ctx.take_until_before_judgment(judgment=until_judgment) function_ty = FunctionTy( domain=PVector(parameter_tys), codomain=return_ty, reason=TodoReason(todo="infer_lambda"), ) return (env, ctx, function_ty) def check( env: Env, ctx: TypingContext, expr: Expr, ty: Ty ) -> Tuple[Env, TypingContext, Optional[Reason]]: if isinstance(expr, LetExpr): # The typing rule for let-bindings is # # Ψ ⊢ e ⇒ A Ψ, x:A ⊢ e' ⇐ C # --------------------------- let # Ψ ⊢ let x = e in e' ⇐ C # # Note that we have to adapt this for function definitions by also # using the rule for typing lambdas. n_pattern = expr.n_pattern if n_pattern is None: return (env, ctx, InvalidSyntaxReason()) if expr.n_parameter_list is None: n_value = expr.n_value if n_value is None: return (env, ctx, InvalidSyntaxReason()) (env, ctx, value_ty) = infer(env, ctx, n_value) if not isinstance(n_pattern, VariablePattern): raise NotImplementedError( "TODO: patterns other than VariablePattern not supported" ) ctx = ctx.add_pattern_ty(n_pattern, value_ty) if tys_equal(value_ty, VOID_TY): env = env.add_error( code=ErrorCode.CANNOT_BIND_TO_VOID, severity=Severity.ERROR, message=( f"This expression has type {ctx.ty_to_string(VOID_TY)}, " + "so it cannot be bound to a variable." ), range=env.get_range_for_node(n_value), notes=[ Note( file_info=env.file_info, message="This is the variable it's being bound to.", range=env.get_range_for_node(n_pattern), ) ], ) n_body = expr.n_body if n_body is None: return (env, ctx, InvalidSyntaxReason()) return check(env, ctx, expr=n_body, ty=ty) else: n_parameter_list = expr.n_parameter_list if n_parameter_list is None: return (env, ctx, True) parameters = n_parameter_list.parameters if parameters is None: raise NotImplementedError( "TODO(missing): raise error for missing parameters" ) (env, ctx, function_ty) = infer_function_definition(env, ctx, expr) n_body = expr.n_body if n_body is None: raise NotImplementedError( "TODO(missing): raise error for missing function body" ) # parameters_with_tys: PVector[CtxElemExprHasTy] = PVector() # for n_argument, argument_ty in zip(parameters, ty.domain): # n_expr = n_argument.n_expr # if n_expr is None: # return (env, ctx, True) # parameters_with_tys = parameters_with_tys.append( # CtxElemExprHasTy(expr=n_expr, ty=argument_ty) # ) # ctx = ctx.add_elem(CtxElemExprsHaveTys(expr_tys=parameters_with_tys)) assert isinstance(n_pattern, VariablePattern) ctx = ctx.add_pattern_ty(n_pattern, function_ty) return check(env, ctx, expr=n_body, ty=ty) else: (env, ctx, actual_ty) = infer(env, ctx, expr=expr) (env, ctx, reason) = check_subtype(env, ctx, lhs=actual_ty, rhs=ty) if reason is not None: return (env, ctx, reason) else: env = env.add_error( code=ErrorCode.INCOMPATIBLE_TYPES, severity=Severity.ERROR, message=( f"I was expecting this expression to have type {ctx.ty_to_string(ty)}, " + f"but it actually had type {ctx.ty_to_string(actual_ty)}." ), range=env.get_range_for_node(expr), ) return (env, ctx, reason) def check_subtype( env: Env, ctx: TypingContext, lhs: Ty, rhs: Ty ) -> Tuple[Env, TypingContext, Optional[Reason]]: if tys_equal(lhs, rhs): return (env, ctx, EqualTysReason(lhs=lhs, rhs=rhs)) elif isinstance(lhs, UniversalTy): # TODO: implement return (env, ctx, TodoReason(todo="UniversalTy")) elif isinstance(rhs, UniversalTy): judgment = DeclareVarJudgment(variable=rhs.quantifier_ty) ctx = ctx.add_judgment(judgment) (env, ctx, checks) = check_subtype(env, ctx, lhs=lhs, rhs=rhs.ty) ctx = ctx.take_until_before_judgment(judgment) return (env, ctx, checks) elif isinstance(lhs, ExistentialTyVar): # <:InstantiateL # TODO: check to see that the existential type variable is not in the # free variables of the right-hand side. return instantiate_lhs_existential(env, ctx, lhs=lhs, rhs=rhs) elif isinstance(rhs, ExistentialTyVar): # <:InstantiateR # TODO: check to see that the existential type variable is not in the # free variables of the right-hand side. return instantiate_rhs_existential(env, ctx, lhs=lhs, rhs=rhs) elif isinstance(rhs, TyVar): return (env, ctx, SubtypeOfUnboundedGenericReason()) elif tys_equal(rhs, VOID_TY): assert lhs != VOID_TY, "should be handled in tys_equal case" (env, ctx, reason) = check_subtype(env, ctx, lhs=lhs, rhs=NONE_TY) if reason is not None: return (env, ctx, NoneIsSubtypeOfVoidReason()) else: return (env, ctx, None) elif tys_equal(rhs, OBJECT_TY): if tys_equal(lhs, VOID_TY): return (env, ctx, None) else: return (env, ctx, SubtypeOfObjectReason()) elif isinstance(lhs, BaseTy) and isinstance(rhs, BaseTy): assert not tys_equal(lhs, rhs), "should have been handled in tys_equal case" return (env, ctx, None) elif isinstance(lhs, FunctionTy) or isinstance(rhs, FunctionTy): if not isinstance(lhs, FunctionTy): raise NotImplementedError( f"TODO: handle subtype failure for non-function type {lhs!r} and function type {rhs!r}" ) if not isinstance(rhs, FunctionTy): raise NotImplementedError( f"TODO: handle subtype failure for function type {lhs!r} and non-function type {rhs!r}" ) raise NotImplementedError("TODO: implement subtype checking for functions") # TODO: implement the rest of the subtyping from Figure 9. raise NotImplementedError( f"TODO: subtype checking for lhs {lhs!r} and rhs {rhs!r} not implemented" ) def instantiate_lhs_existential( env: Env, ctx: TypingContext, lhs: ExistentialTyVar, rhs: Ty ) -> Tuple[Env, TypingContext, Reason]: if isinstance(rhs, (MonoTy, ExistentialTyVar)): ctx = ctx.instantiate_existential(existential_ty_var=lhs, to=rhs) return (env, ctx, InstantiateExistentialReason(existential_ty_var=lhs, to=rhs)) raise NotImplementedError( f"TODO: LHS existential instantiation for lhs {lhs!r} and rhs {rhs!r} not implemented" ) def instantiate_rhs_existential( env: Env, ctx: TypingContext, lhs: Ty, rhs: ExistentialTyVar ) -> Tuple[Env, TypingContext, Reason]: if isinstance(lhs, MonoTy): ctx = ctx.instantiate_existential(existential_ty_var=rhs, to=lhs) return (env, ctx, InstantiateExistentialReason(existential_ty_var=rhs, to=lhs)) raise NotImplementedError( f"TODO: RHS existential instantiation for lhs {lhs!r} and rhs {rhs!r} not implemented" ) def typecheck( file_info: FileInfo, syntax_tree: SyntaxTree, bindation: Bindation, global_scope: PMap[str, Ty], ) -> Typeation: ctx = TypingContext(judgments=PVector(), inferred_tys=PMap()) if syntax_tree.n_expr is None: return Typeation(ctx, errors=[]) env = Env( file_info=file_info, bindation=bindation, global_scope=global_scope, errors=PVector(), ) (env, ctx, checks) = check(env, ctx, expr=syntax_tree.n_expr, ty=TOP_TY) assert checks, "The program should always check against the top type" return Typeation(ctx=ctx, errors=list(env.errors)) PK!sN::pytch/typesystem/types.pyimport attr from pytch.containers import PVector import pytch.typesystem.reason @attr.s(auto_attribs=True, frozen=True) class Ty: reason: "pytch.typesystem.reason.Reason" def __eq__(self, other: object) -> bool: # TODO: does this work? Do we need to assign types unique IDs instead? return self is other def __neq__(self, other: object) -> bool: return not (self == other) class MonoTy(Ty): pass @attr.s(auto_attribs=True, frozen=True) class BaseTy(MonoTy): name: str @attr.s(auto_attribs=True, frozen=True) class FunctionTy(MonoTy): domain: PVector[Ty] codomain: Ty @attr.s(auto_attribs=True, frozen=True) class TyVar(MonoTy): """Type variable. This represents an indeterminate type which is to be symbolically manipulated, such as in a universally-quantified type. We don't instantiate it to a concrete type during typechecking. Usually type variables are denoted by Greek letters, such as "α" rather than "a". """ name: str @attr.s(auto_attribs=True, frozen=True) class UniversalTy(Ty): """Universally-quantified type. For example, the type ∀α. α → unit is a function type which takes a value of any type and returns the unit value. """ quantifier_ty: TyVar ty: Ty @attr.s(auto_attribs=True, frozen=True) class ExistentialTyVar(Ty): """Existential type variable. This represents an unsolved type that should be solved during type inference. The rules for doing so are covered in Dunfield 2013 Figure 10. """ name: str PK!u}.pytch/utils.pyfrom typing import List import attr Offset = int """A zero-indexed offset into a file.""" @attr.s(auto_attribs=True, frozen=True) class OffsetRange: start: Offset """The inclusive start offset of the range.""" end: Offset """The exclusive end offset of the range.""" @attr.s(auto_attribs=True, frozen=True) class Position: line: int character: int @attr.s(auto_attribs=True, frozen=True) class Range: start: Position end: Position @attr.s(auto_attribs=True) class FileInfo: file_path: str source_code: str lines: List[str] = attr.ib(init=False) def __attrs_post_init__(self) -> None: self.lines = splitlines(self.source_code) def get_position_for_offset(self, offset: int) -> Position: # 0-based index ranges are inclusive on the left and exclusive on the # right, which means that the length of the source code is a valid # index for constructing a range. assert ( 0 <= offset <= len(self.source_code) ), f"offset {offset} is not in range [0, {len(self.source_code)}]" current_offset = 0 current_line = 0 # Add 1 to the length of the line to account for the removed "\n" # character. while ( current_line < len(self.lines) and current_offset + len(self.lines[current_line]) + 1 <= offset ): current_offset += len(self.lines[current_line]) + 1 current_line += 1 character = offset - current_offset return Position(line=current_line, character=character) def get_range_from_offset_range(self, offset_range: OffsetRange) -> Range: return Range( start=self.get_position_for_offset(offset_range.start), end=self.get_position_for_offset(offset_range.end), ) def splitlines(s: str) -> List[str]: """Don't use `str.splitlines`. This splits on multiple Unicode newline-like characters, which we don't want to include. See https://docs.python.org/3/library/stdtypes.html#str.splitlines """ lines = s.split("\n") if lines[-1] == "": lines = lines[:-1] return lines PK!H[*,&pytch-0.0.1.dist-info/entry_points.txtN+I/N.,()*,Iΰz񹉙yV9\\PK!HnHTUpytch-0.0.1.dist-info/WHEEL A н#Z;/"d&F[xzw@Zpy3Fv]\fi4WZ^EgM_-]#0(q7PK!H!HFkpytch-0.0.1.dist-info/METADATAUr6}Wl>X3$(iVZ7N*Gv'ȕP\$1_!M_4 ޡw<JpF{bselH ߶tJ6h5SFtXj\f Uz4\.=_";(UE*J7b W@휶E/%f%yLi"J^& Rk.w\4˺7ސI-^PY _!;JOaTB 9FoTkVc(F- o[Lv>F gG V_o,wU  /ǣcv^1lDy0Gl?q 2d/!;~ySlY1_.K; ԃ'B$e f }ΎvÓOMQKFӡa?}d&IJt餀)jeSK[ghK# JPmviR`9-M؁$W/ .w^\Si+aR0e3Z<7zrj|o${{p+ln)NzCȦ!w$i2)7 -9BuPQ2S0ʝfNЙnG 1&{m { 6y[' !pXP-baq&No j N:gmUXtm UAbhS¬Gԧbnfi^r, /Q!{ސwin,>PJ AT+6"R*XsJۛ}Po&I]_e|=Ȗ ִ0\ߊ` '}y,VQV(KjHM[^+jVnrB_dY$@F=vCɟo޶PK!H_pytch-0.0.1.dist-info/RECORDuɒ+@t;-80A`"23fD7Q4^L_zqtuR՛|#l)c?%&.w4 Cn+!Vs9O7/(` ; ר,' KSedk[,oMcM)#1p6dÙx+B`lH&l>Mfс\\A}n5dN  `?*y=kg} Ť jU`-h.=WTAS>.T qZ]#ٜDt~][- А?'~Ld7::9riGv5;8Yp uNi[g:gxfcf=0yu|mc|߹Ƽϣ]?f!_Ծ0QZV=>AAK.3O'p߇[ݟĦngJ&[NଦX9[;f&Ty)_˝Y{S|2FZ uvhjW hb:iw܋oC4\s7y~ H+YH;:$!n71ע7GCJyM C# .v71A ;1ʓ|Survkvޡh:Aq[IpI=uG=|Yks& `M|NX[2~ U?5]u7>{1|@l)Զ,u΄yI>DMe]~O*\jTvz8u,zЗLϔfB6mÃO@bjKbϊx?)UäV2DǓEKmy^7_ĝvT$}`&op4ýdqN3UxG75xĿXJPK!"{^^pytch/__init__.pyPK!zHAApytch/__main__.pyPK!pytch/binder.pyPK!9v'D'D pytch/codegen/__init__.pyPK!Q]epytch/codegen/py3ast.pyPK!X# vpytch/containers.pyPK!GYpytch/cstquery.pyPK!3wbwbPpytch/errors.pyPK! ZnYY pytch/fuzz.pyPK!((xpytch/greencst.pyPK!,CuSSLpytch/lexer.pyPK!%dAA|ipytch/parser.pyPK!Y6O\O\pytch/redcst.pyPK!B fQpytch/repl.pyPK!L}s|^pytch/syntax_tree.txtPK! <~dpytch/typesystem/__init__.pyPK!TS-Nhpytch/typesystem/builtins.pyPK!Zppytch/typesystem/judgments.pyPK!P N))htpytch/typesystem/reason.pyPK!Ĭ]]{pytch/typesystem/typecheck.pyPK!sN::pytch/typesystem/types.pyPK!u}.!pytch/utils.pyPK!H[*,&pytch-0.0.1.dist-info/entry_points.txtPK!HnHTUBpytch-0.0.1.dist-info/WHEELPK!H!HFkpytch-0.0.1.dist-info/METADATAPK!H_vpytch-0.0.1.dist-info/RECORDPKc