diff --git a/compose/project.py b/compose/project.py index 7c0d19da39..2f3675ffba 100644 --- a/compose/project.py +++ b/compose/project.py @@ -202,17 +202,15 @@ def up(self, running_containers = [] for service in self.get_services(service_names, include_deps=start_deps): if recreate: - for (_, container) in service.recreate_containers( - insecure_registry=insecure_registry, - detach=detach, - do_build=do_build): - running_containers.append(container) + create_func = service.recreate_containers else: - for container in service.start_or_create_containers( - insecure_registry=insecure_registry, - detach=detach, - do_build=do_build): - running_containers.append(container) + create_func = service.start_or_create_containers + + for container in create_func( + insecure_registry=insecure_registry, + detach=detach, + do_build=do_build): + running_containers.append(container) return running_containers diff --git a/compose/service.py b/compose/service.py index ee47142f26..a1c0f9258f 100644 --- a/compose/service.py +++ b/compose/service.py @@ -30,6 +30,7 @@ 'pid', 'privileged', 'restart', + 'volumes_from', ] VALID_NAME_CHARS = '[a-zA-Z0-9]' @@ -175,16 +176,16 @@ def create_container(self, one_off=False, insecure_registry=False, do_build=True, - intermediate_container=None, + previous_container=None, **override_options): """ Create a container for this service. If the image doesn't exist, attempt to pull it. """ + override_options['volumes_from'] = self._get_volumes_from(previous_container) container_options = self._get_container_create_options( override_options, one_off=one_off, - intermediate_container=intermediate_container, ) if (do_build and @@ -213,21 +214,24 @@ def recreate_containers(self, insecure_registry=False, do_build=True, **override do_build=do_build, **override_options) self.start_container(container) - return [(None, container)] - else: - tuples = [] - - for c in containers: - log.info("Recreating %s..." % c.name) - tuples.append(self.recreate_container(c, insecure_registry=insecure_registry, **override_options)) + return [container] - return tuples + return [ + self.recreate_container( + c, + insecure_registry=insecure_registry, + **override_options) + for c in containers + ] def recreate_container(self, container, **override_options): - """Recreate a container. An intermediate container is created so that - the new container has the same name, while still supporting - `volumes-from` the original container. + """Recreate a container. + + The original container is renamed to a temporary name so that data + volumes can be copied to the new container, before the original + container is removed. """ + log.info("Recreating %s..." % container.name) try: container.stop() except APIError as e: @@ -238,29 +242,17 @@ def recreate_container(self, container, **override_options): else: raise - intermediate_container = Container.create( - self.client, - image=container.image, - entrypoint=['/bin/echo'], - command=[], - detach=True, - host_config=create_host_config(volumes_from=[container.id]), - ) - intermediate_container.start() - intermediate_container.wait() - container.remove() - - options = dict(override_options) + # Use a hopefully unique container name by prepending the short id + self.client.rename( + container.id, + '%s_%s' % (container.short_id, container.name)) new_container = self.create_container( do_build=False, - intermediate_container=intermediate_container, - **options - ) + previous_container=container, + **override_options) self.start_container(new_container) - - intermediate_container.remove() - - return (intermediate_container, new_container) + container.remove() + return new_container def start_container_if_stopped(self, container): if container.is_running: @@ -333,7 +325,7 @@ def _get_links(self, link_to_self): links.append((external_link, link_name)) return links - def _get_volumes_from(self, intermediate_container=None): + def _get_volumes_from(self, previous_container=None): volumes_from = [] for volume_source in self.volumes_from: if isinstance(volume_source, Service): @@ -346,8 +338,8 @@ def _get_volumes_from(self, intermediate_container=None): elif isinstance(volume_source, Container): volumes_from.append(volume_source.id) - if intermediate_container: - volumes_from.append(intermediate_container.id) + if previous_container: + volumes_from.append(previous_container.id) return volumes_from @@ -370,7 +362,7 @@ def _get_net(self): return net - def _get_container_create_options(self, override_options, one_off=False, intermediate_container=None): + def _get_container_create_options(self, override_options, one_off=False): container_options = dict( (k, self.options[k]) for k in DOCKER_CONFIG_KEYS if k in self.options) @@ -415,11 +407,13 @@ def _get_container_create_options(self, override_options, one_off=False, interme for key in DOCKER_START_KEYS: container_options.pop(key, None) - container_options['host_config'] = self._get_container_host_config(override_options, one_off=one_off, intermediate_container=intermediate_container) + container_options['host_config'] = self._get_container_host_config( + override_options, + one_off=one_off) return container_options - def _get_container_host_config(self, override_options, one_off=False, intermediate_container=None): + def _get_container_host_config(self, override_options, one_off=False): options = dict(self.options, **override_options) port_bindings = build_port_bindings(options.get('ports') or []) @@ -451,7 +445,7 @@ def _get_container_host_config(self, override_options, one_off=False, intermedia links=self._get_links(link_to_self=one_off), port_bindings=port_bindings, binds=volume_bindings, - volumes_from=self._get_volumes_from(intermediate_container), + volumes_from=options.get('volumes_from'), privileged=privileged, network_mode=self._get_net(), dns=dns, diff --git a/tests/integration/service_test.py b/tests/integration/service_test.py index 678aacdd07..dbb4a609c2 100644 --- a/tests/integration/service_test.py +++ b/tests/integration/service_test.py @@ -249,25 +249,20 @@ def test_recreate_containers(self): num_containers_before = len(self.client.containers(all=True)) service.options['environment']['FOO'] = '2' - tuples = service.recreate_containers() - self.assertEqual(len(tuples), 1) - - intermediate_container = tuples[0][0] - new_container = tuples[0][1] - self.assertEqual(intermediate_container.dictionary['Config']['Entrypoint'], ['/bin/echo']) + new_container, = service.recreate_containers() self.assertEqual(new_container.dictionary['Config']['Entrypoint'], ['sleep']) self.assertEqual(new_container.dictionary['Config']['Cmd'], ['300']) self.assertIn('FOO=2', new_container.dictionary['Config']['Env']) self.assertEqual(new_container.name, 'composetest_db_1') self.assertEqual(new_container.inspect()['Volumes']['/etc'], volume_path) - self.assertIn(intermediate_container.id, new_container.dictionary['HostConfig']['VolumesFrom']) + self.assertIn(old_container.id, new_container.dictionary['HostConfig']['VolumesFrom']) self.assertEqual(len(self.client.containers(all=True)), num_containers_before) self.assertNotEqual(old_container.id, new_container.id) self.assertRaises(APIError, self.client.inspect_container, - intermediate_container.id) + old_container.id) def test_recreate_containers_when_containers_are_stopped(self): service = self.create_service( diff --git a/tests/unit/service_test.py b/tests/unit/service_test.py index b9f968db10..583f72ef0b 100644 --- a/tests/unit/service_test.py +++ b/tests/unit/service_test.py @@ -86,7 +86,7 @@ def test_get_volumes_from_container(self): self.assertEqual(service._get_volumes_from(), [container_id]) - def test_get_volumes_from_intermediate_container(self): + def test_get_volumes_from_previous_container(self): container_id = 'aabbccddee' service = Service('test', image='foo') container = mock.Mock(id=container_id, spec=Container, image='foo') @@ -263,6 +263,20 @@ def test_create_container_from_insecure_registry( mock_log.info.assert_called_once_with( 'Pulling foo (someimage:sometag)...') + @mock.patch('compose.service.Container', autospec=True) + def test_recreate_container(self, _): + mock_container = mock.create_autospec(Container) + service = Service('foo', client=self.mock_client, image='someimage') + new_container = service.recreate_container(mock_container) + + mock_container.stop.assert_called_once_with() + self.mock_client.rename.assert_called_once_with( + mock_container.id, + '%s_%s' % (mock_container.short_id, mock_container.name)) + + new_container.start.assert_called_once_with() + mock_container.remove.assert_called_once_with() + def test_parse_repository_tag(self): self.assertEqual(parse_repository_tag("root"), ("root", "")) self.assertEqual(parse_repository_tag("root:tag"), ("root", "tag"))