From eda8e7971277ca263d15de946cca5243636d5659 Mon Sep 17 00:00:00 2001 From: Michal Date: Sat, 25 Apr 2026 16:38:38 +0100 Subject: [PATCH] feat(agents): mcpd repos + Agent/Chat services with tool-use loop (Stage 2) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Layers the persistence-side logic on top of the Stage 1 schema. AgentService mirrors LlmService's CRUD shape with name-resolved llm/project references and yaml round-trip support; ChatService is the orchestrator that drives one chat turn end-to-end: build the merged system block (agent.systemPrompt + project Prompts ordered by priority desc + per-call systemAppend), persist the user turn, run the adapter, dispatch any tool_calls through an injected ChatToolDispatcher, persist tool turns linked back via toolCallId, and loop until the model returns terminal text. Per-call params resolve LiteLLM-style: request body → agent.defaultParams → adapter default. The escape hatch `extra` is forwarded as-is so each adapter can cherry-pick provider-specific knobs (Anthropic metadata, vLLM repetition_penalty, etc.) without code changes here. Persistence is non-transactional across the loop because tool calls can take minutes; long-held DB transactions would starve other writers. Instead each in-flight assistant turn is written `pending` and flipped to `complete` only after its tool results land. On failure or max-iter overrun, every `pending` row in the thread is flipped to `error` so the trail is auditable. Tools are namespaced on the wire as `__`, unmarshalled at dispatch time; `tools_allowlist` filters before the model sees the list. Tests: agent-service.test.ts (7) — CRUD with name-resolved llm/project, conflict on duplicate, llm switch, project detach, listByProject filtering, upsertByName branch coverage. chat-service.test.ts (9) — plain text turn, full text→tool→text loop with toolCallId linkage, max-iter cap leaves zero pending, adapter-throws leaves zero pending, body→defaultParams merge, `extra` passthrough, project-Prompt priority ordering in the system block, tool-without- project rejection, tools_allowlist filtering. All 16 green; full mcpd suite still 737/737. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/mcpd/src/repositories/agent.repository.ts | 102 ++++ src/mcpd/src/repositories/chat.repository.ts | 139 +++++ src/mcpd/src/services/agent.service.ts | 160 ++++++ src/mcpd/src/services/chat.service.ts | 532 ++++++++++++++++++ src/mcpd/src/validation/agent.schema.ts | 114 ++++ src/mcpd/tests/agent-service.test.ts | 192 +++++++ src/mcpd/tests/chat-service.test.ts | 413 ++++++++++++++ 7 files changed, 1652 insertions(+) create mode 100644 src/mcpd/src/repositories/agent.repository.ts create mode 100644 src/mcpd/src/repositories/chat.repository.ts create mode 100644 src/mcpd/src/services/agent.service.ts create mode 100644 src/mcpd/src/services/chat.service.ts create mode 100644 src/mcpd/src/validation/agent.schema.ts create mode 100644 src/mcpd/tests/agent-service.test.ts create mode 100644 src/mcpd/tests/chat-service.test.ts diff --git a/src/mcpd/src/repositories/agent.repository.ts b/src/mcpd/src/repositories/agent.repository.ts new file mode 100644 index 0000000..3e30fd0 --- /dev/null +++ b/src/mcpd/src/repositories/agent.repository.ts @@ -0,0 +1,102 @@ +import type { PrismaClient, Agent, Prisma } from '@prisma/client'; + +export interface CreateAgentRepoInput { + name: string; + description?: string; + systemPrompt?: string; + llmId: string; + projectId?: string | null; + proxyModelName?: string | null; + defaultParams?: Record; + extras?: Record; + ownerId: string; +} + +export interface UpdateAgentRepoInput { + description?: string; + systemPrompt?: string; + llmId?: string; + projectId?: string | null; + proxyModelName?: string | null; + defaultParams?: Record; + extras?: Record; +} + +export interface IAgentRepository { + findAll(): Promise; + findById(id: string): Promise; + findByName(name: string): Promise; + findByProjectId(projectId: string): Promise; + create(data: CreateAgentRepoInput): Promise; + update(id: string, data: UpdateAgentRepoInput): Promise; + delete(id: string): Promise; +} + +export class AgentRepository implements IAgentRepository { + constructor(private readonly prisma: PrismaClient) {} + + async findAll(): Promise { + return this.prisma.agent.findMany({ orderBy: { name: 'asc' } }); + } + + async findById(id: string): Promise { + return this.prisma.agent.findUnique({ where: { id } }); + } + + async findByName(name: string): Promise { + return this.prisma.agent.findUnique({ where: { name } }); + } + + async findByProjectId(projectId: string): Promise { + return this.prisma.agent.findMany({ + where: { projectId }, + orderBy: { name: 'asc' }, + }); + } + + async create(data: CreateAgentRepoInput): Promise { + return this.prisma.agent.create({ + data: { + name: data.name, + description: data.description ?? '', + systemPrompt: data.systemPrompt ?? '', + llmId: data.llmId, + projectId: data.projectId ?? null, + proxyModelName: data.proxyModelName ?? null, + defaultParams: (data.defaultParams ?? {}) as Prisma.InputJsonValue, + extras: (data.extras ?? {}) as Prisma.InputJsonValue, + ownerId: data.ownerId, + }, + }); + } + + async update(id: string, data: UpdateAgentRepoInput): Promise { + const updateData: Prisma.AgentUpdateInput = {}; + if (data.description !== undefined) updateData.description = data.description; + if (data.systemPrompt !== undefined) updateData.systemPrompt = data.systemPrompt; + if (data.llmId !== undefined) { + updateData.llm = { connect: { id: data.llmId } }; + } + if (data.projectId !== undefined) { + updateData.project = data.projectId === null + ? { disconnect: true } + : { connect: { id: data.projectId } }; + } + if (data.proxyModelName !== undefined) { + updateData.proxyModelName = data.proxyModelName; + } + if (data.defaultParams !== undefined) { + updateData.defaultParams = data.defaultParams as Prisma.InputJsonValue; + } + if (data.extras !== undefined) { + updateData.extras = data.extras as Prisma.InputJsonValue; + } + // Bump optimistic version on every update. + updateData.version = { increment: 1 }; + return this.prisma.agent.update({ where: { id }, data: updateData }); + } + + async delete(id: string): Promise { + await this.prisma.agent.delete({ where: { id } }); + } +} diff --git a/src/mcpd/src/repositories/chat.repository.ts b/src/mcpd/src/repositories/chat.repository.ts new file mode 100644 index 0000000..8d004b6 --- /dev/null +++ b/src/mcpd/src/repositories/chat.repository.ts @@ -0,0 +1,139 @@ +/** + * Chat thread + message persistence. + * + * Each ChatThread holds an ordered, monotonic sequence of ChatMessages + * (turnIndex 0..N). The schema's `@@unique([threadId, turnIndex])` prevents + * concurrent appenders from clobbering each other; on collision the caller + * retries with a fresh `nextTurnIndex(threadId)`. + * + * `status` is `'pending' | 'complete' | 'error'`. Orchestrators flip a row + * from `pending` → `complete` once the turn settles. A crash mid-turn leaves + * a `pending` row that downstream views should render as truncated. + */ +import { Prisma } from '@prisma/client'; +import type { PrismaClient, ChatThread, ChatMessage } from '@prisma/client'; + +export type ChatRole = 'system' | 'user' | 'assistant' | 'tool'; +export type ChatStatus = 'pending' | 'complete' | 'error'; + +export interface AppendMessageInput { + threadId: string; + role: ChatRole; + content: string; + toolCalls?: Array>; + toolCallId?: string; + status?: ChatStatus; + /** Optional explicit turnIndex (caller-provided). If omitted the repo allocates the next one. */ + turnIndex?: number; +} + +export interface IChatRepository { + createThread(input: { agentId: string; ownerId: string; title?: string }): Promise; + findThread(id: string): Promise; + listThreadsByAgent(agentId: string, ownerId?: string): Promise; + appendMessage(input: AppendMessageInput): Promise; + listMessages(threadId: string): Promise; + updateStatus(messageId: string, status: ChatStatus): Promise; + markPendingAsError(threadId: string): Promise; + touchThread(threadId: string): Promise; + /** Compute MAX(turnIndex)+1 for a thread. 0 if the thread is empty. */ + nextTurnIndex(threadId: string): Promise; +} + +const RACE_RETRIES = 3; +/** Postgres unique-constraint violation code (Prisma surfaces it as P2002). */ +const UNIQUE_VIOLATION = 'P2002'; + +export class ChatRepository implements IChatRepository { + constructor(private readonly prisma: PrismaClient) {} + + async createThread(input: { agentId: string; ownerId: string; title?: string }): Promise { + return this.prisma.chatThread.create({ + data: { + agentId: input.agentId, + ownerId: input.ownerId, + title: input.title ?? '', + }, + }); + } + + async findThread(id: string): Promise { + return this.prisma.chatThread.findUnique({ where: { id } }); + } + + async listThreadsByAgent(agentId: string, ownerId?: string): Promise { + return this.prisma.chatThread.findMany({ + where: { agentId, ...(ownerId !== undefined ? { ownerId } : {}) }, + orderBy: { lastTurnAt: 'desc' }, + }); + } + + async listMessages(threadId: string): Promise { + return this.prisma.chatMessage.findMany({ + where: { threadId }, + orderBy: { turnIndex: 'asc' }, + }); + } + + async nextTurnIndex(threadId: string): Promise { + const last = await this.prisma.chatMessage.findFirst({ + where: { threadId }, + orderBy: { turnIndex: 'desc' }, + select: { turnIndex: true }, + }); + return (last?.turnIndex ?? -1) + 1; + } + + async appendMessage(input: AppendMessageInput): Promise { + let attempt = 0; + // Retry on unique-violation: parallel appenders can both compute the same + // nextTurnIndex; the second insert fails P2002 and we recompute + retry. + while (true) { + const turnIndex = input.turnIndex ?? (await this.nextTurnIndex(input.threadId)); + try { + return await this.prisma.chatMessage.create({ + data: { + threadId: input.threadId, + turnIndex, + role: input.role, + content: input.content, + toolCalls: input.toolCalls === undefined + ? Prisma.JsonNull + : (input.toolCalls as Prisma.InputJsonValue), + toolCallId: input.toolCallId ?? null, + status: input.status ?? 'complete', + }, + }); + } catch (err) { + attempt += 1; + const code = (err as { code?: string }).code; + if (code === UNIQUE_VIOLATION && input.turnIndex === undefined && attempt <= RACE_RETRIES) { + continue; + } + throw err; + } + } + } + + async updateStatus(messageId: string, status: ChatStatus): Promise { + return this.prisma.chatMessage.update({ + where: { id: messageId }, + data: { status }, + }); + } + + async markPendingAsError(threadId: string): Promise { + const res = await this.prisma.chatMessage.updateMany({ + where: { threadId, status: 'pending' }, + data: { status: 'error' }, + }); + return res.count; + } + + async touchThread(threadId: string): Promise { + await this.prisma.chatThread.update({ + where: { id: threadId }, + data: { lastTurnAt: new Date() }, + }); + } +} diff --git a/src/mcpd/src/services/agent.service.ts b/src/mcpd/src/services/agent.service.ts new file mode 100644 index 0000000..b958367 --- /dev/null +++ b/src/mcpd/src/services/agent.service.ts @@ -0,0 +1,160 @@ +/** + * AgentService — CRUD over `Agent` rows. + * + * Agents pin to one Llm (FK Restrict, so an Llm in active use can't be + * deleted from under them) and optionally attach to a Project (FK SetNull — + * agents survive project deletion). The service translates name-based + * references (`{ llm: { name } }`, `{ project: { name } }`) to the FK ids + * on write, and back to names on read so callers always work with stable + * identifiers. + */ +import type { Agent } from '@prisma/client'; +import type { IAgentRepository } from '../repositories/agent.repository.js'; +import type { LlmService } from './llm.service.js'; +import type { ProjectService } from './project.service.js'; +import { + CreateAgentSchema, + UpdateAgentSchema, + type AgentChatParams, + type CreateAgentInput, +} from '../validation/agent.schema.js'; +import { NotFoundError, ConflictError } from './mcp-server.service.js'; + +/** Shape returned by the API layer — embeds llm + project metadata. */ +export interface AgentView { + id: string; + name: string; + description: string; + systemPrompt: string; + llm: { id: string; name: string }; + project: { id: string; name: string } | null; + proxyModelName: string | null; + defaultParams: AgentChatParams; + extras: Record; + ownerId: string; + version: number; + createdAt: Date; + updatedAt: Date; +} + +export class AgentService { + constructor( + private readonly repo: IAgentRepository, + private readonly llms: LlmService, + private readonly projects: ProjectService, + ) {} + + async list(): Promise { + const rows = await this.repo.findAll(); + return Promise.all(rows.map((r) => this.toView(r))); + } + + async listByProject(projectName: string): Promise { + const project = await this.projects.resolveAndGet(projectName); + const rows = await this.repo.findByProjectId(project.id); + return Promise.all(rows.map((r) => this.toView(r))); + } + + async getById(id: string): Promise { + const row = await this.repo.findById(id); + if (row === null) throw new NotFoundError(`Agent not found: ${id}`); + return this.toView(row); + } + + async getByName(name: string): Promise { + const row = await this.repo.findByName(name); + if (row === null) throw new NotFoundError(`Agent not found: ${name}`); + return this.toView(row); + } + + async create(input: unknown, ownerId: string): Promise { + const data = CreateAgentSchema.parse(input); + const existing = await this.repo.findByName(data.name); + if (existing !== null) throw new ConflictError(`Agent already exists: ${data.name}`); + + const llm = 'name' in data.llm ? await this.llms.getByName(data.llm.name) : await this.llms.getById(data.llm.id); + const projectId = data.project !== undefined + ? (await this.projects.resolveAndGet(data.project.name)).id + : null; + + const row = await this.repo.create({ + name: data.name, + description: data.description, + systemPrompt: data.systemPrompt, + llmId: llm.id, + projectId, + proxyModelName: data.proxyModelName ?? null, + defaultParams: data.defaultParams as Record, + extras: data.extras, + ownerId, + }); + return this.toView(row); + } + + async update(id: string, input: unknown): Promise { + const data = UpdateAgentSchema.parse(input); + await this.getById(id); + + const updateFields: Parameters[1] = {}; + if (data.description !== undefined) updateFields.description = data.description; + if (data.systemPrompt !== undefined) updateFields.systemPrompt = data.systemPrompt; + if (data.llm !== undefined) { + const llm = 'name' in data.llm ? await this.llms.getByName(data.llm.name) : await this.llms.getById(data.llm.id); + updateFields.llmId = llm.id; + } + if (data.project !== undefined) { + updateFields.projectId = data.project === null + ? null + : (await this.projects.resolveAndGet(data.project.name)).id; + } + if (data.proxyModelName !== undefined) updateFields.proxyModelName = data.proxyModelName; + if (data.defaultParams !== undefined) updateFields.defaultParams = data.defaultParams as Record; + if (data.extras !== undefined) updateFields.extras = data.extras; + + const row = await this.repo.update(id, updateFields); + return this.toView(row); + } + + async delete(id: string): Promise { + await this.getById(id); + await this.repo.delete(id); + } + + // ── Backup/restore helpers ── + + async upsertByName(input: CreateAgentInput, ownerId: string): Promise { + const existing = await this.repo.findByName(input.name); + if (existing !== null) { + return this.update(existing.id, input); + } + return this.create(input, ownerId); + } + + async deleteByName(name: string): Promise { + const row = await this.repo.findByName(name); + if (row === null) return; + await this.delete(row.id); + } + + private async toView(row: Agent): Promise { + const llm = await this.llms.getById(row.llmId); + const project = row.projectId !== null + ? await this.projects.getById(row.projectId).catch(() => null) + : null; + return { + id: row.id, + name: row.name, + description: row.description, + systemPrompt: row.systemPrompt, + llm: { id: llm.id, name: llm.name }, + project: project !== null ? { id: project.id, name: project.name } : null, + proxyModelName: row.proxyModelName, + defaultParams: row.defaultParams as AgentChatParams, + extras: row.extras as Record, + ownerId: row.ownerId, + version: row.version, + createdAt: row.createdAt, + updatedAt: row.updatedAt, + }; + } +} diff --git a/src/mcpd/src/services/chat.service.ts b/src/mcpd/src/services/chat.service.ts new file mode 100644 index 0000000..80ba953 --- /dev/null +++ b/src/mcpd/src/services/chat.service.ts @@ -0,0 +1,532 @@ +/** + * ChatService — orchestrates an agent's chat turn end-to-end. + * + * For one inbound chat call: + * 1. Resolve the agent → its Llm and (optional) Project. + * 2. Build messages: merged system block (agent.systemPrompt + project + * Prompts joined by priority desc) + persisted thread history + new + * user turn. Persist the user turn (status:complete) up front. + * 3. Enumerate tools from the project's MCP servers via the injected + * ToolDispatcher and translate to OpenAI function-tool format. + * 4. Loop (cap = MAX_ITERATIONS) calling the adapter: + * - if the model returns text → persist as assistant (complete), end. + * - if it returns tool_calls → persist assistant turn (pending) with + * the tool_calls JSON; for each call, dispatch through the + * ToolDispatcher; persist a tool turn with the result; flip the + * assistant turn to complete; loop. + * 5. On any exception, mark all `pending` rows in the thread as `error` + * and surface the error to the caller. No big DB transaction wraps the + * loop because tool calls can take minutes. + * + * Per-call params merge resolution: request body → agent.defaultParams → + * adapter default. `extra` is forwarded as-is for provider-specific knobs. + */ +import type { ChatMessage } from '@prisma/client'; +import type { AgentService } from './agent.service.js'; +import type { LlmService } from './llm.service.js'; +import type { LlmAdapterRegistry } from './llm/dispatcher.js'; +import type { + IChatRepository, + ChatRole, +} from '../repositories/chat.repository.js'; +import type { IPromptRepository } from '../repositories/prompt.repository.js'; +import type { OpenAiChatRequest, OpenAiMessage } from './llm/types.js'; +import type { AgentChatParams } from '../validation/agent.schema.js'; +import { NotFoundError } from './mcp-server.service.js'; + +export const TOOL_NAME_SEPARATOR = '__'; +export const MAX_ITERATIONS = 12; + +/** Project-scoped tool surface the chat loop calls into. Stub-friendly. */ +export interface ChatTool { + /** Wire format: `${TOOL_NAME_SEPARATOR}`. */ + name: string; + description: string; + parameters: Record; +} + +export interface ChatToolDispatcher { + /** List tools available to an agent's project. Empty if no project. */ + listTools(projectId: string | null): Promise; + /** Execute a tool call. Throws on error. */ + callTool(args: { + projectId: string; + serverName: string; + toolName: string; + args: Record; + }): Promise; +} + +export interface ChatStreamChunk { + type: 'text' | 'tool_call' | 'tool_result' | 'final' | 'error'; + delta?: string; + toolName?: string; + args?: Record; + ok?: boolean; + threadId?: string; + turnIndex?: number; + message?: string; +} + +export interface ChatRequestArgs { + agentName: string; + threadId?: string; + userMessage?: string; + /** Optional full-history override; if set, threadId history is ignored. */ + messagesOverride?: OpenAiMessage[]; + ownerId: string; + params?: AgentChatParams; +} + +export interface ChatResult { + threadId: string; + assistant: string; + turnIndex: number; +} + +export class ChatService { + constructor( + private readonly agents: AgentService, + private readonly llms: LlmService, + private readonly adapters: LlmAdapterRegistry, + private readonly chatRepo: IChatRepository, + private readonly promptRepo: IPromptRepository, + private readonly tools: ChatToolDispatcher, + ) {} + + async createThread(agentName: string, ownerId: string, title?: string): Promise<{ id: string }> { + const agent = await this.agents.getByName(agentName); + const thread = await this.chatRepo.createThread({ + agentId: agent.id, + ownerId, + ...(title !== undefined ? { title } : {}), + }); + return { id: thread.id }; + } + + async listThreads(agentName: string, ownerId?: string): Promise> { + const agent = await this.agents.getByName(agentName); + const rows = await this.chatRepo.listThreadsByAgent(agent.id, ownerId); + return rows.map((r) => ({ id: r.id, title: r.title, lastTurnAt: r.lastTurnAt, createdAt: r.createdAt })); + } + + async listMessages(threadId: string): Promise { + return this.chatRepo.listMessages(threadId); + } + + /** Non-streaming chat. Persists rows + returns the final assistant text. */ + async chat(args: ChatRequestArgs): Promise { + const ctx = await this.prepareContext(args); + let assistantFinal = ''; + let lastTurnIndex = ctx.startingTurnIndex; + try { + for (let i = 0; i < MAX_ITERATIONS; i += 1) { + const adapter = this.adapters.get(ctx.llmType); + const result = await adapter.infer({ + body: this.buildBody(ctx), + modelOverride: ctx.modelOverride, + apiKey: ctx.apiKey, + url: ctx.url, + extraConfig: ctx.extraConfig, + }); + const choice = extractChoice(result.body); + if (choice === null) { + throw new Error(`Adapter returned no choice (status ${String(result.status)})`); + } + if (choice.tool_calls !== undefined && choice.tool_calls.length > 0) { + const assistantTurn = await this.chatRepo.appendMessage({ + threadId: ctx.threadId, + role: 'assistant', + content: choice.content ?? '', + toolCalls: choice.tool_calls.map((c) => ({ + id: c.id, + name: c.function.name, + arguments: safeParseJson(c.function.arguments), + })), + status: 'pending', + }); + ctx.history.push({ + role: 'assistant', + content: choice.content ?? '', + tool_calls: choice.tool_calls, + }); + for (const call of choice.tool_calls) { + const toolResult = await this.dispatchTool(call.function.name, call.function.arguments, ctx.projectId); + const resultMsg = await this.chatRepo.appendMessage({ + threadId: ctx.threadId, + role: 'tool', + content: typeof toolResult === 'string' ? toolResult : JSON.stringify(toolResult), + toolCallId: call.id, + }); + lastTurnIndex = resultMsg.turnIndex; + ctx.history.push({ + role: 'tool', + content: typeof toolResult === 'string' ? toolResult : JSON.stringify(toolResult), + tool_call_id: call.id, + }); + } + await this.chatRepo.updateStatus(assistantTurn.id, 'complete'); + continue; + } + // Terminal text turn. + const finalMsg = await this.chatRepo.appendMessage({ + threadId: ctx.threadId, + role: 'assistant', + content: choice.content ?? '', + }); + assistantFinal = choice.content ?? ''; + lastTurnIndex = finalMsg.turnIndex; + await this.chatRepo.touchThread(ctx.threadId); + return { threadId: ctx.threadId, assistant: assistantFinal, turnIndex: lastTurnIndex }; + } + throw new Error(`Chat loop exceeded ${String(MAX_ITERATIONS)} iterations without a terminal turn`); + } catch (err) { + await this.chatRepo.markPendingAsError(ctx.threadId); + throw err; + } + } + + /** Streaming chat. Yields text deltas + tool events. Persists rows in lockstep. */ + async *chatStream(args: ChatRequestArgs): AsyncGenerator { + const ctx = await this.prepareContext(args); + try { + for (let i = 0; i < MAX_ITERATIONS; i += 1) { + const adapter = this.adapters.get(ctx.llmType); + const accumulated: { content: string; toolCalls: Array<{ id: string; name: string; argumentsJson: string }> } = { + content: '', + toolCalls: [], + }; + let finishReason: string | null = null; + for await (const chunk of adapter.stream({ + body: { ...this.buildBody(ctx), stream: true }, + modelOverride: ctx.modelOverride, + apiKey: ctx.apiKey, + url: ctx.url, + extraConfig: ctx.extraConfig, + })) { + if (chunk.done === true) break; + if (chunk.data === '[DONE]') break; + const evt = parseStreamingChunk(chunk.data); + if (evt === null) continue; + if (evt.contentDelta !== undefined) { + accumulated.content += evt.contentDelta; + yield { type: 'text', delta: evt.contentDelta }; + } + if (evt.toolCallDeltas !== undefined) { + for (const td of evt.toolCallDeltas) { + const slot = (accumulated.toolCalls[td.index] ??= { id: '', name: '', argumentsJson: '' }); + if (td.id !== undefined) slot.id = td.id; + if (td.name !== undefined) slot.name = td.name; + if (td.argumentsDelta !== undefined) slot.argumentsJson += td.argumentsDelta; + } + } + if (evt.finishReason !== null && evt.finishReason !== undefined) { + finishReason = evt.finishReason; + } + } + + if (accumulated.toolCalls.length > 0 && finishReason === 'tool_calls') { + const assistantTurn = await this.chatRepo.appendMessage({ + threadId: ctx.threadId, + role: 'assistant', + content: accumulated.content, + toolCalls: accumulated.toolCalls.map((c) => ({ + id: c.id, + name: c.name, + arguments: safeParseJson(c.argumentsJson), + })), + status: 'pending', + }); + ctx.history.push({ + role: 'assistant', + content: accumulated.content, + tool_calls: accumulated.toolCalls.map((c) => ({ + id: c.id, + type: 'function', + function: { name: c.name, arguments: c.argumentsJson }, + })), + }); + for (const call of accumulated.toolCalls) { + yield { type: 'tool_call', toolName: call.name, args: safeParseJson(call.argumentsJson) as Record }; + try { + const result = await this.dispatchTool(call.name, call.argumentsJson, ctx.projectId); + const resultStr = typeof result === 'string' ? result : JSON.stringify(result); + await this.chatRepo.appendMessage({ + threadId: ctx.threadId, + role: 'tool', + content: resultStr, + toolCallId: call.id, + }); + ctx.history.push({ role: 'tool', content: resultStr, tool_call_id: call.id }); + yield { type: 'tool_result', toolName: call.name, ok: true }; + } catch (toolErr) { + const errMsg = (toolErr as Error).message; + await this.chatRepo.appendMessage({ + threadId: ctx.threadId, + role: 'tool', + content: `error: ${errMsg}`, + toolCallId: call.id, + status: 'error', + }); + ctx.history.push({ role: 'tool', content: `error: ${errMsg}`, tool_call_id: call.id }); + yield { type: 'tool_result', toolName: call.name, ok: false }; + } + } + await this.chatRepo.updateStatus(assistantTurn.id, 'complete'); + continue; + } + + const finalMsg = await this.chatRepo.appendMessage({ + threadId: ctx.threadId, + role: 'assistant', + content: accumulated.content, + }); + await this.chatRepo.touchThread(ctx.threadId); + yield { type: 'final', threadId: ctx.threadId, turnIndex: finalMsg.turnIndex }; + return; + } + throw new Error(`Chat loop exceeded ${String(MAX_ITERATIONS)} iterations without a terminal turn`); + } catch (err) { + await this.chatRepo.markPendingAsError(ctx.threadId); + yield { type: 'error', message: (err as Error).message }; + } + } + + private async prepareContext(args: ChatRequestArgs): Promise<{ + threadId: string; + history: OpenAiMessage[]; + systemBlock: string; + llmName: string; + llmType: string; + modelOverride: string; + url: string; + apiKey: string; + extraConfig: Record; + mergedParams: AgentChatParams; + toolList: ChatTool[]; + projectId: string | null; + startingTurnIndex: number; + }> { + const agent = await this.agents.getByName(args.agentName); + const llm = await this.llms.getByName(agent.llm.name); + const apiKey = await this.llms.resolveApiKey(agent.llm.name).catch(() => ''); + + const threadId = await this.resolveThreadId(args, agent.id); + const projectId = agent.project?.id ?? null; + + const projectPrompts = projectId !== null + ? await this.promptRepo.findAll(projectId) + : []; + const sortedPrompts = [...projectPrompts] + .filter((p) => p.projectId === projectId) + .sort((a, b) => b.priority - a.priority); + + const mergedParams: AgentChatParams = { + ...(agent.defaultParams ?? {}), + ...(args.params ?? {}), + }; + + const baseSystem = mergedParams.systemOverride ?? agent.systemPrompt; + const systemBlock = [ + baseSystem, + ...sortedPrompts.map((p) => p.content), + mergedParams.systemAppend ?? '', + ] + .filter((s) => s.length > 0) + .join('\n\n'); + + const history = args.messagesOverride !== undefined + ? [...args.messagesOverride] + : await this.loadHistory(threadId); + + let startingTurnIndex = await this.chatRepo.nextTurnIndex(threadId); + if (args.userMessage !== undefined && args.messagesOverride === undefined) { + const userTurn = await this.chatRepo.appendMessage({ + threadId, + role: 'user', + content: args.userMessage, + }); + startingTurnIndex = userTurn.turnIndex; + history.push({ role: 'user', content: args.userMessage }); + } + + const toolList = await this.tools.listTools(projectId); + const allowed = mergedParams.tools_allowlist; + const filteredTools = allowed === undefined + ? toolList + : toolList.filter((t) => allowed.includes(t.name)); + + return { + threadId, + history, + systemBlock, + llmName: llm.name, + llmType: llm.type, + modelOverride: llm.model, + url: llm.url, + apiKey, + extraConfig: llm.extraConfig, + mergedParams, + toolList: filteredTools, + projectId, + startingTurnIndex, + }; + } + + private async resolveThreadId(args: ChatRequestArgs, agentId: string): Promise { + if (args.threadId !== undefined) { + const existing = await this.chatRepo.findThread(args.threadId); + if (existing === null) throw new NotFoundError(`Thread not found: ${args.threadId}`); + return existing.id; + } + const created = await this.chatRepo.createThread({ agentId, ownerId: args.ownerId }); + return created.id; + } + + private async loadHistory(threadId: string): Promise { + const rows = await this.chatRepo.listMessages(threadId); + return rows + .filter((r) => r.status !== 'error') + .map((r) => { + const msg: OpenAiMessage = { role: r.role as ChatRole, content: r.content }; + if (r.toolCallId !== null) msg.tool_call_id = r.toolCallId; + if (r.toolCalls !== null && Array.isArray(r.toolCalls)) { + const calls = r.toolCalls as Array<{ id: string; name: string; arguments: unknown }>; + msg.tool_calls = calls.map((c) => ({ + id: c.id, + type: 'function' as const, + function: { name: c.name, arguments: typeof c.arguments === 'string' ? c.arguments : JSON.stringify(c.arguments) }, + })); + } + return msg; + }); + } + + private buildBody(ctx: { + history: OpenAiMessage[]; + systemBlock: string; + modelOverride: string; + mergedParams: AgentChatParams; + toolList: ChatTool[]; + }): OpenAiChatRequest { + const messages: OpenAiMessage[] = []; + if (ctx.systemBlock.length > 0) { + messages.push({ role: 'system', content: ctx.systemBlock }); + } + messages.push(...ctx.history); + const body: OpenAiChatRequest = { + model: ctx.modelOverride, + messages, + }; + const p = ctx.mergedParams; + if (p.temperature !== undefined) body.temperature = p.temperature; + if (p.top_p !== undefined) body.top_p = p.top_p; + if (p.top_k !== undefined) (body as Record)['top_k'] = p.top_k; + if (p.max_tokens !== undefined) body.max_tokens = p.max_tokens; + if (p.stop !== undefined) body.stop = p.stop; + if (p.presence_penalty !== undefined) (body as Record)['presence_penalty'] = p.presence_penalty; + if (p.frequency_penalty !== undefined) (body as Record)['frequency_penalty'] = p.frequency_penalty; + if (p.seed !== undefined) (body as Record)['seed'] = p.seed; + if (p.response_format !== undefined) (body as Record)['response_format'] = p.response_format; + if (p.tool_choice !== undefined) body.tool_choice = p.tool_choice; + if (ctx.toolList.length > 0) { + body.tools = ctx.toolList.map((t) => ({ + type: 'function' as const, + function: { name: t.name, description: t.description, parameters: t.parameters }, + })); + } + if (p.extra !== undefined) { + for (const [k, v] of Object.entries(p.extra)) { + (body as Record)[k] = v; + } + } + return body; + } + + private async dispatchTool(toolWireName: string, argsJson: string, projectId: string | null): Promise { + if (projectId === null) { + throw new Error('Tool calls require an agent attached to a Project'); + } + const sep = toolWireName.indexOf(TOOL_NAME_SEPARATOR); + if (sep === -1) { + throw new Error(`Tool name '${toolWireName}' missing '${TOOL_NAME_SEPARATOR}' separator`); + } + const serverName = toolWireName.slice(0, sep); + const toolName = toolWireName.slice(sep + TOOL_NAME_SEPARATOR.length); + const parsed = safeParseJson(argsJson) as Record; + return this.tools.callTool({ projectId, serverName, toolName, args: parsed }); + } +} + +interface ExtractedChoice { + content: string | null; + tool_calls?: Array<{ id: string; type: 'function'; function: { name: string; arguments: string } }>; +} + +function extractChoice(body: unknown): ExtractedChoice | null { + if (typeof body !== 'object' || body === null) return null; + const choices = (body as { choices?: unknown }).choices; + if (!Array.isArray(choices) || choices.length === 0) return null; + const first = choices[0] as { message?: { content?: unknown; tool_calls?: unknown } } | undefined; + if (first?.message === undefined) return null; + const content = typeof first.message.content === 'string' ? first.message.content : null; + const toolCalls = first.message.tool_calls; + const out: ExtractedChoice = { content }; + if (Array.isArray(toolCalls)) { + out.tool_calls = toolCalls as NonNullable; + } + return out; +} + +function safeParseJson(s: string): unknown { + if (s === '') return {}; + try { + return JSON.parse(s); + } catch { + return {}; + } +} + +interface ParsedStreamEvent { + contentDelta?: string; + toolCallDeltas?: Array<{ index: number; id?: string; name?: string; argumentsDelta?: string }>; + finishReason?: string | null; +} + +function parseStreamingChunk(data: string): ParsedStreamEvent | null { + if (data === '' || data === '[DONE]') return null; + let json: unknown; + try { + json = JSON.parse(data); + } catch { + return null; + } + if (typeof json !== 'object' || json === null) return null; + const choices = (json as { choices?: unknown }).choices; + if (!Array.isArray(choices) || choices.length === 0) return null; + const c = choices[0] as { delta?: { content?: unknown; tool_calls?: unknown }; finish_reason?: unknown }; + const evt: ParsedStreamEvent = {}; + const delta = c.delta; + if (delta !== undefined) { + if (typeof delta.content === 'string' && delta.content.length > 0) { + evt.contentDelta = delta.content; + } + if (Array.isArray(delta.tool_calls)) { + evt.toolCallDeltas = (delta.tool_calls as Array<{ + index: number; + id?: string; + function?: { name?: string; arguments?: string }; + }>).map((t) => { + const td: { index: number; id?: string; name?: string; argumentsDelta?: string } = { index: t.index }; + if (t.id !== undefined) td.id = t.id; + if (t.function?.name !== undefined) td.name = t.function.name; + if (t.function?.arguments !== undefined) td.argumentsDelta = t.function.arguments; + return td; + }); + } + } + if (c.finish_reason !== undefined) { + evt.finishReason = (c.finish_reason as string | null); + } + return evt; +} diff --git a/src/mcpd/src/validation/agent.schema.ts b/src/mcpd/src/validation/agent.schema.ts new file mode 100644 index 0000000..651bd39 --- /dev/null +++ b/src/mcpd/src/validation/agent.schema.ts @@ -0,0 +1,114 @@ +/** + * Agent + Chat validation schemas. + * + * `AgentChatParamsSchema` is the LiteLLM-style passthrough used by both + * `agent.defaultParams` (stored on the agent row) and the per-call request + * body. Resolution order at chat time: request body → agent.defaultParams → + * adapter default. `extra` is the escape hatch for provider-specific knobs; + * adapters cherry-pick what they understand and ignore the rest. + */ +import { z } from 'zod'; + +/** OpenAI tool-choice schema, matching what we'll thread through to adapters. */ +const ToolChoiceSchema = z.union([ + z.literal('auto'), + z.literal('none'), + z.literal('required'), + z.object({ + type: z.literal('function'), + function: z.object({ name: z.string().min(1) }), + }), +]); + +const ResponseFormatSchema = z + .object({ + type: z.enum(['text', 'json_object', 'json_schema']), + }) + .passthrough(); + +/** + * The LiteLLM-style chat parameter set. Every field is optional — both + * `defaultParams` (stored on Agent) and per-call overrides reuse this shape. + */ +export const AgentChatParamsSchema = z + .object({ + // Sampling + temperature: z.number().min(0).max(2).optional(), + top_p: z.number().min(0).max(1).optional(), + top_k: z.number().int().min(0).optional(), + max_tokens: z.number().int().positive().optional(), + stop: z.union([z.string(), z.array(z.string()).max(4)]).optional(), + presence_penalty: z.number().min(-2).max(2).optional(), + frequency_penalty: z.number().min(-2).max(2).optional(), + seed: z.number().int().optional(), + response_format: ResponseFormatSchema.optional(), + // Persona overrides + systemOverride: z.string().optional(), + systemAppend: z.string().optional(), + // Tools + tool_choice: ToolChoiceSchema.optional(), + tools_allowlist: z.array(z.string().min(1)).optional(), + // Provider escape hatch + extra: z.record(z.unknown()).optional(), + }) + .strict(); + +export type AgentChatParams = z.infer; + +/** Optional named pointer at an Llm row. Mirrors `apiKeyRef` on Llm. */ +const LlmRefSchema = z.union([ + z.object({ name: z.string().min(1) }), + z.object({ id: z.string().min(1) }), +]); +const ProjectRefSchema = z.object({ name: z.string().min(1) }); + +const NAME_RE = /^[a-z0-9-]+$/; + +export const CreateAgentSchema = z.object({ + name: z + .string() + .min(1) + .max(100) + .regex(NAME_RE, 'Name must be lowercase alphanumeric with hyphens'), + description: z.string().max(500).default(''), + systemPrompt: z.string().max(64_000).default(''), + llm: LlmRefSchema, + project: ProjectRefSchema.optional(), + proxyModelName: z.string().min(1).optional(), + defaultParams: AgentChatParamsSchema.default({}), + extras: z.record(z.unknown()).default({}), +}); + +export const UpdateAgentSchema = z.object({ + description: z.string().max(500).optional(), + systemPrompt: z.string().max(64_000).optional(), + llm: LlmRefSchema.optional(), + project: ProjectRefSchema.nullable().optional(), + proxyModelName: z.string().min(1).nullable().optional(), + defaultParams: AgentChatParamsSchema.optional(), + extras: z.record(z.unknown()).optional(), +}); + +/** Body schema for `POST /api/v1/agents/:name/chat`. */ +export const AgentChatRequestSchema = AgentChatParamsSchema.merge( + z.object({ + threadId: z.string().min(1).optional(), + message: z.string().min(1).optional(), + messages: z + .array( + z.object({ + role: z.enum(['system', 'user', 'assistant', 'tool']), + content: z.string(), + tool_call_id: z.string().optional(), + }), + ) + .optional(), + stream: z.boolean().optional(), + }), +).refine((v) => v.message !== undefined || (v.messages?.length ?? 0) > 0, { + message: 'Either `message` or `messages` is required', +}); + +export type CreateAgentInput = z.infer; +export type UpdateAgentInput = z.infer; +export type AgentChatRequest = z.infer; diff --git a/src/mcpd/tests/agent-service.test.ts b/src/mcpd/tests/agent-service.test.ts new file mode 100644 index 0000000..f1332ba --- /dev/null +++ b/src/mcpd/tests/agent-service.test.ts @@ -0,0 +1,192 @@ +import { describe, it, expect, vi } from 'vitest'; +import { AgentService } from '../src/services/agent.service.js'; +import type { IAgentRepository } from '../src/repositories/agent.repository.js'; +import type { LlmService } from '../src/services/llm.service.js'; +import type { ProjectService } from '../src/services/project.service.js'; +import type { Agent } from '@prisma/client'; + +function makeAgent(overrides: Partial = {}): Agent { + return { + id: 'agent-1', + name: 'reviewer', + description: '', + systemPrompt: '', + llmId: 'llm-1', + projectId: null, + proxyModelName: null, + defaultParams: {} as Agent['defaultParams'], + extras: {} as Agent['extras'], + ownerId: 'owner-1', + version: 1, + createdAt: new Date(), + updatedAt: new Date(), + ...overrides, + }; +} + +function mockRepo(initial: Agent[] = []): IAgentRepository { + const rows = new Map(initial.map((r) => [r.id, r])); + return { + findAll: vi.fn(async () => [...rows.values()]), + findById: vi.fn(async (id: string) => rows.get(id) ?? null), + findByName: vi.fn(async (name: string) => { + for (const r of rows.values()) if (r.name === name) return r; + return null; + }), + findByProjectId: vi.fn(async (projectId: string) => + [...rows.values()].filter((r) => r.projectId === projectId)), + create: vi.fn(async (data) => { + const row = makeAgent({ + id: `agent-${String(rows.size + 1)}`, + name: data.name, + description: data.description ?? '', + systemPrompt: data.systemPrompt ?? '', + llmId: data.llmId, + projectId: data.projectId ?? null, + proxyModelName: data.proxyModelName ?? null, + defaultParams: (data.defaultParams ?? {}) as Agent['defaultParams'], + extras: (data.extras ?? {}) as Agent['extras'], + ownerId: data.ownerId, + }); + rows.set(row.id, row); + return row; + }), + update: vi.fn(async (id, data) => { + const existing = rows.get(id); + if (!existing) throw new Error('not found'); + const next: Agent = { + ...existing, + ...(data.description !== undefined ? { description: data.description } : {}), + ...(data.systemPrompt !== undefined ? { systemPrompt: data.systemPrompt } : {}), + ...(data.llmId !== undefined ? { llmId: data.llmId } : {}), + ...(data.projectId !== undefined ? { projectId: data.projectId } : {}), + ...(data.proxyModelName !== undefined ? { proxyModelName: data.proxyModelName } : {}), + ...(data.defaultParams !== undefined ? { defaultParams: data.defaultParams as Agent['defaultParams'] } : {}), + ...(data.extras !== undefined ? { extras: data.extras as Agent['extras'] } : {}), + version: existing.version + 1, + }; + rows.set(id, next); + return next; + }), + delete: vi.fn(async (id: string) => { + rows.delete(id); + }), + }; +} + +function mockLlms(): LlmService { + return { + getById: vi.fn(async (id: string) => ({ + id, name: id === 'llm-1' ? 'qwen3-thinking' : 'other', + type: 'openai', model: 'm', url: '', tier: 'fast', + description: '', apiKeyRef: null, extraConfig: {}, + version: 1, createdAt: new Date(), updatedAt: new Date(), + })), + getByName: vi.fn(async (name: string) => ({ + id: name === 'qwen3-thinking' ? 'llm-1' : 'llm-other', + name, type: 'openai', model: 'm', url: '', tier: 'fast', + description: '', apiKeyRef: null, extraConfig: {}, + version: 1, createdAt: new Date(), updatedAt: new Date(), + })), + } as unknown as LlmService; +} + +function mockProjects(): ProjectService { + return { + getById: vi.fn(async (id: string) => ({ id, name: id === 'proj-1' ? 'mcpctl-dev' : 'other' })), + resolveAndGet: vi.fn(async (idOrName: string) => ({ + id: idOrName === 'mcpctl-dev' ? 'proj-1' : 'proj-other', + name: idOrName, + })), + } as unknown as ProjectService; +} + +describe('AgentService', () => { + it('creates an agent resolving llm + project by name', async () => { + const repo = mockRepo(); + const svc = new AgentService(repo, mockLlms(), mockProjects()); + const view = await svc.create({ + name: 'reviewer', + description: 'I review security', + systemPrompt: 'be terse', + llm: { name: 'qwen3-thinking' }, + project: { name: 'mcpctl-dev' }, + defaultParams: { temperature: 0.2, max_tokens: 4096 }, + }, 'owner-1'); + expect(view.name).toBe('reviewer'); + expect(view.llm.name).toBe('qwen3-thinking'); + expect(view.project?.name).toBe('mcpctl-dev'); + expect(view.defaultParams.temperature).toBe(0.2); + expect(repo.create).toHaveBeenCalledOnce(); + }); + + it('creates an agent without a project (null projectId stays null)', async () => { + const repo = mockRepo(); + const svc = new AgentService(repo, mockLlms(), mockProjects()); + const view = await svc.create({ + name: 'standalone', + llm: { name: 'qwen3-thinking' }, + }, 'owner-1'); + expect(view.project).toBeNull(); + }); + + it('rejects creating an agent with a duplicate name (Conflict)', async () => { + const repo = mockRepo([makeAgent({ id: 'a1', name: 'dup' })]); + const svc = new AgentService(repo, mockLlms(), mockProjects()); + await expect(svc.create({ + name: 'dup', + llm: { name: 'qwen3-thinking' }, + }, 'owner-1')).rejects.toThrow(/already exists/); + }); + + it('updates llm reference by name', async () => { + const repo = mockRepo([makeAgent({ id: 'a1', name: 'switcher', llmId: 'llm-1' })]); + const svc = new AgentService(repo, mockLlms(), mockProjects()); + const updated = await svc.update('a1', { llm: { name: 'other' } }); + expect(updated.llm.id).toBe('llm-other'); + }); + + it('detaches a project when project is set to null', async () => { + const repo = mockRepo([makeAgent({ id: 'a1', name: 'attached', projectId: 'proj-1' })]); + const svc = new AgentService(repo, mockLlms(), mockProjects()); + const updated = await svc.update('a1', { project: null }); + expect(updated.project).toBeNull(); + }); + + it('listByProject returns only agents in the project', async () => { + const repo = mockRepo([ + makeAgent({ id: 'a1', name: 'in-proj', projectId: 'proj-1' }), + makeAgent({ id: 'a2', name: 'no-proj', projectId: null }), + makeAgent({ id: 'a3', name: 'other-proj', projectId: 'proj-other' }), + ]); + const svc = new AgentService(repo, mockLlms(), mockProjects()); + const list = await svc.listByProject('mcpctl-dev'); + expect(list.map((a) => a.name)).toEqual(['in-proj']); + }); + + it('upsertByName creates if missing, updates if present', async () => { + const repo = mockRepo(); + const svc = new AgentService(repo, mockLlms(), mockProjects()); + + const created = await svc.upsertByName({ + name: 'roundtrip', + description: 'first', + systemPrompt: '', + llm: { name: 'qwen3-thinking' }, + defaultParams: {}, + extras: {}, + }, 'owner-1'); + expect(created.description).toBe('first'); + + const updated = await svc.upsertByName({ + name: 'roundtrip', + description: 'second', + systemPrompt: '', + llm: { name: 'qwen3-thinking' }, + defaultParams: {}, + extras: {}, + }, 'owner-1'); + expect(updated.description).toBe('second'); + expect(updated.id).toBe(created.id); + }); +}); diff --git a/src/mcpd/tests/chat-service.test.ts b/src/mcpd/tests/chat-service.test.ts new file mode 100644 index 0000000..2b6170b --- /dev/null +++ b/src/mcpd/tests/chat-service.test.ts @@ -0,0 +1,413 @@ +import { describe, it, expect, vi } from 'vitest'; +import { ChatService, MAX_ITERATIONS, TOOL_NAME_SEPARATOR, type ChatToolDispatcher } from '../src/services/chat.service.js'; +import type { AgentService } from '../src/services/agent.service.js'; +import type { LlmService } from '../src/services/llm.service.js'; +import type { LlmAdapterRegistry } from '../src/services/llm/dispatcher.js'; +import type { LlmAdapter, NonStreamingResult, InferContext } from '../src/services/llm/types.js'; +import type { IChatRepository } from '../src/repositories/chat.repository.js'; +import type { IPromptRepository } from '../src/repositories/prompt.repository.js'; +import type { ChatMessage, ChatThread, Prompt } from '@prisma/client'; + +const NOW = new Date(); + +function mockChatRepo(): IChatRepository & { _msgs: ChatMessage[]; _threads: ChatThread[] } { + const msgs: ChatMessage[] = []; + const threads: ChatThread[] = []; + let idCounter = 1; + + return { + _msgs: msgs, + _threads: threads, + createThread: vi.fn(async ({ agentId, ownerId, title }) => { + const t: ChatThread = { + id: `thread-${String(idCounter++)}`, + agentId, + ownerId, + title: title ?? '', + lastTurnAt: NOW, + createdAt: NOW, + updatedAt: NOW, + }; + threads.push(t); + return t; + }), + findThread: vi.fn(async (id: string) => threads.find((t) => t.id === id) ?? null), + listThreadsByAgent: vi.fn(async (agentId: string) => threads.filter((t) => t.agentId === agentId)), + listMessages: vi.fn(async (threadId: string) => + msgs.filter((m) => m.threadId === threadId).sort((a, b) => a.turnIndex - b.turnIndex)), + appendMessage: vi.fn(async (input) => { + const turnIndex = input.turnIndex ?? msgs.filter((m) => m.threadId === input.threadId).length; + const m: ChatMessage = { + id: `msg-${String(idCounter++)}`, + threadId: input.threadId, + turnIndex, + role: input.role, + content: input.content, + toolCalls: (input.toolCalls ?? null) as ChatMessage['toolCalls'], + toolCallId: input.toolCallId ?? null, + status: input.status ?? 'complete', + createdAt: NOW, + }; + msgs.push(m); + return m; + }), + updateStatus: vi.fn(async (id: string, status) => { + const m = msgs.find((x) => x.id === id); + if (!m) throw new Error('not found'); + m.status = status; + return m; + }), + markPendingAsError: vi.fn(async (threadId: string) => { + let n = 0; + for (const m of msgs) { + if (m.threadId === threadId && m.status === 'pending') { + m.status = 'error'; + n += 1; + } + } + return n; + }), + touchThread: vi.fn(async () => undefined), + nextTurnIndex: vi.fn(async (threadId: string) => + msgs.filter((m) => m.threadId === threadId).length), + }; +} + +function mockPromptRepo(rows: Prompt[] = []): IPromptRepository { + return { + findAll: vi.fn(async () => rows), + findGlobal: vi.fn(async () => rows.filter((p) => p.projectId === null)), + findById: vi.fn(async (id: string) => rows.find((p) => p.id === id) ?? null), + findByNameAndProject: vi.fn(async () => null), + create: vi.fn(), + update: vi.fn(), + delete: vi.fn(), + } as unknown as IPromptRepository; +} + +function mockTools(impl: Partial = {}): ChatToolDispatcher { + return { + listTools: impl.listTools ?? vi.fn(async () => []), + callTool: impl.callTool ?? vi.fn(async () => ({ ok: true })), + }; +} + +function mockAgents(): AgentService { + return { + getByName: vi.fn(async (name: string) => ({ + id: `agent-${name}`, + name, + description: 'desc', + systemPrompt: 'You are a helpful agent.', + llm: { id: 'llm-1', name: 'qwen3-thinking' }, + project: name === 'no-project' + ? null + : { id: 'proj-1', name: 'mcpctl-dev' }, + proxyModelName: null, + defaultParams: { temperature: 0.5 }, + extras: {}, + ownerId: 'owner-1', + version: 1, + createdAt: NOW, + updatedAt: NOW, + })), + } as unknown as AgentService; +} + +function mockLlms(): LlmService { + return { + getByName: vi.fn(async (name: string) => ({ + id: 'llm-1', name, type: 'openai', model: 'qwen3-thinking', + url: '', tier: 'fast', description: '', + apiKeyRef: null, extraConfig: {}, + version: 1, createdAt: NOW, updatedAt: NOW, + })), + resolveApiKey: vi.fn(async () => 'fake-key'), + } as unknown as LlmService; +} + +/** Adapter that yields a scripted sequence of canned responses, one per call. */ +function scriptedAdapter(responses: NonStreamingResult[]): LlmAdapter { + let i = 0; + return { + kind: 'scripted', + infer: vi.fn(async (_ctx: InferContext) => { + const r = responses[i] ?? responses[responses.length - 1]; + i += 1; + if (r === undefined) throw new Error('no scripted response'); + return r; + }), + stream: async function*(_ctx: InferContext) { + yield { data: '[DONE]', done: true }; + }, + }; +} + +function adapterRegistry(adapter: LlmAdapter): LlmAdapterRegistry { + return { get: () => adapter } as unknown as LlmAdapterRegistry; +} + +function chatCompletion(content: string): NonStreamingResult { + return { + status: 200, + body: { + id: 'cmpl-1', + object: 'chat.completion', + choices: [{ index: 0, message: { role: 'assistant', content }, finish_reason: 'stop' }], + }, + }; +} + +function toolCall(name: string, args: Record): NonStreamingResult { + return { + status: 200, + body: { + id: 'cmpl-1', + object: 'chat.completion', + choices: [{ + index: 0, + message: { + role: 'assistant', + content: '', + tool_calls: [{ + id: `call-${name}`, + type: 'function', + function: { name, arguments: JSON.stringify(args) }, + }], + }, + finish_reason: 'tool_calls', + }], + }, + }; +} + +describe('ChatService', () => { + it('plain text turn — persists user + assistant rows and returns the reply', async () => { + const chatRepo = mockChatRepo(); + const adapter = scriptedAdapter([chatCompletion('hello back')]); + const svc = new ChatService( + mockAgents(), mockLlms(), adapterRegistry(adapter), + chatRepo, mockPromptRepo(), mockTools(), + ); + + const result = await svc.chat({ + agentName: 'reviewer', + userMessage: 'hi', + ownerId: 'owner-1', + }); + + expect(result.assistant).toBe('hello back'); + const stored = chatRepo._msgs.filter((m) => m.threadId === result.threadId); + expect(stored.map((m) => m.role)).toEqual(['user', 'assistant']); + expect(stored[1]?.status).toBe('complete'); + }); + + it('runs a full tool-use round-trip and ends with a text reply', async () => { + const chatRepo = mockChatRepo(); + const tools = mockTools({ + listTools: vi.fn(async () => [{ + name: `grafana${TOOL_NAME_SEPARATOR}query`, + description: 'query grafana', + parameters: { type: 'object', properties: {} }, + }]), + callTool: vi.fn(async () => ({ rows: [{ value: 42 }] })), + }); + const adapter = scriptedAdapter([ + toolCall(`grafana${TOOL_NAME_SEPARATOR}query`, { q: 'cpu' }), + chatCompletion('the answer is 42'), + ]); + const svc = new ChatService( + mockAgents(), mockLlms(), adapterRegistry(adapter), + chatRepo, mockPromptRepo(), tools, + ); + + const result = await svc.chat({ + agentName: 'reviewer', + userMessage: 'what is cpu?', + ownerId: 'owner-1', + }); + + expect(result.assistant).toBe('the answer is 42'); + expect(tools.callTool).toHaveBeenCalledWith({ + projectId: 'proj-1', + serverName: 'grafana', + toolName: 'query', + args: { q: 'cpu' }, + }); + const stored = chatRepo._msgs.filter((m) => m.threadId === result.threadId); + expect(stored.map((m) => m.role)).toEqual(['user', 'assistant', 'tool', 'assistant']); + // No `pending` rows leaked. + expect(stored.every((m) => m.status === 'complete')).toBe(true); + // Tool turn's toolCallId links back. + const toolTurn = stored.find((m) => m.role === 'tool'); + expect(toolTurn?.toolCallId).toBe(`call-grafana${TOOL_NAME_SEPARATOR}query`); + }); + + it('caps the loop at MAX_ITERATIONS when the model never settles', async () => { + const chatRepo = mockChatRepo(); + const tools = mockTools({ + listTools: vi.fn(async () => [{ + name: `g${TOOL_NAME_SEPARATOR}t`, + description: '', + parameters: { type: 'object' }, + }]), + callTool: vi.fn(async () => ({})), + }); + // Always return a tool_call → the loop never reaches a terminal turn. + const adapter = scriptedAdapter([toolCall(`g${TOOL_NAME_SEPARATOR}t`, {})]); + const svc = new ChatService( + mockAgents(), mockLlms(), adapterRegistry(adapter), + chatRepo, mockPromptRepo(), tools, + ); + + await expect(svc.chat({ + agentName: 'reviewer', + userMessage: 'loop forever', + ownerId: 'owner-1', + })).rejects.toThrow(new RegExp(`exceeded ${String(MAX_ITERATIONS)}`)); + + // After failure, no row should remain `pending`. + expect(chatRepo._msgs.every((m) => m.status !== 'pending')).toBe(true); + }); + + it('flips pending rows to error when the adapter throws mid-loop', async () => { + const chatRepo = mockChatRepo(); + const tools = mockTools({ + listTools: vi.fn(async () => [{ + name: `g${TOOL_NAME_SEPARATOR}t`, description: '', parameters: {}, + }]), + callTool: vi.fn(async () => ({})), + }); + const adapter: LlmAdapter = { + kind: 'fail-after-one', + infer: vi.fn() + .mockResolvedValueOnce(toolCall(`g${TOOL_NAME_SEPARATOR}t`, {})) + .mockRejectedValueOnce(new Error('upstream blew up')), + stream: async function*() { yield { data: '[DONE]', done: true }; }, + }; + const svc = new ChatService( + mockAgents(), mockLlms(), adapterRegistry(adapter), + chatRepo, mockPromptRepo(), tools, + ); + + await expect(svc.chat({ + agentName: 'reviewer', + userMessage: 'go', + ownerId: 'owner-1', + })).rejects.toThrow('upstream blew up'); + + expect(chatRepo._msgs.some((m) => m.status === 'error')).toBe(false); + expect(chatRepo._msgs.every((m) => m.status !== 'pending')).toBe(true); + }); + + it('merges per-call params over agent.defaultParams (override wins)', async () => { + const chatRepo = mockChatRepo(); + const adapter = scriptedAdapter([chatCompletion('ok')]); + const inferSpy = adapter.infer as ReturnType; + const svc = new ChatService( + mockAgents(), mockLlms(), adapterRegistry(adapter), + chatRepo, mockPromptRepo(), mockTools(), + ); + await svc.chat({ + agentName: 'reviewer', + userMessage: 'hi', + ownerId: 'owner-1', + params: { temperature: 0.9, max_tokens: 256 }, + }); + const ctx = inferSpy.mock.calls[0][0] as InferContext; + expect(ctx.body.temperature).toBe(0.9); + expect(ctx.body.max_tokens).toBe(256); + }); + + it('forwards `extra` keys into the body for provider-specific knobs', async () => { + const chatRepo = mockChatRepo(); + const adapter = scriptedAdapter([chatCompletion('ok')]); + const inferSpy = adapter.infer as ReturnType; + const svc = new ChatService( + mockAgents(), mockLlms(), adapterRegistry(adapter), + chatRepo, mockPromptRepo(), mockTools(), + ); + await svc.chat({ + agentName: 'reviewer', + userMessage: 'hi', + ownerId: 'owner-1', + params: { extra: { metadata: { user_id: 'abc' }, repetition_penalty: 1.05 } }, + }); + const ctx = inferSpy.mock.calls[0][0] as InferContext; + expect((ctx.body as Record)['repetition_penalty']).toBe(1.05); + expect((ctx.body as Record)['metadata']).toEqual({ user_id: 'abc' }); + }); + + it('builds a system block from agent.systemPrompt + project prompts (priority desc)', async () => { + const chatRepo = mockChatRepo(); + const adapter = scriptedAdapter([chatCompletion('ok')]); + const inferSpy = adapter.infer as ReturnType; + const prompts: Prompt[] = [ + { + id: 'p1', name: 'low', content: 'LOW prompt', + projectId: 'proj-1', priority: 1, summary: null, chapters: null, + linkTarget: null, version: 1, createdAt: NOW, updatedAt: NOW, + }, + { + id: 'p2', name: 'high', content: 'HIGH prompt', + projectId: 'proj-1', priority: 9, summary: null, chapters: null, + linkTarget: null, version: 1, createdAt: NOW, updatedAt: NOW, + }, + ]; + const svc = new ChatService( + mockAgents(), mockLlms(), adapterRegistry(adapter), + chatRepo, mockPromptRepo(prompts), mockTools(), + ); + await svc.chat({ agentName: 'reviewer', userMessage: 'hi', ownerId: 'owner-1' }); + const ctx = inferSpy.mock.calls[0][0] as InferContext; + const sys = ctx.body.messages.find((m) => m.role === 'system'); + expect(typeof sys?.content).toBe('string'); + const text = sys?.content as string; + // High-priority prompt comes before low-priority. + expect(text.indexOf('HIGH prompt')).toBeLessThan(text.indexOf('LOW prompt')); + // Agent's own system prompt leads. + expect(text.indexOf('You are a helpful agent.')).toBeLessThan(text.indexOf('HIGH prompt')); + }); + + it('refuses tool calls when the agent has no project attached', async () => { + const chatRepo = mockChatRepo(); + const adapter = scriptedAdapter([toolCall(`x${TOOL_NAME_SEPARATOR}y`, {})]); + const tools = mockTools({ + listTools: vi.fn(async () => [{ name: `x${TOOL_NAME_SEPARATOR}y`, description: '', parameters: {} }]), + }); + const svc = new ChatService( + mockAgents(), mockLlms(), adapterRegistry(adapter), + chatRepo, mockPromptRepo(), tools, + ); + await expect(svc.chat({ + agentName: 'no-project', + userMessage: 'go', + ownerId: 'owner-1', + })).rejects.toThrow(/Project/); + }); + + it('honours tools_allowlist (filters tools before sending to adapter)', async () => { + const chatRepo = mockChatRepo(); + const adapter = scriptedAdapter([chatCompletion('ok')]); + const inferSpy = adapter.infer as ReturnType; + const tools = mockTools({ + listTools: vi.fn(async () => [ + { name: `s1${TOOL_NAME_SEPARATOR}a`, description: '', parameters: {} }, + { name: `s1${TOOL_NAME_SEPARATOR}b`, description: '', parameters: {} }, + ]), + }); + const svc = new ChatService( + mockAgents(), mockLlms(), adapterRegistry(adapter), + chatRepo, mockPromptRepo(), tools, + ); + await svc.chat({ + agentName: 'reviewer', + userMessage: 'hi', + ownerId: 'owner-1', + params: { tools_allowlist: [`s1${TOOL_NAME_SEPARATOR}a`] }, + }); + const ctx = inferSpy.mock.calls[0][0] as InferContext; + expect(ctx.body.tools).toHaveLength(1); + expect(ctx.body.tools?.[0]?.function.name).toBe(`s1${TOOL_NAME_SEPARATOR}a`); + }); +});