"""Recursive-descent parser for the schemashift DSL.
Converts a DSL expression string into an AST composed of nodes from
:mod:`schemashift.dsl.ast_nodes`. Raises :class:`schemashift.errors.DSLSyntaxError`
for any invalid input.
"""
import ast
import re
from enum import Enum, auto
from typing import NamedTuple
from schemashift.errors import DSLSyntaxError
from .ast_nodes import (
ASTNode,
BinaryOp,
Coalesce,
ColRef,
CustomLookup,
Literal,
Lookup,
MethodCall,
UnaryOp,
WhenChain,
WhenClause,
)
# ---------------------------------------------------------------------------
# Allowlisted methods
# ---------------------------------------------------------------------------
_DIRECT_METHODS: frozenset[str] = frozenset({"round", "abs", "cast", "fill_null", "is_null"})
_STR_METHODS: frozenset[str] = frozenset(
{
"strip",
"lower",
"upper",
"to_lowercase",
"to_uppercase",
"slice",
"replace",
"replace_regex",
"contains",
"starts_with",
"ends_with",
"to_datetime",
"lengths",
"extract",
}
)
_DT_METHODS: frozenset[str] = frozenset({"year", "month", "day", "hour", "minute", "second", "strftime", "timestamp"})
def _is_allowed_method(namespace: str, name: str) -> bool:
if namespace == "":
return name in _DIRECT_METHODS
if namespace == "str":
return name in _STR_METHODS
if namespace == "dt":
return name in _DT_METHODS
return False
# ---------------------------------------------------------------------------
# Tokenizer
# ---------------------------------------------------------------------------
[docs]
class TT(Enum):
"""Token type."""
NUMBER = auto()
STRING = auto()
IDENT = auto()
DOT = auto()
LPAREN = auto()
RPAREN = auto()
COMMA = auto()
PLUS = auto()
MINUS = auto()
STAR = auto()
SLASH = auto()
PERCENT = auto()
EQ = auto() # ==
NE = auto() # !=
GT = auto() # >
LT = auto() # <
GE = auto() # >=
LE = auto() # <=
AMP = auto() # &
PIPE = auto() # |
LBRACE = auto() # {
RBRACE = auto() # }
COLON = auto() # :
EOF = auto()
[docs]
class Token(NamedTuple):
type: TT
value: object # str | int | float
pos: int # character offset in source
# Regex patterns for tokenizing (order matters).
_TOKEN_RE = re.compile(
r"""
(?P<FLOAT> \d+\.\d* | \.\d+) # float before int
|(?P<INT> \d+)
|(?P<STRING> '[^'\\]*(?:\\.[^'\\]*)*' | "[^"\\]*(?:\\.[^"\\]*)*")
|(?P<IDENT> [A-Za-z_][A-Za-z0-9_]*)
|(?P<EQ> ==)
|(?P<NE> !=)
|(?P<GE> >=)
|(?P<LE> <=)
|(?P<GT> >)
|(?P<LT> <)
|(?P<DOT> \.)
|(?P<LPAREN> \()
|(?P<RPAREN> \))
|(?P<COMMA> ,)
|(?P<PLUS> \+)
|(?P<MINUS> -)
|(?P<STAR> \*)
|(?P<SLASH> /)
|(?P<PERCENT> %)
|(?P<AMP> &)
|(?P<PIPE> \|)
|(?P<LBRACE> \{)
|(?P<RBRACE> \})
|(?P<COLON> :)
|(?P<WS> \s+) # ignored whitespace
""",
re.VERBOSE,
)
_TT_MAP: dict[str, TT] = {
"EQ": TT.EQ,
"NE": TT.NE,
"GE": TT.GE,
"LE": TT.LE,
"GT": TT.GT,
"LT": TT.LT,
"DOT": TT.DOT,
"LPAREN": TT.LPAREN,
"RPAREN": TT.RPAREN,
"COMMA": TT.COMMA,
"PLUS": TT.PLUS,
"MINUS": TT.MINUS,
"STAR": TT.STAR,
"SLASH": TT.SLASH,
"PERCENT": TT.PERCENT,
"AMP": TT.AMP,
"PIPE": TT.PIPE,
"LBRACE": TT.LBRACE,
"RBRACE": TT.RBRACE,
"COLON": TT.COLON,
"IDENT": TT.IDENT,
}
def _unescape_string(raw: str) -> str:
"""Remove surrounding quotes and process escape sequences.
Delegates to :func:`ast.literal_eval` so that ``\\n``, ``\\t``,
``\\uXXXX``, etc. are interpreted the same way Python would.
"""
# ast.literal_eval handles both single- and double-quoted strings.
return ast.literal_eval(raw)
[docs]
def tokenize(expression: str) -> list[Token]:
"""Convert *expression* to a flat list of tokens (excluding whitespace)."""
tokens: list[Token] = []
pos = 0
length = len(expression)
while pos < length:
m = _TOKEN_RE.match(expression, pos)
if m is None:
raise DSLSyntaxError(
f"Unexpected character {expression[pos]!r}",
expression=expression,
position=pos,
)
kind = m.lastgroup
text = m.group()
if kind == "WS":
pos = m.end()
continue
if kind == "FLOAT":
tokens.append(Token(TT.NUMBER, float(text), pos))
elif kind == "INT":
tokens.append(Token(TT.NUMBER, int(text), pos))
elif kind == "STRING":
tokens.append(Token(TT.STRING, _unescape_string(text), pos))
else:
if kind is None: # pragma: no cover
raise DSLSyntaxError(f"Unexpected token at position {pos}", expression=expression, position=pos)
tt = _TT_MAP[kind]
tokens.append(Token(tt, text, pos))
pos = m.end()
tokens.append(Token(TT.EOF, "", length))
return tokens
# ---------------------------------------------------------------------------
# Parser
# ---------------------------------------------------------------------------
_KEYWORDS: frozenset[str] = frozenset(
{"col", "when", "otherwise", "true", "false", "null", "coalesce", "lookup", "custom_lookup", "not"}
)
class _Parser:
"""Hand-written recursive-descent parser."""
def __init__(self, expression: str) -> None:
self._expr = expression
self._tokens = tokenize(expression)
self._pos = 0
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
def _peek(self) -> Token:
return self._tokens[self._pos]
def _advance(self) -> Token:
tok = self._tokens[self._pos]
self._pos += 1
return tok
def _expect(self, tt: TT, *, hint: str = "") -> Token:
tok = self._peek()
if tok.type != tt:
msg = hint or f"Expected {tt.name}, got {tok.value!r}"
raise DSLSyntaxError(msg, expression=self._expr, position=tok.pos)
return self._advance()
def _match(self, *types: TT) -> bool:
return self._peek().type in types
def _current_pos(self) -> int:
return self._peek().pos
# ------------------------------------------------------------------
# Grammar productions
# ------------------------------------------------------------------
def parse(self) -> ASTNode:
node = self._expression()
eof = self._peek()
if eof.type != TT.EOF:
raise DSLSyntaxError(
f"Unexpected token {eof.value!r} after expression",
expression=self._expr,
position=eof.pos,
)
return node
# expression := logical
def _expression(self) -> ASTNode:
return self._logical()
# logical := comparison (('&' | '|') comparison)*
def _logical(self) -> ASTNode:
left = self._comparison()
while self._match(TT.AMP, TT.PIPE):
op_tok = self._advance()
op = "&" if op_tok.type == TT.AMP else "|"
right = self._comparison()
left = BinaryOp(op, left, right)
return left
# comparison := additive (('==' | '!=' | '>' | '<' | '>=' | '<=') additive)?
def _comparison(self) -> ASTNode:
left = self._additive()
_CMP = {TT.EQ: "==", TT.NE: "!=", TT.GT: ">", TT.LT: "<", TT.GE: ">=", TT.LE: "<="}
if self._peek().type in _CMP:
op_tok = self._advance()
op = _CMP[op_tok.type]
right = self._additive()
left = BinaryOp(op, left, right)
return left
# additive := multiplicative (('+' | '-') multiplicative)*
def _additive(self) -> ASTNode:
left = self._multiplicative()
while self._match(TT.PLUS, TT.MINUS):
op_tok = self._advance()
op = "+" if op_tok.type == TT.PLUS else "-"
right = self._multiplicative()
left = BinaryOp(op, left, right)
return left
# multiplicative := unary (('*' | '/' | '%') unary)*
def _multiplicative(self) -> ASTNode:
left = self._unary()
_MUL = {TT.STAR: "*", TT.SLASH: "/", TT.PERCENT: "%"}
while self._peek().type in _MUL:
op_tok = self._advance()
op = _MUL[op_tok.type]
right = self._unary()
left = BinaryOp(op, left, right)
return left
# unary := '-' unary | 'not' unary | atom_with_methods
def _unary(self) -> ASTNode:
if self._match(TT.MINUS):
self._advance()
operand = self._unary()
return UnaryOp("-", operand)
if self._match(TT.IDENT) and str(self._peek().value) == "not":
self._advance()
operand = self._unary()
return UnaryOp("!", operand)
return self._atom_with_methods()
# atom_with_methods := atom ('.' method_chain)*
def _atom_with_methods(self) -> ASTNode:
node = self._atom()
while self._match(TT.DOT):
node = self._method_chain_step(node)
return node
# method_chain := IDENT '(' args? ')' | IDENT '.' IDENT '(' args? ')'
def _method_chain_step(self, obj: ASTNode) -> ASTNode:
self._expect(TT.DOT)
name_tok = self._expect(TT.IDENT, hint="Expected method name after '.'")
name = str(name_tok.value)
# Sub-namespace: str.xxx or dt.xxx
if self._match(TT.DOT):
# namespace.method
self._advance() # consume second dot
sub_tok = self._expect(TT.IDENT, hint=f"Expected method name after '{name}.'")
sub_name = str(sub_tok.value)
full_method = f"{name}.{sub_name}"
if not _is_allowed_method(name, sub_name):
raise DSLSyntaxError(
f"Unknown method '{full_method}'",
expression=self._expr,
position=sub_tok.pos,
)
self._expect(TT.LPAREN, hint=f"Expected '(' after '{full_method}'")
args = self._args() if not self._match(TT.RPAREN) else ()
self._expect(TT.RPAREN, hint=f"Expected ')' to close '{full_method}(...'")
return MethodCall(obj, full_method, tuple(args))
# Direct method
if not _is_allowed_method("", name):
raise DSLSyntaxError(
f"Unknown method '{name}'",
expression=self._expr,
position=name_tok.pos,
)
self._expect(TT.LPAREN, hint=f"Expected '(' after '{name}'")
args = self._args() if not self._match(TT.RPAREN) else ()
self._expect(TT.RPAREN, hint=f"Expected ')' to close '{name}(...'")
return MethodCall(obj, name, tuple(args))
# atom := NUMBER | STRING | BOOLEAN | NULL | col_ref | when_expr | '(' expression ')'
def _atom(self) -> ASTNode: # noqa: C901
tok = self._peek()
if tok.type == TT.NUMBER:
self._advance()
return Literal(tok.value)
if tok.type == TT.STRING:
self._advance()
return Literal(tok.value)
if tok.type == TT.IDENT:
ident = str(tok.value)
if ident == "true":
self._advance()
return Literal(True)
if ident == "false":
self._advance()
return Literal(False)
if ident == "null":
self._advance()
return Literal(None)
if ident == "col":
return self._col_ref()
if ident == "when":
return self._when_expr()
if ident == "coalesce":
return self._coalesce_expr()
if ident == "lookup":
return self._lookup_expr()
if ident == "custom_lookup":
return self._custom_lookup_expr()
if ident == "not":
raise DSLSyntaxError(
"'not' must prefix an expression (e.g. not col(\"x\").is_null())",
expression=self._expr,
position=tok.pos,
)
# Unknown identifier
raise DSLSyntaxError(
f"Unknown identifier {ident!r}",
expression=self._expr,
position=tok.pos,
)
if tok.type == TT.LPAREN:
self._advance()
node = self._expression()
self._expect(TT.RPAREN, hint="Expected ')' to close parenthesised expression")
return node
raise DSLSyntaxError(
f"Unexpected token {tok.value!r}",
expression=self._expr,
position=tok.pos,
)
# col_ref := 'col' '(' STRING ')'
def _col_ref(self) -> ColRef:
self._expect(TT.IDENT, hint="Expected 'col'")
self._expect(TT.LPAREN, hint="Expected '(' after 'col'")
name_tok = self._expect(TT.STRING, hint="Expected column name string in col(...)")
self._expect(TT.RPAREN, hint="Expected ')' after column name")
return ColRef(str(name_tok.value))
# when_expr := 'when' '(' expr ',' expr ')' ('.when(...)*)* '.otherwise' '(' expr ')'
def _when_expr(self) -> WhenChain:
# Consume leading 'when'
start_tok = self._expect(TT.IDENT, hint="Expected 'when'")
if str(start_tok.value) != "when":
raise DSLSyntaxError(
"Expected 'when'",
expression=self._expr,
position=start_tok.pos,
)
whens: list[WhenClause] = []
# First when(..., ...)
self._expect(TT.LPAREN, hint="Expected '(' after 'when'")
cond = self._expression()
self._expect(TT.COMMA, hint="Expected ',' between when condition and value")
val = self._expression()
self._expect(TT.RPAREN, hint="Expected ')' to close 'when(...'")
whens.append(WhenClause(cond, val))
# Chain: .when(...) | .otherwise(...)
while self._match(TT.DOT):
self._advance() # consume '.'
kw_tok = self._expect(TT.IDENT, hint="Expected 'when' or 'otherwise' after '.'")
kw = str(kw_tok.value)
if kw == "when":
self._expect(TT.LPAREN, hint="Expected '(' after '.when'")
cond2 = self._expression()
self._expect(TT.COMMA, hint="Expected ',' between when condition and value")
val2 = self._expression()
self._expect(TT.RPAREN, hint="Expected ')' to close '.when(...'")
whens.append(WhenClause(cond2, val2))
elif kw == "otherwise":
self._expect(TT.LPAREN, hint="Expected '(' after '.otherwise'")
default = self._expression()
self._expect(TT.RPAREN, hint="Expected ')' to close '.otherwise(...'")
return WhenChain(tuple(whens), default)
else:
raise DSLSyntaxError(
f"Expected 'when' or 'otherwise', got {kw!r}",
expression=self._expr,
position=kw_tok.pos,
)
# Fell off the end without .otherwise(...)
raise DSLSyntaxError(
"Expected '.otherwise(...)' to close when-chain",
expression=self._expr,
position=self._current_pos(),
)
# coalesce_expr := 'coalesce' '(' expression ',' expression (',' expression)* ')'
def _coalesce_expr(self) -> Coalesce:
self._expect(TT.IDENT, hint="Expected 'coalesce'")
self._expect(TT.LPAREN, hint="Expected '(' after 'coalesce'")
exprs = self._args()
self._expect(TT.RPAREN, hint="Expected ')' to close 'coalesce(...'")
if len(exprs) < 2:
raise DSLSyntaxError(
"coalesce() requires at least 2 arguments",
expression=self._expr,
position=self._current_pos(),
)
return Coalesce(tuple(exprs))
# lookup := 'lookup' '(' expression ',' STRING ')'
def _lookup_expr(self) -> Lookup:
self._expect(TT.IDENT, hint="Expected 'lookup'")
self._expect(TT.LPAREN, hint="Expected '(' after 'lookup'")
expr = self._expression()
self._expect(TT.COMMA, hint="Expected ',' after expression in lookup()")
table_tok = self._expect(TT.STRING, hint="lookup() table name must be a string literal")
self._expect(TT.RPAREN, hint="Expected ')' to close 'lookup(...'")
return Lookup(expr, str(table_tok.value))
# custom_lookup := 'custom_lookup' '(' expression ',' map_literal (',' STRING)? ')'
def _custom_lookup_expr(self) -> CustomLookup:
self._expect(TT.IDENT, hint="Expected 'custom_lookup'")
self._expect(TT.LPAREN, hint="Expected '(' after 'custom_lookup'")
expr = self._expression()
self._expect(TT.COMMA, hint="Expected ',' after expression in custom_lookup()")
mapping = self._map_literal()
if not mapping:
raise DSLSyntaxError(
"custom_lookup() mapping must not be empty",
expression=self._expr,
position=self._current_pos(),
)
# Optional: , "base_table_name"
base_table: str | None = None
if self._match(TT.COMMA):
self._advance()
tbl_tok = self._expect(TT.STRING, hint="custom_lookup() base table name must be a string literal")
base_table = str(tbl_tok.value)
self._expect(TT.RPAREN, hint="Expected ')' to close 'custom_lookup(...'")
return CustomLookup(expr, mapping, base_table)
# map_literal := '{' (literal ':' literal (',' literal ':' literal)* ','?)? '}'
def _map_literal(self) -> tuple[tuple[Literal, Literal], ...]:
self._expect(TT.LBRACE, hint="Expected '{' to start mapping")
pairs: list[tuple[Literal, Literal]] = []
while not self._match(TT.RBRACE):
if pairs:
self._expect(TT.COMMA, hint="Expected ',' between mapping entries")
if self._match(TT.RBRACE): # trailing comma
break
key = self._literal_atom()
self._expect(TT.COLON, hint="Expected ':' between key and value")
val = self._literal_atom()
pairs.append((key, val))
self._expect(TT.RBRACE, hint="Expected '}' to close mapping")
return tuple(pairs)
def _literal_atom(self) -> Literal:
"""Parse a scalar literal (string, number, bool, null) for map keys/values."""
tok = self._peek()
if tok.type == TT.NUMBER:
self._advance()
return Literal(tok.value)
if tok.type == TT.STRING:
self._advance()
return Literal(tok.value)
if tok.type == TT.IDENT and str(tok.value) in ("true", "false", "null"):
self._advance()
if str(tok.value) == "true":
return Literal(True)
if str(tok.value) == "false":
return Literal(False)
return Literal(None)
raise DSLSyntaxError(
f"Expected a literal value (string, number, true, false, null), got {tok.value!r}",
expression=self._expr,
position=tok.pos,
)
# args := expression (',' expression)*
def _args(self) -> list[ASTNode]:
args: list[ASTNode] = [self._expression()]
while self._match(TT.COMMA):
self._advance()
args.append(self._expression())
return args
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
[docs]
def parse_dsl(expression: str) -> ASTNode:
"""Parse *expression* and return the root AST node.
Raises :class:`schemashift.errors.DSLSyntaxError` on any syntax error.
"""
if not expression or not expression.strip():
raise DSLSyntaxError(
"Expression must not be empty",
expression=expression,
position=0,
)
return _Parser(expression).parse()