Skip to content

Commit

Permalink
Removed experimental features from main (#31)
Browse files Browse the repository at this point in the history
* Removed treePool

* Updated version

* Added tests for chat models
  • Loading branch information
maykcaldas committed Sep 28, 2023
1 parent 44bfe35 commit 1a186c2
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 391 deletions.
2 changes: 1 addition & 1 deletion bolift/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@
print("GPR Packages not installed. Do `pip install bolift[gpr]` to install them")
from .asktellRidgeRegression import AskTellRidgeKernelRegression
from .asktellNearestNeighbor import AskTellNearestNeighbor
from .pool import Pool, TreeNode, TreePool
from .pool import Pool
from .tool import BOLiftTool
53 changes: 13 additions & 40 deletions bolift/asktell.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
upper_confidence_bound,
greedy,
)
from .pool import Pool, TreePool
from .pool import Pool
from langchain.prompts.few_shot import FewShotPromptTemplate
from langchain.prompts.prompt import PromptTemplate
from langchain.vectorstores import FAISS, Chroma
Expand Down Expand Up @@ -53,7 +53,6 @@ def __init__(
model: str = "text-curie-001",
temperature: Optional[float] = None,
prefix: Optional[str] = None,
inv_prefix: Optional[str] = None,
x_formatter: Callable[[str], str] = lambda x: x,
y_formatter: Callable[[float], str] = lambda y: f"{y:0.2f}",
y_name: str = "output",
Expand Down Expand Up @@ -94,7 +93,6 @@ def __init__(
self._prompt_template = prompt_template
self._suffix = suffix
self._prefix = prefix
self._inv_prefix = inv_prefix
self._model = model
self._example_count = 0
self._temperature = temperature
Expand Down Expand Up @@ -130,11 +128,7 @@ def _setup_inv_llm(self, model: str, temperature: Optional[float] = None):
temperature=0.05 if temperature is None else temperature,
)

def _setup_inverse_prompt(self,
example: Dict,
prefix: Optional[str] = None):
if prefix is None:
prefix = ""
def _setup_inverse_prompt(self, example: Dict):
prompt_template = PromptTemplate(
input_variables=["x", "y", "y_name", "x_name"],
template="If {y_name} is {y}, then {x_name} is @@@\n{x}###",
Expand Down Expand Up @@ -163,7 +157,6 @@ def _setup_inverse_prompt(self,
example_prompt=prompt_template,
example_selector=example_selector,
suffix="If {y_name} is {y}, then {x_name} is @@@",
prefix=prefix,
input_variables=["y", "y_name", "x_name"],
)

Expand Down Expand Up @@ -270,7 +263,7 @@ def tell(self, x: str, y: float, alt_ys: Optional[List[float]] = None) -> None:
self.prompt = self._setup_prompt(
example_dict, self._prompt_template, self._suffix, self._prefix
)
self.inv_prompt = self._setup_inverse_prompt(inv_example, self._inv_prefix)
self.inv_prompt = self._setup_inverse_prompt(inv_example)
self.llm = self._setup_llm(self._model, self._temperature)
self.inv_llm = self._setup_inv_llm(self._model, self._temperature)
self._ready = True
Expand Down Expand Up @@ -321,7 +314,7 @@ def predict(self, x: str) -> Union[Tuple[float, float], List[Tuple[float, float]
self.prompt = self._setup_prompt(
None, self._prompt_template, self._suffix, self._prefix
)
self.inv_prompt = self._setup_inverse_prompt(None, self._inv_prefix)
self.inv_prompt = self._setup_inverse_prompt(None)
self.llm = self._setup_llm(self._model)
self._ready = True

Expand Down Expand Up @@ -367,7 +360,7 @@ def predict(self, x: str) -> Union[Tuple[float, float], List[Tuple[float, float]

def ask(
self,
possible_x: Union[Pool, List[str], TreePool, OrderedDict[str, Any]],
possible_x: Union[Pool, List[str]],
aq_fxn: str = "upper_confidence_bound",
k: int = 1,
inv_filter: int = 16,
Expand All @@ -383,16 +376,14 @@ def ask(
aq_fxn: Acquisition function to use.
k: Number of x values to return.
inv_filter: Reduce pool size to this number with inverse model. If 0, not used
aug_random_filter: Add this many random examples to the pool to increase diversity after reducing pool with inverse model
aug_random_filter: Add this man y random examples to the pool to increase diversity after reducing pool with inverse model
_lambda: Lambda value to use for UCB
Return:
The selected x values, their acquisition function values, and the predicted y modes.
Sorted by acquisition function value (descending)
"""
if type(possible_x) == type([]):
possible_x = Pool(possible_x, self.format_x)
elif type(possible_x) == type(OrderedDict()):
possible_x = TreePool(possible_x, self._prompt_template.prompt, self.format_x) #need to input the string for the prompt template

if aq_fxn == "probability_of_improvement":
aq_fxn = probability_of_improvement
Expand All @@ -416,33 +407,15 @@ def ask(
else:
best = np.max(self._ys)

if isinstance(possible_x, Pool):
if inv_filter+aug_random_filter < len(possible_x):
possible_x_l = []
print(inv_filter, aug_random_filter)
if inv_filter:
approx_x = self.inv_predict(best * np.random.normal(1.0, 0.05))
possible_x_l.extend(possible_x.approx_sample(approx_x, inv_filter))
if aug_random_filter:
possible_x_l.extend(possible_x.sample(aug_random_filter))
else:
possible_x_l = list(possible_x)
elif isinstance(possible_x, TreePool):
if inv_filter+aug_random_filter < len(possible_x):
possible_x_l = []
while len(possible_x_l) < k:
node = possible_x._root
while not node.is_leaf():
partial_possible_x = [possible_x.partial_format_prompt(child.get_branch()) for child in node.get_children_list()]
node_retriever = dict(zip(partial_possible_x, node.get_children_list()))
selected_child = self._ask(partial_possible_x, best, aq_fxn, 1)
selected_child = selected_child[0][0]
node = node_retriever[selected_child]
selected = possible_x.format_prompt(node.get_branch())
while selected in possible_x_l:
selected = possible_x.sample(1)[0]
possible_x_l.append(selected)
if inv_filter:
approx_x = self.inv_predict(best * np.random.normal(1.0, 0.05))
possible_x_l.extend(possible_x.approx_sample(approx_x, inv_filter))
if aug_random_filter:
possible_x_l.extend(possible_x.sample(aug_random_filter))
else:
raise ValueError("Unknown pool type")
possible_x_l = list(possible_x)

results = self._ask(possible_x_l, best, aq_fxn, k)
if len(results[0]) == 0 and len(possible_x_l) != 0:
Expand Down
Loading

0 comments on commit 1a186c2

Please sign in to comment.