Skip to content

Commit

Permalink
Python: Reduce memory usage when restoring snapshot
Browse files Browse the repository at this point in the history
Previously we paid for two copies of the snapshot memory: one copy in the wasm
linear memory itself and a second copy in `BUNDLE_MEMORY_SNAPSHOT`. This ensures
that we never have more memory than one copy of the linear memory heap by
copying the memory directly from the snapshot to the linear memory. We also
release the C++ memory when we are done with it.
  • Loading branch information
hoodmane committed Mar 21, 2024
1 parent 371a876 commit 35d5ea1
Show file tree
Hide file tree
Showing 6 changed files with 124 additions and 62 deletions.
7 changes: 6 additions & 1 deletion src/pyodide/internal/metadata.js
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import { default as MetadataReader } from "pyodide-internal:runtime-generated/metadata";
export { default as LOCKFILE } from "pyodide-internal:generated/pyodide-lock.json";
import { default as PYODIDE_BUCKET } from "pyodide-internal:generated/pyodide-bucket.json";
import { default as ArtifactBundler } from "pyodide-internal:artifacts";

export const IS_WORKERD = MetadataReader.isWorkerd();
export const IS_TRACING = MetadataReader.isTracing();
export const WORKERD_INDEX_URL = PYODIDE_BUCKET.PYODIDE_PACKAGE_BUCKET_URL;
export const REQUIREMENTS = MetadataReader.getRequirements();
export const MAIN_MODULE_NAME = MetadataReader.getMainModule();
export const BUNDLE_MEMORY_SNAPSHOT = MetadataReader.getMemorySnapshot();
export const MEMORY_SNAPSHOT_READER = MetadataReader.hasMemorySnapshot()
? MetadataReader
: ArtifactBundler.hasMemorySnapshot()
? ArtifactBundler
: undefined;
58 changes: 29 additions & 29 deletions src/pyodide/internal/python.js
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import {
} from "pyodide-internal:setupPackages";
import { default as TarReader } from "pyodide-internal:packages_tar_reader";
import processScriptImports from "pyodide-internal:process_script_imports.py";
import { BUNDLE_MEMORY_SNAPSHOT } from "pyodide-internal:metadata";
import { MEMORY_SNAPSHOT_READER } from "pyodide-internal:metadata";

/**
* This file is a simplified version of the Pyodide loader:
Expand Down Expand Up @@ -38,7 +38,8 @@ import pyodideWasmModule from "pyodide-internal:generated/pyodide.asm.wasm";
*/
import stdlib from "pyodide-internal:generated/python_stdlib.zip";

const SHOULD_UPLOAD_SNAPSHOT = ArtifactBundler.isEnabled() || ArtifactBundler.isEwValidating();
const SHOULD_UPLOAD_SNAPSHOT =
ArtifactBundler.isEnabled() || ArtifactBundler.isEwValidating();
const DEDICATED_SNAPSHOT = true;

/**
Expand All @@ -47,7 +48,7 @@ const DEDICATED_SNAPSHOT = true;
* which is quite slow. Startup with snapshot is 3-5 times faster than without
* it.
*/
let MEMORY = undefined;
let READ_MEMORY = undefined;
/**
* Record the dlopen handles that are needed by the MEMORY.
*/
Expand Down Expand Up @@ -234,7 +235,7 @@ function getEmscriptenSettings(lockfile, indexURL) {
// important because the file system lives outside of linear memory.
preRun: [prepareFileSystem, setEnv, preloadDynamicLibs],
instantiateWasm,
noInitialRun: !!MEMORY, // skip running main() if we have a snapshot
noInitialRun: !!READ_MEMORY, // skip running main() if we have a snapshot
API, // Pyodide requires we pass this in.
};
}
Expand Down Expand Up @@ -278,16 +279,8 @@ async function instantiateEmscriptenModule(emscriptenSettings) {
async function prepareWasmLinearMemory(Module) {
// Note: if we are restoring from a snapshot, runtime is not initialized yet.
mountLib(Module, SITE_PACKAGES_INFO);
if (MEMORY) {
if (!(MEMORY instanceof Uint8Array)) {
throw new TypeError("Expected MEMORY to be a Uint8Array");
}
// resize linear memory to fit our snapshot. I think `growMemory` only
// exists if `-sALLOW_MEMORY_GROWTH` is passed to the linker but we'll
// probably always do that.
Module.growMemory(MEMORY.byteLength);
// restore memory from snapshot
Module.HEAP8.set(MEMORY);
if (READ_MEMORY) {
READ_MEMORY(Module);
// Don't call adjustSysPath here: it was called in the other branch when we
// were creating the snapshot so the outcome of that is already baked in.
return;
Expand Down Expand Up @@ -440,13 +433,22 @@ function encodeSnapshot(heap, dsoJSON) {
/**
* Decode heap and dsoJSON from the memory snapshot artifact we downloaded
*/
function decodeSnapshot(memorySnapshot) {
const uint32View = new Uint32Array(memorySnapshot);
const snapshotOffset = uint32View[0];
const jsonLength = uint32View[1];
const jsonView = new Uint8Array(memorySnapshot, 8, jsonLength);
DSO_METADATA = JSON.parse(new TextDecoder().decode(jsonView));
MEMORY = new Uint8Array(memorySnapshot, snapshotOffset);
function decodeSnapshot() {
const buf = new Uint32Array(2);
MEMORY_SNAPSHOT_READER.readMemorySnapshot(0, buf);
const snapshotOffset = buf[0];
const snapshotSize = MEMORY_SNAPSHOT_READER.getMemorySnapshotSize() - snapshotOffset;
const jsonLength = buf[1];
const jsonBuf = new Uint8Array(jsonLength);
MEMORY_SNAPSHOT_READER.readMemorySnapshot(8, jsonBuf);
DSO_METADATA = JSON.parse(new TextDecoder().decode(jsonBuf));
READ_MEMORY = function(Module) {
// resize linear memory to fit our snapshot.
Module.growMemory(snapshotSize);
// restore memory from snapshot
MEMORY_SNAPSHOT_READER.readMemorySnapshot(snapshotOffset, Module.HEAP8);
MEMORY_SNAPSHOT_READER.disposeMemorySnapshot();
}
}

/**
Expand Down Expand Up @@ -492,24 +494,22 @@ function simpleRunPython(emscriptenModule, code) {
let TEST_SNAPSHOT = undefined;
(function () {
// Lookup memory snapshot from artifact store.
const memorySnapshot = BUNDLE_MEMORY_SNAPSHOT || ArtifactBundler.getMemorySnapshot();
if (!memorySnapshot) {
if (!MEMORY_SNAPSHOT_READER) {
// snapshots are disabled or there isn't one yet
return;
}
if (memorySnapshot.constructor.name !== "ArrayBuffer") {
throw new TypeError("Expected snapshot to be an ArrayBuffer");
}

// Simple sanity check to ensure this snapshot isn't corrupted.
//
// TODO(later): we need better detection when this is corrupted. Right now the isolate will
// just die.
if (memorySnapshot.byteLength <= 100) {
TEST_SNAPSHOT = memorySnapshot;
const snapshotSize = MEMORY_SNAPSHOT_READER.getMemorySnapshotSize();
if (snapshotSize <= 100) {
TEST_SNAPSHOT = new Uint8Array(snapshotSize);
MEMORY_SNAPSHOT_READER.readMemorySnapshot(0, TEST_SNAPSHOT);
return;
}
decodeSnapshot(memorySnapshot);
decodeSnapshot();
})();

export async function loadPyodide(lockfile, indexURL) {
Expand Down
6 changes: 2 additions & 4 deletions src/pyodide/internal/setupPackages.js
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ const STDLIB_PACKAGES = Object.values(LOCKFILE.packages)
.filter(({ install_dir }) => install_dir === "stdlib")
.map(({ name }) => canonicalizePackageName(name));


/**
* This stitches together the view of the site packages directory. Each
* requirement corresponds to a folder in the original tar file. For each
Expand Down Expand Up @@ -186,6 +185,5 @@ function addPackageToLoad(lockfile, name, toLoad) {

export { REQUIREMENTS };
export const TRANSITIVE_REQUIREMENTS = getTransitiveRequirements();
export const [SITE_PACKAGES_INFO, SITE_PACKAGES_SO_FILES, USE_LOAD_PACKAGE] = buildSitePackages(
TRANSITIVE_REQUIREMENTS,
);
export const [SITE_PACKAGES_INFO, SITE_PACKAGES_SO_FILES, USE_LOAD_PACKAGE] =
buildSitePackages(TRANSITIVE_REQUIREMENTS);
10 changes: 8 additions & 2 deletions src/pyodide/python-entrypoint-helper.js
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
// This file is a BUILTIN module that provides the actual implementation for the
// python-entrypoint.js USER module.

import { loadPyodide, uploadArtifacts, getMemoryToUpload } from "pyodide-internal:python";
import {
loadPyodide,
uploadArtifacts,
getMemoryToUpload,
} from "pyodide-internal:python";
import { enterJaegerSpan } from "pyodide-internal:jaeger";
import {
REQUIREMENTS,
Expand Down Expand Up @@ -109,7 +113,9 @@ function getMainModule() {
mainModulePromise = (async function () {
const pyodide = await getPyodide();
await setupPackages(pyodide);
return enterJaegerSpan("pyimport_main_module", () => pyimportMainModule(pyodide));
return enterJaegerSpan("pyimport_main_module", () =>
pyimportMainModule(pyodide),
);
})();
return mainModulePromise;
});
Expand Down
28 changes: 28 additions & 0 deletions src/workerd/api/pyodide/pyodide.c++
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#include "pyodide.h"
#include "kj/array.h"
#include "kj/common.h"

namespace workerd::api::pyodide {

Expand Down Expand Up @@ -56,6 +58,32 @@ int PyodideMetadataReader::read(jsg::Lock& js, int index, int offset, kj::Array<
return toCopy;
}

int PyodideMetadataReader::readMemorySnapshot(int offset, kj::Array<kj::byte> buf) {
int snapshotSize = memorySnapshot.size();
if (offset >= snapshotSize || offset < 0) {
return 0;
}
int toCopy = buf.size();
if (snapshotSize - offset < toCopy) {
toCopy = snapshotSize - offset;
}
memcpy(buf.begin(), &memorySnapshot[0] + offset, toCopy);
return toCopy;
}

int ArtifactBundler::readMemorySnapshot(int offset, kj::Array<kj::byte> buf) {
int snapshotSize = existingSnapshot.size();
if (offset >= snapshotSize || offset < 0) {
return 0;
}
int toCopy = buf.size();
if (snapshotSize - offset < toCopy) {
toCopy = snapshotSize - offset;
}
memcpy(buf.begin(), &existingSnapshot[0] + offset, toCopy);
return toCopy;
}

jsg::Ref<PyodideMetadataReader> makePyodideMetadataReader(Worker::Reader conf) {
auto modules = conf.getModules();
auto mainModule = kj::str(modules.begin()->getName());
Expand Down
77 changes: 51 additions & 26 deletions src/workerd/api/pyodide/pyodide.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include "kj/array.h"
#include <kj/common.h>
#include <pyodide/generated/pyodide_extra.capnp.h>
#include <pyodide/pyodide.capnp.h>
Expand All @@ -23,6 +24,13 @@ class PackagesTarReader : public jsg::Object {
}
};

inline kj::Array<kj::byte> maybeArrayToArray(kj::Maybe<kj::Array<kj::byte>> maybeArray) {
KJ_IF_SOME(array, maybeArray) {
return kj::mv(array);
}
return kj::heapArray<kj::byte>(0);
}

// A function to read a segment of the tar file into a buffer
// Set up this way to avoid copying files that aren't accessed.
class PyodideMetadataReader : public jsg::Object {
Expand All @@ -33,16 +41,16 @@ class PyodideMetadataReader : public jsg::Object {
kj::Array<kj::String> requirements;
bool isWorkerdFlag;
bool isTracingFlag;
kj::Maybe<kj::Array<kj::byte>> memorySnapshot;
kj::Array<kj::byte> memorySnapshot;

public:
PyodideMetadataReader(kj::String mainModule, kj::Array<kj::String> names,
kj::Array<kj::Array<kj::byte>> contents, kj::Array<kj::String> requirements,
bool isWorkerd, bool isTracing,
kj::Maybe<kj::Array<kj::byte>> memorySnapshot)
kj::Maybe<kj::Array<kj::byte>> snapshot)
: mainModule(kj::mv(mainModule)), names(kj::mv(names)), contents(kj::mv(contents)),
requirements(kj::mv(requirements)), isWorkerdFlag(isWorkerd), isTracingFlag(isTracing),
memorySnapshot(kj::mv(memorySnapshot)) {}
memorySnapshot(maybeArrayToArray(kj::mv(memorySnapshot))) {}

bool isWorkerd() {
return this->isWorkerdFlag;
Expand All @@ -56,10 +64,6 @@ class PyodideMetadataReader : public jsg::Object {
return kj::str(this->mainModule);
}

kj::Maybe<kj::Array<kj::byte>> getMemorySnapshot() {
return kj::mv(memorySnapshot);
}

kj::Array<jsg::JsRef<jsg::JsString>> getNames(jsg::Lock& js);

kj::Array<jsg::JsRef<jsg::JsString>> getRequirements(jsg::Lock& js);
Expand All @@ -68,6 +72,18 @@ class PyodideMetadataReader : public jsg::Object {

int read(jsg::Lock& js, int index, int offset, kj::Array<kj::byte> buf);

bool hasMemorySnapshot() {
return memorySnapshot.size() > 0;
}
int getMemorySnapshotSize() {
return memorySnapshot.size();
}

void disposeMemorySnapshot() {
memorySnapshot = kj::heapArray<kj::byte>(0);
}
int readMemorySnapshot(int offset, kj::Array<kj::byte> buf);

JSG_RESOURCE_TYPE(PyodideMetadataReader) {
JSG_METHOD(isWorkerd);
JSG_METHOD(isTracing);
Expand All @@ -76,7 +92,10 @@ class PyodideMetadataReader : public jsg::Object {
JSG_METHOD(getNames);
JSG_METHOD(getSizes);
JSG_METHOD(read);
JSG_METHOD(getMemorySnapshot);
JSG_METHOD(hasMemorySnapshot);
JSG_METHOD(getMemorySnapshotSize);
JSG_METHOD(readMemorySnapshot);
JSG_METHOD(disposeMemorySnapshot);
}

void visitForMemoryInfo(jsg::MemoryTracker& tracker) const {
Expand All @@ -101,22 +120,24 @@ class ArtifactBundler : public jsg::Object {

ArtifactBundler(kj::Maybe<kj::Array<kj::byte>> existingSnapshot,
kj::Function<kj::Promise<bool>(kj::Array<kj::byte> snapshot)> uploadMemorySnapshotCb)
: storedSnapshot(kj::none),
existingSnapshot(kj::mv(existingSnapshot)),
:
storedSnapshot(kj::none),
existingSnapshot(maybeArrayToArray(kj::mv(existingSnapshot))),
uploadMemorySnapshotCb(kj::mv(uploadMemorySnapshotCb)),
hasUploaded(false),
isValidating(false) {};
isValidating(false)
{};

ArtifactBundler(kj::Maybe<kj::Array<kj::byte>> existingSnapshot)
: storedSnapshot(kj::none),
existingSnapshot(kj::mv(existingSnapshot)),
existingSnapshot(maybeArrayToArray(kj::mv(existingSnapshot))),
uploadMemorySnapshotCb(kj::none),
hasUploaded(false),
isValidating(false) {};

ArtifactBundler(bool isValidating = false)
: storedSnapshot(kj::none),
existingSnapshot(kj::none),
existingSnapshot(kj::heapArray<kj::byte>(0)),
uploadMemorySnapshotCb(kj::none),
hasUploaded(false),
isValidating(isValidating) {};
Expand Down Expand Up @@ -144,21 +165,24 @@ class ArtifactBundler : public jsg::Object {
storedSnapshot = kj::mv(snapshot);
}

jsg::Optional<kj::Array<kj::byte>> getMemorySnapshot(jsg::Lock& js) {
KJ_IF_SOME(val, existingSnapshot) {
return kj::mv(val);
}
return kj::none;
}

bool isEnabled() {
return uploadMemorySnapshotCb != kj::none;
}

bool hasMemorySnapshot() {
return existingSnapshot != kj::none;
return existingSnapshot.size() > 0;
}

int getMemorySnapshotSize() {
return existingSnapshot.size();
}

int readMemorySnapshot(int offset, kj::Array<kj::byte> buf);
void disposeMemorySnapshot() {
existingSnapshot = kj::heapArray<kj::byte>(0);
}


// Determines whether this ArtifactBundler was created inside the validator.
bool isEwValidating() {
return isValidating;
Expand All @@ -169,14 +193,15 @@ class ArtifactBundler : public jsg::Object {
}

void visitForMemoryInfo(jsg::MemoryTracker& tracker) const {
KJ_IF_SOME(snapshot, existingSnapshot) {
tracker.trackFieldWithSize("snapshot", snapshot.size());
}
tracker.trackFieldWithSize("snapshot", existingSnapshot.size());
}

JSG_RESOURCE_TYPE(ArtifactBundler) {
JSG_METHOD(uploadMemorySnapshot);
JSG_METHOD(getMemorySnapshot);
JSG_METHOD(hasMemorySnapshot);
JSG_METHOD(getMemorySnapshotSize);
JSG_METHOD(readMemorySnapshot);
JSG_METHOD(disposeMemorySnapshot);
JSG_METHOD(isEnabled);
JSG_METHOD(isEwValidating);
JSG_METHOD(storeMemorySnapshot);
Expand All @@ -185,7 +210,7 @@ class ArtifactBundler : public jsg::Object {
private:
// A memory snapshot of the state of the Python interpreter after initialisation. Used to speed
// up cold starts.
kj::Maybe<kj::Array<kj::byte>> existingSnapshot;
kj::Array<kj::byte> existingSnapshot;
kj::Maybe<kj::Function<kj::Promise<bool>(kj::Array<kj::byte> snapshot)>> uploadMemorySnapshotCb;
bool hasUploaded;
bool isValidating;
Expand Down

0 comments on commit 35d5ea1

Please sign in to comment.