Skip to content

Commit

Permalink
Merge pull request #23 from tosemml/tosemml-patch-1
Browse files Browse the repository at this point in the history
Code improvements using Numpy
  • Loading branch information
ysymyth authored Aug 28, 2024
2 parents 17dbef5 + ad98d38 commit 0837da4
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 10 deletions.
7 changes: 3 additions & 4 deletions tests/web-agent-site/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@
import shutil
from pathlib import Path
from web_agent_site.utils import *
import numpy as np

def test_random_idx():
random.seed(24)
weights = [random.randint(0, 10) for _ in range(0, 50)]
cml_weights = [0]
for w in weights:
cml_weights.append(cml_weights[-1] + w)
cml_weights = [0] + np.cumsum(weights).tolist()
idx_1, expected_1 = random_idx(cml_weights), 44
idx_2, expected_2 = random_idx(cml_weights), 15
idx_3, expected_3 = random_idx(cml_weights), 36
Expand Down Expand Up @@ -46,4 +45,4 @@ def test_generate_mturk_code():
for session_id, expected in suite:
output = generate_mturk_code(session_id)
assert type(expected) is str
assert output == expected
assert output == expected
10 changes: 4 additions & 6 deletions web_agent_site/envs/web_agent_text_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import time
import torch

import numpy as np

from bs4 import BeautifulSoup
from bs4.element import Comment
from collections import defaultdict
Expand Down Expand Up @@ -314,9 +316,7 @@ def __init__(
# Imposes `limit` on goals via random selection
if limit_goals != -1 and limit_goals < len(self.goals):
self.weights = [goal['weight'] for goal in self.goals]
self.cum_weights = [0]
for w in self.weights:
self.cum_weights.append(self.cum_weights[-1] + w)
self.cum_weights = [0] + np.cumsum(self.weights).tolist()
idxs = []
while len(idxs) < limit_goals:
idx = random_idx(self.cum_weights)
Expand All @@ -327,9 +327,7 @@ def __init__(

# Set extraneous housekeeping variables
self.weights = [goal['weight'] for goal in self.goals]
self.cum_weights = [0]
for w in self.weights:
self.cum_weights.append(self.cum_weights[-1] + w)
self.cum_weights = [0] + np.cumsum(self.weights).tolist()
self.user_sessions = dict()
self.search_time = 0
self.render_time = 0
Expand Down

0 comments on commit 0837da4

Please sign in to comment.