From 365f070619c12f2b4f6093ab8b145c3c9db2d4bd Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Wed, 24 Jan 2024 00:47:03 -0600 Subject: [PATCH] Unify stack reduction Currently there are two codepaths for resuming computations: `k()` and `k.tail()` where the first hase a return value, but is recursive and is subject to stack overflow, while `k.tail()` is not allowed to have a return value, but has a more resilient stack. There are a number of problems with this. The first is the cognitive overhead on the API side. You have to constantly remember that you need to call `k.tail()` instead of `k()` or you might just introduce a stack overflow bug. The second is the complexity of the implementation. In order to make the `k.tail()` work, we needed to manage the call stack to be weirdly re-entrant with a separate `Stack` object, and a check to see if the stack was currently reducing (Stack.reducing). This also comes with a _memory complexity_ penalty because we're creating a lot of intermediate objects in the form of the stack itself, and also in the form of the thunks that we are pushing and popping off the stack. There was also memory overhead in creating four continuation functions (`k()`, `k.tail()`, `reject()` and `reject.tail()`) instead of just two. This avoids all that complexity by making the entire `evaluate()` method into a single while loop that only takes up one native JavaScript stack frame to execute no matter how deeply nested the computation you want to run is. However, this does require a key breaking change: `k()` is a Computation, not a callback === the `k()` passed to `shift()` is of type `(value) => Compuation` instead of `(value) => T`. This means that in order to use it, you have to either `yield* k()` or re-enter a reduction with evaluate. In other words, reduction stacks must be explicitly created with `evaluate()` in order to use them, for example, from a `setTimeot()`. Otherwise, they can be used seamlessly if you are already inside of a delimited continuation: ```ts let result = evaluate(function* run() { let sum = 0; for (let i = 0; i < 100_000; i++) { sum += yield* shift<1>(function* incr(k) { return yield* k(1); }); } return sum; }); console.dir({ result }); ``` the above will print => `{ result: 10000 }`; The tradeoff for this breaking change is drastically reduced complexity of implementation. --- mod.ts | 263 ++++++++++++++++------------------------- t/continuation.test.ts | 67 ++++++----- t/error.test.ts | 48 ++------ 3 files changed, 149 insertions(+), 229 deletions(-) diff --git a/mod.ts b/mod.ts index 20fae10..038a6af 100644 --- a/mod.ts +++ b/mod.ts @@ -1,190 +1,135 @@ // deno-lint-ignore-file no-explicit-any -interface Thunk { - method: "next" | "throw"; - iterator: Iterator; - value?: unknown | Error; +export interface Computation { + [Symbol.iterator](): Iterator; } -interface Stack { - reducing: boolean; - push(...thunks: Thunk[]): number; - pop(): Thunk | undefined; - value?: any; +export interface Continuation { + (value: T extends void ? void : T): Computation; } -export interface Computation { - [Symbol.iterator](): Iterator; -} +export type Result = { ok: true; value: T } | { ok?: false; error: Error }; -export interface Continuation { - (value: T): R; - tail(value: T): void; -} +export type Control = + | { + type: "shift"; + block(k: Continuation, reject: Continuation): Computation; + } + | { + type: "reset"; + block(): Computation; + } + | { + type: "resume"; + result: Result; + iter: Iterator; + }; -export function* reset(block: () => Computation): Computation { +export function* reset(block: () => Computation): Computation { return yield { type: "reset", block }; } -export function* shift( - block: (resolve: Continuation, reject: Continuation) => Computation, +export function* shift( + block: ( + k: Continuation, + reject: Continuation, + ) => Computation, ): Computation { return yield { type: "shift", block }; } -function createStack(): Stack { - let list: Thunk[] = []; - return { - reducing: false, - push(...thunks: Thunk[]): number { - return list.push(...thunks); - }, - pop(): Thunk | undefined { - return list.pop(); - }, - }; -} - -export function evaluate(iterator: () => Computation): T { - let stack = createStack(); - stack.push({ - method: "next", - iterator: iterator()[Symbol.iterator](), - }); - return reduce(stack); -} - -function reduce(stack: Stack): T { - try { - stack.reducing = true; - for (let current = stack.pop(); current; current = stack.pop()) { - try { - let next = getNext(current); - stack.value = next.value; - - if (!next.done) { - let control = next.value; - if (control.type === "reset") { - stack.push( - { - ...current, - method: "next", - get value() { - return stack.value; - }, - }, - { - method: "next", - iterator: control.block()[Symbol.iterator](), - }, - ); - } else { - let thunk = current; - let resolve = oneshot((value: unknown) => { - stack.push({ - method: "next", - iterator: thunk.iterator, - value, - }); - return reduce(stack); - }); - resolve.tail = oneshot((value: unknown) => { - stack.push({ - method: "next", - iterator: thunk.iterator, - value, - }); - if (!stack.reducing) { - reduce(stack); - } - }); - let reject = oneshot((error: Error) => { - stack.push({ - method: "throw", - iterator: thunk.iterator, - value: error, - }); - return reduce(stack); - }); - reject.tail = oneshot((error: Error) => { - stack.push({ - method: "throw", - iterator: thunk.iterator, - value: error, - }); - - if (!stack.reducing) { - reduce(stack); - } - }); - - stack.push({ - method: "next", - iterator: control.block(resolve, reject)[Symbol.iterator](), - value: void 0, - }); - } +export function evaluate(block: () => Computation) { + let stack: Iterator[] = []; + let iter = block()[Symbol.iterator](); + + let current = $next(undefined); + while (true) { + let result = safe(() => current(iter)); + if (!result.ok) { + if (stack.length > 0) { + iter = stack.pop()!; + current = $throw(result.error); + } else { + throw result.error; + } + } else { + let next = result.value; + if (next.done) { + if (stack.length > 0) { + iter = stack.pop()!; + current = $next(next.value); + } else { + return next.value as T; } - } catch (error) { - let top = stack.pop(); - if (top) { - stack.push({ ...top, method: "throw", value: error }); + } else { + const control = next.value; + if (control.type === "reset") { + stack.push(iter); + iter = control.block()[Symbol.iterator](); + } else if (control.type === "shift") { + const continuation = iter; + + let resolve = oneshot((value: unknown) => ({ + type: "resume", + iter: continuation, + result: { ok: true, value }, + })); + + let reject = oneshot((error: Error) => { + return { type: "resume", iter: continuation, result: { error } }; + }); + iter = control.block(resolve, reject)[Symbol.iterator](); } else { - throw error; + iter = control.iter; + let { result } = control; + current = result.ok ? $next(result.value) : $throw(result.error); } } } - } finally { - stack.reducing = false; } - - return stack.value; } -function getNext(thunk: Thunk) { - let { iterator } = thunk; - if (thunk.method === "next") { - return iterator.next(thunk.value); - } else { - let value = thunk.value as Error; - if (iterator.throw) { - return iterator.throw(value); - } else { - throw value; - } - } +function $next(value: unknown) { + return (iter: Iterator) => iter.next(value); } -function oneshot(fn: (t: T) => R): Continuation { - let continued = false; - let failure: { error: unknown }; - let result: any; - - return ((value) => { - if (!continued) { - continued = true; - try { - return (result = fn(value)); - } catch (error) { - failure = { error }; - throw error; - } - } else if (failure) { - throw failure.error; +function $throw(error: Error): ReturnType { + return (iter) => { + if (iter.throw) { + return iter.throw(error); } else { - return result; + throw error; } - }) as Continuation; + }; } -export type K = Continuation; - -export type Control = - | { - type: "shift"; - block(resolve: Continuation, reject: Continuation): Computation; +function safe(fn: () => T): Result { + try { + return { ok: true, value: fn() }; + } catch (error) { + return { error }; } - | { - type: "reset"; - block(): Computation; +} + +function oneshot( + fn: (arg: TArg) => Control, +): (arg: TArg) => Computation { + let computation: Computation | undefined = undefined; + + return function k(arg: TArg) { + if (!computation) { + let control = fn(arg); + let iterator: Iterator = { + next() { + iterator.next = () => ({ + done: true, + value: undefined as unknown as T, + }); + return { done: false, value: control }; + }, + }; + computation = { [Symbol.iterator]: () => iterator }; + } + return computation; }; +} diff --git a/t/continuation.test.ts b/t/continuation.test.ts index 17579b1..f782ed3 100644 --- a/t/continuation.test.ts +++ b/t/continuation.test.ts @@ -1,6 +1,6 @@ import { describe, it } from "./bdd.ts"; import { assertEquals, assertThrows } from "./asserts.ts"; -import { Computation, evaluate, K, reset, shift } from "../mod.ts"; +import { Computation, Continuation, evaluate, reset, shift } from "../mod.ts"; describe("continuation", () => { it("evaluates synchronous values synchronously", () => { @@ -25,50 +25,54 @@ describe("continuation", () => { it("each continuation point function only resumes once", () => { let beginning, middle, end; - let next = evaluate>>(function* () { - beginning = true; - middle = yield* shift(function* (k) { - return k; - }); - end = yield* shift(function* (k) { - return k; - }); - return end * 10; - }); + let next = evaluate>>( + function* () { + beginning = true; + middle = yield* shift(function* (k) { + return k; + }); + end = yield* shift(function* (k) { + return k; + }); + return end * 10; + }, + ); assertEquals(true, beginning); assertEquals(undefined, middle); - let last = next("reached middle"); + let last = evaluate<(val: number) => Computation>(() => + next("reached middle") + ); assertEquals("reached middle", middle); assertEquals(undefined, end); assertEquals("function", typeof last); - let second = next("continue"); + let second = evaluate(() => next("continue")); assertEquals("reached middle", middle); assertEquals(undefined, end); - assertEquals(last, second); + assertEquals(void 0, second); - let result = last(10); + let result = evaluate(() => last(10)); assertEquals(10, end); assertEquals(100, result); - let result2 = last(100); + let result2 = evaluate(() => last(100)); assertEquals(10, end); - assertEquals(100, result2); + assertEquals(undefined, result2); }); it("each continuation point only fails once", () => { let bing = 0; - let boom = evaluate>(function* () { + let boom = evaluate>(function* () { yield* shift(function* (k) { return k; }); throw new Error(`bing ${++bing}`); }); - assertThrows(boom, Error, "bing 1"); - assertThrows(boom, Error, "bing 1"); + assertThrows(() => evaluate(() => boom()), Error, "bing 1"); + assertEquals(undefined, evaluate(() => boom())); }); it("can exit early from recursion", () => { @@ -89,31 +93,32 @@ describe("continuation", () => { }); it("returns the value of the following shift point when continuing ", () => { - let { k } = evaluate<{ k: K }>(function* () { + let { k } = evaluate<{ k: Continuation }>(function* () { let k = yield* reset(function* () { - let result = yield* shift(function* (k) { + let result = yield* shift(function* (k) { return k; }); - yield* shift(function* () { + + return yield* shift(function* () { return result * 2; }); }); return { k }; }); assertEquals("function", typeof k); - assertEquals(10, k(5)); + assertEquals(10, evaluate(() => k(5))); }); it("can withstand stack overflow", () => { - function* run() { + let result = evaluate(function* run() { + let sum = 0; for (let i = 0; i < 100_000; i++) { - yield* shift(function* (k) { - k.tail(1); + sum += yield* shift<1>(function* incr(k) { + return yield* k(1); }); } - } - - evaluate(run); - assertEquals(true, true); + return sum; + }); + assertEquals(result, 100_000); }); }); diff --git a/t/error.test.ts b/t/error.test.ts index 6d89470..1e03c0d 100644 --- a/t/error.test.ts +++ b/t/error.test.ts @@ -1,4 +1,4 @@ -import { evaluate, K, reset, shift } from "../mod.ts"; +import { Continuation, evaluate, reset, shift } from "../mod.ts"; import { assertEquals, assertObjectMatch, assertThrows } from "./asserts.ts"; import { describe, it } from "./bdd.ts"; @@ -60,7 +60,7 @@ describe("error", () => { }); it("can be raised programatically from a shift", () => { - let reject = evaluate>(function* () { + let reject = evaluate>(function* () { try { yield* shift(function* (_, reject) { return reject; @@ -72,50 +72,20 @@ describe("error", () => { let error = new Error("boom!"); - assertEquals({ caught: true, error }, reject(error)); - }); - - it("reject.tail should throw exception", () => { - assertThrows( - () => - evaluate>(function* () { - yield* shift(function* (_, reject) { - return reject.tail(new Error("boom!")); - }); - }), - Error, - "boom!", - ); + assertEquals({ caught: true, error }, evaluate(() => reject(error))); }); it("blows up the caller if it is not caught inside the continuation", () => { - let reject = evaluate>(function* () { + let reject = evaluate>(function* () { yield* shift(function* (_, reject) { return reject; }); }); - assertThrows(() => reject(new Error("boom!")), Error, "boom!"); - }); - - it("can tail resume even when a stack blows up", () => { - let things = evaluate<{ one: K; two: K }>(function* () { - let one = yield* reset(function* () { - yield* shift(function* (resolve) { - return resolve.tail; - }); - throw new Error("boom 1!"); - }); - let two = yield* reset(function* () { - yield* shift(function* (resolve) { - return resolve.tail; - }); - throw new Error("boom 2!"); - }); - return { one, two }; - }); - - assertThrows(() => things.one(), Error, "boom 1!"); - assertThrows(() => things.two(), Error, "boom 2!"); + assertThrows( + () => evaluate(() => reject(new Error("boom!"))), + Error, + "boom!", + ); }); });