diff --git a/packages/rivetkit/scripts/dump-openapi.ts b/packages/rivetkit/scripts/dump-openapi.ts index e2aa85414..bf6fb7dda 100644 --- a/packages/rivetkit/scripts/dump-openapi.ts +++ b/packages/rivetkit/scripts/dump-openapi.ts @@ -3,7 +3,12 @@ import { resolve } from "node:path"; import { createFileSystemOrMemoryDriver } from "@/drivers/file-system/mod"; import type { ManagerDriver } from "@/manager/driver"; import { createManagerRouter } from "@/manager/router"; -import { type RegistryConfig, RegistryConfigSchema, setup } from "@/mod"; +import { + createClientWithDriver, + type RegistryConfig, + RegistryConfigSchema, + setup, +} from "@/mod"; import { type RunnerConfig, RunnerConfigSchema } from "@/registry/run-config"; import { VERSION } from "@/utils"; @@ -34,11 +39,14 @@ function main() { getOrCreateInspectorAccessToken: unimplemented, }; + const client = createClientWithDriver(managerDriver); + const { openapi } = createManagerRouter( registryConfig, driverConfig, managerDriver, - undefined, + driverConfig.driver!, + client, ); const openApiDoc = openapi.getOpenAPIDocument({ diff --git a/packages/rivetkit/src/driver-test-suite/mod.ts b/packages/rivetkit/src/driver-test-suite/mod.ts index 5042188d5..dd26f8d89 100644 --- a/packages/rivetkit/src/driver-test-suite/mod.ts +++ b/packages/rivetkit/src/driver-test-suite/mod.ts @@ -6,7 +6,12 @@ import { describe } from "vitest"; import type { Transport } from "@/client/mod"; import { configureInspectorAccessToken } from "@/inspector/utils"; import { createManagerRouter } from "@/manager/router"; -import type { DriverConfig, Registry, RunConfig } from "@/mod"; +import { + createClientWithDriver, + type DriverConfig, + type Registry, + type RunConfig, +} from "@/mod"; import { RunnerConfigSchema } from "@/registry/run-config"; import { getPort } from "@/test/mod"; import { logger } from "./log"; @@ -210,12 +215,14 @@ export async function createTestRuntime( // Create router const managerDriver = driver.manager(registry.config, config); + const client = createClientWithDriver(managerDriver); configureInspectorAccessToken(config, managerDriver); const { router } = createManagerRouter( registry.config, config, managerDriver, - undefined, + driver, + client, ); // Inject WebSocket diff --git a/packages/rivetkit/src/drivers/default.ts b/packages/rivetkit/src/drivers/default.ts index c94e86b59..829e3e983 100644 --- a/packages/rivetkit/src/drivers/default.ts +++ b/packages/rivetkit/src/drivers/default.ts @@ -29,7 +29,7 @@ export function chooseDefaultDriver(runConfig: RunnerConfig): DriverConfig { msg: "using rivet engine driver", endpoint: runConfig.endpoint, }); - return createEngineDriver(runConfig); + return createEngineDriver(); } loggerWithoutContext().debug({ msg: "using default file system driver" }); diff --git a/packages/rivetkit/src/drivers/engine/actor-driver.ts b/packages/rivetkit/src/drivers/engine/actor-driver.ts index e7a044a85..affa603ed 100644 --- a/packages/rivetkit/src/drivers/engine/actor-driver.ts +++ b/packages/rivetkit/src/drivers/engine/actor-driver.ts @@ -41,7 +41,6 @@ import { promiseWithResolvers, setLongTimeout, } from "@/utils"; -import type { EngineConfig } from "./config"; import { KEYS } from "./kv"; import { logger } from "./log"; @@ -58,7 +57,6 @@ export class EngineActorDriver implements ActorDriver { #runConfig: RunnerConfig; #managerDriver: ManagerDriver; #inlineClient: Client; - #config: EngineConfig; #runner: Runner; #actors: Map = new Map(); #actorRouter: ActorRouter; @@ -73,17 +71,15 @@ export class EngineActorDriver implements ActorDriver { runConfig: RunnerConfig, managerDriver: ManagerDriver, inlineClient: Client, - config: EngineConfig, ) { this.#registryConfig = registryConfig; this.#runConfig = runConfig; this.#managerDriver = managerDriver; this.#inlineClient = inlineClient; - this.#config = config; // HACK: Override inspector token (which are likely to be // removed later on) with token from x-rivet-token header - const token = runConfig.token ?? config.token; + const token = runConfig.token ?? runConfig.token; if (token && runConfig.inspector && runConfig.inspector.enabled) { runConfig.inspector.token = () => token; } @@ -98,12 +94,12 @@ export class EngineActorDriver implements ActorDriver { let hasDisconnected = false; const engineRunnerConfig: EngineRunnerConfig = { version: this.#version, - endpoint: getEndpoint(config), + endpoint: getEndpoint(runConfig), token, - namespace: runConfig.namespace ?? config.namespace, - totalSlots: runConfig.totalSlots ?? config.totalSlots, - runnerName: runConfig.runnerName ?? config.runnerName, - runnerKey: config.runnerKey, + namespace: runConfig.namespace ?? runConfig.namespace, + totalSlots: runConfig.totalSlots ?? runConfig.totalSlots, + runnerName: runConfig.runnerName ?? runConfig.runnerName, + runnerKey: runConfig.runnerKey, metadata: { inspectorToken: this.#runConfig.inspector.token(), }, @@ -117,14 +113,14 @@ export class EngineActorDriver implements ActorDriver { if (hasDisconnected) { logger().info({ msg: "runner reconnected", - namespace: this.#config.namespace, - runnerName: this.#config.runnerName, + namespace: this.#runConfig.namespace, + runnerName: this.#runConfig.runnerName, }); } else { logger().debug({ msg: "runner connected", - namespace: this.#config.namespace, - runnerName: this.#config.runnerName, + namespace: this.#runConfig.namespace, + runnerName: this.#runConfig.runnerName, }); } @@ -133,8 +129,8 @@ export class EngineActorDriver implements ActorDriver { onDisconnected: () => { logger().warn({ msg: "runner disconnected", - namespace: this.#config.namespace, - runnerName: this.#config.runnerName, + namespace: this.#runConfig.namespace, + runnerName: this.#runConfig.runnerName, }); hasDisconnected = true; }, @@ -153,9 +149,9 @@ export class EngineActorDriver implements ActorDriver { this.#runner.start(); logger().debug({ msg: "engine runner started", - endpoint: config.endpoint, - namespace: config.namespace, - runnerName: config.runnerName, + endpoint: runConfig.endpoint, + namespace: runConfig.namespace, + runnerName: runConfig.runnerName, }); } @@ -236,20 +232,20 @@ export class EngineActorDriver implements ActorDriver { async #runnerOnActorStart( actorId: string, generation: number, - config: RunnerActorConfig, + runConfig: RunnerActorConfig, ): Promise { logger().debug({ msg: "runner actor starting", actorId, - name: config.name, - key: config.key, + name: runConfig.name, + key: runConfig.key, generation, }); // Deserialize input let input: any; - if (config.input) { - input = cbor.decode(config.input); + if (runConfig.input) { + input = cbor.decode(runConfig.input); } // Get or create handler @@ -262,15 +258,12 @@ export class EngineActorDriver implements ActorDriver { this.#actors.set(actorId, handler); } - const name = config.name as string; - invariant(config.key, "actor should have a key"); - const key = deserializeActorKey(config.key); + const name = runConfig.name as string; + invariant(runConfig.key, "actor should have a key"); + const key = deserializeActorKey(runConfig.key); // Create actor instance - const definition = lookupInRegistry( - this.#registryConfig, - config.name as string, // TODO: Remove cast - ); + const definition = lookupInRegistry(this.#registryConfig, runConfig.name); handler.actor = definition.instantiate(); // Start actor diff --git a/packages/rivetkit/src/drivers/engine/mod.ts b/packages/rivetkit/src/drivers/engine/mod.ts index 59b232701..2437df6f7 100644 --- a/packages/rivetkit/src/drivers/engine/mod.ts +++ b/packages/rivetkit/src/drivers/engine/mod.ts @@ -13,11 +13,7 @@ export { EngingConfigSchema as ConfigSchema, } from "./config"; -export function createEngineDriver( - inputConfig?: EngineConfigInput, -): DriverConfig { - const config = EngingConfigSchema.parse(inputConfig); - +export function createEngineDriver(): DriverConfig { return { name: "engine", manager: (_registryConfig, runConfig) => { @@ -34,7 +30,6 @@ export function createEngineDriver( runConfig, managerDriver, inlineClient, - config, ); }, }; diff --git a/packages/rivetkit/src/manager/router-schema.ts b/packages/rivetkit/src/manager/router-schema.ts new file mode 100644 index 000000000..5655ef524 --- /dev/null +++ b/packages/rivetkit/src/manager/router-schema.ts @@ -0,0 +1,20 @@ +import { z } from "zod"; + +export const ServerlessStartHeadersSchema = z.object({ + endpoint: z.string({ required_error: "x-rivet-endpoint header is required" }), + token: z + .string({ invalid_type_error: "x-rivet-token header must be a string" }) + .optional(), + totalSlots: z.coerce + .number({ + invalid_type_error: "x-rivet-total-slots header must be a number", + }) + .int("x-rivet-total-slots header must be an integer") + .gte(1, "x-rivet-total-slots header must be positive"), + runnerName: z.string({ + required_error: "x-rivet-runner-name header is required", + }), + namespace: z.string({ + required_error: "x-rivet-namespace-id header is required", + }), +}); diff --git a/packages/rivetkit/src/manager/router.ts b/packages/rivetkit/src/manager/router.ts index 1e0e53fb1..d8fca244d 100644 --- a/packages/rivetkit/src/manager/router.ts +++ b/packages/rivetkit/src/manager/router.ts @@ -11,9 +11,9 @@ import { createMiddleware } from "hono/factory"; import { streamSSE } from "hono/streaming"; import invariant from "invariant"; import { z } from "zod"; -import { ActorNotFound, Unsupported } from "@/actor/errors"; +import { ActorNotFound, InvalidRequest, Unsupported } from "@/actor/errors"; import { serializeActorKey } from "@/actor/keys"; -import type { Encoding, Transport } from "@/client/mod"; +import type { Client, Encoding, Transport } from "@/client/mod"; import { WS_PROTOCOL_ACTOR, WS_PROTOCOL_CONN_ID, @@ -28,7 +28,12 @@ import { handleRouteNotFound, loggerMiddleware, } from "@/common/router"; -import { deconstructError, noopNext, stringifyError } from "@/common/utils"; +import { + assertUnreachable, + deconstructError, + noopNext, + stringifyError, +} from "@/common/utils"; import { type ActorDriver, HEADER_ACTOR_ID } from "@/driver-helpers/mod"; import type { TestInlineDriverCallRequest, @@ -50,13 +55,14 @@ import { type Actor as ApiActor, } from "@/manager-api/actors"; import { RivetIdSchema } from "@/manager-api/common"; -import type { ServerlessActorDriverBuilder } from "@/mod"; +import type { AnyClient } from "@/mod"; import type { RegistryConfig } from "@/registry/config"; -import type { RunnerConfig } from "@/registry/run-config"; +import type { DriverConfig, RunnerConfig } from "@/registry/run-config"; import { VERSION } from "@/utils"; import type { ActorOutput, ManagerDriver } from "./driver"; import { actorGateway, createTestWebSocketProxy } from "./gateway"; import { logger } from "./log"; +import { ServerlessStartHeadersSchema } from "./router-schema"; function buildOpenApiResponses(schema: T) { return { @@ -81,7 +87,8 @@ export function createManagerRouter( registryConfig: RegistryConfig, runConfig: RunnerConfig, managerDriver: ManagerDriver, - serverlessActorDriverBuilder: ServerlessActorDriverBuilder | undefined, + driverConfig: DriverConfig, + client: AnyClient, ): { router: Hono; openapi: OpenAPIHono } { const router = new OpenAPIHono({ strict: false }).basePath( runConfig.basePath, @@ -108,10 +115,19 @@ export function createManagerRouter( }), ); - if (serverlessActorDriverBuilder) { - addServerlessRoutes(runConfig, serverlessActorDriverBuilder, router); - } else { + if (runConfig.runnerKind === "serverless") { + addServerlessRoutes( + driverConfig, + registryConfig, + runConfig, + managerDriver, + client, + router, + ); + } else if (runConfig.runnerKind === "normal") { addManagerRoutes(registryConfig, runConfig, managerDriver, router); + } else { + assertUnreachable(runConfig.runnerKind); } // Error handling @@ -122,8 +138,11 @@ export function createManagerRouter( } function addServerlessRoutes( + driverConfig: DriverConfig, + registryConfig: RegistryConfig, runConfig: RunnerConfig, - serverlessActorDriverBuilder: ServerlessActorDriverBuilder, + managerDriver: ManagerDriver, + client: AnyClient, router: OpenAPIHono, ) { // Apply CORS @@ -138,24 +157,53 @@ function addServerlessRoutes( // Serverless start endpoint router.get("/start", async (c) => { - const token = c.req.header("x-rivet-token"); - let totalSlots: number | undefined = parseInt( - c.req.header("x-rivet-total-slots") as any, - ); - if (!Number.isFinite(totalSlots)) totalSlots = undefined; - const runnerName = c.req.header("x-rivet-runner-name"); - const namespace = c.req.header("x-rivet-namespace-id"); + // Parse headers + const parseResult = ServerlessStartHeadersSchema.safeParse({ + endpoint: c.req.header("x-rivet-endpoint"), + token: c.req.header("x-rivet-token") ?? undefined, + totalSlots: c.req.header("x-rivet-total-slots"), + runnerName: c.req.header("x-rivet-runner-name"), + namespace: c.req.header("x-rivet-namespace-id"), + }); + if (!parseResult.success) { + throw new InvalidRequest( + parseResult.error.issues[0]?.message ?? + "invalid serverless start headers", + ); + } + const { endpoint, token, totalSlots, runnerName, namespace } = + parseResult.data; - const actorDriver = serverlessActorDriverBuilder( - token, + logger().debug({ + msg: "received serverless runner start request", + endpoint, totalSlots, runnerName, namespace, + }); + + // Override config + // + // We can't do a structuredClone here since this holds functions + const newRunConfig = Object.assign({}, runConfig); + newRunConfig.endpoint = endpoint; + newRunConfig.token = token; + newRunConfig.totalSlots = totalSlots; + newRunConfig.runnerName = runnerName; + newRunConfig.namespace = namespace; + + // Create new actor driver with updated config + const actorDriver = driverConfig.actor( + registryConfig, + newRunConfig, + managerDriver, + client, ); invariant( actorDriver.serverlessHandleStart, "missing serverlessHandleStart on ActorDriver", ); + return await actorDriver.serverlessHandleStart(c); }); @@ -596,7 +644,7 @@ function addManagerRoutes( return c.json({ status: "ok", rivetkit: { - version: packageJson.version, + version: VERSION, }, }); }); diff --git a/packages/rivetkit/src/registry/mod.ts b/packages/rivetkit/src/registry/mod.ts index 367a00a6c..3f0b3d9d6 100644 --- a/packages/rivetkit/src/registry/mod.ts +++ b/packages/rivetkit/src/registry/mod.ts @@ -27,10 +27,7 @@ import { import { crossPlatformServe } from "./serve"; export type ServerlessActorDriverBuilder = ( - token?: string, - totalSlots?: number, - runnerName?: string, - namespace?: string, + updateConfig: (config: RunnerConfig) => void, ) => ActorDriver; interface ServerOutput> { @@ -184,38 +181,19 @@ export class Registry { }); } - // Setup serverless driver - let serverlessActorDriverBuilder: undefined | ServerlessActorDriverBuilder; - if (config.runnerKind === "serverless") { - // Configure serverless runner if enabled when actor driver is disabled - if (config.autoConfigureServerless) { - Promise.all(readyPromises).then(async () => { - await configureServerlessRunner(config); - }); - } - - serverlessActorDriverBuilder = ( - token, - totalSlots, - runnerName, - namespace, - ) => { - // Override config - if (token) config.token = token; - if (totalSlots) config.totalSlots = totalSlots; - if (runnerName) config.runnerName = runnerName; - if (namespace) config.namespace = namespace; - - // Create new actor driver with updated config - return driver.actor(this.#config, config, managerDriver, client); - }; + // Configure serverless runner if enabled when actor driver is disabled + if (config.runnerKind === "serverless" && config.autoConfigureServerless) { + Promise.all(readyPromises).then(async () => { + await configureServerlessRunner(config); + }); } const { router: hono } = createManagerRouter( this.#config, config, managerDriver, - serverlessActorDriverBuilder, + driver, + client, ); // Start server diff --git a/packages/rivetkit/src/remote-manager-driver/api-utils.ts b/packages/rivetkit/src/remote-manager-driver/api-utils.ts index 00a790bfc..196627050 100644 --- a/packages/rivetkit/src/remote-manager-driver/api-utils.ts +++ b/packages/rivetkit/src/remote-manager-driver/api-utils.ts @@ -16,7 +16,7 @@ export class EngineApiError extends Error { } export function getEndpoint(config: ClientConfig) { - return config.endpoint ?? "http://localhost:6420"; + return config.endpoint ?? "http://127.0.0.1:6420"; } // Helper function for making API calls diff --git a/packages/rivetkit/src/test/mod.ts b/packages/rivetkit/src/test/mod.ts index 5f3f308b9..332fa9418 100644 --- a/packages/rivetkit/src/test/mod.ts +++ b/packages/rivetkit/src/test/mod.ts @@ -10,6 +10,7 @@ import { getInspectorUrl, } from "@/inspector/utils"; import { createManagerRouter } from "@/manager/router"; +import { createClientWithDriver } from "@/mod"; import type { Registry } from "@/registry/mod"; import { RunnerConfigSchema } from "@/registry/run-config"; import { ConfigSchema, type InputConfig } from "./config"; @@ -30,12 +31,14 @@ function serve(registry: Registry, inputConfig?: InputConfig): ServerType { const runConfig = RunnerConfigSchema.parse(inputConfig); const driver = inputConfig.driver ?? createFileSystemOrMemoryDriver(false); const managerDriver = driver.manager(registry.config, config); + const client = createClientWithDriver(managerDriver); configureInspectorAccessToken(config, managerDriver); const { router } = createManagerRouter( registry.config, runConfig, managerDriver, - undefined, + driver, + client, ); // Inject WebSocket diff --git a/packages/rivetkit/tests/driver-engine.test.ts b/packages/rivetkit/tests/driver-engine.test.ts index 4cd382547..f33f5a0e2 100644 --- a/packages/rivetkit/tests/driver-engine.test.ts +++ b/packages/rivetkit/tests/driver-engine.test.ts @@ -39,13 +39,7 @@ runDriverTests({ } // Create driver config - const driverConfig = createEngineDriver({ - endpoint, - namespace, - runnerName, - token: "dev", - totalSlots: 1000, - }); + const driverConfig = createEngineDriver(); // Start the actor driver const runConfig = RunnerConfigSchema.parse({