Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: add csv badges back in Quickstart #418

Merged
merged 13 commits into from
Dec 14, 2020
3 changes: 2 additions & 1 deletion databuilder/extractor/csv_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ def _load_csv(self) -> None:
name=column_dict['name'],
description=column_dict['description'],
col_type=column_dict['col_type'],
sort_order=int(column_dict['sort_order'])
sort_order=int(column_dict['sort_order']),
badges=[column_dict['badges']]
)
parsed_columns[id].append(column)

Expand Down
10 changes: 7 additions & 3 deletions databuilder/models/badge.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright Contributors to the Amundsen project.
# SPDX-License-Identifier: Apache-2.0

from typing import List, Optional
from typing import Any, List, Optional
import re

from databuilder.models.graph_serializable import GraphSerializable
Expand All @@ -18,6 +18,10 @@ def __repr__(self) -> str:
return 'Badge({!r}, {!r})'.format(self.name,
self.category)

def __eq__(self, other: Any) -> bool: # type: ignore[override]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

other: Badge

Copy link
Contributor Author

@jornh jornh Dec 11, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah of course! Thanks - currently AFK but will change

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, apart from that it then throws an error

    def __eq__(self, other: Badge) -> bool:
E   NameError: name 'Badge' is not defined

Apparently the Class itself isn't know at that point - which I find a little strange. So I'm keeping Any unless you know a trick.

I got rid if the # type: ignore[override] though 🙂


Oh wait, through this i found out it needs to be other: 'Badge' and then mypy itself gives the full recommended way:

mypy .
databuilder/models/badge.py:21: error: Argument 1 of "__eq__" is incompatible with supertype "object"; supertype defines the argument type as "object"
databuilder/models/badge.py:21: note: This violates the Liskov substitution principle
databuilder/models/badge.py:21: note: See https://mypy.readthedocs.io/en/stable/common_issues.html#incompatible-overrides
databuilder/models/badge.py:21: note: It is recommended for "__eq__" to work with arbitrary objects, for example:
databuilder/models/badge.py:21: note:     def __eq__(self, other: object) -> bool:
databuilder/models/badge.py:21: note:         if not isinstance(other, Badge):
databuilder/models/badge.py:21: note:             return NotImplemented
databuilder/models/badge.py:21: note:         return <logic to compare two Badge instances>

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wow that's so useful

return self.name == other.name and \
self.category == other.category


class BadgeMetadata(GraphSerializable):
"""
Expand Down Expand Up @@ -86,7 +90,7 @@ def get_metadata_model_key(self) -> str:

def create_nodes(self) -> List[GraphNode]:
"""
Create a list of Neo4j node records
Create a list of `GraphNode` records
:return:
"""
results = []
Expand All @@ -103,7 +107,7 @@ def create_nodes(self) -> List[GraphNode]:
return results

def create_relation(self) -> List[GraphRelationship]:
results = []
results: List[GraphRelationship] = []
for badge in self.badges:
relation = GraphRelationship(
start_label=self.start_label,
Expand Down
42 changes: 22 additions & 20 deletions databuilder/models/table_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,16 @@
DESCRIPTION_NODE_LABEL = DESCRIPTION_NODE_LABEL_VAL


def _format_as_list(tags: Union[List, str, None]) -> List:
if tags is None:
tags = []
if isinstance(tags, str):
tags = list(filter(None, tags.split(',')))
if isinstance(tags, list):
tags = [tag.lower().strip() for tag in tags]
return tags


class TagMetadata(GraphSerializable):
TAG_NODE_LABEL = 'Tag'
TAG_KEY_FORMAT = '{tag}'
Expand Down Expand Up @@ -157,24 +167,23 @@ def __init__(self,
description: Union[str, None],
col_type: str,
sort_order: int,
badges: Union[List[str], None] = None
badges: Union[List[str], None] = None,
) -> None:
"""
TODO: Add stats
:param name:
:param description:
:param col_type:
:param sort_order:
:param badges: Optional. Column level badges
"""
self.name = name
self.description = DescriptionMetadata.create_description_metadata(source=None,
text=description)
self.type = col_type
self.sort_order = sort_order
if badges:
self.badges = [Badge(badge, 'column') for badge in badges]
else:
self.badges = []
formatted_badges = _format_as_list(badges)
self.badges = [Badge(badge, 'column') for badge in formatted_badges]

def __repr__(self) -> str:
return 'ColumnMetadata({!r}, {!r}, {!r}, {!r}, {!r})'.format(self.name,
Expand Down Expand Up @@ -260,7 +269,7 @@ def __init__(self,
self.is_view = is_view
self.attrs: Optional[Dict[str, Any]] = None

self.tags = TableMetadata.format_tags(tags)
self.tags = _format_as_list(tags)

if kwargs:
self.attrs = copy.deepcopy(kwargs)
Expand Down Expand Up @@ -324,14 +333,7 @@ def _get_col_description_key(self,

@staticmethod
def format_tags(tags: Union[List, str, None]) -> List:
if tags is None:
tags = []
if isinstance(tags, str):
tags = list(filter(None, tags.split(',')))
if isinstance(tags, list):
tags = [tag.lower().strip() for tag in tags]

return tags
return _format_as_list(tags)

def create_next_node(self) -> Union[GraphNode, None]:
try:
Expand All @@ -346,7 +348,7 @@ def _create_next_node(self) -> Iterator[GraphNode]:
node_key = self._get_table_description_key(self.description)
yield self.description.get_node(node_key)

# Create the table tag node
# Create the table tag nodes
if self.tags:
for tag in self.tags:
yield TagMetadata.create_tag_node(tag)
Expand All @@ -368,11 +370,11 @@ def _create_next_node(self) -> Iterator[GraphNode]:
yield col.description.get_node(node_key)

if col.badges:
badge_metadata = BadgeMetadata(start_label=ColumnMetadata.COLUMN_NODE_LABEL,
start_key=self._get_col_key(col),
badges=col.badges)
badge_nodes = badge_metadata.create_nodes()
for node in badge_nodes:
col_badge_metadata = BadgeMetadata(
start_label=ColumnMetadata.COLUMN_NODE_LABEL,
start_key=self._get_col_key(col),
badges=col.badges)
for node in col_badge_metadata.create_nodes():
yield node

# Database, cluster, schema
Expand Down
24 changes: 12 additions & 12 deletions example/sample_data/sample_col.csv
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
name,description,col_type,sort_order,database,cluster,schema,table_name
col1,"col1 description","string",1,hive,gold,test_schema,test_table1
col2,"col2 description","string",2,hive,gold,test_schema,test_table1
col3,"col3 description","string",3,hive,gold,test_schema,test_table1
col4,"col4 description","string",4,hive,gold,test_schema,test_table1
col5,"col5 description","float",5,hive,gold,test_schema,test_table1
col1,"col1 description","string",1,dynamo,gold,test_schema,test_table2
col2,"col2 description","string",2,dynamo,gold,test_schema,test_table2
col3,"col3 description","string",3,dynamo,gold,test_schema,test_table2
col4,"col4 description","int",4,dynamo,gold,test_schema,test_table2
col1,"view col description","int",1,hive,gold,test_schema,test_view1
col1,"col1 description","int",1,hive,gold,test_schema,test_table3
name,description,col_type,sort_order,database,cluster,schema,table_name,badges
col1,"col1 description","string",1,hive,gold,test_schema,test_table1,PK
col2,"col2 description","string",2,hive,gold,test_schema,test_table1,PII
col3,"col3 description","string",3,hive,gold,test_schema,test_table1,
col4,"col4 description","string",4,hive,gold,test_schema,test_table1,
col5,"col5 description","float",5,hive,gold,test_schema,test_table1,
col1,"col1 description","string",1,dynamo,gold,test_schema,test_table2,
col2,"col2 description","string",2,dynamo,gold,test_schema,test_table2,
col3,"col3 description","string",3,dynamo,gold,test_schema,test_table2,
col4,"col4 description","int",4,dynamo,gold,test_schema,test_table2,
col1,"view col description","int",1,hive,gold,test_schema,test_view1,
col1,"col1 description","int",1,hive,gold,test_schema,test_table3,
4 changes: 2 additions & 2 deletions example/sample_data/sample_table.csv
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
database,cluster,schema,name,description,tags,is_view,description_source
hive,gold,test_schema,test_table1,"1st test table","tag1,tag2,pii,high_quality",false,
dynamo,gold,test_schema,test_table2,"2nd test table","high_quality,recommended",false,
hive,gold,test_schema,test_table1,"1st test table","tag1,tag2",false,
dynamo,gold,test_schema,test_table2,"2nd test table",recommended,false,
hive,gold,test_schema,test_view1,"1st test view","tag1",true,
hive,gold,test_schema,test_table3,"3rd test","needs_documentation",false,
hive,gold,test_schema,"test's_table4","4th test","needs_documentation",false,
44 changes: 37 additions & 7 deletions tests/unit/extractor/test_csv_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,21 @@
from pyhocon import ConfigFactory

from databuilder import Scoped
from databuilder.extractor.csv_extractor import CsvExtractor
from databuilder.extractor.csv_extractor import CsvExtractor, CsvTableColumnExtractor
from databuilder.models.badge import Badge


class TestCsvExtractor(unittest.TestCase):

def setUp(self) -> None:
def test_extraction_with_model_class(self) -> None:
"""
Test Extraction using model class
"""
config_dict = {
'extractor.csv.{}'.format(CsvExtractor.FILE_LOCATION): 'example/sample_data/sample_table.csv',
'extractor.csv.model_class': 'databuilder.models.table_metadata.TableMetadata',
}
self.conf = ConfigFactory.from_dict(config_dict)

def test_extraction_with_model_class(self) -> None:
"""
Test Extraction using model class
"""
extractor = CsvExtractor()
extractor.init(Scoped.get_scoped_conf(conf=self.conf,
scope=extractor.get_scope()))
Expand All @@ -32,3 +31,34 @@ def test_extraction_with_model_class(self) -> None:
self.assertEqual(result.database, 'hive')
self.assertEqual(result.cluster, 'gold')
self.assertEqual(result.schema, 'test_schema')
self.assertEqual(result.tags, ['tag1', 'tag2'])
self.assertEqual(result.is_view, 'false')

result2 = extractor.extract()
self.assertEqual(result2.name, 'test_table2')
self.assertEqual(result2.is_view, 'false')

result3 = extractor.extract()
self.assertEqual(result3.name, 'test_view1')
self.assertEqual(result3.is_view, 'true')

def test_extraction_of_tablecolumn_badges(self) -> None:
"""
Test Extraction using the combined CsvTableModel model class
"""
config_dict = {
f'extractor.csvtablecolumn.{CsvTableColumnExtractor.TABLE_FILE_LOCATION}':
'example/sample_data/sample_table.csv',
f'extractor.csvtablecolumn.{CsvTableColumnExtractor.COLUMN_FILE_LOCATION}':
'example/sample_data/sample_col.csv',
}
self.conf = ConfigFactory.from_dict(config_dict)

extractor = CsvTableColumnExtractor()
extractor.init(Scoped.get_scoped_conf(conf=self.conf,
scope=extractor.get_scope()))

result = extractor.extract()
self.assertEqual(result.name, 'test_table1')
self.assertEqual(result.columns[0].badges, [Badge('pk', 'column')])
self.assertEqual(result.columns[1].badges, [Badge('pii', 'column')])