Skip to content
This repository has been archived by the owner on Aug 25, 2024. It is now read-only.

Commit

Permalink
model: Predict only yields repo
Browse files Browse the repository at this point in the history
Signed-off-by: John Andersen <[email protected]>
  • Loading branch information
John Andersen authored and pdxjohnny committed Oct 3, 2019
1 parent d3bf6c9 commit 65e4ce4
Show file tree
Hide file tree
Showing 9 changed files with 37 additions and 30 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- shouldi example runs bandit now in addition to safety
- The way safety gets called
- Switched documentation to Read The Docs theme
- Models yield only a repo object instead of the value and confidence of the
prediction as well. Models are not responsible for calling the predicted
method on the repo. This will ease the process of making predict feature
specific.
### Fixed
- Docs get version from dffml.version.VERSION.
- FileSource zipfiles are wrapped with TextIOWrapper because CSVSource expects
Expand Down
4 changes: 1 addition & 3 deletions dffml/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,7 @@ async def accuracy(self, sources: Sources) -> Accuracy:
raise NotImplementedError()

@abc.abstractmethod
async def predict(
self, repos: AsyncIterator[Repo]
) -> AsyncIterator[Tuple[Repo, Any, float]]:
async def predict(self, repos: AsyncIterator[Repo]) -> AsyncIterator[Repo]:
"""
Uses trained data to make a prediction about the quality of a repo.
"""
Expand Down
14 changes: 12 additions & 2 deletions dffml/repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
Information on the software to evaluate is stored in a Repo instance.
"""
import os
import warnings
from datetime import datetime
from typing import Optional, List, Dict, Any, AsyncIterator

Expand All @@ -21,7 +22,7 @@ class RepoPrediction(dict):

EXPORTED = ["value", "confidence"]

def __init__(self, *, confidence: float = 0.0, value: Any = "") -> None:
def __init__(self, *, confidence: float = 0.0, value: Any = None) -> None:
self["confidence"] = confidence
self["value"] = value

Expand All @@ -39,7 +40,7 @@ def dict(self):
return self

def __len__(self):
if self["confidence"] == 0.0 and not self["value"]:
if self["confidence"] == 0.0 and self["value"] is None:
return 0
return 2

Expand Down Expand Up @@ -128,6 +129,15 @@ def __init__(
self.extra = extra

def dict(self):
# TODO Remove dict method in favor of export
warnings.warn(
"dict method will be removed in favor of export",
DeprecationWarning,
stacklevel=2,
)
return self.export()

def export(self):
data = self.data.dict()
data["extra"] = self.extra
return data
Expand Down
7 changes: 3 additions & 4 deletions model/scikit/dffml_model_scikit/scikit_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,7 @@ async def accuracy(self, sources: Sources) -> Accuracy:
self.logger.debug("Model Accuracy: {}".format(self.confidence))
return self.confidence

async def predict(
self, repos: AsyncIterator[Repo]
) -> AsyncIterator[Tuple[Repo, Any, float]]:
async def predict(self, repos: AsyncIterator[Repo]) -> AsyncIterator[Repo]:
if self.confidence is None:
raise ValueError("Model Not Trained")
async for repo in repos:
Expand All @@ -107,7 +105,8 @@ async def predict(
self.clf.predict(predict),
)
)
yield repo, self.clf.predict(predict)[0], self.confidence
repo.predicted(self.clf.predict(predict)[0], self.confidence)
yield repo


class Scikit(Model):
Expand Down
5 changes: 2 additions & 3 deletions model/scikit/tests/test_scikit.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,8 @@ async def test_01_accuracy(self):
async def test_02_predict(self):
async with self.sources as sources, self.features as features, self.model as model:
async with sources() as sctx, model(features) as mctx:
async for repo, prediction, confidence in mctx.predict(
sctx.repos()
):
async for repo in mctx.predict(sctx.repos()):
prediction = repo.prediction().value
if self.MODEL_TYPE is "CLASSIFICATION":
self.assertIn(prediction, [2, 4])
elif self.MODEL_TYPE is "REGRESSION":
Expand Down
12 changes: 6 additions & 6 deletions model/scratch/dffml_model_scratch/slr.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,14 +114,14 @@ async def accuracy(self, sources: Sources) -> Accuracy:
accuracy_value = self.regression_line[2]
return Accuracy(accuracy_value)

async def predict(
self, repos: AsyncIterator[Repo]
) -> AsyncIterator[Tuple[Repo, Any, float]]:
async def predict(self, repos: AsyncIterator[Repo]) -> AsyncIterator[Repo]:
async for repo in repos:
feature_data = repo.features(self.features)
yield repo, await self.predict_input(
feature_data[self.features[0]]
), self.regression_line[2]
repo.predicted(
await self.predict_input(feature_data[self.features[0]]),
self.regression_line[2],
)
yield repo


@entry_point("slr")
Expand Down
10 changes: 4 additions & 6 deletions model/scratch/tests/test_slr.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,10 @@ async def test_context(self):
res = await mctx.accuracy(sctx)
self.assertTrue(0.0 <= res < 1.0)
# Test predict
async for repo, prediction, confidence in mctx.predict(
sctx.repos()
):
async for repo in mctx.predict(sctx.repos()):
correct = FEATURE_DATA[int(repo.src_url)][1]
# Comparison of correct to prediction to make sure prediction is within a reasonable range
prediction = repo.prediction().value
self.assertGreater(prediction, correct - (correct * 0.10))
self.assertLess(prediction, correct + (correct * 0.10))

Expand All @@ -90,9 +89,8 @@ async def test_01_accuracy(self):
async def test_02_predict(self):
async with self.sources as sources, self.features as features, self.model as model:
async with sources() as sctx, model(features) as mctx:
async for repo, prediction, confidence in mctx.predict(
sctx.repos()
):
async for repo in mctx.predict(sctx.repos()):
correct = FEATURE_DATA[int(repo.src_url)][1]
prediction = repo.prediction().value
self.assertGreater(prediction, correct - (correct * 0.10))
self.assertLess(prediction, correct + (correct * 0.10))
7 changes: 3 additions & 4 deletions model/tensorflow/dffml_model_tensorflow/dnnc.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,9 +308,7 @@ async def accuracy(self, sources: Sources) -> Accuracy:
accuracy_score = self.model.evaluate(input_fn=input_fn)
return Accuracy(accuracy_score["accuracy"])

async def predict(
self, repos: AsyncIterator[Repo]
) -> AsyncIterator[Tuple[Repo, Any, float]]:
async def predict(self, repos: AsyncIterator[Repo]) -> AsyncIterator[Repo]:
"""
Uses trained data to make a prediction about the quality of a repo.
"""
Expand All @@ -323,7 +321,8 @@ async def predict(
for repo, pred_dict in zip(predict, predictions):
class_id = pred_dict["class_ids"][0]
probability = pred_dict["probabilities"][class_id]
yield repo, self.cids[class_id], probability
repo.predicted(self.cids[class_id], probability)
yield repo


@entry_point("tfdnnc")
Expand Down
4 changes: 2 additions & 2 deletions model/tensorflow/tests/test_dnnc.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,5 +111,5 @@ async def test_02_predict(self):
async with sources() as sctx, model(features) as mctx:
res = [repo async for repo in mctx.predict(sctx.repos())]
self.assertEqual(len(res), 1)
self.assertEqual(res[0][0].src_url, a.src_url)
self.assertTrue(res[0][1])
self.assertEqual(res[0].src_url, a.src_url)
self.assertTrue(res[0].prediction().value)

0 comments on commit 65e4ce4

Please sign in to comment.