diff --git a/src/db/prisma/migrations/20260428010811_add_inference_task_queue/migration.sql b/src/db/prisma/migrations/20260428010811_add_inference_task_queue/migration.sql new file mode 100644 index 0000000..41ff6ba --- /dev/null +++ b/src/db/prisma/migrations/20260428010811_add_inference_task_queue/migration.sql @@ -0,0 +1,39 @@ +-- v5: durable inference task queue. Every inference call (sync infer, +-- agent chat, or async POST /inference-tasks) gets a row here. Workers +-- (mcplocal sessions) drain pending rows whose `poolName` matches the +-- pool keys they own when they bind their SSE channel. +CREATE TYPE "InferenceTaskStatus" AS ENUM ('pending', 'claimed', 'running', 'completed', 'error', 'cancelled'); + +CREATE TABLE "InferenceTask" ( + "id" TEXT NOT NULL, + "status" "InferenceTaskStatus" NOT NULL DEFAULT 'pending', + "poolName" TEXT NOT NULL, + "llmName" TEXT NOT NULL, + "model" TEXT NOT NULL, + "tier" TEXT, + "claimedBy" TEXT, + "requestBody" JSONB NOT NULL, + "responseBody" JSONB, + "errorMessage" TEXT, + "streaming" BOOLEAN NOT NULL DEFAULT false, + "createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP, + "claimedAt" TIMESTAMP(3), + "streamStartedAt" TIMESTAMP(3), + "completedAt" TIMESTAMP(3), + "ownerId" TEXT NOT NULL, + "agentId" TEXT, + + CONSTRAINT "InferenceTask_pkey" PRIMARY KEY ("id") +); + +-- Worker claim path: SELECT … WHERE status='pending' AND poolName IN (…) +-- runs on every SSE bind. Compound index keeps that fast as the table grows. +CREATE INDEX "InferenceTask_status_poolName_idx" ON "InferenceTask"("status", "poolName"); +-- unbindSession revert path: SELECT … WHERE claimedBy=$1 AND status IN ('claimed','running'). +CREATE INDEX "InferenceTask_claimedBy_idx" ON "InferenceTask"("claimedBy"); +-- Owner scoping for the async API listing. +CREATE INDEX "InferenceTask_ownerId_idx" ON "InferenceTask"("ownerId"); +CREATE INDEX "InferenceTask_agentId_idx" ON "InferenceTask"("agentId"); +-- GC sweep predicate: completedAt < now()-7d. Indexed so the daily cleanup +-- doesn't seq-scan once the table grows past a few thousand rows. +CREATE INDEX "InferenceTask_completedAt_idx" ON "InferenceTask"("completedAt"); diff --git a/src/db/prisma/schema.prisma b/src/db/prisma/schema.prisma index 403f18d..4289ae0 100644 --- a/src/db/prisma/schema.prisma +++ b/src/db/prisma/schema.prisma @@ -604,6 +604,79 @@ model ChatMessage { @@index([threadId, createdAt]) } +// ── Inference Tasks (v5) ── +// +// Every inference call (sync infer, agent chat, async POST /inference-tasks) +// creates a row here. The DB is the source of truth; mcpd's previous +// in-memory `pendingTasks` map is gone — the result-handler updates the row +// and an in-process EventEmitter wakes any blocked HTTP handlers (single- +// instance for now; multi-instance scaling is a v6 concern that will swap +// the emitter for pg LISTEN/NOTIFY without changing the data model). +// +// Routing: `poolName` is the effective pool key at enqueue time +// (`Llm.poolName ?? Llm.name`). Workers (mcplocal sessions) drain pending +// rows whose `poolName` matches the pool keys they own when they bind their +// SSE channel — that's how queued tasks survive worker offline windows. + +enum InferenceTaskStatus { + pending // in queue, no worker has it yet (or claim was reverted) + claimed // a worker has it (SSE frame sent), no chunks back yet + running // worker started streaming chunks back (streaming tasks only) + completed // worker POSTed the final result + error // permanent failure (auth, bad request, queue timeout) + cancelled // caller said never mind via DELETE +} + +model InferenceTask { + id String @id @default(cuid()) + status InferenceTaskStatus @default(pending) + // Routing — pool key drives worker matching at claim time. Stored at + // enqueue time so a later rename of Llm.poolName doesn't reroute + // already-queued work. + poolName String + llmName String // pinned target Llm name (for audit + agent backref) + model String + tier String? + // Worker tracking. NULL while pending; set on claim; cleared on + // unbindSession-driven revert (worker disconnect mid-task). + claimedBy String? + // Body + result. Both are Json so streaming chunks can be reconstructed + // (see TaskService.complete) and async pollers get a structured payload. + // requestBody is required (the OpenAI chat-completion request body the + // worker should run); responseBody is null until status=completed. + requestBody Json + responseBody Json? + errorMessage String? + /** + * Whether the original request asked for streaming. Drives the chunk-vs- + * final-body protocol on the result POST and tells async API callers + * whether `/stream` will yield chunks or just a single completion event. + */ + streaming Boolean @default(false) + // Timestamps for observability + GC: + // pending → claimed: claimedAt set + // claimed → running: streamStartedAt set (first chunk received) + // running/claimed → completed/error/cancelled: completedAt set + createdAt DateTime @default(now()) + claimedAt DateTime? + streamStartedAt DateTime? + completedAt DateTime? + // Caller tracking — RBAC + observability. ownerId references User.id; + // agentId is set when the task came in via /agents//chat (null + // for direct /llms//infer or async POST /inference-tasks calls + // that don't pin an agent). + ownerId String + agentId String? + + @@index([status, poolName]) + @@index([claimedBy]) + @@index([ownerId]) + @@index([agentId]) + // GC sweep predicate: completedAt < 7d ago. Index speeds up the daily + // cleanup without scanning the whole table. + @@index([completedAt]) +} + // ── Audit Logs ── model AuditLog { diff --git a/src/db/tests/inference-task-schema.test.ts b/src/db/tests/inference-task-schema.test.ts new file mode 100644 index 0000000..a95bbe9 --- /dev/null +++ b/src/db/tests/inference-task-schema.test.ts @@ -0,0 +1,169 @@ +/** + * v5 db-level tests for the InferenceTask queue. Exercises the actual + * column shapes + index lookups; the mcpd-side service tests cover the + * state machine + signal channels with a mocked repo. + */ +import { describe, it, expect, beforeAll, afterAll, beforeEach } from 'vitest'; +import type { PrismaClient } from '@prisma/client'; +import { setupTestDb, cleanupTestDb, clearAllTables } from './helpers.js'; + +async function makeOwner(prisma: PrismaClient): Promise { + const u = await prisma.user.create({ + data: { email: `owner-${String(Date.now())}@test`, passwordHash: 'x' }, + }); + return u.id; +} + +describe('InferenceTask schema (v5)', () => { + let prisma: PrismaClient; + + beforeAll(async () => { + prisma = await setupTestDb(); + }, 30_000); + + afterAll(async () => { + await cleanupTestDb(); + }); + + beforeEach(async () => { + await clearAllTables(prisma); + }); + + it('defaults a fresh row to status=pending with claim/completion fields null', async () => { + const ownerId = await makeOwner(prisma); + const row = await prisma.inferenceTask.create({ + data: { + poolName: 'qwen-pool', + llmName: 'qwen-prod-1', + model: 'qwen3-thinking', + requestBody: { messages: [{ role: 'user', content: 'hi' }] }, + ownerId, + }, + }); + expect(row.status).toBe('pending'); + expect(row.claimedBy).toBeNull(); + expect(row.claimedAt).toBeNull(); + expect(row.streamStartedAt).toBeNull(); + expect(row.completedAt).toBeNull(); + expect(row.responseBody).toBeNull(); + expect(row.streaming).toBe(false); + }); + + it('roundtrips streaming=true and a structured requestBody/responseBody', async () => { + const ownerId = await makeOwner(prisma); + const requestBody = { + messages: [{ role: 'user', content: 'hello' }], + temperature: 0.2, + tools: [{ type: 'function', function: { name: 'noop' } }], + }; + const row = await prisma.inferenceTask.create({ + data: { + poolName: 'qwen-pool', + llmName: 'qwen-prod-1', + model: 'qwen3', + requestBody, + streaming: true, + ownerId, + }, + }); + expect(row.streaming).toBe(true); + expect(row.requestBody).toEqual(requestBody); + + const completedAt = new Date(); + const responseBody = { choices: [{ message: { role: 'assistant', content: 'world' } }] }; + const updated = await prisma.inferenceTask.update({ + where: { id: row.id }, + data: { status: 'completed', responseBody, completedAt }, + }); + expect(updated.responseBody).toEqual(responseBody); + expect(updated.completedAt?.getTime()).toBe(completedAt.getTime()); + }); + + it('compound index supports the dispatcher\'s drain query (status + poolName IN ...)', async () => { + // The actual EXPLAIN/index-use check is too brittle for unit tests; + // here we verify the QUERY shape that the repo's findPendingForPools + // issues — same WHERE/ORDER BY — returns the expected rows in FIFO + // order. Index usage is implied by the Prisma model definition. + const ownerId = await makeOwner(prisma); + const t1 = await prisma.inferenceTask.create({ + data: { poolName: 'pool-a', llmName: 'a-1', model: 'm', requestBody: {}, ownerId }, + }); + await new Promise((r) => setTimeout(r, 5)); + const t2 = await prisma.inferenceTask.create({ + data: { poolName: 'pool-a', llmName: 'a-2', model: 'm', requestBody: {}, ownerId }, + }); + await prisma.inferenceTask.create({ + data: { poolName: 'pool-b', llmName: 'b-1', model: 'm', requestBody: {}, ownerId }, + }); + // One row in pool-a is no longer pending — must be excluded. + await prisma.inferenceTask.create({ + data: { poolName: 'pool-a', llmName: 'a-3', model: 'm', requestBody: {}, ownerId, status: 'completed' }, + }); + + const drained = await prisma.inferenceTask.findMany({ + where: { status: 'pending', poolName: { in: ['pool-a', 'pool-b'] } }, + orderBy: { createdAt: 'asc' }, + }); + expect(drained.map((r) => r.id)).toEqual([t1.id, t2.id, drained[2]!.id]); + expect(drained.map((r) => r.poolName)).toEqual(['pool-a', 'pool-a', 'pool-b']); + }); + + it('claimedBy index supports unbindSession revert (worker disconnect path)', async () => { + const ownerId = await makeOwner(prisma); + await prisma.inferenceTask.create({ + data: { poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, ownerId, status: 'claimed', claimedBy: 'sess-A' }, + }); + await prisma.inferenceTask.create({ + data: { poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, ownerId, status: 'running', claimedBy: 'sess-A' }, + }); + await prisma.inferenceTask.create({ + data: { poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, ownerId, status: 'claimed', claimedBy: 'sess-B' }, + }); + // Completed-but-claimedBy=sess-A row: must NOT revert (terminal state). + await prisma.inferenceTask.create({ + data: { poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, ownerId, status: 'completed', claimedBy: 'sess-A' }, + }); + + const heldByA = await prisma.inferenceTask.findMany({ + where: { claimedBy: 'sess-A', status: { in: ['claimed', 'running'] } }, + }); + expect(heldByA).toHaveLength(2); + }); + + it('GC predicate (terminal + completedAt < cutoff) is index-friendly and filters correctly', async () => { + const ownerId = await makeOwner(prisma); + const old = new Date(Date.now() - 8 * 24 * 60 * 60 * 1000); // 8 d ago + const recent = new Date(Date.now() - 1 * 60 * 60 * 1000); // 1 h ago + await prisma.inferenceTask.create({ + data: { poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, ownerId, status: 'completed', completedAt: old }, + }); + await prisma.inferenceTask.create({ + data: { poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, ownerId, status: 'error', completedAt: old, errorMessage: 'boom' }, + }); + // Inside retention — must not be picked up by GC. + await prisma.inferenceTask.create({ + data: { poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, ownerId, status: 'completed', completedAt: recent }, + }); + // Pending row — must not be picked up by terminal GC. + await prisma.inferenceTask.create({ + data: { poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, ownerId, status: 'pending' }, + }); + + const cutoff = new Date(Date.now() - 7 * 24 * 60 * 60 * 1000); + const expired = await prisma.inferenceTask.findMany({ + where: { + status: { in: ['completed', 'error', 'cancelled'] }, + completedAt: { lt: cutoff }, + }, + }); + expect(expired).toHaveLength(2); + }); + + it('agentId is nullable — direct chat-llm tasks have no agent', async () => { + const ownerId = await makeOwner(prisma); + const row = await prisma.inferenceTask.create({ + data: { poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, ownerId, agentId: null }, + }); + expect(row.agentId).toBeNull(); + }); +}); diff --git a/src/mcpd/src/repositories/inference-task.repository.ts b/src/mcpd/src/repositories/inference-task.repository.ts new file mode 100644 index 0000000..3e6edb5 --- /dev/null +++ b/src/mcpd/src/repositories/inference-task.repository.ts @@ -0,0 +1,235 @@ +/** + * v5 InferenceTask repository — durable queue for inference work. + * + * Workers (mcplocal SSE sessions) drain rows whose `poolName` matches the + * pool keys they own when they bind. The state machine is: + * + * pending ──claim──> claimed ──first chunk──> running + * \ │ + * \ ▼ + * ──complete/fail──> completed | error + * + * Plus orthogonal terminal transitions: + * + * pending|claimed|running ──cancel──> cancelled + * claimed|running ──worker disconnect──> pending (revert + re-queue) + * + * The DB is the source of truth. mcpd's in-flight HTTP handlers wait via + * an in-process EventEmitter (see TaskService) — single-instance for now; + * pg LISTEN/NOTIFY is the v6 swap when scaling horizontally. + */ +import type { + PrismaClient, + InferenceTask, + InferenceTaskStatus, +} from '@prisma/client'; +// `Prisma` is imported as a value because we reference Prisma.JsonNull at +// runtime (the sentinel for explicit JSON null on a nullable column). +import { Prisma } from '@prisma/client'; + +export interface CreateInferenceTaskInput { + poolName: string; + llmName: string; + model: string; + tier?: string | null; + requestBody: Record; + streaming: boolean; + ownerId: string; + agentId?: string | null; +} + +export interface IInferenceTaskRepository { + create(data: CreateInferenceTaskInput): Promise; + findById(id: string): Promise; + /** Pending rows for one or more pool keys, oldest first (FIFO). */ + findPendingForPools(poolNames: string[], limit?: number): Promise; + /** Tasks held by a worker session — used by unbindSession to revert on disconnect. */ + findHeldBy(claimedBy: string): Promise; + /** List for the async API; filters are AND-combined. */ + list(filter: { + ownerId?: string; + status?: InferenceTaskStatus | InferenceTaskStatus[]; + poolName?: string; + agentId?: string; + limit?: number; + }): Promise; + + // ── State transitions ── + // Each transition is conditional ("compare-and-swap" via updateMany + + // affected-row count) so two workers racing to claim the same row both + // succeed at the DB level — one gets affected=1 and a populated row, + // the other gets affected=0 and we fall through to the next candidate. + + /** pending → claimed; returns the updated row, or null if the row was no longer pending. */ + tryClaim(id: string, claimedBy: string, claimedAt: Date): Promise; + /** claimed → running; idempotent — returns the row even if already running. */ + markRunning(id: string, at: Date): Promise; + /** {claimed,running} → completed; only allowed from a non-terminal state. */ + markCompleted(id: string, responseBody: Record | null, at: Date): Promise; + /** any non-terminal → error. Records `errorMessage`. */ + markError(id: string, errorMessage: string, at: Date): Promise; + /** any non-terminal → cancelled. */ + markCancelled(id: string, at: Date): Promise; + /** {claimed,running} → pending; clears `claimedBy`. Used on worker disconnect. */ + revertToPending(id: string): Promise; + + // ── GC ── + + /** Pending rows older than the cutoff — the GC sweep flips these to error. */ + findStalePending(cutoff: Date): Promise; + /** Completed/error/cancelled rows older than the cutoff — GC deletes them. */ + findExpiredTerminal(cutoff: Date): Promise; + /** Bulk delete by id — used by the GC sweep after collecting expired ids. */ + deleteMany(ids: string[]): Promise; +} + +const NON_TERMINAL: InferenceTaskStatus[] = ['pending', 'claimed', 'running']; +const HELD: InferenceTaskStatus[] = ['claimed', 'running']; + +export class InferenceTaskRepository implements IInferenceTaskRepository { + constructor(private readonly prisma: PrismaClient) {} + + async create(data: CreateInferenceTaskInput): Promise { + return this.prisma.inferenceTask.create({ + data: { + poolName: data.poolName, + llmName: data.llmName, + model: data.model, + tier: data.tier ?? null, + requestBody: data.requestBody as Prisma.InputJsonValue, + streaming: data.streaming, + ownerId: data.ownerId, + agentId: data.agentId ?? null, + }, + }); + } + + async findById(id: string): Promise { + return this.prisma.inferenceTask.findUnique({ where: { id } }); + } + + async findPendingForPools(poolNames: string[], limit?: number): Promise { + if (poolNames.length === 0) return []; + return this.prisma.inferenceTask.findMany({ + where: { status: 'pending', poolName: { in: poolNames } }, + orderBy: { createdAt: 'asc' }, + ...(limit !== undefined ? { take: limit } : {}), + }); + } + + async findHeldBy(claimedBy: string): Promise { + return this.prisma.inferenceTask.findMany({ + where: { claimedBy, status: { in: HELD } }, + }); + } + + async list(filter: { + ownerId?: string; + status?: InferenceTaskStatus | InferenceTaskStatus[]; + poolName?: string; + agentId?: string; + limit?: number; + }): Promise { + const where: Prisma.InferenceTaskWhereInput = {}; + if (filter.ownerId !== undefined) where.ownerId = filter.ownerId; + if (filter.status !== undefined) { + where.status = Array.isArray(filter.status) ? { in: filter.status } : filter.status; + } + if (filter.poolName !== undefined) where.poolName = filter.poolName; + if (filter.agentId !== undefined) where.agentId = filter.agentId; + return this.prisma.inferenceTask.findMany({ + where, + orderBy: { createdAt: 'desc' }, + ...(filter.limit !== undefined ? { take: filter.limit } : {}), + }); + } + + async tryClaim(id: string, claimedBy: string, claimedAt: Date): Promise { + // Compare-and-swap on status='pending'. Two workers racing both run + // this UPDATE; whichever the DB serializes first sees affected=1 and + // gets the row, the loser sees affected=0 and falls through. + const result = await this.prisma.inferenceTask.updateMany({ + where: { id, status: 'pending' }, + data: { status: 'claimed', claimedBy, claimedAt }, + }); + if (result.count === 0) return null; + return this.findById(id); + } + + async markRunning(id: string, at: Date): Promise { + const result = await this.prisma.inferenceTask.updateMany({ + where: { id, status: { in: ['claimed', 'running'] } }, + data: { status: 'running', streamStartedAt: at }, + }); + if (result.count === 0) return null; + return this.findById(id); + } + + async markCompleted(id: string, responseBody: Record | null, at: Date): Promise { + // Prisma's nullable JSON column writes need an explicit + // `Prisma.JsonNull` sentinel for null; passing literal null fails + // typecheck because the column also accepts a plain JSON value. + const bodyForUpdate: Prisma.InputJsonValue | typeof Prisma.JsonNull = + responseBody === null ? Prisma.JsonNull : (responseBody as Prisma.InputJsonValue); + const result = await this.prisma.inferenceTask.updateMany({ + where: { id, status: { in: NON_TERMINAL } }, + data: { + status: 'completed', + responseBody: bodyForUpdate, + completedAt: at, + }, + }); + if (result.count === 0) return null; + return this.findById(id); + } + + async markError(id: string, errorMessage: string, at: Date): Promise { + const result = await this.prisma.inferenceTask.updateMany({ + where: { id, status: { in: NON_TERMINAL } }, + data: { status: 'error', errorMessage, completedAt: at }, + }); + if (result.count === 0) return null; + return this.findById(id); + } + + async markCancelled(id: string, at: Date): Promise { + const result = await this.prisma.inferenceTask.updateMany({ + where: { id, status: { in: NON_TERMINAL } }, + data: { status: 'cancelled', completedAt: at }, + }); + if (result.count === 0) return null; + return this.findById(id); + } + + async revertToPending(id: string): Promise { + const result = await this.prisma.inferenceTask.updateMany({ + where: { id, status: { in: HELD } }, + data: { status: 'pending', claimedBy: null, claimedAt: null, streamStartedAt: null }, + }); + if (result.count === 0) return null; + return this.findById(id); + } + + async findStalePending(cutoff: Date): Promise { + return this.prisma.inferenceTask.findMany({ + where: { status: 'pending', createdAt: { lt: cutoff } }, + }); + } + + async findExpiredTerminal(cutoff: Date): Promise { + return this.prisma.inferenceTask.findMany({ + where: { + status: { in: ['completed', 'error', 'cancelled'] }, + completedAt: { lt: cutoff }, + }, + }); + } + + async deleteMany(ids: string[]): Promise { + if (ids.length === 0) return 0; + const result = await this.prisma.inferenceTask.deleteMany({ + where: { id: { in: ids } }, + }); + return result.count; + } +} diff --git a/src/mcpd/src/services/inference-task.service.ts b/src/mcpd/src/services/inference-task.service.ts new file mode 100644 index 0000000..81d4109 --- /dev/null +++ b/src/mcpd/src/services/inference-task.service.ts @@ -0,0 +1,297 @@ +/** + * v5 InferenceTaskService — business logic over the durable task queue. + * + * The DB row (managed via IInferenceTaskRepository) is the source of truth + * for *state*. This service owns the in-process *signaling* layer that + * lets blocked HTTP handlers wake up promptly when a worker POSTs a + * result, instead of polling the DB. The signals carry no state — they + * just say "go re-read row X" — so a missed signal degrades to a slower + * poll-then-resolve (still correct, just less responsive). + * + * Single-instance assumption: EventEmitter only fires on the mcpd that + * holds the row's blocked handler. With multiple mcpd replicas, the + * worker's POST might land on instance A while the caller's handler is + * on instance B; B never sees the local emitter. Solution at scale is + * pg LISTEN/NOTIFY (v6 swap) — no schema change needed; just replace the + * EventEmitter wakeup with a NOTIFY emit. + */ +import { EventEmitter } from 'node:events'; +import type { InferenceTask, InferenceTaskStatus } from '@prisma/client'; +import type { IInferenceTaskRepository, CreateInferenceTaskInput } from '../repositories/inference-task.repository.js'; +import { NotFoundError } from './mcp-server.service.js'; + +/** + * One streaming chunk pushed back from a worker. Lives in-memory only — + * we don't persist every SSE delta to the DB (too expensive for + * high-frequency streams). The final assembled body lands on the row + * via `complete()` when the worker emits its terminal `done:true`. + */ +export interface InferenceTaskChunk { + data: string; + done?: boolean; +} + +/** + * The wait handle returned to a blocked HTTP handler. `done` resolves + * when the row reaches a terminal status (completed | error | cancelled); + * `chunks` is an async iterator that yields streaming deltas as they + * arrive (only meaningful when the underlying request asked for streaming). + */ +export interface InferenceTaskWaiter { + /** The DB row's terminal state. Throws on `error`/`cancelled` with an explanatory message. */ + done: Promise; + /** Async iterator of streaming chunks. Yields nothing for non-streaming tasks. */ + chunks: AsyncGenerator; +} + +export interface IInferenceTaskService { + /** Create a new pending task. Caller is expected to immediately attempt dispatch. */ + enqueue(input: CreateInferenceTaskInput): Promise; + /** Wait for a task to reach a terminal state, with optional timeout (ms). */ + waitFor(taskId: string, timeoutMs: number): InferenceTaskWaiter; + /** Conditional CAS claim — returns the row if `pending`, null if already claimed. */ + tryClaim(taskId: string, claimedBy: string): Promise; + /** Worker reported first chunk — flips `claimed` → `running`. Idempotent. */ + markRunning(taskId: string): Promise; + /** Worker pushed a streaming chunk. Wakes up the in-process waiter; no DB write. */ + pushChunk(taskId: string, chunk: InferenceTaskChunk): boolean; + /** Worker POSTed the final result. Persists the body + flips to `completed`. */ + complete(taskId: string, responseBody: Record | null): Promise; + /** Worker (or mcpd) marked the task as failed. */ + fail(taskId: string, errorMessage: string): Promise; + /** Caller (or GC) marked the task as cancelled. */ + cancel(taskId: string): Promise; + /** Worker disconnected mid-task — revert claimed/running rows back to pending. */ + revertHeldBy(claimedBy: string): Promise; + /** Pending rows for one or more pool keys, ordered FIFO. Used by drainPending on bind. */ + findPendingForPools(poolNames: string[], limit?: number): Promise; + findById(id: string): Promise; + list(filter: { + ownerId?: string; + status?: InferenceTaskStatus | InferenceTaskStatus[]; + poolName?: string; + agentId?: string; + limit?: number; + }): Promise; + /** GC sweep: pending too long → error; terminal too long → delete. */ + gcSweep(opts: { pendingTimeoutMs: number; terminalRetentionMs: number }): Promise<{ erroredOut: number; deleted: number }>; +} + +const DEFAULT_PENDING_TIMEOUT_MS = 60 * 60 * 1000; // 1 h +const DEFAULT_TERMINAL_RETENTION_MS = 7 * 24 * 60 * 60 * 1000; // 7 d + +export class InferenceTaskService implements IInferenceTaskService { + /** + * Wakeup channel per task. The result-handler (worker POSTs landing on + * mcpd) emits a terminal status; the streaming push emits chunks. + * Cleared once a waiter resolves to avoid leaks. EventEmitter listener + * cap is bumped to a generous default — typical concurrent waiters per + * task is 1, but `/stream` SSE consumers and the original HTTP handler + * can both subscribe. + */ + private readonly events = new EventEmitter(); + + constructor( + private readonly repo: IInferenceTaskRepository, + private readonly clock: () => Date = () => new Date(), + ) { + this.events.setMaxListeners(50); + } + + async enqueue(input: CreateInferenceTaskInput): Promise { + return this.repo.create(input); + } + + waitFor(taskId: string, timeoutMs: number): InferenceTaskWaiter { + // Set up the chunk channel up-front so a worker pushing chunks before + // the consumer subscribes doesn't drop them. We use a queue + wake + // pattern (same shape as the v3 chat.service streaming bridge). + const chunkQueue: InferenceTaskChunk[] = []; + let chunkResolve: (() => void) | null = null; + let finished = false; + let finishError: Error | null = null; + + const onChunk = (chunk: InferenceTaskChunk): void => { + chunkQueue.push(chunk); + if (chunkResolve !== null) { + const r = chunkResolve; + chunkResolve = null; + r(); + } + }; + const onTerminal = (): void => { + finished = true; + if (chunkResolve !== null) { + const r = chunkResolve; + chunkResolve = null; + r(); + } + }; + + this.events.on(`chunk:${taskId}`, onChunk); + this.events.on(`terminal:${taskId}`, onTerminal); + + const cleanup = (): void => { + this.events.off(`chunk:${taskId}`, onChunk); + this.events.off(`terminal:${taskId}`, onTerminal); + }; + + // The terminal promise: poll once at start (in case the row is + // already terminal — common for already-completed tasks the caller + // is just fetching the result of), then race the timeout against + // the wakeup signal. + const done = (async (): Promise => { + try { + const initial = await this.repo.findById(taskId); + if (initial === null) throw new NotFoundError(`InferenceTask not found: ${taskId}`); + if (isTerminal(initial.status)) return ensureSuccess(initial); + + // Wait for terminal event OR timeout. If the wake happens before + // the timer fires, we re-fetch and return; otherwise we error. + const timer = new Promise((_, reject) => { + setTimeout(() => reject(new Error(`InferenceTask wait timed out after ${String(timeoutMs)}ms (task ${taskId})`)), timeoutMs); + }); + const wake = new Promise((resolve) => { + this.events.once(`terminal:${taskId}`, () => resolve()); + }); + await Promise.race([timer, wake]); + finished = true; + const final = await this.repo.findById(taskId); + if (final === null) throw new NotFoundError(`InferenceTask not found: ${taskId}`); + return ensureSuccess(final); + } catch (err) { + finishError = err as Error; + finished = true; + throw err; + } finally { + cleanup(); + } + })(); + + const chunks = (async function* gen(): AsyncGenerator { + while (true) { + while (chunkQueue.length > 0) { + const c = chunkQueue.shift()!; + yield c; + if (c.done === true) return; + } + if (finished) { + if (finishError !== null) throw finishError; + return; + } + await new Promise((r) => { chunkResolve = r; }); + } + })(); + + return { done, chunks }; + } + + async tryClaim(taskId: string, claimedBy: string): Promise { + return this.repo.tryClaim(taskId, claimedBy, this.clock()); + } + + async markRunning(taskId: string): Promise { + return this.repo.markRunning(taskId, this.clock()); + } + + pushChunk(taskId: string, chunk: InferenceTaskChunk): boolean { + // Streaming chunks are in-memory only. If the row no longer exists or + // there are no listeners, the emit is a no-op — return false so the + // result-route can decide whether to log a warning. + return this.events.emit(`chunk:${taskId}`, chunk); + } + + async complete(taskId: string, responseBody: Record | null): Promise { + const updated = await this.repo.markCompleted(taskId, responseBody, this.clock()); + if (updated !== null) this.events.emit(`terminal:${taskId}`); + return updated; + } + + async fail(taskId: string, errorMessage: string): Promise { + const updated = await this.repo.markError(taskId, errorMessage, this.clock()); + if (updated !== null) this.events.emit(`terminal:${taskId}`); + return updated; + } + + async cancel(taskId: string): Promise { + const updated = await this.repo.markCancelled(taskId, this.clock()); + if (updated !== null) this.events.emit(`terminal:${taskId}`); + return updated; + } + + async revertHeldBy(claimedBy: string): Promise { + const held = await this.repo.findHeldBy(claimedBy); + const reverted: InferenceTask[] = []; + for (const t of held) { + const r = await this.repo.revertToPending(t.id); + if (r !== null) reverted.push(r); + } + return reverted; + } + + async findPendingForPools(poolNames: string[], limit?: number): Promise { + return this.repo.findPendingForPools(poolNames, limit); + } + + async findById(id: string): Promise { + return this.repo.findById(id); + } + + async list(filter: { + ownerId?: string; + status?: InferenceTaskStatus | InferenceTaskStatus[]; + poolName?: string; + agentId?: string; + limit?: number; + }): Promise { + return this.repo.list(filter); + } + + async gcSweep(opts: { pendingTimeoutMs: number; terminalRetentionMs: number } = { + pendingTimeoutMs: DEFAULT_PENDING_TIMEOUT_MS, + terminalRetentionMs: DEFAULT_TERMINAL_RETENTION_MS, + }): Promise<{ erroredOut: number; deleted: number }> { + const now = this.clock(); + const pendingCutoff = new Date(now.getTime() - opts.pendingTimeoutMs); + const terminalCutoff = new Date(now.getTime() - opts.terminalRetentionMs); + + // Pass 1: pending tasks that have aged out → flip to error so the + // caller (if still waiting) gets a clean failure. + const stale = await this.repo.findStalePending(pendingCutoff); + let erroredOut = 0; + for (const t of stale) { + const r = await this.repo.markError(t.id, `Task expired in pending state after ${String(opts.pendingTimeoutMs / 1000)}s`, now); + if (r !== null) { + this.events.emit(`terminal:${t.id}`); + erroredOut += 1; + } + } + + // Pass 2: terminal tasks past retention → delete. + const expired = await this.repo.findExpiredTerminal(terminalCutoff); + const deleted = await this.repo.deleteMany(expired.map((t) => t.id)); + + return { erroredOut, deleted }; + } +} + +function isTerminal(status: InferenceTaskStatus): boolean { + return status === 'completed' || status === 'error' || status === 'cancelled'; +} + +/** + * Surface a non-success terminal state as a thrown error so the HTTP + * handler propagates it cleanly. `error` rows carry an `errorMessage`; + * `cancelled` rows don't, so we synthesize one. + */ +function ensureSuccess(task: InferenceTask): InferenceTask { + if (task.status === 'completed') return task; + if (task.status === 'cancelled') { + throw new Error(`InferenceTask cancelled: ${task.id}`); + } + if (task.status === 'error') { + throw new Error(task.errorMessage ?? `InferenceTask failed: ${task.id}`); + } + // Theoretically unreachable — waitFor only resolves on terminal events. + throw new Error(`InferenceTask in unexpected state: ${String(task.status)}`); +} diff --git a/src/mcpd/tests/inference-task-service.test.ts b/src/mcpd/tests/inference-task-service.test.ts new file mode 100644 index 0000000..77a1413 --- /dev/null +++ b/src/mcpd/tests/inference-task-service.test.ts @@ -0,0 +1,330 @@ +/** + * v5 InferenceTaskService — state machine, signal channels, and GC sweep. + * The repo is mocked with an in-memory Map so we exercise the service + * logic deterministically without touching Postgres. Schema-level tests + * live in src/db/tests/inference-task-schema.test.ts. + */ +import { describe, it, expect, vi } from 'vitest'; +import type { InferenceTask, InferenceTaskStatus } from '@prisma/client'; +import type { IInferenceTaskRepository } from '../src/repositories/inference-task.repository.js'; +import { InferenceTaskService } from '../src/services/inference-task.service.js'; + +function makeRow(overrides: Partial = {}): InferenceTask { + return { + id: overrides.id ?? `task-${Math.random().toString(36).slice(2, 8)}`, + status: 'pending', + poolName: 'pool-a', + llmName: 'pool-a', + model: 'qwen3-thinking', + tier: null, + claimedBy: null, + requestBody: { messages: [{ role: 'user', content: 'hi' }] } as unknown as InferenceTask['requestBody'], + responseBody: null, + errorMessage: null, + streaming: false, + createdAt: new Date(), + claimedAt: null, + streamStartedAt: null, + completedAt: null, + ownerId: 'owner-1', + agentId: null, + ...overrides, + }; +} + +function mockRepo(initial: InferenceTask[] = []): IInferenceTaskRepository { + const rows = new Map(initial.map((r) => [r.id, { ...r }])); + // Tiny CAS helper that mirrors the real `updateMany({where:{status:in}})` + // semantics — only flips the row if the current status matches. + const cas = (id: string, allowed: InferenceTaskStatus[], patch: Partial): InferenceTask | null => { + const row = rows.get(id); + if (row === undefined) return null; + if (!allowed.includes(row.status)) return null; + const next = { ...row, ...patch }; + rows.set(id, next); + return next; + }; + return { + create: vi.fn(async (data) => { + const row = makeRow({ + id: `task-${rows.size + 1}`, + poolName: data.poolName, + llmName: data.llmName, + model: data.model, + tier: data.tier ?? null, + requestBody: data.requestBody as InferenceTask['requestBody'], + streaming: data.streaming, + ownerId: data.ownerId, + agentId: data.agentId ?? null, + }); + rows.set(row.id, row); + return row; + }), + findById: vi.fn(async (id) => rows.get(id) ?? null), + findPendingForPools: vi.fn(async (poolNames, limit) => { + const out = [...rows.values()] + .filter((r) => r.status === 'pending' && poolNames.includes(r.poolName)) + .sort((a, b) => a.createdAt.getTime() - b.createdAt.getTime()); + return limit !== undefined ? out.slice(0, limit) : out; + }), + findHeldBy: vi.fn(async (claimedBy) => + [...rows.values()].filter((r) => r.claimedBy === claimedBy && (r.status === 'claimed' || r.status === 'running')), + ), + list: vi.fn(async (filter) => { + let out = [...rows.values()]; + if (filter.ownerId !== undefined) out = out.filter((r) => r.ownerId === filter.ownerId); + if (filter.poolName !== undefined) out = out.filter((r) => r.poolName === filter.poolName); + if (filter.agentId !== undefined) out = out.filter((r) => r.agentId === filter.agentId); + if (filter.status !== undefined) { + const statuses = Array.isArray(filter.status) ? filter.status : [filter.status]; + out = out.filter((r) => statuses.includes(r.status)); + } + out.sort((a, b) => b.createdAt.getTime() - a.createdAt.getTime()); + return filter.limit !== undefined ? out.slice(0, filter.limit) : out; + }), + tryClaim: vi.fn(async (id, claimedBy, claimedAt) => + cas(id, ['pending'], { status: 'claimed', claimedBy, claimedAt }), + ), + markRunning: vi.fn(async (id, at) => + cas(id, ['claimed', 'running'], { status: 'running', streamStartedAt: at }), + ), + markCompleted: vi.fn(async (id, body, at) => + cas(id, ['pending', 'claimed', 'running'], { + status: 'completed', + responseBody: (body ?? null) as InferenceTask['responseBody'], + completedAt: at, + }), + ), + markError: vi.fn(async (id, errorMessage, at) => + cas(id, ['pending', 'claimed', 'running'], { status: 'error', errorMessage, completedAt: at }), + ), + markCancelled: vi.fn(async (id, at) => + cas(id, ['pending', 'claimed', 'running'], { status: 'cancelled', completedAt: at }), + ), + revertToPending: vi.fn(async (id) => + cas(id, ['claimed', 'running'], { status: 'pending', claimedBy: null, claimedAt: null, streamStartedAt: null }), + ), + findStalePending: vi.fn(async (cutoff) => + [...rows.values()].filter((r) => r.status === 'pending' && r.createdAt.getTime() < cutoff.getTime()), + ), + findExpiredTerminal: vi.fn(async (cutoff) => + [...rows.values()].filter((r) => + (r.status === 'completed' || r.status === 'error' || r.status === 'cancelled') + && r.completedAt !== null + && r.completedAt.getTime() < cutoff.getTime(), + ), + ), + deleteMany: vi.fn(async (ids) => { + let n = 0; + for (const id of ids) if (rows.delete(id)) n += 1; + return n; + }), + }; +} + +describe('InferenceTaskService — state machine', () => { + it('enqueue creates a pending row with the given pool/llm/model', async () => { + const repo = mockRepo(); + const svc = new InferenceTaskService(repo); + const t = await svc.enqueue({ + poolName: 'pool-a', + llmName: 'a-1', + model: 'qwen3', + requestBody: { messages: [] }, + streaming: false, + ownerId: 'owner-1', + }); + expect(t.status).toBe('pending'); + expect(t.poolName).toBe('pool-a'); + expect(t.claimedBy).toBeNull(); + }); + + it('tryClaim races: only one of two concurrent claimers gets the row', async () => { + const repo = mockRepo(); + const svc = new InferenceTaskService(repo); + const t = await svc.enqueue({ + poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, streaming: false, ownerId: 'o', + }); + // Both workers issue tryClaim at the "same time". The repo's CAS + // serializes them — first claim wins, second sees a non-pending + // row and returns null. + const [a, b] = await Promise.all([svc.tryClaim(t.id, 'sess-A'), svc.tryClaim(t.id, 'sess-B')]); + const winners = [a, b].filter((r) => r !== null); + const losers = [a, b].filter((r) => r === null); + expect(winners).toHaveLength(1); + expect(losers).toHaveLength(1); + expect(winners[0]!.status).toBe('claimed'); + expect(['sess-A', 'sess-B']).toContain(winners[0]!.claimedBy); + }); + + it('complete after claim transitions claimed → completed and stores responseBody', async () => { + const repo = mockRepo(); + const svc = new InferenceTaskService(repo); + const t = await svc.enqueue({ poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, streaming: false, ownerId: 'o' }); + await svc.tryClaim(t.id, 'sess-A'); + const done = await svc.complete(t.id, { choices: [{ message: { content: 'hi' } }] }); + expect(done?.status).toBe('completed'); + expect(done?.responseBody).toEqual({ choices: [{ message: { content: 'hi' } }] }); + expect(done?.completedAt).not.toBeNull(); + }); + + it('refuses double-complete (idempotent terminal)', async () => { + const repo = mockRepo(); + const svc = new InferenceTaskService(repo); + const t = await svc.enqueue({ poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, streaming: false, ownerId: 'o' }); + const first = await svc.complete(t.id, { ok: 1 }); + expect(first?.status).toBe('completed'); + // Second worker tries to complete the same task — CAS rejects because + // the row is no longer in a non-terminal state. + const second = await svc.complete(t.id, { ok: 2 }); + expect(second).toBeNull(); + // First completion's body is preserved. + const reread = await svc.findById(t.id); + expect(reread?.responseBody).toEqual({ ok: 1 }); + }); + + it('revertHeldBy reverts every claimed/running row owned by a session and leaves terminals alone', async () => { + const repo = mockRepo(); + const svc = new InferenceTaskService(repo); + const t1 = await svc.enqueue({ poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, streaming: false, ownerId: 'o' }); + const t2 = await svc.enqueue({ poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, streaming: false, ownerId: 'o' }); + const t3 = await svc.enqueue({ poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, streaming: false, ownerId: 'o' }); + await svc.tryClaim(t1.id, 'sess-A'); + await svc.tryClaim(t2.id, 'sess-A'); + await svc.markRunning(t2.id); + await svc.tryClaim(t3.id, 'sess-A'); + await svc.complete(t3.id, { ok: 1 }); // t3 finished before disconnect + + const reverted = await svc.revertHeldBy('sess-A'); + expect(reverted.map((r) => r.id).sort()).toEqual([t1.id, t2.id].sort()); + expect((await svc.findById(t1.id))?.status).toBe('pending'); + expect((await svc.findById(t2.id))?.status).toBe('pending'); + // t3 stayed completed — terminal rows are not reverted on disconnect. + expect((await svc.findById(t3.id))?.status).toBe('completed'); + }); + + it('cancel from pending records cancelled and emits terminal event', async () => { + const repo = mockRepo(); + const svc = new InferenceTaskService(repo); + const t = await svc.enqueue({ poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, streaming: false, ownerId: 'o' }); + const cancelled = await svc.cancel(t.id); + expect(cancelled?.status).toBe('cancelled'); + // A subsequent complete must fail (cancelled is terminal). + const result = await svc.complete(t.id, { ok: 1 }); + expect(result).toBeNull(); + }); +}); + +describe('InferenceTaskService — waitFor signals', () => { + it('resolves immediately when the row is already terminal at subscribe time', async () => { + const repo = mockRepo(); + const svc = new InferenceTaskService(repo); + const t = await svc.enqueue({ poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, streaming: false, ownerId: 'o' }); + await svc.complete(t.id, { ok: 1 }); + const waiter = svc.waitFor(t.id, 1_000); + const final = await waiter.done; + expect(final.status).toBe('completed'); + expect(final.responseBody).toEqual({ ok: 1 }); + }); + + it('wakes a waiter on complete event without polling the DB', async () => { + const repo = mockRepo(); + const svc = new InferenceTaskService(repo); + const t = await svc.enqueue({ poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, streaming: false, ownerId: 'o' }); + const waiter = svc.waitFor(t.id, 5_000); + // Fire the complete after a microtask so the waiter is definitely + // already subscribed to the terminal event. + setTimeout(() => { void svc.complete(t.id, { ok: 1 }); }, 10); + const final = await waiter.done; + expect(final.status).toBe('completed'); + }); + + it('wakes the chunks generator on pushChunk and ends on terminal', async () => { + const repo = mockRepo(); + const svc = new InferenceTaskService(repo); + const t = await svc.enqueue({ poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, streaming: true, ownerId: 'o' }); + const waiter = svc.waitFor(t.id, 5_000); + + setTimeout(() => { + svc.pushChunk(t.id, { data: 'hello ' }); + svc.pushChunk(t.id, { data: 'world' }); + void svc.complete(t.id, { ok: 1 }); + }, 10); + + const seen: string[] = []; + for await (const c of waiter.chunks) { + seen.push(c.data); + } + expect(seen).toEqual(['hello ', 'world']); + const final = await waiter.done; + expect(final.status).toBe('completed'); + }); + + it('throws on cancellation with a clear error message', async () => { + const repo = mockRepo(); + const svc = new InferenceTaskService(repo); + const t = await svc.enqueue({ poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, streaming: false, ownerId: 'o' }); + const waiter = svc.waitFor(t.id, 5_000); + setTimeout(() => { void svc.cancel(t.id); }, 10); + await expect(waiter.done).rejects.toThrow(/cancelled/i); + }); + + it('throws on error and surfaces the worker\'s errorMessage', async () => { + const repo = mockRepo(); + const svc = new InferenceTaskService(repo); + const t = await svc.enqueue({ poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, streaming: false, ownerId: 'o' }); + const waiter = svc.waitFor(t.id, 5_000); + setTimeout(() => { void svc.fail(t.id, 'upstream 500'); }, 10); + await expect(waiter.done).rejects.toThrow(/upstream 500/); + }); + + it('times out when no terminal event arrives within the deadline', async () => { + const repo = mockRepo(); + const svc = new InferenceTaskService(repo); + const t = await svc.enqueue({ poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, streaming: false, ownerId: 'o' }); + const waiter = svc.waitFor(t.id, 30); + await expect(waiter.done).rejects.toThrow(/timed out/); + }); +}); + +describe('InferenceTaskService — gcSweep', () => { + it('flips stale pending rows to error AND deletes expired terminal rows', async () => { + const now = new Date('2026-04-28T00:00:00Z'); + const fixedClock = (): Date => now; + const repo = mockRepo(); + const svc = new InferenceTaskService(repo, fixedClock); + + // 90 min old pending — past the 1h pendingTimeout cutoff → error. + const stale = await svc.enqueue({ poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, streaming: false, ownerId: 'o' }); + // Backdate via direct fixture mutation — easier than wiring a clock through enqueue. + const staleRow = (await svc.findById(stale.id))!; + (staleRow as { createdAt: Date }).createdAt = new Date(now.getTime() - 90 * 60 * 1000); + + // 30 min old pending — within window, should not be touched. + const fresh = await svc.enqueue({ poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, streaming: false, ownerId: 'o' }); + const freshRow = (await svc.findById(fresh.id))!; + (freshRow as { createdAt: Date }).createdAt = new Date(now.getTime() - 30 * 60 * 1000); + + // 8 day-old completed — past the 7d retention → delete. + const old = await svc.enqueue({ poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, streaming: false, ownerId: 'o' }); + await svc.complete(old.id, { ok: 1 }); + const oldRow = (await svc.findById(old.id))!; + (oldRow as { completedAt: Date }).completedAt = new Date(now.getTime() - 8 * 24 * 60 * 60 * 1000); + + const result = await svc.gcSweep({ + pendingTimeoutMs: 60 * 60 * 1000, // 1h + terminalRetentionMs: 7 * 24 * 60 * 60 * 1000, // 7d + }); + expect(result.erroredOut).toBe(1); + expect(result.deleted).toBe(1); + + // Stale pending was flipped to error. + const staleAfter = await svc.findById(stale.id); + expect(staleAfter?.status).toBe('error'); + expect(staleAfter?.errorMessage).toMatch(/expired in pending/); + // Old completed is gone. + expect(await svc.findById(old.id)).toBeNull(); + // Fresh pending untouched. + expect((await svc.findById(fresh.id))?.status).toBe('pending'); + }); +});