Skip to content

Commit

Permalink
ProVe_main_process update
Browse files Browse the repository at this point in the history
  • Loading branch information
dignityc committed Nov 12, 2024
1 parent c45c136 commit b846b89
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 51 deletions.
40 changes: 33 additions & 7 deletions ProVe_main_process.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,28 @@
from wikidata_parser import WikidataParser
from refs_html_collection import HTMLFetcher
from refs_html_to_evidences import HTMLSentenceProcessor
from refs_html_to_evidences import EvidenceSelector
from claim_entailment import process_entailment
from refs_html_to_evidences import HTMLSentenceProcessor, EvidenceSelector
from claim_entailment import ClaimEntailmentChecker
from utils.textual_entailment_module import TextualEntailmentModule
from utils.sentence_retrieval_module import SentenceRetrievalModule
from utils.verbalisation_module import VerbModule

if __name__ == "__main__":
qid = 'Q44'
def initialize_models():
"""Initialize all required models once"""
text_entailment = TextualEntailmentModule()
sentence_retrieval = SentenceRetrievalModule(max_len=512)
verb_module = VerbModule()
return text_entailment, sentence_retrieval, verb_module

def process_entity(qid: str, models: tuple) -> tuple:
"""
Process a single entity with pre-loaded models
"""
text_entailment, sentence_retrieval, verb_module = models

# Initialize processors with pre-loaded models
selector = EvidenceSelector(sentence_retrieval=sentence_retrieval,
verb_module=verb_module)
checker = ClaimEntailmentChecker(text_entailment=text_entailment)

# Get URLs and claims
parser = WikidataParser()
Expand All @@ -20,9 +37,18 @@
sentences_df = processor.process_html_to_sentences(html_df)

# Process evidence selection
selector = EvidenceSelector()
evidence_df = selector.process_evidence(sentences_df, parser_result)

# Check entailment with metadata
entailment_results = process_entailment(evidence_df, html_df, qid)
entailment_results = checker.process_entailment(evidence_df, html_df, qid)

return html_df, evidence_df, entailment_results

if __name__ == "__main__":
# Initialize models once
models = initialize_models()

# Process entity
qid = 'Q44'
html_df, evidence_df, entailment_results = process_entity(qid, models)

Empty file added ProVe_main_service.py
Empty file.
81 changes: 40 additions & 41 deletions claim_entailment.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
from datetime import datetime

class ClaimEntailmentChecker:
def __init__(self, config_path: str = 'config.yaml'):
def __init__(self, config_path: str = 'config.yaml', text_entailment=None):
self.logger = logging.getLogger(__name__)
self.config = self.load_config(config_path)
self.te_module = TextualEntailmentModule()
# Use provided model or create new one
self.te_module = text_entailment or TextualEntailmentModule()

@staticmethod
def load_config(config_path: str) -> Dict:
Expand Down Expand Up @@ -168,45 +169,43 @@ def process_evidence(self, sentences_df: pd.DataFrame, parser_result: Dict) -> p

return evidence_df

def process_entailment(evidence_df: pd.DataFrame, html_df: pd.DataFrame, qid: str) -> pd.DataFrame:
"""
Main function to process entailment checking
"""
checker = ClaimEntailmentChecker()

# Add URLs from html_df using reference_id
evidence_df = evidence_df.merge(
html_df[['reference_id', 'url']],
on='reference_id',
how='left'
)

# Check entailment and keep original probabilities
entailment_results = checker.check_entailment(evidence_df)
probabilities = entailment_results['evidence_TE_prob'].copy() # 원본 확률값 저장

# Format results
aggregated_results = checker.format_results(entailment_results)

# Get final verdict
final_verdict = checker.get_final_verdict(aggregated_results)
aggregated_results = pd.concat([aggregated_results, final_verdict], axis=1)

# Keep only necessary columns and drop 'Results'
final_results = aggregated_results[['text_entailment_score', 'similarity_score',
'processed_timestamp', 'result',
'result_sentence', 'reference_id']]

# Add label probabilities using the saved probabilities
final_results['label_probabilities'] = probabilities.apply(
lambda x: {
'SUPPORTS': float(x[0][0]),
'REFUTES': float(x[0][1]),
'NOT ENOUGH INFO': float(x[0][2])
}
)

return final_results
def process_entailment(self, evidence_df: pd.DataFrame, html_df: pd.DataFrame, qid: str) -> pd.DataFrame:
"""
Main function to process entailment checking
"""
# Add URLs from html_df using reference_id
evidence_df = evidence_df.merge(
html_df[['reference_id', 'url']],
on='reference_id',
how='left'
)

# Check entailment and keep original probabilities
entailment_results = self.check_entailment(evidence_df)
probabilities = entailment_results['evidence_TE_prob'].copy()

# Format results
aggregated_results = self.format_results(entailment_results)

# Get final verdict
final_verdict = self.get_final_verdict(aggregated_results)
aggregated_results = pd.concat([aggregated_results, final_verdict], axis=1)

# Keep only necessary columns and drop 'Results'
final_results = aggregated_results[['text_entailment_score', 'similarity_score',
'processed_timestamp', 'result',
'result_sentence', 'reference_id']]

# Add label probabilities using the saved probabilities
final_results['label_probabilities'] = probabilities.apply(
lambda x: {
'SUPPORTS': float(x[0][0]),
'REFUTES': float(x[0][1]),
'NOT ENOUGH INFO': float(x[0][2])
}
)

return final_results

if __name__ == "__main__":
qid = 'Q44'
Expand Down
7 changes: 4 additions & 3 deletions refs_html_to_evidences.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,15 @@ def slide_sentences(sentences, window_size=2):
return valid_html_df[['reference_id', 'url', 'nlp_sentences', 'nlp_sentences_slide_2']]

class EvidenceSelector:
def __init__(self):
def __init__(self, sentence_retrieval=None, verb_module=None):
self.logger = logging.getLogger(__name__)
self.endpoint_url = "https://query.wikidata.org/sparql"
self.headers = {
'User-Agent': 'Mozilla/5.0 (compatible; MyBot/1.0; mailto:[email protected])'
}
self.verb_module = VerbModule()
self.sentence_retrieval = SentenceRetrievalModule(max_len=512)
# Use provided models or create new ones
self.verb_module = verb_module or VerbModule()
self.sentence_retrieval = sentence_retrieval or SentenceRetrievalModule(max_len=512)
self.top_k = 5

def get_labels_from_sparql(self, property_ids: List[str], entity_ids: List[str]) -> Tuple[Dict[str, str], Dict[str, str]]:
Expand Down

0 comments on commit b846b89

Please sign in to comment.