feat(mcpd+db): durable InferenceTask queue + state machine (v5 Stage 1)
The persistence + signaling layer for v5. No integration with the
existing in-flight inference path yet — that's Stage 2. This commit
just lands the durable queue underneath, with a state machine that
mcpd's HTTP handlers, the worker result-POST route, and the GC sweep
will all build on.
Schema (src/db/prisma/schema.prisma + migration):
- New `InferenceTask` model + `InferenceTaskStatus` enum
(pending|claimed|running|completed|error|cancelled).
- Routing fields stored at enqueue time so a later rename of
`Llm.poolName` doesn't reroute already-queued work: `poolName`
(effective pool key), `llmName` (pinned target), `model`, `tier`.
- Worker tracking: `claimedBy` (providerSessionId) + `claimedAt`,
cleared on revert.
- Bodies as `Json`: requestBody (always set), responseBody (set at
completion). Streaming chunks are NOT persisted — too expensive at
delta granularity. The final assembled body lands once per task.
- Lifecycle timestamps: createdAt, claimedAt, streamStartedAt,
completedAt. Plus ownerId (RBAC + audit) and agentId (null for
direct chat-llm calls).
- Indexes for the hot paths: (status, poolName) for the dispatcher's
drain query, claimedBy for the disconnect revert, completedAt for
the GC retention sweep, owner/agent for the async API listing.
Repository (src/mcpd/src/repositories/inference-task.repository.ts):
- CRUD + state transitions as conditional CAS via `updateMany`. Two
workers racing to claim the same row both run the UPDATE; whichever
the DB serializes first sees affected=1 and gets the row, the loser
sees 0 and falls through to the next candidate. No application-
level locking required.
- findPendingForPools(poolNames[]) for the worker drain on bind.
- findHeldBy(claimedBy) for the unbindSession revert.
- findStalePending + findExpiredTerminal for the GC sweep.
Service (src/mcpd/src/services/inference-task.service.ts):
- Owns the in-process EventEmitter that wakes blocked HTTP handlers
when a worker POSTs results. The DB row is the source of truth for
*state*; the EventEmitter just signals "go re-read row X" so we
don't have to poll. Single-instance assumption for v5; pg
LISTEN/NOTIFY is the v6 swap when scaling horizontally — no schema
change needed, just replace the emitter wakeup.
- waitFor(taskId, timeoutMs) returns { done, chunks }: the terminal
promise + an async iterator of streaming deltas. Throws on cancel
(clear message) or error (worker's errorMessage propagates) or
timeout. Polls the row once at subscribe time so an already-
terminal task resolves immediately without waiting for an event
that's never coming.
- gcSweep flips stale pending rows to error (with a clear message
about the timeout) and deletes terminal rows past retention.
Defaults: 1h pending timeout, 7d terminal retention; both
configurable.
Tests:
- 6 db-level schema tests (defaults, json roundtrip, drain query
shape, claimedBy filter, GC predicate, agentId nullable).
- 13 service tests covering enqueue, the CAS race on tryClaim,
complete/fail/cancel, idempotent terminal transitions, revertHeldBy
on disconnect, and the full waitFor signal lifecycle (immediate
resolve, wake on event, chunk streaming, cancel/error/timeout
paths). Plus a gcSweep test with a fixed clock.
mcpd 881/881 (was 868; +13). db pool-schema 14/14, +6 new
inference-task-schema. Pre-existing failures in models.test.ts
(Secret FK fixture issue, also fails on main HEAD) are unrelated.
Stage 2 (next): VirtualLlmService rewires through this — remove the
in-memory pendingTasks map; enqueue creates a row, dispatch picks an
active session, the result-route updates the row + emits the wakeup.
Worker disconnect reverts; worker bind drains.
This commit is contained in:
@@ -0,0 +1,39 @@
|
||||
-- v5: durable inference task queue. Every inference call (sync infer,
|
||||
-- agent chat, or async POST /inference-tasks) gets a row here. Workers
|
||||
-- (mcplocal sessions) drain pending rows whose `poolName` matches the
|
||||
-- pool keys they own when they bind their SSE channel.
|
||||
CREATE TYPE "InferenceTaskStatus" AS ENUM ('pending', 'claimed', 'running', 'completed', 'error', 'cancelled');
|
||||
|
||||
CREATE TABLE "InferenceTask" (
|
||||
"id" TEXT NOT NULL,
|
||||
"status" "InferenceTaskStatus" NOT NULL DEFAULT 'pending',
|
||||
"poolName" TEXT NOT NULL,
|
||||
"llmName" TEXT NOT NULL,
|
||||
"model" TEXT NOT NULL,
|
||||
"tier" TEXT,
|
||||
"claimedBy" TEXT,
|
||||
"requestBody" JSONB NOT NULL,
|
||||
"responseBody" JSONB,
|
||||
"errorMessage" TEXT,
|
||||
"streaming" BOOLEAN NOT NULL DEFAULT false,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"claimedAt" TIMESTAMP(3),
|
||||
"streamStartedAt" TIMESTAMP(3),
|
||||
"completedAt" TIMESTAMP(3),
|
||||
"ownerId" TEXT NOT NULL,
|
||||
"agentId" TEXT,
|
||||
|
||||
CONSTRAINT "InferenceTask_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- Worker claim path: SELECT … WHERE status='pending' AND poolName IN (…)
|
||||
-- runs on every SSE bind. Compound index keeps that fast as the table grows.
|
||||
CREATE INDEX "InferenceTask_status_poolName_idx" ON "InferenceTask"("status", "poolName");
|
||||
-- unbindSession revert path: SELECT … WHERE claimedBy=$1 AND status IN ('claimed','running').
|
||||
CREATE INDEX "InferenceTask_claimedBy_idx" ON "InferenceTask"("claimedBy");
|
||||
-- Owner scoping for the async API listing.
|
||||
CREATE INDEX "InferenceTask_ownerId_idx" ON "InferenceTask"("ownerId");
|
||||
CREATE INDEX "InferenceTask_agentId_idx" ON "InferenceTask"("agentId");
|
||||
-- GC sweep predicate: completedAt < now()-7d. Indexed so the daily cleanup
|
||||
-- doesn't seq-scan once the table grows past a few thousand rows.
|
||||
CREATE INDEX "InferenceTask_completedAt_idx" ON "InferenceTask"("completedAt");
|
||||
@@ -604,6 +604,79 @@ model ChatMessage {
|
||||
@@index([threadId, createdAt])
|
||||
}
|
||||
|
||||
// ── Inference Tasks (v5) ──
|
||||
//
|
||||
// Every inference call (sync infer, agent chat, async POST /inference-tasks)
|
||||
// creates a row here. The DB is the source of truth; mcpd's previous
|
||||
// in-memory `pendingTasks` map is gone — the result-handler updates the row
|
||||
// and an in-process EventEmitter wakes any blocked HTTP handlers (single-
|
||||
// instance for now; multi-instance scaling is a v6 concern that will swap
|
||||
// the emitter for pg LISTEN/NOTIFY without changing the data model).
|
||||
//
|
||||
// Routing: `poolName` is the effective pool key at enqueue time
|
||||
// (`Llm.poolName ?? Llm.name`). Workers (mcplocal sessions) drain pending
|
||||
// rows whose `poolName` matches the pool keys they own when they bind their
|
||||
// SSE channel — that's how queued tasks survive worker offline windows.
|
||||
|
||||
enum InferenceTaskStatus {
|
||||
pending // in queue, no worker has it yet (or claim was reverted)
|
||||
claimed // a worker has it (SSE frame sent), no chunks back yet
|
||||
running // worker started streaming chunks back (streaming tasks only)
|
||||
completed // worker POSTed the final result
|
||||
error // permanent failure (auth, bad request, queue timeout)
|
||||
cancelled // caller said never mind via DELETE
|
||||
}
|
||||
|
||||
model InferenceTask {
|
||||
id String @id @default(cuid())
|
||||
status InferenceTaskStatus @default(pending)
|
||||
// Routing — pool key drives worker matching at claim time. Stored at
|
||||
// enqueue time so a later rename of Llm.poolName doesn't reroute
|
||||
// already-queued work.
|
||||
poolName String
|
||||
llmName String // pinned target Llm name (for audit + agent backref)
|
||||
model String
|
||||
tier String?
|
||||
// Worker tracking. NULL while pending; set on claim; cleared on
|
||||
// unbindSession-driven revert (worker disconnect mid-task).
|
||||
claimedBy String?
|
||||
// Body + result. Both are Json so streaming chunks can be reconstructed
|
||||
// (see TaskService.complete) and async pollers get a structured payload.
|
||||
// requestBody is required (the OpenAI chat-completion request body the
|
||||
// worker should run); responseBody is null until status=completed.
|
||||
requestBody Json
|
||||
responseBody Json?
|
||||
errorMessage String?
|
||||
/**
|
||||
* Whether the original request asked for streaming. Drives the chunk-vs-
|
||||
* final-body protocol on the result POST and tells async API callers
|
||||
* whether `/stream` will yield chunks or just a single completion event.
|
||||
*/
|
||||
streaming Boolean @default(false)
|
||||
// Timestamps for observability + GC:
|
||||
// pending → claimed: claimedAt set
|
||||
// claimed → running: streamStartedAt set (first chunk received)
|
||||
// running/claimed → completed/error/cancelled: completedAt set
|
||||
createdAt DateTime @default(now())
|
||||
claimedAt DateTime?
|
||||
streamStartedAt DateTime?
|
||||
completedAt DateTime?
|
||||
// Caller tracking — RBAC + observability. ownerId references User.id;
|
||||
// agentId is set when the task came in via /agents/<name>/chat (null
|
||||
// for direct /llms/<name>/infer or async POST /inference-tasks calls
|
||||
// that don't pin an agent).
|
||||
ownerId String
|
||||
agentId String?
|
||||
|
||||
@@index([status, poolName])
|
||||
@@index([claimedBy])
|
||||
@@index([ownerId])
|
||||
@@index([agentId])
|
||||
// GC sweep predicate: completedAt < 7d ago. Index speeds up the daily
|
||||
// cleanup without scanning the whole table.
|
||||
@@index([completedAt])
|
||||
}
|
||||
|
||||
// ── Audit Logs ──
|
||||
|
||||
model AuditLog {
|
||||
|
||||
169
src/db/tests/inference-task-schema.test.ts
Normal file
169
src/db/tests/inference-task-schema.test.ts
Normal file
@@ -0,0 +1,169 @@
|
||||
/**
|
||||
* v5 db-level tests for the InferenceTask queue. Exercises the actual
|
||||
* column shapes + index lookups; the mcpd-side service tests cover the
|
||||
* state machine + signal channels with a mocked repo.
|
||||
*/
|
||||
import { describe, it, expect, beforeAll, afterAll, beforeEach } from 'vitest';
|
||||
import type { PrismaClient } from '@prisma/client';
|
||||
import { setupTestDb, cleanupTestDb, clearAllTables } from './helpers.js';
|
||||
|
||||
async function makeOwner(prisma: PrismaClient): Promise<string> {
|
||||
const u = await prisma.user.create({
|
||||
data: { email: `owner-${String(Date.now())}@test`, passwordHash: 'x' },
|
||||
});
|
||||
return u.id;
|
||||
}
|
||||
|
||||
describe('InferenceTask schema (v5)', () => {
|
||||
let prisma: PrismaClient;
|
||||
|
||||
beforeAll(async () => {
|
||||
prisma = await setupTestDb();
|
||||
}, 30_000);
|
||||
|
||||
afterAll(async () => {
|
||||
await cleanupTestDb();
|
||||
});
|
||||
|
||||
beforeEach(async () => {
|
||||
await clearAllTables(prisma);
|
||||
});
|
||||
|
||||
it('defaults a fresh row to status=pending with claim/completion fields null', async () => {
|
||||
const ownerId = await makeOwner(prisma);
|
||||
const row = await prisma.inferenceTask.create({
|
||||
data: {
|
||||
poolName: 'qwen-pool',
|
||||
llmName: 'qwen-prod-1',
|
||||
model: 'qwen3-thinking',
|
||||
requestBody: { messages: [{ role: 'user', content: 'hi' }] },
|
||||
ownerId,
|
||||
},
|
||||
});
|
||||
expect(row.status).toBe('pending');
|
||||
expect(row.claimedBy).toBeNull();
|
||||
expect(row.claimedAt).toBeNull();
|
||||
expect(row.streamStartedAt).toBeNull();
|
||||
expect(row.completedAt).toBeNull();
|
||||
expect(row.responseBody).toBeNull();
|
||||
expect(row.streaming).toBe(false);
|
||||
});
|
||||
|
||||
it('roundtrips streaming=true and a structured requestBody/responseBody', async () => {
|
||||
const ownerId = await makeOwner(prisma);
|
||||
const requestBody = {
|
||||
messages: [{ role: 'user', content: 'hello' }],
|
||||
temperature: 0.2,
|
||||
tools: [{ type: 'function', function: { name: 'noop' } }],
|
||||
};
|
||||
const row = await prisma.inferenceTask.create({
|
||||
data: {
|
||||
poolName: 'qwen-pool',
|
||||
llmName: 'qwen-prod-1',
|
||||
model: 'qwen3',
|
||||
requestBody,
|
||||
streaming: true,
|
||||
ownerId,
|
||||
},
|
||||
});
|
||||
expect(row.streaming).toBe(true);
|
||||
expect(row.requestBody).toEqual(requestBody);
|
||||
|
||||
const completedAt = new Date();
|
||||
const responseBody = { choices: [{ message: { role: 'assistant', content: 'world' } }] };
|
||||
const updated = await prisma.inferenceTask.update({
|
||||
where: { id: row.id },
|
||||
data: { status: 'completed', responseBody, completedAt },
|
||||
});
|
||||
expect(updated.responseBody).toEqual(responseBody);
|
||||
expect(updated.completedAt?.getTime()).toBe(completedAt.getTime());
|
||||
});
|
||||
|
||||
it('compound index supports the dispatcher\'s drain query (status + poolName IN ...)', async () => {
|
||||
// The actual EXPLAIN/index-use check is too brittle for unit tests;
|
||||
// here we verify the QUERY shape that the repo's findPendingForPools
|
||||
// issues — same WHERE/ORDER BY — returns the expected rows in FIFO
|
||||
// order. Index usage is implied by the Prisma model definition.
|
||||
const ownerId = await makeOwner(prisma);
|
||||
const t1 = await prisma.inferenceTask.create({
|
||||
data: { poolName: 'pool-a', llmName: 'a-1', model: 'm', requestBody: {}, ownerId },
|
||||
});
|
||||
await new Promise((r) => setTimeout(r, 5));
|
||||
const t2 = await prisma.inferenceTask.create({
|
||||
data: { poolName: 'pool-a', llmName: 'a-2', model: 'm', requestBody: {}, ownerId },
|
||||
});
|
||||
await prisma.inferenceTask.create({
|
||||
data: { poolName: 'pool-b', llmName: 'b-1', model: 'm', requestBody: {}, ownerId },
|
||||
});
|
||||
// One row in pool-a is no longer pending — must be excluded.
|
||||
await prisma.inferenceTask.create({
|
||||
data: { poolName: 'pool-a', llmName: 'a-3', model: 'm', requestBody: {}, ownerId, status: 'completed' },
|
||||
});
|
||||
|
||||
const drained = await prisma.inferenceTask.findMany({
|
||||
where: { status: 'pending', poolName: { in: ['pool-a', 'pool-b'] } },
|
||||
orderBy: { createdAt: 'asc' },
|
||||
});
|
||||
expect(drained.map((r) => r.id)).toEqual([t1.id, t2.id, drained[2]!.id]);
|
||||
expect(drained.map((r) => r.poolName)).toEqual(['pool-a', 'pool-a', 'pool-b']);
|
||||
});
|
||||
|
||||
it('claimedBy index supports unbindSession revert (worker disconnect path)', async () => {
|
||||
const ownerId = await makeOwner(prisma);
|
||||
await prisma.inferenceTask.create({
|
||||
data: { poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, ownerId, status: 'claimed', claimedBy: 'sess-A' },
|
||||
});
|
||||
await prisma.inferenceTask.create({
|
||||
data: { poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, ownerId, status: 'running', claimedBy: 'sess-A' },
|
||||
});
|
||||
await prisma.inferenceTask.create({
|
||||
data: { poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, ownerId, status: 'claimed', claimedBy: 'sess-B' },
|
||||
});
|
||||
// Completed-but-claimedBy=sess-A row: must NOT revert (terminal state).
|
||||
await prisma.inferenceTask.create({
|
||||
data: { poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, ownerId, status: 'completed', claimedBy: 'sess-A' },
|
||||
});
|
||||
|
||||
const heldByA = await prisma.inferenceTask.findMany({
|
||||
where: { claimedBy: 'sess-A', status: { in: ['claimed', 'running'] } },
|
||||
});
|
||||
expect(heldByA).toHaveLength(2);
|
||||
});
|
||||
|
||||
it('GC predicate (terminal + completedAt < cutoff) is index-friendly and filters correctly', async () => {
|
||||
const ownerId = await makeOwner(prisma);
|
||||
const old = new Date(Date.now() - 8 * 24 * 60 * 60 * 1000); // 8 d ago
|
||||
const recent = new Date(Date.now() - 1 * 60 * 60 * 1000); // 1 h ago
|
||||
await prisma.inferenceTask.create({
|
||||
data: { poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, ownerId, status: 'completed', completedAt: old },
|
||||
});
|
||||
await prisma.inferenceTask.create({
|
||||
data: { poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, ownerId, status: 'error', completedAt: old, errorMessage: 'boom' },
|
||||
});
|
||||
// Inside retention — must not be picked up by GC.
|
||||
await prisma.inferenceTask.create({
|
||||
data: { poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, ownerId, status: 'completed', completedAt: recent },
|
||||
});
|
||||
// Pending row — must not be picked up by terminal GC.
|
||||
await prisma.inferenceTask.create({
|
||||
data: { poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, ownerId, status: 'pending' },
|
||||
});
|
||||
|
||||
const cutoff = new Date(Date.now() - 7 * 24 * 60 * 60 * 1000);
|
||||
const expired = await prisma.inferenceTask.findMany({
|
||||
where: {
|
||||
status: { in: ['completed', 'error', 'cancelled'] },
|
||||
completedAt: { lt: cutoff },
|
||||
},
|
||||
});
|
||||
expect(expired).toHaveLength(2);
|
||||
});
|
||||
|
||||
it('agentId is nullable — direct chat-llm tasks have no agent', async () => {
|
||||
const ownerId = await makeOwner(prisma);
|
||||
const row = await prisma.inferenceTask.create({
|
||||
data: { poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, ownerId, agentId: null },
|
||||
});
|
||||
expect(row.agentId).toBeNull();
|
||||
});
|
||||
});
|
||||
235
src/mcpd/src/repositories/inference-task.repository.ts
Normal file
235
src/mcpd/src/repositories/inference-task.repository.ts
Normal file
@@ -0,0 +1,235 @@
|
||||
/**
|
||||
* v5 InferenceTask repository — durable queue for inference work.
|
||||
*
|
||||
* Workers (mcplocal SSE sessions) drain rows whose `poolName` matches the
|
||||
* pool keys they own when they bind. The state machine is:
|
||||
*
|
||||
* pending ──claim──> claimed ──first chunk──> running
|
||||
* \ │
|
||||
* \ ▼
|
||||
* ──complete/fail──> completed | error
|
||||
*
|
||||
* Plus orthogonal terminal transitions:
|
||||
*
|
||||
* pending|claimed|running ──cancel──> cancelled
|
||||
* claimed|running ──worker disconnect──> pending (revert + re-queue)
|
||||
*
|
||||
* The DB is the source of truth. mcpd's in-flight HTTP handlers wait via
|
||||
* an in-process EventEmitter (see TaskService) — single-instance for now;
|
||||
* pg LISTEN/NOTIFY is the v6 swap when scaling horizontally.
|
||||
*/
|
||||
import type {
|
||||
PrismaClient,
|
||||
InferenceTask,
|
||||
InferenceTaskStatus,
|
||||
} from '@prisma/client';
|
||||
// `Prisma` is imported as a value because we reference Prisma.JsonNull at
|
||||
// runtime (the sentinel for explicit JSON null on a nullable column).
|
||||
import { Prisma } from '@prisma/client';
|
||||
|
||||
export interface CreateInferenceTaskInput {
|
||||
poolName: string;
|
||||
llmName: string;
|
||||
model: string;
|
||||
tier?: string | null;
|
||||
requestBody: Record<string, unknown>;
|
||||
streaming: boolean;
|
||||
ownerId: string;
|
||||
agentId?: string | null;
|
||||
}
|
||||
|
||||
export interface IInferenceTaskRepository {
|
||||
create(data: CreateInferenceTaskInput): Promise<InferenceTask>;
|
||||
findById(id: string): Promise<InferenceTask | null>;
|
||||
/** Pending rows for one or more pool keys, oldest first (FIFO). */
|
||||
findPendingForPools(poolNames: string[], limit?: number): Promise<InferenceTask[]>;
|
||||
/** Tasks held by a worker session — used by unbindSession to revert on disconnect. */
|
||||
findHeldBy(claimedBy: string): Promise<InferenceTask[]>;
|
||||
/** List for the async API; filters are AND-combined. */
|
||||
list(filter: {
|
||||
ownerId?: string;
|
||||
status?: InferenceTaskStatus | InferenceTaskStatus[];
|
||||
poolName?: string;
|
||||
agentId?: string;
|
||||
limit?: number;
|
||||
}): Promise<InferenceTask[]>;
|
||||
|
||||
// ── State transitions ──
|
||||
// Each transition is conditional ("compare-and-swap" via updateMany +
|
||||
// affected-row count) so two workers racing to claim the same row both
|
||||
// succeed at the DB level — one gets affected=1 and a populated row,
|
||||
// the other gets affected=0 and we fall through to the next candidate.
|
||||
|
||||
/** pending → claimed; returns the updated row, or null if the row was no longer pending. */
|
||||
tryClaim(id: string, claimedBy: string, claimedAt: Date): Promise<InferenceTask | null>;
|
||||
/** claimed → running; idempotent — returns the row even if already running. */
|
||||
markRunning(id: string, at: Date): Promise<InferenceTask | null>;
|
||||
/** {claimed,running} → completed; only allowed from a non-terminal state. */
|
||||
markCompleted(id: string, responseBody: Record<string, unknown> | null, at: Date): Promise<InferenceTask | null>;
|
||||
/** any non-terminal → error. Records `errorMessage`. */
|
||||
markError(id: string, errorMessage: string, at: Date): Promise<InferenceTask | null>;
|
||||
/** any non-terminal → cancelled. */
|
||||
markCancelled(id: string, at: Date): Promise<InferenceTask | null>;
|
||||
/** {claimed,running} → pending; clears `claimedBy`. Used on worker disconnect. */
|
||||
revertToPending(id: string): Promise<InferenceTask | null>;
|
||||
|
||||
// ── GC ──
|
||||
|
||||
/** Pending rows older than the cutoff — the GC sweep flips these to error. */
|
||||
findStalePending(cutoff: Date): Promise<InferenceTask[]>;
|
||||
/** Completed/error/cancelled rows older than the cutoff — GC deletes them. */
|
||||
findExpiredTerminal(cutoff: Date): Promise<InferenceTask[]>;
|
||||
/** Bulk delete by id — used by the GC sweep after collecting expired ids. */
|
||||
deleteMany(ids: string[]): Promise<number>;
|
||||
}
|
||||
|
||||
const NON_TERMINAL: InferenceTaskStatus[] = ['pending', 'claimed', 'running'];
|
||||
const HELD: InferenceTaskStatus[] = ['claimed', 'running'];
|
||||
|
||||
export class InferenceTaskRepository implements IInferenceTaskRepository {
|
||||
constructor(private readonly prisma: PrismaClient) {}
|
||||
|
||||
async create(data: CreateInferenceTaskInput): Promise<InferenceTask> {
|
||||
return this.prisma.inferenceTask.create({
|
||||
data: {
|
||||
poolName: data.poolName,
|
||||
llmName: data.llmName,
|
||||
model: data.model,
|
||||
tier: data.tier ?? null,
|
||||
requestBody: data.requestBody as Prisma.InputJsonValue,
|
||||
streaming: data.streaming,
|
||||
ownerId: data.ownerId,
|
||||
agentId: data.agentId ?? null,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
async findById(id: string): Promise<InferenceTask | null> {
|
||||
return this.prisma.inferenceTask.findUnique({ where: { id } });
|
||||
}
|
||||
|
||||
async findPendingForPools(poolNames: string[], limit?: number): Promise<InferenceTask[]> {
|
||||
if (poolNames.length === 0) return [];
|
||||
return this.prisma.inferenceTask.findMany({
|
||||
where: { status: 'pending', poolName: { in: poolNames } },
|
||||
orderBy: { createdAt: 'asc' },
|
||||
...(limit !== undefined ? { take: limit } : {}),
|
||||
});
|
||||
}
|
||||
|
||||
async findHeldBy(claimedBy: string): Promise<InferenceTask[]> {
|
||||
return this.prisma.inferenceTask.findMany({
|
||||
where: { claimedBy, status: { in: HELD } },
|
||||
});
|
||||
}
|
||||
|
||||
async list(filter: {
|
||||
ownerId?: string;
|
||||
status?: InferenceTaskStatus | InferenceTaskStatus[];
|
||||
poolName?: string;
|
||||
agentId?: string;
|
||||
limit?: number;
|
||||
}): Promise<InferenceTask[]> {
|
||||
const where: Prisma.InferenceTaskWhereInput = {};
|
||||
if (filter.ownerId !== undefined) where.ownerId = filter.ownerId;
|
||||
if (filter.status !== undefined) {
|
||||
where.status = Array.isArray(filter.status) ? { in: filter.status } : filter.status;
|
||||
}
|
||||
if (filter.poolName !== undefined) where.poolName = filter.poolName;
|
||||
if (filter.agentId !== undefined) where.agentId = filter.agentId;
|
||||
return this.prisma.inferenceTask.findMany({
|
||||
where,
|
||||
orderBy: { createdAt: 'desc' },
|
||||
...(filter.limit !== undefined ? { take: filter.limit } : {}),
|
||||
});
|
||||
}
|
||||
|
||||
async tryClaim(id: string, claimedBy: string, claimedAt: Date): Promise<InferenceTask | null> {
|
||||
// Compare-and-swap on status='pending'. Two workers racing both run
|
||||
// this UPDATE; whichever the DB serializes first sees affected=1 and
|
||||
// gets the row, the loser sees affected=0 and falls through.
|
||||
const result = await this.prisma.inferenceTask.updateMany({
|
||||
where: { id, status: 'pending' },
|
||||
data: { status: 'claimed', claimedBy, claimedAt },
|
||||
});
|
||||
if (result.count === 0) return null;
|
||||
return this.findById(id);
|
||||
}
|
||||
|
||||
async markRunning(id: string, at: Date): Promise<InferenceTask | null> {
|
||||
const result = await this.prisma.inferenceTask.updateMany({
|
||||
where: { id, status: { in: ['claimed', 'running'] } },
|
||||
data: { status: 'running', streamStartedAt: at },
|
||||
});
|
||||
if (result.count === 0) return null;
|
||||
return this.findById(id);
|
||||
}
|
||||
|
||||
async markCompleted(id: string, responseBody: Record<string, unknown> | null, at: Date): Promise<InferenceTask | null> {
|
||||
// Prisma's nullable JSON column writes need an explicit
|
||||
// `Prisma.JsonNull` sentinel for null; passing literal null fails
|
||||
// typecheck because the column also accepts a plain JSON value.
|
||||
const bodyForUpdate: Prisma.InputJsonValue | typeof Prisma.JsonNull =
|
||||
responseBody === null ? Prisma.JsonNull : (responseBody as Prisma.InputJsonValue);
|
||||
const result = await this.prisma.inferenceTask.updateMany({
|
||||
where: { id, status: { in: NON_TERMINAL } },
|
||||
data: {
|
||||
status: 'completed',
|
||||
responseBody: bodyForUpdate,
|
||||
completedAt: at,
|
||||
},
|
||||
});
|
||||
if (result.count === 0) return null;
|
||||
return this.findById(id);
|
||||
}
|
||||
|
||||
async markError(id: string, errorMessage: string, at: Date): Promise<InferenceTask | null> {
|
||||
const result = await this.prisma.inferenceTask.updateMany({
|
||||
where: { id, status: { in: NON_TERMINAL } },
|
||||
data: { status: 'error', errorMessage, completedAt: at },
|
||||
});
|
||||
if (result.count === 0) return null;
|
||||
return this.findById(id);
|
||||
}
|
||||
|
||||
async markCancelled(id: string, at: Date): Promise<InferenceTask | null> {
|
||||
const result = await this.prisma.inferenceTask.updateMany({
|
||||
where: { id, status: { in: NON_TERMINAL } },
|
||||
data: { status: 'cancelled', completedAt: at },
|
||||
});
|
||||
if (result.count === 0) return null;
|
||||
return this.findById(id);
|
||||
}
|
||||
|
||||
async revertToPending(id: string): Promise<InferenceTask | null> {
|
||||
const result = await this.prisma.inferenceTask.updateMany({
|
||||
where: { id, status: { in: HELD } },
|
||||
data: { status: 'pending', claimedBy: null, claimedAt: null, streamStartedAt: null },
|
||||
});
|
||||
if (result.count === 0) return null;
|
||||
return this.findById(id);
|
||||
}
|
||||
|
||||
async findStalePending(cutoff: Date): Promise<InferenceTask[]> {
|
||||
return this.prisma.inferenceTask.findMany({
|
||||
where: { status: 'pending', createdAt: { lt: cutoff } },
|
||||
});
|
||||
}
|
||||
|
||||
async findExpiredTerminal(cutoff: Date): Promise<InferenceTask[]> {
|
||||
return this.prisma.inferenceTask.findMany({
|
||||
where: {
|
||||
status: { in: ['completed', 'error', 'cancelled'] },
|
||||
completedAt: { lt: cutoff },
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
async deleteMany(ids: string[]): Promise<number> {
|
||||
if (ids.length === 0) return 0;
|
||||
const result = await this.prisma.inferenceTask.deleteMany({
|
||||
where: { id: { in: ids } },
|
||||
});
|
||||
return result.count;
|
||||
}
|
||||
}
|
||||
297
src/mcpd/src/services/inference-task.service.ts
Normal file
297
src/mcpd/src/services/inference-task.service.ts
Normal file
@@ -0,0 +1,297 @@
|
||||
/**
|
||||
* v5 InferenceTaskService — business logic over the durable task queue.
|
||||
*
|
||||
* The DB row (managed via IInferenceTaskRepository) is the source of truth
|
||||
* for *state*. This service owns the in-process *signaling* layer that
|
||||
* lets blocked HTTP handlers wake up promptly when a worker POSTs a
|
||||
* result, instead of polling the DB. The signals carry no state — they
|
||||
* just say "go re-read row X" — so a missed signal degrades to a slower
|
||||
* poll-then-resolve (still correct, just less responsive).
|
||||
*
|
||||
* Single-instance assumption: EventEmitter only fires on the mcpd that
|
||||
* holds the row's blocked handler. With multiple mcpd replicas, the
|
||||
* worker's POST might land on instance A while the caller's handler is
|
||||
* on instance B; B never sees the local emitter. Solution at scale is
|
||||
* pg LISTEN/NOTIFY (v6 swap) — no schema change needed; just replace the
|
||||
* EventEmitter wakeup with a NOTIFY emit.
|
||||
*/
|
||||
import { EventEmitter } from 'node:events';
|
||||
import type { InferenceTask, InferenceTaskStatus } from '@prisma/client';
|
||||
import type { IInferenceTaskRepository, CreateInferenceTaskInput } from '../repositories/inference-task.repository.js';
|
||||
import { NotFoundError } from './mcp-server.service.js';
|
||||
|
||||
/**
|
||||
* One streaming chunk pushed back from a worker. Lives in-memory only —
|
||||
* we don't persist every SSE delta to the DB (too expensive for
|
||||
* high-frequency streams). The final assembled body lands on the row
|
||||
* via `complete()` when the worker emits its terminal `done:true`.
|
||||
*/
|
||||
export interface InferenceTaskChunk {
|
||||
data: string;
|
||||
done?: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* The wait handle returned to a blocked HTTP handler. `done` resolves
|
||||
* when the row reaches a terminal status (completed | error | cancelled);
|
||||
* `chunks` is an async iterator that yields streaming deltas as they
|
||||
* arrive (only meaningful when the underlying request asked for streaming).
|
||||
*/
|
||||
export interface InferenceTaskWaiter {
|
||||
/** The DB row's terminal state. Throws on `error`/`cancelled` with an explanatory message. */
|
||||
done: Promise<InferenceTask>;
|
||||
/** Async iterator of streaming chunks. Yields nothing for non-streaming tasks. */
|
||||
chunks: AsyncGenerator<InferenceTaskChunk>;
|
||||
}
|
||||
|
||||
export interface IInferenceTaskService {
|
||||
/** Create a new pending task. Caller is expected to immediately attempt dispatch. */
|
||||
enqueue(input: CreateInferenceTaskInput): Promise<InferenceTask>;
|
||||
/** Wait for a task to reach a terminal state, with optional timeout (ms). */
|
||||
waitFor(taskId: string, timeoutMs: number): InferenceTaskWaiter;
|
||||
/** Conditional CAS claim — returns the row if `pending`, null if already claimed. */
|
||||
tryClaim(taskId: string, claimedBy: string): Promise<InferenceTask | null>;
|
||||
/** Worker reported first chunk — flips `claimed` → `running`. Idempotent. */
|
||||
markRunning(taskId: string): Promise<InferenceTask | null>;
|
||||
/** Worker pushed a streaming chunk. Wakes up the in-process waiter; no DB write. */
|
||||
pushChunk(taskId: string, chunk: InferenceTaskChunk): boolean;
|
||||
/** Worker POSTed the final result. Persists the body + flips to `completed`. */
|
||||
complete(taskId: string, responseBody: Record<string, unknown> | null): Promise<InferenceTask | null>;
|
||||
/** Worker (or mcpd) marked the task as failed. */
|
||||
fail(taskId: string, errorMessage: string): Promise<InferenceTask | null>;
|
||||
/** Caller (or GC) marked the task as cancelled. */
|
||||
cancel(taskId: string): Promise<InferenceTask | null>;
|
||||
/** Worker disconnected mid-task — revert claimed/running rows back to pending. */
|
||||
revertHeldBy(claimedBy: string): Promise<InferenceTask[]>;
|
||||
/** Pending rows for one or more pool keys, ordered FIFO. Used by drainPending on bind. */
|
||||
findPendingForPools(poolNames: string[], limit?: number): Promise<InferenceTask[]>;
|
||||
findById(id: string): Promise<InferenceTask | null>;
|
||||
list(filter: {
|
||||
ownerId?: string;
|
||||
status?: InferenceTaskStatus | InferenceTaskStatus[];
|
||||
poolName?: string;
|
||||
agentId?: string;
|
||||
limit?: number;
|
||||
}): Promise<InferenceTask[]>;
|
||||
/** GC sweep: pending too long → error; terminal too long → delete. */
|
||||
gcSweep(opts: { pendingTimeoutMs: number; terminalRetentionMs: number }): Promise<{ erroredOut: number; deleted: number }>;
|
||||
}
|
||||
|
||||
const DEFAULT_PENDING_TIMEOUT_MS = 60 * 60 * 1000; // 1 h
|
||||
const DEFAULT_TERMINAL_RETENTION_MS = 7 * 24 * 60 * 60 * 1000; // 7 d
|
||||
|
||||
export class InferenceTaskService implements IInferenceTaskService {
|
||||
/**
|
||||
* Wakeup channel per task. The result-handler (worker POSTs landing on
|
||||
* mcpd) emits a terminal status; the streaming push emits chunks.
|
||||
* Cleared once a waiter resolves to avoid leaks. EventEmitter listener
|
||||
* cap is bumped to a generous default — typical concurrent waiters per
|
||||
* task is 1, but `/stream` SSE consumers and the original HTTP handler
|
||||
* can both subscribe.
|
||||
*/
|
||||
private readonly events = new EventEmitter();
|
||||
|
||||
constructor(
|
||||
private readonly repo: IInferenceTaskRepository,
|
||||
private readonly clock: () => Date = () => new Date(),
|
||||
) {
|
||||
this.events.setMaxListeners(50);
|
||||
}
|
||||
|
||||
async enqueue(input: CreateInferenceTaskInput): Promise<InferenceTask> {
|
||||
return this.repo.create(input);
|
||||
}
|
||||
|
||||
waitFor(taskId: string, timeoutMs: number): InferenceTaskWaiter {
|
||||
// Set up the chunk channel up-front so a worker pushing chunks before
|
||||
// the consumer subscribes doesn't drop them. We use a queue + wake
|
||||
// pattern (same shape as the v3 chat.service streaming bridge).
|
||||
const chunkQueue: InferenceTaskChunk[] = [];
|
||||
let chunkResolve: (() => void) | null = null;
|
||||
let finished = false;
|
||||
let finishError: Error | null = null;
|
||||
|
||||
const onChunk = (chunk: InferenceTaskChunk): void => {
|
||||
chunkQueue.push(chunk);
|
||||
if (chunkResolve !== null) {
|
||||
const r = chunkResolve;
|
||||
chunkResolve = null;
|
||||
r();
|
||||
}
|
||||
};
|
||||
const onTerminal = (): void => {
|
||||
finished = true;
|
||||
if (chunkResolve !== null) {
|
||||
const r = chunkResolve;
|
||||
chunkResolve = null;
|
||||
r();
|
||||
}
|
||||
};
|
||||
|
||||
this.events.on(`chunk:${taskId}`, onChunk);
|
||||
this.events.on(`terminal:${taskId}`, onTerminal);
|
||||
|
||||
const cleanup = (): void => {
|
||||
this.events.off(`chunk:${taskId}`, onChunk);
|
||||
this.events.off(`terminal:${taskId}`, onTerminal);
|
||||
};
|
||||
|
||||
// The terminal promise: poll once at start (in case the row is
|
||||
// already terminal — common for already-completed tasks the caller
|
||||
// is just fetching the result of), then race the timeout against
|
||||
// the wakeup signal.
|
||||
const done = (async (): Promise<InferenceTask> => {
|
||||
try {
|
||||
const initial = await this.repo.findById(taskId);
|
||||
if (initial === null) throw new NotFoundError(`InferenceTask not found: ${taskId}`);
|
||||
if (isTerminal(initial.status)) return ensureSuccess(initial);
|
||||
|
||||
// Wait for terminal event OR timeout. If the wake happens before
|
||||
// the timer fires, we re-fetch and return; otherwise we error.
|
||||
const timer = new Promise<never>((_, reject) => {
|
||||
setTimeout(() => reject(new Error(`InferenceTask wait timed out after ${String(timeoutMs)}ms (task ${taskId})`)), timeoutMs);
|
||||
});
|
||||
const wake = new Promise<void>((resolve) => {
|
||||
this.events.once(`terminal:${taskId}`, () => resolve());
|
||||
});
|
||||
await Promise.race([timer, wake]);
|
||||
finished = true;
|
||||
const final = await this.repo.findById(taskId);
|
||||
if (final === null) throw new NotFoundError(`InferenceTask not found: ${taskId}`);
|
||||
return ensureSuccess(final);
|
||||
} catch (err) {
|
||||
finishError = err as Error;
|
||||
finished = true;
|
||||
throw err;
|
||||
} finally {
|
||||
cleanup();
|
||||
}
|
||||
})();
|
||||
|
||||
const chunks = (async function* gen(): AsyncGenerator<InferenceTaskChunk> {
|
||||
while (true) {
|
||||
while (chunkQueue.length > 0) {
|
||||
const c = chunkQueue.shift()!;
|
||||
yield c;
|
||||
if (c.done === true) return;
|
||||
}
|
||||
if (finished) {
|
||||
if (finishError !== null) throw finishError;
|
||||
return;
|
||||
}
|
||||
await new Promise<void>((r) => { chunkResolve = r; });
|
||||
}
|
||||
})();
|
||||
|
||||
return { done, chunks };
|
||||
}
|
||||
|
||||
async tryClaim(taskId: string, claimedBy: string): Promise<InferenceTask | null> {
|
||||
return this.repo.tryClaim(taskId, claimedBy, this.clock());
|
||||
}
|
||||
|
||||
async markRunning(taskId: string): Promise<InferenceTask | null> {
|
||||
return this.repo.markRunning(taskId, this.clock());
|
||||
}
|
||||
|
||||
pushChunk(taskId: string, chunk: InferenceTaskChunk): boolean {
|
||||
// Streaming chunks are in-memory only. If the row no longer exists or
|
||||
// there are no listeners, the emit is a no-op — return false so the
|
||||
// result-route can decide whether to log a warning.
|
||||
return this.events.emit(`chunk:${taskId}`, chunk);
|
||||
}
|
||||
|
||||
async complete(taskId: string, responseBody: Record<string, unknown> | null): Promise<InferenceTask | null> {
|
||||
const updated = await this.repo.markCompleted(taskId, responseBody, this.clock());
|
||||
if (updated !== null) this.events.emit(`terminal:${taskId}`);
|
||||
return updated;
|
||||
}
|
||||
|
||||
async fail(taskId: string, errorMessage: string): Promise<InferenceTask | null> {
|
||||
const updated = await this.repo.markError(taskId, errorMessage, this.clock());
|
||||
if (updated !== null) this.events.emit(`terminal:${taskId}`);
|
||||
return updated;
|
||||
}
|
||||
|
||||
async cancel(taskId: string): Promise<InferenceTask | null> {
|
||||
const updated = await this.repo.markCancelled(taskId, this.clock());
|
||||
if (updated !== null) this.events.emit(`terminal:${taskId}`);
|
||||
return updated;
|
||||
}
|
||||
|
||||
async revertHeldBy(claimedBy: string): Promise<InferenceTask[]> {
|
||||
const held = await this.repo.findHeldBy(claimedBy);
|
||||
const reverted: InferenceTask[] = [];
|
||||
for (const t of held) {
|
||||
const r = await this.repo.revertToPending(t.id);
|
||||
if (r !== null) reverted.push(r);
|
||||
}
|
||||
return reverted;
|
||||
}
|
||||
|
||||
async findPendingForPools(poolNames: string[], limit?: number): Promise<InferenceTask[]> {
|
||||
return this.repo.findPendingForPools(poolNames, limit);
|
||||
}
|
||||
|
||||
async findById(id: string): Promise<InferenceTask | null> {
|
||||
return this.repo.findById(id);
|
||||
}
|
||||
|
||||
async list(filter: {
|
||||
ownerId?: string;
|
||||
status?: InferenceTaskStatus | InferenceTaskStatus[];
|
||||
poolName?: string;
|
||||
agentId?: string;
|
||||
limit?: number;
|
||||
}): Promise<InferenceTask[]> {
|
||||
return this.repo.list(filter);
|
||||
}
|
||||
|
||||
async gcSweep(opts: { pendingTimeoutMs: number; terminalRetentionMs: number } = {
|
||||
pendingTimeoutMs: DEFAULT_PENDING_TIMEOUT_MS,
|
||||
terminalRetentionMs: DEFAULT_TERMINAL_RETENTION_MS,
|
||||
}): Promise<{ erroredOut: number; deleted: number }> {
|
||||
const now = this.clock();
|
||||
const pendingCutoff = new Date(now.getTime() - opts.pendingTimeoutMs);
|
||||
const terminalCutoff = new Date(now.getTime() - opts.terminalRetentionMs);
|
||||
|
||||
// Pass 1: pending tasks that have aged out → flip to error so the
|
||||
// caller (if still waiting) gets a clean failure.
|
||||
const stale = await this.repo.findStalePending(pendingCutoff);
|
||||
let erroredOut = 0;
|
||||
for (const t of stale) {
|
||||
const r = await this.repo.markError(t.id, `Task expired in pending state after ${String(opts.pendingTimeoutMs / 1000)}s`, now);
|
||||
if (r !== null) {
|
||||
this.events.emit(`terminal:${t.id}`);
|
||||
erroredOut += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Pass 2: terminal tasks past retention → delete.
|
||||
const expired = await this.repo.findExpiredTerminal(terminalCutoff);
|
||||
const deleted = await this.repo.deleteMany(expired.map((t) => t.id));
|
||||
|
||||
return { erroredOut, deleted };
|
||||
}
|
||||
}
|
||||
|
||||
function isTerminal(status: InferenceTaskStatus): boolean {
|
||||
return status === 'completed' || status === 'error' || status === 'cancelled';
|
||||
}
|
||||
|
||||
/**
|
||||
* Surface a non-success terminal state as a thrown error so the HTTP
|
||||
* handler propagates it cleanly. `error` rows carry an `errorMessage`;
|
||||
* `cancelled` rows don't, so we synthesize one.
|
||||
*/
|
||||
function ensureSuccess(task: InferenceTask): InferenceTask {
|
||||
if (task.status === 'completed') return task;
|
||||
if (task.status === 'cancelled') {
|
||||
throw new Error(`InferenceTask cancelled: ${task.id}`);
|
||||
}
|
||||
if (task.status === 'error') {
|
||||
throw new Error(task.errorMessage ?? `InferenceTask failed: ${task.id}`);
|
||||
}
|
||||
// Theoretically unreachable — waitFor only resolves on terminal events.
|
||||
throw new Error(`InferenceTask in unexpected state: ${String(task.status)}`);
|
||||
}
|
||||
330
src/mcpd/tests/inference-task-service.test.ts
Normal file
330
src/mcpd/tests/inference-task-service.test.ts
Normal file
@@ -0,0 +1,330 @@
|
||||
/**
|
||||
* v5 InferenceTaskService — state machine, signal channels, and GC sweep.
|
||||
* The repo is mocked with an in-memory Map so we exercise the service
|
||||
* logic deterministically without touching Postgres. Schema-level tests
|
||||
* live in src/db/tests/inference-task-schema.test.ts.
|
||||
*/
|
||||
import { describe, it, expect, vi } from 'vitest';
|
||||
import type { InferenceTask, InferenceTaskStatus } from '@prisma/client';
|
||||
import type { IInferenceTaskRepository } from '../src/repositories/inference-task.repository.js';
|
||||
import { InferenceTaskService } from '../src/services/inference-task.service.js';
|
||||
|
||||
function makeRow(overrides: Partial<InferenceTask> = {}): InferenceTask {
|
||||
return {
|
||||
id: overrides.id ?? `task-${Math.random().toString(36).slice(2, 8)}`,
|
||||
status: 'pending',
|
||||
poolName: 'pool-a',
|
||||
llmName: 'pool-a',
|
||||
model: 'qwen3-thinking',
|
||||
tier: null,
|
||||
claimedBy: null,
|
||||
requestBody: { messages: [{ role: 'user', content: 'hi' }] } as unknown as InferenceTask['requestBody'],
|
||||
responseBody: null,
|
||||
errorMessage: null,
|
||||
streaming: false,
|
||||
createdAt: new Date(),
|
||||
claimedAt: null,
|
||||
streamStartedAt: null,
|
||||
completedAt: null,
|
||||
ownerId: 'owner-1',
|
||||
agentId: null,
|
||||
...overrides,
|
||||
};
|
||||
}
|
||||
|
||||
function mockRepo(initial: InferenceTask[] = []): IInferenceTaskRepository {
|
||||
const rows = new Map<string, InferenceTask>(initial.map((r) => [r.id, { ...r }]));
|
||||
// Tiny CAS helper that mirrors the real `updateMany({where:{status:in}})`
|
||||
// semantics — only flips the row if the current status matches.
|
||||
const cas = (id: string, allowed: InferenceTaskStatus[], patch: Partial<InferenceTask>): InferenceTask | null => {
|
||||
const row = rows.get(id);
|
||||
if (row === undefined) return null;
|
||||
if (!allowed.includes(row.status)) return null;
|
||||
const next = { ...row, ...patch };
|
||||
rows.set(id, next);
|
||||
return next;
|
||||
};
|
||||
return {
|
||||
create: vi.fn(async (data) => {
|
||||
const row = makeRow({
|
||||
id: `task-${rows.size + 1}`,
|
||||
poolName: data.poolName,
|
||||
llmName: data.llmName,
|
||||
model: data.model,
|
||||
tier: data.tier ?? null,
|
||||
requestBody: data.requestBody as InferenceTask['requestBody'],
|
||||
streaming: data.streaming,
|
||||
ownerId: data.ownerId,
|
||||
agentId: data.agentId ?? null,
|
||||
});
|
||||
rows.set(row.id, row);
|
||||
return row;
|
||||
}),
|
||||
findById: vi.fn(async (id) => rows.get(id) ?? null),
|
||||
findPendingForPools: vi.fn(async (poolNames, limit) => {
|
||||
const out = [...rows.values()]
|
||||
.filter((r) => r.status === 'pending' && poolNames.includes(r.poolName))
|
||||
.sort((a, b) => a.createdAt.getTime() - b.createdAt.getTime());
|
||||
return limit !== undefined ? out.slice(0, limit) : out;
|
||||
}),
|
||||
findHeldBy: vi.fn(async (claimedBy) =>
|
||||
[...rows.values()].filter((r) => r.claimedBy === claimedBy && (r.status === 'claimed' || r.status === 'running')),
|
||||
),
|
||||
list: vi.fn(async (filter) => {
|
||||
let out = [...rows.values()];
|
||||
if (filter.ownerId !== undefined) out = out.filter((r) => r.ownerId === filter.ownerId);
|
||||
if (filter.poolName !== undefined) out = out.filter((r) => r.poolName === filter.poolName);
|
||||
if (filter.agentId !== undefined) out = out.filter((r) => r.agentId === filter.agentId);
|
||||
if (filter.status !== undefined) {
|
||||
const statuses = Array.isArray(filter.status) ? filter.status : [filter.status];
|
||||
out = out.filter((r) => statuses.includes(r.status));
|
||||
}
|
||||
out.sort((a, b) => b.createdAt.getTime() - a.createdAt.getTime());
|
||||
return filter.limit !== undefined ? out.slice(0, filter.limit) : out;
|
||||
}),
|
||||
tryClaim: vi.fn(async (id, claimedBy, claimedAt) =>
|
||||
cas(id, ['pending'], { status: 'claimed', claimedBy, claimedAt }),
|
||||
),
|
||||
markRunning: vi.fn(async (id, at) =>
|
||||
cas(id, ['claimed', 'running'], { status: 'running', streamStartedAt: at }),
|
||||
),
|
||||
markCompleted: vi.fn(async (id, body, at) =>
|
||||
cas(id, ['pending', 'claimed', 'running'], {
|
||||
status: 'completed',
|
||||
responseBody: (body ?? null) as InferenceTask['responseBody'],
|
||||
completedAt: at,
|
||||
}),
|
||||
),
|
||||
markError: vi.fn(async (id, errorMessage, at) =>
|
||||
cas(id, ['pending', 'claimed', 'running'], { status: 'error', errorMessage, completedAt: at }),
|
||||
),
|
||||
markCancelled: vi.fn(async (id, at) =>
|
||||
cas(id, ['pending', 'claimed', 'running'], { status: 'cancelled', completedAt: at }),
|
||||
),
|
||||
revertToPending: vi.fn(async (id) =>
|
||||
cas(id, ['claimed', 'running'], { status: 'pending', claimedBy: null, claimedAt: null, streamStartedAt: null }),
|
||||
),
|
||||
findStalePending: vi.fn(async (cutoff) =>
|
||||
[...rows.values()].filter((r) => r.status === 'pending' && r.createdAt.getTime() < cutoff.getTime()),
|
||||
),
|
||||
findExpiredTerminal: vi.fn(async (cutoff) =>
|
||||
[...rows.values()].filter((r) =>
|
||||
(r.status === 'completed' || r.status === 'error' || r.status === 'cancelled')
|
||||
&& r.completedAt !== null
|
||||
&& r.completedAt.getTime() < cutoff.getTime(),
|
||||
),
|
||||
),
|
||||
deleteMany: vi.fn(async (ids) => {
|
||||
let n = 0;
|
||||
for (const id of ids) if (rows.delete(id)) n += 1;
|
||||
return n;
|
||||
}),
|
||||
};
|
||||
}
|
||||
|
||||
describe('InferenceTaskService — state machine', () => {
|
||||
it('enqueue creates a pending row with the given pool/llm/model', async () => {
|
||||
const repo = mockRepo();
|
||||
const svc = new InferenceTaskService(repo);
|
||||
const t = await svc.enqueue({
|
||||
poolName: 'pool-a',
|
||||
llmName: 'a-1',
|
||||
model: 'qwen3',
|
||||
requestBody: { messages: [] },
|
||||
streaming: false,
|
||||
ownerId: 'owner-1',
|
||||
});
|
||||
expect(t.status).toBe('pending');
|
||||
expect(t.poolName).toBe('pool-a');
|
||||
expect(t.claimedBy).toBeNull();
|
||||
});
|
||||
|
||||
it('tryClaim races: only one of two concurrent claimers gets the row', async () => {
|
||||
const repo = mockRepo();
|
||||
const svc = new InferenceTaskService(repo);
|
||||
const t = await svc.enqueue({
|
||||
poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, streaming: false, ownerId: 'o',
|
||||
});
|
||||
// Both workers issue tryClaim at the "same time". The repo's CAS
|
||||
// serializes them — first claim wins, second sees a non-pending
|
||||
// row and returns null.
|
||||
const [a, b] = await Promise.all([svc.tryClaim(t.id, 'sess-A'), svc.tryClaim(t.id, 'sess-B')]);
|
||||
const winners = [a, b].filter((r) => r !== null);
|
||||
const losers = [a, b].filter((r) => r === null);
|
||||
expect(winners).toHaveLength(1);
|
||||
expect(losers).toHaveLength(1);
|
||||
expect(winners[0]!.status).toBe('claimed');
|
||||
expect(['sess-A', 'sess-B']).toContain(winners[0]!.claimedBy);
|
||||
});
|
||||
|
||||
it('complete after claim transitions claimed → completed and stores responseBody', async () => {
|
||||
const repo = mockRepo();
|
||||
const svc = new InferenceTaskService(repo);
|
||||
const t = await svc.enqueue({ poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, streaming: false, ownerId: 'o' });
|
||||
await svc.tryClaim(t.id, 'sess-A');
|
||||
const done = await svc.complete(t.id, { choices: [{ message: { content: 'hi' } }] });
|
||||
expect(done?.status).toBe('completed');
|
||||
expect(done?.responseBody).toEqual({ choices: [{ message: { content: 'hi' } }] });
|
||||
expect(done?.completedAt).not.toBeNull();
|
||||
});
|
||||
|
||||
it('refuses double-complete (idempotent terminal)', async () => {
|
||||
const repo = mockRepo();
|
||||
const svc = new InferenceTaskService(repo);
|
||||
const t = await svc.enqueue({ poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, streaming: false, ownerId: 'o' });
|
||||
const first = await svc.complete(t.id, { ok: 1 });
|
||||
expect(first?.status).toBe('completed');
|
||||
// Second worker tries to complete the same task — CAS rejects because
|
||||
// the row is no longer in a non-terminal state.
|
||||
const second = await svc.complete(t.id, { ok: 2 });
|
||||
expect(second).toBeNull();
|
||||
// First completion's body is preserved.
|
||||
const reread = await svc.findById(t.id);
|
||||
expect(reread?.responseBody).toEqual({ ok: 1 });
|
||||
});
|
||||
|
||||
it('revertHeldBy reverts every claimed/running row owned by a session and leaves terminals alone', async () => {
|
||||
const repo = mockRepo();
|
||||
const svc = new InferenceTaskService(repo);
|
||||
const t1 = await svc.enqueue({ poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, streaming: false, ownerId: 'o' });
|
||||
const t2 = await svc.enqueue({ poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, streaming: false, ownerId: 'o' });
|
||||
const t3 = await svc.enqueue({ poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, streaming: false, ownerId: 'o' });
|
||||
await svc.tryClaim(t1.id, 'sess-A');
|
||||
await svc.tryClaim(t2.id, 'sess-A');
|
||||
await svc.markRunning(t2.id);
|
||||
await svc.tryClaim(t3.id, 'sess-A');
|
||||
await svc.complete(t3.id, { ok: 1 }); // t3 finished before disconnect
|
||||
|
||||
const reverted = await svc.revertHeldBy('sess-A');
|
||||
expect(reverted.map((r) => r.id).sort()).toEqual([t1.id, t2.id].sort());
|
||||
expect((await svc.findById(t1.id))?.status).toBe('pending');
|
||||
expect((await svc.findById(t2.id))?.status).toBe('pending');
|
||||
// t3 stayed completed — terminal rows are not reverted on disconnect.
|
||||
expect((await svc.findById(t3.id))?.status).toBe('completed');
|
||||
});
|
||||
|
||||
it('cancel from pending records cancelled and emits terminal event', async () => {
|
||||
const repo = mockRepo();
|
||||
const svc = new InferenceTaskService(repo);
|
||||
const t = await svc.enqueue({ poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, streaming: false, ownerId: 'o' });
|
||||
const cancelled = await svc.cancel(t.id);
|
||||
expect(cancelled?.status).toBe('cancelled');
|
||||
// A subsequent complete must fail (cancelled is terminal).
|
||||
const result = await svc.complete(t.id, { ok: 1 });
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe('InferenceTaskService — waitFor signals', () => {
|
||||
it('resolves immediately when the row is already terminal at subscribe time', async () => {
|
||||
const repo = mockRepo();
|
||||
const svc = new InferenceTaskService(repo);
|
||||
const t = await svc.enqueue({ poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, streaming: false, ownerId: 'o' });
|
||||
await svc.complete(t.id, { ok: 1 });
|
||||
const waiter = svc.waitFor(t.id, 1_000);
|
||||
const final = await waiter.done;
|
||||
expect(final.status).toBe('completed');
|
||||
expect(final.responseBody).toEqual({ ok: 1 });
|
||||
});
|
||||
|
||||
it('wakes a waiter on complete event without polling the DB', async () => {
|
||||
const repo = mockRepo();
|
||||
const svc = new InferenceTaskService(repo);
|
||||
const t = await svc.enqueue({ poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, streaming: false, ownerId: 'o' });
|
||||
const waiter = svc.waitFor(t.id, 5_000);
|
||||
// Fire the complete after a microtask so the waiter is definitely
|
||||
// already subscribed to the terminal event.
|
||||
setTimeout(() => { void svc.complete(t.id, { ok: 1 }); }, 10);
|
||||
const final = await waiter.done;
|
||||
expect(final.status).toBe('completed');
|
||||
});
|
||||
|
||||
it('wakes the chunks generator on pushChunk and ends on terminal', async () => {
|
||||
const repo = mockRepo();
|
||||
const svc = new InferenceTaskService(repo);
|
||||
const t = await svc.enqueue({ poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, streaming: true, ownerId: 'o' });
|
||||
const waiter = svc.waitFor(t.id, 5_000);
|
||||
|
||||
setTimeout(() => {
|
||||
svc.pushChunk(t.id, { data: 'hello ' });
|
||||
svc.pushChunk(t.id, { data: 'world' });
|
||||
void svc.complete(t.id, { ok: 1 });
|
||||
}, 10);
|
||||
|
||||
const seen: string[] = [];
|
||||
for await (const c of waiter.chunks) {
|
||||
seen.push(c.data);
|
||||
}
|
||||
expect(seen).toEqual(['hello ', 'world']);
|
||||
const final = await waiter.done;
|
||||
expect(final.status).toBe('completed');
|
||||
});
|
||||
|
||||
it('throws on cancellation with a clear error message', async () => {
|
||||
const repo = mockRepo();
|
||||
const svc = new InferenceTaskService(repo);
|
||||
const t = await svc.enqueue({ poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, streaming: false, ownerId: 'o' });
|
||||
const waiter = svc.waitFor(t.id, 5_000);
|
||||
setTimeout(() => { void svc.cancel(t.id); }, 10);
|
||||
await expect(waiter.done).rejects.toThrow(/cancelled/i);
|
||||
});
|
||||
|
||||
it('throws on error and surfaces the worker\'s errorMessage', async () => {
|
||||
const repo = mockRepo();
|
||||
const svc = new InferenceTaskService(repo);
|
||||
const t = await svc.enqueue({ poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, streaming: false, ownerId: 'o' });
|
||||
const waiter = svc.waitFor(t.id, 5_000);
|
||||
setTimeout(() => { void svc.fail(t.id, 'upstream 500'); }, 10);
|
||||
await expect(waiter.done).rejects.toThrow(/upstream 500/);
|
||||
});
|
||||
|
||||
it('times out when no terminal event arrives within the deadline', async () => {
|
||||
const repo = mockRepo();
|
||||
const svc = new InferenceTaskService(repo);
|
||||
const t = await svc.enqueue({ poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, streaming: false, ownerId: 'o' });
|
||||
const waiter = svc.waitFor(t.id, 30);
|
||||
await expect(waiter.done).rejects.toThrow(/timed out/);
|
||||
});
|
||||
});
|
||||
|
||||
describe('InferenceTaskService — gcSweep', () => {
|
||||
it('flips stale pending rows to error AND deletes expired terminal rows', async () => {
|
||||
const now = new Date('2026-04-28T00:00:00Z');
|
||||
const fixedClock = (): Date => now;
|
||||
const repo = mockRepo();
|
||||
const svc = new InferenceTaskService(repo, fixedClock);
|
||||
|
||||
// 90 min old pending — past the 1h pendingTimeout cutoff → error.
|
||||
const stale = await svc.enqueue({ poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, streaming: false, ownerId: 'o' });
|
||||
// Backdate via direct fixture mutation — easier than wiring a clock through enqueue.
|
||||
const staleRow = (await svc.findById(stale.id))!;
|
||||
(staleRow as { createdAt: Date }).createdAt = new Date(now.getTime() - 90 * 60 * 1000);
|
||||
|
||||
// 30 min old pending — within window, should not be touched.
|
||||
const fresh = await svc.enqueue({ poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, streaming: false, ownerId: 'o' });
|
||||
const freshRow = (await svc.findById(fresh.id))!;
|
||||
(freshRow as { createdAt: Date }).createdAt = new Date(now.getTime() - 30 * 60 * 1000);
|
||||
|
||||
// 8 day-old completed — past the 7d retention → delete.
|
||||
const old = await svc.enqueue({ poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, streaming: false, ownerId: 'o' });
|
||||
await svc.complete(old.id, { ok: 1 });
|
||||
const oldRow = (await svc.findById(old.id))!;
|
||||
(oldRow as { completedAt: Date }).completedAt = new Date(now.getTime() - 8 * 24 * 60 * 60 * 1000);
|
||||
|
||||
const result = await svc.gcSweep({
|
||||
pendingTimeoutMs: 60 * 60 * 1000, // 1h
|
||||
terminalRetentionMs: 7 * 24 * 60 * 60 * 1000, // 7d
|
||||
});
|
||||
expect(result.erroredOut).toBe(1);
|
||||
expect(result.deleted).toBe(1);
|
||||
|
||||
// Stale pending was flipped to error.
|
||||
const staleAfter = await svc.findById(stale.id);
|
||||
expect(staleAfter?.status).toBe('error');
|
||||
expect(staleAfter?.errorMessage).toMatch(/expired in pending/);
|
||||
// Old completed is gone.
|
||||
expect(await svc.findById(old.id)).toBeNull();
|
||||
// Fresh pending untouched.
|
||||
expect((await svc.findById(fresh.id))?.status).toBe('pending');
|
||||
});
|
||||
});
|
||||
Reference in New Issue
Block a user