Skip to content

Commit

Permalink
feature(pu): add muzero_segment_collector.py
Browse files Browse the repository at this point in the history
  • Loading branch information
dyyoungg committed Sep 5, 2024
1 parent 1d010d3 commit fea98ee
Show file tree
Hide file tree
Showing 6 changed files with 851 additions and 9 deletions.
5 changes: 3 additions & 2 deletions lzero/entry/train_unizero.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
from lzero.entry.utils import log_buffer_memory_usage
from lzero.policy import visit_count_temperature
from lzero.policy.random_policy import LightZeroRandomPolicy
from lzero.worker import MuZeroCollector as Collector
# from lzero.worker import MuZeroCollector as Collector
from lzero.worker import MuZeroSegmentCollector as Collector # TODO
from lzero.worker import MuZeroEvaluator as Evaluator
from .utils import random_collect

Expand Down Expand Up @@ -107,7 +108,7 @@ def train_unizero(
batch_size = policy._cfg.batch_size

# TODO: for visualize
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
# stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)

while True:
# Log buffer memory usage
Expand Down
10 changes: 8 additions & 2 deletions lzero/model/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,12 @@ def __init__(
self.activation = activation
self.embedding_dim = embedding_dim

self.last_linear = nn.Linear(64 * 8 * 8, self.embedding_dim, bias=False)

if self.observation_shape[1] == 64:
self.last_linear = nn.Linear(64 * 8 * 8, self.embedding_dim, bias=False)

elif self.observation_shape[1] == 96:
self.last_linear = nn.Linear(64 * 6 * 6, self.embedding_dim, bias=False)

self.sim_norm = SimNorm(simnorm_dim=group_size)

Expand All @@ -365,7 +370,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
# Important: Transform the output feature plane to the latent state.
# For example, for an Atari feature plane of shape (64, 8, 8),
# flattening results in a size of 4096, which is then transformed to 768.
x = self.last_linear(x.reshape(-1, 64 * 8 * 8))
x = self.last_linear(x.view(x.size(0), -1))

x = x.view(-1, self.embedding_dim)

# NOTE: very important for training stability.
Expand Down
1 change: 1 addition & 0 deletions lzero/worker/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .alphazero_collector import AlphaZeroCollector
from .alphazero_evaluator import AlphaZeroEvaluator
from .muzero_collector import MuZeroCollector
from .muzero_segment_collector import MuZeroSegmentCollector
from .muzero_evaluator import MuZeroEvaluator
Loading

0 comments on commit fea98ee

Please sign in to comment.