Skip to content

Commit

Permalink
Fix typing
Browse files Browse the repository at this point in the history
  • Loading branch information
robbie-c committed Mar 21, 2024
1 parent 99132ca commit d6d9760
Showing 1 changed file with 13 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from posthog.hogql import ast
from posthog.hogql.ast import CompareOperationOp, ArithmeticOperationOp
from posthog.hogql.context import HogQLContext
from posthog.hogql.database.models import DatabaseField
from posthog.hogql.visitor import clone_expr, CloningVisitor, Visitor

SESSION_BUFFER_DAYS = 3
Expand Down Expand Up @@ -287,10 +288,10 @@ def visit_constant(self, node: ast.Constant) -> bool:
def visit_field(self, node: ast.Field) -> bool:
if node.type and isinstance(node.type, ast.FieldType):
resolved_field = node.type.resolve_database_field(self.context)
return resolved_field.name in ["min_timestamp", "timestamp"]
else:
# no type information, so just use the name of the field
return node.chain[-1] in ["min_timestamp", "timestamp"]
if resolved_field and isinstance(resolved_field, DatabaseField):
return resolved_field.name in ["min_timestamp", "timestamp"]
# no type information, so just use the name of the field
return node.chain[-1] in ["min_timestamp", "timestamp"]

def visit_arithmetic_operation(self, node: ast.ArithmeticOperation) -> bool:
# only allow the min_timestamp field to be used on one side of the arithmetic operation
Expand Down Expand Up @@ -364,12 +365,15 @@ def __init__(self, context: HogQLContext, *args, **kwargs):
def visit_field(self, node: ast.Field) -> ast.Field:
if node.type and isinstance(node.type, ast.FieldType):
resolved_field = node.type.resolve_database_field(self.context)
if resolved_field and resolved_field.name in ["min_timestamp", "timestamp"]:
return ast.Field(chain=["raw_sessions", "min_timestamp"])
else:
# no type information, so just use the name of the field
if node.chain[-1] in ["min_timestamp", "timestamp"]:
if (
resolved_field
and isinstance(resolved_field, DatabaseField)
and resolved_field.name in ["min_timestamp", "timestamp"]
):
return ast.Field(chain=["raw_sessions", "min_timestamp"])
# no type information, so just use the name of the field
if node.chain[-1] in ["min_timestamp", "timestamp"]:
return ast.Field(chain=["raw_sessions", "min_timestamp"])
return node

def visit_alias(self, node: ast.Alias) -> ast.Expr:
Expand Down

0 comments on commit d6d9760

Please sign in to comment.