Skip to content

Commit

Permalink
Refactoring of the rendering stuff.
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcCote committed Nov 29, 2019
1 parent 6497ebb commit 593307d
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 78 deletions.
3 changes: 2 additions & 1 deletion textworld/envs/wrappers/tests/test_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@

import textworld

from textworld.utils import make_temp_directory, get_webdriver
from textworld.utils import make_temp_directory
from textworld.generator import compile_game
from textworld.envs.wrappers import HtmlViewer
from textworld.render import get_webdriver


def test_html_viewer():
Expand Down
3 changes: 2 additions & 1 deletion textworld/envs/wrappers/viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Tuple

from textworld.core import Environment, GameState, Wrapper
from textworld.render import WebdriverNotFoundError


class HtmlViewer(Wrapper):
Expand Down Expand Up @@ -88,7 +89,7 @@ def reset(self) -> GameState:
from textworld.render.serve import VisualizationService
self._server = VisualizationService(game_state, self.open_automatically)
self._server.start(threading.current_thread(), port=self._port)
except ModuleNotFoundError as e:
except WebdriverNotFoundError as e:
print("Importing HtmlViewer without installed dependencies. Try re-installing textworld.")
raise e

Expand Down
5 changes: 4 additions & 1 deletion textworld/render/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.


from textworld.render.render import WebdriverNotFoundError
from textworld.render.render import get_webdriver
from textworld.render.render import load_state, load_state_from_game_state, visualize
from textworld.render.graph import show_graph
from textworld.render.graph import show_graph
82 changes: 81 additions & 1 deletion textworld/render/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@


import io
import os
import time
import json
import tempfile
from os.path import join as pjoin
Expand All @@ -15,7 +17,7 @@
from textworld.logic import Proposition, Action
from textworld.logic import State
from textworld.generator import World, Game
from textworld.utils import maybe_mkdir, get_webdriver
from textworld.utils import maybe_mkdir

from textworld.generator.game import EntityInfo
from textworld.generator.data import KnowledgeBase
Expand Down Expand Up @@ -355,6 +357,84 @@ def concat_images(*images):
return new_im


class WebdriverNotFoundError(Exception):
pass


def which(program):
"""
helper to see if a program is in PATH
:param program: name of program
:return: path of program or None
"""
def is_exe(fpath):
return os.path.isfile(fpath) and os.access(fpath, os.X_OK)

fpath, _ = os.path.split(program)
if fpath:
if is_exe(program):
return program
else:
for path in os.environ["PATH"].split(os.pathsep):
exe_file = os.path.join(path, program)
if is_exe(exe_file):
return exe_file

return None


def get_webdriver(path=None):
"""
Get the driver and options objects.
:param path: path to browser binary.
:return: driver
"""
from selenium import webdriver

def chrome_driver(path=None):
import urllib3
from selenium.webdriver.chrome.options import Options
options = Options()
options.add_argument('headless')
options.add_argument('ignore-certificate-errors')
options.add_argument("test-type")
options.add_argument("no-sandbox")
options.add_argument("disable-gpu")
if path is not None:
options.binary_location = path

SELENIUM_RETRIES = 10
SELENIUM_DELAY = 3 # seconds
for _ in range(SELENIUM_RETRIES):
try:
return webdriver.Chrome(chrome_options=options)
except urllib3.exceptions.ProtocolError: # https://github.com/SeleniumHQ/selenium/issues/5296
time.sleep(SELENIUM_DELAY)

raise ConnectionResetError('Cannot connect to Chrome, giving up after {SELENIUM_RETRIES} attempts.')

def firefox_driver(path=None):
from selenium.webdriver.firefox.options import Options
options = Options()
options.add_argument('headless')
driver = webdriver.Firefox(firefox_binary=path, options=options)
return driver


driver_mapping = {
'geckodriver': firefox_driver,
'chromedriver': chrome_driver,
'chromium-driver': chrome_driver
}

for driver in driver_mapping.keys():
found = which(driver)
if found is not None:
return driver_mapping.get(driver, None)(path)

raise WebdriverNotFoundError("Chrome/Chromium/FireFox Webdriver not found.")


def visualize(world: Union[Game, State, GameState, World],
interactive: bool = False):
"""
Expand Down
74 changes: 0 additions & 74 deletions textworld/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import os
import re
import time
import shutil
import tempfile
import itertools
Expand Down Expand Up @@ -39,79 +38,6 @@ def next(self):
return rng


def which(program):
"""
helper to see if a program is in PATH
:param program: name of program
:return: path of program or None
"""
def is_exe(fpath):
return os.path.isfile(fpath) and os.access(fpath, os.X_OK)

fpath, fname = os.path.split(program)
if fpath:
if is_exe(program):
return program
else:
for path in os.environ["PATH"].split(os.pathsep):
exe_file = os.path.join(path, program)
if is_exe(exe_file):
return exe_file

return None


def get_webdriver(path=None):
"""
Get the driver and options objects.
:param path: path to browser binary.
:return: driver
"""
from selenium import webdriver

def chrome_driver(path=None):
import urllib3
from selenium.webdriver.chrome.options import Options
options = Options()
options.add_argument('headless')
options.add_argument('ignore-certificate-errors')
options.add_argument("test-type")
options.add_argument("no-sandbox")
options.add_argument("disable-gpu")
if path is not None:
options.binary_location = path

SELENIUM_RETRIES = 10
SELENIUM_DELAY = 3 # seconds
for _ in range(SELENIUM_RETRIES):
try:
return webdriver.Chrome(chrome_options=options)
except urllib3.exceptions.ProtocolError: # https://github.com/SeleniumHQ/selenium/issues/5296
time.sleep(SELENIUM_DELAY)

raise ConnectionResetError('Cannot connect to Chrome, giving up after {SELENIUM_RETRIES} attempts.')

def firefox_driver(path=None):
from selenium.webdriver.firefox.options import Options
options = Options()
options.add_argument('headless')
driver = webdriver.Firefox(firefox_binary=path, options=options)
return driver

driver_mapping = {
'geckodriver': firefox_driver,
'chromedriver': chrome_driver,
'chromium-driver': chrome_driver
}

for driver in driver_mapping.keys():
found = which(driver)
if found is not None:
return driver_mapping.get(driver, None)(path)

raise ModuleNotFoundError("Chrome/Chromium/FireFox Webdriver not found.")


class RegexDict(OrderedDict):
""" Ordered dictionary that supports querying with regex.
Expand Down

0 comments on commit 593307d

Please sign in to comment.