Skip to content

Commit

Permalink
Add graph visualization support.
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcCote committed Oct 12, 2019
1 parent f1ac489 commit 00c3fc1
Show file tree
Hide file tree
Showing 7 changed files with 232 additions and 11 deletions.
11 changes: 1 addition & 10 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,8 @@ tatsu>=4.3.0
hashids>=1.2.0
jericho>=1.1.5

# For visualization
pybars3>=0.9.3
flask>=1.0.2
selenium>=3.12.0
greenlet==0.4.13
gevent==1.3.5
pillow>=5.1.0
pydot>=1.2.4

# For advanced prompt
prompt_toolkit<2.1.0,>=2.0.0
prompt_toolkit

# For gym support
gym>=0.10.11
14 changes: 14 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,18 @@ def run(self):
tests_require=[
'nose==1.3.7',
],
extras_require={
'vis': [
'pybars3>=0.9.3',
'flask>=1.0.2',
'selenium>=3.12.0',
'greenlet==0.4.13',
'gevent==1.3.5',
'pillow>=5.1.0',
'plotly>=4.0.0',
'pydot>=1.2.4',
'psutil',
'matplotlib',
],
},
)
4 changes: 4 additions & 0 deletions textworld/envs/glulx/git_glulx_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,10 @@ def entities(self):
def extras(self):
return self._game.extras

@property
def game(self):
return self._game.serialize()


class GitGlulxMLEnvironment(textworld.Environment):
""" Environment to support playing Glulx games generated by TextWorld.
Expand Down
3 changes: 3 additions & 0 deletions textworld/envs/wrappers/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class EnvInfos:
'entities', 'verbs', 'command_templates',
'admissible_commands', 'intermediate_reward',
'policy_commands',
'game',
'extras']

def __init__(self, **kwargs):
Expand Down Expand Up @@ -77,6 +78,8 @@ def __init__(self, **kwargs):
#: bool: Templates for commands understood by the the game.
#: This information *doesn't* change from one step to another.
self.command_templates = kwargs.get("command_templates", False)
#: bool: Current game in its serialized form. Use with `textworld.Game.deserialize`.
self.game = kwargs.get("game", False)
#: List[str]: Names of extra information which are game specific.
self.extras = kwargs.get("extras", [])

Expand Down
5 changes: 4 additions & 1 deletion textworld/generator/inform7/world2inform7.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,10 @@ def gen_commands_from_actions(self, actions: Iterable[Action]) -> List[str]:
return commands

def get_human_readable_fact(self, fact: Proposition) -> Proposition:
arguments = [Variable(self.entity_infos[var.name].name, var.type) for var in fact.arguments]
def _get_name(info):
return info.name if info.name else info.id

arguments = [Variable(_get_name(self.entity_infos[var.name]), var.type) for var in fact.arguments]
return Proposition(fact.name, arguments)

def get_human_readable_action(self, action: Action) -> Action:
Expand Down
1 change: 1 addition & 0 deletions textworld/render/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@


from textworld.render.render import load_state, load_state_from_game_state, visualize
from textworld.render.graph import show_graph
205 changes: 205 additions & 0 deletions textworld/render/graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
from typing import Iterable, Tuple, Optional

import numpy as np
import networkx as nx

from textworld.logic import Proposition


def build_graph_from_facts(facts: Iterable[Proposition]) -> nx.DiGraph:
""" Builds a graph from a collection of facts.
Arguments:
facts: Collection of facts representing a state of a game.
Returns:
The underlying graph representation.
"""
G = nx.DiGraph()
labels = {}
for fact in facts:
# Extract relation triplet from fact (subject, object, relation)
triplet = (*fact.names, fact.name)
triplet = triplet if len(triplet) >= 3 else triplet + ("is",)

src = triplet[0]
dest = triplet[1]
relation = triplet[-1]
if relation in {"is"}:
# For entity properties and states, we artificially
# add unique node for better visualization.
dest = src + "-" + dest

labels[src] = triplet[0]
labels[dest] = triplet[1]
G.add_edge(src, dest, type=triplet[-1])

nx.set_node_attributes(G, labels, 'label')
return G


def show_graph(facts: Iterable[Proposition],
title: str = "Knowledge Graph",
renderer:Optional[str] = None,
save:Optional[str] = None) -> "plotly.graph_objs._figure.Figure":

r""" Visualizes the graph made from a collection of facts.
Arguments:
facts: Collection of facts representing a state of a game.
title: Title for the figure
renderer:
Which Plotly's renderer to use (e.g., 'browser').
save:
If provided, path where to save a PNG version of the graph.
Returns:
The Plotly's figure representing the graph.
Example:
>>> import textworld
>>> options = textworld.GameOptions()
>>> options.seeds = 1234
>>> game_file, game = textworld.make(options)
>>> import gym
>>> import textworld.gym
>>> from textworld import EnvInfos
>>> request_infos = EnvInfos(facts=True)
>>> env_id = textworld.gym.register_game(game_file, request_infos)
>>> env = gym.make(env_id)
>>> _, infos = env.reset()
>>> textworld.render.show_graph(infos["facts"])
"""

# Local imports for optional dependencies
try:
import plotly.graph_objects as go
import matplotlib.pylab as plt
except:
raise ImportError('Visualization dependencies not installed. Try running `pip install textworld[vis]`')

G = build_graph_from_facts(facts)

plt.figure(figsize=(16, 9))
pos = nx.drawing.nx_pydot.pydot_layout(G, prog="fdp")

edge_labels_pos = {}
trace3_list = []
for edge in G.edges(data=True):
trace3 = go.Scatter(
x=[],
y=[],
mode='lines',
line=dict(width=0.5, color='#888', shape='spline', smoothing=1),
hoverinfo='none'
)
x0, y0 = pos[edge[0]]
x1, y1 = pos[edge[1]]
rvec = (x0-x1, y0-y1) # Vector from dest -> src.
length = np.sqrt(rvec[0] ** 2 + rvec[1] ** 2)
mid = ((x0+x1)/2., (y0+y1)/2.)
orthogonal = (rvec[1] / length, -rvec[0] / length)

trace3['x'] += (x0, mid[0] + 0 * orthogonal[0], x1, None)
trace3['y'] += (y0, mid[1] + 0 * orthogonal[1], y1, None)
trace3_list.append(trace3)

offset_ = 5
edge_labels_pos[(pos[edge[0]], pos[edge[1]])] = (mid[0] + offset_ * orthogonal[0],
mid[1] + offset_ * orthogonal[1])

node_x = []
node_y = []
node_labels = []
for node, data in G.nodes(data=True):
x, y = pos[node]
node_x.append(x)
node_y.append(y)
node_labels.append("<b>{}</b>".format(data['label'].replace(" ", "<br>")))

node_trace = go.Scatter(
x=node_x,
y=node_y,
mode='text',
text=node_labels,
textfont=dict(
family="sans serif",
size=12,
color="black"
),
hoverinfo='none',
marker=dict(
showscale=True,
color=[],
size=10,
line_width=2
)
)

fig = go.Figure(
data=[*trace3_list, node_trace],
layout=go.Layout(
title=title,
titlefont_size=16,
showlegend=False,
hovermode='closest',
margin=dict(b=20,l=5,r=5,t=40),
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False)
)
)

def _get_angle(p0, p1):
x0, y0 = p0
x1, y1 = p1
if x1 == x0:
return 0

angle = -np.rad2deg(np.arctan((y1-y0)/(x1-x0)/(16/9)))
return angle

# Add relation names and relation arrows.
annotations = []
for edge in G.edges(data=True):
annotations.append(
go.layout.Annotation(
x=pos[edge[1]][0],
y=pos[edge[1]][1],
ax=(pos[edge[0]][0]+pos[edge[1]][0])/2,
ay=(pos[edge[0]][1]+pos[edge[1]][1])/2,
axref="x",
ayref="y",
showarrow=True,
arrowhead=2,
arrowsize=3,
arrowwidth=0.5,
arrowcolor="#888",
standoff=5 + np.log(90 / abs(_get_angle(pos[edge[0]], pos[edge[1]]))) * max(map(len, G.nodes[edge[1]]['label'].split())),
)
)
annotations.append(
go.layout.Annotation(
x=edge_labels_pos[(pos[edge[0]], pos[edge[1]])][0],
y=edge_labels_pos[(pos[edge[0]], pos[edge[1]])][1],
showarrow=False,
text="<i>{}</i>".format(edge[2]['type']),
textangle=_get_angle(pos[edge[0]], pos[edge[1]]),
font=dict(
family="sans serif",
size=12,
color="blue"
),
)
)

fig.update_layout(annotations=annotations)

if renderer:
fig.show(renderer=renderer)

if save:
fig.write_image(save, width=1920, height=1080, scale=4)

return fig

0 comments on commit 00c3fc1

Please sign in to comment.