Skip to content

Commit

Permalink
Make evaluation run a context manager instead of a singleton. (#3529)
Browse files Browse the repository at this point in the history
# Description

In this PR we are doing three things:
1. Making evaluator run a context manager
2. Moving function, writing properties to the run history to the
evaluation run as a method
3. Removing code to keep evaluator run as a singletone as we, generally,
do not need it.
4. Doing some unittest refactoring and adding/removing appropriate
tests.
See work item 3334452.

# All Promptflow Contribution checklist:
- [x] **The pull request does not introduce [breaking changes].**
- [x] **CHANGELOG is updated for new features, bug fixes or other
significant changes.**
- [x] **I have read the [contribution guidelines](../CONTRIBUTING.md).**
- [x] **Create an issue and link to the pull request to get dedicated
review from promptflow team. Learn more: [suggested
workflow](../CONTRIBUTING.md#suggested-workflow).**

## General Guidelines and Best Practices
- [x] Title of the pull request is clear and informative.
- [x] There are a small number of commits, each of which have an
informative message. This means that previously merged commits do not
appear in the history of the PR. For more information on cleaning up the
commits in your PR, [see this
page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md).

### Testing Guidelines
- [x] Pull request includes test coverage for the included changes.
  • Loading branch information
nick863 authored Jul 11, 2024
1 parent 5a3396f commit 2ba76d1
Show file tree
Hide file tree
Showing 4 changed files with 487 additions and 435 deletions.
229 changes: 133 additions & 96 deletions src/promptflow-evals/promptflow/evals/evaluate/_eval_run.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
import contextlib
import dataclasses
import enum
import logging
import os
import posixpath
import requests
import time
import uuid
from typing import Any, Dict, Optional, Type
from typing import Any, Dict, Optional, Set
from urllib.parse import urlparse

from requests.adapters import HTTPAdapter
Expand Down Expand Up @@ -52,28 +54,15 @@ def generate(run_name: Optional[str]) -> 'RunInfo':
)


class Singleton(type):
"""Singleton class, which will be used as a metaclass."""
class RunStatus(enum.Enum):
"""Run states."""
NOT_STARTED = 0
STARTED = 1
BROKEN = 2
TERMINATED = 3

_instances = {}

def __call__(cls, *args, **kwargs):
"""Redefinition of call to return one instance per type."""
if cls not in Singleton._instances:
Singleton._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return Singleton._instances[cls]

@staticmethod
def destroy(cls: Type) -> None:
"""
Destroy the singleton instance.
:param cls: The class to be destroyed.
"""
Singleton._instances.pop(cls, None)


class EvalRun(metaclass=Singleton):
class EvalRun(contextlib.AbstractContextManager):
"""
The simple singleton run class, used for accessing artifact store.
Expand Down Expand Up @@ -119,25 +108,18 @@ def __init__(self,
self._workspace_name: str = workspace_name
self._ml_client: Any = ml_client
self._is_promptflow_run: bool = promptflow_run is not None
self._is_broken = False
if self._tracking_uri is None:
LOGGER.warning("tracking_uri was not provided, "
"The results will be saved locally, but will not be logged to Azure.")
self._url_base = None
self._is_broken = True
self.info = RunInfo.generate(run_name)
else:
self._url_base = urlparse(self._tracking_uri).netloc
if promptflow_run is not None:
self.info = RunInfo(
promptflow_run.name,
promptflow_run._experiment_name,
promptflow_run.name
)
else:
self._is_broken = self._start_run(run_name)
self._run_name = run_name
self._promptflow_run = promptflow_run
self._status = RunStatus.NOT_STARTED

self._is_terminated = False
@property
def status(self) -> RunStatus:
"""
Return the run status.
:return: The status of the run.
"""
return self._status

def _get_scope(self) -> str:
"""
Expand All @@ -156,76 +138,97 @@ def _get_scope(self) -> str:
self._workspace_name,
)

def _start_run(self, run_name: Optional[str]) -> bool:
def _start_run(self) -> None:
"""
Make a request to start the mlflow run. If the run will not start, it will be
marked as broken and the logging will be switched off.
:param run_name: The display name for the run.
:type run_name: Optional[str]
:returns: True if the run has started and False otherwise.
Start the run, or, if it is not applicable (for example, if tracking is not enabled), mark it as started.
"""
url = f"https://{self._url_base}/mlflow/v2.0" f"{self._get_scope()}/api/2.0/mlflow/runs/create"
body = {
"experiment_id": "0",
"user_id": "promptflow-evals",
"start_time": int(time.time() * 1000),
"tags": [{"key": "mlflow.user", "value": "promptflow-evals"}],
}
if run_name:
body["run_name"] = run_name
response = self.request_with_retry(
url=url,
method='POST',
json_dict=body
)
if response.status_code != 200:
self.info = RunInfo.generate(run_name)
LOGGER.warning(f"The run failed to start: {response.status_code}: {response.text}."
self._check_state_and_log('start run',
{v for v in RunStatus if v != RunStatus.NOT_STARTED},
True)
self._status = RunStatus.STARTED
if self._tracking_uri is None:
LOGGER.warning("tracking_uri was not provided, "
"The results will be saved locally, but will not be logged to Azure.")
return True
parsed_response = response.json()
self.info = RunInfo(
run_id=parsed_response['run']['info']['run_id'],
experiment_id=parsed_response['run']['info']['experiment_id'],
run_name=parsed_response['run']['info']['run_name']
)
return False

def end_run(self, status: str) -> None:
self._url_base = None
self._status = RunStatus.BROKEN
self.info = RunInfo.generate(self._run_name)
else:
self._url_base = urlparse(self._tracking_uri).netloc
if self._promptflow_run is not None:
self.info = RunInfo(
self._promptflow_run.name,
self._promptflow_run._experiment_name,
self._promptflow_run.name
)
else:
url = f"https://{self._url_base}/mlflow/v2.0" f"{self._get_scope()}/api/2.0/mlflow/runs/create"
body = {
"experiment_id": "0",
"user_id": "promptflow-evals",
"start_time": int(time.time() * 1000),
"tags": [{"key": "mlflow.user", "value": "promptflow-evals"}],
}
if self._run_name:
body["run_name"] = self._run_name
response = self.request_with_retry(
url=url,
method='POST',
json_dict=body
)
if response.status_code != 200:
self.info = RunInfo.generate(self._run_name)
LOGGER.warning(f"The run failed to start: {response.status_code}: {response.text}."
"The results will be saved locally, but will not be logged to Azure.")
self._status = RunStatus.BROKEN
else:
parsed_response = response.json()
self.info = RunInfo(
run_id=parsed_response['run']['info']['run_id'],
experiment_id=parsed_response['run']['info']['experiment_id'],
run_name=parsed_response['run']['info']['run_name']
)
self._status = RunStatus.STARTED

def _end_run(self, reason: str) -> None:
"""
Tetminate the run.
:param status: One of "FINISHED" "FAILED" and "KILLED"
:type status: str
:param reason: One of "FINISHED" "FAILED" and "KILLED"
:type reason: str
:raises: ValueError if the run is not in ("FINISHED", "FAILED", "KILLED")
"""
if not self._check_state_and_log('stop run',
{RunStatus.BROKEN, RunStatus.NOT_STARTED, RunStatus.TERMINATED},
False):
return
if self._is_promptflow_run:
# This run is already finished, we just add artifacts/metrics to it.
Singleton.destroy(EvalRun)
self._status = RunStatus.TERMINATED
return
if status not in ("FINISHED", "FAILED", "KILLED"):
if reason not in ("FINISHED", "FAILED", "KILLED"):
raise ValueError(
f"Incorrect terminal status {status}. " 'Valid statuses are "FINISHED", "FAILED" and "KILLED".'
f"Incorrect terminal status {reason}. " 'Valid statuses are "FINISHED", "FAILED" and "KILLED".'
)
if self._is_terminated:
LOGGER.warning("Unable to stop run because it was already terminated.")
return
if self._is_broken:
LOGGER.warning("Unable to stop run because the run failed to start.")
return
url = f"https://{self._url_base}/mlflow/v2.0" f"{self._get_scope()}/api/2.0/mlflow/runs/update"
body = {
"run_uuid": self.info.run_id,
"status": status,
"status": reason,
"end_time": int(time.time() * 1000),
"run_id": self.info.run_id,
}
response = self.request_with_retry(url=url, method="POST", json_dict=body)
if response.status_code != 200:
LOGGER.warning("Unable to terminate the run.")
Singleton.destroy(EvalRun)
self._is_terminated = True
self._status = RunStatus.TERMINATED

def __enter__(self):
"""The Context Manager enter call."""
self._start_run()
return self

def __exit__(self, exc_type, exc_value, exc_tb):
"""The context manager exit call."""
self._end_run("FINISHED")

def get_run_history_uri(self) -> str:
"""
Expand Down Expand Up @@ -306,6 +309,33 @@ def _log_warning(self, failed_op: str, response: requests.Response) -> None:
f"{response.text=}."
)

def _check_state_and_log(
self,
action: str,
bad_states: Set[RunStatus],
should_raise: bool) -> bool:
"""
Check that the run is in the correct state and log worning if it is not.
:param action: Action, which caused this check. For example if it is "log artifact",
the log message will start "Unable to log artifact."
:type action: str
:param bad_states: The states, considered invalid for given action.
:type bad_states: set
:param should_raise: Should we raise an error if the bad state has been encountered?
:type should_raise: bool
:raises: RuntimeError if should_raise is True and invalid state was encountered.
:return: boolean saying if run is in the correct state.
"""
if self._status in bad_states:
msg = f"Unable to {action} due to Run status={self._status}."
if should_raise:
raise RuntimeError(msg)
else:
LOGGER.warning(msg)
return False
return True

def log_artifact(self, artifact_folder: str, artifact_name: str = EVALUATION_ARTIFACT) -> None:
"""
The local implementation of mlflow-like artifact logging.
Expand All @@ -316,8 +346,7 @@ def log_artifact(self, artifact_folder: str, artifact_name: str = EVALUATION_ART
:param artifact_folder: The folder with artifacts to be uploaded.
:type artifact_folder: str
"""
if self._is_broken:
LOGGER.warning("Unable to log artifact because the run failed to start.")
if not self._check_state_and_log('log artifact', {RunStatus.BROKEN, RunStatus.NOT_STARTED}, False):
return
# Check if artifact dirrectory is empty or does not exist.
if not os.path.isdir(artifact_folder):
Expand Down Expand Up @@ -404,8 +433,7 @@ def log_metric(self, key: str, value: float) -> None:
:param value: The valure to be logged.
:type value: float
"""
if self._is_broken:
LOGGER.warning("Unable to log metric because the run failed to start.")
if not self._check_state_and_log('log metric', {RunStatus.BROKEN, RunStatus.NOT_STARTED}, False):
return
body = {
"run_uuid": self.info.run_id,
Expand All @@ -423,11 +451,20 @@ def log_metric(self, key: str, value: float) -> None:
if response.status_code != 200:
self._log_warning("save metrics", response)

@staticmethod
def get_instance(*args, **kwargs) -> "EvalRun":
def write_properties_to_run_history(self, properties: Dict[str, Any]) -> None:
"""
The convenience method to the the EvalRun instance.
Write properties to the RunHistory service.
:return: The EvalRun instance.
:param properties: The properties to be written to run history.
:type properties: dict
"""
return EvalRun(*args, **kwargs)
if not self._check_state_and_log('write properties', {RunStatus.BROKEN, RunStatus.NOT_STARTED}, False):
return
# update host to run history and request PATCH API
response = self.request_with_retry(
url=self.get_run_history_uri(),
method="PATCH",
json_dict={"runId": self.info.run_id, "properties": properties},
)
if response.status_code != 200:
LOGGER.error("Fail writing properties '%s' to run history: %s", properties, response.text)
Loading

0 comments on commit 2ba76d1

Please sign in to comment.