Skip to content

Commit

Permalink
Merge pull request #1479 from fetchai/feature/gym_async_and_tests
Browse files Browse the repository at this point in the history
[AEA-685] Gym connection fully async and full test coverage
  • Loading branch information
DavidMinarsch authored Jul 8, 2020
2 parents 94028f8 + 79601d8 commit 23aa8d6
Show file tree
Hide file tree
Showing 9 changed files with 178 additions and 97 deletions.
14 changes: 7 additions & 7 deletions docs/gym-skill.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Follow the <a href="../quickstart/#preliminaries">Preliminaries</a> and <a href=

First, fetch the gym AEA:
``` bash
aea fetch fetchai/gym_aea:0.4.0 --alias my_gym_aea
aea fetch fetchai/gym_aea:0.5.0 --alias my_gym_aea
cd my_gym_aea
aea install
```
Expand All @@ -34,15 +34,15 @@ aea create my_gym_aea
cd my_gym_aea
```

### Add the gym skill
### Add the gym skill
``` bash
aea add skill fetchai/gym:0.4.0
```

### Add a gym connection
``` bash
aea add connection fetchai/gym:0.3.0
aea config set agent.default_connection fetchai/gym:0.3.0
aea add connection fetchai/gym:0.4.0
aea config set agent.default_connection fetchai/gym:0.4.0
```

### Install the skill dependencies
Expand Down Expand Up @@ -90,13 +90,13 @@ aea delete my_gym_aea
```

## Communication
This diagram shows the communication between the AEA and the gym environment
This diagram shows the communication between the AEA and the gym environment

<div class="mermaid">
sequenceDiagram
participant AEA
participant Environment

activate AEA
activate Environment
AEA->>Environment: reset
Expand All @@ -105,7 +105,7 @@ This diagram shows the communication between the AEA and the gym environment
Environment->>AEA: percept
end
AEA->>Environment: close

deactivate AEA
deactivate Environment
</div>
Expand Down
6 changes: 3 additions & 3 deletions packages/fetchai/agents/gym_aea/aea-config.yaml
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
agent_name: gym_aea
author: fetchai
version: 0.4.0
version: 0.5.0
description: The gym aea demos the interaction between a skill containing a RL agent
and a gym connection.
license: Apache-2.0
aea_version: '>=0.5.0, <0.6.0'
fingerprint: {}
fingerprint_ignore_patterns: []
connections:
- fetchai/gym:0.3.0
- fetchai/gym:0.4.0
- fetchai/stub:0.6.0
contracts: []
protocols:
Expand All @@ -17,7 +17,7 @@ protocols:
skills:
- fetchai/error:0.3.0
- fetchai/gym:0.4.0
default_connection: fetchai/gym:0.3.0
default_connection: fetchai/gym:0.4.0
default_ledger: fetchai
ledger_apis: {}
logging_config:
Expand Down
108 changes: 52 additions & 56 deletions packages/fetchai/connections/gym/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@

import asyncio
import logging
import threading
from asyncio import CancelledError
from typing import Dict, Optional, cast
from asyncio.events import AbstractEventLoop
from concurrent.futures.thread import ThreadPoolExecutor
from typing import Optional, cast

import gym

Expand All @@ -39,57 +40,58 @@

"""default 'to' field for Gym envelopes."""
DEFAULT_GYM = "gym"
PUBLIC_ID = PublicId.from_str("fetchai/gym:0.3.0")
PUBLIC_ID = PublicId.from_str("fetchai/gym:0.4.0")


class GymChannel:
"""A wrapper of the gym environment."""

THREAD_POOL_SIZE = 3

def __init__(self, address: Address, gym_env: gym.Env):
"""Initialize a gym channel."""
self.address = address
self.gym_env = gym_env
self._lock = threading.Lock()

self._queues = {} # type: Dict[str, asyncio.Queue]
self._loop: Optional[AbstractEventLoop] = None
self._queue: Optional[asyncio.Queue] = None
self._threaded_pool: ThreadPoolExecutor = ThreadPoolExecutor(
self.THREAD_POOL_SIZE
)

@property
def queue(self) -> asyncio.Queue:
"""Check queue is set and return queue."""
if self._queue is None: # pragma: nocover
raise ValueError("Channel is not connected")
return self._queue

def connect(self) -> Optional[asyncio.Queue]:
async def connect(self) -> None:
"""
Connect an address to the gym.
:return: an asynchronous queue, that constitutes the communication channel.
"""
if self.address in self._queues:
if self._queue: # pragma: nocover
return None
self._loop = asyncio.get_event_loop()
self._queue = asyncio.Queue()

assert len(self._queues.keys()) == 0, "Only one address can register to a gym."
q = asyncio.Queue() # type: asyncio.Queue
self._queues[self.address] = q
return q

def send(self, envelope: Envelope) -> None:
async def send(self, envelope: Envelope) -> None:
"""
Process the envelopes to the gym.
:return: None
"""
sender = envelope.sender
logger.debug("Processing message from {}: {}".format(sender, envelope))
self._decode_envelope(envelope)

def _decode_envelope(self, envelope: Envelope) -> None:
"""
Decode the envelope.
:param envelope: the envelope
:return: None
"""
if envelope.protocol_id == GymMessage.protocol_id:
self.handle_gym_message(envelope)
else:
if envelope.protocol_id != GymMessage.protocol_id:
raise ValueError("This protocol is not valid for gym.")
await self.handle_gym_message(envelope)

def handle_gym_message(self, envelope: Envelope) -> None:
async def _run_in_executor(self, fn, *args):
return await self._loop.run_in_executor(self._threaded_pool, fn, *args)

async def handle_gym_message(self, envelope: Envelope) -> None:
"""
Forward a message to gym.
Expand All @@ -103,7 +105,11 @@ def handle_gym_message(self, envelope: Envelope) -> None:
if gym_message.performative == GymMessage.Performative.ACT:
action = gym_message.action.any
step_id = gym_message.step_id
observation, reward, done, info = self.gym_env.step(action) # type: ignore

observation, reward, done, info = await self._run_in_executor(
self.gym_env.step, action
)

msg = GymMessage(
performative=GymMessage.Performative.PERCEPT,
observation=GymMessage.AnyObject(observation),
Expand All @@ -118,29 +124,34 @@ def handle_gym_message(self, envelope: Envelope) -> None:
protocol_id=GymMessage.protocol_id,
message=msg,
)
self._send(envelope)
await self._send(envelope)
elif gym_message.performative == GymMessage.Performative.RESET:
self.gym_env.reset() # type: ignore
await self._run_in_executor(self.gym_env.reset)
elif gym_message.performative == GymMessage.Performative.CLOSE:
self.gym_env.close() # type: ignore
await self._run_in_executor(self.gym_env.close)

def _send(self, envelope: Envelope) -> None:
async def _send(self, envelope: Envelope) -> None:
"""Send a message.
:param envelope: the envelope
:return: None
"""
destination = envelope.to
self._queues[destination].put_nowait(envelope)
assert envelope.to == self.address, "Invalid destination address"
await self.queue.put(envelope)

def disconnect(self) -> None:
async def disconnect(self) -> None:
"""
Disconnect.
:return: None
"""
with self._lock:
self._queues.pop(self.address, None)
if self._queue is not None:
await self._queue.put(None)
self._queue = None

async def get(self) -> Optional[Envelope]:
"""Get incoming envelope."""
return await self.queue.get()


class GymConnection(Connection):
Expand Down Expand Up @@ -172,7 +183,7 @@ async def connect(self) -> None:
"""
if not self.connection_status.is_connected:
self.connection_status.is_connected = True
self._connection = self.channel.connect()
await self.channel.connect()

async def disconnect(self) -> None:
"""
Expand All @@ -181,12 +192,8 @@ async def disconnect(self) -> None:
:return: None
"""
if self.connection_status.is_connected:
assert self._connection is not None
self.connection_status.is_connected = False
await self._connection.put(None)
self.channel.disconnect()
self._connection = None
self.stop()
await self.channel.disconnect()

async def send(self, envelope: Envelope) -> None:
"""
Expand All @@ -199,7 +206,7 @@ async def send(self, envelope: Envelope) -> None:
raise ConnectionError(
"Connection not established yet. Please use 'connect()'."
)
self.channel.send(envelope)
await self.channel.send(envelope)

async def receive(self, *args, **kwargs) -> Optional["Envelope"]:
"""Receive an envelope."""
Expand All @@ -208,18 +215,7 @@ async def receive(self, *args, **kwargs) -> Optional["Envelope"]:
"Connection not established yet. Please use 'connect()'."
)
try:
assert self._connection is not None
envelope = await self._connection.get()
if envelope is None:
return None
envelope = await self.channel.get()
return envelope
except CancelledError: # pragma: no cover
return None

def stop(self) -> None:
"""
Tear down the connection.
:return: None
"""
self._connection = None
4 changes: 2 additions & 2 deletions packages/fetchai/connections/gym/connection.yaml
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
name: gym
author: fetchai
version: 0.3.0
version: 0.4.0
description: The gym connection wraps an OpenAI gym.
license: Apache-2.0
aea_version: '>=0.5.0, <0.6.0'
fingerprint:
__init__.py: QmWwxj1hGGZNteCvRtZxwtY9PuEKsrWsEmMWCKwiYCdvRR
connection.py: QmU7asAG4fddYm5K8YKLKrrAvg1CY147r9yH6KwE7u3aPJ
connection.py: QmV2REDadG36ogXD3eW4Ms82gUfWdjAQJcNJ6ik48P1CC4
fingerprint_ignore_patterns: []
protocols:
- fetchai/gym:0.3.0
Expand Down
4 changes: 2 additions & 2 deletions packages/hashes.csv
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ fetchai/agents/erc1155_client,QmYSpr6ZTEXWheNxUD5Z7dXUS1eZxbe498JuHqwwjKs888
fetchai/agents/erc1155_deployer,QmevosZhB78HTPQAb62v8hLCdtcSqdoSQnKWwKkbXT55L4
fetchai/agents/generic_buyer,QmPAdWvKuw3VFxxQi9NkMPAC4ymAwVSftaYbc5upBTtPtf
fetchai/agents/generic_seller,QmUF18HoArCHf6mLdXjq1zXCuJKY7JwXXSYTdfsWCwPWKn
fetchai/agents/gym_aea,QmWAx6DS9ZNLwabo4cmJamx4nUDPWktSm9vq895zMk6szL
fetchai/agents/gym_aea,QmbEzUY4VeTgaBjQQYkuwDJaoUeSjTwpEyB1EN1GvMvM9d
fetchai/agents/ml_data_provider,QmZ8bArz2gkm8CRenSQMgmUYQo2cHHgUcy5q2rPSp2Ukka
fetchai/agents/ml_model_trainer,QmNtPQewjgUHQaBFxvBLL5MjHvZyTEh2paTBk1pg1cZB9L
fetchai/agents/my_first_aea,QmPEUS71Z2BXchXADVzTjEFLzyi6Pbvn1U6s5hC2mAGcCk
Expand All @@ -18,7 +18,7 @@ fetchai/agents/thermometer_aea,QmXwmPDtZ3Q7t5u3k1ounzDg5rtFD4vsTBTH43UGrmbdvq
fetchai/agents/thermometer_client,QmRMKu9hAzSZQyuSPGg9umQGDRrq1miwrVKo7SFMKDqQV4
fetchai/agents/weather_client,Qmah4VhqdoH6k95xUZk9VREjG4iX5drKvUj2cypiAugoXK
fetchai/agents/weather_station,QmfD44aXS4TmcZFMASb8vDxYK6eNFsQMkSTBmTdcqzGPhc
fetchai/connections/gym,QmbAr8uBUs9g4ZCpbACAvwwb8NLBgYwB6qWcZpFo3MhtpB
fetchai/connections/gym,QmZNEJvgi9n5poswQrHav3fvSv5vA1nbxxkTzWENCoCdrc
fetchai/connections/http_client,QmXQrA6gA4hMEMkMQsEp1MQwDEqRw5BnnqR4gCrP5xqVD2
fetchai/connections/http_server,QmPMSyX1iaWM7mWqFtW8LnSyR9r88RzYbGtyYmopT6tshC
fetchai/connections/ledger,QmezMgaJkk9wbQ4nzURERnNJdrzkQyvV5PiieH6uGbVzc3
Expand Down
4 changes: 2 additions & 2 deletions tests/test_cli/test_eject.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ def test_eject_commands_positive(self):

self.set_agent_context(agent_name)
cwd = os.path.join(self.t, agent_name)
self.add_item("connection", "fetchai/gym:0.3.0")
self.add_item("connection", "fetchai/gym:0.4.0")
self.add_item("skill", "fetchai/gym:0.4.0")
self.add_item("contract", "fetchai/erc1155:0.6.0")

self.run_cli_command("eject", "connection", "fetchai/gym:0.3.0", cwd=cwd)
self.run_cli_command("eject", "connection", "fetchai/gym:0.4.0", cwd=cwd)
assert "gym" not in os.listdir(
(os.path.join(cwd, "vendor", "fetchai", "connections"))
)
Expand Down
6 changes: 3 additions & 3 deletions tests/test_docs/test_bash_yaml/md_files/bash-gym-skill.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
``` bash
aea fetch fetchai/gym_aea:0.4.0 --alias my_gym_aea
aea fetch fetchai/gym_aea:0.5.0 --alias my_gym_aea
cd my_gym_aea
aea install
```
Expand All @@ -11,8 +11,8 @@ cd my_gym_aea
aea add skill fetchai/gym:0.4.0
```
``` bash
aea add connection fetchai/gym:0.3.0
aea config set agent.default_connection fetchai/gym:0.3.0
aea add connection fetchai/gym:0.4.0
aea config set agent.default_connection fetchai/gym:0.4.0
```
``` bash
aea install
Expand Down
Loading

0 comments on commit 23aa8d6

Please sign in to comment.