Skip to content

Commit

Permalink
Enoriega/issue837 (#844)
Browse files Browse the repository at this point in the history
## Summary of Changes

Added support for:
- Generalized AMR linking.
- Resolving AMR type automatically
- Expanding the pdf annotation endpoint to link to an amr when provided
one as input

### Related issues

Resolves #837
  • Loading branch information
enoriega authored Mar 5, 2024
1 parent 156e193 commit db6f7f4
Show file tree
Hide file tree
Showing 11 changed files with 1,285 additions and 38 deletions.
770 changes: 770 additions & 0 deletions skema/metal/model_linker/examples/data/gamr.json

Large diffs are not rendered by default.

21 changes: 12 additions & 9 deletions skema/metal/model_linker/skema_model_linker/linkers/amr_linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,22 @@ def _generate_linking_sources(self, elements: Iterable[JsonNode]) -> Dict[str, L

def _align_texts(self, sources: List[str], targets: List[str], threshold: float) -> List[Tuple[str, str]]:

with torch.no_grad():
s_embs = self._model.encode(sources)
t_embs = self._model.encode(targets)
if len(sources) > 0 and len(targets) > 0:
with torch.no_grad():
s_embs = self._model.encode(sources)
t_embs = self._model.encode(targets)

similarities = util.pytorch_cos_sim(s_embs, t_embs)
similarities = util.pytorch_cos_sim(s_embs, t_embs)

indices = (similarities >= threshold).nonzero()
indices = (similarities >= threshold).nonzero()

ret = list()
for ix in indices:
ret.append((sources[ix[0]], targets[ix[1]]))
ret = list()
for ix in indices:
ret.append((sources[ix[0]], targets[ix[1]]))

return ret
return ret
else:
return []

def _generate_linking_targets(self, extractions: Iterable[Attribute]) -> Dict[str, List[AnchoredEntity]]:
""" Will generate candidate texts to link to model elements """
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from collections import defaultdict
from typing import Iterable, Dict, List, Any

from . import heuristics, AMRLinker
from ..walkers import JsonDictWalker, JsonNode
from ..walkers import GeneralizedAMRWalker


class GeneralizedAMRLinker(AMRLinker):

def _generate_linking_sources(self, elements: Iterable[JsonNode]) -> Dict[str, List[Any]]:
ret = defaultdict(list)
for name, val, ix in elements:
if (name == "states") or (name == "parameters" and 'name' in val):
key = val['name'].strip()
lower_case_key = key.lower()

if "description" in val:
ret[f"{key}: {val['description']}"] = val
else:
if lower_case_key in heuristics:
descriptions = heuristics[lower_case_key]
for desc in descriptions:
ret[f"{key}: {desc}"] = val
ret[key] = val

return ret

def _build_walker(self, amr_data) -> JsonDictWalker:
return GeneralizedAMRWalker(amr_data)
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
from .json import JsonNode, JsonDictWalker
from .petrinet import PetriNetWalker
from .regnet import RegNetWalker
from .generalized_amr import GeneralizedAMRWalker

__all__ =[
__all__ = [
"ModelWalker",
"JsonNode",
"JsonDictWalker",
"PetriNetWalker",
"RegNetWalker",
]
"GeneralizedAMRWalker",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from typing import Optional, Any

from . import JsonDictWalker


class GeneralizedAMRWalker(JsonDictWalker):

def _filter(self, obj_name: Optional[str], obj: Any, index: Optional[int]) -> bool:
return obj_name in {"states", "parameters"}
14 changes: 7 additions & 7 deletions skema/metal/model_linker/skema_model_linker/walkers/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,15 @@ def __step(self, obj_name: Optional[str], obj: Any, index: Optional[int] = None,
if allowed and callback:
callback(obj_name, obj, index)

for prop, val in obj.items():
if isinstance(val, list):
for ix, elem in enumerate(val):
if type(elem) in (list, dict):
ret.extend(self.__step(prop, elem, ix, callback, **kwargs))
elif isinstance(val, dict):
if isinstance(obj, list):
for ix, elem in enumerate(obj):
if type(elem) in (list, dict):
ret.extend(self.__step(obj_name, elem, ix, callback, **kwargs))
elif isinstance(obj, dict):
for prop, val in obj.items():
ret.extend(self.__step(prop, val, None, callback, **kwargs))

if allowed:
if allowed and not isinstance(obj, list):
ret.append(JsonNode(obj_name, obj, index))

return ret
Expand Down
29 changes: 27 additions & 2 deletions skema/rest/integrated_text_reading_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from askem_extractions.data_model import AttributeCollection
from askem_extractions.importers import import_arizona
from fastapi import APIRouter, Depends, FastAPI, UploadFile, Response, status
from langchain.tools.e2b_data_analysis.tool import UploadedFile

from skema.rest.proxies import SKEMA_TR_ADDRESS, MIT_TR_ADDRESS, OPENAI_KEY, COSMOS_ADDRESS
from skema.rest.schema import (
Expand All @@ -23,7 +24,7 @@
TextReadingDocumentResults,
TextReadingError, MiraGroundingInputs, MiraGroundingOutputItem, TextReadingEvaluationResults,
)
from skema.rest import utils
from skema.rest import utils, metal_proxy

router = APIRouter()

Expand Down Expand Up @@ -448,6 +449,7 @@ async def integrated_text_extractions(
async def integrated_pdf_extractions(
response: Response,
pdfs: List[UploadFile],
amrs: List[UploadFile] = [],
annotate_skema: bool = True,
annotate_mit: bool = True
) -> TextReadingAnnotationsOutput:
Expand Down Expand Up @@ -481,7 +483,7 @@ async def integrated_pdf_extractions(
plain_texts = ['\n'.join(block['content'] for block in c) for c in cosmos_data]

# Run the text extractors
return integrated_extractions(
extractions = integrated_extractions(
response,
annotate_pdfs_with_skema,
cosmos_data,
Expand All @@ -490,6 +492,29 @@ async def integrated_pdf_extractions(
annotate_mit
)

# Do the alignment
aligned_amrs = list()
if len(amrs) > 0:
# Build an UploadFile instance from the extractions
json_extractions = extractions.model_dump_json()
extractions_ufile = UploadFile(file=io.BytesIO(json_extractions.encode('utf-8')))
for amr in amrs:
try:
aligned_amr = metal_proxy.link_amr(
amr_file=amr,
text_extractions_file=extractions_ufile)
aligned_amrs.append(aligned_amr)
except Exception as e:
error = TextReadingError(pipeline="AMR Linker", message=f"Error annotating {amr.filename}: {e}")
if extractions.generalized_errors is None:
extractions.generalized_errors = [error]
else:
extractions.generalized_errors.append(error)

extractions.aligned_amrs = aligned_amrs

return extractions


# These are some direct proxies to the SKEMA and MIT APIs
@router.post(
Expand Down
27 changes: 15 additions & 12 deletions skema/rest/metal_proxy.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import json
import itertools as it
import json

from askem_extractions.data_model import AttributeCollection
from fastapi import UploadFile, File, APIRouter, FastAPI
from pydantic import Json


from skema.metal.model_linker.skema_model_linker.linkers import PetriNetLinker, RegNetLinker
from skema.metal.model_linker.skema_model_linker.link_amr import replace_xml_codepoints
from skema.rest.schema import TextReadingAnnotationsOutput, TextReadingEvaluationResults, AMRLinkingEvaluationResults
from skema.metal.model_linker.skema_model_linker.linkers import PetriNetLinker, RegNetLinker
from skema.metal.model_linker.skema_model_linker.linkers.generalizer_amr_linker import GeneralizedAMRLinker
from skema.rest.schema import AMRLinkingEvaluationResults
from skema.rest.utils import compute_amr_linking_evaluation

router = APIRouter()
Expand All @@ -17,25 +16,21 @@
@router.post(
"/link_amr",
)
def link_amr(amr_type: str,
similarity_model: str = "sentence-transformers/all-MiniLM-L6-v2",
def link_amr(similarity_model: str = "sentence-transformers/all-MiniLM-L6-v2",
similarity_threshold: float = 0.5,
amr_file: UploadFile = File(...),
text_extractions_file: UploadFile = File(...)):
""" Links an AMR to a text extractions file
### Python example
```
params = {
"amr_type": "petrinet"
}
files = {
"amr_file": ("amr.json", open("amr.json"), "application/json"),
"text_extractions_file": ("extractions.json", open("extractions.json"), "application/json")
}
response = requests.post(f"{ENDPOINT}/metal/link_amr", params=params, files=files)
response = requests.post(f"{ENDPOINT}/metal/link_amr", files=files)
if response.status_code == 200:
enriched_amr = response.json()
```
Expand All @@ -59,13 +54,21 @@ def link_amr(amr_type: str,
extractions = AttributeCollection.from_json(raw_extractions)
# text_extractions = TextReadingAnnotationsOutput(**json.load(text_extractions_file.file))


# Get the AMR type from the header of the json
if 'schema_name' in amr:
amr_type = amr['schema_name'].lower()
elif 'header' in amr and 'schema_name' in amr['header']:
amr_type = amr['header']['schema_name'].lower()
else:
raise Exception("Schema name missing in AMR")

# Link the AMR
if amr_type == "petrinet":
Linker = PetriNetLinker
elif amr_type == "regnet":
Linker = RegNetLinker
elif amr_type == "generalized amr":
Linker = GeneralizedAMRLinker
else:
raise NotImplementedError(f"{amr_type} AMR currently not supported")

Expand Down
8 changes: 7 additions & 1 deletion skema/rest/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"""
Response models for API
"""
from typing import List, Optional
from typing import List, Optional, Dict, Any

from askem_extractions.data_model import AttributeCollection
from pydantic import BaseModel, Field
Expand Down Expand Up @@ -256,3 +256,9 @@ class TextReadingAnnotationsOutput(BaseModel):
description="Any pipeline-wide errors, not specific to a particular input",
examples=[[TextReadingError(pipeline="MIT", message="API quota exceeded")]],
)

aligned_amrs: List[Dict[str, Any]] = Field(
description="An aligned list of AMRs to the text extractions. This field will be populated only if it was"
" provided as part of the input",
default_factory=lambda: []
)
Loading

0 comments on commit db6f7f4

Please sign in to comment.