diff --git a/posthog/api/test/__snapshots__/test_query.ambr b/posthog/api/test/__snapshots__/test_query.ambr index 158a059e51a88..34702611c2ded 100644 --- a/posthog/api/test/__snapshots__/test_query.ambr +++ b/posthog/api/test/__snapshots__/test_query.ambr @@ -348,12 +348,12 @@ # name: TestQuery.test_select_hogql_expressions.1 ' /* user_id:0 request:_snapshot_ */ - SELECT tuple(uuid, event, properties, timestamp, team_id, distinct_id, elements_chain, created_at, person_id, person_created_at, person_properties), + SELECT tuple(uuid, event, properties, timestamp, team_id, distinct_id, elements_chain, created_at), event FROM events WHERE team_id = 2 AND timestamp < '2020-01-10 12:14:05.000000' - ORDER BY tuple(uuid, event, properties, timestamp, team_id, distinct_id, elements_chain, created_at, person_id, person_created_at, person_properties) ASC + ORDER BY tuple(uuid, event, properties, timestamp, team_id, distinct_id, elements_chain, created_at) ASC LIMIT 101 ' --- diff --git a/posthog/hogql/ast.py b/posthog/hogql/ast.py index 36fd7bd32dcfe..4d6d54da10492 100644 --- a/posthog/hogql/ast.py +++ b/posthog/hogql/ast.py @@ -53,9 +53,14 @@ def has_child(self, name: str) -> bool: return self.table.has_field(name) def get_child(self, name: str) -> Symbol: + if name == "*": + return AsteriskSymbol(table=self) if self.has_child(name): + field = self.table.get_field(name) + if isinstance(field, Table): + return TableSymbol(table=field) return FieldSymbol(name=name, table=self) - raise ValueError(f"Field not found: {name}") + raise ValueError(f'Field "{name}" not found on table {type(self.table).__name__}') class TableAliasSymbol(Symbol): @@ -66,6 +71,8 @@ def has_child(self, name: str) -> bool: return self.table.has_child(name) def get_child(self, name: str) -> Symbol: + if name == "*": + return AsteriskSymbol(table=self) if self.has_child(name): return FieldSymbol(name=name, table=self) return self.table.get_child(name) @@ -84,6 +91,8 @@ class SelectQuerySymbol(Symbol): anonymous_tables: List["SelectQuerySymbol"] = PydanticField(default_factory=list) def get_child(self, name: str) -> Symbol: + if name == "*": + return AsteriskSymbol(table=self) if name in self.columns: return FieldSymbol(name=name, table=self) raise ValueError(f"Column not found: {name}") @@ -97,9 +106,11 @@ class SelectQueryAliasSymbol(Symbol): symbol: SelectQuerySymbol def get_child(self, name: str) -> Symbol: + if name == "*": + return AsteriskSymbol(table=self) if self.symbol.has_child(name): return FieldSymbol(name=name, table=self) - raise ValueError(f"Field not found: {name}") + raise ValueError(f"Field {name} not found on query with alias {self.name}") def has_child(self, name: str) -> bool: return self.symbol.has_child(name) diff --git a/posthog/hogql/constants.py b/posthog/hogql/constants.py index 92c88a8c65df4..27e180bec6574 100644 --- a/posthog/hogql/constants.py +++ b/posthog/hogql/constants.py @@ -110,9 +110,6 @@ "distinct_id", "elements_chain", "created_at", - "person.id", - "person.created_at", - "person.properties", ] # Never return more rows than this in top level HogQL SELECT statements diff --git a/posthog/hogql/database.py b/posthog/hogql/database.py index 186241ef621cf..1b02557be851f 100644 --- a/posthog/hogql/database.py +++ b/posthog/hogql/database.py @@ -1,4 +1,4 @@ -from typing import List +from typing import Dict from pydantic import BaseModel, Extra @@ -48,17 +48,17 @@ def get_field(self, name: str) -> DatabaseField: def clickhouse_table(self): raise NotImplementedError("Table.clickhouse_table not overridden") - def get_asterisk(self) -> List[str]: - list: List[str] = [] - for field in self.__fields__.values(): + def get_asterisk(self) -> Dict[str, DatabaseField]: + asterisk: Dict[str, DatabaseField] = {} + for key, field in self.__fields__.items(): database_field = field.default if isinstance(database_field, DatabaseField): - list.append(database_field.name) + asterisk[key] = database_field elif isinstance(database_field, Table): - list.extend(database_field.get_asterisk()) + pass # ignore virtual tables for now else: raise ValueError(f"Unknown field type {type(database_field).__name__} for asterisk") - return list + return asterisk class PersonsTable(Table): diff --git a/posthog/hogql/parser.py b/posthog/hogql/parser.py index 5b7f4b9eff6fc..b7be7d161938a 100644 --- a/posthog/hogql/parser.py +++ b/posthog/hogql/parser.py @@ -296,6 +296,9 @@ def visitColumnExprList(self, ctx: HogQLParser.ColumnExprListContext): return [self.visit(c) for c in ctx.columnsExpr()] def visitColumnsExprAsterisk(self, ctx: HogQLParser.ColumnsExprAsteriskContext): + if ctx.tableIdentifier(): + table = self.visit(ctx.tableIdentifier()) + return ast.Field(chain=table + ["*"]) return ast.Field(chain=["*"]) def visitColumnsExprSubquery(self, ctx: HogQLParser.ColumnsExprSubqueryContext): @@ -500,6 +503,9 @@ def visitColumnExprFunction(self, ctx: HogQLParser.ColumnExprFunctionContext): return ast.Call(name=name, args=args) def visitColumnExprAsterisk(self, ctx: HogQLParser.ColumnExprAsteriskContext): + if ctx.tableIdentifier(): + table = self.visit(ctx.tableIdentifier()) + return ast.Field(chain=table + ["*"]) return ast.Field(chain=["*"]) def visitColumnArgList(self, ctx: HogQLParser.ColumnArgListContext): diff --git a/posthog/hogql/printer.py b/posthog/hogql/printer.py index 20930c183a06b..287e9c1a9d6ff 100644 --- a/posthog/hogql/printer.py +++ b/posthog/hogql/printer.py @@ -8,6 +8,7 @@ from posthog.hogql.database import Table, database from posthog.hogql.print_string import print_clickhouse_identifier, print_hogql_identifier from posthog.hogql.resolver import ResolverException, lookup_field_by_name, resolve_symbols +from posthog.hogql.transforms import expand_asterisks from posthog.hogql.visitor import Visitor from posthog.models.property import PropertyName, TableColumn @@ -40,6 +41,8 @@ def print_ast( # modify the cloned tree as needed if dialect == "clickhouse": + expand_asterisks(node) + # TODO: add team_id checks (currently done in the printer) # TODO: add joins to person and group tables pass @@ -472,16 +475,7 @@ def visit_field_alias_symbol(self, symbol: ast.SelectQueryAliasSymbol): return self._print_identifier(symbol.name) def visit_asterisk_symbol(self, symbol: ast.AsteriskSymbol): - table = symbol.table - while isinstance(table, ast.TableAliasSymbol): - table = table.table - if not isinstance(table, ast.TableSymbol): - raise ValueError(f"Unknown AsteriskSymbol table type: {type(table).__name__}") - asterisk_fields = table.table.get_asterisk() - prefix = ( - f"{self._print_identifier(symbol.table.name)}." if isinstance(symbol.table, ast.TableAliasSymbol) else "" - ) - return f"tuple({', '.join(f'{prefix}{self._print_identifier(field)}' for field in asterisk_fields)})" + raise ValueError("Unexpected ast.AsteriskSymbol. Make sure AsteriskExpander has run on the AST.") def visit_unknown(self, node: ast.AST): raise ValueError(f"Unknown AST node {type(node).__name__}") diff --git a/posthog/hogql/resolver.py b/posthog/hogql/resolver.py index 341235000607e..bc2910fa20392 100644 --- a/posthog/hogql/resolver.py +++ b/posthog/hogql/resolver.py @@ -165,7 +165,7 @@ def visit_field(self, node): if table_count == 0: raise ResolverException("Cannot use '*' when there are no tables in the query") if table_count > 1: - raise ResolverException("Cannot use '*' when there are multiple tables in the query") + raise ResolverException("Cannot use '*' without table name when there are multiple tables in the query") table = scope.anonymous_tables[0] if len(scope.anonymous_tables) > 0 else list(scope.tables.values())[0] symbol = ast.AsteriskSymbol(table=table) diff --git a/posthog/hogql/test/test_printer.py b/posthog/hogql/test/test_printer.py index a076006d70a77..ba61f0304e47f 100644 --- a/posthog/hogql/test/test_printer.py +++ b/posthog/hogql/test/test_printer.py @@ -329,17 +329,6 @@ def test_comments(self): context = HogQLContext() self.assertEqual(self._expr("event -- something", context), "event") - def test_special_root_properties(self): - self.assertEqual( - self._expr("*"), - "tuple(uuid, event, properties, timestamp, team_id, distinct_id, elements_chain, created_at, person_id, person_created_at, person_properties)", - ) - context = HogQLContext() - self.assertEqual( - self._expr("person", context), - "tuple(person_id, person_created_at, person_properties)", - ) - def test_values(self): context = HogQLContext() self.assertEqual(self._expr("event == 'E'", context), "equals(event, %(hogql_val_0)s)") diff --git a/posthog/hogql/test/test_transforms.py b/posthog/hogql/test/test_transforms.py new file mode 100644 index 0000000000000..b1f7864d43f4f --- /dev/null +++ b/posthog/hogql/test/test_transforms.py @@ -0,0 +1,158 @@ +from posthog.hogql import ast +from posthog.hogql.database import database +from posthog.hogql.parser import parse_select +from posthog.hogql.resolver import ResolverException, resolve_symbols +from posthog.hogql.transforms import expand_asterisks +from posthog.test.base import BaseTest + + +class TestTransforms(BaseTest): + def test_asterisk_expander_table(self): + node = parse_select("select * from events") + resolve_symbols(node) + expand_asterisks(node) + events_table_symbol = ast.TableSymbol(table=database.events) + self.assertEqual( + node.select, + [ + ast.Field(chain=["uuid"], symbol=ast.FieldSymbol(name="uuid", table=events_table_symbol)), + ast.Field(chain=["event"], symbol=ast.FieldSymbol(name="event", table=events_table_symbol)), + ast.Field(chain=["properties"], symbol=ast.FieldSymbol(name="properties", table=events_table_symbol)), + ast.Field(chain=["timestamp"], symbol=ast.FieldSymbol(name="timestamp", table=events_table_symbol)), + ast.Field(chain=["team_id"], symbol=ast.FieldSymbol(name="team_id", table=events_table_symbol)), + ast.Field(chain=["distinct_id"], symbol=ast.FieldSymbol(name="distinct_id", table=events_table_symbol)), + ast.Field( + chain=["elements_chain"], symbol=ast.FieldSymbol(name="elements_chain", table=events_table_symbol) + ), + ast.Field(chain=["created_at"], symbol=ast.FieldSymbol(name="created_at", table=events_table_symbol)), + ], + ) + + def test_asterisk_expander_table_alias(self): + node = parse_select("select * from events e") + resolve_symbols(node) + expand_asterisks(node) + events_table_symbol = ast.TableSymbol(table=database.events) + events_table_alias_symbol = ast.TableAliasSymbol(table=events_table_symbol, name="e") + self.assertEqual( + node.select, + [ + ast.Field(chain=["uuid"], symbol=ast.FieldSymbol(name="uuid", table=events_table_alias_symbol)), + ast.Field(chain=["event"], symbol=ast.FieldSymbol(name="event", table=events_table_alias_symbol)), + ast.Field( + chain=["properties"], symbol=ast.FieldSymbol(name="properties", table=events_table_alias_symbol) + ), + ast.Field( + chain=["timestamp"], symbol=ast.FieldSymbol(name="timestamp", table=events_table_alias_symbol) + ), + ast.Field(chain=["team_id"], symbol=ast.FieldSymbol(name="team_id", table=events_table_alias_symbol)), + ast.Field( + chain=["distinct_id"], symbol=ast.FieldSymbol(name="distinct_id", table=events_table_alias_symbol) + ), + ast.Field( + chain=["elements_chain"], + symbol=ast.FieldSymbol(name="elements_chain", table=events_table_alias_symbol), + ), + ast.Field( + chain=["created_at"], symbol=ast.FieldSymbol(name="created_at", table=events_table_alias_symbol) + ), + ], + ) + + def test_asterisk_expander_subquery(self): + node = parse_select("select * from (select 1 as a, 2 as b)") + resolve_symbols(node) + expand_asterisks(node) + select_subquery_symbol = ast.SelectQuerySymbol( + aliases={ + "a": ast.FieldAliasSymbol(name="a", symbol=ast.ConstantSymbol(value=1)), + "b": ast.FieldAliasSymbol(name="b", symbol=ast.ConstantSymbol(value=2)), + }, + columns={ + "a": ast.FieldAliasSymbol(name="a", symbol=ast.ConstantSymbol(value=1)), + "b": ast.FieldAliasSymbol(name="b", symbol=ast.ConstantSymbol(value=2)), + }, + tables={}, + anonymous_tables=[], + ) + self.assertEqual( + node.select, + [ + ast.Field(chain=["a"], symbol=ast.FieldSymbol(name="a", table=select_subquery_symbol)), + ast.Field(chain=["b"], symbol=ast.FieldSymbol(name="b", table=select_subquery_symbol)), + ], + ) + + def test_asterisk_expander_subquery_alias(self): + node = parse_select("select x.* from (select 1 as a, 2 as b) x") + resolve_symbols(node) + expand_asterisks(node) + select_subquery_symbol = ast.SelectQueryAliasSymbol( + name="x", + symbol=ast.SelectQuerySymbol( + aliases={ + "a": ast.FieldAliasSymbol(name="a", symbol=ast.ConstantSymbol(value=1)), + "b": ast.FieldAliasSymbol(name="b", symbol=ast.ConstantSymbol(value=2)), + }, + columns={ + "a": ast.FieldAliasSymbol(name="a", symbol=ast.ConstantSymbol(value=1)), + "b": ast.FieldAliasSymbol(name="b", symbol=ast.ConstantSymbol(value=2)), + }, + tables={}, + anonymous_tables=[], + ), + ) + self.assertEqual( + node.select, + [ + ast.Field(chain=["a"], symbol=ast.FieldSymbol(name="a", table=select_subquery_symbol)), + ast.Field(chain=["b"], symbol=ast.FieldSymbol(name="b", table=select_subquery_symbol)), + ], + ) + + def test_asterisk_expander_from_subquery_table(self): + node = parse_select("select * from (select * from events)") + resolve_symbols(node) + expand_asterisks(node) + + events_table_symbol = ast.TableSymbol(table=database.events) + inner_select_symbol = ast.SelectQuerySymbol( + tables={"events": events_table_symbol}, + anonymous_tables=[], + aliases={}, + columns={ + "uuid": ast.FieldSymbol(name="uuid", table=events_table_symbol), + "event": ast.FieldSymbol(name="event", table=events_table_symbol), + "properties": ast.FieldSymbol(name="properties", table=events_table_symbol), + "timestamp": ast.FieldSymbol(name="timestamp", table=events_table_symbol), + "team_id": ast.FieldSymbol(name="team_id", table=events_table_symbol), + "distinct_id": ast.FieldSymbol(name="distinct_id", table=events_table_symbol), + "elements_chain": ast.FieldSymbol(name="elements_chain", table=events_table_symbol), + "created_at": ast.FieldSymbol(name="created_at", table=events_table_symbol), + }, + ) + + self.assertEqual( + node.select, + [ + ast.Field(chain=["uuid"], symbol=ast.FieldSymbol(name="uuid", table=inner_select_symbol)), + ast.Field(chain=["event"], symbol=ast.FieldSymbol(name="event", table=inner_select_symbol)), + ast.Field(chain=["properties"], symbol=ast.FieldSymbol(name="properties", table=inner_select_symbol)), + ast.Field(chain=["timestamp"], symbol=ast.FieldSymbol(name="timestamp", table=inner_select_symbol)), + ast.Field(chain=["team_id"], symbol=ast.FieldSymbol(name="team_id", table=inner_select_symbol)), + ast.Field(chain=["distinct_id"], symbol=ast.FieldSymbol(name="distinct_id", table=inner_select_symbol)), + ast.Field( + chain=["elements_chain"], + symbol=ast.FieldSymbol(name="elements_chain", table=inner_select_symbol), + ), + ast.Field(chain=["created_at"], symbol=ast.FieldSymbol(name="created_at", table=inner_select_symbol)), + ], + ) + + def test_asterisk_expander_multiple_table_error(self): + node = parse_select("select * from (select 1 as a, 2 as b) x left join (select 1 as a, 2 as b) y on x.a = y.a") + with self.assertRaises(ResolverException) as e: + resolve_symbols(node) + self.assertEqual( + str(e.exception), "Cannot use '*' without table name when there are multiple tables in the query" + ) diff --git a/posthog/hogql/transforms.py b/posthog/hogql/transforms.py new file mode 100644 index 0000000000000..b3d8c91ee5cf0 --- /dev/null +++ b/posthog/hogql/transforms.py @@ -0,0 +1,49 @@ +from typing import List + +from posthog.hogql import ast +from posthog.hogql.visitor import TraversingVisitor + + +def expand_asterisks(node: ast.Expr): + AsteriskExpander().visit(node) + + +class AsteriskExpander(TraversingVisitor): + def visit_select_query(self, node: ast.SelectQuery): + super().visit_select_query(node) + + columns: List[ast.Expr] = [] + for column in node.select: + if isinstance(column.symbol, ast.AsteriskSymbol): + asterisk = column.symbol + if isinstance(asterisk.table, ast.TableSymbol) or isinstance(asterisk.table, ast.TableAliasSymbol): + table = asterisk.table + while isinstance(table, ast.TableAliasSymbol): + table = table.table + if isinstance(table, ast.TableSymbol): + database_fields = table.table.get_asterisk() + for key in database_fields.keys(): + symbol = ast.FieldSymbol(name=key, table=asterisk.table) + columns.append(ast.Field(chain=[key], symbol=symbol)) + node.symbol.columns[key] = symbol + else: + raise ValueError("Can't expand asterisk (*) on table") + elif isinstance(asterisk.table, ast.SelectQuerySymbol) or isinstance( + asterisk.table, ast.SelectQueryAliasSymbol + ): + select = asterisk.table + while isinstance(select, ast.SelectQueryAliasSymbol): + select = select.symbol + if isinstance(select, ast.SelectQuerySymbol): + for name in select.columns.keys(): + symbol = ast.FieldSymbol(name=name, table=asterisk.table) + columns.append(ast.Field(chain=[name], symbol=symbol)) + node.symbol.columns[name] = symbol + else: + raise ValueError("Can't expand asterisk (*) on subquery") + else: + raise ValueError(f"Can't expand asterisk (*) on a symbol of type {type(asterisk.table).__name__}") + + else: + columns.append(column) + node.select = columns diff --git a/posthog/models/event/query_event_list.py b/posthog/models/event/query_event_list.py index 878bf00e4ed32..2a4c8f5f5b483 100644 --- a/posthog/models/event/query_event_list.py +++ b/posthog/models/event/query_event_list.py @@ -210,6 +210,8 @@ def run_events_query( for expr in select: hogql_context.found_aggregation = False + if expr == "*": + expr = f'tuple({", ".join(SELECT_STAR_FROM_EVENTS_FIELDS)})' clickhouse_sql = translate_hogql(expr, hogql_context) select_columns.append(clickhouse_sql) if not hogql_context.found_aggregation: @@ -273,13 +275,6 @@ def run_events_query( results[index] = list(result) results[index][star] = convert_star_select_to_dict(result[star]) - # Convert person field from tuple to dict in each result - if "person" in select: - person = select.index("person") - for index, result in enumerate(results): - results[index] = list(result) - results[index][person] = convert_person_select_to_dict(result[person]) - received_extra_row = len(results) == limit # limit was +=1'd above return EventsQueryResponse( @@ -293,23 +288,6 @@ def run_events_query( def convert_star_select_to_dict(select: Tuple[Any]) -> Dict[str, Any]: new_result = dict(zip(SELECT_STAR_FROM_EVENTS_FIELDS, select)) new_result["properties"] = json.loads(new_result["properties"]) - new_result["person"] = { - "id": new_result["person.id"], - "created_at": new_result["person.created_at"], - "properties": json.loads(new_result["person.properties"]), - } - new_result.pop("person.id") - new_result.pop("person.created_at") - new_result.pop("person.properties") if new_result["elements_chain"]: new_result["elements"] = ElementSerializer(chain_to_elements(new_result["elements_chain"]), many=True).data return new_result - - -def convert_person_select_to_dict(select: Tuple[str, str, str, str, str]) -> Dict[str, Any]: - return { - "id": select[1], - "created_at": select[2], - "properties": {"name": select[3], "email": select[4]}, - "distinct_ids": [select[0]], - }