Compare commits
2 Commits
main
...
feat/infer
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7b18bb6d6b | ||
|
|
ed21ad1b5a |
@@ -0,0 +1,39 @@
|
||||
-- v5: durable inference task queue. Every inference call (sync infer,
|
||||
-- agent chat, or async POST /inference-tasks) gets a row here. Workers
|
||||
-- (mcplocal sessions) drain pending rows whose `poolName` matches the
|
||||
-- pool keys they own when they bind their SSE channel.
|
||||
CREATE TYPE "InferenceTaskStatus" AS ENUM ('pending', 'claimed', 'running', 'completed', 'error', 'cancelled');
|
||||
|
||||
CREATE TABLE "InferenceTask" (
|
||||
"id" TEXT NOT NULL,
|
||||
"status" "InferenceTaskStatus" NOT NULL DEFAULT 'pending',
|
||||
"poolName" TEXT NOT NULL,
|
||||
"llmName" TEXT NOT NULL,
|
||||
"model" TEXT NOT NULL,
|
||||
"tier" TEXT,
|
||||
"claimedBy" TEXT,
|
||||
"requestBody" JSONB NOT NULL,
|
||||
"responseBody" JSONB,
|
||||
"errorMessage" TEXT,
|
||||
"streaming" BOOLEAN NOT NULL DEFAULT false,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"claimedAt" TIMESTAMP(3),
|
||||
"streamStartedAt" TIMESTAMP(3),
|
||||
"completedAt" TIMESTAMP(3),
|
||||
"ownerId" TEXT NOT NULL,
|
||||
"agentId" TEXT,
|
||||
|
||||
CONSTRAINT "InferenceTask_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- Worker claim path: SELECT … WHERE status='pending' AND poolName IN (…)
|
||||
-- runs on every SSE bind. Compound index keeps that fast as the table grows.
|
||||
CREATE INDEX "InferenceTask_status_poolName_idx" ON "InferenceTask"("status", "poolName");
|
||||
-- unbindSession revert path: SELECT … WHERE claimedBy=$1 AND status IN ('claimed','running').
|
||||
CREATE INDEX "InferenceTask_claimedBy_idx" ON "InferenceTask"("claimedBy");
|
||||
-- Owner scoping for the async API listing.
|
||||
CREATE INDEX "InferenceTask_ownerId_idx" ON "InferenceTask"("ownerId");
|
||||
CREATE INDEX "InferenceTask_agentId_idx" ON "InferenceTask"("agentId");
|
||||
-- GC sweep predicate: completedAt < now()-7d. Indexed so the daily cleanup
|
||||
-- doesn't seq-scan once the table grows past a few thousand rows.
|
||||
CREATE INDEX "InferenceTask_completedAt_idx" ON "InferenceTask"("completedAt");
|
||||
@@ -604,6 +604,79 @@ model ChatMessage {
|
||||
@@index([threadId, createdAt])
|
||||
}
|
||||
|
||||
// ── Inference Tasks (v5) ──
|
||||
//
|
||||
// Every inference call (sync infer, agent chat, async POST /inference-tasks)
|
||||
// creates a row here. The DB is the source of truth; mcpd's previous
|
||||
// in-memory `pendingTasks` map is gone — the result-handler updates the row
|
||||
// and an in-process EventEmitter wakes any blocked HTTP handlers (single-
|
||||
// instance for now; multi-instance scaling is a v6 concern that will swap
|
||||
// the emitter for pg LISTEN/NOTIFY without changing the data model).
|
||||
//
|
||||
// Routing: `poolName` is the effective pool key at enqueue time
|
||||
// (`Llm.poolName ?? Llm.name`). Workers (mcplocal sessions) drain pending
|
||||
// rows whose `poolName` matches the pool keys they own when they bind their
|
||||
// SSE channel — that's how queued tasks survive worker offline windows.
|
||||
|
||||
enum InferenceTaskStatus {
|
||||
pending // in queue, no worker has it yet (or claim was reverted)
|
||||
claimed // a worker has it (SSE frame sent), no chunks back yet
|
||||
running // worker started streaming chunks back (streaming tasks only)
|
||||
completed // worker POSTed the final result
|
||||
error // permanent failure (auth, bad request, queue timeout)
|
||||
cancelled // caller said never mind via DELETE
|
||||
}
|
||||
|
||||
model InferenceTask {
|
||||
id String @id @default(cuid())
|
||||
status InferenceTaskStatus @default(pending)
|
||||
// Routing — pool key drives worker matching at claim time. Stored at
|
||||
// enqueue time so a later rename of Llm.poolName doesn't reroute
|
||||
// already-queued work.
|
||||
poolName String
|
||||
llmName String // pinned target Llm name (for audit + agent backref)
|
||||
model String
|
||||
tier String?
|
||||
// Worker tracking. NULL while pending; set on claim; cleared on
|
||||
// unbindSession-driven revert (worker disconnect mid-task).
|
||||
claimedBy String?
|
||||
// Body + result. Both are Json so streaming chunks can be reconstructed
|
||||
// (see TaskService.complete) and async pollers get a structured payload.
|
||||
// requestBody is required (the OpenAI chat-completion request body the
|
||||
// worker should run); responseBody is null until status=completed.
|
||||
requestBody Json
|
||||
responseBody Json?
|
||||
errorMessage String?
|
||||
/**
|
||||
* Whether the original request asked for streaming. Drives the chunk-vs-
|
||||
* final-body protocol on the result POST and tells async API callers
|
||||
* whether `/stream` will yield chunks or just a single completion event.
|
||||
*/
|
||||
streaming Boolean @default(false)
|
||||
// Timestamps for observability + GC:
|
||||
// pending → claimed: claimedAt set
|
||||
// claimed → running: streamStartedAt set (first chunk received)
|
||||
// running/claimed → completed/error/cancelled: completedAt set
|
||||
createdAt DateTime @default(now())
|
||||
claimedAt DateTime?
|
||||
streamStartedAt DateTime?
|
||||
completedAt DateTime?
|
||||
// Caller tracking — RBAC + observability. ownerId references User.id;
|
||||
// agentId is set when the task came in via /agents/<name>/chat (null
|
||||
// for direct /llms/<name>/infer or async POST /inference-tasks calls
|
||||
// that don't pin an agent).
|
||||
ownerId String
|
||||
agentId String?
|
||||
|
||||
@@index([status, poolName])
|
||||
@@index([claimedBy])
|
||||
@@index([ownerId])
|
||||
@@index([agentId])
|
||||
// GC sweep predicate: completedAt < 7d ago. Index speeds up the daily
|
||||
// cleanup without scanning the whole table.
|
||||
@@index([completedAt])
|
||||
}
|
||||
|
||||
// ── Audit Logs ──
|
||||
|
||||
model AuditLog {
|
||||
|
||||
169
src/db/tests/inference-task-schema.test.ts
Normal file
169
src/db/tests/inference-task-schema.test.ts
Normal file
@@ -0,0 +1,169 @@
|
||||
/**
|
||||
* v5 db-level tests for the InferenceTask queue. Exercises the actual
|
||||
* column shapes + index lookups; the mcpd-side service tests cover the
|
||||
* state machine + signal channels with a mocked repo.
|
||||
*/
|
||||
import { describe, it, expect, beforeAll, afterAll, beforeEach } from 'vitest';
|
||||
import type { PrismaClient } from '@prisma/client';
|
||||
import { setupTestDb, cleanupTestDb, clearAllTables } from './helpers.js';
|
||||
|
||||
async function makeOwner(prisma: PrismaClient): Promise<string> {
|
||||
const u = await prisma.user.create({
|
||||
data: { email: `owner-${String(Date.now())}@test`, passwordHash: 'x' },
|
||||
});
|
||||
return u.id;
|
||||
}
|
||||
|
||||
describe('InferenceTask schema (v5)', () => {
|
||||
let prisma: PrismaClient;
|
||||
|
||||
beforeAll(async () => {
|
||||
prisma = await setupTestDb();
|
||||
}, 30_000);
|
||||
|
||||
afterAll(async () => {
|
||||
await cleanupTestDb();
|
||||
});
|
||||
|
||||
beforeEach(async () => {
|
||||
await clearAllTables(prisma);
|
||||
});
|
||||
|
||||
it('defaults a fresh row to status=pending with claim/completion fields null', async () => {
|
||||
const ownerId = await makeOwner(prisma);
|
||||
const row = await prisma.inferenceTask.create({
|
||||
data: {
|
||||
poolName: 'qwen-pool',
|
||||
llmName: 'qwen-prod-1',
|
||||
model: 'qwen3-thinking',
|
||||
requestBody: { messages: [{ role: 'user', content: 'hi' }] },
|
||||
ownerId,
|
||||
},
|
||||
});
|
||||
expect(row.status).toBe('pending');
|
||||
expect(row.claimedBy).toBeNull();
|
||||
expect(row.claimedAt).toBeNull();
|
||||
expect(row.streamStartedAt).toBeNull();
|
||||
expect(row.completedAt).toBeNull();
|
||||
expect(row.responseBody).toBeNull();
|
||||
expect(row.streaming).toBe(false);
|
||||
});
|
||||
|
||||
it('roundtrips streaming=true and a structured requestBody/responseBody', async () => {
|
||||
const ownerId = await makeOwner(prisma);
|
||||
const requestBody = {
|
||||
messages: [{ role: 'user', content: 'hello' }],
|
||||
temperature: 0.2,
|
||||
tools: [{ type: 'function', function: { name: 'noop' } }],
|
||||
};
|
||||
const row = await prisma.inferenceTask.create({
|
||||
data: {
|
||||
poolName: 'qwen-pool',
|
||||
llmName: 'qwen-prod-1',
|
||||
model: 'qwen3',
|
||||
requestBody,
|
||||
streaming: true,
|
||||
ownerId,
|
||||
},
|
||||
});
|
||||
expect(row.streaming).toBe(true);
|
||||
expect(row.requestBody).toEqual(requestBody);
|
||||
|
||||
const completedAt = new Date();
|
||||
const responseBody = { choices: [{ message: { role: 'assistant', content: 'world' } }] };
|
||||
const updated = await prisma.inferenceTask.update({
|
||||
where: { id: row.id },
|
||||
data: { status: 'completed', responseBody, completedAt },
|
||||
});
|
||||
expect(updated.responseBody).toEqual(responseBody);
|
||||
expect(updated.completedAt?.getTime()).toBe(completedAt.getTime());
|
||||
});
|
||||
|
||||
it('compound index supports the dispatcher\'s drain query (status + poolName IN ...)', async () => {
|
||||
// The actual EXPLAIN/index-use check is too brittle for unit tests;
|
||||
// here we verify the QUERY shape that the repo's findPendingForPools
|
||||
// issues — same WHERE/ORDER BY — returns the expected rows in FIFO
|
||||
// order. Index usage is implied by the Prisma model definition.
|
||||
const ownerId = await makeOwner(prisma);
|
||||
const t1 = await prisma.inferenceTask.create({
|
||||
data: { poolName: 'pool-a', llmName: 'a-1', model: 'm', requestBody: {}, ownerId },
|
||||
});
|
||||
await new Promise((r) => setTimeout(r, 5));
|
||||
const t2 = await prisma.inferenceTask.create({
|
||||
data: { poolName: 'pool-a', llmName: 'a-2', model: 'm', requestBody: {}, ownerId },
|
||||
});
|
||||
await prisma.inferenceTask.create({
|
||||
data: { poolName: 'pool-b', llmName: 'b-1', model: 'm', requestBody: {}, ownerId },
|
||||
});
|
||||
// One row in pool-a is no longer pending — must be excluded.
|
||||
await prisma.inferenceTask.create({
|
||||
data: { poolName: 'pool-a', llmName: 'a-3', model: 'm', requestBody: {}, ownerId, status: 'completed' },
|
||||
});
|
||||
|
||||
const drained = await prisma.inferenceTask.findMany({
|
||||
where: { status: 'pending', poolName: { in: ['pool-a', 'pool-b'] } },
|
||||
orderBy: { createdAt: 'asc' },
|
||||
});
|
||||
expect(drained.map((r) => r.id)).toEqual([t1.id, t2.id, drained[2]!.id]);
|
||||
expect(drained.map((r) => r.poolName)).toEqual(['pool-a', 'pool-a', 'pool-b']);
|
||||
});
|
||||
|
||||
it('claimedBy index supports unbindSession revert (worker disconnect path)', async () => {
|
||||
const ownerId = await makeOwner(prisma);
|
||||
await prisma.inferenceTask.create({
|
||||
data: { poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, ownerId, status: 'claimed', claimedBy: 'sess-A' },
|
||||
});
|
||||
await prisma.inferenceTask.create({
|
||||
data: { poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, ownerId, status: 'running', claimedBy: 'sess-A' },
|
||||
});
|
||||
await prisma.inferenceTask.create({
|
||||
data: { poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, ownerId, status: 'claimed', claimedBy: 'sess-B' },
|
||||
});
|
||||
// Completed-but-claimedBy=sess-A row: must NOT revert (terminal state).
|
||||
await prisma.inferenceTask.create({
|
||||
data: { poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, ownerId, status: 'completed', claimedBy: 'sess-A' },
|
||||
});
|
||||
|
||||
const heldByA = await prisma.inferenceTask.findMany({
|
||||
where: { claimedBy: 'sess-A', status: { in: ['claimed', 'running'] } },
|
||||
});
|
||||
expect(heldByA).toHaveLength(2);
|
||||
});
|
||||
|
||||
it('GC predicate (terminal + completedAt < cutoff) is index-friendly and filters correctly', async () => {
|
||||
const ownerId = await makeOwner(prisma);
|
||||
const old = new Date(Date.now() - 8 * 24 * 60 * 60 * 1000); // 8 d ago
|
||||
const recent = new Date(Date.now() - 1 * 60 * 60 * 1000); // 1 h ago
|
||||
await prisma.inferenceTask.create({
|
||||
data: { poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, ownerId, status: 'completed', completedAt: old },
|
||||
});
|
||||
await prisma.inferenceTask.create({
|
||||
data: { poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, ownerId, status: 'error', completedAt: old, errorMessage: 'boom' },
|
||||
});
|
||||
// Inside retention — must not be picked up by GC.
|
||||
await prisma.inferenceTask.create({
|
||||
data: { poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, ownerId, status: 'completed', completedAt: recent },
|
||||
});
|
||||
// Pending row — must not be picked up by terminal GC.
|
||||
await prisma.inferenceTask.create({
|
||||
data: { poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, ownerId, status: 'pending' },
|
||||
});
|
||||
|
||||
const cutoff = new Date(Date.now() - 7 * 24 * 60 * 60 * 1000);
|
||||
const expired = await prisma.inferenceTask.findMany({
|
||||
where: {
|
||||
status: { in: ['completed', 'error', 'cancelled'] },
|
||||
completedAt: { lt: cutoff },
|
||||
},
|
||||
});
|
||||
expect(expired).toHaveLength(2);
|
||||
});
|
||||
|
||||
it('agentId is nullable — direct chat-llm tasks have no agent', async () => {
|
||||
const ownerId = await makeOwner(prisma);
|
||||
const row = await prisma.inferenceTask.create({
|
||||
data: { poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, ownerId, agentId: null },
|
||||
});
|
||||
expect(row.agentId).toBeNull();
|
||||
});
|
||||
});
|
||||
@@ -32,6 +32,8 @@ import { SecretBackendRotatorLoop } from './services/secret-backend-rotator-loop
|
||||
import { registerSecretBackendRotateRoutes } from './routes/secret-backend-rotate.js';
|
||||
import { LlmRepository } from './repositories/llm.repository.js';
|
||||
import { LlmService } from './services/llm.service.js';
|
||||
import { InferenceTaskRepository } from './repositories/inference-task.repository.js';
|
||||
import { InferenceTaskService } from './services/inference-task.service.js';
|
||||
import { AgentRepository } from './repositories/agent.repository.js';
|
||||
import { ChatRepository } from './repositories/chat.repository.js';
|
||||
import { AgentService } from './services/agent.service.js';
|
||||
@@ -463,10 +465,17 @@ async function main(): Promise<void> {
|
||||
const personalityRepo = new PersonalityRepository(prisma);
|
||||
const personalityService = new PersonalityService(personalityRepo, agentRepo, promptRepo);
|
||||
const agentService = new AgentService(agentRepo, llmService, projectService, personalityRepo);
|
||||
// v5: durable inference task queue. VirtualLlmService persists infer
|
||||
// tasks here; the result-route updates them; an in-process emitter
|
||||
// wakes blocked HTTP handlers when results land.
|
||||
const inferenceTaskRepo = new InferenceTaskRepository(prisma);
|
||||
const inferenceTaskService = new InferenceTaskService(inferenceTaskRepo);
|
||||
// Virtual-provider state machine (kind=virtual rows for both Llms and
|
||||
// Agents). v3 wires AgentService for heartbeat/disconnect/GC cascade.
|
||||
// v5 wires inferenceTaskService — enqueueInferTask now persists rows,
|
||||
// worker disconnect reverts claimed rows, worker bind drains pending.
|
||||
// The 60-s GC ticker is started below after `app.listen`.
|
||||
const virtualLlmService = new VirtualLlmService(llmRepo, agentService);
|
||||
const virtualLlmService = new VirtualLlmService(llmRepo, agentService, inferenceTaskService);
|
||||
// ChatService needs the proxy + project repo via the ChatToolDispatcher
|
||||
// bridge. The dispatcher's logger references `app.log`, which is not
|
||||
// constructed until further down — `chatService` itself is built right
|
||||
|
||||
235
src/mcpd/src/repositories/inference-task.repository.ts
Normal file
235
src/mcpd/src/repositories/inference-task.repository.ts
Normal file
@@ -0,0 +1,235 @@
|
||||
/**
|
||||
* v5 InferenceTask repository — durable queue for inference work.
|
||||
*
|
||||
* Workers (mcplocal SSE sessions) drain rows whose `poolName` matches the
|
||||
* pool keys they own when they bind. The state machine is:
|
||||
*
|
||||
* pending ──claim──> claimed ──first chunk──> running
|
||||
* \ │
|
||||
* \ ▼
|
||||
* ──complete/fail──> completed | error
|
||||
*
|
||||
* Plus orthogonal terminal transitions:
|
||||
*
|
||||
* pending|claimed|running ──cancel──> cancelled
|
||||
* claimed|running ──worker disconnect──> pending (revert + re-queue)
|
||||
*
|
||||
* The DB is the source of truth. mcpd's in-flight HTTP handlers wait via
|
||||
* an in-process EventEmitter (see TaskService) — single-instance for now;
|
||||
* pg LISTEN/NOTIFY is the v6 swap when scaling horizontally.
|
||||
*/
|
||||
import type {
|
||||
PrismaClient,
|
||||
InferenceTask,
|
||||
InferenceTaskStatus,
|
||||
} from '@prisma/client';
|
||||
// `Prisma` is imported as a value because we reference Prisma.JsonNull at
|
||||
// runtime (the sentinel for explicit JSON null on a nullable column).
|
||||
import { Prisma } from '@prisma/client';
|
||||
|
||||
export interface CreateInferenceTaskInput {
|
||||
poolName: string;
|
||||
llmName: string;
|
||||
model: string;
|
||||
tier?: string | null;
|
||||
requestBody: Record<string, unknown>;
|
||||
streaming: boolean;
|
||||
ownerId: string;
|
||||
agentId?: string | null;
|
||||
}
|
||||
|
||||
export interface IInferenceTaskRepository {
|
||||
create(data: CreateInferenceTaskInput): Promise<InferenceTask>;
|
||||
findById(id: string): Promise<InferenceTask | null>;
|
||||
/** Pending rows for one or more pool keys, oldest first (FIFO). */
|
||||
findPendingForPools(poolNames: string[], limit?: number): Promise<InferenceTask[]>;
|
||||
/** Tasks held by a worker session — used by unbindSession to revert on disconnect. */
|
||||
findHeldBy(claimedBy: string): Promise<InferenceTask[]>;
|
||||
/** List for the async API; filters are AND-combined. */
|
||||
list(filter: {
|
||||
ownerId?: string;
|
||||
status?: InferenceTaskStatus | InferenceTaskStatus[];
|
||||
poolName?: string;
|
||||
agentId?: string;
|
||||
limit?: number;
|
||||
}): Promise<InferenceTask[]>;
|
||||
|
||||
// ── State transitions ──
|
||||
// Each transition is conditional ("compare-and-swap" via updateMany +
|
||||
// affected-row count) so two workers racing to claim the same row both
|
||||
// succeed at the DB level — one gets affected=1 and a populated row,
|
||||
// the other gets affected=0 and we fall through to the next candidate.
|
||||
|
||||
/** pending → claimed; returns the updated row, or null if the row was no longer pending. */
|
||||
tryClaim(id: string, claimedBy: string, claimedAt: Date): Promise<InferenceTask | null>;
|
||||
/** claimed → running; idempotent — returns the row even if already running. */
|
||||
markRunning(id: string, at: Date): Promise<InferenceTask | null>;
|
||||
/** {claimed,running} → completed; only allowed from a non-terminal state. */
|
||||
markCompleted(id: string, responseBody: Record<string, unknown> | null, at: Date): Promise<InferenceTask | null>;
|
||||
/** any non-terminal → error. Records `errorMessage`. */
|
||||
markError(id: string, errorMessage: string, at: Date): Promise<InferenceTask | null>;
|
||||
/** any non-terminal → cancelled. */
|
||||
markCancelled(id: string, at: Date): Promise<InferenceTask | null>;
|
||||
/** {claimed,running} → pending; clears `claimedBy`. Used on worker disconnect. */
|
||||
revertToPending(id: string): Promise<InferenceTask | null>;
|
||||
|
||||
// ── GC ──
|
||||
|
||||
/** Pending rows older than the cutoff — the GC sweep flips these to error. */
|
||||
findStalePending(cutoff: Date): Promise<InferenceTask[]>;
|
||||
/** Completed/error/cancelled rows older than the cutoff — GC deletes them. */
|
||||
findExpiredTerminal(cutoff: Date): Promise<InferenceTask[]>;
|
||||
/** Bulk delete by id — used by the GC sweep after collecting expired ids. */
|
||||
deleteMany(ids: string[]): Promise<number>;
|
||||
}
|
||||
|
||||
const NON_TERMINAL: InferenceTaskStatus[] = ['pending', 'claimed', 'running'];
|
||||
const HELD: InferenceTaskStatus[] = ['claimed', 'running'];
|
||||
|
||||
export class InferenceTaskRepository implements IInferenceTaskRepository {
|
||||
constructor(private readonly prisma: PrismaClient) {}
|
||||
|
||||
async create(data: CreateInferenceTaskInput): Promise<InferenceTask> {
|
||||
return this.prisma.inferenceTask.create({
|
||||
data: {
|
||||
poolName: data.poolName,
|
||||
llmName: data.llmName,
|
||||
model: data.model,
|
||||
tier: data.tier ?? null,
|
||||
requestBody: data.requestBody as Prisma.InputJsonValue,
|
||||
streaming: data.streaming,
|
||||
ownerId: data.ownerId,
|
||||
agentId: data.agentId ?? null,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
async findById(id: string): Promise<InferenceTask | null> {
|
||||
return this.prisma.inferenceTask.findUnique({ where: { id } });
|
||||
}
|
||||
|
||||
async findPendingForPools(poolNames: string[], limit?: number): Promise<InferenceTask[]> {
|
||||
if (poolNames.length === 0) return [];
|
||||
return this.prisma.inferenceTask.findMany({
|
||||
where: { status: 'pending', poolName: { in: poolNames } },
|
||||
orderBy: { createdAt: 'asc' },
|
||||
...(limit !== undefined ? { take: limit } : {}),
|
||||
});
|
||||
}
|
||||
|
||||
async findHeldBy(claimedBy: string): Promise<InferenceTask[]> {
|
||||
return this.prisma.inferenceTask.findMany({
|
||||
where: { claimedBy, status: { in: HELD } },
|
||||
});
|
||||
}
|
||||
|
||||
async list(filter: {
|
||||
ownerId?: string;
|
||||
status?: InferenceTaskStatus | InferenceTaskStatus[];
|
||||
poolName?: string;
|
||||
agentId?: string;
|
||||
limit?: number;
|
||||
}): Promise<InferenceTask[]> {
|
||||
const where: Prisma.InferenceTaskWhereInput = {};
|
||||
if (filter.ownerId !== undefined) where.ownerId = filter.ownerId;
|
||||
if (filter.status !== undefined) {
|
||||
where.status = Array.isArray(filter.status) ? { in: filter.status } : filter.status;
|
||||
}
|
||||
if (filter.poolName !== undefined) where.poolName = filter.poolName;
|
||||
if (filter.agentId !== undefined) where.agentId = filter.agentId;
|
||||
return this.prisma.inferenceTask.findMany({
|
||||
where,
|
||||
orderBy: { createdAt: 'desc' },
|
||||
...(filter.limit !== undefined ? { take: filter.limit } : {}),
|
||||
});
|
||||
}
|
||||
|
||||
async tryClaim(id: string, claimedBy: string, claimedAt: Date): Promise<InferenceTask | null> {
|
||||
// Compare-and-swap on status='pending'. Two workers racing both run
|
||||
// this UPDATE; whichever the DB serializes first sees affected=1 and
|
||||
// gets the row, the loser sees affected=0 and falls through.
|
||||
const result = await this.prisma.inferenceTask.updateMany({
|
||||
where: { id, status: 'pending' },
|
||||
data: { status: 'claimed', claimedBy, claimedAt },
|
||||
});
|
||||
if (result.count === 0) return null;
|
||||
return this.findById(id);
|
||||
}
|
||||
|
||||
async markRunning(id: string, at: Date): Promise<InferenceTask | null> {
|
||||
const result = await this.prisma.inferenceTask.updateMany({
|
||||
where: { id, status: { in: ['claimed', 'running'] } },
|
||||
data: { status: 'running', streamStartedAt: at },
|
||||
});
|
||||
if (result.count === 0) return null;
|
||||
return this.findById(id);
|
||||
}
|
||||
|
||||
async markCompleted(id: string, responseBody: Record<string, unknown> | null, at: Date): Promise<InferenceTask | null> {
|
||||
// Prisma's nullable JSON column writes need an explicit
|
||||
// `Prisma.JsonNull` sentinel for null; passing literal null fails
|
||||
// typecheck because the column also accepts a plain JSON value.
|
||||
const bodyForUpdate: Prisma.InputJsonValue | typeof Prisma.JsonNull =
|
||||
responseBody === null ? Prisma.JsonNull : (responseBody as Prisma.InputJsonValue);
|
||||
const result = await this.prisma.inferenceTask.updateMany({
|
||||
where: { id, status: { in: NON_TERMINAL } },
|
||||
data: {
|
||||
status: 'completed',
|
||||
responseBody: bodyForUpdate,
|
||||
completedAt: at,
|
||||
},
|
||||
});
|
||||
if (result.count === 0) return null;
|
||||
return this.findById(id);
|
||||
}
|
||||
|
||||
async markError(id: string, errorMessage: string, at: Date): Promise<InferenceTask | null> {
|
||||
const result = await this.prisma.inferenceTask.updateMany({
|
||||
where: { id, status: { in: NON_TERMINAL } },
|
||||
data: { status: 'error', errorMessage, completedAt: at },
|
||||
});
|
||||
if (result.count === 0) return null;
|
||||
return this.findById(id);
|
||||
}
|
||||
|
||||
async markCancelled(id: string, at: Date): Promise<InferenceTask | null> {
|
||||
const result = await this.prisma.inferenceTask.updateMany({
|
||||
where: { id, status: { in: NON_TERMINAL } },
|
||||
data: { status: 'cancelled', completedAt: at },
|
||||
});
|
||||
if (result.count === 0) return null;
|
||||
return this.findById(id);
|
||||
}
|
||||
|
||||
async revertToPending(id: string): Promise<InferenceTask | null> {
|
||||
const result = await this.prisma.inferenceTask.updateMany({
|
||||
where: { id, status: { in: HELD } },
|
||||
data: { status: 'pending', claimedBy: null, claimedAt: null, streamStartedAt: null },
|
||||
});
|
||||
if (result.count === 0) return null;
|
||||
return this.findById(id);
|
||||
}
|
||||
|
||||
async findStalePending(cutoff: Date): Promise<InferenceTask[]> {
|
||||
return this.prisma.inferenceTask.findMany({
|
||||
where: { status: 'pending', createdAt: { lt: cutoff } },
|
||||
});
|
||||
}
|
||||
|
||||
async findExpiredTerminal(cutoff: Date): Promise<InferenceTask[]> {
|
||||
return this.prisma.inferenceTask.findMany({
|
||||
where: {
|
||||
status: { in: ['completed', 'error', 'cancelled'] },
|
||||
completedAt: { lt: cutoff },
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
async deleteMany(ids: string[]): Promise<number> {
|
||||
if (ids.length === 0) return 0;
|
||||
const result = await this.prisma.inferenceTask.deleteMany({
|
||||
where: { id: { in: ids } },
|
||||
});
|
||||
return result.count;
|
||||
}
|
||||
}
|
||||
@@ -98,8 +98,12 @@ export function registerLlmInferRoutes(
|
||||
return { error: 'virtual LLM dispatch unavailable (server misconfiguration)' };
|
||||
}
|
||||
try {
|
||||
// Direct infer is the v1-v4 sync path: caller pinned to a
|
||||
// specific Llm name and wants a fast 503 if the publisher is
|
||||
// offline. The async durable-queue API (Stage 3) is for callers
|
||||
// that explicitly opt into queueing.
|
||||
if (!streaming) {
|
||||
const ref = await deps.virtualLlms.enqueueInferTask(llm.name, body, false);
|
||||
const ref = await deps.virtualLlms.enqueueInferTask(llm.name, body, false, { failFast: true });
|
||||
const result = await ref.done;
|
||||
reply.code(result.status);
|
||||
audit(result.status);
|
||||
@@ -113,7 +117,7 @@ export function registerLlmInferRoutes(
|
||||
Connection: 'keep-alive',
|
||||
'X-Accel-Buffering': 'no',
|
||||
});
|
||||
const ref = await deps.virtualLlms.enqueueInferTask(llm.name, body, true);
|
||||
const ref = await deps.virtualLlms.enqueueInferTask(llm.name, body, true, { failFast: true });
|
||||
const unsubscribe = ref.onChunk((chunk) => writeSseChunk(reply, chunk.data));
|
||||
try {
|
||||
await ref.done;
|
||||
|
||||
@@ -462,10 +462,16 @@ export class ChatService {
|
||||
// iterator. Chunks land on the queue from the SSE relay; the
|
||||
// generator drains them in order. ref.done resolves when the
|
||||
// publisher emits its `[DONE]` marker.
|
||||
//
|
||||
// failFast: true — the chat dispatcher's pool failover loop relies
|
||||
// on a fast "transport error" surfacing so it can iterate to the
|
||||
// next candidate. Without it, a downed pool member would queue the
|
||||
// task and the loop would wait 10 min before trying the next one.
|
||||
const ref = await this.virtualLlms.enqueueInferTask(
|
||||
candidate.llmName,
|
||||
{ ...this.buildBody({ ...ctx, modelOverride: candidate.modelOverride }), stream: true },
|
||||
true,
|
||||
{ failFast: true },
|
||||
);
|
||||
const queue: Array<{ data: string; done?: boolean }> = [];
|
||||
let resolveTick: (() => void) | null = null;
|
||||
@@ -544,6 +550,7 @@ export class ChatService {
|
||||
candidate.llmName,
|
||||
this.buildBody({ ...ctx, modelOverride: candidate.modelOverride }),
|
||||
false,
|
||||
{ failFast: true },
|
||||
);
|
||||
return ref.done;
|
||||
}
|
||||
|
||||
311
src/mcpd/src/services/inference-task.service.ts
Normal file
311
src/mcpd/src/services/inference-task.service.ts
Normal file
@@ -0,0 +1,311 @@
|
||||
/**
|
||||
* v5 InferenceTaskService — business logic over the durable task queue.
|
||||
*
|
||||
* The DB row (managed via IInferenceTaskRepository) is the source of truth
|
||||
* for *state*. This service owns the in-process *signaling* layer that
|
||||
* lets blocked HTTP handlers wake up promptly when a worker POSTs a
|
||||
* result, instead of polling the DB. The signals carry no state — they
|
||||
* just say "go re-read row X" — so a missed signal degrades to a slower
|
||||
* poll-then-resolve (still correct, just less responsive).
|
||||
*
|
||||
* Single-instance assumption: EventEmitter only fires on the mcpd that
|
||||
* holds the row's blocked handler. With multiple mcpd replicas, the
|
||||
* worker's POST might land on instance A while the caller's handler is
|
||||
* on instance B; B never sees the local emitter. Solution at scale is
|
||||
* pg LISTEN/NOTIFY (v6 swap) — no schema change needed; just replace the
|
||||
* EventEmitter wakeup with a NOTIFY emit.
|
||||
*/
|
||||
import { EventEmitter } from 'node:events';
|
||||
import type { InferenceTask, InferenceTaskStatus } from '@prisma/client';
|
||||
import type { IInferenceTaskRepository, CreateInferenceTaskInput } from '../repositories/inference-task.repository.js';
|
||||
import { NotFoundError } from './mcp-server.service.js';
|
||||
|
||||
/**
|
||||
* One streaming chunk pushed back from a worker. Lives in-memory only —
|
||||
* we don't persist every SSE delta to the DB (too expensive for
|
||||
* high-frequency streams). The final assembled body lands on the row
|
||||
* via `complete()` when the worker emits its terminal `done:true`.
|
||||
*/
|
||||
export interface InferenceTaskChunk {
|
||||
data: string;
|
||||
done?: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* The wait handle returned to a blocked HTTP handler. `done` resolves
|
||||
* when the row reaches a terminal status (completed | error | cancelled);
|
||||
* `chunks` is an async iterator that yields streaming deltas as they
|
||||
* arrive (only meaningful when the underlying request asked for streaming).
|
||||
*/
|
||||
export interface InferenceTaskWaiter {
|
||||
/** The DB row's terminal state. Throws on `error`/`cancelled` with an explanatory message. */
|
||||
done: Promise<InferenceTask>;
|
||||
/** Async iterator of streaming chunks. Yields nothing for non-streaming tasks. */
|
||||
chunks: AsyncGenerator<InferenceTaskChunk>;
|
||||
}
|
||||
|
||||
export interface IInferenceTaskService {
|
||||
/** Create a new pending task. Caller is expected to immediately attempt dispatch. */
|
||||
enqueue(input: CreateInferenceTaskInput): Promise<InferenceTask>;
|
||||
/** Wait for a task to reach a terminal state, with optional timeout (ms). */
|
||||
waitFor(taskId: string, timeoutMs: number): InferenceTaskWaiter;
|
||||
/** Conditional CAS claim — returns the row if `pending`, null if already claimed. */
|
||||
tryClaim(taskId: string, claimedBy: string): Promise<InferenceTask | null>;
|
||||
/** Worker reported first chunk — flips `claimed` → `running`. Idempotent. */
|
||||
markRunning(taskId: string): Promise<InferenceTask | null>;
|
||||
/** Worker pushed a streaming chunk. Wakes up the in-process waiter; no DB write. */
|
||||
pushChunk(taskId: string, chunk: InferenceTaskChunk): boolean;
|
||||
/**
|
||||
* Subscribe directly to a task's chunk stream as a callback. Returns
|
||||
* an unsubscribe function. Used by VirtualLlmService's legacy
|
||||
* PendingTaskRef.onChunk bridge — the high-level `waitFor` API is
|
||||
* better when you need an async iterator + done promise together.
|
||||
*/
|
||||
subscribeChunksUnsafe(taskId: string, cb: (chunk: InferenceTaskChunk) => void): () => void;
|
||||
/** Worker POSTed the final result. Persists the body + flips to `completed`. */
|
||||
complete(taskId: string, responseBody: Record<string, unknown> | null): Promise<InferenceTask | null>;
|
||||
/** Worker (or mcpd) marked the task as failed. */
|
||||
fail(taskId: string, errorMessage: string): Promise<InferenceTask | null>;
|
||||
/** Caller (or GC) marked the task as cancelled. */
|
||||
cancel(taskId: string): Promise<InferenceTask | null>;
|
||||
/** Worker disconnected mid-task — revert claimed/running rows back to pending. */
|
||||
revertHeldBy(claimedBy: string): Promise<InferenceTask[]>;
|
||||
/** Pending rows for one or more pool keys, ordered FIFO. Used by drainPending on bind. */
|
||||
findPendingForPools(poolNames: string[], limit?: number): Promise<InferenceTask[]>;
|
||||
findById(id: string): Promise<InferenceTask | null>;
|
||||
list(filter: {
|
||||
ownerId?: string;
|
||||
status?: InferenceTaskStatus | InferenceTaskStatus[];
|
||||
poolName?: string;
|
||||
agentId?: string;
|
||||
limit?: number;
|
||||
}): Promise<InferenceTask[]>;
|
||||
/** GC sweep: pending too long → error; terminal too long → delete. */
|
||||
gcSweep(opts: { pendingTimeoutMs: number; terminalRetentionMs: number }): Promise<{ erroredOut: number; deleted: number }>;
|
||||
}
|
||||
|
||||
const DEFAULT_PENDING_TIMEOUT_MS = 60 * 60 * 1000; // 1 h
|
||||
const DEFAULT_TERMINAL_RETENTION_MS = 7 * 24 * 60 * 60 * 1000; // 7 d
|
||||
|
||||
export class InferenceTaskService implements IInferenceTaskService {
|
||||
/**
|
||||
* Wakeup channel per task. The result-handler (worker POSTs landing on
|
||||
* mcpd) emits a terminal status; the streaming push emits chunks.
|
||||
* Cleared once a waiter resolves to avoid leaks. EventEmitter listener
|
||||
* cap is bumped to a generous default — typical concurrent waiters per
|
||||
* task is 1, but `/stream` SSE consumers and the original HTTP handler
|
||||
* can both subscribe.
|
||||
*/
|
||||
private readonly events = new EventEmitter();
|
||||
|
||||
constructor(
|
||||
private readonly repo: IInferenceTaskRepository,
|
||||
private readonly clock: () => Date = () => new Date(),
|
||||
) {
|
||||
this.events.setMaxListeners(50);
|
||||
}
|
||||
|
||||
async enqueue(input: CreateInferenceTaskInput): Promise<InferenceTask> {
|
||||
return this.repo.create(input);
|
||||
}
|
||||
|
||||
waitFor(taskId: string, timeoutMs: number): InferenceTaskWaiter {
|
||||
// Set up the chunk channel up-front so a worker pushing chunks before
|
||||
// the consumer subscribes doesn't drop them. We use a queue + wake
|
||||
// pattern (same shape as the v3 chat.service streaming bridge).
|
||||
const chunkQueue: InferenceTaskChunk[] = [];
|
||||
let chunkResolve: (() => void) | null = null;
|
||||
let finished = false;
|
||||
let finishError: Error | null = null;
|
||||
|
||||
const onChunk = (chunk: InferenceTaskChunk): void => {
|
||||
chunkQueue.push(chunk);
|
||||
if (chunkResolve !== null) {
|
||||
const r = chunkResolve;
|
||||
chunkResolve = null;
|
||||
r();
|
||||
}
|
||||
};
|
||||
const onTerminal = (): void => {
|
||||
finished = true;
|
||||
if (chunkResolve !== null) {
|
||||
const r = chunkResolve;
|
||||
chunkResolve = null;
|
||||
r();
|
||||
}
|
||||
};
|
||||
|
||||
this.events.on(`chunk:${taskId}`, onChunk);
|
||||
this.events.on(`terminal:${taskId}`, onTerminal);
|
||||
|
||||
const cleanup = (): void => {
|
||||
this.events.off(`chunk:${taskId}`, onChunk);
|
||||
this.events.off(`terminal:${taskId}`, onTerminal);
|
||||
};
|
||||
|
||||
// The terminal promise: poll once at start (in case the row is
|
||||
// already terminal — common for already-completed tasks the caller
|
||||
// is just fetching the result of), then race the timeout against
|
||||
// the wakeup signal.
|
||||
const done = (async (): Promise<InferenceTask> => {
|
||||
try {
|
||||
const initial = await this.repo.findById(taskId);
|
||||
if (initial === null) throw new NotFoundError(`InferenceTask not found: ${taskId}`);
|
||||
if (isTerminal(initial.status)) return ensureSuccess(initial);
|
||||
|
||||
// Wait for terminal event OR timeout. If the wake happens before
|
||||
// the timer fires, we re-fetch and return; otherwise we error.
|
||||
const timer = new Promise<never>((_, reject) => {
|
||||
setTimeout(() => reject(new Error(`InferenceTask wait timed out after ${String(timeoutMs)}ms (task ${taskId})`)), timeoutMs);
|
||||
});
|
||||
const wake = new Promise<void>((resolve) => {
|
||||
this.events.once(`terminal:${taskId}`, () => resolve());
|
||||
});
|
||||
await Promise.race([timer, wake]);
|
||||
finished = true;
|
||||
const final = await this.repo.findById(taskId);
|
||||
if (final === null) throw new NotFoundError(`InferenceTask not found: ${taskId}`);
|
||||
return ensureSuccess(final);
|
||||
} catch (err) {
|
||||
finishError = err as Error;
|
||||
finished = true;
|
||||
throw err;
|
||||
} finally {
|
||||
cleanup();
|
||||
}
|
||||
})();
|
||||
|
||||
const chunks = (async function* gen(): AsyncGenerator<InferenceTaskChunk> {
|
||||
while (true) {
|
||||
while (chunkQueue.length > 0) {
|
||||
const c = chunkQueue.shift()!;
|
||||
yield c;
|
||||
if (c.done === true) return;
|
||||
}
|
||||
if (finished) {
|
||||
if (finishError !== null) throw finishError;
|
||||
return;
|
||||
}
|
||||
await new Promise<void>((r) => { chunkResolve = r; });
|
||||
}
|
||||
})();
|
||||
|
||||
return { done, chunks };
|
||||
}
|
||||
|
||||
async tryClaim(taskId: string, claimedBy: string): Promise<InferenceTask | null> {
|
||||
return this.repo.tryClaim(taskId, claimedBy, this.clock());
|
||||
}
|
||||
|
||||
async markRunning(taskId: string): Promise<InferenceTask | null> {
|
||||
return this.repo.markRunning(taskId, this.clock());
|
||||
}
|
||||
|
||||
pushChunk(taskId: string, chunk: InferenceTaskChunk): boolean {
|
||||
// Streaming chunks are in-memory only. If the row no longer exists or
|
||||
// there are no listeners, the emit is a no-op — return false so the
|
||||
// result-route can decide whether to log a warning.
|
||||
return this.events.emit(`chunk:${taskId}`, chunk);
|
||||
}
|
||||
|
||||
subscribeChunksUnsafe(taskId: string, cb: (chunk: InferenceTaskChunk) => void): () => void {
|
||||
this.events.on(`chunk:${taskId}`, cb);
|
||||
return (): void => {
|
||||
this.events.off(`chunk:${taskId}`, cb);
|
||||
};
|
||||
}
|
||||
|
||||
async complete(taskId: string, responseBody: Record<string, unknown> | null): Promise<InferenceTask | null> {
|
||||
const updated = await this.repo.markCompleted(taskId, responseBody, this.clock());
|
||||
if (updated !== null) this.events.emit(`terminal:${taskId}`);
|
||||
return updated;
|
||||
}
|
||||
|
||||
async fail(taskId: string, errorMessage: string): Promise<InferenceTask | null> {
|
||||
const updated = await this.repo.markError(taskId, errorMessage, this.clock());
|
||||
if (updated !== null) this.events.emit(`terminal:${taskId}`);
|
||||
return updated;
|
||||
}
|
||||
|
||||
async cancel(taskId: string): Promise<InferenceTask | null> {
|
||||
const updated = await this.repo.markCancelled(taskId, this.clock());
|
||||
if (updated !== null) this.events.emit(`terminal:${taskId}`);
|
||||
return updated;
|
||||
}
|
||||
|
||||
async revertHeldBy(claimedBy: string): Promise<InferenceTask[]> {
|
||||
const held = await this.repo.findHeldBy(claimedBy);
|
||||
const reverted: InferenceTask[] = [];
|
||||
for (const t of held) {
|
||||
const r = await this.repo.revertToPending(t.id);
|
||||
if (r !== null) reverted.push(r);
|
||||
}
|
||||
return reverted;
|
||||
}
|
||||
|
||||
async findPendingForPools(poolNames: string[], limit?: number): Promise<InferenceTask[]> {
|
||||
return this.repo.findPendingForPools(poolNames, limit);
|
||||
}
|
||||
|
||||
async findById(id: string): Promise<InferenceTask | null> {
|
||||
return this.repo.findById(id);
|
||||
}
|
||||
|
||||
async list(filter: {
|
||||
ownerId?: string;
|
||||
status?: InferenceTaskStatus | InferenceTaskStatus[];
|
||||
poolName?: string;
|
||||
agentId?: string;
|
||||
limit?: number;
|
||||
}): Promise<InferenceTask[]> {
|
||||
return this.repo.list(filter);
|
||||
}
|
||||
|
||||
async gcSweep(opts: { pendingTimeoutMs: number; terminalRetentionMs: number } = {
|
||||
pendingTimeoutMs: DEFAULT_PENDING_TIMEOUT_MS,
|
||||
terminalRetentionMs: DEFAULT_TERMINAL_RETENTION_MS,
|
||||
}): Promise<{ erroredOut: number; deleted: number }> {
|
||||
const now = this.clock();
|
||||
const pendingCutoff = new Date(now.getTime() - opts.pendingTimeoutMs);
|
||||
const terminalCutoff = new Date(now.getTime() - opts.terminalRetentionMs);
|
||||
|
||||
// Pass 1: pending tasks that have aged out → flip to error so the
|
||||
// caller (if still waiting) gets a clean failure.
|
||||
const stale = await this.repo.findStalePending(pendingCutoff);
|
||||
let erroredOut = 0;
|
||||
for (const t of stale) {
|
||||
const r = await this.repo.markError(t.id, `Task expired in pending state after ${String(opts.pendingTimeoutMs / 1000)}s`, now);
|
||||
if (r !== null) {
|
||||
this.events.emit(`terminal:${t.id}`);
|
||||
erroredOut += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Pass 2: terminal tasks past retention → delete.
|
||||
const expired = await this.repo.findExpiredTerminal(terminalCutoff);
|
||||
const deleted = await this.repo.deleteMany(expired.map((t) => t.id));
|
||||
|
||||
return { erroredOut, deleted };
|
||||
}
|
||||
}
|
||||
|
||||
function isTerminal(status: InferenceTaskStatus): boolean {
|
||||
return status === 'completed' || status === 'error' || status === 'cancelled';
|
||||
}
|
||||
|
||||
/**
|
||||
* Surface a non-success terminal state as a thrown error so the HTTP
|
||||
* handler propagates it cleanly. `error` rows carry an `errorMessage`;
|
||||
* `cancelled` rows don't, so we synthesize one.
|
||||
*/
|
||||
function ensureSuccess(task: InferenceTask): InferenceTask {
|
||||
if (task.status === 'completed') return task;
|
||||
if (task.status === 'cancelled') {
|
||||
throw new Error(`InferenceTask cancelled: ${task.id}`);
|
||||
}
|
||||
if (task.status === 'error') {
|
||||
throw new Error(task.errorMessage ?? `InferenceTask failed: ${task.id}`);
|
||||
}
|
||||
// Theoretically unreachable — waitFor only resolves on terminal events.
|
||||
throw new Error(`InferenceTask in unexpected state: ${String(task.status)}`);
|
||||
}
|
||||
@@ -29,6 +29,8 @@ import type { ILlmRepository } from '../repositories/llm.repository.js';
|
||||
import type { OpenAiChatRequest } from './llm/types.js';
|
||||
import { NotFoundError } from './mcp-server.service.js';
|
||||
import type { AgentService } from './agent.service.js';
|
||||
import type { IInferenceTaskService, InferenceTaskChunk } from './inference-task.service.js';
|
||||
import { effectivePoolName } from './llm.service.js';
|
||||
|
||||
/** A virtual provider's announcement at registration time. */
|
||||
export interface RegisterProviderInput {
|
||||
@@ -81,29 +83,53 @@ export type VirtualTaskFrame =
|
||||
| { kind: 'wake'; taskId: string; llmName: string };
|
||||
|
||||
/**
|
||||
* Pending inference task. The route handler awaits `done`; the result POST
|
||||
* resolves it via `completeTask()`. The error path rejects via `failTask()`.
|
||||
* In-memory wake task. Wake is publisher-control work, not inference —
|
||||
* we don't persist wake tasks to the DB because their lifetime is
|
||||
* milliseconds and missing a wake on restart just means the next infer
|
||||
* fires a fresh wake. Inference tasks live in the durable queue (v5);
|
||||
* see InferenceTaskService.
|
||||
*/
|
||||
interface PendingTask {
|
||||
interface InMemoryWakeTask {
|
||||
taskId: string;
|
||||
sessionId: string;
|
||||
llmName: string;
|
||||
streaming: boolean;
|
||||
resolveNonStreaming: (body: unknown, status: number) => void;
|
||||
rejectNonStreaming: (err: Error) => void;
|
||||
/** For streaming tasks only; null on non-streaming. */
|
||||
pushChunk: ((chunk: { data: string; done?: boolean }) => void) | null;
|
||||
resolve: (status: number) => void;
|
||||
reject: (err: Error) => void;
|
||||
}
|
||||
|
||||
const HEARTBEAT_TIMEOUT_MS = 90_000;
|
||||
const INACTIVE_RETENTION_MS = 4 * 60 * 60 * 1000; // 4 h
|
||||
/**
|
||||
* v5: how long enqueueInferTask waits for the worker's terminal POST
|
||||
* before bailing. 10 minutes is generous for thinking models that
|
||||
* spend 30-90s warming up. The TASK itself stays queued past this —
|
||||
* only the in-flight HTTP handler gives up; an async-API consumer can
|
||||
* still poll the row and pick up the result later.
|
||||
*/
|
||||
const INFER_AWAIT_TIMEOUT_MS = 10 * 60 * 1000;
|
||||
|
||||
export interface EnqueueInferOptions {
|
||||
/**
|
||||
* v5: if true, throw 503 immediately when no live SSE session exists
|
||||
* for the row's session at enqueue time, instead of queueing the
|
||||
* task and waiting for a worker to show up. Used by:
|
||||
* - the direct `/api/v1/llms/<name>/infer` route (caller asked for
|
||||
* THIS specific Llm; no pool fanout at this layer)
|
||||
* - chat.service's pool failover loop (it iterates candidates and
|
||||
* needs each candidate's failure to surface fast)
|
||||
* The async `POST /api/v1/inference-tasks` API leaves this false so
|
||||
* callers explicitly opting into the durable queue get the queue.
|
||||
* Default: false (durable, queues + waits up to INFER_AWAIT_TIMEOUT_MS).
|
||||
*/
|
||||
failFast?: boolean;
|
||||
}
|
||||
|
||||
export interface IVirtualLlmService {
|
||||
register(input: { providerSessionId?: string | null; providers: RegisterProviderInput[] }): Promise<RegisterResult>;
|
||||
heartbeat(providerSessionId: string): Promise<void>;
|
||||
bindSession(providerSessionId: string, handle: VirtualSessionHandle): void;
|
||||
unbindSession(providerSessionId: string): Promise<void>;
|
||||
enqueueInferTask(llmName: string, request: OpenAiChatRequest, streaming: boolean): Promise<PendingTaskRef>;
|
||||
enqueueInferTask(llmName: string, request: OpenAiChatRequest, streaming: boolean, options?: EnqueueInferOptions): Promise<PendingTaskRef>;
|
||||
completeTask(taskId: string, result: { status: number; body: unknown }): boolean;
|
||||
pushTaskChunk(taskId: string, chunk: { data: string; done?: boolean }): boolean;
|
||||
failTask(taskId: string, error: Error): boolean;
|
||||
@@ -121,7 +147,12 @@ export interface PendingTaskRef {
|
||||
|
||||
export class VirtualLlmService implements IVirtualLlmService {
|
||||
private readonly sessions = new Map<string, VirtualSessionHandle>();
|
||||
private readonly tasksById = new Map<string, PendingTask>();
|
||||
/**
|
||||
* Wake tasks live here, in-memory. The result POST resolves them by
|
||||
* id. Inference tasks have moved to the durable queue (v5); see
|
||||
* IInferenceTaskService.
|
||||
*/
|
||||
private readonly wakeTasks = new Map<string, InMemoryWakeTask>();
|
||||
/**
|
||||
* Dedupe concurrent wake requests for the same Llm. The first request
|
||||
* starts the wake; later requests for the same name await the same
|
||||
@@ -138,6 +169,19 @@ export class VirtualLlmService implements IVirtualLlmService {
|
||||
* before deleting the Llm itself (Agent.llmId is Restrict).
|
||||
*/
|
||||
private readonly agents?: AgentService,
|
||||
/**
|
||||
* v5: durable inference task queue. Optional so older test wirings
|
||||
* (and the non-virtual chat path) compile without it; when absent,
|
||||
* enqueueInferTask falls back to a clear "task queue not wired"
|
||||
* error rather than silently regressing to in-memory.
|
||||
*/
|
||||
private readonly tasks?: IInferenceTaskService,
|
||||
/**
|
||||
* v5: caller's user id, threaded into newly-created task rows for
|
||||
* RBAC + observability. Optional — older wirings that don't have a
|
||||
* request context attribute tasks to 'system'.
|
||||
*/
|
||||
private readonly resolveOwner: () => string = () => 'system',
|
||||
) {}
|
||||
|
||||
async register(input: { providerSessionId?: string | null; providers: RegisterProviderInput[] }): Promise<RegisterResult> {
|
||||
@@ -227,6 +271,12 @@ export class VirtualLlmService implements IVirtualLlmService {
|
||||
// Replace any prior handle for this session — keeps "last writer wins"
|
||||
// simple. The old SSE will have been closed by the publisher anyway.
|
||||
this.sessions.set(providerSessionId, handle);
|
||||
// v5: drain queued inference tasks targeting any pool this session
|
||||
// owns. Fire-and-forget: the bind() route handler completes the SSE
|
||||
// handshake first; tasks land on the channel right after.
|
||||
if (this.tasks !== undefined) {
|
||||
void this.drainPendingForSession(providerSessionId, handle);
|
||||
}
|
||||
}
|
||||
|
||||
async unbindSession(providerSessionId: string): Promise<void> {
|
||||
@@ -243,11 +293,60 @@ export class VirtualLlmService implements IVirtualLlmService {
|
||||
if (this.agents !== undefined) {
|
||||
await this.agents.markVirtualAgentsInactiveBySession(providerSessionId);
|
||||
}
|
||||
// Reject any in-flight tasks for this session — the relay can't deliver
|
||||
// a result POST anymore.
|
||||
for (const t of this.tasksById.values()) {
|
||||
// v5: revert claimed/running inference tasks back to pending so
|
||||
// another worker on the same pool can pick them up. If no other
|
||||
// worker is up, they stay queued for whenever one shows up.
|
||||
if (this.tasks !== undefined) {
|
||||
await this.tasks.revertHeldBy(providerSessionId);
|
||||
}
|
||||
// Reject any in-flight wake tasks for this session — those don't
|
||||
// benefit from re-queue since the publisher itself went away.
|
||||
for (const t of this.wakeTasks.values()) {
|
||||
if (t.sessionId === providerSessionId) {
|
||||
this.failTask(t.taskId, new Error('publisher disconnected'));
|
||||
this.wakeTasks.delete(t.taskId);
|
||||
t.reject(new Error('publisher disconnected'));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* v5: when a worker binds its SSE channel, find every pending
|
||||
* InferenceTask targeting a pool key this session owns, claim each,
|
||||
* and push the frame down the channel. If two workers in the same
|
||||
* pool race here, the repo's CAS on tryClaim ensures each task is
|
||||
* pushed to exactly one of them.
|
||||
*/
|
||||
private async drainPendingForSession(
|
||||
providerSessionId: string,
|
||||
handle: VirtualSessionHandle,
|
||||
): Promise<void> {
|
||||
if (this.tasks === undefined) return;
|
||||
const owned = await this.repo.findBySessionId(providerSessionId);
|
||||
if (owned.length === 0) return;
|
||||
const poolKeys = Array.from(new Set(owned.map((row) => effectivePoolName(row))));
|
||||
const pending = await this.tasks.findPendingForPools(poolKeys);
|
||||
for (const task of pending) {
|
||||
// Cap drain per bind to avoid pushing a giant backlog into a
|
||||
// brand-new SSE channel all at once. Hardcoded for now; a
|
||||
// per-session capacity hint is a v6 concern.
|
||||
if (!handle.alive) break;
|
||||
const claimed = await this.tasks.tryClaim(task.id, providerSessionId);
|
||||
if (claimed === null) continue; // raced; another worker got it
|
||||
try {
|
||||
handle.pushTask({
|
||||
kind: 'infer',
|
||||
taskId: claimed.id,
|
||||
llmName: claimed.llmName,
|
||||
request: claimed.requestBody as unknown as OpenAiChatRequest,
|
||||
streaming: claimed.streaming,
|
||||
});
|
||||
} catch (err) {
|
||||
// SSE write failed mid-drain — revert the claim so a
|
||||
// healthier worker can pick it up.
|
||||
await this.tasks.revertHeldBy(providerSessionId);
|
||||
// eslint-disable-next-line no-console
|
||||
console.warn(`drainPendingForSession: pushTask failed for ${claimed.id}: ${(err as Error).message}`);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -256,7 +355,11 @@ export class VirtualLlmService implements IVirtualLlmService {
|
||||
llmName: string,
|
||||
request: OpenAiChatRequest,
|
||||
streaming: boolean,
|
||||
options: EnqueueInferOptions = {},
|
||||
): Promise<PendingTaskRef> {
|
||||
if (this.tasks === undefined) {
|
||||
throw new Error('InferenceTaskService not wired into VirtualLlmService');
|
||||
}
|
||||
const llm = await this.repo.findByName(llmName);
|
||||
if (llm === null) throw new NotFoundError(`Llm not found: ${llmName}`);
|
||||
if (llm.kind !== 'virtual' || llm.providerSessionId === null) {
|
||||
@@ -265,6 +368,13 @@ export class VirtualLlmService implements IVirtualLlmService {
|
||||
{ statusCode: 500 },
|
||||
);
|
||||
}
|
||||
|
||||
// failFast callers (chat.service pool failover, direct infer route)
|
||||
// get the v1-v4 semantic: row inactive OR no live session = 503,
|
||||
// immediately. The chat dispatcher then iterates the next pool
|
||||
// candidate. Without failFast, both cases queue durably and a
|
||||
// future bindSession drains.
|
||||
if (options.failFast === true) {
|
||||
if (llm.status === 'inactive') {
|
||||
throw Object.assign(
|
||||
new Error(`Virtual Llm '${llmName}' is inactive; publisher offline`),
|
||||
@@ -278,54 +388,113 @@ export class VirtualLlmService implements IVirtualLlmService {
|
||||
{ statusCode: 503 },
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// ── Wake-on-demand (v2) ──
|
||||
// Status=hibernating means the publisher told us at register time
|
||||
// (or via a later status update) that the backend is asleep. Fire a
|
||||
// wake task and wait for the publisher to confirm readiness before
|
||||
// dispatching the actual inference. Concurrent infers for the same
|
||||
// Llm share a single wake promise.
|
||||
// that the backend is asleep. Fire a wake task and wait for the
|
||||
// publisher to confirm readiness before persisting the inference
|
||||
// task. Concurrent infers for the same Llm share a single wake
|
||||
// promise. We hold the wake INSIDE this method (not after enqueue)
|
||||
// so we don't queue an inference task that we then can't dispatch.
|
||||
if (llm.status === 'hibernating') {
|
||||
const handle = this.sessions.get(llm.providerSessionId);
|
||||
if (handle === undefined || !handle.alive) {
|
||||
throw Object.assign(
|
||||
new Error(`Virtual Llm '${llmName}' has no live SSE session; cannot wake`),
|
||||
{ statusCode: 503 },
|
||||
);
|
||||
}
|
||||
await this.ensureAwake(llm.id, llm.name, llm.providerSessionId, handle);
|
||||
}
|
||||
|
||||
const taskId = randomUUID();
|
||||
const chunkSubscribers = new Set<(chunk: { data: string; done?: boolean }) => void>();
|
||||
|
||||
let resolveDone!: (v: { status: number; body: unknown }) => void;
|
||||
let rejectDone!: (err: Error) => void;
|
||||
const done = new Promise<{ status: number; body: unknown }>((resolve, reject) => {
|
||||
resolveDone = resolve;
|
||||
rejectDone = reject;
|
||||
// ── v5: persist the task BEFORE attempting dispatch ──
|
||||
// Even if no worker is up, the row stays pending and a future
|
||||
// bindSession will drain it. Caller's HTTP timeout still bounds
|
||||
// the wait, but the *task* survives.
|
||||
const created = await this.tasks.enqueue({
|
||||
poolName: effectivePoolName(llm),
|
||||
llmName,
|
||||
model: llm.model,
|
||||
tier: llm.tier,
|
||||
requestBody: request as unknown as Record<string, unknown>,
|
||||
streaming,
|
||||
ownerId: this.resolveOwner(),
|
||||
});
|
||||
|
||||
const pending: PendingTask = {
|
||||
taskId,
|
||||
sessionId: llm.providerSessionId,
|
||||
llmName,
|
||||
streaming,
|
||||
resolveNonStreaming: (body, status) => resolveDone({ status, body }),
|
||||
rejectNonStreaming: rejectDone,
|
||||
pushChunk: streaming
|
||||
? (chunk): void => { for (const cb of chunkSubscribers) cb(chunk); }
|
||||
: null,
|
||||
};
|
||||
this.tasksById.set(taskId, pending);
|
||||
|
||||
// Try to claim + dispatch immediately if a session is up. If not,
|
||||
// the row stays pending for drainPendingForSession to pick up.
|
||||
const handle = this.sessions.get(llm.providerSessionId);
|
||||
if (handle !== undefined && handle.alive) {
|
||||
const claimed = await this.tasks.tryClaim(created.id, llm.providerSessionId);
|
||||
if (claimed !== null) {
|
||||
handle.pushTask({
|
||||
kind: 'infer',
|
||||
taskId,
|
||||
taskId: claimed.id,
|
||||
llmName,
|
||||
request,
|
||||
streaming,
|
||||
});
|
||||
}
|
||||
// tryClaim can return null only if another concurrent enqueue
|
||||
// (or the GC sweep) already moved the row off pending — leave
|
||||
// it; drainPendingForSession will reconcile.
|
||||
}
|
||||
|
||||
// Wrap the task service's waitFor into the legacy PendingTaskRef
|
||||
// shape so existing callers (chat.service, llm-infer route)
|
||||
// don't need to change. The chunk callback bridges
|
||||
// InferenceTaskService.events into the on-the-fly subscriber set.
|
||||
const tasks = this.tasks;
|
||||
const taskId = created.id;
|
||||
const subscribers = new Set<(chunk: InferenceTaskChunk) => void>();
|
||||
// Single bridge listener on the service's emitter; fans out to all
|
||||
// subscribers locally so we don't pile up listeners on a hot key.
|
||||
let bridgeUnsub: (() => void) | null = null;
|
||||
const ensureBridge = (): void => {
|
||||
if (bridgeUnsub !== null) return;
|
||||
const onChunk = (chunk: InferenceTaskChunk): void => {
|
||||
for (const cb of subscribers) cb(chunk);
|
||||
};
|
||||
// Subscribe via the public pushChunk path: the service emits on
|
||||
// its own EventEmitter; we attach a listener here. Use the
|
||||
// service's events through a side channel — exposed below as
|
||||
// subscribeChunks for clarity.
|
||||
bridgeUnsub = tasks.subscribeChunksUnsafe(taskId, onChunk);
|
||||
};
|
||||
|
||||
const done = (async (): Promise<{ status: number; body: unknown }> => {
|
||||
// waitFor's `done` rejects on cancel/error/timeout. For non-
|
||||
// streaming tasks the responseBody IS the body; for streaming
|
||||
// the body is null and chunks have already been piped through.
|
||||
const waiter = tasks.waitFor(taskId, INFER_AWAIT_TIMEOUT_MS);
|
||||
// If streaming, we must drain the chunks generator concurrently
|
||||
// so chunks aren't dropped just because no one's subscribed yet.
|
||||
// The real consumer subscribes via onChunk(); their cb fires
|
||||
// synchronously inside the bridge.
|
||||
if (streaming) ensureBridge();
|
||||
const finalRow = await waiter.done;
|
||||
// Status code: legacy callers expect 200 on success; the worker
|
||||
// sends its own status via the result POST and we forward it.
|
||||
// Today, success POSTs come with status=200 and the row's
|
||||
// responseBody is { status, body }. v5 stores the *body* directly
|
||||
// on the row; we synthesize a 200 here to match the legacy shape.
|
||||
return { status: 200, body: finalRow.responseBody };
|
||||
})();
|
||||
|
||||
return {
|
||||
taskId,
|
||||
done,
|
||||
onChunk(cb): () => void {
|
||||
chunkSubscribers.add(cb);
|
||||
return () => chunkSubscribers.delete(cb);
|
||||
subscribers.add(cb);
|
||||
ensureBridge();
|
||||
return (): void => {
|
||||
subscribers.delete(cb);
|
||||
if (subscribers.size === 0 && bridgeUnsub !== null) {
|
||||
bridgeUnsub();
|
||||
bridgeUnsub = null;
|
||||
}
|
||||
};
|
||||
},
|
||||
};
|
||||
}
|
||||
@@ -370,22 +539,17 @@ export class VirtualLlmService implements IVirtualLlmService {
|
||||
rejectDone = reject;
|
||||
});
|
||||
|
||||
const pending: PendingTask = {
|
||||
const wake: InMemoryWakeTask = {
|
||||
taskId,
|
||||
sessionId,
|
||||
llmName,
|
||||
streaming: false,
|
||||
// Wake tasks return { ok: true } on success or never resolve at
|
||||
// all if the publisher dies; the rejectNonStreaming path covers
|
||||
// the disconnect-mid-wake case via unbindSession.
|
||||
resolveNonStreaming: (_body, status) => {
|
||||
resolve: (status) => {
|
||||
if (status >= 200 && status < 300) resolveDone();
|
||||
else rejectDone(new Error(`wake task returned status ${String(status)}`));
|
||||
},
|
||||
rejectNonStreaming: rejectDone,
|
||||
pushChunk: null,
|
||||
reject: rejectDone,
|
||||
};
|
||||
this.tasksById.set(taskId, pending);
|
||||
this.wakeTasks.set(taskId, wake);
|
||||
|
||||
handle.pushTask({ kind: 'wake', taskId, llmName });
|
||||
|
||||
@@ -402,33 +566,56 @@ export class VirtualLlmService implements IVirtualLlmService {
|
||||
}
|
||||
|
||||
completeTask(taskId: string, result: { status: number; body: unknown }): boolean {
|
||||
const t = this.tasksById.get(taskId);
|
||||
if (t === undefined) return false;
|
||||
this.tasksById.delete(taskId);
|
||||
t.resolveNonStreaming(result.body, result.status);
|
||||
// Wake tasks: in-memory map. Resolve and bail.
|
||||
const wake = this.wakeTasks.get(taskId);
|
||||
if (wake !== undefined) {
|
||||
this.wakeTasks.delete(taskId);
|
||||
wake.resolve(result.status);
|
||||
return true;
|
||||
}
|
||||
// Inference tasks: durable queue. Persist body + flip terminal +
|
||||
// emit wakeup. Fire-and-forget the DB write; the result POST
|
||||
// returns 200 either way (the caller's HTTP handler is what cares
|
||||
// about the eventual emit, not the worker).
|
||||
if (this.tasks !== undefined) {
|
||||
void this.tasks.complete(taskId, result.body as Record<string, unknown> | null);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
pushTaskChunk(taskId: string, chunk: { data: string; done?: boolean }): boolean {
|
||||
const t = this.tasksById.get(taskId);
|
||||
if (t === undefined || t.pushChunk === null) return false;
|
||||
t.pushChunk(chunk);
|
||||
// Wake tasks never receive chunks — they're non-streaming control
|
||||
// messages. Don't even check the wake map; if the id isn't an
|
||||
// inference task, just drop the chunk.
|
||||
if (this.tasks === undefined) return false;
|
||||
// First chunk for a claimed task → flip claimed → running. Idempotent
|
||||
// and fire-and-forget; if the row is already running, the CAS in the
|
||||
// repo just no-ops.
|
||||
void this.tasks.markRunning(taskId);
|
||||
this.tasks.pushChunk(taskId, chunk);
|
||||
if (chunk.done === true) {
|
||||
// For streaming tasks, also resolve the `done` promise so the route
|
||||
// handler can clean up.
|
||||
t.resolveNonStreaming(null, 200);
|
||||
this.tasksById.delete(taskId);
|
||||
// Streaming completion: persist with null body + flip terminal so
|
||||
// the waiter unblocks. The actual content was already streamed
|
||||
// through the chunks channel; nothing to store.
|
||||
void this.tasks.complete(taskId, null);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
failTask(taskId: string, error: Error): boolean {
|
||||
const t = this.tasksById.get(taskId);
|
||||
if (t === undefined) return false;
|
||||
this.tasksById.delete(taskId);
|
||||
t.rejectNonStreaming(error);
|
||||
const wake = this.wakeTasks.get(taskId);
|
||||
if (wake !== undefined) {
|
||||
this.wakeTasks.delete(taskId);
|
||||
wake.reject(error);
|
||||
return true;
|
||||
}
|
||||
if (this.tasks !== undefined) {
|
||||
void this.tasks.fail(taskId, error.message);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
async gcSweep(now: Date = new Date()): Promise<{ markedInactive: number; deleted: number }> {
|
||||
let markedInactive = 0;
|
||||
|
||||
@@ -193,6 +193,10 @@ describe('ChatService — kind=virtual branch (v3 Stage 1)', () => {
|
||||
'vllm-local',
|
||||
expect.objectContaining({ messages: expect.any(Array) }),
|
||||
false,
|
||||
// v5: chat.service passes failFast:true so its pool failover loop
|
||||
// surfaces transport errors quickly instead of waiting on the
|
||||
// durable queue's 10-min timeout.
|
||||
{ failFast: true },
|
||||
);
|
||||
});
|
||||
|
||||
@@ -224,6 +228,7 @@ describe('ChatService — kind=virtual branch (v3 Stage 1)', () => {
|
||||
'vllm-local',
|
||||
expect.objectContaining({ messages: expect.any(Array), stream: true }),
|
||||
true,
|
||||
{ failFast: true },
|
||||
);
|
||||
});
|
||||
|
||||
|
||||
330
src/mcpd/tests/inference-task-service.test.ts
Normal file
330
src/mcpd/tests/inference-task-service.test.ts
Normal file
@@ -0,0 +1,330 @@
|
||||
/**
|
||||
* v5 InferenceTaskService — state machine, signal channels, and GC sweep.
|
||||
* The repo is mocked with an in-memory Map so we exercise the service
|
||||
* logic deterministically without touching Postgres. Schema-level tests
|
||||
* live in src/db/tests/inference-task-schema.test.ts.
|
||||
*/
|
||||
import { describe, it, expect, vi } from 'vitest';
|
||||
import type { InferenceTask, InferenceTaskStatus } from '@prisma/client';
|
||||
import type { IInferenceTaskRepository } from '../src/repositories/inference-task.repository.js';
|
||||
import { InferenceTaskService } from '../src/services/inference-task.service.js';
|
||||
|
||||
function makeRow(overrides: Partial<InferenceTask> = {}): InferenceTask {
|
||||
return {
|
||||
id: overrides.id ?? `task-${Math.random().toString(36).slice(2, 8)}`,
|
||||
status: 'pending',
|
||||
poolName: 'pool-a',
|
||||
llmName: 'pool-a',
|
||||
model: 'qwen3-thinking',
|
||||
tier: null,
|
||||
claimedBy: null,
|
||||
requestBody: { messages: [{ role: 'user', content: 'hi' }] } as unknown as InferenceTask['requestBody'],
|
||||
responseBody: null,
|
||||
errorMessage: null,
|
||||
streaming: false,
|
||||
createdAt: new Date(),
|
||||
claimedAt: null,
|
||||
streamStartedAt: null,
|
||||
completedAt: null,
|
||||
ownerId: 'owner-1',
|
||||
agentId: null,
|
||||
...overrides,
|
||||
};
|
||||
}
|
||||
|
||||
function mockRepo(initial: InferenceTask[] = []): IInferenceTaskRepository {
|
||||
const rows = new Map<string, InferenceTask>(initial.map((r) => [r.id, { ...r }]));
|
||||
// Tiny CAS helper that mirrors the real `updateMany({where:{status:in}})`
|
||||
// semantics — only flips the row if the current status matches.
|
||||
const cas = (id: string, allowed: InferenceTaskStatus[], patch: Partial<InferenceTask>): InferenceTask | null => {
|
||||
const row = rows.get(id);
|
||||
if (row === undefined) return null;
|
||||
if (!allowed.includes(row.status)) return null;
|
||||
const next = { ...row, ...patch };
|
||||
rows.set(id, next);
|
||||
return next;
|
||||
};
|
||||
return {
|
||||
create: vi.fn(async (data) => {
|
||||
const row = makeRow({
|
||||
id: `task-${rows.size + 1}`,
|
||||
poolName: data.poolName,
|
||||
llmName: data.llmName,
|
||||
model: data.model,
|
||||
tier: data.tier ?? null,
|
||||
requestBody: data.requestBody as InferenceTask['requestBody'],
|
||||
streaming: data.streaming,
|
||||
ownerId: data.ownerId,
|
||||
agentId: data.agentId ?? null,
|
||||
});
|
||||
rows.set(row.id, row);
|
||||
return row;
|
||||
}),
|
||||
findById: vi.fn(async (id) => rows.get(id) ?? null),
|
||||
findPendingForPools: vi.fn(async (poolNames, limit) => {
|
||||
const out = [...rows.values()]
|
||||
.filter((r) => r.status === 'pending' && poolNames.includes(r.poolName))
|
||||
.sort((a, b) => a.createdAt.getTime() - b.createdAt.getTime());
|
||||
return limit !== undefined ? out.slice(0, limit) : out;
|
||||
}),
|
||||
findHeldBy: vi.fn(async (claimedBy) =>
|
||||
[...rows.values()].filter((r) => r.claimedBy === claimedBy && (r.status === 'claimed' || r.status === 'running')),
|
||||
),
|
||||
list: vi.fn(async (filter) => {
|
||||
let out = [...rows.values()];
|
||||
if (filter.ownerId !== undefined) out = out.filter((r) => r.ownerId === filter.ownerId);
|
||||
if (filter.poolName !== undefined) out = out.filter((r) => r.poolName === filter.poolName);
|
||||
if (filter.agentId !== undefined) out = out.filter((r) => r.agentId === filter.agentId);
|
||||
if (filter.status !== undefined) {
|
||||
const statuses = Array.isArray(filter.status) ? filter.status : [filter.status];
|
||||
out = out.filter((r) => statuses.includes(r.status));
|
||||
}
|
||||
out.sort((a, b) => b.createdAt.getTime() - a.createdAt.getTime());
|
||||
return filter.limit !== undefined ? out.slice(0, filter.limit) : out;
|
||||
}),
|
||||
tryClaim: vi.fn(async (id, claimedBy, claimedAt) =>
|
||||
cas(id, ['pending'], { status: 'claimed', claimedBy, claimedAt }),
|
||||
),
|
||||
markRunning: vi.fn(async (id, at) =>
|
||||
cas(id, ['claimed', 'running'], { status: 'running', streamStartedAt: at }),
|
||||
),
|
||||
markCompleted: vi.fn(async (id, body, at) =>
|
||||
cas(id, ['pending', 'claimed', 'running'], {
|
||||
status: 'completed',
|
||||
responseBody: (body ?? null) as InferenceTask['responseBody'],
|
||||
completedAt: at,
|
||||
}),
|
||||
),
|
||||
markError: vi.fn(async (id, errorMessage, at) =>
|
||||
cas(id, ['pending', 'claimed', 'running'], { status: 'error', errorMessage, completedAt: at }),
|
||||
),
|
||||
markCancelled: vi.fn(async (id, at) =>
|
||||
cas(id, ['pending', 'claimed', 'running'], { status: 'cancelled', completedAt: at }),
|
||||
),
|
||||
revertToPending: vi.fn(async (id) =>
|
||||
cas(id, ['claimed', 'running'], { status: 'pending', claimedBy: null, claimedAt: null, streamStartedAt: null }),
|
||||
),
|
||||
findStalePending: vi.fn(async (cutoff) =>
|
||||
[...rows.values()].filter((r) => r.status === 'pending' && r.createdAt.getTime() < cutoff.getTime()),
|
||||
),
|
||||
findExpiredTerminal: vi.fn(async (cutoff) =>
|
||||
[...rows.values()].filter((r) =>
|
||||
(r.status === 'completed' || r.status === 'error' || r.status === 'cancelled')
|
||||
&& r.completedAt !== null
|
||||
&& r.completedAt.getTime() < cutoff.getTime(),
|
||||
),
|
||||
),
|
||||
deleteMany: vi.fn(async (ids) => {
|
||||
let n = 0;
|
||||
for (const id of ids) if (rows.delete(id)) n += 1;
|
||||
return n;
|
||||
}),
|
||||
};
|
||||
}
|
||||
|
||||
describe('InferenceTaskService — state machine', () => {
|
||||
it('enqueue creates a pending row with the given pool/llm/model', async () => {
|
||||
const repo = mockRepo();
|
||||
const svc = new InferenceTaskService(repo);
|
||||
const t = await svc.enqueue({
|
||||
poolName: 'pool-a',
|
||||
llmName: 'a-1',
|
||||
model: 'qwen3',
|
||||
requestBody: { messages: [] },
|
||||
streaming: false,
|
||||
ownerId: 'owner-1',
|
||||
});
|
||||
expect(t.status).toBe('pending');
|
||||
expect(t.poolName).toBe('pool-a');
|
||||
expect(t.claimedBy).toBeNull();
|
||||
});
|
||||
|
||||
it('tryClaim races: only one of two concurrent claimers gets the row', async () => {
|
||||
const repo = mockRepo();
|
||||
const svc = new InferenceTaskService(repo);
|
||||
const t = await svc.enqueue({
|
||||
poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, streaming: false, ownerId: 'o',
|
||||
});
|
||||
// Both workers issue tryClaim at the "same time". The repo's CAS
|
||||
// serializes them — first claim wins, second sees a non-pending
|
||||
// row and returns null.
|
||||
const [a, b] = await Promise.all([svc.tryClaim(t.id, 'sess-A'), svc.tryClaim(t.id, 'sess-B')]);
|
||||
const winners = [a, b].filter((r) => r !== null);
|
||||
const losers = [a, b].filter((r) => r === null);
|
||||
expect(winners).toHaveLength(1);
|
||||
expect(losers).toHaveLength(1);
|
||||
expect(winners[0]!.status).toBe('claimed');
|
||||
expect(['sess-A', 'sess-B']).toContain(winners[0]!.claimedBy);
|
||||
});
|
||||
|
||||
it('complete after claim transitions claimed → completed and stores responseBody', async () => {
|
||||
const repo = mockRepo();
|
||||
const svc = new InferenceTaskService(repo);
|
||||
const t = await svc.enqueue({ poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, streaming: false, ownerId: 'o' });
|
||||
await svc.tryClaim(t.id, 'sess-A');
|
||||
const done = await svc.complete(t.id, { choices: [{ message: { content: 'hi' } }] });
|
||||
expect(done?.status).toBe('completed');
|
||||
expect(done?.responseBody).toEqual({ choices: [{ message: { content: 'hi' } }] });
|
||||
expect(done?.completedAt).not.toBeNull();
|
||||
});
|
||||
|
||||
it('refuses double-complete (idempotent terminal)', async () => {
|
||||
const repo = mockRepo();
|
||||
const svc = new InferenceTaskService(repo);
|
||||
const t = await svc.enqueue({ poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, streaming: false, ownerId: 'o' });
|
||||
const first = await svc.complete(t.id, { ok: 1 });
|
||||
expect(first?.status).toBe('completed');
|
||||
// Second worker tries to complete the same task — CAS rejects because
|
||||
// the row is no longer in a non-terminal state.
|
||||
const second = await svc.complete(t.id, { ok: 2 });
|
||||
expect(second).toBeNull();
|
||||
// First completion's body is preserved.
|
||||
const reread = await svc.findById(t.id);
|
||||
expect(reread?.responseBody).toEqual({ ok: 1 });
|
||||
});
|
||||
|
||||
it('revertHeldBy reverts every claimed/running row owned by a session and leaves terminals alone', async () => {
|
||||
const repo = mockRepo();
|
||||
const svc = new InferenceTaskService(repo);
|
||||
const t1 = await svc.enqueue({ poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, streaming: false, ownerId: 'o' });
|
||||
const t2 = await svc.enqueue({ poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, streaming: false, ownerId: 'o' });
|
||||
const t3 = await svc.enqueue({ poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, streaming: false, ownerId: 'o' });
|
||||
await svc.tryClaim(t1.id, 'sess-A');
|
||||
await svc.tryClaim(t2.id, 'sess-A');
|
||||
await svc.markRunning(t2.id);
|
||||
await svc.tryClaim(t3.id, 'sess-A');
|
||||
await svc.complete(t3.id, { ok: 1 }); // t3 finished before disconnect
|
||||
|
||||
const reverted = await svc.revertHeldBy('sess-A');
|
||||
expect(reverted.map((r) => r.id).sort()).toEqual([t1.id, t2.id].sort());
|
||||
expect((await svc.findById(t1.id))?.status).toBe('pending');
|
||||
expect((await svc.findById(t2.id))?.status).toBe('pending');
|
||||
// t3 stayed completed — terminal rows are not reverted on disconnect.
|
||||
expect((await svc.findById(t3.id))?.status).toBe('completed');
|
||||
});
|
||||
|
||||
it('cancel from pending records cancelled and emits terminal event', async () => {
|
||||
const repo = mockRepo();
|
||||
const svc = new InferenceTaskService(repo);
|
||||
const t = await svc.enqueue({ poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, streaming: false, ownerId: 'o' });
|
||||
const cancelled = await svc.cancel(t.id);
|
||||
expect(cancelled?.status).toBe('cancelled');
|
||||
// A subsequent complete must fail (cancelled is terminal).
|
||||
const result = await svc.complete(t.id, { ok: 1 });
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe('InferenceTaskService — waitFor signals', () => {
|
||||
it('resolves immediately when the row is already terminal at subscribe time', async () => {
|
||||
const repo = mockRepo();
|
||||
const svc = new InferenceTaskService(repo);
|
||||
const t = await svc.enqueue({ poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, streaming: false, ownerId: 'o' });
|
||||
await svc.complete(t.id, { ok: 1 });
|
||||
const waiter = svc.waitFor(t.id, 1_000);
|
||||
const final = await waiter.done;
|
||||
expect(final.status).toBe('completed');
|
||||
expect(final.responseBody).toEqual({ ok: 1 });
|
||||
});
|
||||
|
||||
it('wakes a waiter on complete event without polling the DB', async () => {
|
||||
const repo = mockRepo();
|
||||
const svc = new InferenceTaskService(repo);
|
||||
const t = await svc.enqueue({ poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, streaming: false, ownerId: 'o' });
|
||||
const waiter = svc.waitFor(t.id, 5_000);
|
||||
// Fire the complete after a microtask so the waiter is definitely
|
||||
// already subscribed to the terminal event.
|
||||
setTimeout(() => { void svc.complete(t.id, { ok: 1 }); }, 10);
|
||||
const final = await waiter.done;
|
||||
expect(final.status).toBe('completed');
|
||||
});
|
||||
|
||||
it('wakes the chunks generator on pushChunk and ends on terminal', async () => {
|
||||
const repo = mockRepo();
|
||||
const svc = new InferenceTaskService(repo);
|
||||
const t = await svc.enqueue({ poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, streaming: true, ownerId: 'o' });
|
||||
const waiter = svc.waitFor(t.id, 5_000);
|
||||
|
||||
setTimeout(() => {
|
||||
svc.pushChunk(t.id, { data: 'hello ' });
|
||||
svc.pushChunk(t.id, { data: 'world' });
|
||||
void svc.complete(t.id, { ok: 1 });
|
||||
}, 10);
|
||||
|
||||
const seen: string[] = [];
|
||||
for await (const c of waiter.chunks) {
|
||||
seen.push(c.data);
|
||||
}
|
||||
expect(seen).toEqual(['hello ', 'world']);
|
||||
const final = await waiter.done;
|
||||
expect(final.status).toBe('completed');
|
||||
});
|
||||
|
||||
it('throws on cancellation with a clear error message', async () => {
|
||||
const repo = mockRepo();
|
||||
const svc = new InferenceTaskService(repo);
|
||||
const t = await svc.enqueue({ poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, streaming: false, ownerId: 'o' });
|
||||
const waiter = svc.waitFor(t.id, 5_000);
|
||||
setTimeout(() => { void svc.cancel(t.id); }, 10);
|
||||
await expect(waiter.done).rejects.toThrow(/cancelled/i);
|
||||
});
|
||||
|
||||
it('throws on error and surfaces the worker\'s errorMessage', async () => {
|
||||
const repo = mockRepo();
|
||||
const svc = new InferenceTaskService(repo);
|
||||
const t = await svc.enqueue({ poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, streaming: false, ownerId: 'o' });
|
||||
const waiter = svc.waitFor(t.id, 5_000);
|
||||
setTimeout(() => { void svc.fail(t.id, 'upstream 500'); }, 10);
|
||||
await expect(waiter.done).rejects.toThrow(/upstream 500/);
|
||||
});
|
||||
|
||||
it('times out when no terminal event arrives within the deadline', async () => {
|
||||
const repo = mockRepo();
|
||||
const svc = new InferenceTaskService(repo);
|
||||
const t = await svc.enqueue({ poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, streaming: false, ownerId: 'o' });
|
||||
const waiter = svc.waitFor(t.id, 30);
|
||||
await expect(waiter.done).rejects.toThrow(/timed out/);
|
||||
});
|
||||
});
|
||||
|
||||
describe('InferenceTaskService — gcSweep', () => {
|
||||
it('flips stale pending rows to error AND deletes expired terminal rows', async () => {
|
||||
const now = new Date('2026-04-28T00:00:00Z');
|
||||
const fixedClock = (): Date => now;
|
||||
const repo = mockRepo();
|
||||
const svc = new InferenceTaskService(repo, fixedClock);
|
||||
|
||||
// 90 min old pending — past the 1h pendingTimeout cutoff → error.
|
||||
const stale = await svc.enqueue({ poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, streaming: false, ownerId: 'o' });
|
||||
// Backdate via direct fixture mutation — easier than wiring a clock through enqueue.
|
||||
const staleRow = (await svc.findById(stale.id))!;
|
||||
(staleRow as { createdAt: Date }).createdAt = new Date(now.getTime() - 90 * 60 * 1000);
|
||||
|
||||
// 30 min old pending — within window, should not be touched.
|
||||
const fresh = await svc.enqueue({ poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, streaming: false, ownerId: 'o' });
|
||||
const freshRow = (await svc.findById(fresh.id))!;
|
||||
(freshRow as { createdAt: Date }).createdAt = new Date(now.getTime() - 30 * 60 * 1000);
|
||||
|
||||
// 8 day-old completed — past the 7d retention → delete.
|
||||
const old = await svc.enqueue({ poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, streaming: false, ownerId: 'o' });
|
||||
await svc.complete(old.id, { ok: 1 });
|
||||
const oldRow = (await svc.findById(old.id))!;
|
||||
(oldRow as { completedAt: Date }).completedAt = new Date(now.getTime() - 8 * 24 * 60 * 60 * 1000);
|
||||
|
||||
const result = await svc.gcSweep({
|
||||
pendingTimeoutMs: 60 * 60 * 1000, // 1h
|
||||
terminalRetentionMs: 7 * 24 * 60 * 60 * 1000, // 7d
|
||||
});
|
||||
expect(result.erroredOut).toBe(1);
|
||||
expect(result.deleted).toBe(1);
|
||||
|
||||
// Stale pending was flipped to error.
|
||||
const staleAfter = await svc.findById(stale.id);
|
||||
expect(staleAfter?.status).toBe('error');
|
||||
expect(staleAfter?.errorMessage).toMatch(/expired in pending/);
|
||||
// Old completed is gone.
|
||||
expect(await svc.findById(old.id)).toBeNull();
|
||||
// Fresh pending untouched.
|
||||
expect((await svc.findById(fresh.id))?.status).toBe('pending');
|
||||
});
|
||||
});
|
||||
@@ -244,6 +244,9 @@ describe('POST /api/v1/llms/:name/infer', () => {
|
||||
'claude',
|
||||
expect.objectContaining({ messages: expect.any(Array) }),
|
||||
false,
|
||||
// v5: direct infer route passes failFast:true so a downed publisher
|
||||
// returns 503 immediately instead of queueing the task durably.
|
||||
{ failFast: true },
|
||||
);
|
||||
});
|
||||
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import { describe, it, expect, vi } from 'vitest';
|
||||
import { VirtualLlmService, type VirtualSessionHandle } from '../src/services/virtual-llm.service.js';
|
||||
import { InferenceTaskService } from '../src/services/inference-task.service.js';
|
||||
import type { IInferenceTaskService } from '../src/services/inference-task.service.js';
|
||||
import type { IInferenceTaskRepository } from '../src/repositories/inference-task.repository.js';
|
||||
import type { ILlmRepository } from '../src/repositories/llm.repository.js';
|
||||
import type { Llm } from '@prisma/client';
|
||||
import type { Llm, InferenceTask, InferenceTaskStatus } from '@prisma/client';
|
||||
|
||||
function makeLlm(overrides: Partial<Llm> = {}): Llm {
|
||||
return {
|
||||
@@ -15,6 +18,7 @@ function makeLlm(overrides: Partial<Llm> = {}): Llm {
|
||||
apiKeySecretId: null,
|
||||
apiKeySecretKey: null,
|
||||
extraConfig: {} as Llm['extraConfig'],
|
||||
poolName: null,
|
||||
kind: 'virtual',
|
||||
providerSessionId: 's-1',
|
||||
lastHeartbeatAt: new Date(),
|
||||
@@ -27,6 +31,105 @@ function makeLlm(overrides: Partial<Llm> = {}): Llm {
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Drop-in mock of `InferenceTaskService` backed by a Map. We mirror just
|
||||
* enough of the real service's signaling — events for terminal/chunk —
|
||||
* so VirtualLlmService's enqueue/result flows behave the same way they
|
||||
* do in production. Nothing here talks to Postgres.
|
||||
*/
|
||||
function mockTaskService(): IInferenceTaskService {
|
||||
// Build a minimal repo for the real InferenceTaskService — that way we
|
||||
// exercise the actual event-emitter wakeup logic, just without a DB.
|
||||
const rows = new Map<string, InferenceTask>();
|
||||
let n = 0;
|
||||
const repo: IInferenceTaskRepository = {
|
||||
create: vi.fn(async (data) => {
|
||||
n += 1;
|
||||
const row: InferenceTask = {
|
||||
id: `task-${String(n)}`,
|
||||
status: 'pending',
|
||||
poolName: data.poolName,
|
||||
llmName: data.llmName,
|
||||
model: data.model,
|
||||
tier: data.tier ?? null,
|
||||
claimedBy: null,
|
||||
requestBody: data.requestBody as InferenceTask['requestBody'],
|
||||
responseBody: null,
|
||||
errorMessage: null,
|
||||
streaming: data.streaming,
|
||||
createdAt: new Date(),
|
||||
claimedAt: null,
|
||||
streamStartedAt: null,
|
||||
completedAt: null,
|
||||
ownerId: data.ownerId,
|
||||
agentId: data.agentId ?? null,
|
||||
};
|
||||
rows.set(row.id, row);
|
||||
return row;
|
||||
}),
|
||||
findById: vi.fn(async (id) => rows.get(id) ?? null),
|
||||
findPendingForPools: vi.fn(async (poolNames: string[]) =>
|
||||
[...rows.values()].filter((r) => r.status === 'pending' && poolNames.includes(r.poolName)),
|
||||
),
|
||||
findHeldBy: vi.fn(async (claimedBy: string) =>
|
||||
[...rows.values()].filter((r) =>
|
||||
r.claimedBy === claimedBy
|
||||
&& (r.status === 'claimed' || r.status === 'running')),
|
||||
),
|
||||
list: vi.fn(async () => [...rows.values()]),
|
||||
tryClaim: vi.fn(async (id, claimedBy, claimedAt) => {
|
||||
const r = rows.get(id);
|
||||
if (r === undefined || r.status !== 'pending') return null;
|
||||
const next = { ...r, status: 'claimed' as InferenceTaskStatus, claimedBy, claimedAt };
|
||||
rows.set(id, next);
|
||||
return next;
|
||||
}),
|
||||
markRunning: vi.fn(async (id, at) => {
|
||||
const r = rows.get(id);
|
||||
if (r === undefined || (r.status !== 'claimed' && r.status !== 'running')) return null;
|
||||
const next = { ...r, status: 'running' as InferenceTaskStatus, streamStartedAt: at };
|
||||
rows.set(id, next);
|
||||
return next;
|
||||
}),
|
||||
markCompleted: vi.fn(async (id, body, at) => {
|
||||
const r = rows.get(id);
|
||||
if (r === undefined || r.status === 'completed' || r.status === 'error' || r.status === 'cancelled') return null;
|
||||
const next = { ...r, status: 'completed' as InferenceTaskStatus, responseBody: (body ?? null) as InferenceTask['responseBody'], completedAt: at };
|
||||
rows.set(id, next);
|
||||
return next;
|
||||
}),
|
||||
markError: vi.fn(async (id, errorMessage, at) => {
|
||||
const r = rows.get(id);
|
||||
if (r === undefined || r.status === 'completed' || r.status === 'error' || r.status === 'cancelled') return null;
|
||||
const next = { ...r, status: 'error' as InferenceTaskStatus, errorMessage, completedAt: at };
|
||||
rows.set(id, next);
|
||||
return next;
|
||||
}),
|
||||
markCancelled: vi.fn(async (id, at) => {
|
||||
const r = rows.get(id);
|
||||
if (r === undefined || r.status === 'completed' || r.status === 'error' || r.status === 'cancelled') return null;
|
||||
const next = { ...r, status: 'cancelled' as InferenceTaskStatus, completedAt: at };
|
||||
rows.set(id, next);
|
||||
return next;
|
||||
}),
|
||||
revertToPending: vi.fn(async (id) => {
|
||||
const r = rows.get(id);
|
||||
if (r === undefined || (r.status !== 'claimed' && r.status !== 'running')) return null;
|
||||
const next = { ...r, status: 'pending' as InferenceTaskStatus, claimedBy: null, claimedAt: null, streamStartedAt: null };
|
||||
rows.set(id, next);
|
||||
return next;
|
||||
}),
|
||||
findStalePending: vi.fn(async () => []),
|
||||
findExpiredTerminal: vi.fn(async () => []),
|
||||
deleteMany: vi.fn(async (ids) => {
|
||||
let c = 0;
|
||||
for (const id of ids) if (rows.delete(id)) c += 1;
|
||||
return c;
|
||||
}),
|
||||
};
|
||||
return new InferenceTaskService(repo);
|
||||
}
|
||||
|
||||
function mockRepo(initial: Llm[] = []): ILlmRepository {
|
||||
const rows = new Map<string, Llm>(initial.map((l) => [l.id, l]));
|
||||
let counter = rows.size;
|
||||
@@ -38,6 +141,14 @@ function mockRepo(initial: Llm[] = []): ILlmRepository {
|
||||
return null;
|
||||
}),
|
||||
findByTier: vi.fn(async () => []),
|
||||
findByPoolName: vi.fn(async (poolName: string) => {
|
||||
const out: Llm[] = [];
|
||||
for (const l of rows.values()) {
|
||||
if (l.poolName === poolName) out.push(l);
|
||||
else if (l.poolName === null && l.name === poolName) out.push(l);
|
||||
}
|
||||
return out;
|
||||
}),
|
||||
findBySessionId: vi.fn(async (sid: string) =>
|
||||
[...rows.values()].filter((l) => l.providerSessionId === sid)),
|
||||
findStaleVirtuals: vi.fn(async (cutoff: Date) =>
|
||||
@@ -105,7 +216,7 @@ function fakeSession(): VirtualSessionHandle & { tasks: Array<unknown>; alive: b
|
||||
describe('VirtualLlmService', () => {
|
||||
it('register inserts new virtual rows with active status + sessionId', async () => {
|
||||
const repo = mockRepo();
|
||||
const svc = new VirtualLlmService(repo);
|
||||
const svc = new VirtualLlmService(repo, undefined, mockTaskService());
|
||||
const { providerSessionId, llms } = await svc.register({
|
||||
providerSessionId: null,
|
||||
providers: [
|
||||
@@ -122,7 +233,7 @@ describe('VirtualLlmService', () => {
|
||||
|
||||
it('register reuses the same row on sticky reconnect (same name + sessionId)', async () => {
|
||||
const repo = mockRepo();
|
||||
const svc = new VirtualLlmService(repo);
|
||||
const svc = new VirtualLlmService(repo, undefined, mockTaskService());
|
||||
const first = await svc.register({
|
||||
providerSessionId: 'fixed-id',
|
||||
providers: [{ name: 'vllm-local', type: 'openai', model: 'm' }],
|
||||
@@ -140,7 +251,7 @@ describe('VirtualLlmService', () => {
|
||||
|
||||
it('register refuses to overwrite a public LLM with the same name', async () => {
|
||||
const repo = mockRepo([makeLlm({ name: 'qwen3-thinking', kind: 'public', providerSessionId: null })]);
|
||||
const svc = new VirtualLlmService(repo);
|
||||
const svc = new VirtualLlmService(repo, undefined, mockTaskService());
|
||||
await expect(svc.register({
|
||||
providerSessionId: 'sess-x',
|
||||
providers: [{ name: 'qwen3-thinking', type: 'openai', model: 'm' }],
|
||||
@@ -149,7 +260,7 @@ describe('VirtualLlmService', () => {
|
||||
|
||||
it('register refuses if another active session owns the name', async () => {
|
||||
const repo = mockRepo([makeLlm({ name: 'vllm-local', providerSessionId: 'other', status: 'active' })]);
|
||||
const svc = new VirtualLlmService(repo);
|
||||
const svc = new VirtualLlmService(repo, undefined, mockTaskService());
|
||||
await expect(svc.register({
|
||||
providerSessionId: 'mine',
|
||||
providers: [{ name: 'vllm-local', type: 'openai', model: 'm' }],
|
||||
@@ -161,7 +272,7 @@ describe('VirtualLlmService', () => {
|
||||
name: 'vllm-local', providerSessionId: 'old-session',
|
||||
status: 'inactive', inactiveSince: new Date(),
|
||||
})]);
|
||||
const svc = new VirtualLlmService(repo);
|
||||
const svc = new VirtualLlmService(repo, undefined, mockTaskService());
|
||||
const { llms } = await svc.register({
|
||||
providerSessionId: 'new-session',
|
||||
providers: [{ name: 'vllm-local', type: 'openai', model: 'm' }],
|
||||
@@ -177,7 +288,7 @@ describe('VirtualLlmService', () => {
|
||||
name: 'vllm-local', providerSessionId: 'sess', status: 'inactive',
|
||||
lastHeartbeatAt: past, inactiveSince: past,
|
||||
})]);
|
||||
const svc = new VirtualLlmService(repo);
|
||||
const svc = new VirtualLlmService(repo, undefined, mockTaskService());
|
||||
await svc.heartbeat('sess');
|
||||
const row = await repo.findByName('vllm-local');
|
||||
expect(row?.status).toBe('active');
|
||||
@@ -191,7 +302,7 @@ describe('VirtualLlmService', () => {
|
||||
makeLlm({ name: 'b', providerSessionId: 'sess' }),
|
||||
makeLlm({ name: 'c', providerSessionId: 'other' }),
|
||||
]);
|
||||
const svc = new VirtualLlmService(repo);
|
||||
const svc = new VirtualLlmService(repo, undefined, mockTaskService());
|
||||
svc.bindSession('sess', fakeSession());
|
||||
await svc.unbindSession('sess');
|
||||
expect((await repo.findByName('a'))?.status).toBe('inactive');
|
||||
@@ -201,7 +312,7 @@ describe('VirtualLlmService', () => {
|
||||
|
||||
it('enqueueInferTask pushes a task frame to the SSE session', async () => {
|
||||
const repo = mockRepo([makeLlm({ name: 'vllm-local', providerSessionId: 'sess' })]);
|
||||
const svc = new VirtualLlmService(repo);
|
||||
const svc = new VirtualLlmService(repo, undefined, mockTaskService());
|
||||
const session = fakeSession();
|
||||
svc.bindSession('sess', session);
|
||||
|
||||
@@ -218,26 +329,41 @@ describe('VirtualLlmService', () => {
|
||||
expect(t.streaming).toBe(false);
|
||||
});
|
||||
|
||||
it('enqueueInferTask rejects when the publisher is offline (no session bound)', async () => {
|
||||
it('enqueueInferTask queues the task when no session is bound (durable, drains on bind)', async () => {
|
||||
// v5 semantic change: with a durable queue underneath, "no worker
|
||||
// up" no longer rejects — the row stays pending and a future
|
||||
// bindSession drains it. Caller's HTTP handler awaits on ref.done
|
||||
// and bounds itself with INFER_AWAIT_TIMEOUT_MS; from the service's
|
||||
// POV the enqueue itself succeeds.
|
||||
const repo = mockRepo([makeLlm({ name: 'vllm-local', providerSessionId: 'sess' })]);
|
||||
const svc = new VirtualLlmService(repo);
|
||||
await expect(
|
||||
svc.enqueueInferTask('vllm-local', { model: 'm', messages: [] }, false),
|
||||
).rejects.toThrow(/no live SSE session|publisher offline/);
|
||||
const tasks = mockTaskService();
|
||||
const svc = new VirtualLlmService(repo, undefined, tasks);
|
||||
const ref = await svc.enqueueInferTask('vllm-local', { model: 'm', messages: [] }, false);
|
||||
expect(ref.taskId).toMatch(/^task-/);
|
||||
// The row exists and is still pending — no claim happened.
|
||||
const row = await tasks.findById(ref.taskId);
|
||||
expect(row?.status).toBe('pending');
|
||||
expect(row?.claimedBy).toBeNull();
|
||||
});
|
||||
|
||||
it('enqueueInferTask rejects when the row is inactive', async () => {
|
||||
it('enqueueInferTask still queues against an inactive row (pool may have a sibling worker)', async () => {
|
||||
// v5 semantic change: status=inactive on a specific Llm doesn't
|
||||
// mean the pool is dead — another mcplocal publishing the same
|
||||
// poolName might be active. The dispatcher's bindSession drain
|
||||
// matches by poolName, so even a "dead" pin queues correctly.
|
||||
const repo = mockRepo([makeLlm({ name: 'vllm-local', providerSessionId: 'sess', status: 'inactive', inactiveSince: new Date() })]);
|
||||
const svc = new VirtualLlmService(repo);
|
||||
svc.bindSession('sess', fakeSession());
|
||||
await expect(
|
||||
svc.enqueueInferTask('vllm-local', { model: 'm', messages: [] }, false),
|
||||
).rejects.toThrow(/inactive|publisher offline/);
|
||||
const tasks = mockTaskService();
|
||||
const svc = new VirtualLlmService(repo, undefined, tasks);
|
||||
const ref = await svc.enqueueInferTask('vllm-local', { model: 'm', messages: [] }, false);
|
||||
const row = await tasks.findById(ref.taskId);
|
||||
expect(row?.status).toBe('pending');
|
||||
// No frame pushed because no session is bound.
|
||||
expect(row?.claimedBy).toBeNull();
|
||||
});
|
||||
|
||||
it('enqueueInferTask rejects when the LLM is public (not virtual)', async () => {
|
||||
const repo = mockRepo([makeLlm({ name: 'qwen3-thinking', kind: 'public', providerSessionId: null })]);
|
||||
const svc = new VirtualLlmService(repo);
|
||||
const svc = new VirtualLlmService(repo, undefined, mockTaskService());
|
||||
await expect(
|
||||
svc.enqueueInferTask('qwen3-thinking', { model: 'm', messages: [] }, false),
|
||||
).rejects.toThrow(/not a virtual provider/);
|
||||
@@ -245,7 +371,7 @@ describe('VirtualLlmService', () => {
|
||||
|
||||
it('completeTask resolves the pending non-streaming promise', async () => {
|
||||
const repo = mockRepo([makeLlm({ name: 'vllm-local', providerSessionId: 'sess' })]);
|
||||
const svc = new VirtualLlmService(repo);
|
||||
const svc = new VirtualLlmService(repo, undefined, mockTaskService());
|
||||
svc.bindSession('sess', fakeSession());
|
||||
const ref = await svc.enqueueInferTask(
|
||||
'vllm-local',
|
||||
@@ -258,7 +384,7 @@ describe('VirtualLlmService', () => {
|
||||
|
||||
it('streaming: pushTaskChunk fans chunks to subscribers; done resolves the ref', async () => {
|
||||
const repo = mockRepo([makeLlm({ name: 'vllm-local', providerSessionId: 'sess' })]);
|
||||
const svc = new VirtualLlmService(repo);
|
||||
const svc = new VirtualLlmService(repo, undefined, mockTaskService());
|
||||
svc.bindSession('sess', fakeSession());
|
||||
const ref = await svc.enqueueInferTask(
|
||||
'vllm-local',
|
||||
@@ -278,7 +404,7 @@ describe('VirtualLlmService', () => {
|
||||
|
||||
it('failTask rejects the pending promise with a clear error', async () => {
|
||||
const repo = mockRepo([makeLlm({ name: 'vllm-local', providerSessionId: 'sess' })]);
|
||||
const svc = new VirtualLlmService(repo);
|
||||
const svc = new VirtualLlmService(repo, undefined, mockTaskService());
|
||||
svc.bindSession('sess', fakeSession());
|
||||
const ref = await svc.enqueueInferTask(
|
||||
'vllm-local',
|
||||
@@ -289,17 +415,34 @@ describe('VirtualLlmService', () => {
|
||||
await expect(ref.done).rejects.toThrow(/upstream blew up/);
|
||||
});
|
||||
|
||||
it('unbindSession rejects in-flight tasks for that session', async () => {
|
||||
it('unbindSession reverts claimed inference tasks to pending (durable re-queue, not reject)', async () => {
|
||||
// v5 semantic change: a worker disconnecting mid-task no longer
|
||||
// *rejects* the task. The row goes back to pending so another
|
||||
// worker on the same pool can pick it up. The original caller's
|
||||
// ref.done keeps waiting up to its 10-min INFER_AWAIT_TIMEOUT_MS;
|
||||
// the same caller is what gets the result whichever worker
|
||||
// ultimately delivers it.
|
||||
const repo = mockRepo([makeLlm({ name: 'vllm-local', providerSessionId: 'sess' })]);
|
||||
const svc = new VirtualLlmService(repo);
|
||||
const tasks = mockTaskService();
|
||||
const svc = new VirtualLlmService(repo, undefined, tasks);
|
||||
svc.bindSession('sess', fakeSession());
|
||||
const ref = await svc.enqueueInferTask(
|
||||
'vllm-local',
|
||||
{ model: 'm', messages: [{ role: 'user', content: 'hi' }] },
|
||||
false,
|
||||
);
|
||||
// After enqueue with a session up, the task is claimed.
|
||||
let row = await tasks.findById(ref.taskId);
|
||||
expect(row?.status).toBe('claimed');
|
||||
expect(row?.claimedBy).toBe('sess');
|
||||
|
||||
await svc.unbindSession('sess');
|
||||
await expect(ref.done).rejects.toThrow(/publisher disconnected/);
|
||||
|
||||
// After disconnect, claimed/running rows revert to pending — ready
|
||||
// for the next worker to drain.
|
||||
row = await tasks.findById(ref.taskId);
|
||||
expect(row?.status).toBe('pending');
|
||||
expect(row?.claimedBy).toBeNull();
|
||||
});
|
||||
|
||||
it('gcSweep flips heartbeat-stale active virtuals to inactive', async () => {
|
||||
@@ -309,7 +452,7 @@ describe('VirtualLlmService', () => {
|
||||
makeLlm({ name: 'stale', providerSessionId: 'a', status: 'active', lastHeartbeatAt: long }),
|
||||
makeLlm({ name: 'fresh', providerSessionId: 'b', status: 'active', lastHeartbeatAt: recent }),
|
||||
]);
|
||||
const svc = new VirtualLlmService(repo);
|
||||
const svc = new VirtualLlmService(repo, undefined, mockTaskService());
|
||||
const result = await svc.gcSweep();
|
||||
expect(result.markedInactive).toBe(1);
|
||||
expect((await repo.findByName('stale'))?.status).toBe('inactive');
|
||||
@@ -324,7 +467,7 @@ describe('VirtualLlmService', () => {
|
||||
makeLlm({ name: 'recent', providerSessionId: 'b', status: 'inactive', inactiveSince: fresh }),
|
||||
makeLlm({ name: 'public-survivor', providerSessionId: null, kind: 'public' }),
|
||||
]);
|
||||
const svc = new VirtualLlmService(repo);
|
||||
const svc = new VirtualLlmService(repo, undefined, mockTaskService());
|
||||
const result = await svc.gcSweep();
|
||||
expect(result.deleted).toBe(1);
|
||||
expect(await repo.findByName('old')).toBeNull();
|
||||
@@ -336,7 +479,7 @@ describe('VirtualLlmService', () => {
|
||||
|
||||
it('hibernating: dispatches a wake task first and waits for it to complete before infer', async () => {
|
||||
const repo = mockRepo([makeLlm({ name: 'sleeping', providerSessionId: 'sess', status: 'hibernating' })]);
|
||||
const svc = new VirtualLlmService(repo);
|
||||
const svc = new VirtualLlmService(repo, undefined, mockTaskService());
|
||||
const session = fakeSession();
|
||||
svc.bindSession('sess', session);
|
||||
|
||||
@@ -370,7 +513,7 @@ describe('VirtualLlmService', () => {
|
||||
|
||||
it('hibernating: concurrent infer requests share a single wake task', async () => {
|
||||
const repo = mockRepo([makeLlm({ name: 'sleeping', providerSessionId: 'sess', status: 'hibernating' })]);
|
||||
const svc = new VirtualLlmService(repo);
|
||||
const svc = new VirtualLlmService(repo, undefined, mockTaskService());
|
||||
const session = fakeSession();
|
||||
svc.bindSession('sess', session);
|
||||
|
||||
@@ -398,7 +541,7 @@ describe('VirtualLlmService', () => {
|
||||
|
||||
it('hibernating: rejects when the wake task fails', async () => {
|
||||
const repo = mockRepo([makeLlm({ name: 'broken', providerSessionId: 'sess', status: 'hibernating' })]);
|
||||
const svc = new VirtualLlmService(repo);
|
||||
const svc = new VirtualLlmService(repo, undefined, mockTaskService());
|
||||
svc.bindSession('sess', fakeSession());
|
||||
|
||||
const inferPromise = svc.enqueueInferTask(
|
||||
@@ -408,12 +551,12 @@ describe('VirtualLlmService', () => {
|
||||
);
|
||||
await new Promise((r) => setTimeout(r, 0));
|
||||
|
||||
// Get the wake task id from the in-flight tasks map (its only entry).
|
||||
// We test the failure path via failTask which is part of the public
|
||||
// surface used by the result-POST route handler.
|
||||
// v5: wake tasks live in `wakeTasks` (in-memory). Inference tasks
|
||||
// moved to the DB-backed queue but wake is publisher-control work
|
||||
// that doesn't need durability — we kept the in-memory map for it.
|
||||
const taskIds: string[] = [];
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
for (const id of (svc as any).tasksById.keys()) taskIds.push(id);
|
||||
for (const id of (svc as any).wakeTasks.keys()) taskIds.push(id);
|
||||
expect(taskIds).toHaveLength(1);
|
||||
expect(svc.failTask(taskIds[0]!, new Error('wake recipe failed'))).toBe(true);
|
||||
|
||||
@@ -424,14 +567,28 @@ describe('VirtualLlmService', () => {
|
||||
expect(row?.status).toBe('hibernating');
|
||||
});
|
||||
|
||||
it('inactive: still rejects with 503 (publisher offline) — wake path only fires for hibernating', async () => {
|
||||
it('inactive: queues without firing the wake path — wake only triggers on status=hibernating', async () => {
|
||||
// Coverage for the v5 inactive-vs-hibernating distinction.
|
||||
// hibernating = "publisher told us the backend is asleep, ask
|
||||
// them to wake it"; inactive = "publisher itself is offline".
|
||||
// For inactive rows, queueing is the right behavior (wait for a
|
||||
// worker on the pool to come online and drain). The wake path
|
||||
// must NOT fire — wake is opt-in via the publisher's register
|
||||
// payload, not a generic "row is down" recovery.
|
||||
//
|
||||
// No session bind here: an "inactive" row in production means
|
||||
// unbindSession already flipped it after SSE close. Binding a
|
||||
// session for the same providerSessionId would be a contradictory
|
||||
// setup that the test wouldn't model anything real about.
|
||||
const repo = mockRepo([makeLlm({ name: 'gone', providerSessionId: 'sess', status: 'inactive', inactiveSince: new Date() })]);
|
||||
const svc = new VirtualLlmService(repo);
|
||||
svc.bindSession('sess', fakeSession());
|
||||
const tasks = mockTaskService();
|
||||
const svc = new VirtualLlmService(repo, undefined, tasks);
|
||||
|
||||
await expect(
|
||||
svc.enqueueInferTask('gone', { model: 'm', messages: [] }, false),
|
||||
).rejects.toThrow(/inactive|publisher offline/);
|
||||
const ref = await svc.enqueueInferTask('gone', { model: 'm', messages: [] }, false);
|
||||
// Task queued in pending; no claim, no frame.
|
||||
const row = await tasks.findById(ref.taskId);
|
||||
expect(row?.status).toBe('pending');
|
||||
expect(row?.claimedBy).toBeNull();
|
||||
});
|
||||
|
||||
it('gcSweep is idempotent — running twice in a row is a no-op the second time', async () => {
|
||||
@@ -439,11 +596,75 @@ describe('VirtualLlmService', () => {
|
||||
const repo = mockRepo([
|
||||
makeLlm({ name: 'stale', providerSessionId: 'a', status: 'active', lastHeartbeatAt: long }),
|
||||
]);
|
||||
const svc = new VirtualLlmService(repo);
|
||||
const svc = new VirtualLlmService(repo, undefined, mockTaskService());
|
||||
const first = await svc.gcSweep();
|
||||
const second = await svc.gcSweep();
|
||||
expect(first.markedInactive).toBe(1);
|
||||
expect(second.markedInactive).toBe(0);
|
||||
expect(second.deleted).toBe(0);
|
||||
});
|
||||
|
||||
// ── v5: durable queue + drain-on-bind ──
|
||||
|
||||
it('bindSession drains pending inference tasks owned by the session\'s pool keys', async () => {
|
||||
// Two enqueues land while no session is bound. Each row is created
|
||||
// with status=pending; no SSE frame goes anywhere. When the worker
|
||||
// finally binds, the drain loop claims both and pushes the frames.
|
||||
const repo = mockRepo([makeLlm({ name: 'vllm-local', providerSessionId: 'sess', poolName: 'qwen-pool' })]);
|
||||
const tasks = mockTaskService();
|
||||
const svc = new VirtualLlmService(repo, undefined, tasks);
|
||||
const ref1 = await svc.enqueueInferTask('vllm-local', { model: 'm', messages: [{ role: 'user', content: 'one' }] }, false);
|
||||
const ref2 = await svc.enqueueInferTask('vllm-local', { model: 'm', messages: [{ role: 'user', content: 'two' }] }, false);
|
||||
// Both rows are still pending — no worker bound yet.
|
||||
expect((await tasks.findById(ref1.taskId))?.status).toBe('pending');
|
||||
expect((await tasks.findById(ref2.taskId))?.status).toBe('pending');
|
||||
|
||||
// Worker shows up. Drain runs synchronously enough that we just
|
||||
// need to flush the microtask queue before checking SSE frames.
|
||||
const session = fakeSession();
|
||||
svc.bindSession('sess', session);
|
||||
// drainPendingForSession is fired with `void` so let microtasks
|
||||
// settle before asserting.
|
||||
await new Promise((r) => setTimeout(r, 0));
|
||||
|
||||
expect((session.tasks as Array<{ kind: string; taskId: string }>).map((t) => t.taskId).sort())
|
||||
.toEqual([ref1.taskId, ref2.taskId].sort());
|
||||
// Rows are now claimed by this session.
|
||||
expect((await tasks.findById(ref1.taskId))?.status).toBe('claimed');
|
||||
expect((await tasks.findById(ref1.taskId))?.claimedBy).toBe('sess');
|
||||
});
|
||||
|
||||
it('drain-on-bind matches the effective pool key, not just llm.name', async () => {
|
||||
// The pinned Llm has name=vllm-alice but poolName=qwen-pool.
|
||||
// Enqueue against vllm-alice → row.poolName=qwen-pool.
|
||||
// Worker binds with a session that owns vllm-alice (same pool key).
|
||||
// Drain must surface this row even though poolName != name.
|
||||
const repo = mockRepo([makeLlm({ name: 'vllm-alice', providerSessionId: 'sess', poolName: 'qwen-pool' })]);
|
||||
const tasks = mockTaskService();
|
||||
const svc = new VirtualLlmService(repo, undefined, tasks);
|
||||
const ref = await svc.enqueueInferTask('vllm-alice', { model: 'm', messages: [] }, false);
|
||||
expect((await tasks.findById(ref.taskId))?.poolName).toBe('qwen-pool');
|
||||
|
||||
const session = fakeSession();
|
||||
svc.bindSession('sess', session);
|
||||
await new Promise((r) => setTimeout(r, 0));
|
||||
expect((session.tasks as Array<{ taskId: string }>).map((t) => t.taskId)).toEqual([ref.taskId]);
|
||||
});
|
||||
|
||||
it('completeTask via the result-route updates the DB row + emits the wakeup', async () => {
|
||||
// End-to-end through the public surface: enqueue → claim happens
|
||||
// because session is bound → worker POSTs result → completeTask
|
||||
// routes to InferenceTaskService.complete → ref.done resolves.
|
||||
const repo = mockRepo([makeLlm({ name: 'vllm-local', providerSessionId: 'sess' })]);
|
||||
const tasks = mockTaskService();
|
||||
const svc = new VirtualLlmService(repo, undefined, tasks);
|
||||
svc.bindSession('sess', fakeSession());
|
||||
const ref = await svc.enqueueInferTask('vllm-local', { model: 'm', messages: [] }, false);
|
||||
|
||||
expect(svc.completeTask(ref.taskId, { status: 200, body: { ok: true } })).toBe(true);
|
||||
await expect(ref.done).resolves.toEqual({ status: 200, body: { ok: true } });
|
||||
const row = await tasks.findById(ref.taskId);
|
||||
expect(row?.status).toBe('completed');
|
||||
expect(row?.responseBody).toEqual({ ok: true });
|
||||
});
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user