Skip to content

Commit

Permalink
cleanup: __fetch_adb_docs
Browse files Browse the repository at this point in the history
  • Loading branch information
aMahanna committed Jan 19, 2024
1 parent 70dbc0b commit 0197b4b
Showing 1 changed file with 12 additions and 15 deletions.
27 changes: 12 additions & 15 deletions adbdgl_adapter/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def udf_v1_x(v1_df):

# 1. Fetch ArangoDB vertices
v_col_cursor, v_col_size = self.__fetch_adb_docs(
v_col, meta, **adb_export_kwargs
v_col, False, meta, **adb_export_kwargs
)

# 2. Process ArangoDB vertices
Expand All @@ -294,7 +294,7 @@ def udf_v1_x(v1_df):

# 1. Fetch ArangoDB edges
e_col_cursor, e_col_size = self.__fetch_adb_docs(
e_col, meta, **adb_export_kwargs
e_col, True, meta, **adb_export_kwargs
)

# 2. Process ArangoDB edges
Expand Down Expand Up @@ -614,6 +614,7 @@ def y_tensor_to_2_column_dataframe(dgl_tensor):
def __fetch_adb_docs(
self,
col: str,
is_edge: bool,
meta: Union[Set[str], Dict[str, ADBMetagraphValues]],
**adb_export_kwargs: Any,
) -> Tuple[Cursor, int]:
Expand All @@ -622,6 +623,8 @@ def __fetch_adb_docs(
:param col: The ArangoDB collection.
:type col: str
:param is_edge: True if **col** is an edge collection.
:type is_edge: bool
:param meta: The MetaGraph associated to **col**
:type meta: Set[str] | Dict[str, adbdgl_adapter.typings.ADBMetagraphValues]
:param adb_export_kwargs: Keyword arguments to specify AQL query options
Expand All @@ -631,42 +634,36 @@ def __fetch_adb_docs(
:rtype: pandas.DataFrame
"""

def get_aql_return_value(
meta: Union[Set[str], Dict[str, ADBMetagraphValues]]
) -> str:
def get_aql_return_value() -> str:
"""Helper method to formulate the AQL `RETURN` value based on
the document attributes specified in **meta**
"""
attributes = []
attributes = ["_key"]
attributes += ["_from", "_to"] if is_edge else []

if type(meta) is set:
attributes = list(meta)
attributes += list(meta)

elif type(meta) is dict:
for value in meta.values():
if type(value) is str:
attributes.append(value)
elif type(value) is dict:
attributes.extend(list(value.keys()))
attributes += list(value.keys())
elif callable(value):
# Cannot determine which attributes to extract if UDFs are used
# Therefore we just return the entire document
return "doc"

return f"""
MERGE(
{{ _key: doc._key, _from: doc._from, _to: doc._to }},
KEEP(doc, {list(attributes)})
)
"""
return f"KEEP(doc, {attributes})"

col_size: int = self.__db.collection(col).count()

with get_export_spinner_progress(f"ADB Export: '{col}' ({col_size})") as p:
p.add_task(col)

cursor: Cursor = self.__db.aql.execute(
f"FOR doc IN @@col RETURN {get_aql_return_value(meta)}",
f"FOR doc IN @@col RETURN {get_aql_return_value()}",
bind_vars={"@col": col},
**{**adb_export_kwargs, **{"stream": True}},
)
Expand Down

0 comments on commit 0197b4b

Please sign in to comment.