Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[defog] Postgres table name format. #9

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

JeevansSP
Copy link

@JeevansSP JeevansSP commented Sep 13, 2023

Fixes #8

else:
for table in tables:
if not table or len(table.split("."))!=2:
raise ValueError(f"PostgreSQL table names should be of the following format <schema>.<table> which is violated by '{table}`")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might not be a good idea to raise an exception here if users just pass us their table names without schema (for backward compatibility purposes). Could we default to the current logic of working with table names directly when no schema is provided (ie not raise an exception)?

Copy link
Author

@JeevansSP JeevansSP Sep 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While I agree with backward compatibility
The problem arises in these two sql queries

    def generate_postgres_schema(self, tables: list, upload: bool = True) -> str:
        # when upload is True, we send the schema to the defog servers and generate a Google Sheet
        # when its false, we return the schema as a dict
        try:
            import psycopg2
        except ImportError:
            raise ImportError(
                "psycopg2 not installed. Please install it with `pip install psycopg2-binary`."
            )

        conn = psycopg2.connect(**self.db_creds)
        cur = conn.cursor()
        schemas = {}

        if tables == [""]:
            # get all tables
            cur.execute(
                "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public';"  
            )
            tables = [row[0] for row in cur.fetchall()]
        print("Retrieved the following tables:")
        for t in tables:
            print(f"\t{t}")

        print("Getting schema for each table in your database...")
        # get the schema for each table
        for table_name in tables:
            cur.execute(
                "SELECT CAST(column_name AS TEXT), CAST(data_type AS TEXT) FROM information_schema.columns WHERE table_name::text = %s;",
                (table_name,),
            ) #--> Two tables can have the same name but belong to different schemas hence schema check is absolutely necessary
            rows = cur.fetchall()
            rows = [row for row in rows]
            rows = [{"column_name": i[0], "data_type": i[1]} for i in rows]
            schemas[table_name] = rows

        # get foreign key relationships
        print("Getting foreign keys for each table in your database...")
        tables_regclass_str = ", ".join(
            [f"'{table_name}'::regclass" for table_name in tables]
        ) #--> Table name here needs to have schema prefixed or else the next query will not work
        query = f"""SELECT
                conrelid::regclass AS table_from,
                pg_get_constraintdef(oid) AS foreign_key_definition
                FROM pg_constraint
                WHERE contype = 'f'
                AND conrelid::regclass IN ({tables_regclass_str})
                AND confrelid::regclass IN ({tables_regclass_str});
                """
        cur.execute(query)
        foreign_keys = list(cur.fetchall())
        foreign_keys = [fk[0] + " " + fk[1] for fk in foreign_keys]

        # get indexes for each table
        print("Getting indexes for each table in your database...")
        tables_str = ", ".join([f"'{table_name}'" for table_name in tables])
        query = (
            f"""SELECT indexdef FROM pg_indexes WHERE tablename IN ({tables_str});"""
        )
        cur.execute(query)
        indexes = list(cur.fetchall())
        if len(indexes) > 0:
            indexes = [index[0] for index in indexes]
        else:
            indexes = []
            # print("No indexes found.")
        conn.close()

        print(
            "Sending the schema to the defog servers and generating a Google Sheet. This might take up to 2 minutes..."
        )
        if upload:
            # send the schemas dict to the defog servers
            r = requests.post(
                "https://api.defog.ai/get_postgres_schema_gsheets",
                json={
                    "api_key": self.api_key,
                    "schemas": schemas,
                    "foreign_keys": foreign_keys,
                    "indexes": indexes,
                },
            )
            resp = r.json()
            if "sheet_url" in resp:
                gsheet_url = resp["sheet_url"]
                return gsheet_url
            else:
                print(f"We got an error!")
                if "message" in resp:
                    print(f"Error message: {resp['message']}")
                print(
                    f"Please feel free to open a github issue if this a generic library issue, or email [email protected] if you need dedicated customer-specific support."
                )

Copy link
Author

@JeevansSP JeevansSP Sep 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also the exception is raised only if the provided tables to generate the schema does not contain scheme prefixed
when no table names are provided then it defaults to using the tables from the public schema

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for updating this. A number of our customers rely on just passing in the table names (without schema) and we can't raise an exception here.

else:
for table in tables:
if not table or len(table.split("."))!=2:
raise ValueError(f"PostgreSQL table names should be of the following format <schema>.<table> which is violated by '{table}`")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for updating this. A number of our customers rely on just passing in the table names (without schema) and we can't raise an exception here.

for table_name in tables:
for table in tables:

schema,table_name = table.split(".")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To avoid needing to throw an exception, one suggestion is to first check if there is a ., and if there isn't, to default the schema to public. eg:

for table in tables:
  if '.' in table:
    schema, table_name = table.split(".",1)
  else:
   schema, table_name = "public", table

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[defog] Postgres table names should be passed along with their respective schemas
2 participants