Skip to content

Commit

Permalink
Add FaqGen Accuracy scripts & Refine Ragas (#91)
Browse files Browse the repository at this point in the history
* fix ragas to align latest code

Signed-off-by: Xinyao Wang <[email protected]>

* add FaqGen Accuracy scripts

Signed-off-by: Xinyao Wang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix bug

Signed-off-by: Xinyao Wang <[email protected]>

---------

Signed-off-by: Xinyao Wang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
XinyaoWa and pre-commit-ci[bot] authored Sep 5, 2024
1 parent 514a6d6 commit 4df6438
Show file tree
Hide file tree
Showing 8 changed files with 227 additions and 14 deletions.
27 changes: 17 additions & 10 deletions evals/metrics/ragas/ragas.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,14 @@ def __init__(
self.embeddings = embeddings
self.metrics = metrics
self.validated_list = [
"answer_relevancy",
"faithfulness",
"answer_correctness",
"answer_relevancy",
"answer_similarity",
"context_precision",
"context_relevancy",
"context_recall",
"faithfulness",
"context_utilization",
"reference_free_rubrics_score",
]

async def a_measure(self, test_case: Dict):
Expand All @@ -55,8 +56,9 @@ def measure(self, test_case: Dict):
answer_similarity,
context_precision,
context_recall,
context_relevancy,
context_utilization,
faithfulness,
reference_free_rubrics_score,
)

except ModuleNotFoundError:
Expand All @@ -67,8 +69,14 @@ def measure(self, test_case: Dict):
except ModuleNotFoundError:
raise ModuleNotFoundError("Please install dataset")
self.metrics_instance = {
"answer_correctness": answer_correctness,
"answer_relevancy": answer_relevancy,
"answer_similarity": answer_similarity,
"context_precision": context_precision,
"context_recall": context_recall,
"faithfulness": faithfulness,
"context_utilization": context_utilization,
"reference_free_rubrics_score": reference_free_rubrics_score,
}

# Set LLM model
Expand Down Expand Up @@ -101,7 +109,7 @@ def measure(self, test_case: Dict):
else:
if metric == "answer_relevancy" and self.embeddings is None:
raise ValueError("answer_relevancy metric need provide embeddings model.")
tmp_metrics.append(metric)
tmp_metrics.append(self.metrics_instance[metric])
self.metrics = tmp_metrics
else:
self.metrics = [
Expand All @@ -110,15 +118,14 @@ def measure(self, test_case: Dict):
answer_correctness,
answer_similarity,
context_precision,
context_relevancy,
context_recall,
]

data = {
"question": test_case["input"],
"contexts": test_case["retrieval_context"],
"answer": test_case["actual_output"],
"ground_truth": test_case["expected_output"],
"question": test_case["question"],
"contexts": test_case["contexts"],
"answer": test_case["answer"],
"ground_truth": test_case["ground_truth"],
}
dataset = Dataset.from_dict(data)

Expand Down
61 changes: 61 additions & 0 deletions examples/FaqGen/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
## Dataset
We evaluate performance on QA dataset [Squad_v2](https://huggingface.co/datasets/rajpurkar/squad_v2). Generate FAQs on "context" columns in validation dataset, which contains 1204 unique records.

First download dataset and put at "./data".

Extract unique "context" columns, which will be save to 'data/sqv2_context.json':
```
python get_context.py
```

## Generate FAQs

### Launch FaQGen microservice
Please refer to [FaQGen microservice](https://github.com/opea-project/GenAIComps/tree/main/comps/llms/faq-generation/tgi), set up an microservice endpoint.
```
export FAQ_ENDPOINT = "http://${your_ip}:9000/v1/faqgen"
```

### Generate FAQs with microservice
Use the microservice endpoint to generate FAQs for dataset.
```
python generate_FAQ.py
```

Post-process the output to get the right data, which will be save to 'data/sqv2_faq.json'.
```
python post_process_FAQ.py
```

## Evaluate with Ragas

### Launch TGI service
We use "mistralai/Mixtral-8x7B-Instruct-v0.1" as LLM referee to evaluate the model. First we need to launch a LLM endpoint on Gaudi.
```
export HUGGING_FACE_HUB_TOKEN="your_huggingface_token"
bash launch_tgi.sh
```
Get the endpoint:
```
export LLM_ENDPOINT = "http://${ip_address}:8082"
```

Verify the service:
```bash
curl http://${ip_address}:8082/generate \
-X POST \
-d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":128}}' \
-H 'Content-Type: application/json'
```

### Evaluate
evaluate the performance with the LLM:
```
python evaluate.py
```

### Performance Result
Here is the tested result for your reference
| answer_relevancy | faithfulness | context_utilization | reference_free_rubrics_score |
| ---- | ---- |---- |---- |
| 0.7191 | 0.9681 | 0.8964 | 4.4125|
45 changes: 45 additions & 0 deletions examples/FaqGen/evaluate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import json
import os

from langchain_community.embeddings import HuggingFaceBgeEmbeddings

from evals.metrics.ragas import RagasMetric

llm_endpoint = os.getenv("LLM_ENDPOINT", "http://0.0.0.0:8082")

f = open("data/sqv2_context.json", "r")
sqv2_context = json.load(f)

f = open("data/sqv2_faq.json", "r")
sqv2_faq = json.load(f)

templ = """Create a concise FAQs (frequently asked questions and answers) for following text:
TEXT: {text}
Do not use any prefix or suffix to the FAQ.
"""

number = 1204
question = []
answer = []
ground_truth = ["None"] * number
contexts = []
for i in range(number):
inputs = sqv2_context[str(i)]
inputs_faq = templ.format_map({"text": inputs})
actual_output = sqv2_faq[str(i)]

question.append(inputs_faq)
answer.append(actual_output)
contexts.append([inputs_faq])

embeddings = HuggingFaceBgeEmbeddings(model_name="BAAI/bge-base-en-v1.5")
metrics_faq = ["answer_relevancy", "faithfulness", "context_utilization", "reference_free_rubrics_score"]
metric = RagasMetric(threshold=0.5, model=llm_endpoint, embeddings=embeddings, metrics=metrics_faq)

test_case = {"question": question, "answer": answer, "ground_truth": ground_truth, "contexts": contexts}

metric.measure(test_case)
print(metric.score)
28 changes: 28 additions & 0 deletions examples/FaqGen/generate_FAQ.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import json
import os
import time

import requests

llm_endpoint = os.getenv("FAQ_ENDPOINT", "http://0.0.0.0:9000/v1/faqgen")

f = open("data/sqv2_context.json", "r")
sqv2_context = json.load(f)

start_time = time.time()
headers = {"Content-Type": "application/json"}
for i in range(1204):
start_time_tmp = time.time()
print(i)
inputs = sqv2_context[str(i)]
data = {"query": inputs, "max_new_tokens": 128}
response = requests.post(llm_endpoint, json=data, headers=headers)
f = open(f"data/result/sqv2_faq_{i}", "w")
f.write(inputs)
f.write(str(response.content, encoding="utf-8"))
f.close()
print(f"Cost {time.time()-start_time_tmp} seconds")
print(f"\n Finished! \n Totally Cost {time.time()-start_time} seconds\n")
17 changes: 17 additions & 0 deletions examples/FaqGen/get_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import json
import os

import pandas as pd

data_path = "./data"
data = pd.read_parquet(os.path.join(data_path, "squad_v2/squad_v2/validation-00000-of-00001.parquet"))
sq_context = list(data["context"].unique())
sq_context_d = dict()
for i in range(len(sq_context)):
sq_context_d[i] = sq_context[i]

with open(os.path.join(data_path, "sqv2_context.json"), "w") as outfile:
json.dump(sq_context_d, outfile)
28 changes: 28 additions & 0 deletions examples/FaqGen/launch_tgi.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

max_input_tokens=3072
max_total_tokens=4096
port_number=8082
model_name="mistralai/Mixtral-8x7B-Instruct-v0.1"
volume="./data"
docker run -it --rm \
--name="tgi_Mixtral" \
-p $port_number:80 \
-v $volume:/data \
--runtime=habana \
--restart always \
-e HUGGING_FACE_HUB_TOKEN=$HUGGING_FACE_HUB_TOKEN \
-e HABANA_VISIBLE_DEVICES=all \
-e OMPI_MCA_btl_vader_single_copy_mechanism=none \
-e PT_HPU_ENABLE_LAZY_COLLECTIVES=true \
--cap-add=sys_nice \
--ipc=host \
-e HTTPS_PROXY=$https_proxy \
-e HTTP_PROXY=$https_proxy \
ghcr.io/huggingface/tgi-gaudi:2.0.1 \
--model-id $model_name \
--max-input-tokens $max_input_tokens \
--max-total-tokens $max_total_tokens \
--sharded true \
--num-shard 2
27 changes: 27 additions & 0 deletions examples/FaqGen/post_process_FAQ.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import json

faq_dict = {}
fails = []
for i in range(1204):
data = open(f"data/result/sqv2_faq_{i}", "r").readlines()
result = data[-6][6:]
# print(result)
if "LLMChain/final_output" not in result:
print(f"error1: fail for {i}")
fails.append(i)
continue
try:
result2 = json.loads(result)
result3 = result2["ops"][0]["value"]["text"]
faq_dict[str(i)] = result3
except:
print(f"error2: fail for {i}")
fails.append(i)
continue
with open("data/sqv2_faq.json", "w") as outfile:
json.dump(faq_dict, outfile)
print("Failure index:")
print(fails)
8 changes: 4 additions & 4 deletions tests/test_ragas.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ def test_ragas(self):
embeddings = HuggingFaceBgeEmbeddings(model_name="BAAI/bge-base-en-v1.5")
metric = RagasMetric(threshold=0.5, model="http://localhost:8008", embeddings=embeddings)
test_case = {
"input": ["What if these shoes don't fit?"],
"actual_output": [actual_output],
"expected_output": [expected_output],
"retrieval_context": [retrieval_context],
"question": ["What if these shoes don't fit?"],
"answer": [actual_output],
"ground_truth": [expected_output],
"contexts": [retrieval_context],
}

metric.measure(test_case)
Expand Down

0 comments on commit 4df6438

Please sign in to comment.