Skip to content

Commit

Permalink
fix: use sqlalchemy_url property in get_uri for postgresql provid…
Browse files Browse the repository at this point in the history
…er (apache#38831)

* update get_uri

* update get_uri

Signed-off-by: kalyanr <[email protected]>

* update docstring

Signed-off-by: kalyanr <[email protected]>

* add and use sa_uri property

* update database in sa_uri

* update tests

* remove client_encoding from test_get_uri

* use sqlalchemy_url property

* add default port

* update tests

* update usage of ports

* revert client_encoding updates

---------

Signed-off-by: kalyanr <[email protected]>
  • Loading branch information
rawwar authored and RodrigoGanancia committed May 10, 2024
1 parent 84f63b4 commit 40ba7c3
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 7 deletions.
20 changes: 15 additions & 5 deletions airflow/providers/postgres/hooks/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import psycopg2.extras
from deprecated import deprecated
from psycopg2.extras import DictCursor, NamedTupleCursor, RealDictCursor
from sqlalchemy.engine import URL

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.common.sql.hooks.sql import DbApiHook
Expand Down Expand Up @@ -113,6 +114,18 @@ def schema(self):
def schema(self, value):
self.database = value

@property
def sqlalchemy_url(self) -> URL:
conn = self.get_connection(getattr(self, self.conn_name_attr))
return URL.create(
drivername="postgresql",
username=conn.login,
password=conn.password,
host=conn.host,
port=conn.port,
database=self.database or conn.schema,
)

def _get_cursor(self, raw_cursor: str) -> CursorType:
_cursor = raw_cursor.lower()
cursor_types = {
Expand Down Expand Up @@ -186,12 +199,9 @@ def copy_expert(self, sql: str, filename: str) -> None:
def get_uri(self) -> str:
"""Extract the URI from the connection.
:return: the extracted uri.
:return: the extracted URI in Sqlalchemy URI format.
"""
conn = self.get_connection(getattr(self, self.conn_name_attr))
conn.schema = self.database or conn.schema
uri = conn.get_uri().replace("postgres://", "postgresql://")
return uri
return self.sqlalchemy_url.render_as_string(hide_password=False)

def bulk_load(self, table: str, tmp_file: str) -> None:
"""Load a tab-delimited file into a database table."""
Expand Down
4 changes: 2 additions & 2 deletions tests/providers/postgres/hooks/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,11 @@ def test_get_conn(self, mock_connect):

@mock.patch("airflow.providers.postgres.hooks.postgres.psycopg2.connect")
def test_get_uri(self, mock_connect):
self.connection.extra = json.dumps({"client_encoding": "utf-8"})
self.connection.conn_type = "postgres"
self.connection.port = 5432
self.db_hook.get_conn()
assert mock_connect.call_count == 1
assert self.db_hook.get_uri() == "postgresql://login:password@host/database?client_encoding=utf-8"
assert self.db_hook.get_uri() == "postgresql://login:password@host:5432/database"

@mock.patch("airflow.providers.postgres.hooks.postgres.psycopg2.connect")
def test_get_conn_cursor(self, mock_connect):
Expand Down

0 comments on commit 40ba7c3

Please sign in to comment.