Compare commits
2 Commits
main
...
feat/infer
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7b18bb6d6b | ||
|
|
ed21ad1b5a |
@@ -0,0 +1,39 @@
|
|||||||
|
-- v5: durable inference task queue. Every inference call (sync infer,
|
||||||
|
-- agent chat, or async POST /inference-tasks) gets a row here. Workers
|
||||||
|
-- (mcplocal sessions) drain pending rows whose `poolName` matches the
|
||||||
|
-- pool keys they own when they bind their SSE channel.
|
||||||
|
CREATE TYPE "InferenceTaskStatus" AS ENUM ('pending', 'claimed', 'running', 'completed', 'error', 'cancelled');
|
||||||
|
|
||||||
|
CREATE TABLE "InferenceTask" (
|
||||||
|
"id" TEXT NOT NULL,
|
||||||
|
"status" "InferenceTaskStatus" NOT NULL DEFAULT 'pending',
|
||||||
|
"poolName" TEXT NOT NULL,
|
||||||
|
"llmName" TEXT NOT NULL,
|
||||||
|
"model" TEXT NOT NULL,
|
||||||
|
"tier" TEXT,
|
||||||
|
"claimedBy" TEXT,
|
||||||
|
"requestBody" JSONB NOT NULL,
|
||||||
|
"responseBody" JSONB,
|
||||||
|
"errorMessage" TEXT,
|
||||||
|
"streaming" BOOLEAN NOT NULL DEFAULT false,
|
||||||
|
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
"claimedAt" TIMESTAMP(3),
|
||||||
|
"streamStartedAt" TIMESTAMP(3),
|
||||||
|
"completedAt" TIMESTAMP(3),
|
||||||
|
"ownerId" TEXT NOT NULL,
|
||||||
|
"agentId" TEXT,
|
||||||
|
|
||||||
|
CONSTRAINT "InferenceTask_pkey" PRIMARY KEY ("id")
|
||||||
|
);
|
||||||
|
|
||||||
|
-- Worker claim path: SELECT … WHERE status='pending' AND poolName IN (…)
|
||||||
|
-- runs on every SSE bind. Compound index keeps that fast as the table grows.
|
||||||
|
CREATE INDEX "InferenceTask_status_poolName_idx" ON "InferenceTask"("status", "poolName");
|
||||||
|
-- unbindSession revert path: SELECT … WHERE claimedBy=$1 AND status IN ('claimed','running').
|
||||||
|
CREATE INDEX "InferenceTask_claimedBy_idx" ON "InferenceTask"("claimedBy");
|
||||||
|
-- Owner scoping for the async API listing.
|
||||||
|
CREATE INDEX "InferenceTask_ownerId_idx" ON "InferenceTask"("ownerId");
|
||||||
|
CREATE INDEX "InferenceTask_agentId_idx" ON "InferenceTask"("agentId");
|
||||||
|
-- GC sweep predicate: completedAt < now()-7d. Indexed so the daily cleanup
|
||||||
|
-- doesn't seq-scan once the table grows past a few thousand rows.
|
||||||
|
CREATE INDEX "InferenceTask_completedAt_idx" ON "InferenceTask"("completedAt");
|
||||||
@@ -604,6 +604,79 @@ model ChatMessage {
|
|||||||
@@index([threadId, createdAt])
|
@@index([threadId, createdAt])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ── Inference Tasks (v5) ──
|
||||||
|
//
|
||||||
|
// Every inference call (sync infer, agent chat, async POST /inference-tasks)
|
||||||
|
// creates a row here. The DB is the source of truth; mcpd's previous
|
||||||
|
// in-memory `pendingTasks` map is gone — the result-handler updates the row
|
||||||
|
// and an in-process EventEmitter wakes any blocked HTTP handlers (single-
|
||||||
|
// instance for now; multi-instance scaling is a v6 concern that will swap
|
||||||
|
// the emitter for pg LISTEN/NOTIFY without changing the data model).
|
||||||
|
//
|
||||||
|
// Routing: `poolName` is the effective pool key at enqueue time
|
||||||
|
// (`Llm.poolName ?? Llm.name`). Workers (mcplocal sessions) drain pending
|
||||||
|
// rows whose `poolName` matches the pool keys they own when they bind their
|
||||||
|
// SSE channel — that's how queued tasks survive worker offline windows.
|
||||||
|
|
||||||
|
enum InferenceTaskStatus {
|
||||||
|
pending // in queue, no worker has it yet (or claim was reverted)
|
||||||
|
claimed // a worker has it (SSE frame sent), no chunks back yet
|
||||||
|
running // worker started streaming chunks back (streaming tasks only)
|
||||||
|
completed // worker POSTed the final result
|
||||||
|
error // permanent failure (auth, bad request, queue timeout)
|
||||||
|
cancelled // caller said never mind via DELETE
|
||||||
|
}
|
||||||
|
|
||||||
|
model InferenceTask {
|
||||||
|
id String @id @default(cuid())
|
||||||
|
status InferenceTaskStatus @default(pending)
|
||||||
|
// Routing — pool key drives worker matching at claim time. Stored at
|
||||||
|
// enqueue time so a later rename of Llm.poolName doesn't reroute
|
||||||
|
// already-queued work.
|
||||||
|
poolName String
|
||||||
|
llmName String // pinned target Llm name (for audit + agent backref)
|
||||||
|
model String
|
||||||
|
tier String?
|
||||||
|
// Worker tracking. NULL while pending; set on claim; cleared on
|
||||||
|
// unbindSession-driven revert (worker disconnect mid-task).
|
||||||
|
claimedBy String?
|
||||||
|
// Body + result. Both are Json so streaming chunks can be reconstructed
|
||||||
|
// (see TaskService.complete) and async pollers get a structured payload.
|
||||||
|
// requestBody is required (the OpenAI chat-completion request body the
|
||||||
|
// worker should run); responseBody is null until status=completed.
|
||||||
|
requestBody Json
|
||||||
|
responseBody Json?
|
||||||
|
errorMessage String?
|
||||||
|
/**
|
||||||
|
* Whether the original request asked for streaming. Drives the chunk-vs-
|
||||||
|
* final-body protocol on the result POST and tells async API callers
|
||||||
|
* whether `/stream` will yield chunks or just a single completion event.
|
||||||
|
*/
|
||||||
|
streaming Boolean @default(false)
|
||||||
|
// Timestamps for observability + GC:
|
||||||
|
// pending → claimed: claimedAt set
|
||||||
|
// claimed → running: streamStartedAt set (first chunk received)
|
||||||
|
// running/claimed → completed/error/cancelled: completedAt set
|
||||||
|
createdAt DateTime @default(now())
|
||||||
|
claimedAt DateTime?
|
||||||
|
streamStartedAt DateTime?
|
||||||
|
completedAt DateTime?
|
||||||
|
// Caller tracking — RBAC + observability. ownerId references User.id;
|
||||||
|
// agentId is set when the task came in via /agents/<name>/chat (null
|
||||||
|
// for direct /llms/<name>/infer or async POST /inference-tasks calls
|
||||||
|
// that don't pin an agent).
|
||||||
|
ownerId String
|
||||||
|
agentId String?
|
||||||
|
|
||||||
|
@@index([status, poolName])
|
||||||
|
@@index([claimedBy])
|
||||||
|
@@index([ownerId])
|
||||||
|
@@index([agentId])
|
||||||
|
// GC sweep predicate: completedAt < 7d ago. Index speeds up the daily
|
||||||
|
// cleanup without scanning the whole table.
|
||||||
|
@@index([completedAt])
|
||||||
|
}
|
||||||
|
|
||||||
// ── Audit Logs ──
|
// ── Audit Logs ──
|
||||||
|
|
||||||
model AuditLog {
|
model AuditLog {
|
||||||
|
|||||||
169
src/db/tests/inference-task-schema.test.ts
Normal file
169
src/db/tests/inference-task-schema.test.ts
Normal file
@@ -0,0 +1,169 @@
|
|||||||
|
/**
|
||||||
|
* v5 db-level tests for the InferenceTask queue. Exercises the actual
|
||||||
|
* column shapes + index lookups; the mcpd-side service tests cover the
|
||||||
|
* state machine + signal channels with a mocked repo.
|
||||||
|
*/
|
||||||
|
import { describe, it, expect, beforeAll, afterAll, beforeEach } from 'vitest';
|
||||||
|
import type { PrismaClient } from '@prisma/client';
|
||||||
|
import { setupTestDb, cleanupTestDb, clearAllTables } from './helpers.js';
|
||||||
|
|
||||||
|
async function makeOwner(prisma: PrismaClient): Promise<string> {
|
||||||
|
const u = await prisma.user.create({
|
||||||
|
data: { email: `owner-${String(Date.now())}@test`, passwordHash: 'x' },
|
||||||
|
});
|
||||||
|
return u.id;
|
||||||
|
}
|
||||||
|
|
||||||
|
describe('InferenceTask schema (v5)', () => {
|
||||||
|
let prisma: PrismaClient;
|
||||||
|
|
||||||
|
beforeAll(async () => {
|
||||||
|
prisma = await setupTestDb();
|
||||||
|
}, 30_000);
|
||||||
|
|
||||||
|
afterAll(async () => {
|
||||||
|
await cleanupTestDb();
|
||||||
|
});
|
||||||
|
|
||||||
|
beforeEach(async () => {
|
||||||
|
await clearAllTables(prisma);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('defaults a fresh row to status=pending with claim/completion fields null', async () => {
|
||||||
|
const ownerId = await makeOwner(prisma);
|
||||||
|
const row = await prisma.inferenceTask.create({
|
||||||
|
data: {
|
||||||
|
poolName: 'qwen-pool',
|
||||||
|
llmName: 'qwen-prod-1',
|
||||||
|
model: 'qwen3-thinking',
|
||||||
|
requestBody: { messages: [{ role: 'user', content: 'hi' }] },
|
||||||
|
ownerId,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
expect(row.status).toBe('pending');
|
||||||
|
expect(row.claimedBy).toBeNull();
|
||||||
|
expect(row.claimedAt).toBeNull();
|
||||||
|
expect(row.streamStartedAt).toBeNull();
|
||||||
|
expect(row.completedAt).toBeNull();
|
||||||
|
expect(row.responseBody).toBeNull();
|
||||||
|
expect(row.streaming).toBe(false);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('roundtrips streaming=true and a structured requestBody/responseBody', async () => {
|
||||||
|
const ownerId = await makeOwner(prisma);
|
||||||
|
const requestBody = {
|
||||||
|
messages: [{ role: 'user', content: 'hello' }],
|
||||||
|
temperature: 0.2,
|
||||||
|
tools: [{ type: 'function', function: { name: 'noop' } }],
|
||||||
|
};
|
||||||
|
const row = await prisma.inferenceTask.create({
|
||||||
|
data: {
|
||||||
|
poolName: 'qwen-pool',
|
||||||
|
llmName: 'qwen-prod-1',
|
||||||
|
model: 'qwen3',
|
||||||
|
requestBody,
|
||||||
|
streaming: true,
|
||||||
|
ownerId,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
expect(row.streaming).toBe(true);
|
||||||
|
expect(row.requestBody).toEqual(requestBody);
|
||||||
|
|
||||||
|
const completedAt = new Date();
|
||||||
|
const responseBody = { choices: [{ message: { role: 'assistant', content: 'world' } }] };
|
||||||
|
const updated = await prisma.inferenceTask.update({
|
||||||
|
where: { id: row.id },
|
||||||
|
data: { status: 'completed', responseBody, completedAt },
|
||||||
|
});
|
||||||
|
expect(updated.responseBody).toEqual(responseBody);
|
||||||
|
expect(updated.completedAt?.getTime()).toBe(completedAt.getTime());
|
||||||
|
});
|
||||||
|
|
||||||
|
it('compound index supports the dispatcher\'s drain query (status + poolName IN ...)', async () => {
|
||||||
|
// The actual EXPLAIN/index-use check is too brittle for unit tests;
|
||||||
|
// here we verify the QUERY shape that the repo's findPendingForPools
|
||||||
|
// issues — same WHERE/ORDER BY — returns the expected rows in FIFO
|
||||||
|
// order. Index usage is implied by the Prisma model definition.
|
||||||
|
const ownerId = await makeOwner(prisma);
|
||||||
|
const t1 = await prisma.inferenceTask.create({
|
||||||
|
data: { poolName: 'pool-a', llmName: 'a-1', model: 'm', requestBody: {}, ownerId },
|
||||||
|
});
|
||||||
|
await new Promise((r) => setTimeout(r, 5));
|
||||||
|
const t2 = await prisma.inferenceTask.create({
|
||||||
|
data: { poolName: 'pool-a', llmName: 'a-2', model: 'm', requestBody: {}, ownerId },
|
||||||
|
});
|
||||||
|
await prisma.inferenceTask.create({
|
||||||
|
data: { poolName: 'pool-b', llmName: 'b-1', model: 'm', requestBody: {}, ownerId },
|
||||||
|
});
|
||||||
|
// One row in pool-a is no longer pending — must be excluded.
|
||||||
|
await prisma.inferenceTask.create({
|
||||||
|
data: { poolName: 'pool-a', llmName: 'a-3', model: 'm', requestBody: {}, ownerId, status: 'completed' },
|
||||||
|
});
|
||||||
|
|
||||||
|
const drained = await prisma.inferenceTask.findMany({
|
||||||
|
where: { status: 'pending', poolName: { in: ['pool-a', 'pool-b'] } },
|
||||||
|
orderBy: { createdAt: 'asc' },
|
||||||
|
});
|
||||||
|
expect(drained.map((r) => r.id)).toEqual([t1.id, t2.id, drained[2]!.id]);
|
||||||
|
expect(drained.map((r) => r.poolName)).toEqual(['pool-a', 'pool-a', 'pool-b']);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('claimedBy index supports unbindSession revert (worker disconnect path)', async () => {
|
||||||
|
const ownerId = await makeOwner(prisma);
|
||||||
|
await prisma.inferenceTask.create({
|
||||||
|
data: { poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, ownerId, status: 'claimed', claimedBy: 'sess-A' },
|
||||||
|
});
|
||||||
|
await prisma.inferenceTask.create({
|
||||||
|
data: { poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, ownerId, status: 'running', claimedBy: 'sess-A' },
|
||||||
|
});
|
||||||
|
await prisma.inferenceTask.create({
|
||||||
|
data: { poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, ownerId, status: 'claimed', claimedBy: 'sess-B' },
|
||||||
|
});
|
||||||
|
// Completed-but-claimedBy=sess-A row: must NOT revert (terminal state).
|
||||||
|
await prisma.inferenceTask.create({
|
||||||
|
data: { poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, ownerId, status: 'completed', claimedBy: 'sess-A' },
|
||||||
|
});
|
||||||
|
|
||||||
|
const heldByA = await prisma.inferenceTask.findMany({
|
||||||
|
where: { claimedBy: 'sess-A', status: { in: ['claimed', 'running'] } },
|
||||||
|
});
|
||||||
|
expect(heldByA).toHaveLength(2);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('GC predicate (terminal + completedAt < cutoff) is index-friendly and filters correctly', async () => {
|
||||||
|
const ownerId = await makeOwner(prisma);
|
||||||
|
const old = new Date(Date.now() - 8 * 24 * 60 * 60 * 1000); // 8 d ago
|
||||||
|
const recent = new Date(Date.now() - 1 * 60 * 60 * 1000); // 1 h ago
|
||||||
|
await prisma.inferenceTask.create({
|
||||||
|
data: { poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, ownerId, status: 'completed', completedAt: old },
|
||||||
|
});
|
||||||
|
await prisma.inferenceTask.create({
|
||||||
|
data: { poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, ownerId, status: 'error', completedAt: old, errorMessage: 'boom' },
|
||||||
|
});
|
||||||
|
// Inside retention — must not be picked up by GC.
|
||||||
|
await prisma.inferenceTask.create({
|
||||||
|
data: { poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, ownerId, status: 'completed', completedAt: recent },
|
||||||
|
});
|
||||||
|
// Pending row — must not be picked up by terminal GC.
|
||||||
|
await prisma.inferenceTask.create({
|
||||||
|
data: { poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, ownerId, status: 'pending' },
|
||||||
|
});
|
||||||
|
|
||||||
|
const cutoff = new Date(Date.now() - 7 * 24 * 60 * 60 * 1000);
|
||||||
|
const expired = await prisma.inferenceTask.findMany({
|
||||||
|
where: {
|
||||||
|
status: { in: ['completed', 'error', 'cancelled'] },
|
||||||
|
completedAt: { lt: cutoff },
|
||||||
|
},
|
||||||
|
});
|
||||||
|
expect(expired).toHaveLength(2);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('agentId is nullable — direct chat-llm tasks have no agent', async () => {
|
||||||
|
const ownerId = await makeOwner(prisma);
|
||||||
|
const row = await prisma.inferenceTask.create({
|
||||||
|
data: { poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, ownerId, agentId: null },
|
||||||
|
});
|
||||||
|
expect(row.agentId).toBeNull();
|
||||||
|
});
|
||||||
|
});
|
||||||
@@ -32,6 +32,8 @@ import { SecretBackendRotatorLoop } from './services/secret-backend-rotator-loop
|
|||||||
import { registerSecretBackendRotateRoutes } from './routes/secret-backend-rotate.js';
|
import { registerSecretBackendRotateRoutes } from './routes/secret-backend-rotate.js';
|
||||||
import { LlmRepository } from './repositories/llm.repository.js';
|
import { LlmRepository } from './repositories/llm.repository.js';
|
||||||
import { LlmService } from './services/llm.service.js';
|
import { LlmService } from './services/llm.service.js';
|
||||||
|
import { InferenceTaskRepository } from './repositories/inference-task.repository.js';
|
||||||
|
import { InferenceTaskService } from './services/inference-task.service.js';
|
||||||
import { AgentRepository } from './repositories/agent.repository.js';
|
import { AgentRepository } from './repositories/agent.repository.js';
|
||||||
import { ChatRepository } from './repositories/chat.repository.js';
|
import { ChatRepository } from './repositories/chat.repository.js';
|
||||||
import { AgentService } from './services/agent.service.js';
|
import { AgentService } from './services/agent.service.js';
|
||||||
@@ -463,10 +465,17 @@ async function main(): Promise<void> {
|
|||||||
const personalityRepo = new PersonalityRepository(prisma);
|
const personalityRepo = new PersonalityRepository(prisma);
|
||||||
const personalityService = new PersonalityService(personalityRepo, agentRepo, promptRepo);
|
const personalityService = new PersonalityService(personalityRepo, agentRepo, promptRepo);
|
||||||
const agentService = new AgentService(agentRepo, llmService, projectService, personalityRepo);
|
const agentService = new AgentService(agentRepo, llmService, projectService, personalityRepo);
|
||||||
|
// v5: durable inference task queue. VirtualLlmService persists infer
|
||||||
|
// tasks here; the result-route updates them; an in-process emitter
|
||||||
|
// wakes blocked HTTP handlers when results land.
|
||||||
|
const inferenceTaskRepo = new InferenceTaskRepository(prisma);
|
||||||
|
const inferenceTaskService = new InferenceTaskService(inferenceTaskRepo);
|
||||||
// Virtual-provider state machine (kind=virtual rows for both Llms and
|
// Virtual-provider state machine (kind=virtual rows for both Llms and
|
||||||
// Agents). v3 wires AgentService for heartbeat/disconnect/GC cascade.
|
// Agents). v3 wires AgentService for heartbeat/disconnect/GC cascade.
|
||||||
|
// v5 wires inferenceTaskService — enqueueInferTask now persists rows,
|
||||||
|
// worker disconnect reverts claimed rows, worker bind drains pending.
|
||||||
// The 60-s GC ticker is started below after `app.listen`.
|
// The 60-s GC ticker is started below after `app.listen`.
|
||||||
const virtualLlmService = new VirtualLlmService(llmRepo, agentService);
|
const virtualLlmService = new VirtualLlmService(llmRepo, agentService, inferenceTaskService);
|
||||||
// ChatService needs the proxy + project repo via the ChatToolDispatcher
|
// ChatService needs the proxy + project repo via the ChatToolDispatcher
|
||||||
// bridge. The dispatcher's logger references `app.log`, which is not
|
// bridge. The dispatcher's logger references `app.log`, which is not
|
||||||
// constructed until further down — `chatService` itself is built right
|
// constructed until further down — `chatService` itself is built right
|
||||||
|
|||||||
235
src/mcpd/src/repositories/inference-task.repository.ts
Normal file
235
src/mcpd/src/repositories/inference-task.repository.ts
Normal file
@@ -0,0 +1,235 @@
|
|||||||
|
/**
|
||||||
|
* v5 InferenceTask repository — durable queue for inference work.
|
||||||
|
*
|
||||||
|
* Workers (mcplocal SSE sessions) drain rows whose `poolName` matches the
|
||||||
|
* pool keys they own when they bind. The state machine is:
|
||||||
|
*
|
||||||
|
* pending ──claim──> claimed ──first chunk──> running
|
||||||
|
* \ │
|
||||||
|
* \ ▼
|
||||||
|
* ──complete/fail──> completed | error
|
||||||
|
*
|
||||||
|
* Plus orthogonal terminal transitions:
|
||||||
|
*
|
||||||
|
* pending|claimed|running ──cancel──> cancelled
|
||||||
|
* claimed|running ──worker disconnect──> pending (revert + re-queue)
|
||||||
|
*
|
||||||
|
* The DB is the source of truth. mcpd's in-flight HTTP handlers wait via
|
||||||
|
* an in-process EventEmitter (see TaskService) — single-instance for now;
|
||||||
|
* pg LISTEN/NOTIFY is the v6 swap when scaling horizontally.
|
||||||
|
*/
|
||||||
|
import type {
|
||||||
|
PrismaClient,
|
||||||
|
InferenceTask,
|
||||||
|
InferenceTaskStatus,
|
||||||
|
} from '@prisma/client';
|
||||||
|
// `Prisma` is imported as a value because we reference Prisma.JsonNull at
|
||||||
|
// runtime (the sentinel for explicit JSON null on a nullable column).
|
||||||
|
import { Prisma } from '@prisma/client';
|
||||||
|
|
||||||
|
export interface CreateInferenceTaskInput {
|
||||||
|
poolName: string;
|
||||||
|
llmName: string;
|
||||||
|
model: string;
|
||||||
|
tier?: string | null;
|
||||||
|
requestBody: Record<string, unknown>;
|
||||||
|
streaming: boolean;
|
||||||
|
ownerId: string;
|
||||||
|
agentId?: string | null;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface IInferenceTaskRepository {
|
||||||
|
create(data: CreateInferenceTaskInput): Promise<InferenceTask>;
|
||||||
|
findById(id: string): Promise<InferenceTask | null>;
|
||||||
|
/** Pending rows for one or more pool keys, oldest first (FIFO). */
|
||||||
|
findPendingForPools(poolNames: string[], limit?: number): Promise<InferenceTask[]>;
|
||||||
|
/** Tasks held by a worker session — used by unbindSession to revert on disconnect. */
|
||||||
|
findHeldBy(claimedBy: string): Promise<InferenceTask[]>;
|
||||||
|
/** List for the async API; filters are AND-combined. */
|
||||||
|
list(filter: {
|
||||||
|
ownerId?: string;
|
||||||
|
status?: InferenceTaskStatus | InferenceTaskStatus[];
|
||||||
|
poolName?: string;
|
||||||
|
agentId?: string;
|
||||||
|
limit?: number;
|
||||||
|
}): Promise<InferenceTask[]>;
|
||||||
|
|
||||||
|
// ── State transitions ──
|
||||||
|
// Each transition is conditional ("compare-and-swap" via updateMany +
|
||||||
|
// affected-row count) so two workers racing to claim the same row both
|
||||||
|
// succeed at the DB level — one gets affected=1 and a populated row,
|
||||||
|
// the other gets affected=0 and we fall through to the next candidate.
|
||||||
|
|
||||||
|
/** pending → claimed; returns the updated row, or null if the row was no longer pending. */
|
||||||
|
tryClaim(id: string, claimedBy: string, claimedAt: Date): Promise<InferenceTask | null>;
|
||||||
|
/** claimed → running; idempotent — returns the row even if already running. */
|
||||||
|
markRunning(id: string, at: Date): Promise<InferenceTask | null>;
|
||||||
|
/** {claimed,running} → completed; only allowed from a non-terminal state. */
|
||||||
|
markCompleted(id: string, responseBody: Record<string, unknown> | null, at: Date): Promise<InferenceTask | null>;
|
||||||
|
/** any non-terminal → error. Records `errorMessage`. */
|
||||||
|
markError(id: string, errorMessage: string, at: Date): Promise<InferenceTask | null>;
|
||||||
|
/** any non-terminal → cancelled. */
|
||||||
|
markCancelled(id: string, at: Date): Promise<InferenceTask | null>;
|
||||||
|
/** {claimed,running} → pending; clears `claimedBy`. Used on worker disconnect. */
|
||||||
|
revertToPending(id: string): Promise<InferenceTask | null>;
|
||||||
|
|
||||||
|
// ── GC ──
|
||||||
|
|
||||||
|
/** Pending rows older than the cutoff — the GC sweep flips these to error. */
|
||||||
|
findStalePending(cutoff: Date): Promise<InferenceTask[]>;
|
||||||
|
/** Completed/error/cancelled rows older than the cutoff — GC deletes them. */
|
||||||
|
findExpiredTerminal(cutoff: Date): Promise<InferenceTask[]>;
|
||||||
|
/** Bulk delete by id — used by the GC sweep after collecting expired ids. */
|
||||||
|
deleteMany(ids: string[]): Promise<number>;
|
||||||
|
}
|
||||||
|
|
||||||
|
const NON_TERMINAL: InferenceTaskStatus[] = ['pending', 'claimed', 'running'];
|
||||||
|
const HELD: InferenceTaskStatus[] = ['claimed', 'running'];
|
||||||
|
|
||||||
|
export class InferenceTaskRepository implements IInferenceTaskRepository {
|
||||||
|
constructor(private readonly prisma: PrismaClient) {}
|
||||||
|
|
||||||
|
async create(data: CreateInferenceTaskInput): Promise<InferenceTask> {
|
||||||
|
return this.prisma.inferenceTask.create({
|
||||||
|
data: {
|
||||||
|
poolName: data.poolName,
|
||||||
|
llmName: data.llmName,
|
||||||
|
model: data.model,
|
||||||
|
tier: data.tier ?? null,
|
||||||
|
requestBody: data.requestBody as Prisma.InputJsonValue,
|
||||||
|
streaming: data.streaming,
|
||||||
|
ownerId: data.ownerId,
|
||||||
|
agentId: data.agentId ?? null,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
async findById(id: string): Promise<InferenceTask | null> {
|
||||||
|
return this.prisma.inferenceTask.findUnique({ where: { id } });
|
||||||
|
}
|
||||||
|
|
||||||
|
async findPendingForPools(poolNames: string[], limit?: number): Promise<InferenceTask[]> {
|
||||||
|
if (poolNames.length === 0) return [];
|
||||||
|
return this.prisma.inferenceTask.findMany({
|
||||||
|
where: { status: 'pending', poolName: { in: poolNames } },
|
||||||
|
orderBy: { createdAt: 'asc' },
|
||||||
|
...(limit !== undefined ? { take: limit } : {}),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
async findHeldBy(claimedBy: string): Promise<InferenceTask[]> {
|
||||||
|
return this.prisma.inferenceTask.findMany({
|
||||||
|
where: { claimedBy, status: { in: HELD } },
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
async list(filter: {
|
||||||
|
ownerId?: string;
|
||||||
|
status?: InferenceTaskStatus | InferenceTaskStatus[];
|
||||||
|
poolName?: string;
|
||||||
|
agentId?: string;
|
||||||
|
limit?: number;
|
||||||
|
}): Promise<InferenceTask[]> {
|
||||||
|
const where: Prisma.InferenceTaskWhereInput = {};
|
||||||
|
if (filter.ownerId !== undefined) where.ownerId = filter.ownerId;
|
||||||
|
if (filter.status !== undefined) {
|
||||||
|
where.status = Array.isArray(filter.status) ? { in: filter.status } : filter.status;
|
||||||
|
}
|
||||||
|
if (filter.poolName !== undefined) where.poolName = filter.poolName;
|
||||||
|
if (filter.agentId !== undefined) where.agentId = filter.agentId;
|
||||||
|
return this.prisma.inferenceTask.findMany({
|
||||||
|
where,
|
||||||
|
orderBy: { createdAt: 'desc' },
|
||||||
|
...(filter.limit !== undefined ? { take: filter.limit } : {}),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
async tryClaim(id: string, claimedBy: string, claimedAt: Date): Promise<InferenceTask | null> {
|
||||||
|
// Compare-and-swap on status='pending'. Two workers racing both run
|
||||||
|
// this UPDATE; whichever the DB serializes first sees affected=1 and
|
||||||
|
// gets the row, the loser sees affected=0 and falls through.
|
||||||
|
const result = await this.prisma.inferenceTask.updateMany({
|
||||||
|
where: { id, status: 'pending' },
|
||||||
|
data: { status: 'claimed', claimedBy, claimedAt },
|
||||||
|
});
|
||||||
|
if (result.count === 0) return null;
|
||||||
|
return this.findById(id);
|
||||||
|
}
|
||||||
|
|
||||||
|
async markRunning(id: string, at: Date): Promise<InferenceTask | null> {
|
||||||
|
const result = await this.prisma.inferenceTask.updateMany({
|
||||||
|
where: { id, status: { in: ['claimed', 'running'] } },
|
||||||
|
data: { status: 'running', streamStartedAt: at },
|
||||||
|
});
|
||||||
|
if (result.count === 0) return null;
|
||||||
|
return this.findById(id);
|
||||||
|
}
|
||||||
|
|
||||||
|
async markCompleted(id: string, responseBody: Record<string, unknown> | null, at: Date): Promise<InferenceTask | null> {
|
||||||
|
// Prisma's nullable JSON column writes need an explicit
|
||||||
|
// `Prisma.JsonNull` sentinel for null; passing literal null fails
|
||||||
|
// typecheck because the column also accepts a plain JSON value.
|
||||||
|
const bodyForUpdate: Prisma.InputJsonValue | typeof Prisma.JsonNull =
|
||||||
|
responseBody === null ? Prisma.JsonNull : (responseBody as Prisma.InputJsonValue);
|
||||||
|
const result = await this.prisma.inferenceTask.updateMany({
|
||||||
|
where: { id, status: { in: NON_TERMINAL } },
|
||||||
|
data: {
|
||||||
|
status: 'completed',
|
||||||
|
responseBody: bodyForUpdate,
|
||||||
|
completedAt: at,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
if (result.count === 0) return null;
|
||||||
|
return this.findById(id);
|
||||||
|
}
|
||||||
|
|
||||||
|
async markError(id: string, errorMessage: string, at: Date): Promise<InferenceTask | null> {
|
||||||
|
const result = await this.prisma.inferenceTask.updateMany({
|
||||||
|
where: { id, status: { in: NON_TERMINAL } },
|
||||||
|
data: { status: 'error', errorMessage, completedAt: at },
|
||||||
|
});
|
||||||
|
if (result.count === 0) return null;
|
||||||
|
return this.findById(id);
|
||||||
|
}
|
||||||
|
|
||||||
|
async markCancelled(id: string, at: Date): Promise<InferenceTask | null> {
|
||||||
|
const result = await this.prisma.inferenceTask.updateMany({
|
||||||
|
where: { id, status: { in: NON_TERMINAL } },
|
||||||
|
data: { status: 'cancelled', completedAt: at },
|
||||||
|
});
|
||||||
|
if (result.count === 0) return null;
|
||||||
|
return this.findById(id);
|
||||||
|
}
|
||||||
|
|
||||||
|
async revertToPending(id: string): Promise<InferenceTask | null> {
|
||||||
|
const result = await this.prisma.inferenceTask.updateMany({
|
||||||
|
where: { id, status: { in: HELD } },
|
||||||
|
data: { status: 'pending', claimedBy: null, claimedAt: null, streamStartedAt: null },
|
||||||
|
});
|
||||||
|
if (result.count === 0) return null;
|
||||||
|
return this.findById(id);
|
||||||
|
}
|
||||||
|
|
||||||
|
async findStalePending(cutoff: Date): Promise<InferenceTask[]> {
|
||||||
|
return this.prisma.inferenceTask.findMany({
|
||||||
|
where: { status: 'pending', createdAt: { lt: cutoff } },
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
async findExpiredTerminal(cutoff: Date): Promise<InferenceTask[]> {
|
||||||
|
return this.prisma.inferenceTask.findMany({
|
||||||
|
where: {
|
||||||
|
status: { in: ['completed', 'error', 'cancelled'] },
|
||||||
|
completedAt: { lt: cutoff },
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
async deleteMany(ids: string[]): Promise<number> {
|
||||||
|
if (ids.length === 0) return 0;
|
||||||
|
const result = await this.prisma.inferenceTask.deleteMany({
|
||||||
|
where: { id: { in: ids } },
|
||||||
|
});
|
||||||
|
return result.count;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -98,8 +98,12 @@ export function registerLlmInferRoutes(
|
|||||||
return { error: 'virtual LLM dispatch unavailable (server misconfiguration)' };
|
return { error: 'virtual LLM dispatch unavailable (server misconfiguration)' };
|
||||||
}
|
}
|
||||||
try {
|
try {
|
||||||
|
// Direct infer is the v1-v4 sync path: caller pinned to a
|
||||||
|
// specific Llm name and wants a fast 503 if the publisher is
|
||||||
|
// offline. The async durable-queue API (Stage 3) is for callers
|
||||||
|
// that explicitly opt into queueing.
|
||||||
if (!streaming) {
|
if (!streaming) {
|
||||||
const ref = await deps.virtualLlms.enqueueInferTask(llm.name, body, false);
|
const ref = await deps.virtualLlms.enqueueInferTask(llm.name, body, false, { failFast: true });
|
||||||
const result = await ref.done;
|
const result = await ref.done;
|
||||||
reply.code(result.status);
|
reply.code(result.status);
|
||||||
audit(result.status);
|
audit(result.status);
|
||||||
@@ -113,7 +117,7 @@ export function registerLlmInferRoutes(
|
|||||||
Connection: 'keep-alive',
|
Connection: 'keep-alive',
|
||||||
'X-Accel-Buffering': 'no',
|
'X-Accel-Buffering': 'no',
|
||||||
});
|
});
|
||||||
const ref = await deps.virtualLlms.enqueueInferTask(llm.name, body, true);
|
const ref = await deps.virtualLlms.enqueueInferTask(llm.name, body, true, { failFast: true });
|
||||||
const unsubscribe = ref.onChunk((chunk) => writeSseChunk(reply, chunk.data));
|
const unsubscribe = ref.onChunk((chunk) => writeSseChunk(reply, chunk.data));
|
||||||
try {
|
try {
|
||||||
await ref.done;
|
await ref.done;
|
||||||
|
|||||||
@@ -462,10 +462,16 @@ export class ChatService {
|
|||||||
// iterator. Chunks land on the queue from the SSE relay; the
|
// iterator. Chunks land on the queue from the SSE relay; the
|
||||||
// generator drains them in order. ref.done resolves when the
|
// generator drains them in order. ref.done resolves when the
|
||||||
// publisher emits its `[DONE]` marker.
|
// publisher emits its `[DONE]` marker.
|
||||||
|
//
|
||||||
|
// failFast: true — the chat dispatcher's pool failover loop relies
|
||||||
|
// on a fast "transport error" surfacing so it can iterate to the
|
||||||
|
// next candidate. Without it, a downed pool member would queue the
|
||||||
|
// task and the loop would wait 10 min before trying the next one.
|
||||||
const ref = await this.virtualLlms.enqueueInferTask(
|
const ref = await this.virtualLlms.enqueueInferTask(
|
||||||
candidate.llmName,
|
candidate.llmName,
|
||||||
{ ...this.buildBody({ ...ctx, modelOverride: candidate.modelOverride }), stream: true },
|
{ ...this.buildBody({ ...ctx, modelOverride: candidate.modelOverride }), stream: true },
|
||||||
true,
|
true,
|
||||||
|
{ failFast: true },
|
||||||
);
|
);
|
||||||
const queue: Array<{ data: string; done?: boolean }> = [];
|
const queue: Array<{ data: string; done?: boolean }> = [];
|
||||||
let resolveTick: (() => void) | null = null;
|
let resolveTick: (() => void) | null = null;
|
||||||
@@ -544,6 +550,7 @@ export class ChatService {
|
|||||||
candidate.llmName,
|
candidate.llmName,
|
||||||
this.buildBody({ ...ctx, modelOverride: candidate.modelOverride }),
|
this.buildBody({ ...ctx, modelOverride: candidate.modelOverride }),
|
||||||
false,
|
false,
|
||||||
|
{ failFast: true },
|
||||||
);
|
);
|
||||||
return ref.done;
|
return ref.done;
|
||||||
}
|
}
|
||||||
|
|||||||
311
src/mcpd/src/services/inference-task.service.ts
Normal file
311
src/mcpd/src/services/inference-task.service.ts
Normal file
@@ -0,0 +1,311 @@
|
|||||||
|
/**
|
||||||
|
* v5 InferenceTaskService — business logic over the durable task queue.
|
||||||
|
*
|
||||||
|
* The DB row (managed via IInferenceTaskRepository) is the source of truth
|
||||||
|
* for *state*. This service owns the in-process *signaling* layer that
|
||||||
|
* lets blocked HTTP handlers wake up promptly when a worker POSTs a
|
||||||
|
* result, instead of polling the DB. The signals carry no state — they
|
||||||
|
* just say "go re-read row X" — so a missed signal degrades to a slower
|
||||||
|
* poll-then-resolve (still correct, just less responsive).
|
||||||
|
*
|
||||||
|
* Single-instance assumption: EventEmitter only fires on the mcpd that
|
||||||
|
* holds the row's blocked handler. With multiple mcpd replicas, the
|
||||||
|
* worker's POST might land on instance A while the caller's handler is
|
||||||
|
* on instance B; B never sees the local emitter. Solution at scale is
|
||||||
|
* pg LISTEN/NOTIFY (v6 swap) — no schema change needed; just replace the
|
||||||
|
* EventEmitter wakeup with a NOTIFY emit.
|
||||||
|
*/
|
||||||
|
import { EventEmitter } from 'node:events';
|
||||||
|
import type { InferenceTask, InferenceTaskStatus } from '@prisma/client';
|
||||||
|
import type { IInferenceTaskRepository, CreateInferenceTaskInput } from '../repositories/inference-task.repository.js';
|
||||||
|
import { NotFoundError } from './mcp-server.service.js';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* One streaming chunk pushed back from a worker. Lives in-memory only —
|
||||||
|
* we don't persist every SSE delta to the DB (too expensive for
|
||||||
|
* high-frequency streams). The final assembled body lands on the row
|
||||||
|
* via `complete()` when the worker emits its terminal `done:true`.
|
||||||
|
*/
|
||||||
|
export interface InferenceTaskChunk {
|
||||||
|
data: string;
|
||||||
|
done?: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The wait handle returned to a blocked HTTP handler. `done` resolves
|
||||||
|
* when the row reaches a terminal status (completed | error | cancelled);
|
||||||
|
* `chunks` is an async iterator that yields streaming deltas as they
|
||||||
|
* arrive (only meaningful when the underlying request asked for streaming).
|
||||||
|
*/
|
||||||
|
export interface InferenceTaskWaiter {
|
||||||
|
/** The DB row's terminal state. Throws on `error`/`cancelled` with an explanatory message. */
|
||||||
|
done: Promise<InferenceTask>;
|
||||||
|
/** Async iterator of streaming chunks. Yields nothing for non-streaming tasks. */
|
||||||
|
chunks: AsyncGenerator<InferenceTaskChunk>;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface IInferenceTaskService {
|
||||||
|
/** Create a new pending task. Caller is expected to immediately attempt dispatch. */
|
||||||
|
enqueue(input: CreateInferenceTaskInput): Promise<InferenceTask>;
|
||||||
|
/** Wait for a task to reach a terminal state, with optional timeout (ms). */
|
||||||
|
waitFor(taskId: string, timeoutMs: number): InferenceTaskWaiter;
|
||||||
|
/** Conditional CAS claim — returns the row if `pending`, null if already claimed. */
|
||||||
|
tryClaim(taskId: string, claimedBy: string): Promise<InferenceTask | null>;
|
||||||
|
/** Worker reported first chunk — flips `claimed` → `running`. Idempotent. */
|
||||||
|
markRunning(taskId: string): Promise<InferenceTask | null>;
|
||||||
|
/** Worker pushed a streaming chunk. Wakes up the in-process waiter; no DB write. */
|
||||||
|
pushChunk(taskId: string, chunk: InferenceTaskChunk): boolean;
|
||||||
|
/**
|
||||||
|
* Subscribe directly to a task's chunk stream as a callback. Returns
|
||||||
|
* an unsubscribe function. Used by VirtualLlmService's legacy
|
||||||
|
* PendingTaskRef.onChunk bridge — the high-level `waitFor` API is
|
||||||
|
* better when you need an async iterator + done promise together.
|
||||||
|
*/
|
||||||
|
subscribeChunksUnsafe(taskId: string, cb: (chunk: InferenceTaskChunk) => void): () => void;
|
||||||
|
/** Worker POSTed the final result. Persists the body + flips to `completed`. */
|
||||||
|
complete(taskId: string, responseBody: Record<string, unknown> | null): Promise<InferenceTask | null>;
|
||||||
|
/** Worker (or mcpd) marked the task as failed. */
|
||||||
|
fail(taskId: string, errorMessage: string): Promise<InferenceTask | null>;
|
||||||
|
/** Caller (or GC) marked the task as cancelled. */
|
||||||
|
cancel(taskId: string): Promise<InferenceTask | null>;
|
||||||
|
/** Worker disconnected mid-task — revert claimed/running rows back to pending. */
|
||||||
|
revertHeldBy(claimedBy: string): Promise<InferenceTask[]>;
|
||||||
|
/** Pending rows for one or more pool keys, ordered FIFO. Used by drainPending on bind. */
|
||||||
|
findPendingForPools(poolNames: string[], limit?: number): Promise<InferenceTask[]>;
|
||||||
|
findById(id: string): Promise<InferenceTask | null>;
|
||||||
|
list(filter: {
|
||||||
|
ownerId?: string;
|
||||||
|
status?: InferenceTaskStatus | InferenceTaskStatus[];
|
||||||
|
poolName?: string;
|
||||||
|
agentId?: string;
|
||||||
|
limit?: number;
|
||||||
|
}): Promise<InferenceTask[]>;
|
||||||
|
/** GC sweep: pending too long → error; terminal too long → delete. */
|
||||||
|
gcSweep(opts: { pendingTimeoutMs: number; terminalRetentionMs: number }): Promise<{ erroredOut: number; deleted: number }>;
|
||||||
|
}
|
||||||
|
|
||||||
|
const DEFAULT_PENDING_TIMEOUT_MS = 60 * 60 * 1000; // 1 h
|
||||||
|
const DEFAULT_TERMINAL_RETENTION_MS = 7 * 24 * 60 * 60 * 1000; // 7 d
|
||||||
|
|
||||||
|
export class InferenceTaskService implements IInferenceTaskService {
|
||||||
|
/**
|
||||||
|
* Wakeup channel per task. The result-handler (worker POSTs landing on
|
||||||
|
* mcpd) emits a terminal status; the streaming push emits chunks.
|
||||||
|
* Cleared once a waiter resolves to avoid leaks. EventEmitter listener
|
||||||
|
* cap is bumped to a generous default — typical concurrent waiters per
|
||||||
|
* task is 1, but `/stream` SSE consumers and the original HTTP handler
|
||||||
|
* can both subscribe.
|
||||||
|
*/
|
||||||
|
private readonly events = new EventEmitter();
|
||||||
|
|
||||||
|
constructor(
|
||||||
|
private readonly repo: IInferenceTaskRepository,
|
||||||
|
private readonly clock: () => Date = () => new Date(),
|
||||||
|
) {
|
||||||
|
this.events.setMaxListeners(50);
|
||||||
|
}
|
||||||
|
|
||||||
|
async enqueue(input: CreateInferenceTaskInput): Promise<InferenceTask> {
|
||||||
|
return this.repo.create(input);
|
||||||
|
}
|
||||||
|
|
||||||
|
waitFor(taskId: string, timeoutMs: number): InferenceTaskWaiter {
|
||||||
|
// Set up the chunk channel up-front so a worker pushing chunks before
|
||||||
|
// the consumer subscribes doesn't drop them. We use a queue + wake
|
||||||
|
// pattern (same shape as the v3 chat.service streaming bridge).
|
||||||
|
const chunkQueue: InferenceTaskChunk[] = [];
|
||||||
|
let chunkResolve: (() => void) | null = null;
|
||||||
|
let finished = false;
|
||||||
|
let finishError: Error | null = null;
|
||||||
|
|
||||||
|
const onChunk = (chunk: InferenceTaskChunk): void => {
|
||||||
|
chunkQueue.push(chunk);
|
||||||
|
if (chunkResolve !== null) {
|
||||||
|
const r = chunkResolve;
|
||||||
|
chunkResolve = null;
|
||||||
|
r();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
const onTerminal = (): void => {
|
||||||
|
finished = true;
|
||||||
|
if (chunkResolve !== null) {
|
||||||
|
const r = chunkResolve;
|
||||||
|
chunkResolve = null;
|
||||||
|
r();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
this.events.on(`chunk:${taskId}`, onChunk);
|
||||||
|
this.events.on(`terminal:${taskId}`, onTerminal);
|
||||||
|
|
||||||
|
const cleanup = (): void => {
|
||||||
|
this.events.off(`chunk:${taskId}`, onChunk);
|
||||||
|
this.events.off(`terminal:${taskId}`, onTerminal);
|
||||||
|
};
|
||||||
|
|
||||||
|
// The terminal promise: poll once at start (in case the row is
|
||||||
|
// already terminal — common for already-completed tasks the caller
|
||||||
|
// is just fetching the result of), then race the timeout against
|
||||||
|
// the wakeup signal.
|
||||||
|
const done = (async (): Promise<InferenceTask> => {
|
||||||
|
try {
|
||||||
|
const initial = await this.repo.findById(taskId);
|
||||||
|
if (initial === null) throw new NotFoundError(`InferenceTask not found: ${taskId}`);
|
||||||
|
if (isTerminal(initial.status)) return ensureSuccess(initial);
|
||||||
|
|
||||||
|
// Wait for terminal event OR timeout. If the wake happens before
|
||||||
|
// the timer fires, we re-fetch and return; otherwise we error.
|
||||||
|
const timer = new Promise<never>((_, reject) => {
|
||||||
|
setTimeout(() => reject(new Error(`InferenceTask wait timed out after ${String(timeoutMs)}ms (task ${taskId})`)), timeoutMs);
|
||||||
|
});
|
||||||
|
const wake = new Promise<void>((resolve) => {
|
||||||
|
this.events.once(`terminal:${taskId}`, () => resolve());
|
||||||
|
});
|
||||||
|
await Promise.race([timer, wake]);
|
||||||
|
finished = true;
|
||||||
|
const final = await this.repo.findById(taskId);
|
||||||
|
if (final === null) throw new NotFoundError(`InferenceTask not found: ${taskId}`);
|
||||||
|
return ensureSuccess(final);
|
||||||
|
} catch (err) {
|
||||||
|
finishError = err as Error;
|
||||||
|
finished = true;
|
||||||
|
throw err;
|
||||||
|
} finally {
|
||||||
|
cleanup();
|
||||||
|
}
|
||||||
|
})();
|
||||||
|
|
||||||
|
const chunks = (async function* gen(): AsyncGenerator<InferenceTaskChunk> {
|
||||||
|
while (true) {
|
||||||
|
while (chunkQueue.length > 0) {
|
||||||
|
const c = chunkQueue.shift()!;
|
||||||
|
yield c;
|
||||||
|
if (c.done === true) return;
|
||||||
|
}
|
||||||
|
if (finished) {
|
||||||
|
if (finishError !== null) throw finishError;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
await new Promise<void>((r) => { chunkResolve = r; });
|
||||||
|
}
|
||||||
|
})();
|
||||||
|
|
||||||
|
return { done, chunks };
|
||||||
|
}
|
||||||
|
|
||||||
|
async tryClaim(taskId: string, claimedBy: string): Promise<InferenceTask | null> {
|
||||||
|
return this.repo.tryClaim(taskId, claimedBy, this.clock());
|
||||||
|
}
|
||||||
|
|
||||||
|
async markRunning(taskId: string): Promise<InferenceTask | null> {
|
||||||
|
return this.repo.markRunning(taskId, this.clock());
|
||||||
|
}
|
||||||
|
|
||||||
|
pushChunk(taskId: string, chunk: InferenceTaskChunk): boolean {
|
||||||
|
// Streaming chunks are in-memory only. If the row no longer exists or
|
||||||
|
// there are no listeners, the emit is a no-op — return false so the
|
||||||
|
// result-route can decide whether to log a warning.
|
||||||
|
return this.events.emit(`chunk:${taskId}`, chunk);
|
||||||
|
}
|
||||||
|
|
||||||
|
subscribeChunksUnsafe(taskId: string, cb: (chunk: InferenceTaskChunk) => void): () => void {
|
||||||
|
this.events.on(`chunk:${taskId}`, cb);
|
||||||
|
return (): void => {
|
||||||
|
this.events.off(`chunk:${taskId}`, cb);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
async complete(taskId: string, responseBody: Record<string, unknown> | null): Promise<InferenceTask | null> {
|
||||||
|
const updated = await this.repo.markCompleted(taskId, responseBody, this.clock());
|
||||||
|
if (updated !== null) this.events.emit(`terminal:${taskId}`);
|
||||||
|
return updated;
|
||||||
|
}
|
||||||
|
|
||||||
|
async fail(taskId: string, errorMessage: string): Promise<InferenceTask | null> {
|
||||||
|
const updated = await this.repo.markError(taskId, errorMessage, this.clock());
|
||||||
|
if (updated !== null) this.events.emit(`terminal:${taskId}`);
|
||||||
|
return updated;
|
||||||
|
}
|
||||||
|
|
||||||
|
async cancel(taskId: string): Promise<InferenceTask | null> {
|
||||||
|
const updated = await this.repo.markCancelled(taskId, this.clock());
|
||||||
|
if (updated !== null) this.events.emit(`terminal:${taskId}`);
|
||||||
|
return updated;
|
||||||
|
}
|
||||||
|
|
||||||
|
async revertHeldBy(claimedBy: string): Promise<InferenceTask[]> {
|
||||||
|
const held = await this.repo.findHeldBy(claimedBy);
|
||||||
|
const reverted: InferenceTask[] = [];
|
||||||
|
for (const t of held) {
|
||||||
|
const r = await this.repo.revertToPending(t.id);
|
||||||
|
if (r !== null) reverted.push(r);
|
||||||
|
}
|
||||||
|
return reverted;
|
||||||
|
}
|
||||||
|
|
||||||
|
async findPendingForPools(poolNames: string[], limit?: number): Promise<InferenceTask[]> {
|
||||||
|
return this.repo.findPendingForPools(poolNames, limit);
|
||||||
|
}
|
||||||
|
|
||||||
|
async findById(id: string): Promise<InferenceTask | null> {
|
||||||
|
return this.repo.findById(id);
|
||||||
|
}
|
||||||
|
|
||||||
|
async list(filter: {
|
||||||
|
ownerId?: string;
|
||||||
|
status?: InferenceTaskStatus | InferenceTaskStatus[];
|
||||||
|
poolName?: string;
|
||||||
|
agentId?: string;
|
||||||
|
limit?: number;
|
||||||
|
}): Promise<InferenceTask[]> {
|
||||||
|
return this.repo.list(filter);
|
||||||
|
}
|
||||||
|
|
||||||
|
async gcSweep(opts: { pendingTimeoutMs: number; terminalRetentionMs: number } = {
|
||||||
|
pendingTimeoutMs: DEFAULT_PENDING_TIMEOUT_MS,
|
||||||
|
terminalRetentionMs: DEFAULT_TERMINAL_RETENTION_MS,
|
||||||
|
}): Promise<{ erroredOut: number; deleted: number }> {
|
||||||
|
const now = this.clock();
|
||||||
|
const pendingCutoff = new Date(now.getTime() - opts.pendingTimeoutMs);
|
||||||
|
const terminalCutoff = new Date(now.getTime() - opts.terminalRetentionMs);
|
||||||
|
|
||||||
|
// Pass 1: pending tasks that have aged out → flip to error so the
|
||||||
|
// caller (if still waiting) gets a clean failure.
|
||||||
|
const stale = await this.repo.findStalePending(pendingCutoff);
|
||||||
|
let erroredOut = 0;
|
||||||
|
for (const t of stale) {
|
||||||
|
const r = await this.repo.markError(t.id, `Task expired in pending state after ${String(opts.pendingTimeoutMs / 1000)}s`, now);
|
||||||
|
if (r !== null) {
|
||||||
|
this.events.emit(`terminal:${t.id}`);
|
||||||
|
erroredOut += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pass 2: terminal tasks past retention → delete.
|
||||||
|
const expired = await this.repo.findExpiredTerminal(terminalCutoff);
|
||||||
|
const deleted = await this.repo.deleteMany(expired.map((t) => t.id));
|
||||||
|
|
||||||
|
return { erroredOut, deleted };
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function isTerminal(status: InferenceTaskStatus): boolean {
|
||||||
|
return status === 'completed' || status === 'error' || status === 'cancelled';
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Surface a non-success terminal state as a thrown error so the HTTP
|
||||||
|
* handler propagates it cleanly. `error` rows carry an `errorMessage`;
|
||||||
|
* `cancelled` rows don't, so we synthesize one.
|
||||||
|
*/
|
||||||
|
function ensureSuccess(task: InferenceTask): InferenceTask {
|
||||||
|
if (task.status === 'completed') return task;
|
||||||
|
if (task.status === 'cancelled') {
|
||||||
|
throw new Error(`InferenceTask cancelled: ${task.id}`);
|
||||||
|
}
|
||||||
|
if (task.status === 'error') {
|
||||||
|
throw new Error(task.errorMessage ?? `InferenceTask failed: ${task.id}`);
|
||||||
|
}
|
||||||
|
// Theoretically unreachable — waitFor only resolves on terminal events.
|
||||||
|
throw new Error(`InferenceTask in unexpected state: ${String(task.status)}`);
|
||||||
|
}
|
||||||
@@ -29,6 +29,8 @@ import type { ILlmRepository } from '../repositories/llm.repository.js';
|
|||||||
import type { OpenAiChatRequest } from './llm/types.js';
|
import type { OpenAiChatRequest } from './llm/types.js';
|
||||||
import { NotFoundError } from './mcp-server.service.js';
|
import { NotFoundError } from './mcp-server.service.js';
|
||||||
import type { AgentService } from './agent.service.js';
|
import type { AgentService } from './agent.service.js';
|
||||||
|
import type { IInferenceTaskService, InferenceTaskChunk } from './inference-task.service.js';
|
||||||
|
import { effectivePoolName } from './llm.service.js';
|
||||||
|
|
||||||
/** A virtual provider's announcement at registration time. */
|
/** A virtual provider's announcement at registration time. */
|
||||||
export interface RegisterProviderInput {
|
export interface RegisterProviderInput {
|
||||||
@@ -81,29 +83,53 @@ export type VirtualTaskFrame =
|
|||||||
| { kind: 'wake'; taskId: string; llmName: string };
|
| { kind: 'wake'; taskId: string; llmName: string };
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Pending inference task. The route handler awaits `done`; the result POST
|
* In-memory wake task. Wake is publisher-control work, not inference —
|
||||||
* resolves it via `completeTask()`. The error path rejects via `failTask()`.
|
* we don't persist wake tasks to the DB because their lifetime is
|
||||||
|
* milliseconds and missing a wake on restart just means the next infer
|
||||||
|
* fires a fresh wake. Inference tasks live in the durable queue (v5);
|
||||||
|
* see InferenceTaskService.
|
||||||
*/
|
*/
|
||||||
interface PendingTask {
|
interface InMemoryWakeTask {
|
||||||
taskId: string;
|
taskId: string;
|
||||||
sessionId: string;
|
sessionId: string;
|
||||||
llmName: string;
|
llmName: string;
|
||||||
streaming: boolean;
|
resolve: (status: number) => void;
|
||||||
resolveNonStreaming: (body: unknown, status: number) => void;
|
reject: (err: Error) => void;
|
||||||
rejectNonStreaming: (err: Error) => void;
|
|
||||||
/** For streaming tasks only; null on non-streaming. */
|
|
||||||
pushChunk: ((chunk: { data: string; done?: boolean }) => void) | null;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const HEARTBEAT_TIMEOUT_MS = 90_000;
|
const HEARTBEAT_TIMEOUT_MS = 90_000;
|
||||||
const INACTIVE_RETENTION_MS = 4 * 60 * 60 * 1000; // 4 h
|
const INACTIVE_RETENTION_MS = 4 * 60 * 60 * 1000; // 4 h
|
||||||
|
/**
|
||||||
|
* v5: how long enqueueInferTask waits for the worker's terminal POST
|
||||||
|
* before bailing. 10 minutes is generous for thinking models that
|
||||||
|
* spend 30-90s warming up. The TASK itself stays queued past this —
|
||||||
|
* only the in-flight HTTP handler gives up; an async-API consumer can
|
||||||
|
* still poll the row and pick up the result later.
|
||||||
|
*/
|
||||||
|
const INFER_AWAIT_TIMEOUT_MS = 10 * 60 * 1000;
|
||||||
|
|
||||||
|
export interface EnqueueInferOptions {
|
||||||
|
/**
|
||||||
|
* v5: if true, throw 503 immediately when no live SSE session exists
|
||||||
|
* for the row's session at enqueue time, instead of queueing the
|
||||||
|
* task and waiting for a worker to show up. Used by:
|
||||||
|
* - the direct `/api/v1/llms/<name>/infer` route (caller asked for
|
||||||
|
* THIS specific Llm; no pool fanout at this layer)
|
||||||
|
* - chat.service's pool failover loop (it iterates candidates and
|
||||||
|
* needs each candidate's failure to surface fast)
|
||||||
|
* The async `POST /api/v1/inference-tasks` API leaves this false so
|
||||||
|
* callers explicitly opting into the durable queue get the queue.
|
||||||
|
* Default: false (durable, queues + waits up to INFER_AWAIT_TIMEOUT_MS).
|
||||||
|
*/
|
||||||
|
failFast?: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
export interface IVirtualLlmService {
|
export interface IVirtualLlmService {
|
||||||
register(input: { providerSessionId?: string | null; providers: RegisterProviderInput[] }): Promise<RegisterResult>;
|
register(input: { providerSessionId?: string | null; providers: RegisterProviderInput[] }): Promise<RegisterResult>;
|
||||||
heartbeat(providerSessionId: string): Promise<void>;
|
heartbeat(providerSessionId: string): Promise<void>;
|
||||||
bindSession(providerSessionId: string, handle: VirtualSessionHandle): void;
|
bindSession(providerSessionId: string, handle: VirtualSessionHandle): void;
|
||||||
unbindSession(providerSessionId: string): Promise<void>;
|
unbindSession(providerSessionId: string): Promise<void>;
|
||||||
enqueueInferTask(llmName: string, request: OpenAiChatRequest, streaming: boolean): Promise<PendingTaskRef>;
|
enqueueInferTask(llmName: string, request: OpenAiChatRequest, streaming: boolean, options?: EnqueueInferOptions): Promise<PendingTaskRef>;
|
||||||
completeTask(taskId: string, result: { status: number; body: unknown }): boolean;
|
completeTask(taskId: string, result: { status: number; body: unknown }): boolean;
|
||||||
pushTaskChunk(taskId: string, chunk: { data: string; done?: boolean }): boolean;
|
pushTaskChunk(taskId: string, chunk: { data: string; done?: boolean }): boolean;
|
||||||
failTask(taskId: string, error: Error): boolean;
|
failTask(taskId: string, error: Error): boolean;
|
||||||
@@ -121,7 +147,12 @@ export interface PendingTaskRef {
|
|||||||
|
|
||||||
export class VirtualLlmService implements IVirtualLlmService {
|
export class VirtualLlmService implements IVirtualLlmService {
|
||||||
private readonly sessions = new Map<string, VirtualSessionHandle>();
|
private readonly sessions = new Map<string, VirtualSessionHandle>();
|
||||||
private readonly tasksById = new Map<string, PendingTask>();
|
/**
|
||||||
|
* Wake tasks live here, in-memory. The result POST resolves them by
|
||||||
|
* id. Inference tasks have moved to the durable queue (v5); see
|
||||||
|
* IInferenceTaskService.
|
||||||
|
*/
|
||||||
|
private readonly wakeTasks = new Map<string, InMemoryWakeTask>();
|
||||||
/**
|
/**
|
||||||
* Dedupe concurrent wake requests for the same Llm. The first request
|
* Dedupe concurrent wake requests for the same Llm. The first request
|
||||||
* starts the wake; later requests for the same name await the same
|
* starts the wake; later requests for the same name await the same
|
||||||
@@ -138,6 +169,19 @@ export class VirtualLlmService implements IVirtualLlmService {
|
|||||||
* before deleting the Llm itself (Agent.llmId is Restrict).
|
* before deleting the Llm itself (Agent.llmId is Restrict).
|
||||||
*/
|
*/
|
||||||
private readonly agents?: AgentService,
|
private readonly agents?: AgentService,
|
||||||
|
/**
|
||||||
|
* v5: durable inference task queue. Optional so older test wirings
|
||||||
|
* (and the non-virtual chat path) compile without it; when absent,
|
||||||
|
* enqueueInferTask falls back to a clear "task queue not wired"
|
||||||
|
* error rather than silently regressing to in-memory.
|
||||||
|
*/
|
||||||
|
private readonly tasks?: IInferenceTaskService,
|
||||||
|
/**
|
||||||
|
* v5: caller's user id, threaded into newly-created task rows for
|
||||||
|
* RBAC + observability. Optional — older wirings that don't have a
|
||||||
|
* request context attribute tasks to 'system'.
|
||||||
|
*/
|
||||||
|
private readonly resolveOwner: () => string = () => 'system',
|
||||||
) {}
|
) {}
|
||||||
|
|
||||||
async register(input: { providerSessionId?: string | null; providers: RegisterProviderInput[] }): Promise<RegisterResult> {
|
async register(input: { providerSessionId?: string | null; providers: RegisterProviderInput[] }): Promise<RegisterResult> {
|
||||||
@@ -227,6 +271,12 @@ export class VirtualLlmService implements IVirtualLlmService {
|
|||||||
// Replace any prior handle for this session — keeps "last writer wins"
|
// Replace any prior handle for this session — keeps "last writer wins"
|
||||||
// simple. The old SSE will have been closed by the publisher anyway.
|
// simple. The old SSE will have been closed by the publisher anyway.
|
||||||
this.sessions.set(providerSessionId, handle);
|
this.sessions.set(providerSessionId, handle);
|
||||||
|
// v5: drain queued inference tasks targeting any pool this session
|
||||||
|
// owns. Fire-and-forget: the bind() route handler completes the SSE
|
||||||
|
// handshake first; tasks land on the channel right after.
|
||||||
|
if (this.tasks !== undefined) {
|
||||||
|
void this.drainPendingForSession(providerSessionId, handle);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async unbindSession(providerSessionId: string): Promise<void> {
|
async unbindSession(providerSessionId: string): Promise<void> {
|
||||||
@@ -243,11 +293,60 @@ export class VirtualLlmService implements IVirtualLlmService {
|
|||||||
if (this.agents !== undefined) {
|
if (this.agents !== undefined) {
|
||||||
await this.agents.markVirtualAgentsInactiveBySession(providerSessionId);
|
await this.agents.markVirtualAgentsInactiveBySession(providerSessionId);
|
||||||
}
|
}
|
||||||
// Reject any in-flight tasks for this session — the relay can't deliver
|
// v5: revert claimed/running inference tasks back to pending so
|
||||||
// a result POST anymore.
|
// another worker on the same pool can pick them up. If no other
|
||||||
for (const t of this.tasksById.values()) {
|
// worker is up, they stay queued for whenever one shows up.
|
||||||
|
if (this.tasks !== undefined) {
|
||||||
|
await this.tasks.revertHeldBy(providerSessionId);
|
||||||
|
}
|
||||||
|
// Reject any in-flight wake tasks for this session — those don't
|
||||||
|
// benefit from re-queue since the publisher itself went away.
|
||||||
|
for (const t of this.wakeTasks.values()) {
|
||||||
if (t.sessionId === providerSessionId) {
|
if (t.sessionId === providerSessionId) {
|
||||||
this.failTask(t.taskId, new Error('publisher disconnected'));
|
this.wakeTasks.delete(t.taskId);
|
||||||
|
t.reject(new Error('publisher disconnected'));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* v5: when a worker binds its SSE channel, find every pending
|
||||||
|
* InferenceTask targeting a pool key this session owns, claim each,
|
||||||
|
* and push the frame down the channel. If two workers in the same
|
||||||
|
* pool race here, the repo's CAS on tryClaim ensures each task is
|
||||||
|
* pushed to exactly one of them.
|
||||||
|
*/
|
||||||
|
private async drainPendingForSession(
|
||||||
|
providerSessionId: string,
|
||||||
|
handle: VirtualSessionHandle,
|
||||||
|
): Promise<void> {
|
||||||
|
if (this.tasks === undefined) return;
|
||||||
|
const owned = await this.repo.findBySessionId(providerSessionId);
|
||||||
|
if (owned.length === 0) return;
|
||||||
|
const poolKeys = Array.from(new Set(owned.map((row) => effectivePoolName(row))));
|
||||||
|
const pending = await this.tasks.findPendingForPools(poolKeys);
|
||||||
|
for (const task of pending) {
|
||||||
|
// Cap drain per bind to avoid pushing a giant backlog into a
|
||||||
|
// brand-new SSE channel all at once. Hardcoded for now; a
|
||||||
|
// per-session capacity hint is a v6 concern.
|
||||||
|
if (!handle.alive) break;
|
||||||
|
const claimed = await this.tasks.tryClaim(task.id, providerSessionId);
|
||||||
|
if (claimed === null) continue; // raced; another worker got it
|
||||||
|
try {
|
||||||
|
handle.pushTask({
|
||||||
|
kind: 'infer',
|
||||||
|
taskId: claimed.id,
|
||||||
|
llmName: claimed.llmName,
|
||||||
|
request: claimed.requestBody as unknown as OpenAiChatRequest,
|
||||||
|
streaming: claimed.streaming,
|
||||||
|
});
|
||||||
|
} catch (err) {
|
||||||
|
// SSE write failed mid-drain — revert the claim so a
|
||||||
|
// healthier worker can pick it up.
|
||||||
|
await this.tasks.revertHeldBy(providerSessionId);
|
||||||
|
// eslint-disable-next-line no-console
|
||||||
|
console.warn(`drainPendingForSession: pushTask failed for ${claimed.id}: ${(err as Error).message}`);
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -256,7 +355,11 @@ export class VirtualLlmService implements IVirtualLlmService {
|
|||||||
llmName: string,
|
llmName: string,
|
||||||
request: OpenAiChatRequest,
|
request: OpenAiChatRequest,
|
||||||
streaming: boolean,
|
streaming: boolean,
|
||||||
|
options: EnqueueInferOptions = {},
|
||||||
): Promise<PendingTaskRef> {
|
): Promise<PendingTaskRef> {
|
||||||
|
if (this.tasks === undefined) {
|
||||||
|
throw new Error('InferenceTaskService not wired into VirtualLlmService');
|
||||||
|
}
|
||||||
const llm = await this.repo.findByName(llmName);
|
const llm = await this.repo.findByName(llmName);
|
||||||
if (llm === null) throw new NotFoundError(`Llm not found: ${llmName}`);
|
if (llm === null) throw new NotFoundError(`Llm not found: ${llmName}`);
|
||||||
if (llm.kind !== 'virtual' || llm.providerSessionId === null) {
|
if (llm.kind !== 'virtual' || llm.providerSessionId === null) {
|
||||||
@@ -265,67 +368,133 @@ export class VirtualLlmService implements IVirtualLlmService {
|
|||||||
{ statusCode: 500 },
|
{ statusCode: 500 },
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
if (llm.status === 'inactive') {
|
|
||||||
throw Object.assign(
|
// failFast callers (chat.service pool failover, direct infer route)
|
||||||
new Error(`Virtual Llm '${llmName}' is inactive; publisher offline`),
|
// get the v1-v4 semantic: row inactive OR no live session = 503,
|
||||||
{ statusCode: 503 },
|
// immediately. The chat dispatcher then iterates the next pool
|
||||||
);
|
// candidate. Without failFast, both cases queue durably and a
|
||||||
}
|
// future bindSession drains.
|
||||||
const handle = this.sessions.get(llm.providerSessionId);
|
if (options.failFast === true) {
|
||||||
if (handle === undefined || !handle.alive) {
|
if (llm.status === 'inactive') {
|
||||||
throw Object.assign(
|
throw Object.assign(
|
||||||
new Error(`Virtual Llm '${llmName}' has no live SSE session; publisher offline`),
|
new Error(`Virtual Llm '${llmName}' is inactive; publisher offline`),
|
||||||
{ statusCode: 503 },
|
{ statusCode: 503 },
|
||||||
);
|
);
|
||||||
|
}
|
||||||
|
const handle = this.sessions.get(llm.providerSessionId);
|
||||||
|
if (handle === undefined || !handle.alive) {
|
||||||
|
throw Object.assign(
|
||||||
|
new Error(`Virtual Llm '${llmName}' has no live SSE session; publisher offline`),
|
||||||
|
{ statusCode: 503 },
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ── Wake-on-demand (v2) ──
|
// ── Wake-on-demand (v2) ──
|
||||||
// Status=hibernating means the publisher told us at register time
|
// Status=hibernating means the publisher told us at register time
|
||||||
// (or via a later status update) that the backend is asleep. Fire a
|
// that the backend is asleep. Fire a wake task and wait for the
|
||||||
// wake task and wait for the publisher to confirm readiness before
|
// publisher to confirm readiness before persisting the inference
|
||||||
// dispatching the actual inference. Concurrent infers for the same
|
// task. Concurrent infers for the same Llm share a single wake
|
||||||
// Llm share a single wake promise.
|
// promise. We hold the wake INSIDE this method (not after enqueue)
|
||||||
|
// so we don't queue an inference task that we then can't dispatch.
|
||||||
if (llm.status === 'hibernating') {
|
if (llm.status === 'hibernating') {
|
||||||
|
const handle = this.sessions.get(llm.providerSessionId);
|
||||||
|
if (handle === undefined || !handle.alive) {
|
||||||
|
throw Object.assign(
|
||||||
|
new Error(`Virtual Llm '${llmName}' has no live SSE session; cannot wake`),
|
||||||
|
{ statusCode: 503 },
|
||||||
|
);
|
||||||
|
}
|
||||||
await this.ensureAwake(llm.id, llm.name, llm.providerSessionId, handle);
|
await this.ensureAwake(llm.id, llm.name, llm.providerSessionId, handle);
|
||||||
}
|
}
|
||||||
|
|
||||||
const taskId = randomUUID();
|
// ── v5: persist the task BEFORE attempting dispatch ──
|
||||||
const chunkSubscribers = new Set<(chunk: { data: string; done?: boolean }) => void>();
|
// Even if no worker is up, the row stays pending and a future
|
||||||
|
// bindSession will drain it. Caller's HTTP timeout still bounds
|
||||||
let resolveDone!: (v: { status: number; body: unknown }) => void;
|
// the wait, but the *task* survives.
|
||||||
let rejectDone!: (err: Error) => void;
|
const created = await this.tasks.enqueue({
|
||||||
const done = new Promise<{ status: number; body: unknown }>((resolve, reject) => {
|
poolName: effectivePoolName(llm),
|
||||||
resolveDone = resolve;
|
llmName,
|
||||||
rejectDone = reject;
|
model: llm.model,
|
||||||
|
tier: llm.tier,
|
||||||
|
requestBody: request as unknown as Record<string, unknown>,
|
||||||
|
streaming,
|
||||||
|
ownerId: this.resolveOwner(),
|
||||||
});
|
});
|
||||||
|
|
||||||
const pending: PendingTask = {
|
// Try to claim + dispatch immediately if a session is up. If not,
|
||||||
taskId,
|
// the row stays pending for drainPendingForSession to pick up.
|
||||||
sessionId: llm.providerSessionId,
|
const handle = this.sessions.get(llm.providerSessionId);
|
||||||
llmName,
|
if (handle !== undefined && handle.alive) {
|
||||||
streaming,
|
const claimed = await this.tasks.tryClaim(created.id, llm.providerSessionId);
|
||||||
resolveNonStreaming: (body, status) => resolveDone({ status, body }),
|
if (claimed !== null) {
|
||||||
rejectNonStreaming: rejectDone,
|
handle.pushTask({
|
||||||
pushChunk: streaming
|
kind: 'infer',
|
||||||
? (chunk): void => { for (const cb of chunkSubscribers) cb(chunk); }
|
taskId: claimed.id,
|
||||||
: null,
|
llmName,
|
||||||
|
request,
|
||||||
|
streaming,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
// tryClaim can return null only if another concurrent enqueue
|
||||||
|
// (or the GC sweep) already moved the row off pending — leave
|
||||||
|
// it; drainPendingForSession will reconcile.
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wrap the task service's waitFor into the legacy PendingTaskRef
|
||||||
|
// shape so existing callers (chat.service, llm-infer route)
|
||||||
|
// don't need to change. The chunk callback bridges
|
||||||
|
// InferenceTaskService.events into the on-the-fly subscriber set.
|
||||||
|
const tasks = this.tasks;
|
||||||
|
const taskId = created.id;
|
||||||
|
const subscribers = new Set<(chunk: InferenceTaskChunk) => void>();
|
||||||
|
// Single bridge listener on the service's emitter; fans out to all
|
||||||
|
// subscribers locally so we don't pile up listeners on a hot key.
|
||||||
|
let bridgeUnsub: (() => void) | null = null;
|
||||||
|
const ensureBridge = (): void => {
|
||||||
|
if (bridgeUnsub !== null) return;
|
||||||
|
const onChunk = (chunk: InferenceTaskChunk): void => {
|
||||||
|
for (const cb of subscribers) cb(chunk);
|
||||||
|
};
|
||||||
|
// Subscribe via the public pushChunk path: the service emits on
|
||||||
|
// its own EventEmitter; we attach a listener here. Use the
|
||||||
|
// service's events through a side channel — exposed below as
|
||||||
|
// subscribeChunks for clarity.
|
||||||
|
bridgeUnsub = tasks.subscribeChunksUnsafe(taskId, onChunk);
|
||||||
};
|
};
|
||||||
this.tasksById.set(taskId, pending);
|
|
||||||
|
|
||||||
handle.pushTask({
|
const done = (async (): Promise<{ status: number; body: unknown }> => {
|
||||||
kind: 'infer',
|
// waitFor's `done` rejects on cancel/error/timeout. For non-
|
||||||
taskId,
|
// streaming tasks the responseBody IS the body; for streaming
|
||||||
llmName,
|
// the body is null and chunks have already been piped through.
|
||||||
request,
|
const waiter = tasks.waitFor(taskId, INFER_AWAIT_TIMEOUT_MS);
|
||||||
streaming,
|
// If streaming, we must drain the chunks generator concurrently
|
||||||
});
|
// so chunks aren't dropped just because no one's subscribed yet.
|
||||||
|
// The real consumer subscribes via onChunk(); their cb fires
|
||||||
|
// synchronously inside the bridge.
|
||||||
|
if (streaming) ensureBridge();
|
||||||
|
const finalRow = await waiter.done;
|
||||||
|
// Status code: legacy callers expect 200 on success; the worker
|
||||||
|
// sends its own status via the result POST and we forward it.
|
||||||
|
// Today, success POSTs come with status=200 and the row's
|
||||||
|
// responseBody is { status, body }. v5 stores the *body* directly
|
||||||
|
// on the row; we synthesize a 200 here to match the legacy shape.
|
||||||
|
return { status: 200, body: finalRow.responseBody };
|
||||||
|
})();
|
||||||
|
|
||||||
return {
|
return {
|
||||||
taskId,
|
taskId,
|
||||||
done,
|
done,
|
||||||
onChunk(cb): () => void {
|
onChunk(cb): () => void {
|
||||||
chunkSubscribers.add(cb);
|
subscribers.add(cb);
|
||||||
return () => chunkSubscribers.delete(cb);
|
ensureBridge();
|
||||||
|
return (): void => {
|
||||||
|
subscribers.delete(cb);
|
||||||
|
if (subscribers.size === 0 && bridgeUnsub !== null) {
|
||||||
|
bridgeUnsub();
|
||||||
|
bridgeUnsub = null;
|
||||||
|
}
|
||||||
|
};
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
@@ -370,22 +539,17 @@ export class VirtualLlmService implements IVirtualLlmService {
|
|||||||
rejectDone = reject;
|
rejectDone = reject;
|
||||||
});
|
});
|
||||||
|
|
||||||
const pending: PendingTask = {
|
const wake: InMemoryWakeTask = {
|
||||||
taskId,
|
taskId,
|
||||||
sessionId,
|
sessionId,
|
||||||
llmName,
|
llmName,
|
||||||
streaming: false,
|
resolve: (status) => {
|
||||||
// Wake tasks return { ok: true } on success or never resolve at
|
|
||||||
// all if the publisher dies; the rejectNonStreaming path covers
|
|
||||||
// the disconnect-mid-wake case via unbindSession.
|
|
||||||
resolveNonStreaming: (_body, status) => {
|
|
||||||
if (status >= 200 && status < 300) resolveDone();
|
if (status >= 200 && status < 300) resolveDone();
|
||||||
else rejectDone(new Error(`wake task returned status ${String(status)}`));
|
else rejectDone(new Error(`wake task returned status ${String(status)}`));
|
||||||
},
|
},
|
||||||
rejectNonStreaming: rejectDone,
|
reject: rejectDone,
|
||||||
pushChunk: null,
|
|
||||||
};
|
};
|
||||||
this.tasksById.set(taskId, pending);
|
this.wakeTasks.set(taskId, wake);
|
||||||
|
|
||||||
handle.pushTask({ kind: 'wake', taskId, llmName });
|
handle.pushTask({ kind: 'wake', taskId, llmName });
|
||||||
|
|
||||||
@@ -402,32 +566,55 @@ export class VirtualLlmService implements IVirtualLlmService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
completeTask(taskId: string, result: { status: number; body: unknown }): boolean {
|
completeTask(taskId: string, result: { status: number; body: unknown }): boolean {
|
||||||
const t = this.tasksById.get(taskId);
|
// Wake tasks: in-memory map. Resolve and bail.
|
||||||
if (t === undefined) return false;
|
const wake = this.wakeTasks.get(taskId);
|
||||||
this.tasksById.delete(taskId);
|
if (wake !== undefined) {
|
||||||
t.resolveNonStreaming(result.body, result.status);
|
this.wakeTasks.delete(taskId);
|
||||||
return true;
|
wake.resolve(result.status);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
// Inference tasks: durable queue. Persist body + flip terminal +
|
||||||
|
// emit wakeup. Fire-and-forget the DB write; the result POST
|
||||||
|
// returns 200 either way (the caller's HTTP handler is what cares
|
||||||
|
// about the eventual emit, not the worker).
|
||||||
|
if (this.tasks !== undefined) {
|
||||||
|
void this.tasks.complete(taskId, result.body as Record<string, unknown> | null);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
pushTaskChunk(taskId: string, chunk: { data: string; done?: boolean }): boolean {
|
pushTaskChunk(taskId: string, chunk: { data: string; done?: boolean }): boolean {
|
||||||
const t = this.tasksById.get(taskId);
|
// Wake tasks never receive chunks — they're non-streaming control
|
||||||
if (t === undefined || t.pushChunk === null) return false;
|
// messages. Don't even check the wake map; if the id isn't an
|
||||||
t.pushChunk(chunk);
|
// inference task, just drop the chunk.
|
||||||
|
if (this.tasks === undefined) return false;
|
||||||
|
// First chunk for a claimed task → flip claimed → running. Idempotent
|
||||||
|
// and fire-and-forget; if the row is already running, the CAS in the
|
||||||
|
// repo just no-ops.
|
||||||
|
void this.tasks.markRunning(taskId);
|
||||||
|
this.tasks.pushChunk(taskId, chunk);
|
||||||
if (chunk.done === true) {
|
if (chunk.done === true) {
|
||||||
// For streaming tasks, also resolve the `done` promise so the route
|
// Streaming completion: persist with null body + flip terminal so
|
||||||
// handler can clean up.
|
// the waiter unblocks. The actual content was already streamed
|
||||||
t.resolveNonStreaming(null, 200);
|
// through the chunks channel; nothing to store.
|
||||||
this.tasksById.delete(taskId);
|
void this.tasks.complete(taskId, null);
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
failTask(taskId: string, error: Error): boolean {
|
failTask(taskId: string, error: Error): boolean {
|
||||||
const t = this.tasksById.get(taskId);
|
const wake = this.wakeTasks.get(taskId);
|
||||||
if (t === undefined) return false;
|
if (wake !== undefined) {
|
||||||
this.tasksById.delete(taskId);
|
this.wakeTasks.delete(taskId);
|
||||||
t.rejectNonStreaming(error);
|
wake.reject(error);
|
||||||
return true;
|
return true;
|
||||||
|
}
|
||||||
|
if (this.tasks !== undefined) {
|
||||||
|
void this.tasks.fail(taskId, error.message);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
async gcSweep(now: Date = new Date()): Promise<{ markedInactive: number; deleted: number }> {
|
async gcSweep(now: Date = new Date()): Promise<{ markedInactive: number; deleted: number }> {
|
||||||
|
|||||||
@@ -193,6 +193,10 @@ describe('ChatService — kind=virtual branch (v3 Stage 1)', () => {
|
|||||||
'vllm-local',
|
'vllm-local',
|
||||||
expect.objectContaining({ messages: expect.any(Array) }),
|
expect.objectContaining({ messages: expect.any(Array) }),
|
||||||
false,
|
false,
|
||||||
|
// v5: chat.service passes failFast:true so its pool failover loop
|
||||||
|
// surfaces transport errors quickly instead of waiting on the
|
||||||
|
// durable queue's 10-min timeout.
|
||||||
|
{ failFast: true },
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -224,6 +228,7 @@ describe('ChatService — kind=virtual branch (v3 Stage 1)', () => {
|
|||||||
'vllm-local',
|
'vllm-local',
|
||||||
expect.objectContaining({ messages: expect.any(Array), stream: true }),
|
expect.objectContaining({ messages: expect.any(Array), stream: true }),
|
||||||
true,
|
true,
|
||||||
|
{ failFast: true },
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|||||||
330
src/mcpd/tests/inference-task-service.test.ts
Normal file
330
src/mcpd/tests/inference-task-service.test.ts
Normal file
@@ -0,0 +1,330 @@
|
|||||||
|
/**
|
||||||
|
* v5 InferenceTaskService — state machine, signal channels, and GC sweep.
|
||||||
|
* The repo is mocked with an in-memory Map so we exercise the service
|
||||||
|
* logic deterministically without touching Postgres. Schema-level tests
|
||||||
|
* live in src/db/tests/inference-task-schema.test.ts.
|
||||||
|
*/
|
||||||
|
import { describe, it, expect, vi } from 'vitest';
|
||||||
|
import type { InferenceTask, InferenceTaskStatus } from '@prisma/client';
|
||||||
|
import type { IInferenceTaskRepository } from '../src/repositories/inference-task.repository.js';
|
||||||
|
import { InferenceTaskService } from '../src/services/inference-task.service.js';
|
||||||
|
|
||||||
|
function makeRow(overrides: Partial<InferenceTask> = {}): InferenceTask {
|
||||||
|
return {
|
||||||
|
id: overrides.id ?? `task-${Math.random().toString(36).slice(2, 8)}`,
|
||||||
|
status: 'pending',
|
||||||
|
poolName: 'pool-a',
|
||||||
|
llmName: 'pool-a',
|
||||||
|
model: 'qwen3-thinking',
|
||||||
|
tier: null,
|
||||||
|
claimedBy: null,
|
||||||
|
requestBody: { messages: [{ role: 'user', content: 'hi' }] } as unknown as InferenceTask['requestBody'],
|
||||||
|
responseBody: null,
|
||||||
|
errorMessage: null,
|
||||||
|
streaming: false,
|
||||||
|
createdAt: new Date(),
|
||||||
|
claimedAt: null,
|
||||||
|
streamStartedAt: null,
|
||||||
|
completedAt: null,
|
||||||
|
ownerId: 'owner-1',
|
||||||
|
agentId: null,
|
||||||
|
...overrides,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
function mockRepo(initial: InferenceTask[] = []): IInferenceTaskRepository {
|
||||||
|
const rows = new Map<string, InferenceTask>(initial.map((r) => [r.id, { ...r }]));
|
||||||
|
// Tiny CAS helper that mirrors the real `updateMany({where:{status:in}})`
|
||||||
|
// semantics — only flips the row if the current status matches.
|
||||||
|
const cas = (id: string, allowed: InferenceTaskStatus[], patch: Partial<InferenceTask>): InferenceTask | null => {
|
||||||
|
const row = rows.get(id);
|
||||||
|
if (row === undefined) return null;
|
||||||
|
if (!allowed.includes(row.status)) return null;
|
||||||
|
const next = { ...row, ...patch };
|
||||||
|
rows.set(id, next);
|
||||||
|
return next;
|
||||||
|
};
|
||||||
|
return {
|
||||||
|
create: vi.fn(async (data) => {
|
||||||
|
const row = makeRow({
|
||||||
|
id: `task-${rows.size + 1}`,
|
||||||
|
poolName: data.poolName,
|
||||||
|
llmName: data.llmName,
|
||||||
|
model: data.model,
|
||||||
|
tier: data.tier ?? null,
|
||||||
|
requestBody: data.requestBody as InferenceTask['requestBody'],
|
||||||
|
streaming: data.streaming,
|
||||||
|
ownerId: data.ownerId,
|
||||||
|
agentId: data.agentId ?? null,
|
||||||
|
});
|
||||||
|
rows.set(row.id, row);
|
||||||
|
return row;
|
||||||
|
}),
|
||||||
|
findById: vi.fn(async (id) => rows.get(id) ?? null),
|
||||||
|
findPendingForPools: vi.fn(async (poolNames, limit) => {
|
||||||
|
const out = [...rows.values()]
|
||||||
|
.filter((r) => r.status === 'pending' && poolNames.includes(r.poolName))
|
||||||
|
.sort((a, b) => a.createdAt.getTime() - b.createdAt.getTime());
|
||||||
|
return limit !== undefined ? out.slice(0, limit) : out;
|
||||||
|
}),
|
||||||
|
findHeldBy: vi.fn(async (claimedBy) =>
|
||||||
|
[...rows.values()].filter((r) => r.claimedBy === claimedBy && (r.status === 'claimed' || r.status === 'running')),
|
||||||
|
),
|
||||||
|
list: vi.fn(async (filter) => {
|
||||||
|
let out = [...rows.values()];
|
||||||
|
if (filter.ownerId !== undefined) out = out.filter((r) => r.ownerId === filter.ownerId);
|
||||||
|
if (filter.poolName !== undefined) out = out.filter((r) => r.poolName === filter.poolName);
|
||||||
|
if (filter.agentId !== undefined) out = out.filter((r) => r.agentId === filter.agentId);
|
||||||
|
if (filter.status !== undefined) {
|
||||||
|
const statuses = Array.isArray(filter.status) ? filter.status : [filter.status];
|
||||||
|
out = out.filter((r) => statuses.includes(r.status));
|
||||||
|
}
|
||||||
|
out.sort((a, b) => b.createdAt.getTime() - a.createdAt.getTime());
|
||||||
|
return filter.limit !== undefined ? out.slice(0, filter.limit) : out;
|
||||||
|
}),
|
||||||
|
tryClaim: vi.fn(async (id, claimedBy, claimedAt) =>
|
||||||
|
cas(id, ['pending'], { status: 'claimed', claimedBy, claimedAt }),
|
||||||
|
),
|
||||||
|
markRunning: vi.fn(async (id, at) =>
|
||||||
|
cas(id, ['claimed', 'running'], { status: 'running', streamStartedAt: at }),
|
||||||
|
),
|
||||||
|
markCompleted: vi.fn(async (id, body, at) =>
|
||||||
|
cas(id, ['pending', 'claimed', 'running'], {
|
||||||
|
status: 'completed',
|
||||||
|
responseBody: (body ?? null) as InferenceTask['responseBody'],
|
||||||
|
completedAt: at,
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
markError: vi.fn(async (id, errorMessage, at) =>
|
||||||
|
cas(id, ['pending', 'claimed', 'running'], { status: 'error', errorMessage, completedAt: at }),
|
||||||
|
),
|
||||||
|
markCancelled: vi.fn(async (id, at) =>
|
||||||
|
cas(id, ['pending', 'claimed', 'running'], { status: 'cancelled', completedAt: at }),
|
||||||
|
),
|
||||||
|
revertToPending: vi.fn(async (id) =>
|
||||||
|
cas(id, ['claimed', 'running'], { status: 'pending', claimedBy: null, claimedAt: null, streamStartedAt: null }),
|
||||||
|
),
|
||||||
|
findStalePending: vi.fn(async (cutoff) =>
|
||||||
|
[...rows.values()].filter((r) => r.status === 'pending' && r.createdAt.getTime() < cutoff.getTime()),
|
||||||
|
),
|
||||||
|
findExpiredTerminal: vi.fn(async (cutoff) =>
|
||||||
|
[...rows.values()].filter((r) =>
|
||||||
|
(r.status === 'completed' || r.status === 'error' || r.status === 'cancelled')
|
||||||
|
&& r.completedAt !== null
|
||||||
|
&& r.completedAt.getTime() < cutoff.getTime(),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
deleteMany: vi.fn(async (ids) => {
|
||||||
|
let n = 0;
|
||||||
|
for (const id of ids) if (rows.delete(id)) n += 1;
|
||||||
|
return n;
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
describe('InferenceTaskService — state machine', () => {
|
||||||
|
it('enqueue creates a pending row with the given pool/llm/model', async () => {
|
||||||
|
const repo = mockRepo();
|
||||||
|
const svc = new InferenceTaskService(repo);
|
||||||
|
const t = await svc.enqueue({
|
||||||
|
poolName: 'pool-a',
|
||||||
|
llmName: 'a-1',
|
||||||
|
model: 'qwen3',
|
||||||
|
requestBody: { messages: [] },
|
||||||
|
streaming: false,
|
||||||
|
ownerId: 'owner-1',
|
||||||
|
});
|
||||||
|
expect(t.status).toBe('pending');
|
||||||
|
expect(t.poolName).toBe('pool-a');
|
||||||
|
expect(t.claimedBy).toBeNull();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('tryClaim races: only one of two concurrent claimers gets the row', async () => {
|
||||||
|
const repo = mockRepo();
|
||||||
|
const svc = new InferenceTaskService(repo);
|
||||||
|
const t = await svc.enqueue({
|
||||||
|
poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, streaming: false, ownerId: 'o',
|
||||||
|
});
|
||||||
|
// Both workers issue tryClaim at the "same time". The repo's CAS
|
||||||
|
// serializes them — first claim wins, second sees a non-pending
|
||||||
|
// row and returns null.
|
||||||
|
const [a, b] = await Promise.all([svc.tryClaim(t.id, 'sess-A'), svc.tryClaim(t.id, 'sess-B')]);
|
||||||
|
const winners = [a, b].filter((r) => r !== null);
|
||||||
|
const losers = [a, b].filter((r) => r === null);
|
||||||
|
expect(winners).toHaveLength(1);
|
||||||
|
expect(losers).toHaveLength(1);
|
||||||
|
expect(winners[0]!.status).toBe('claimed');
|
||||||
|
expect(['sess-A', 'sess-B']).toContain(winners[0]!.claimedBy);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('complete after claim transitions claimed → completed and stores responseBody', async () => {
|
||||||
|
const repo = mockRepo();
|
||||||
|
const svc = new InferenceTaskService(repo);
|
||||||
|
const t = await svc.enqueue({ poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, streaming: false, ownerId: 'o' });
|
||||||
|
await svc.tryClaim(t.id, 'sess-A');
|
||||||
|
const done = await svc.complete(t.id, { choices: [{ message: { content: 'hi' } }] });
|
||||||
|
expect(done?.status).toBe('completed');
|
||||||
|
expect(done?.responseBody).toEqual({ choices: [{ message: { content: 'hi' } }] });
|
||||||
|
expect(done?.completedAt).not.toBeNull();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('refuses double-complete (idempotent terminal)', async () => {
|
||||||
|
const repo = mockRepo();
|
||||||
|
const svc = new InferenceTaskService(repo);
|
||||||
|
const t = await svc.enqueue({ poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, streaming: false, ownerId: 'o' });
|
||||||
|
const first = await svc.complete(t.id, { ok: 1 });
|
||||||
|
expect(first?.status).toBe('completed');
|
||||||
|
// Second worker tries to complete the same task — CAS rejects because
|
||||||
|
// the row is no longer in a non-terminal state.
|
||||||
|
const second = await svc.complete(t.id, { ok: 2 });
|
||||||
|
expect(second).toBeNull();
|
||||||
|
// First completion's body is preserved.
|
||||||
|
const reread = await svc.findById(t.id);
|
||||||
|
expect(reread?.responseBody).toEqual({ ok: 1 });
|
||||||
|
});
|
||||||
|
|
||||||
|
it('revertHeldBy reverts every claimed/running row owned by a session and leaves terminals alone', async () => {
|
||||||
|
const repo = mockRepo();
|
||||||
|
const svc = new InferenceTaskService(repo);
|
||||||
|
const t1 = await svc.enqueue({ poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, streaming: false, ownerId: 'o' });
|
||||||
|
const t2 = await svc.enqueue({ poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, streaming: false, ownerId: 'o' });
|
||||||
|
const t3 = await svc.enqueue({ poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, streaming: false, ownerId: 'o' });
|
||||||
|
await svc.tryClaim(t1.id, 'sess-A');
|
||||||
|
await svc.tryClaim(t2.id, 'sess-A');
|
||||||
|
await svc.markRunning(t2.id);
|
||||||
|
await svc.tryClaim(t3.id, 'sess-A');
|
||||||
|
await svc.complete(t3.id, { ok: 1 }); // t3 finished before disconnect
|
||||||
|
|
||||||
|
const reverted = await svc.revertHeldBy('sess-A');
|
||||||
|
expect(reverted.map((r) => r.id).sort()).toEqual([t1.id, t2.id].sort());
|
||||||
|
expect((await svc.findById(t1.id))?.status).toBe('pending');
|
||||||
|
expect((await svc.findById(t2.id))?.status).toBe('pending');
|
||||||
|
// t3 stayed completed — terminal rows are not reverted on disconnect.
|
||||||
|
expect((await svc.findById(t3.id))?.status).toBe('completed');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('cancel from pending records cancelled and emits terminal event', async () => {
|
||||||
|
const repo = mockRepo();
|
||||||
|
const svc = new InferenceTaskService(repo);
|
||||||
|
const t = await svc.enqueue({ poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, streaming: false, ownerId: 'o' });
|
||||||
|
const cancelled = await svc.cancel(t.id);
|
||||||
|
expect(cancelled?.status).toBe('cancelled');
|
||||||
|
// A subsequent complete must fail (cancelled is terminal).
|
||||||
|
const result = await svc.complete(t.id, { ok: 1 });
|
||||||
|
expect(result).toBeNull();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('InferenceTaskService — waitFor signals', () => {
|
||||||
|
it('resolves immediately when the row is already terminal at subscribe time', async () => {
|
||||||
|
const repo = mockRepo();
|
||||||
|
const svc = new InferenceTaskService(repo);
|
||||||
|
const t = await svc.enqueue({ poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, streaming: false, ownerId: 'o' });
|
||||||
|
await svc.complete(t.id, { ok: 1 });
|
||||||
|
const waiter = svc.waitFor(t.id, 1_000);
|
||||||
|
const final = await waiter.done;
|
||||||
|
expect(final.status).toBe('completed');
|
||||||
|
expect(final.responseBody).toEqual({ ok: 1 });
|
||||||
|
});
|
||||||
|
|
||||||
|
it('wakes a waiter on complete event without polling the DB', async () => {
|
||||||
|
const repo = mockRepo();
|
||||||
|
const svc = new InferenceTaskService(repo);
|
||||||
|
const t = await svc.enqueue({ poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, streaming: false, ownerId: 'o' });
|
||||||
|
const waiter = svc.waitFor(t.id, 5_000);
|
||||||
|
// Fire the complete after a microtask so the waiter is definitely
|
||||||
|
// already subscribed to the terminal event.
|
||||||
|
setTimeout(() => { void svc.complete(t.id, { ok: 1 }); }, 10);
|
||||||
|
const final = await waiter.done;
|
||||||
|
expect(final.status).toBe('completed');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('wakes the chunks generator on pushChunk and ends on terminal', async () => {
|
||||||
|
const repo = mockRepo();
|
||||||
|
const svc = new InferenceTaskService(repo);
|
||||||
|
const t = await svc.enqueue({ poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, streaming: true, ownerId: 'o' });
|
||||||
|
const waiter = svc.waitFor(t.id, 5_000);
|
||||||
|
|
||||||
|
setTimeout(() => {
|
||||||
|
svc.pushChunk(t.id, { data: 'hello ' });
|
||||||
|
svc.pushChunk(t.id, { data: 'world' });
|
||||||
|
void svc.complete(t.id, { ok: 1 });
|
||||||
|
}, 10);
|
||||||
|
|
||||||
|
const seen: string[] = [];
|
||||||
|
for await (const c of waiter.chunks) {
|
||||||
|
seen.push(c.data);
|
||||||
|
}
|
||||||
|
expect(seen).toEqual(['hello ', 'world']);
|
||||||
|
const final = await waiter.done;
|
||||||
|
expect(final.status).toBe('completed');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('throws on cancellation with a clear error message', async () => {
|
||||||
|
const repo = mockRepo();
|
||||||
|
const svc = new InferenceTaskService(repo);
|
||||||
|
const t = await svc.enqueue({ poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, streaming: false, ownerId: 'o' });
|
||||||
|
const waiter = svc.waitFor(t.id, 5_000);
|
||||||
|
setTimeout(() => { void svc.cancel(t.id); }, 10);
|
||||||
|
await expect(waiter.done).rejects.toThrow(/cancelled/i);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('throws on error and surfaces the worker\'s errorMessage', async () => {
|
||||||
|
const repo = mockRepo();
|
||||||
|
const svc = new InferenceTaskService(repo);
|
||||||
|
const t = await svc.enqueue({ poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, streaming: false, ownerId: 'o' });
|
||||||
|
const waiter = svc.waitFor(t.id, 5_000);
|
||||||
|
setTimeout(() => { void svc.fail(t.id, 'upstream 500'); }, 10);
|
||||||
|
await expect(waiter.done).rejects.toThrow(/upstream 500/);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('times out when no terminal event arrives within the deadline', async () => {
|
||||||
|
const repo = mockRepo();
|
||||||
|
const svc = new InferenceTaskService(repo);
|
||||||
|
const t = await svc.enqueue({ poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, streaming: false, ownerId: 'o' });
|
||||||
|
const waiter = svc.waitFor(t.id, 30);
|
||||||
|
await expect(waiter.done).rejects.toThrow(/timed out/);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('InferenceTaskService — gcSweep', () => {
|
||||||
|
it('flips stale pending rows to error AND deletes expired terminal rows', async () => {
|
||||||
|
const now = new Date('2026-04-28T00:00:00Z');
|
||||||
|
const fixedClock = (): Date => now;
|
||||||
|
const repo = mockRepo();
|
||||||
|
const svc = new InferenceTaskService(repo, fixedClock);
|
||||||
|
|
||||||
|
// 90 min old pending — past the 1h pendingTimeout cutoff → error.
|
||||||
|
const stale = await svc.enqueue({ poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, streaming: false, ownerId: 'o' });
|
||||||
|
// Backdate via direct fixture mutation — easier than wiring a clock through enqueue.
|
||||||
|
const staleRow = (await svc.findById(stale.id))!;
|
||||||
|
(staleRow as { createdAt: Date }).createdAt = new Date(now.getTime() - 90 * 60 * 1000);
|
||||||
|
|
||||||
|
// 30 min old pending — within window, should not be touched.
|
||||||
|
const fresh = await svc.enqueue({ poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, streaming: false, ownerId: 'o' });
|
||||||
|
const freshRow = (await svc.findById(fresh.id))!;
|
||||||
|
(freshRow as { createdAt: Date }).createdAt = new Date(now.getTime() - 30 * 60 * 1000);
|
||||||
|
|
||||||
|
// 8 day-old completed — past the 7d retention → delete.
|
||||||
|
const old = await svc.enqueue({ poolName: 'p', llmName: 'l', model: 'm', requestBody: {}, streaming: false, ownerId: 'o' });
|
||||||
|
await svc.complete(old.id, { ok: 1 });
|
||||||
|
const oldRow = (await svc.findById(old.id))!;
|
||||||
|
(oldRow as { completedAt: Date }).completedAt = new Date(now.getTime() - 8 * 24 * 60 * 60 * 1000);
|
||||||
|
|
||||||
|
const result = await svc.gcSweep({
|
||||||
|
pendingTimeoutMs: 60 * 60 * 1000, // 1h
|
||||||
|
terminalRetentionMs: 7 * 24 * 60 * 60 * 1000, // 7d
|
||||||
|
});
|
||||||
|
expect(result.erroredOut).toBe(1);
|
||||||
|
expect(result.deleted).toBe(1);
|
||||||
|
|
||||||
|
// Stale pending was flipped to error.
|
||||||
|
const staleAfter = await svc.findById(stale.id);
|
||||||
|
expect(staleAfter?.status).toBe('error');
|
||||||
|
expect(staleAfter?.errorMessage).toMatch(/expired in pending/);
|
||||||
|
// Old completed is gone.
|
||||||
|
expect(await svc.findById(old.id)).toBeNull();
|
||||||
|
// Fresh pending untouched.
|
||||||
|
expect((await svc.findById(fresh.id))?.status).toBe('pending');
|
||||||
|
});
|
||||||
|
});
|
||||||
@@ -244,6 +244,9 @@ describe('POST /api/v1/llms/:name/infer', () => {
|
|||||||
'claude',
|
'claude',
|
||||||
expect.objectContaining({ messages: expect.any(Array) }),
|
expect.objectContaining({ messages: expect.any(Array) }),
|
||||||
false,
|
false,
|
||||||
|
// v5: direct infer route passes failFast:true so a downed publisher
|
||||||
|
// returns 503 immediately instead of queueing the task durably.
|
||||||
|
{ failFast: true },
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,10 @@
|
|||||||
import { describe, it, expect, vi } from 'vitest';
|
import { describe, it, expect, vi } from 'vitest';
|
||||||
import { VirtualLlmService, type VirtualSessionHandle } from '../src/services/virtual-llm.service.js';
|
import { VirtualLlmService, type VirtualSessionHandle } from '../src/services/virtual-llm.service.js';
|
||||||
|
import { InferenceTaskService } from '../src/services/inference-task.service.js';
|
||||||
|
import type { IInferenceTaskService } from '../src/services/inference-task.service.js';
|
||||||
|
import type { IInferenceTaskRepository } from '../src/repositories/inference-task.repository.js';
|
||||||
import type { ILlmRepository } from '../src/repositories/llm.repository.js';
|
import type { ILlmRepository } from '../src/repositories/llm.repository.js';
|
||||||
import type { Llm } from '@prisma/client';
|
import type { Llm, InferenceTask, InferenceTaskStatus } from '@prisma/client';
|
||||||
|
|
||||||
function makeLlm(overrides: Partial<Llm> = {}): Llm {
|
function makeLlm(overrides: Partial<Llm> = {}): Llm {
|
||||||
return {
|
return {
|
||||||
@@ -15,6 +18,7 @@ function makeLlm(overrides: Partial<Llm> = {}): Llm {
|
|||||||
apiKeySecretId: null,
|
apiKeySecretId: null,
|
||||||
apiKeySecretKey: null,
|
apiKeySecretKey: null,
|
||||||
extraConfig: {} as Llm['extraConfig'],
|
extraConfig: {} as Llm['extraConfig'],
|
||||||
|
poolName: null,
|
||||||
kind: 'virtual',
|
kind: 'virtual',
|
||||||
providerSessionId: 's-1',
|
providerSessionId: 's-1',
|
||||||
lastHeartbeatAt: new Date(),
|
lastHeartbeatAt: new Date(),
|
||||||
@@ -27,6 +31,105 @@ function makeLlm(overrides: Partial<Llm> = {}): Llm {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Drop-in mock of `InferenceTaskService` backed by a Map. We mirror just
|
||||||
|
* enough of the real service's signaling — events for terminal/chunk —
|
||||||
|
* so VirtualLlmService's enqueue/result flows behave the same way they
|
||||||
|
* do in production. Nothing here talks to Postgres.
|
||||||
|
*/
|
||||||
|
function mockTaskService(): IInferenceTaskService {
|
||||||
|
// Build a minimal repo for the real InferenceTaskService — that way we
|
||||||
|
// exercise the actual event-emitter wakeup logic, just without a DB.
|
||||||
|
const rows = new Map<string, InferenceTask>();
|
||||||
|
let n = 0;
|
||||||
|
const repo: IInferenceTaskRepository = {
|
||||||
|
create: vi.fn(async (data) => {
|
||||||
|
n += 1;
|
||||||
|
const row: InferenceTask = {
|
||||||
|
id: `task-${String(n)}`,
|
||||||
|
status: 'pending',
|
||||||
|
poolName: data.poolName,
|
||||||
|
llmName: data.llmName,
|
||||||
|
model: data.model,
|
||||||
|
tier: data.tier ?? null,
|
||||||
|
claimedBy: null,
|
||||||
|
requestBody: data.requestBody as InferenceTask['requestBody'],
|
||||||
|
responseBody: null,
|
||||||
|
errorMessage: null,
|
||||||
|
streaming: data.streaming,
|
||||||
|
createdAt: new Date(),
|
||||||
|
claimedAt: null,
|
||||||
|
streamStartedAt: null,
|
||||||
|
completedAt: null,
|
||||||
|
ownerId: data.ownerId,
|
||||||
|
agentId: data.agentId ?? null,
|
||||||
|
};
|
||||||
|
rows.set(row.id, row);
|
||||||
|
return row;
|
||||||
|
}),
|
||||||
|
findById: vi.fn(async (id) => rows.get(id) ?? null),
|
||||||
|
findPendingForPools: vi.fn(async (poolNames: string[]) =>
|
||||||
|
[...rows.values()].filter((r) => r.status === 'pending' && poolNames.includes(r.poolName)),
|
||||||
|
),
|
||||||
|
findHeldBy: vi.fn(async (claimedBy: string) =>
|
||||||
|
[...rows.values()].filter((r) =>
|
||||||
|
r.claimedBy === claimedBy
|
||||||
|
&& (r.status === 'claimed' || r.status === 'running')),
|
||||||
|
),
|
||||||
|
list: vi.fn(async () => [...rows.values()]),
|
||||||
|
tryClaim: vi.fn(async (id, claimedBy, claimedAt) => {
|
||||||
|
const r = rows.get(id);
|
||||||
|
if (r === undefined || r.status !== 'pending') return null;
|
||||||
|
const next = { ...r, status: 'claimed' as InferenceTaskStatus, claimedBy, claimedAt };
|
||||||
|
rows.set(id, next);
|
||||||
|
return next;
|
||||||
|
}),
|
||||||
|
markRunning: vi.fn(async (id, at) => {
|
||||||
|
const r = rows.get(id);
|
||||||
|
if (r === undefined || (r.status !== 'claimed' && r.status !== 'running')) return null;
|
||||||
|
const next = { ...r, status: 'running' as InferenceTaskStatus, streamStartedAt: at };
|
||||||
|
rows.set(id, next);
|
||||||
|
return next;
|
||||||
|
}),
|
||||||
|
markCompleted: vi.fn(async (id, body, at) => {
|
||||||
|
const r = rows.get(id);
|
||||||
|
if (r === undefined || r.status === 'completed' || r.status === 'error' || r.status === 'cancelled') return null;
|
||||||
|
const next = { ...r, status: 'completed' as InferenceTaskStatus, responseBody: (body ?? null) as InferenceTask['responseBody'], completedAt: at };
|
||||||
|
rows.set(id, next);
|
||||||
|
return next;
|
||||||
|
}),
|
||||||
|
markError: vi.fn(async (id, errorMessage, at) => {
|
||||||
|
const r = rows.get(id);
|
||||||
|
if (r === undefined || r.status === 'completed' || r.status === 'error' || r.status === 'cancelled') return null;
|
||||||
|
const next = { ...r, status: 'error' as InferenceTaskStatus, errorMessage, completedAt: at };
|
||||||
|
rows.set(id, next);
|
||||||
|
return next;
|
||||||
|
}),
|
||||||
|
markCancelled: vi.fn(async (id, at) => {
|
||||||
|
const r = rows.get(id);
|
||||||
|
if (r === undefined || r.status === 'completed' || r.status === 'error' || r.status === 'cancelled') return null;
|
||||||
|
const next = { ...r, status: 'cancelled' as InferenceTaskStatus, completedAt: at };
|
||||||
|
rows.set(id, next);
|
||||||
|
return next;
|
||||||
|
}),
|
||||||
|
revertToPending: vi.fn(async (id) => {
|
||||||
|
const r = rows.get(id);
|
||||||
|
if (r === undefined || (r.status !== 'claimed' && r.status !== 'running')) return null;
|
||||||
|
const next = { ...r, status: 'pending' as InferenceTaskStatus, claimedBy: null, claimedAt: null, streamStartedAt: null };
|
||||||
|
rows.set(id, next);
|
||||||
|
return next;
|
||||||
|
}),
|
||||||
|
findStalePending: vi.fn(async () => []),
|
||||||
|
findExpiredTerminal: vi.fn(async () => []),
|
||||||
|
deleteMany: vi.fn(async (ids) => {
|
||||||
|
let c = 0;
|
||||||
|
for (const id of ids) if (rows.delete(id)) c += 1;
|
||||||
|
return c;
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
return new InferenceTaskService(repo);
|
||||||
|
}
|
||||||
|
|
||||||
function mockRepo(initial: Llm[] = []): ILlmRepository {
|
function mockRepo(initial: Llm[] = []): ILlmRepository {
|
||||||
const rows = new Map<string, Llm>(initial.map((l) => [l.id, l]));
|
const rows = new Map<string, Llm>(initial.map((l) => [l.id, l]));
|
||||||
let counter = rows.size;
|
let counter = rows.size;
|
||||||
@@ -38,6 +141,14 @@ function mockRepo(initial: Llm[] = []): ILlmRepository {
|
|||||||
return null;
|
return null;
|
||||||
}),
|
}),
|
||||||
findByTier: vi.fn(async () => []),
|
findByTier: vi.fn(async () => []),
|
||||||
|
findByPoolName: vi.fn(async (poolName: string) => {
|
||||||
|
const out: Llm[] = [];
|
||||||
|
for (const l of rows.values()) {
|
||||||
|
if (l.poolName === poolName) out.push(l);
|
||||||
|
else if (l.poolName === null && l.name === poolName) out.push(l);
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
}),
|
||||||
findBySessionId: vi.fn(async (sid: string) =>
|
findBySessionId: vi.fn(async (sid: string) =>
|
||||||
[...rows.values()].filter((l) => l.providerSessionId === sid)),
|
[...rows.values()].filter((l) => l.providerSessionId === sid)),
|
||||||
findStaleVirtuals: vi.fn(async (cutoff: Date) =>
|
findStaleVirtuals: vi.fn(async (cutoff: Date) =>
|
||||||
@@ -105,7 +216,7 @@ function fakeSession(): VirtualSessionHandle & { tasks: Array<unknown>; alive: b
|
|||||||
describe('VirtualLlmService', () => {
|
describe('VirtualLlmService', () => {
|
||||||
it('register inserts new virtual rows with active status + sessionId', async () => {
|
it('register inserts new virtual rows with active status + sessionId', async () => {
|
||||||
const repo = mockRepo();
|
const repo = mockRepo();
|
||||||
const svc = new VirtualLlmService(repo);
|
const svc = new VirtualLlmService(repo, undefined, mockTaskService());
|
||||||
const { providerSessionId, llms } = await svc.register({
|
const { providerSessionId, llms } = await svc.register({
|
||||||
providerSessionId: null,
|
providerSessionId: null,
|
||||||
providers: [
|
providers: [
|
||||||
@@ -122,7 +233,7 @@ describe('VirtualLlmService', () => {
|
|||||||
|
|
||||||
it('register reuses the same row on sticky reconnect (same name + sessionId)', async () => {
|
it('register reuses the same row on sticky reconnect (same name + sessionId)', async () => {
|
||||||
const repo = mockRepo();
|
const repo = mockRepo();
|
||||||
const svc = new VirtualLlmService(repo);
|
const svc = new VirtualLlmService(repo, undefined, mockTaskService());
|
||||||
const first = await svc.register({
|
const first = await svc.register({
|
||||||
providerSessionId: 'fixed-id',
|
providerSessionId: 'fixed-id',
|
||||||
providers: [{ name: 'vllm-local', type: 'openai', model: 'm' }],
|
providers: [{ name: 'vllm-local', type: 'openai', model: 'm' }],
|
||||||
@@ -140,7 +251,7 @@ describe('VirtualLlmService', () => {
|
|||||||
|
|
||||||
it('register refuses to overwrite a public LLM with the same name', async () => {
|
it('register refuses to overwrite a public LLM with the same name', async () => {
|
||||||
const repo = mockRepo([makeLlm({ name: 'qwen3-thinking', kind: 'public', providerSessionId: null })]);
|
const repo = mockRepo([makeLlm({ name: 'qwen3-thinking', kind: 'public', providerSessionId: null })]);
|
||||||
const svc = new VirtualLlmService(repo);
|
const svc = new VirtualLlmService(repo, undefined, mockTaskService());
|
||||||
await expect(svc.register({
|
await expect(svc.register({
|
||||||
providerSessionId: 'sess-x',
|
providerSessionId: 'sess-x',
|
||||||
providers: [{ name: 'qwen3-thinking', type: 'openai', model: 'm' }],
|
providers: [{ name: 'qwen3-thinking', type: 'openai', model: 'm' }],
|
||||||
@@ -149,7 +260,7 @@ describe('VirtualLlmService', () => {
|
|||||||
|
|
||||||
it('register refuses if another active session owns the name', async () => {
|
it('register refuses if another active session owns the name', async () => {
|
||||||
const repo = mockRepo([makeLlm({ name: 'vllm-local', providerSessionId: 'other', status: 'active' })]);
|
const repo = mockRepo([makeLlm({ name: 'vllm-local', providerSessionId: 'other', status: 'active' })]);
|
||||||
const svc = new VirtualLlmService(repo);
|
const svc = new VirtualLlmService(repo, undefined, mockTaskService());
|
||||||
await expect(svc.register({
|
await expect(svc.register({
|
||||||
providerSessionId: 'mine',
|
providerSessionId: 'mine',
|
||||||
providers: [{ name: 'vllm-local', type: 'openai', model: 'm' }],
|
providers: [{ name: 'vllm-local', type: 'openai', model: 'm' }],
|
||||||
@@ -161,7 +272,7 @@ describe('VirtualLlmService', () => {
|
|||||||
name: 'vllm-local', providerSessionId: 'old-session',
|
name: 'vllm-local', providerSessionId: 'old-session',
|
||||||
status: 'inactive', inactiveSince: new Date(),
|
status: 'inactive', inactiveSince: new Date(),
|
||||||
})]);
|
})]);
|
||||||
const svc = new VirtualLlmService(repo);
|
const svc = new VirtualLlmService(repo, undefined, mockTaskService());
|
||||||
const { llms } = await svc.register({
|
const { llms } = await svc.register({
|
||||||
providerSessionId: 'new-session',
|
providerSessionId: 'new-session',
|
||||||
providers: [{ name: 'vllm-local', type: 'openai', model: 'm' }],
|
providers: [{ name: 'vllm-local', type: 'openai', model: 'm' }],
|
||||||
@@ -177,7 +288,7 @@ describe('VirtualLlmService', () => {
|
|||||||
name: 'vllm-local', providerSessionId: 'sess', status: 'inactive',
|
name: 'vllm-local', providerSessionId: 'sess', status: 'inactive',
|
||||||
lastHeartbeatAt: past, inactiveSince: past,
|
lastHeartbeatAt: past, inactiveSince: past,
|
||||||
})]);
|
})]);
|
||||||
const svc = new VirtualLlmService(repo);
|
const svc = new VirtualLlmService(repo, undefined, mockTaskService());
|
||||||
await svc.heartbeat('sess');
|
await svc.heartbeat('sess');
|
||||||
const row = await repo.findByName('vllm-local');
|
const row = await repo.findByName('vllm-local');
|
||||||
expect(row?.status).toBe('active');
|
expect(row?.status).toBe('active');
|
||||||
@@ -191,7 +302,7 @@ describe('VirtualLlmService', () => {
|
|||||||
makeLlm({ name: 'b', providerSessionId: 'sess' }),
|
makeLlm({ name: 'b', providerSessionId: 'sess' }),
|
||||||
makeLlm({ name: 'c', providerSessionId: 'other' }),
|
makeLlm({ name: 'c', providerSessionId: 'other' }),
|
||||||
]);
|
]);
|
||||||
const svc = new VirtualLlmService(repo);
|
const svc = new VirtualLlmService(repo, undefined, mockTaskService());
|
||||||
svc.bindSession('sess', fakeSession());
|
svc.bindSession('sess', fakeSession());
|
||||||
await svc.unbindSession('sess');
|
await svc.unbindSession('sess');
|
||||||
expect((await repo.findByName('a'))?.status).toBe('inactive');
|
expect((await repo.findByName('a'))?.status).toBe('inactive');
|
||||||
@@ -201,7 +312,7 @@ describe('VirtualLlmService', () => {
|
|||||||
|
|
||||||
it('enqueueInferTask pushes a task frame to the SSE session', async () => {
|
it('enqueueInferTask pushes a task frame to the SSE session', async () => {
|
||||||
const repo = mockRepo([makeLlm({ name: 'vllm-local', providerSessionId: 'sess' })]);
|
const repo = mockRepo([makeLlm({ name: 'vllm-local', providerSessionId: 'sess' })]);
|
||||||
const svc = new VirtualLlmService(repo);
|
const svc = new VirtualLlmService(repo, undefined, mockTaskService());
|
||||||
const session = fakeSession();
|
const session = fakeSession();
|
||||||
svc.bindSession('sess', session);
|
svc.bindSession('sess', session);
|
||||||
|
|
||||||
@@ -218,26 +329,41 @@ describe('VirtualLlmService', () => {
|
|||||||
expect(t.streaming).toBe(false);
|
expect(t.streaming).toBe(false);
|
||||||
});
|
});
|
||||||
|
|
||||||
it('enqueueInferTask rejects when the publisher is offline (no session bound)', async () => {
|
it('enqueueInferTask queues the task when no session is bound (durable, drains on bind)', async () => {
|
||||||
|
// v5 semantic change: with a durable queue underneath, "no worker
|
||||||
|
// up" no longer rejects — the row stays pending and a future
|
||||||
|
// bindSession drains it. Caller's HTTP handler awaits on ref.done
|
||||||
|
// and bounds itself with INFER_AWAIT_TIMEOUT_MS; from the service's
|
||||||
|
// POV the enqueue itself succeeds.
|
||||||
const repo = mockRepo([makeLlm({ name: 'vllm-local', providerSessionId: 'sess' })]);
|
const repo = mockRepo([makeLlm({ name: 'vllm-local', providerSessionId: 'sess' })]);
|
||||||
const svc = new VirtualLlmService(repo);
|
const tasks = mockTaskService();
|
||||||
await expect(
|
const svc = new VirtualLlmService(repo, undefined, tasks);
|
||||||
svc.enqueueInferTask('vllm-local', { model: 'm', messages: [] }, false),
|
const ref = await svc.enqueueInferTask('vllm-local', { model: 'm', messages: [] }, false);
|
||||||
).rejects.toThrow(/no live SSE session|publisher offline/);
|
expect(ref.taskId).toMatch(/^task-/);
|
||||||
|
// The row exists and is still pending — no claim happened.
|
||||||
|
const row = await tasks.findById(ref.taskId);
|
||||||
|
expect(row?.status).toBe('pending');
|
||||||
|
expect(row?.claimedBy).toBeNull();
|
||||||
});
|
});
|
||||||
|
|
||||||
it('enqueueInferTask rejects when the row is inactive', async () => {
|
it('enqueueInferTask still queues against an inactive row (pool may have a sibling worker)', async () => {
|
||||||
|
// v5 semantic change: status=inactive on a specific Llm doesn't
|
||||||
|
// mean the pool is dead — another mcplocal publishing the same
|
||||||
|
// poolName might be active. The dispatcher's bindSession drain
|
||||||
|
// matches by poolName, so even a "dead" pin queues correctly.
|
||||||
const repo = mockRepo([makeLlm({ name: 'vllm-local', providerSessionId: 'sess', status: 'inactive', inactiveSince: new Date() })]);
|
const repo = mockRepo([makeLlm({ name: 'vllm-local', providerSessionId: 'sess', status: 'inactive', inactiveSince: new Date() })]);
|
||||||
const svc = new VirtualLlmService(repo);
|
const tasks = mockTaskService();
|
||||||
svc.bindSession('sess', fakeSession());
|
const svc = new VirtualLlmService(repo, undefined, tasks);
|
||||||
await expect(
|
const ref = await svc.enqueueInferTask('vllm-local', { model: 'm', messages: [] }, false);
|
||||||
svc.enqueueInferTask('vllm-local', { model: 'm', messages: [] }, false),
|
const row = await tasks.findById(ref.taskId);
|
||||||
).rejects.toThrow(/inactive|publisher offline/);
|
expect(row?.status).toBe('pending');
|
||||||
|
// No frame pushed because no session is bound.
|
||||||
|
expect(row?.claimedBy).toBeNull();
|
||||||
});
|
});
|
||||||
|
|
||||||
it('enqueueInferTask rejects when the LLM is public (not virtual)', async () => {
|
it('enqueueInferTask rejects when the LLM is public (not virtual)', async () => {
|
||||||
const repo = mockRepo([makeLlm({ name: 'qwen3-thinking', kind: 'public', providerSessionId: null })]);
|
const repo = mockRepo([makeLlm({ name: 'qwen3-thinking', kind: 'public', providerSessionId: null })]);
|
||||||
const svc = new VirtualLlmService(repo);
|
const svc = new VirtualLlmService(repo, undefined, mockTaskService());
|
||||||
await expect(
|
await expect(
|
||||||
svc.enqueueInferTask('qwen3-thinking', { model: 'm', messages: [] }, false),
|
svc.enqueueInferTask('qwen3-thinking', { model: 'm', messages: [] }, false),
|
||||||
).rejects.toThrow(/not a virtual provider/);
|
).rejects.toThrow(/not a virtual provider/);
|
||||||
@@ -245,7 +371,7 @@ describe('VirtualLlmService', () => {
|
|||||||
|
|
||||||
it('completeTask resolves the pending non-streaming promise', async () => {
|
it('completeTask resolves the pending non-streaming promise', async () => {
|
||||||
const repo = mockRepo([makeLlm({ name: 'vllm-local', providerSessionId: 'sess' })]);
|
const repo = mockRepo([makeLlm({ name: 'vllm-local', providerSessionId: 'sess' })]);
|
||||||
const svc = new VirtualLlmService(repo);
|
const svc = new VirtualLlmService(repo, undefined, mockTaskService());
|
||||||
svc.bindSession('sess', fakeSession());
|
svc.bindSession('sess', fakeSession());
|
||||||
const ref = await svc.enqueueInferTask(
|
const ref = await svc.enqueueInferTask(
|
||||||
'vllm-local',
|
'vllm-local',
|
||||||
@@ -258,7 +384,7 @@ describe('VirtualLlmService', () => {
|
|||||||
|
|
||||||
it('streaming: pushTaskChunk fans chunks to subscribers; done resolves the ref', async () => {
|
it('streaming: pushTaskChunk fans chunks to subscribers; done resolves the ref', async () => {
|
||||||
const repo = mockRepo([makeLlm({ name: 'vllm-local', providerSessionId: 'sess' })]);
|
const repo = mockRepo([makeLlm({ name: 'vllm-local', providerSessionId: 'sess' })]);
|
||||||
const svc = new VirtualLlmService(repo);
|
const svc = new VirtualLlmService(repo, undefined, mockTaskService());
|
||||||
svc.bindSession('sess', fakeSession());
|
svc.bindSession('sess', fakeSession());
|
||||||
const ref = await svc.enqueueInferTask(
|
const ref = await svc.enqueueInferTask(
|
||||||
'vllm-local',
|
'vllm-local',
|
||||||
@@ -278,7 +404,7 @@ describe('VirtualLlmService', () => {
|
|||||||
|
|
||||||
it('failTask rejects the pending promise with a clear error', async () => {
|
it('failTask rejects the pending promise with a clear error', async () => {
|
||||||
const repo = mockRepo([makeLlm({ name: 'vllm-local', providerSessionId: 'sess' })]);
|
const repo = mockRepo([makeLlm({ name: 'vllm-local', providerSessionId: 'sess' })]);
|
||||||
const svc = new VirtualLlmService(repo);
|
const svc = new VirtualLlmService(repo, undefined, mockTaskService());
|
||||||
svc.bindSession('sess', fakeSession());
|
svc.bindSession('sess', fakeSession());
|
||||||
const ref = await svc.enqueueInferTask(
|
const ref = await svc.enqueueInferTask(
|
||||||
'vllm-local',
|
'vllm-local',
|
||||||
@@ -289,17 +415,34 @@ describe('VirtualLlmService', () => {
|
|||||||
await expect(ref.done).rejects.toThrow(/upstream blew up/);
|
await expect(ref.done).rejects.toThrow(/upstream blew up/);
|
||||||
});
|
});
|
||||||
|
|
||||||
it('unbindSession rejects in-flight tasks for that session', async () => {
|
it('unbindSession reverts claimed inference tasks to pending (durable re-queue, not reject)', async () => {
|
||||||
|
// v5 semantic change: a worker disconnecting mid-task no longer
|
||||||
|
// *rejects* the task. The row goes back to pending so another
|
||||||
|
// worker on the same pool can pick it up. The original caller's
|
||||||
|
// ref.done keeps waiting up to its 10-min INFER_AWAIT_TIMEOUT_MS;
|
||||||
|
// the same caller is what gets the result whichever worker
|
||||||
|
// ultimately delivers it.
|
||||||
const repo = mockRepo([makeLlm({ name: 'vllm-local', providerSessionId: 'sess' })]);
|
const repo = mockRepo([makeLlm({ name: 'vllm-local', providerSessionId: 'sess' })]);
|
||||||
const svc = new VirtualLlmService(repo);
|
const tasks = mockTaskService();
|
||||||
|
const svc = new VirtualLlmService(repo, undefined, tasks);
|
||||||
svc.bindSession('sess', fakeSession());
|
svc.bindSession('sess', fakeSession());
|
||||||
const ref = await svc.enqueueInferTask(
|
const ref = await svc.enqueueInferTask(
|
||||||
'vllm-local',
|
'vllm-local',
|
||||||
{ model: 'm', messages: [{ role: 'user', content: 'hi' }] },
|
{ model: 'm', messages: [{ role: 'user', content: 'hi' }] },
|
||||||
false,
|
false,
|
||||||
);
|
);
|
||||||
|
// After enqueue with a session up, the task is claimed.
|
||||||
|
let row = await tasks.findById(ref.taskId);
|
||||||
|
expect(row?.status).toBe('claimed');
|
||||||
|
expect(row?.claimedBy).toBe('sess');
|
||||||
|
|
||||||
await svc.unbindSession('sess');
|
await svc.unbindSession('sess');
|
||||||
await expect(ref.done).rejects.toThrow(/publisher disconnected/);
|
|
||||||
|
// After disconnect, claimed/running rows revert to pending — ready
|
||||||
|
// for the next worker to drain.
|
||||||
|
row = await tasks.findById(ref.taskId);
|
||||||
|
expect(row?.status).toBe('pending');
|
||||||
|
expect(row?.claimedBy).toBeNull();
|
||||||
});
|
});
|
||||||
|
|
||||||
it('gcSweep flips heartbeat-stale active virtuals to inactive', async () => {
|
it('gcSweep flips heartbeat-stale active virtuals to inactive', async () => {
|
||||||
@@ -309,7 +452,7 @@ describe('VirtualLlmService', () => {
|
|||||||
makeLlm({ name: 'stale', providerSessionId: 'a', status: 'active', lastHeartbeatAt: long }),
|
makeLlm({ name: 'stale', providerSessionId: 'a', status: 'active', lastHeartbeatAt: long }),
|
||||||
makeLlm({ name: 'fresh', providerSessionId: 'b', status: 'active', lastHeartbeatAt: recent }),
|
makeLlm({ name: 'fresh', providerSessionId: 'b', status: 'active', lastHeartbeatAt: recent }),
|
||||||
]);
|
]);
|
||||||
const svc = new VirtualLlmService(repo);
|
const svc = new VirtualLlmService(repo, undefined, mockTaskService());
|
||||||
const result = await svc.gcSweep();
|
const result = await svc.gcSweep();
|
||||||
expect(result.markedInactive).toBe(1);
|
expect(result.markedInactive).toBe(1);
|
||||||
expect((await repo.findByName('stale'))?.status).toBe('inactive');
|
expect((await repo.findByName('stale'))?.status).toBe('inactive');
|
||||||
@@ -324,7 +467,7 @@ describe('VirtualLlmService', () => {
|
|||||||
makeLlm({ name: 'recent', providerSessionId: 'b', status: 'inactive', inactiveSince: fresh }),
|
makeLlm({ name: 'recent', providerSessionId: 'b', status: 'inactive', inactiveSince: fresh }),
|
||||||
makeLlm({ name: 'public-survivor', providerSessionId: null, kind: 'public' }),
|
makeLlm({ name: 'public-survivor', providerSessionId: null, kind: 'public' }),
|
||||||
]);
|
]);
|
||||||
const svc = new VirtualLlmService(repo);
|
const svc = new VirtualLlmService(repo, undefined, mockTaskService());
|
||||||
const result = await svc.gcSweep();
|
const result = await svc.gcSweep();
|
||||||
expect(result.deleted).toBe(1);
|
expect(result.deleted).toBe(1);
|
||||||
expect(await repo.findByName('old')).toBeNull();
|
expect(await repo.findByName('old')).toBeNull();
|
||||||
@@ -336,7 +479,7 @@ describe('VirtualLlmService', () => {
|
|||||||
|
|
||||||
it('hibernating: dispatches a wake task first and waits for it to complete before infer', async () => {
|
it('hibernating: dispatches a wake task first and waits for it to complete before infer', async () => {
|
||||||
const repo = mockRepo([makeLlm({ name: 'sleeping', providerSessionId: 'sess', status: 'hibernating' })]);
|
const repo = mockRepo([makeLlm({ name: 'sleeping', providerSessionId: 'sess', status: 'hibernating' })]);
|
||||||
const svc = new VirtualLlmService(repo);
|
const svc = new VirtualLlmService(repo, undefined, mockTaskService());
|
||||||
const session = fakeSession();
|
const session = fakeSession();
|
||||||
svc.bindSession('sess', session);
|
svc.bindSession('sess', session);
|
||||||
|
|
||||||
@@ -370,7 +513,7 @@ describe('VirtualLlmService', () => {
|
|||||||
|
|
||||||
it('hibernating: concurrent infer requests share a single wake task', async () => {
|
it('hibernating: concurrent infer requests share a single wake task', async () => {
|
||||||
const repo = mockRepo([makeLlm({ name: 'sleeping', providerSessionId: 'sess', status: 'hibernating' })]);
|
const repo = mockRepo([makeLlm({ name: 'sleeping', providerSessionId: 'sess', status: 'hibernating' })]);
|
||||||
const svc = new VirtualLlmService(repo);
|
const svc = new VirtualLlmService(repo, undefined, mockTaskService());
|
||||||
const session = fakeSession();
|
const session = fakeSession();
|
||||||
svc.bindSession('sess', session);
|
svc.bindSession('sess', session);
|
||||||
|
|
||||||
@@ -398,7 +541,7 @@ describe('VirtualLlmService', () => {
|
|||||||
|
|
||||||
it('hibernating: rejects when the wake task fails', async () => {
|
it('hibernating: rejects when the wake task fails', async () => {
|
||||||
const repo = mockRepo([makeLlm({ name: 'broken', providerSessionId: 'sess', status: 'hibernating' })]);
|
const repo = mockRepo([makeLlm({ name: 'broken', providerSessionId: 'sess', status: 'hibernating' })]);
|
||||||
const svc = new VirtualLlmService(repo);
|
const svc = new VirtualLlmService(repo, undefined, mockTaskService());
|
||||||
svc.bindSession('sess', fakeSession());
|
svc.bindSession('sess', fakeSession());
|
||||||
|
|
||||||
const inferPromise = svc.enqueueInferTask(
|
const inferPromise = svc.enqueueInferTask(
|
||||||
@@ -408,12 +551,12 @@ describe('VirtualLlmService', () => {
|
|||||||
);
|
);
|
||||||
await new Promise((r) => setTimeout(r, 0));
|
await new Promise((r) => setTimeout(r, 0));
|
||||||
|
|
||||||
// Get the wake task id from the in-flight tasks map (its only entry).
|
// v5: wake tasks live in `wakeTasks` (in-memory). Inference tasks
|
||||||
// We test the failure path via failTask which is part of the public
|
// moved to the DB-backed queue but wake is publisher-control work
|
||||||
// surface used by the result-POST route handler.
|
// that doesn't need durability — we kept the in-memory map for it.
|
||||||
const taskIds: string[] = [];
|
const taskIds: string[] = [];
|
||||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||||
for (const id of (svc as any).tasksById.keys()) taskIds.push(id);
|
for (const id of (svc as any).wakeTasks.keys()) taskIds.push(id);
|
||||||
expect(taskIds).toHaveLength(1);
|
expect(taskIds).toHaveLength(1);
|
||||||
expect(svc.failTask(taskIds[0]!, new Error('wake recipe failed'))).toBe(true);
|
expect(svc.failTask(taskIds[0]!, new Error('wake recipe failed'))).toBe(true);
|
||||||
|
|
||||||
@@ -424,14 +567,28 @@ describe('VirtualLlmService', () => {
|
|||||||
expect(row?.status).toBe('hibernating');
|
expect(row?.status).toBe('hibernating');
|
||||||
});
|
});
|
||||||
|
|
||||||
it('inactive: still rejects with 503 (publisher offline) — wake path only fires for hibernating', async () => {
|
it('inactive: queues without firing the wake path — wake only triggers on status=hibernating', async () => {
|
||||||
|
// Coverage for the v5 inactive-vs-hibernating distinction.
|
||||||
|
// hibernating = "publisher told us the backend is asleep, ask
|
||||||
|
// them to wake it"; inactive = "publisher itself is offline".
|
||||||
|
// For inactive rows, queueing is the right behavior (wait for a
|
||||||
|
// worker on the pool to come online and drain). The wake path
|
||||||
|
// must NOT fire — wake is opt-in via the publisher's register
|
||||||
|
// payload, not a generic "row is down" recovery.
|
||||||
|
//
|
||||||
|
// No session bind here: an "inactive" row in production means
|
||||||
|
// unbindSession already flipped it after SSE close. Binding a
|
||||||
|
// session for the same providerSessionId would be a contradictory
|
||||||
|
// setup that the test wouldn't model anything real about.
|
||||||
const repo = mockRepo([makeLlm({ name: 'gone', providerSessionId: 'sess', status: 'inactive', inactiveSince: new Date() })]);
|
const repo = mockRepo([makeLlm({ name: 'gone', providerSessionId: 'sess', status: 'inactive', inactiveSince: new Date() })]);
|
||||||
const svc = new VirtualLlmService(repo);
|
const tasks = mockTaskService();
|
||||||
svc.bindSession('sess', fakeSession());
|
const svc = new VirtualLlmService(repo, undefined, tasks);
|
||||||
|
|
||||||
await expect(
|
const ref = await svc.enqueueInferTask('gone', { model: 'm', messages: [] }, false);
|
||||||
svc.enqueueInferTask('gone', { model: 'm', messages: [] }, false),
|
// Task queued in pending; no claim, no frame.
|
||||||
).rejects.toThrow(/inactive|publisher offline/);
|
const row = await tasks.findById(ref.taskId);
|
||||||
|
expect(row?.status).toBe('pending');
|
||||||
|
expect(row?.claimedBy).toBeNull();
|
||||||
});
|
});
|
||||||
|
|
||||||
it('gcSweep is idempotent — running twice in a row is a no-op the second time', async () => {
|
it('gcSweep is idempotent — running twice in a row is a no-op the second time', async () => {
|
||||||
@@ -439,11 +596,75 @@ describe('VirtualLlmService', () => {
|
|||||||
const repo = mockRepo([
|
const repo = mockRepo([
|
||||||
makeLlm({ name: 'stale', providerSessionId: 'a', status: 'active', lastHeartbeatAt: long }),
|
makeLlm({ name: 'stale', providerSessionId: 'a', status: 'active', lastHeartbeatAt: long }),
|
||||||
]);
|
]);
|
||||||
const svc = new VirtualLlmService(repo);
|
const svc = new VirtualLlmService(repo, undefined, mockTaskService());
|
||||||
const first = await svc.gcSweep();
|
const first = await svc.gcSweep();
|
||||||
const second = await svc.gcSweep();
|
const second = await svc.gcSweep();
|
||||||
expect(first.markedInactive).toBe(1);
|
expect(first.markedInactive).toBe(1);
|
||||||
expect(second.markedInactive).toBe(0);
|
expect(second.markedInactive).toBe(0);
|
||||||
expect(second.deleted).toBe(0);
|
expect(second.deleted).toBe(0);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// ── v5: durable queue + drain-on-bind ──
|
||||||
|
|
||||||
|
it('bindSession drains pending inference tasks owned by the session\'s pool keys', async () => {
|
||||||
|
// Two enqueues land while no session is bound. Each row is created
|
||||||
|
// with status=pending; no SSE frame goes anywhere. When the worker
|
||||||
|
// finally binds, the drain loop claims both and pushes the frames.
|
||||||
|
const repo = mockRepo([makeLlm({ name: 'vllm-local', providerSessionId: 'sess', poolName: 'qwen-pool' })]);
|
||||||
|
const tasks = mockTaskService();
|
||||||
|
const svc = new VirtualLlmService(repo, undefined, tasks);
|
||||||
|
const ref1 = await svc.enqueueInferTask('vllm-local', { model: 'm', messages: [{ role: 'user', content: 'one' }] }, false);
|
||||||
|
const ref2 = await svc.enqueueInferTask('vllm-local', { model: 'm', messages: [{ role: 'user', content: 'two' }] }, false);
|
||||||
|
// Both rows are still pending — no worker bound yet.
|
||||||
|
expect((await tasks.findById(ref1.taskId))?.status).toBe('pending');
|
||||||
|
expect((await tasks.findById(ref2.taskId))?.status).toBe('pending');
|
||||||
|
|
||||||
|
// Worker shows up. Drain runs synchronously enough that we just
|
||||||
|
// need to flush the microtask queue before checking SSE frames.
|
||||||
|
const session = fakeSession();
|
||||||
|
svc.bindSession('sess', session);
|
||||||
|
// drainPendingForSession is fired with `void` so let microtasks
|
||||||
|
// settle before asserting.
|
||||||
|
await new Promise((r) => setTimeout(r, 0));
|
||||||
|
|
||||||
|
expect((session.tasks as Array<{ kind: string; taskId: string }>).map((t) => t.taskId).sort())
|
||||||
|
.toEqual([ref1.taskId, ref2.taskId].sort());
|
||||||
|
// Rows are now claimed by this session.
|
||||||
|
expect((await tasks.findById(ref1.taskId))?.status).toBe('claimed');
|
||||||
|
expect((await tasks.findById(ref1.taskId))?.claimedBy).toBe('sess');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('drain-on-bind matches the effective pool key, not just llm.name', async () => {
|
||||||
|
// The pinned Llm has name=vllm-alice but poolName=qwen-pool.
|
||||||
|
// Enqueue against vllm-alice → row.poolName=qwen-pool.
|
||||||
|
// Worker binds with a session that owns vllm-alice (same pool key).
|
||||||
|
// Drain must surface this row even though poolName != name.
|
||||||
|
const repo = mockRepo([makeLlm({ name: 'vllm-alice', providerSessionId: 'sess', poolName: 'qwen-pool' })]);
|
||||||
|
const tasks = mockTaskService();
|
||||||
|
const svc = new VirtualLlmService(repo, undefined, tasks);
|
||||||
|
const ref = await svc.enqueueInferTask('vllm-alice', { model: 'm', messages: [] }, false);
|
||||||
|
expect((await tasks.findById(ref.taskId))?.poolName).toBe('qwen-pool');
|
||||||
|
|
||||||
|
const session = fakeSession();
|
||||||
|
svc.bindSession('sess', session);
|
||||||
|
await new Promise((r) => setTimeout(r, 0));
|
||||||
|
expect((session.tasks as Array<{ taskId: string }>).map((t) => t.taskId)).toEqual([ref.taskId]);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('completeTask via the result-route updates the DB row + emits the wakeup', async () => {
|
||||||
|
// End-to-end through the public surface: enqueue → claim happens
|
||||||
|
// because session is bound → worker POSTs result → completeTask
|
||||||
|
// routes to InferenceTaskService.complete → ref.done resolves.
|
||||||
|
const repo = mockRepo([makeLlm({ name: 'vllm-local', providerSessionId: 'sess' })]);
|
||||||
|
const tasks = mockTaskService();
|
||||||
|
const svc = new VirtualLlmService(repo, undefined, tasks);
|
||||||
|
svc.bindSession('sess', fakeSession());
|
||||||
|
const ref = await svc.enqueueInferTask('vllm-local', { model: 'm', messages: [] }, false);
|
||||||
|
|
||||||
|
expect(svc.completeTask(ref.taskId, { status: 200, body: { ok: true } })).toBe(true);
|
||||||
|
await expect(ref.done).resolves.toEqual({ status: 200, body: { ok: true } });
|
||||||
|
const row = await tasks.findById(ref.taskId);
|
||||||
|
expect(row?.status).toBe('completed');
|
||||||
|
expect(row?.responseBody).toEqual({ ok: true });
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
Reference in New Issue
Block a user