Skip to content

Commit

Permalink
fix build error due to changed exception messages and UCI repository
Browse files Browse the repository at this point in the history
expired certificate
  • Loading branch information
imatiach-msft committed Nov 13, 2023
1 parent 73dc621 commit 1fa0b6e
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
"import pandas as pd\n",
"\n",
"from urllib.request import urlretrieve\n",
"import requests\n",
"import zipfile\n",
"\n",
"# Imports for SHAP MimicExplainer with LightGBM surrogate model\n",
Expand All @@ -86,7 +87,13 @@
"source": [
"outdirname = 'superconduct'\n",
"zipfilename = outdirname + '.zip'\n",
"urlretrieve('https://archive.ics.uci.edu/static/public/464/superconductivty+data.zip', zipfilename)\n",
"url = 'https://archive.ics.uci.edu/static/public/464/superconductivty+data.zip'\n",
"# temporary workaround for UCI repository until website SSL certificate\n",
"# is renewed with requests instead of urlretrieve\n",
"# urlretrieve(url, zipfilename)\n",
"content = requests.get(url, verify=False).content\n",
"with open(zipfilename, mode='wb') as localfile:\n",
" localfile.write(content)\n",
"with zipfile.ZipFile(zipfilename, 'r') as unzip:\n",
" unzip.extractall('.')\n",
"df = pd.read_csv(r'./train.csv')\n",
Expand Down Expand Up @@ -167,9 +174,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"metadata": {},
"outputs": [],
"source": [
"predictions = model.predict(X_test)\n",
Expand Down Expand Up @@ -244,9 +249,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"metadata": {},
"outputs": [],
"source": [
"ErrorAnalysisDashboard(global_explanation, model, dataset=X_test,\n",
Expand All @@ -262,7 +265,7 @@
}
],
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -276,9 +279,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.13"
"version": "3.10.11"
}
},
"nbformat": 4,
"nbformat_minor": 2
"nbformat_minor": 4
}
5 changes: 5 additions & 0 deletions responsibleai/tests/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import json
import random
import ssl

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -87,6 +88,9 @@ def fetch(self):


def create_adult_income_dataset(create_small_dataset=True):
# workaround for SSL expiration error from UCI website
default_http_context = ssl._create_default_https_context
ssl._create_default_https_context = ssl._create_unverified_context
fetcher = FetchDiceAdultCensusIncomeDataset()
action_name = "Adult dataset download"
err_msg = "Failed to download adult dataset"
Expand Down Expand Up @@ -122,6 +126,7 @@ def create_adult_income_dataset(create_small_dataset=True):
data_train, data_test, y_train, y_test = train_test_split(
dataset, target, test_size=5000, random_state=7, stratify=target
)
ssl._create_default_https_context = default_http_context
return (
data_train,
data_test,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import os
import shutil
import sys
from pathlib import Path
from tempfile import TemporaryDirectory

Expand All @@ -14,6 +15,7 @@
create_text_classification_pipeline,
load_covid19_emergency_event_dataset,
load_emotion_dataset)
from huggingface_hub.utils._validators import HFValidationError
from rai_text_insights_validator import validate_rai_text_insights

from responsibleai._internal.constants import ManagerNames
Expand Down Expand Up @@ -128,8 +130,13 @@ def test_loading_rai_insights_without_model_file(self):
model_name = 'text-classification-model'
model_pkl_path = Path(tmpdir) / "rai_insights" / model_name
shutil.rmtree(model_pkl_path)
match_msg = 'Can\'t load the configuration'
with pytest.raises(OSError, match=match_msg):
if sys.version_info[:2] == (3, 7):
match_msg = 'Can\'t load the configuration'
expected_error = OSError
else:
match_msg = 'Repo id must be in the form'
expected_error = HFValidationError
with pytest.raises(expected_error, match=match_msg):
without_model_rai_insights = RAITextInsights.load(save_path)
assert without_model_rai_insights.model is None

Expand Down

0 comments on commit 1fa0b6e

Please sign in to comment.