feat: gated project experience & prompt intelligence
Some checks failed
CI / lint (pull_request) Has been cancelled
CI / typecheck (pull_request) Has been cancelled
CI / test (pull_request) Has been cancelled
CI / build (pull_request) Has been cancelled
CI / package (pull_request) Has been cancelled

Implements the full gated session flow and prompt intelligence system:

- Prisma schema: add gated, priority, summary, chapters, linkTarget fields
- Session gate: state machine (gated → begin_session → ungated) with LLM-powered
  tool selection based on prompt index
- Tag matcher: intelligent prompt-to-tool matching with project/server/action tags
- LLM selector: tiered provider selection (fast for gating, heavy for complex tasks)
- Link resolver: cross-project MCP resource references (project/server:uri format)
- Prompt summary service: LLM-generated summaries and chapter extraction
- System project bootstrap: ensures default project exists on startup
- Structural link health checks: enrichWithLinkStatus on prompt GET endpoints
- CLI: create prompt --priority/--link, create project --gated/--no-gated,
  describe project shows prompts section, get prompts shows PRI/LINK/STATUS
- Apply/edit: priority, linkTarget, gated fields supported
- Shell completions: fish updated with new flags
- 1,253 tests passing across all packages

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Michal
2026-02-25 23:22:42 +00:00
parent 62647a7f90
commit 705df06996
46 changed files with 4946 additions and 105 deletions

View File

@@ -55,6 +55,7 @@ export async function refreshProjectUpstreams(
export interface ProjectLlmConfig {
llmProvider?: string;
llmModel?: string;
gated?: boolean;
}
export async function fetchProjectLlmConfig(
@@ -65,10 +66,12 @@ export async function fetchProjectLlmConfig(
const project = await mcpdClient.get<{
llmProvider?: string;
llmModel?: string;
gated?: boolean;
}>(`/api/v1/projects/${encodeURIComponent(projectName)}`);
const config: ProjectLlmConfig = {};
if (project.llmProvider) config.llmProvider = project.llmProvider;
if (project.llmModel) config.llmModel = project.llmModel;
if (project.gated !== undefined) config.gated = project.gated;
return config;
} catch {
return {};

View File

@@ -0,0 +1,81 @@
/**
* LLM-based prompt selection for the gating flow.
*
* Sends tags + prompt index to the heavy LLM, which returns
* a ranked list of relevant prompt names.
*/
import type { ProviderRegistry } from '../providers/registry.js';
export interface PromptIndexForLlm {
name: string;
priority: number;
summary: string | null;
chapters: string[] | null;
}
export interface LlmSelectionResult {
selectedNames: string[];
reasoning: string;
}
export class LlmPromptSelector {
constructor(
private readonly providerRegistry: ProviderRegistry,
private readonly modelOverride?: string,
) {}
async selectPrompts(
tags: string[],
promptIndex: PromptIndexForLlm[],
): Promise<LlmSelectionResult> {
const systemPrompt = `You are a context selection assistant. Given a developer's task keywords and a list of available project prompts, select which prompts are relevant to their work. Return a JSON object with "selectedNames" (array of prompt names) and "reasoning" (brief explanation). Priority 10 prompts must always be included.`;
const userPrompt = `Task keywords: ${tags.join(', ')}
Available prompts:
${promptIndex.map((p) => `- ${p.name} (priority: ${p.priority}): ${p.summary ?? 'No summary'}${p.chapters?.length ? `\n Chapters: ${p.chapters.join(', ')}` : ''}`).join('\n')}
Select the relevant prompts. Return JSON: { "selectedNames": [...], "reasoning": "..." }`;
const provider = this.providerRegistry.getProvider('heavy');
if (!provider) {
throw new Error('No heavy LLM provider available');
}
const completionOptions: import('../providers/types.js').CompletionOptions = {
messages: [
{ role: 'system', content: systemPrompt },
{ role: 'user', content: userPrompt },
],
temperature: 0,
maxTokens: 1024,
};
if (this.modelOverride) {
completionOptions.model = this.modelOverride;
}
const result = await provider.complete(completionOptions);
const response = result.content;
// Parse JSON from response (may be wrapped in markdown code blocks)
const jsonMatch = response.match(/\{[\s\S]*"selectedNames"[\s\S]*\}/);
if (!jsonMatch) {
throw new Error('LLM response did not contain valid selection JSON');
}
const parsed = JSON.parse(jsonMatch[0]) as { selectedNames?: string[]; reasoning?: string };
const selectedNames = parsed.selectedNames ?? [];
const reasoning = parsed.reasoning ?? '';
// Always include priority 10 prompts
for (const p of promptIndex) {
if (p.priority === 10 && !selectedNames.includes(p.name)) {
selectedNames.push(p.name);
}
}
return { selectedNames, reasoning };
}
}

View File

@@ -0,0 +1,76 @@
/**
* Per-session gating state machine.
*
* Tracks whether a session has gone through the prompt selection flow.
* When gated, only begin_session is accessible. After ungating, all tools work.
*/
import type { PromptIndexEntry, TagMatchResult } from './tag-matcher.js';
export interface SessionState {
gated: boolean;
tags: string[];
retrievedPrompts: Set<string>;
briefing: string | null;
}
export class SessionGate {
private sessions = new Map<string, SessionState>();
/** Create a new session. Starts gated if the project is gated. */
createSession(sessionId: string, projectGated: boolean): void {
this.sessions.set(sessionId, {
gated: projectGated,
tags: [],
retrievedPrompts: new Set(),
briefing: null,
});
}
/** Get session state. Returns null if session doesn't exist. */
getSession(sessionId: string): SessionState | null {
return this.sessions.get(sessionId) ?? null;
}
/** Check if a session is currently gated. Unknown sessions are treated as ungated. */
isGated(sessionId: string): boolean {
return this.sessions.get(sessionId)?.gated ?? false;
}
/** Ungate a session after prompt selection is complete. */
ungate(sessionId: string, tags: string[], matchResult: TagMatchResult): void {
const session = this.sessions.get(sessionId);
if (!session) return;
session.gated = false;
session.tags = [...session.tags, ...tags];
// Track which prompts have been sent
for (const p of matchResult.fullContent) {
session.retrievedPrompts.add(p.name);
}
}
/** Record additional prompts retrieved via read_prompts. */
addRetrievedPrompts(sessionId: string, tags: string[], promptNames: string[]): void {
const session = this.sessions.get(sessionId);
if (!session) return;
session.tags = [...session.tags, ...tags];
for (const name of promptNames) {
session.retrievedPrompts.add(name);
}
}
/** Filter out prompts already sent to avoid duplicates. */
filterAlreadySent(sessionId: string, prompts: PromptIndexEntry[]): PromptIndexEntry[] {
const session = this.sessions.get(sessionId);
if (!session) return prompts;
return prompts.filter((p) => !session.retrievedPrompts.has(p.name));
}
/** Remove a session (cleanup on disconnect). */
removeSession(sessionId: string): void {
this.sessions.delete(sessionId);
}
}

View File

@@ -0,0 +1,109 @@
/**
* Deterministic keyword-based tag matching for prompt selection.
*
* Used as the no-LLM fallback (and for read_prompts in hybrid mode).
* Scores prompts by tag overlap * priority, then fills a byte budget.
*/
export interface PromptIndexEntry {
name: string;
priority: number;
summary: string | null;
chapters: string[] | null;
content: string;
}
export interface TagMatchResult {
/** Prompts with full content included (within byte budget) */
fullContent: PromptIndexEntry[];
/** Matched prompts beyond byte budget — name + summary only */
indexOnly: PromptIndexEntry[];
/** Non-matched prompts — listed for awareness */
remaining: PromptIndexEntry[];
}
const DEFAULT_BYTE_BUDGET = 8192;
export class TagMatcher {
constructor(private readonly byteBudget: number = DEFAULT_BYTE_BUDGET) {}
match(tags: string[], prompts: PromptIndexEntry[]): TagMatchResult {
const lowerTags = tags.map((t) => t.toLowerCase());
// Score each prompt
const scored = prompts.map((p) => ({
prompt: p,
score: this.computeScore(lowerTags, p),
matched: this.computeScore(lowerTags, p) > 0,
}));
// Partition: matched (score > 0) vs non-matched
const matched = scored.filter((s) => s.matched).sort((a, b) => b.score - a.score);
const nonMatched = scored.filter((s) => !s.matched).map((s) => s.prompt);
// Fill byte budget from matched prompts
let budgetRemaining = this.byteBudget;
const fullContent: PromptIndexEntry[] = [];
const indexOnly: PromptIndexEntry[] = [];
for (const { prompt } of matched) {
const contentBytes = Buffer.byteLength(prompt.content, 'utf-8');
if (budgetRemaining >= contentBytes) {
fullContent.push(prompt);
budgetRemaining -= contentBytes;
} else {
indexOnly.push(prompt);
}
}
return { fullContent, indexOnly, remaining: nonMatched };
}
private computeScore(lowerTags: string[], prompt: PromptIndexEntry): number {
// Priority 10 always included
if (prompt.priority === 10) return Infinity;
if (lowerTags.length === 0) return 0;
const searchText = [
prompt.name,
prompt.summary ?? '',
...(prompt.chapters ?? []),
].join(' ').toLowerCase();
let matchCount = 0;
for (const tag of lowerTags) {
if (searchText.includes(tag)) matchCount++;
}
return matchCount * prompt.priority;
}
}
/**
* Extract keywords from a tool call for the intercept fallback path.
* Pulls words from the tool name and string argument values.
*/
export function extractKeywordsFromToolCall(
toolName: string,
args: Record<string, unknown>,
): string[] {
const keywords = new Set<string>();
// Tool name parts (split on / and -)
for (const part of toolName.split(/[/-]/)) {
if (part.length > 2) keywords.add(part.toLowerCase());
}
// String argument values — extract words
for (const value of Object.values(args)) {
if (typeof value === 'string' && value.length < 200) {
for (const word of value.split(/\s+/)) {
const clean = word.replace(/[^a-zA-Z0-9-]/g, '').toLowerCase();
if (clean.length > 2) keywords.add(clean);
}
}
}
return [...keywords].slice(0, 10); // Cap at 10 keywords
}

View File

@@ -52,13 +52,28 @@ export function registerProjectMcpEndpoint(app: FastifyInstance, mcpdClient: Mcp
const mcpdConfig = await fetchProjectLlmConfig(mcpdClient, projectName);
const resolvedModel = localOverride?.model ?? mcpdConfig.llmModel ?? undefined;
// If project llmProvider is "none", disable LLM for this project
const llmDisabled = mcpdConfig.llmProvider === 'none' || localOverride?.provider === 'none';
const effectiveRegistry = llmDisabled ? null : (providerRegistry ?? null);
// Wire pagination support with LLM provider and project model override
router.setPaginator(new ResponsePaginator(providerRegistry ?? null, {}, resolvedModel));
router.setPaginator(new ResponsePaginator(effectiveRegistry, {}, resolvedModel));
// Configure prompt resources with SA-scoped client for RBAC
const saClient = mcpdClient.withHeaders({ 'X-Service-Account': `project:${projectName}` });
router.setPromptConfig(saClient, projectName);
// Configure gating if project has it enabled (default: true)
const isGated = mcpdConfig.gated !== false;
const gateConfig: import('../router.js').GateConfig = {
gated: isGated,
providerRegistry: effectiveRegistry,
};
if (resolvedModel) {
gateConfig.modelOverride = resolvedModel;
}
router.setGateConfig(gateConfig);
// Fetch project instructions and set on router
try {
const instructions = await mcpdClient.get<{ prompt: string; servers: Array<{ name: string; description: string }> }>(
@@ -131,6 +146,7 @@ export function registerProjectMcpEndpoint(app: FastifyInstance, mcpdClient: Mcp
const id = transport.sessionId;
if (id) {
sessions.delete(id);
router.cleanupSession(id);
}
};

View File

@@ -2,11 +2,23 @@ import type { UpstreamConnection, JsonRpcRequest, JsonRpcResponse, JsonRpcNotifi
import type { LlmProcessor } from './llm/processor.js';
import { ResponsePaginator } from './llm/pagination.js';
import type { McpdClient } from './http/mcpd-client.js';
import { SessionGate } from './gate/session-gate.js';
import { TagMatcher, extractKeywordsFromToolCall } from './gate/tag-matcher.js';
import type { PromptIndexEntry, TagMatchResult } from './gate/tag-matcher.js';
import { LlmPromptSelector } from './gate/llm-selector.js';
import type { ProviderRegistry } from './providers/registry.js';
export interface RouteContext {
sessionId?: string;
}
export interface GateConfig {
gated: boolean;
providerRegistry: ProviderRegistry | null;
modelOverride?: string;
byteBudget?: number;
}
/**
* Routes MCP requests to the appropriate upstream server.
*
@@ -28,11 +40,28 @@ export class McpRouter {
private projectName: string | null = null;
private mcpctlResourceContents = new Map<string, string>();
private paginator: ResponsePaginator | null = null;
private sessionGate = new SessionGate();
private gateConfig: GateConfig | null = null;
private tagMatcher: TagMatcher | null = null;
private llmSelector: LlmPromptSelector | null = null;
private cachedPromptIndex: PromptIndexEntry[] | null = null;
private promptIndexFetchedAt = 0;
private readonly PROMPT_INDEX_TTL_MS = 60_000;
private systemPromptCache = new Map<string, { content: string; fetchedAt: number }>();
private readonly SYSTEM_PROMPT_TTL_MS = 300_000; // 5 minutes
setPaginator(paginator: ResponsePaginator): void {
this.paginator = paginator;
}
setGateConfig(config: GateConfig): void {
this.gateConfig = config;
this.tagMatcher = new TagMatcher(config.byteBudget);
if (config.providerRegistry) {
this.llmSelector = new LlmPromptSelector(config.providerRegistry, config.modelOverride);
}
}
setLlmProcessor(processor: LlmProcessor): void {
this.llmProcessor = processor;
}
@@ -257,28 +286,50 @@ export class McpRouter {
*/
async route(request: JsonRpcRequest, context?: RouteContext): Promise<JsonRpcResponse> {
switch (request.method) {
case 'initialize':
return {
jsonrpc: '2.0',
id: request.id,
result: {
protocolVersion: '2024-11-05',
serverInfo: {
name: 'mcpctl-proxy',
version: '0.1.0',
},
capabilities: {
tools: {},
resources: {},
prompts: {},
},
...(this.instructions ? { instructions: this.instructions } : {}),
case 'initialize': {
// Create gated session if project is gated
const isGated = this.gateConfig?.gated ?? false;
if (context?.sessionId && this.gateConfig) {
this.sessionGate.createSession(context.sessionId, isGated);
}
// Build instructions: base project instructions + gate message with prompt index
let instructions = this.instructions ?? '';
if (isGated) {
instructions = await this.buildGatedInstructions(instructions);
}
const result: Record<string, unknown> = {
protocolVersion: '2024-11-05',
serverInfo: {
name: 'mcpctl-proxy',
version: '0.1.0',
},
capabilities: {
tools: {},
resources: {},
prompts: {},
},
};
if (instructions) {
result['instructions'] = instructions;
}
return { jsonrpc: '2.0', id: request.id, result };
}
case 'tools/list': {
// When gated, only show begin_session
if (context?.sessionId && this.sessionGate.isGated(context.sessionId)) {
return {
jsonrpc: '2.0',
id: request.id,
result: { tools: [this.getBeginSessionTool()] },
};
}
const tools = await this.discoverTools();
// Append propose_prompt tool if prompt config is set
// Append built-in tools if prompt config is set
if (this.mcpdClient && this.projectName) {
tools.push({
name: 'propose_prompt',
@@ -293,6 +344,10 @@ export class McpRouter {
},
});
}
// Always offer read_prompts when gating is configured (even for ungated sessions)
if (this.gateConfig && this.mcpdClient && this.projectName) {
tools.push(this.getReadPromptsTool());
}
return {
jsonrpc: '2.0',
id: request.id,
@@ -337,6 +392,44 @@ export class McpRouter {
case 'resources/read': {
const params = request.params as Record<string, unknown> | undefined;
const uri = params?.['uri'] as string | undefined;
if (uri?.startsWith('mcpctl://prompts/') && this.mcpdClient && this.projectName) {
const promptName = uri.slice('mcpctl://prompts/'.length);
try {
const sessionParam = context?.sessionId ? `?session=${encodeURIComponent(context.sessionId)}` : '';
const visible = await this.mcpdClient.get<Array<{ name: string; content: string; type: string }>>(
`/api/v1/projects/${encodeURIComponent(this.projectName)}/prompts/visible${sessionParam}`,
);
const found = visible.find((p) => p.name === promptName);
if (found) {
this.mcpctlResourceContents.set(uri, found.content);
return {
jsonrpc: '2.0',
id: request.id,
result: {
contents: [{ uri, mimeType: 'text/plain', text: found.content }],
},
};
}
} catch {
// Fall through to cache
}
// Fallback to cache if mcpd is unreachable
const cached = this.mcpctlResourceContents.get(uri);
if (cached !== undefined) {
return {
jsonrpc: '2.0',
id: request.id,
result: {
contents: [{ uri, mimeType: 'text/plain', text: cached }],
},
};
}
return {
jsonrpc: '2.0',
id: request.id,
error: { code: -32602, message: `Resource not found: ${uri}` },
};
}
if (uri?.startsWith('mcpctl://')) {
const content = this.mcpctlResourceContents.get(uri);
if (content !== undefined) {
@@ -400,13 +493,26 @@ export class McpRouter {
const params = request.params as Record<string, unknown> | undefined;
const toolName = params?.['name'] as string | undefined;
// Handle built-in propose_prompt tool
// Handle built-in tools
if (toolName === 'propose_prompt') {
return this.handleProposePrompt(request, context);
}
if (toolName === 'begin_session') {
return this.handleBeginSession(request, context);
}
if (toolName === 'read_prompts') {
return this.handleReadPrompts(request, context);
}
// Extract tool arguments early (needed for both gated intercept and pagination)
const toolArgs = (params?.['arguments'] ?? {}) as Record<string, unknown>;
// Intercept: if session is gated and trying to call a real tool, auto-ungate with keyword extraction
if (context?.sessionId && this.sessionGate.isGated(context.sessionId)) {
return this.handleGatedIntercept(request, context, toolName ?? '', toolArgs);
}
// Intercept pagination page requests before routing to upstream
const toolArgs = (params?.['arguments'] ?? {}) as Record<string, unknown>;
if (this.paginator) {
const paginationReq = ResponsePaginator.extractPaginationParams(toolArgs);
if (paginationReq) {
@@ -525,6 +631,417 @@ export class McpRouter {
}
}
// ── Gate tool definitions ──
private getBeginSessionTool(): { name: string; description: string; inputSchema: unknown } {
return {
name: 'begin_session',
description: 'Start your session by providing keywords that describe your current task. You will receive relevant project context, policies, and guidelines. This is required before using other tools.',
inputSchema: {
type: 'object',
properties: {
tags: {
type: 'array',
items: { type: 'string' },
maxItems: 10,
description: '3-7 keywords describing your current task (e.g. ["zigbee", "pairing", "mqtt"])',
},
},
required: ['tags'],
},
};
}
private getReadPromptsTool(): { name: string; description: string; inputSchema: unknown } {
return {
name: 'read_prompts',
description: 'Retrieve additional project prompts by keywords. Use this if you need more context about specific topics. Returns matched prompts and a list of other available prompts.',
inputSchema: {
type: 'object',
properties: {
tags: {
type: 'array',
items: { type: 'string' },
maxItems: 10,
description: 'Keywords to match against available prompts',
},
},
required: ['tags'],
},
};
}
// ── Gate handlers ──
private async handleBeginSession(request: JsonRpcRequest, context?: RouteContext): Promise<JsonRpcResponse> {
if (!this.gateConfig || !this.mcpdClient || !this.projectName) {
return { jsonrpc: '2.0', id: request.id, error: { code: -32603, message: 'Gating not configured' } };
}
const params = request.params as Record<string, unknown> | undefined;
const args = (params?.['arguments'] ?? {}) as Record<string, unknown>;
const tags = args['tags'] as string[] | undefined;
if (!tags || !Array.isArray(tags) || tags.length === 0) {
return { jsonrpc: '2.0', id: request.id, error: { code: -32602, message: 'Missing or empty tags array' } };
}
const sessionId = context?.sessionId;
if (sessionId && !this.sessionGate.isGated(sessionId)) {
return {
jsonrpc: '2.0',
id: request.id,
result: {
content: [{ type: 'text', text: 'Session already started. Use read_prompts to retrieve additional context.' }],
},
};
}
try {
const promptIndex = await this.fetchPromptIndex();
// Primary: LLM selection. Fallback: deterministic tag matching.
let matchResult: TagMatchResult;
let reasoning = '';
if (this.llmSelector) {
try {
const llmIndex = promptIndex.map((p) => ({
name: p.name,
priority: p.priority,
summary: p.summary,
chapters: p.chapters,
}));
const llmResult = await this.llmSelector.selectPrompts(tags, llmIndex);
reasoning = llmResult.reasoning;
// Convert LLM names back to full PromptIndexEntry results via TagMatcher for byte-budget
const selectedSet = new Set(llmResult.selectedNames);
const selected = promptIndex.filter((p) => selectedSet.has(p.name));
const remaining = promptIndex.filter((p) => !selectedSet.has(p.name));
// Apply byte budget to the LLM-selected prompts
matchResult = this.tagMatcher!.match(
// Use all tags + selected names as keywords so everything scores > 0
[...tags, ...llmResult.selectedNames],
selected,
);
// Put LLM-unselected in remaining
matchResult.remaining = [...matchResult.remaining, ...remaining];
} catch {
// LLM failed — fall back to keyword matching
matchResult = this.tagMatcher!.match(tags, promptIndex);
}
} else {
matchResult = this.tagMatcher!.match(tags, promptIndex);
}
// Ungate the session
if (sessionId) {
this.sessionGate.ungate(sessionId, tags, matchResult);
}
// Build response
const responseParts: string[] = [];
if (reasoning) {
responseParts.push(`Selection reasoning: ${reasoning}\n`);
}
// Full content prompts
for (const p of matchResult.fullContent) {
responseParts.push(`--- ${p.name} (priority: ${p.priority}) ---\n${p.content}\n`);
}
// Index-only (over budget)
if (matchResult.indexOnly.length > 0) {
responseParts.push('Additional matched prompts (use read_prompts to retrieve full content):');
for (const p of matchResult.indexOnly) {
responseParts.push(` - ${p.name}: ${p.summary ?? 'No description'}`);
}
responseParts.push('');
}
// Remaining prompts for awareness
if (matchResult.remaining.length > 0) {
responseParts.push('Other available prompts:');
for (const p of matchResult.remaining) {
responseParts.push(` - ${p.name}: ${p.summary ?? 'No description'}`);
}
responseParts.push('');
}
// Encouragement (from system prompt or fallback)
const encouragement = await this.getSystemPrompt(
'gate-encouragement',
'If any of the listed prompts seem relevant to your work, or if you encounter unfamiliar patterns, conventions, or constraints during implementation, use read_prompts({ tags: [...] }) to retrieve them. It is better to check and not need it than to proceed without important context.',
);
responseParts.push(encouragement);
return {
jsonrpc: '2.0',
id: request.id,
result: {
content: [{ type: 'text', text: responseParts.join('\n') }],
},
};
} catch (err) {
return {
jsonrpc: '2.0',
id: request.id,
error: { code: -32603, message: `begin_session failed: ${err instanceof Error ? err.message : String(err)}` },
};
}
}
private async handleReadPrompts(request: JsonRpcRequest, context?: RouteContext): Promise<JsonRpcResponse> {
if (!this.tagMatcher || !this.mcpdClient || !this.projectName) {
return { jsonrpc: '2.0', id: request.id, error: { code: -32603, message: 'Prompt retrieval not configured' } };
}
const params = request.params as Record<string, unknown> | undefined;
const args = (params?.['arguments'] ?? {}) as Record<string, unknown>;
const tags = args['tags'] as string[] | undefined;
if (!tags || !Array.isArray(tags) || tags.length === 0) {
return { jsonrpc: '2.0', id: request.id, error: { code: -32602, message: 'Missing or empty tags array' } };
}
try {
const promptIndex = await this.fetchPromptIndex();
const sessionId = context?.sessionId;
// Filter out already-sent prompts
const available = sessionId ? this.sessionGate.filterAlreadySent(sessionId, promptIndex) : promptIndex;
// Always use deterministic tag matching for read_prompts (hybrid mode)
const matchResult = this.tagMatcher.match(tags, available);
// Record retrieved prompts
if (sessionId) {
this.sessionGate.addRetrievedPrompts(
sessionId,
tags,
matchResult.fullContent.map((p) => p.name),
);
}
if (matchResult.fullContent.length === 0 && matchResult.indexOnly.length === 0) {
return {
jsonrpc: '2.0',
id: request.id,
result: {
content: [{ type: 'text', text: 'No new matching prompts found for the given keywords.' }],
},
};
}
const responseParts: string[] = [];
for (const p of matchResult.fullContent) {
responseParts.push(`--- ${p.name} (priority: ${p.priority}) ---\n${p.content}\n`);
}
if (matchResult.indexOnly.length > 0) {
responseParts.push('Additional matched prompts (too large to include, try more specific keywords):');
for (const p of matchResult.indexOnly) {
responseParts.push(` - ${p.name}: ${p.summary ?? 'No description'}`);
}
}
return {
jsonrpc: '2.0',
id: request.id,
result: {
content: [{ type: 'text', text: responseParts.join('\n') }],
},
};
} catch (err) {
return {
jsonrpc: '2.0',
id: request.id,
error: { code: -32603, message: `read_prompts failed: ${err instanceof Error ? err.message : String(err)}` },
};
}
}
/**
* Intercept handler: when a gated session tries to call a real tool,
* extract keywords from the tool call, auto-ungate, and prepend a briefing.
*/
private async handleGatedIntercept(
request: JsonRpcRequest,
context: RouteContext,
toolName: string,
toolArgs: Record<string, unknown>,
): Promise<JsonRpcResponse> {
const sessionId = context.sessionId!;
// Extract keywords from the tool call as a fallback
const tags = extractKeywordsFromToolCall(toolName, toolArgs);
try {
const promptIndex = await this.fetchPromptIndex();
const matchResult = this.tagMatcher!.match(tags, promptIndex);
// Ungate the session
this.sessionGate.ungate(sessionId, tags, matchResult);
// Build briefing from matched content
const briefingParts: string[] = [];
if (matchResult.fullContent.length > 0) {
const preamble = await this.getSystemPrompt(
'gate-intercept-preamble',
'The following project context was automatically retrieved based on your tool call.',
);
briefingParts.push(`--- ${preamble} ---\n`);
for (const p of matchResult.fullContent) {
briefingParts.push(`--- ${p.name} (priority: ${p.priority}) ---\n${p.content}\n`);
}
briefingParts.push('--- End of project context ---\n');
}
if (matchResult.remaining.length > 0 || matchResult.indexOnly.length > 0) {
briefingParts.push('Other prompts available (use read_prompts to retrieve):');
for (const p of [...matchResult.indexOnly, ...matchResult.remaining]) {
briefingParts.push(` - ${p.name}: ${p.summary ?? 'No description'}`);
}
briefingParts.push('');
}
// Now route the actual tool call
const response = await this.routeNamespacedCall(request, 'name', this.toolToServer);
const paginatedResponse = await this.maybePaginate(toolName, response);
// Prepend briefing to the response
if (briefingParts.length > 0 && paginatedResponse.result && !paginatedResponse.error) {
const result = paginatedResponse.result as { content?: Array<{ type: string; text: string }> };
const briefing = briefingParts.join('\n');
if (result.content && Array.isArray(result.content)) {
result.content.unshift({ type: 'text', text: briefing });
} else {
(paginatedResponse.result as Record<string, unknown>)['_briefing'] = briefing;
}
}
return paginatedResponse;
} catch {
// If prompt retrieval fails, just ungate and route normally
this.sessionGate.ungate(sessionId, tags, { fullContent: [], indexOnly: [], remaining: [] });
return this.routeNamespacedCall(request, 'name', this.toolToServer);
}
}
/**
* Fetch prompt index from mcpd with caching.
*/
private async fetchPromptIndex(): Promise<PromptIndexEntry[]> {
const now = Date.now();
if (this.cachedPromptIndex && (now - this.promptIndexFetchedAt) < this.PROMPT_INDEX_TTL_MS) {
return this.cachedPromptIndex;
}
if (!this.mcpdClient || !this.projectName) {
return [];
}
const index = await this.mcpdClient.get<Array<{
name: string;
priority: number;
summary: string | null;
chapters: string[] | null;
content?: string;
}>>(
`/api/v1/projects/${encodeURIComponent(this.projectName)}/prompts/visible`,
);
this.cachedPromptIndex = index.map((p) => ({
name: p.name,
priority: p.priority,
summary: p.summary,
chapters: p.chapters,
content: p.content ?? '',
}));
this.promptIndexFetchedAt = now;
return this.cachedPromptIndex;
}
/**
* Build instructions for gated projects: base instructions + gate message + prompt index.
*/
private async buildGatedInstructions(baseInstructions: string): Promise<string> {
const parts: string[] = [];
if (baseInstructions) {
parts.push(baseInstructions);
}
const gateInstructions = await this.getSystemPrompt(
'gate-instructions',
'IMPORTANT: This project uses a gated session. You must call begin_session with keywords describing your task before using any other tools. This will provide you with relevant project context, policies, and guidelines.',
);
parts.push(`\n${gateInstructions}`);
// Append compact prompt index so the LLM knows what's available
try {
const promptIndex = await this.fetchPromptIndex();
if (promptIndex.length > 0) {
// Cap at 50 entries; if over 50, show priority 7+ only
let displayIndex = promptIndex;
if (displayIndex.length > 50) {
displayIndex = displayIndex.filter((p) => p.priority >= 7);
}
// Sort by priority descending
displayIndex.sort((a, b) => b.priority - a.priority);
parts.push('\nAvailable project prompts:');
for (const p of displayIndex) {
const summary = p.summary ? `: ${p.summary}` : '';
parts.push(`- ${p.name} (priority ${p.priority})${summary}`);
}
parts.push(
'\nChoose your begin_session keywords based on which of these prompts seem relevant to your task.',
);
}
} catch {
// Prompt index is optional — don't fail initialization
}
return parts.join('\n');
}
/**
* Fetch a system prompt from mcpctl-system project, with caching and fallback.
*/
private async getSystemPrompt(name: string, fallback: string): Promise<string> {
const now = Date.now();
const cached = this.systemPromptCache.get(name);
if (cached && (now - cached.fetchedAt) < this.SYSTEM_PROMPT_TTL_MS) {
return cached.content;
}
if (!this.mcpdClient) return fallback;
try {
const visible = await this.mcpdClient.get<Array<{ name: string; content: string }>>(
'/api/v1/projects/mcpctl-system/prompts/visible',
);
// Cache all system prompts from the response
for (const p of visible) {
this.systemPromptCache.set(p.name, { content: p.content, fetchedAt: now });
}
const found = visible.find((p) => p.name === name);
return found?.content ?? fallback;
} catch {
return fallback;
}
}
// ── Session cleanup ──
cleanupSession(sessionId: string): void {
this.sessionGate.removeSession(sessionId);
}
getUpstreamNames(): string[] {
return [...this.upstreams.keys()];
}

View File

@@ -0,0 +1,133 @@
import type { McpdClient } from '../http/mcpd-client.js';
export interface LinkResolution {
content: string | null;
status: 'alive' | 'dead' | 'unknown';
error?: string;
}
interface CacheEntry {
resolution: LinkResolution;
expiresAt: number;
}
interface ParsedLink {
project: string;
server: string;
uri: string;
}
/**
* Resolves prompt links by fetching MCP resources from source projects via mcpd.
* Link format: project/server:resource-uri
*/
export class LinkResolver {
private cache = new Map<string, CacheEntry>();
constructor(
private readonly mcpdClient: McpdClient,
private readonly cacheTtlMs = 5 * 60 * 1000, // 5 minutes
) {}
/**
* Parse a link target string into its components.
* Format: project/server:resource-uri
*/
parseLink(linkTarget: string): ParsedLink {
const slashIdx = linkTarget.indexOf('/');
if (slashIdx < 1) throw new Error(`Invalid link format (missing project): ${linkTarget}`);
const project = linkTarget.slice(0, slashIdx);
const rest = linkTarget.slice(slashIdx + 1);
const colonIdx = rest.indexOf(':');
if (colonIdx < 1) throw new Error(`Invalid link format (missing server:uri): ${linkTarget}`);
const server = rest.slice(0, colonIdx);
const uri = rest.slice(colonIdx + 1);
if (!uri) throw new Error(`Invalid link format (empty uri): ${linkTarget}`);
return { project, server, uri };
}
/**
* Resolve a link target and return the fetched content + status.
* Results are cached with a configurable TTL.
*/
async resolve(linkTarget: string): Promise<LinkResolution> {
// Check cache first
const cached = this.cache.get(linkTarget);
if (cached && cached.expiresAt > Date.now()) {
return cached.resolution;
}
let resolution: LinkResolution;
try {
const { project, server, uri } = this.parseLink(linkTarget);
const content = await this.fetchResource(project, server, uri);
resolution = { content, status: 'alive' };
} catch (err) {
const message = err instanceof Error ? err.message : String(err);
console.error(`[link-resolver] Dead link: ${linkTarget}${message}`);
resolution = { content: null, status: 'dead', error: message };
}
// Cache the result
this.cache.set(linkTarget, {
resolution,
expiresAt: Date.now() + this.cacheTtlMs,
});
return resolution;
}
/**
* Check link health without returning full content (uses cache if available).
*/
async checkHealth(linkTarget: string): Promise<'alive' | 'dead' | 'unknown'> {
const cached = this.cache.get(linkTarget);
if (cached && cached.expiresAt > Date.now()) {
return cached.resolution.status;
}
// Don't do a full resolve just for health — return unknown
return 'unknown';
}
/** Clear all cached resolutions. */
clearCache(): void {
this.cache.clear();
}
private async fetchResource(project: string, server: string, uri: string): Promise<string> {
// Step 1: Resolve server name → server ID from the project's servers
const servers = await this.mcpdClient.get<Array<{ id: string; name: string }>>(
`/api/v1/projects/${encodeURIComponent(project)}/servers`,
);
const target = servers.find((s) => s.name === server);
if (!target) {
throw new Error(`Server '${server}' not found in project '${project}'`);
}
// Step 2: Call resources/read via the MCP proxy
const proxyResponse = await this.mcpdClient.post<{
result?: { contents?: Array<{ text?: string; uri?: string }> };
error?: { code: number; message: string };
}>('/api/v1/mcp/proxy', {
serverId: target.id,
method: 'resources/read',
params: { uri },
});
if (proxyResponse.error) {
throw new Error(`MCP error: ${proxyResponse.error.message}`);
}
const contents = proxyResponse.result?.contents;
if (!contents || contents.length === 0) {
throw new Error(`No content returned for resource: ${uri}`);
}
// Concatenate all text contents
return contents.map((c) => c.text ?? '').join('\n');
}
}

View File

@@ -0,0 +1,241 @@
import { describe, it, expect, vi, beforeEach } from 'vitest';
import { LinkResolver } from '../src/services/link-resolver.js';
import type { McpdClient } from '../src/http/mcpd-client.js';
function mockClient(): McpdClient {
return {
get: vi.fn(),
post: vi.fn(),
put: vi.fn(),
delete: vi.fn(),
forward: vi.fn(),
withHeaders: vi.fn(),
} as unknown as McpdClient;
}
describe('LinkResolver', () => {
let client: McpdClient;
let resolver: LinkResolver;
beforeEach(() => {
client = mockClient();
resolver = new LinkResolver(client, 1000); // 1s TTL for tests
});
// ── parseLink ──
describe('parseLink', () => {
it('parses valid link target', () => {
const result = resolver.parseLink('my-project/docmost-mcp:docmost://pages/abc');
expect(result).toEqual({
project: 'my-project',
server: 'docmost-mcp',
uri: 'docmost://pages/abc',
});
});
it('parses link with complex URI', () => {
const result = resolver.parseLink('proj/srv:file:///path/to/resource');
expect(result).toEqual({
project: 'proj',
server: 'srv',
uri: 'file:///path/to/resource',
});
});
it('throws on missing project separator', () => {
expect(() => resolver.parseLink('noslash')).toThrow('missing project');
});
it('throws on missing server:uri separator', () => {
expect(() => resolver.parseLink('proj/nocolon')).toThrow('missing server:uri');
});
it('throws on empty uri', () => {
expect(() => resolver.parseLink('proj/srv:')).toThrow('empty uri');
});
it('throws when project is empty', () => {
expect(() => resolver.parseLink('/srv:uri')).toThrow('missing project');
});
it('throws when server is empty', () => {
expect(() => resolver.parseLink('proj/:uri')).toThrow('missing server:uri');
});
});
// ── resolve ──
describe('resolve', () => {
it('fetches resource content successfully', async () => {
vi.mocked(client.get).mockResolvedValue([
{ id: 'srv-id-1', name: 'docmost-mcp' },
]);
vi.mocked(client.post).mockResolvedValue({
result: { contents: [{ text: 'Hello from docmost', uri: 'docmost://pages/abc' }] },
});
const result = await resolver.resolve('my-project/docmost-mcp:docmost://pages/abc');
expect(result).toEqual({ content: 'Hello from docmost', status: 'alive' });
expect(client.get).toHaveBeenCalledWith('/api/v1/projects/my-project/servers');
expect(client.post).toHaveBeenCalledWith('/api/v1/mcp/proxy', {
serverId: 'srv-id-1',
method: 'resources/read',
params: { uri: 'docmost://pages/abc' },
});
});
it('returns dead status when server not found in project', async () => {
vi.mocked(client.get).mockResolvedValue([
{ id: 'srv-other', name: 'other-server' },
]);
const result = await resolver.resolve('proj/missing-srv:some://uri');
expect(result.status).toBe('dead');
expect(result.content).toBeNull();
expect(result.error).toContain("Server 'missing-srv' not found");
});
it('returns dead status when MCP proxy returns error', async () => {
vi.mocked(client.get).mockResolvedValue([{ id: 'srv-1', name: 'srv' }]);
vi.mocked(client.post).mockResolvedValue({
error: { code: -32601, message: 'Method not found' },
});
const result = await resolver.resolve('proj/srv:some://uri');
expect(result.status).toBe('dead');
expect(result.error).toContain('Method not found');
});
it('returns dead status when no content returned', async () => {
vi.mocked(client.get).mockResolvedValue([{ id: 'srv-1', name: 'srv' }]);
vi.mocked(client.post).mockResolvedValue({
result: { contents: [] },
});
const result = await resolver.resolve('proj/srv:some://uri');
expect(result.status).toBe('dead');
expect(result.error).toContain('No content returned');
});
it('returns dead status on network error', async () => {
vi.mocked(client.get).mockRejectedValue(new Error('Connection refused'));
const result = await resolver.resolve('proj/srv:some://uri');
expect(result.status).toBe('dead');
expect(result.error).toContain('Connection refused');
});
it('concatenates multiple content entries', async () => {
vi.mocked(client.get).mockResolvedValue([{ id: 'srv-1', name: 'srv' }]);
vi.mocked(client.post).mockResolvedValue({
result: {
contents: [
{ text: 'Part 1', uri: 'uri1' },
{ text: 'Part 2', uri: 'uri2' },
],
},
});
const result = await resolver.resolve('proj/srv:some://uri');
expect(result.content).toBe('Part 1\nPart 2');
expect(result.status).toBe('alive');
});
it('logs dead link to console.error', async () => {
vi.mocked(client.get).mockRejectedValue(new Error('fail'));
const spy = vi.spyOn(console, 'error').mockImplementation(() => {});
await resolver.resolve('proj/srv:some://uri');
expect(spy).toHaveBeenCalledWith(expect.stringContaining('[link-resolver] Dead link'));
spy.mockRestore();
});
});
// ── caching ──
describe('caching', () => {
it('returns cached result on second call', async () => {
vi.mocked(client.get).mockResolvedValue([{ id: 'srv-1', name: 'srv' }]);
vi.mocked(client.post).mockResolvedValue({
result: { contents: [{ text: 'cached content' }] },
});
const first = await resolver.resolve('proj/srv:some://uri');
const second = await resolver.resolve('proj/srv:some://uri');
expect(first).toEqual(second);
// Only one HTTP call — second was cached
expect(client.get).toHaveBeenCalledTimes(1);
});
it('refetches after cache expires', async () => {
vi.mocked(client.get).mockResolvedValue([{ id: 'srv-1', name: 'srv' }]);
vi.mocked(client.post).mockResolvedValue({
result: { contents: [{ text: 'content' }] },
});
await resolver.resolve('proj/srv:some://uri');
// Advance time past TTL
vi.useFakeTimers();
vi.advanceTimersByTime(1500);
await resolver.resolve('proj/srv:some://uri');
expect(client.get).toHaveBeenCalledTimes(2);
vi.useRealTimers();
});
it('clearCache removes all entries', async () => {
vi.mocked(client.get).mockResolvedValue([{ id: 'srv-1', name: 'srv' }]);
vi.mocked(client.post).mockResolvedValue({
result: { contents: [{ text: 'content' }] },
});
await resolver.resolve('proj/srv:some://uri');
resolver.clearCache();
await resolver.resolve('proj/srv:some://uri');
expect(client.get).toHaveBeenCalledTimes(2);
});
});
// ── checkHealth ──
describe('checkHealth', () => {
it('returns cached status if available', async () => {
vi.mocked(client.get).mockResolvedValue([{ id: 'srv-1', name: 'srv' }]);
vi.mocked(client.post).mockResolvedValue({
result: { contents: [{ text: 'content' }] },
});
await resolver.resolve('proj/srv:some://uri');
const health = await resolver.checkHealth('proj/srv:some://uri');
expect(health).toBe('alive');
});
it('returns unknown if not cached', async () => {
const health = await resolver.checkHealth('proj/srv:some://uri');
expect(health).toBe('unknown');
});
it('returns dead from cached dead link', async () => {
vi.mocked(client.get).mockRejectedValue(new Error('fail'));
vi.spyOn(console, 'error').mockImplementation(() => {});
await resolver.resolve('proj/srv:some://uri');
const health = await resolver.checkHealth('proj/srv:some://uri');
expect(health).toBe('dead');
});
});
});

View File

@@ -0,0 +1,166 @@
import { describe, it, expect, vi } from 'vitest';
import { LlmPromptSelector, type PromptIndexForLlm } from '../src/gate/llm-selector.js';
import { ProviderRegistry } from '../src/providers/registry.js';
import type { LlmProvider, CompletionOptions, CompletionResult } from '../src/providers/types.js';
function makeMockProvider(responseContent: string): LlmProvider {
return {
name: 'mock-heavy',
complete: vi.fn().mockResolvedValue({
content: responseContent,
toolCalls: [],
usage: { promptTokens: 100, completionTokens: 50, totalTokens: 150 },
finishReason: 'stop',
} satisfies CompletionResult),
listModels: vi.fn().mockResolvedValue(['mock-model']),
isAvailable: vi.fn().mockResolvedValue(true),
};
}
function makeRegistry(provider: LlmProvider): ProviderRegistry {
const registry = new ProviderRegistry();
registry.register(provider);
registry.assignTier(provider.name, 'heavy');
return registry;
}
const sampleIndex: PromptIndexForLlm[] = [
{ name: 'zigbee-pairing', priority: 7, summary: 'How to pair Zigbee devices', chapters: ['Setup', 'Troubleshooting'] },
{ name: 'mqtt-config', priority: 5, summary: 'MQTT broker configuration', chapters: null },
{ name: 'common-mistakes', priority: 10, summary: 'Critical safety rules', chapters: null },
];
describe('LlmPromptSelector', () => {
it('sends tags and index to heavy LLM and parses response', async () => {
const provider = makeMockProvider(
'```json\n{ "selectedNames": ["zigbee-pairing"], "reasoning": "User is working with zigbee" }\n```',
);
const registry = makeRegistry(provider);
const selector = new LlmPromptSelector(registry);
const result = await selector.selectPrompts(['zigbee', 'pairing'], sampleIndex);
expect(result.selectedNames).toContain('zigbee-pairing');
expect(result.selectedNames).toContain('common-mistakes'); // Priority 10 always included
expect(result.reasoning).toBe('User is working with zigbee');
});
it('always includes priority 10 prompts even if LLM omits them', async () => {
const provider = makeMockProvider(
'{ "selectedNames": ["mqtt-config"], "reasoning": "MQTT related" }',
);
const registry = makeRegistry(provider);
const selector = new LlmPromptSelector(registry);
const result = await selector.selectPrompts(['mqtt'], sampleIndex);
expect(result.selectedNames).toContain('mqtt-config');
expect(result.selectedNames).toContain('common-mistakes');
});
it('does not duplicate priority 10 if LLM already selected them', async () => {
const provider = makeMockProvider(
'{ "selectedNames": ["common-mistakes", "mqtt-config"], "reasoning": "Both needed" }',
);
const registry = makeRegistry(provider);
const selector = new LlmPromptSelector(registry);
const result = await selector.selectPrompts(['mqtt'], sampleIndex);
const count = result.selectedNames.filter((n) => n === 'common-mistakes').length;
expect(count).toBe(1);
});
it('passes system and user messages to provider.complete', async () => {
const provider = makeMockProvider(
'{ "selectedNames": [], "reasoning": "none" }',
);
const registry = makeRegistry(provider);
const selector = new LlmPromptSelector(registry);
await selector.selectPrompts(['test'], sampleIndex);
expect(provider.complete).toHaveBeenCalledOnce();
const call = (provider.complete as ReturnType<typeof vi.fn>).mock.calls[0]![0] as CompletionOptions;
expect(call.messages).toHaveLength(2);
expect(call.messages[0]!.role).toBe('system');
expect(call.messages[1]!.role).toBe('user');
expect(call.messages[1]!.content).toContain('test');
expect(call.temperature).toBe(0);
});
it('passes model override to complete options', async () => {
const provider = makeMockProvider(
'{ "selectedNames": [], "reasoning": "" }',
);
const registry = makeRegistry(provider);
const selector = new LlmPromptSelector(registry, 'gemini-pro');
await selector.selectPrompts(['test'], sampleIndex);
const call = (provider.complete as ReturnType<typeof vi.fn>).mock.calls[0]![0] as CompletionOptions;
expect(call.model).toBe('gemini-pro');
});
it('throws when no heavy provider is available', async () => {
const registry = new ProviderRegistry(); // Empty registry
const selector = new LlmPromptSelector(registry);
await expect(selector.selectPrompts(['test'], sampleIndex)).rejects.toThrow(
'No heavy LLM provider available',
);
});
it('throws when LLM response has no valid JSON', async () => {
const provider = makeMockProvider('I cannot help with that request.');
const registry = makeRegistry(provider);
const selector = new LlmPromptSelector(registry);
await expect(selector.selectPrompts(['test'], sampleIndex)).rejects.toThrow(
'LLM response did not contain valid selection JSON',
);
});
it('handles response with empty selectedNames', async () => {
const provider = makeMockProvider('{ "selectedNames": [], "reasoning": "nothing matched" }');
const registry = makeRegistry(provider);
const selector = new LlmPromptSelector(registry);
// Empty selectedNames, but priority 10 should still be included
const result = await selector.selectPrompts(['test'], sampleIndex);
expect(result.selectedNames).toEqual(['common-mistakes']);
expect(result.reasoning).toBe('nothing matched');
});
it('handles response with reasoning missing', async () => {
const provider = makeMockProvider('{ "selectedNames": ["mqtt-config"] }');
const registry = makeRegistry(provider);
const selector = new LlmPromptSelector(registry);
const result = await selector.selectPrompts(['test'], sampleIndex);
expect(result.reasoning).toBe('');
expect(result.selectedNames).toContain('mqtt-config');
});
it('includes prompt details in the user prompt', async () => {
const indexWithNull: PromptIndexForLlm[] = [
...sampleIndex,
{ name: 'no-desc', priority: 3, summary: null, chapters: null },
];
const provider = makeMockProvider(
'{ "selectedNames": [], "reasoning": "" }',
);
const registry = makeRegistry(provider);
const selector = new LlmPromptSelector(registry);
await selector.selectPrompts(['zigbee'], indexWithNull);
const call = (provider.complete as ReturnType<typeof vi.fn>).mock.calls[0]![0] as CompletionOptions;
const userMsg = call.messages[1]!.content;
expect(userMsg).toContain('zigbee-pairing');
expect(userMsg).toContain('priority: 7');
expect(userMsg).toContain('How to pair Zigbee devices');
expect(userMsg).toContain('Setup, Troubleshooting');
expect(userMsg).toContain('No summary'); // For prompts with null summary
});
});

View File

@@ -0,0 +1,520 @@
import { describe, it, expect, vi, beforeEach } from 'vitest';
import { McpRouter } from '../src/router.js';
import type { UpstreamConnection, JsonRpcRequest, JsonRpcResponse, JsonRpcNotification } from '../src/types.js';
import type { McpdClient } from '../src/http/mcpd-client.js';
import { ProviderRegistry } from '../src/providers/registry.js';
import type { LlmProvider, CompletionResult } from '../src/providers/types.js';
function mockUpstream(
name: string,
opts: { tools?: Array<{ name: string; description?: string }> } = {},
): UpstreamConnection {
return {
name,
isAlive: vi.fn(() => true),
close: vi.fn(async () => {}),
onNotification: vi.fn(),
send: vi.fn(async (req: JsonRpcRequest): Promise<JsonRpcResponse> => {
if (req.method === 'tools/list') {
return { jsonrpc: '2.0', id: req.id, result: { tools: opts.tools ?? [] } };
}
if (req.method === 'tools/call') {
return {
jsonrpc: '2.0',
id: req.id,
result: { content: [{ type: 'text', text: `Called ${(req.params as Record<string, unknown>)?.name}` }] },
};
}
if (req.method === 'resources/list') {
return { jsonrpc: '2.0', id: req.id, result: { resources: [] } };
}
if (req.method === 'prompts/list') {
return { jsonrpc: '2.0', id: req.id, result: { prompts: [] } };
}
return { jsonrpc: '2.0', id: req.id, error: { code: -32601, message: 'Not found' } };
}),
} as UpstreamConnection;
}
function mockMcpdClient(prompts: Array<{ name: string; priority: number; summary: string | null; chapters: string[] | null; content: string; type?: string }> = []): McpdClient {
return {
get: vi.fn(async (path: string) => {
if (path.includes('/prompts/visible')) {
return prompts.map((p) => ({ ...p, type: p.type ?? 'prompt' }));
}
if (path.includes('/prompt-index')) {
return prompts.map((p) => ({
name: p.name,
priority: p.priority,
summary: p.summary,
chapters: p.chapters,
}));
}
return [];
}),
post: vi.fn(async () => ({})),
put: vi.fn(async () => ({})),
delete: vi.fn(async () => {}),
forward: vi.fn(async () => ({ status: 200, body: {} })),
withHeaders: vi.fn(function (this: McpdClient) { return this; }),
} as unknown as McpdClient;
}
const samplePrompts = [
{ name: 'common-mistakes', priority: 10, summary: 'Critical safety rules everyone must follow', chapters: null, content: 'NEVER do X. ALWAYS do Y.' },
{ name: 'zigbee-pairing', priority: 7, summary: 'How to pair Zigbee devices with the hub', chapters: ['Setup', 'Troubleshooting'], content: 'Step 1: Put device in pairing mode...' },
{ name: 'mqtt-config', priority: 5, summary: 'MQTT broker configuration guide', chapters: ['Broker Setup', 'Authentication'], content: 'Configure the MQTT broker at...' },
{ name: 'security-policy', priority: 8, summary: 'Security policies for production deployments', chapters: ['Network', 'Auth'], content: 'All connections must use TLS...' },
];
function setupGatedRouter(
opts: {
gated?: boolean;
prompts?: typeof samplePrompts;
withLlm?: boolean;
llmResponse?: string;
} = {},
): { router: McpRouter; mcpdClient: McpdClient } {
const router = new McpRouter();
const prompts = opts.prompts ?? samplePrompts;
const mcpdClient = mockMcpdClient(prompts);
router.setPromptConfig(mcpdClient, 'test-project');
let providerRegistry: ProviderRegistry | null = null;
if (opts.withLlm) {
providerRegistry = new ProviderRegistry();
const mockProvider: LlmProvider = {
name: 'mock-heavy',
complete: vi.fn().mockResolvedValue({
content: opts.llmResponse ?? '{ "selectedNames": ["zigbee-pairing"], "reasoning": "User is working with zigbee" }',
toolCalls: [],
usage: { promptTokens: 100, completionTokens: 50, totalTokens: 150 },
finishReason: 'stop',
} satisfies CompletionResult),
listModels: vi.fn().mockResolvedValue([]),
isAvailable: vi.fn().mockResolvedValue(true),
};
providerRegistry.register(mockProvider);
providerRegistry.assignTier(mockProvider.name, 'heavy');
}
router.setGateConfig({
gated: opts.gated !== false,
providerRegistry,
});
return { router, mcpdClient };
}
describe('McpRouter gating', () => {
describe('initialize with gating', () => {
it('creates gated session on initialize', async () => {
const { router } = setupGatedRouter();
const res = await router.route(
{ jsonrpc: '2.0', id: 1, method: 'initialize' },
{ sessionId: 's1' },
);
expect(res.result).toBeDefined();
// The session should be gated now
const toolsRes = await router.route(
{ jsonrpc: '2.0', id: 2, method: 'tools/list' },
{ sessionId: 's1' },
);
const tools = (toolsRes.result as { tools: Array<{ name: string }> }).tools;
expect(tools).toHaveLength(1);
expect(tools[0]!.name).toBe('begin_session');
});
it('creates ungated session when project is not gated', async () => {
const { router } = setupGatedRouter({ gated: false });
router.addUpstream(mockUpstream('ha', { tools: [{ name: 'get_entities' }] }));
await router.route(
{ jsonrpc: '2.0', id: 1, method: 'initialize' },
{ sessionId: 's1' },
);
const toolsRes = await router.route(
{ jsonrpc: '2.0', id: 2, method: 'tools/list' },
{ sessionId: 's1' },
);
const tools = (toolsRes.result as { tools: Array<{ name: string }> }).tools;
const names = tools.map((t) => t.name);
expect(names).toContain('ha/get_entities');
expect(names).toContain('read_prompts');
expect(names).not.toContain('begin_session');
});
});
describe('tools/list gating', () => {
it('shows only begin_session when session is gated', async () => {
const { router } = setupGatedRouter();
await router.route({ jsonrpc: '2.0', id: 1, method: 'initialize' }, { sessionId: 's1' });
const res = await router.route(
{ jsonrpc: '2.0', id: 2, method: 'tools/list' },
{ sessionId: 's1' },
);
const tools = (res.result as { tools: Array<{ name: string }> }).tools;
expect(tools).toHaveLength(1);
expect(tools[0]!.name).toBe('begin_session');
});
it('shows all tools plus read_prompts after ungating', async () => {
const { router } = setupGatedRouter();
router.addUpstream(mockUpstream('ha', { tools: [{ name: 'get_entities' }] }));
await router.route({ jsonrpc: '2.0', id: 1, method: 'initialize' }, { sessionId: 's1' });
// Ungate via begin_session
await router.route(
{ jsonrpc: '2.0', id: 2, method: 'tools/call', params: { name: 'begin_session', arguments: { tags: ['zigbee'] } } },
{ sessionId: 's1' },
);
const toolsRes = await router.route(
{ jsonrpc: '2.0', id: 3, method: 'tools/list' },
{ sessionId: 's1' },
);
const tools = (toolsRes.result as { tools: Array<{ name: string }> }).tools;
const names = tools.map((t) => t.name);
expect(names).toContain('ha/get_entities');
expect(names).toContain('propose_prompt');
expect(names).toContain('read_prompts');
expect(names).not.toContain('begin_session');
});
});
describe('begin_session', () => {
it('returns matched prompts with keyword matching', async () => {
const { router } = setupGatedRouter();
await router.route({ jsonrpc: '2.0', id: 1, method: 'initialize' }, { sessionId: 's1' });
const res = await router.route(
{ jsonrpc: '2.0', id: 2, method: 'tools/call', params: { name: 'begin_session', arguments: { tags: ['zigbee', 'pairing'] } } },
{ sessionId: 's1' },
);
expect(res.error).toBeUndefined();
const text = ((res.result as { content: Array<{ text: string }> }).content[0]!.text);
// Should include priority 10 prompt
expect(text).toContain('common-mistakes');
expect(text).toContain('NEVER do X');
// Should include zigbee-pairing (matches both tags)
expect(text).toContain('zigbee-pairing');
expect(text).toContain('pairing mode');
// Should include encouragement
expect(text).toContain('read_prompts');
});
it('includes priority 10 prompts even without matching tags', async () => {
const { router } = setupGatedRouter();
await router.route({ jsonrpc: '2.0', id: 1, method: 'initialize' }, { sessionId: 's1' });
const res = await router.route(
{ jsonrpc: '2.0', id: 2, method: 'tools/call', params: { name: 'begin_session', arguments: { tags: ['unrelated-keyword'] } } },
{ sessionId: 's1' },
);
const text = ((res.result as { content: Array<{ text: string }> }).content[0]!.text);
expect(text).toContain('common-mistakes');
expect(text).toContain('NEVER do X');
});
it('uses LLM selection when provider is available', async () => {
const { router } = setupGatedRouter({
withLlm: true,
llmResponse: '{ "selectedNames": ["zigbee-pairing", "security-policy"], "reasoning": "Zigbee pairing needs security awareness" }',
});
await router.route({ jsonrpc: '2.0', id: 1, method: 'initialize' }, { sessionId: 's1' });
const res = await router.route(
{ jsonrpc: '2.0', id: 2, method: 'tools/call', params: { name: 'begin_session', arguments: { tags: ['zigbee'] } } },
{ sessionId: 's1' },
);
const text = ((res.result as { content: Array<{ text: string }> }).content[0]!.text);
expect(text).toContain('Zigbee pairing needs security awareness');
expect(text).toContain('zigbee-pairing');
expect(text).toContain('security-policy');
expect(text).toContain('common-mistakes'); // priority 10 always included
});
it('rejects empty tags', async () => {
const { router } = setupGatedRouter();
await router.route({ jsonrpc: '2.0', id: 1, method: 'initialize' }, { sessionId: 's1' });
const res = await router.route(
{ jsonrpc: '2.0', id: 2, method: 'tools/call', params: { name: 'begin_session', arguments: { tags: [] } } },
{ sessionId: 's1' },
);
expect(res.error).toBeDefined();
expect(res.error!.code).toBe(-32602);
});
it('returns message when session is already ungated', async () => {
const { router } = setupGatedRouter();
await router.route({ jsonrpc: '2.0', id: 1, method: 'initialize' }, { sessionId: 's1' });
// First call ungates
await router.route(
{ jsonrpc: '2.0', id: 2, method: 'tools/call', params: { name: 'begin_session', arguments: { tags: ['zigbee'] } } },
{ sessionId: 's1' },
);
// Second call tells user to use read_prompts
const res = await router.route(
{ jsonrpc: '2.0', id: 3, method: 'tools/call', params: { name: 'begin_session', arguments: { tags: ['mqtt'] } } },
{ sessionId: 's1' },
);
const text = ((res.result as { content: Array<{ text: string }> }).content[0]!.text);
expect(text).toContain('already started');
expect(text).toContain('read_prompts');
});
it('lists remaining prompts for awareness', async () => {
const { router } = setupGatedRouter();
await router.route({ jsonrpc: '2.0', id: 1, method: 'initialize' }, { sessionId: 's1' });
const res = await router.route(
{ jsonrpc: '2.0', id: 2, method: 'tools/call', params: { name: 'begin_session', arguments: { tags: ['zigbee'] } } },
{ sessionId: 's1' },
);
const text = ((res.result as { content: Array<{ text: string }> }).content[0]!.text);
// Non-matching prompts should be listed as "other available prompts"
// security-policy doesn't match 'zigbee' in keyword mode
expect(text).toContain('security-policy');
});
});
describe('read_prompts', () => {
it('returns prompts matching keywords', async () => {
const { router } = setupGatedRouter({ gated: false });
await router.route({ jsonrpc: '2.0', id: 1, method: 'initialize' }, { sessionId: 's1' });
const res = await router.route(
{ jsonrpc: '2.0', id: 2, method: 'tools/call', params: { name: 'read_prompts', arguments: { tags: ['mqtt', 'broker'] } } },
{ sessionId: 's1' },
);
expect(res.error).toBeUndefined();
const text = ((res.result as { content: Array<{ text: string }> }).content[0]!.text);
expect(text).toContain('mqtt-config');
expect(text).toContain('Configure the MQTT broker');
});
it('filters out already-sent prompts', async () => {
const { router } = setupGatedRouter();
await router.route({ jsonrpc: '2.0', id: 1, method: 'initialize' }, { sessionId: 's1' });
// begin_session sends common-mistakes (priority 10) and zigbee-pairing
await router.route(
{ jsonrpc: '2.0', id: 2, method: 'tools/call', params: { name: 'begin_session', arguments: { tags: ['zigbee'] } } },
{ sessionId: 's1' },
);
// read_prompts for mqtt should not re-send common-mistakes
const res = await router.route(
{ jsonrpc: '2.0', id: 3, method: 'tools/call', params: { name: 'read_prompts', arguments: { tags: ['mqtt'] } } },
{ sessionId: 's1' },
);
const text = ((res.result as { content: Array<{ text: string }> }).content[0]!.text);
expect(text).toContain('mqtt-config');
// common-mistakes was already sent, should not appear again
expect(text).not.toContain('NEVER do X');
});
it('returns message when no new prompts match', async () => {
const { router } = setupGatedRouter({ prompts: [] });
await router.route({ jsonrpc: '2.0', id: 1, method: 'initialize' }, { sessionId: 's1' });
const res = await router.route(
{ jsonrpc: '2.0', id: 2, method: 'tools/call', params: { name: 'read_prompts', arguments: { tags: ['nonexistent'] } } },
{ sessionId: 's1' },
);
const text = ((res.result as { content: Array<{ text: string }> }).content[0]!.text);
expect(text).toContain('No new matching prompts');
});
it('rejects empty tags', async () => {
const { router } = setupGatedRouter({ gated: false });
await router.route({ jsonrpc: '2.0', id: 1, method: 'initialize' }, { sessionId: 's1' });
const res = await router.route(
{ jsonrpc: '2.0', id: 2, method: 'tools/call', params: { name: 'read_prompts', arguments: { tags: [] } } },
{ sessionId: 's1' },
);
expect(res.error).toBeDefined();
expect(res.error!.code).toBe(-32602);
});
});
describe('gated intercept', () => {
it('auto-ungates when gated session calls a real tool', async () => {
const { router } = setupGatedRouter();
const ha = mockUpstream('ha', { tools: [{ name: 'get_entities' }] });
router.addUpstream(ha);
await router.discoverTools();
await router.route({ jsonrpc: '2.0', id: 1, method: 'initialize' }, { sessionId: 's1' });
// Call a real tool while gated — should intercept, extract keywords, and route
const res = await router.route(
{ jsonrpc: '2.0', id: 2, method: 'tools/call', params: { name: 'ha/get_entities', arguments: { domain: 'light' } } },
{ sessionId: 's1' },
);
// Response should include the tool result
expect(res.error).toBeUndefined();
const result = res.result as { content: Array<{ type: string; text: string }> };
// Should have briefing prepended
expect(result.content.length).toBeGreaterThanOrEqual(1);
// Session should now be ungated
const toolsRes = await router.route(
{ jsonrpc: '2.0', id: 3, method: 'tools/list' },
{ sessionId: 's1' },
);
const tools = (toolsRes.result as { tools: Array<{ name: string }> }).tools;
expect(tools.map((t) => t.name)).toContain('ha/get_entities');
});
it('includes project context in intercepted response', async () => {
const { router } = setupGatedRouter();
const ha = mockUpstream('ha', { tools: [{ name: 'get_entities' }] });
router.addUpstream(ha);
await router.discoverTools();
await router.route({ jsonrpc: '2.0', id: 1, method: 'initialize' }, { sessionId: 's1' });
const res = await router.route(
{ jsonrpc: '2.0', id: 2, method: 'tools/call', params: { name: 'ha/get_entities', arguments: { domain: 'light' } } },
{ sessionId: 's1' },
);
const result = res.result as { content: Array<{ type: string; text: string }> };
// First content block should be the briefing (priority 10 at minimum)
const briefing = result.content[0]!.text;
expect(briefing).toContain('common-mistakes');
expect(briefing).toContain('NEVER do X');
});
});
describe('initialize instructions for gated projects', () => {
it('includes gate message and prompt index in instructions', async () => {
const { router } = setupGatedRouter();
const res = await router.route(
{ jsonrpc: '2.0', id: 1, method: 'initialize' },
{ sessionId: 's1' },
);
const result = res.result as { instructions?: string };
expect(result.instructions).toBeDefined();
expect(result.instructions).toContain('begin_session');
expect(result.instructions).toContain('gated session');
// Should list available prompts
expect(result.instructions).toContain('common-mistakes');
expect(result.instructions).toContain('zigbee-pairing');
});
it('does not include gate message for non-gated projects', async () => {
const { router } = setupGatedRouter({ gated: false });
router.setInstructions('Base project instructions');
const res = await router.route(
{ jsonrpc: '2.0', id: 1, method: 'initialize' },
{ sessionId: 's1' },
);
const result = res.result as { instructions?: string };
expect(result.instructions).toBe('Base project instructions');
expect(result.instructions).not.toContain('gated session');
});
it('preserves base instructions and appends gate message', async () => {
const { router } = setupGatedRouter();
router.setInstructions('You are a helpful assistant.');
const res = await router.route(
{ jsonrpc: '2.0', id: 1, method: 'initialize' },
{ sessionId: 's1' },
);
const result = res.result as { instructions?: string };
expect(result.instructions).toContain('You are a helpful assistant.');
expect(result.instructions).toContain('begin_session');
});
it('sorts prompt index by priority descending', async () => {
const { router } = setupGatedRouter();
const res = await router.route(
{ jsonrpc: '2.0', id: 1, method: 'initialize' },
{ sessionId: 's1' },
);
const result = res.result as { instructions: string };
const lines = result.instructions.split('\n');
// Find the prompt index lines
const promptLines = lines.filter((l) => l.startsWith('- ') && l.includes('priority'));
// priority 10 should come first
expect(promptLines[0]).toContain('common-mistakes');
expect(promptLines[0]).toContain('priority 10');
});
});
describe('session cleanup', () => {
it('cleanupSession removes gate state', async () => {
const { router } = setupGatedRouter();
await router.route({ jsonrpc: '2.0', id: 1, method: 'initialize' }, { sessionId: 's1' });
// Session is gated
let toolsRes = await router.route(
{ jsonrpc: '2.0', id: 2, method: 'tools/list' },
{ sessionId: 's1' },
);
expect((toolsRes.result as { tools: Array<{ name: string }> }).tools[0]!.name).toBe('begin_session');
// Cleanup
router.cleanupSession('s1');
// After cleanup, session is treated as unknown (ungated)
toolsRes = await router.route(
{ jsonrpc: '2.0', id: 3, method: 'tools/list' },
{ sessionId: 's1' },
);
const tools = (toolsRes.result as { tools: Array<{ name: string }> }).tools;
expect(tools.map((t) => t.name)).not.toContain('begin_session');
});
});
describe('prompt index caching', () => {
it('caches prompt index for 60 seconds', async () => {
const { router, mcpdClient } = setupGatedRouter({ gated: false });
await router.route({ jsonrpc: '2.0', id: 1, method: 'initialize' }, { sessionId: 's1' });
// First read_prompts call fetches from mcpd
await router.route(
{ jsonrpc: '2.0', id: 2, method: 'tools/call', params: { name: 'read_prompts', arguments: { tags: ['mqtt'] } } },
{ sessionId: 's1' },
);
// Second call should use cache
await router.route(
{ jsonrpc: '2.0', id: 3, method: 'tools/call', params: { name: 'read_prompts', arguments: { tags: ['zigbee'] } } },
{ sessionId: 's1' },
);
// mcpdClient.get should have been called only once for prompts/visible
const getCalls = vi.mocked(mcpdClient.get).mock.calls.filter((c) => (c[0] as string).includes('/prompts/visible'));
expect(getCalls).toHaveLength(1);
});
});
});

View File

@@ -165,16 +165,13 @@ describe('McpRouter - Prompt Integration', () => {
);
});
it('should read mcpctl resource content', async () => {
it('should read mcpctl resource content live from mcpd', async () => {
router.setPromptConfig(mcpdClient, 'proj');
vi.mocked(mcpdClient.get).mockResolvedValue([
{ name: 'my-prompt', content: 'The content here', type: 'prompt' },
]);
// First list to populate cache
await router.route({ jsonrpc: '2.0', id: 1, method: 'resources/list' });
// Then read
// Read directly — no need to list first
const response = await router.route({
jsonrpc: '2.0',
id: 2,
@@ -187,8 +184,55 @@ describe('McpRouter - Prompt Integration', () => {
expect(contents[0]!.text).toBe('The content here');
});
it('should return fresh content after prompt update', async () => {
router.setPromptConfig(mcpdClient, 'proj');
// First call returns old content
vi.mocked(mcpdClient.get).mockResolvedValueOnce([
{ name: 'my-prompt', content: 'Old content', type: 'prompt' },
]);
await router.route({
jsonrpc: '2.0', id: 1, method: 'resources/read',
params: { uri: 'mcpctl://prompts/my-prompt' },
});
// Second call returns updated content
vi.mocked(mcpdClient.get).mockResolvedValueOnce([
{ name: 'my-prompt', content: 'Updated content', type: 'prompt' },
]);
const response = await router.route({
jsonrpc: '2.0', id: 2, method: 'resources/read',
params: { uri: 'mcpctl://prompts/my-prompt' },
});
const contents = (response.result as { contents: Array<{ text: string }> }).contents;
expect(contents[0]!.text).toBe('Updated content');
});
it('should fall back to cache when mcpd is unreachable on read', async () => {
router.setPromptConfig(mcpdClient, 'proj');
// Populate cache via list
vi.mocked(mcpdClient.get).mockResolvedValueOnce([
{ name: 'cached-prompt', content: 'Cached content', type: 'prompt' },
]);
await router.route({ jsonrpc: '2.0', id: 1, method: 'resources/list' });
// mcpd goes down for read
vi.mocked(mcpdClient.get).mockRejectedValueOnce(new Error('Connection refused'));
const response = await router.route({
jsonrpc: '2.0', id: 2, method: 'resources/read',
params: { uri: 'mcpctl://prompts/cached-prompt' },
});
expect(response.error).toBeUndefined();
const contents = (response.result as { contents: Array<{ text: string }> }).contents;
expect(contents[0]!.text).toBe('Cached content');
});
it('should return error for unknown mcpctl resource', async () => {
router.setPromptConfig(mcpdClient, 'proj');
vi.mocked(mcpdClient.get).mockResolvedValue([]);
const response = await router.route({
jsonrpc: '2.0',

View File

@@ -0,0 +1,155 @@
import { describe, it, expect } from 'vitest';
import { SessionGate } from '../src/gate/session-gate.js';
import type { TagMatchResult, PromptIndexEntry } from '../src/gate/tag-matcher.js';
function makeMatchResult(names: string[]): TagMatchResult {
return {
fullContent: names.map((name) => ({
name,
priority: 5,
summary: null,
chapters: null,
content: `Content of ${name}`,
})),
indexOnly: [],
remaining: [],
};
}
describe('SessionGate', () => {
it('creates a gated session when project is gated', () => {
const gate = new SessionGate();
gate.createSession('s1', true);
expect(gate.isGated('s1')).toBe(true);
});
it('creates an ungated session when project is not gated', () => {
const gate = new SessionGate();
gate.createSession('s1', false);
expect(gate.isGated('s1')).toBe(false);
});
it('unknown sessions are treated as ungated', () => {
const gate = new SessionGate();
expect(gate.isGated('nonexistent')).toBe(false);
});
it('getSession returns null for unknown sessions', () => {
const gate = new SessionGate();
expect(gate.getSession('nonexistent')).toBeNull();
});
it('getSession returns session state', () => {
const gate = new SessionGate();
gate.createSession('s1', true);
const state = gate.getSession('s1');
expect(state).not.toBeNull();
expect(state!.gated).toBe(true);
expect(state!.tags).toEqual([]);
expect(state!.retrievedPrompts.size).toBe(0);
expect(state!.briefing).toBeNull();
});
it('ungate marks session as ungated and records tags', () => {
const gate = new SessionGate();
gate.createSession('s1', true);
gate.ungate('s1', ['zigbee', 'mqtt'], makeMatchResult(['prompt-a', 'prompt-b']));
expect(gate.isGated('s1')).toBe(false);
const state = gate.getSession('s1');
expect(state!.tags).toEqual(['zigbee', 'mqtt']);
expect(state!.retrievedPrompts.has('prompt-a')).toBe(true);
expect(state!.retrievedPrompts.has('prompt-b')).toBe(true);
});
it('ungate appends tags on repeated calls', () => {
const gate = new SessionGate();
gate.createSession('s1', true);
gate.ungate('s1', ['zigbee'], makeMatchResult(['p1']));
gate.ungate('s1', ['mqtt'], makeMatchResult(['p2']));
const state = gate.getSession('s1');
expect(state!.tags).toEqual(['zigbee', 'mqtt']);
expect(state!.retrievedPrompts.has('p1')).toBe(true);
expect(state!.retrievedPrompts.has('p2')).toBe(true);
});
it('ungate is no-op for unknown sessions', () => {
const gate = new SessionGate();
// Should not throw
gate.ungate('nonexistent', ['tag'], makeMatchResult(['p']));
});
it('addRetrievedPrompts records additional prompts', () => {
const gate = new SessionGate();
gate.createSession('s1', true);
gate.ungate('s1', ['zigbee'], makeMatchResult(['p1']));
gate.addRetrievedPrompts('s1', ['mqtt', 'lights'], ['p2', 'p3']);
const state = gate.getSession('s1');
expect(state!.tags).toEqual(['zigbee', 'mqtt', 'lights']);
expect(state!.retrievedPrompts.has('p2')).toBe(true);
expect(state!.retrievedPrompts.has('p3')).toBe(true);
});
it('addRetrievedPrompts is no-op for unknown sessions', () => {
const gate = new SessionGate();
gate.addRetrievedPrompts('nonexistent', ['tag'], ['p']);
});
it('filterAlreadySent removes already-sent prompts', () => {
const gate = new SessionGate();
gate.createSession('s1', true);
gate.ungate('s1', ['zigbee'], makeMatchResult(['p1']));
const prompts: PromptIndexEntry[] = [
{ name: 'p1', priority: 5, summary: 'already sent', chapters: null, content: 'x' },
{ name: 'p2', priority: 5, summary: 'new', chapters: null, content: 'y' },
];
const filtered = gate.filterAlreadySent('s1', prompts);
expect(filtered).toHaveLength(1);
expect(filtered[0]!.name).toBe('p2');
});
it('filterAlreadySent returns all prompts for unknown sessions', () => {
const gate = new SessionGate();
const prompts: PromptIndexEntry[] = [
{ name: 'p1', priority: 5, summary: null, chapters: null, content: 'x' },
];
const filtered = gate.filterAlreadySent('nonexistent', prompts);
expect(filtered).toHaveLength(1);
});
it('removeSession cleans up state', () => {
const gate = new SessionGate();
gate.createSession('s1', true);
expect(gate.getSession('s1')).not.toBeNull();
gate.removeSession('s1');
expect(gate.getSession('s1')).toBeNull();
expect(gate.isGated('s1')).toBe(false);
});
it('removeSession is safe for unknown sessions', () => {
const gate = new SessionGate();
gate.removeSession('nonexistent'); // Should not throw
});
it('manages multiple sessions independently', () => {
const gate = new SessionGate();
gate.createSession('s1', true);
gate.createSession('s2', false);
expect(gate.isGated('s1')).toBe(true);
expect(gate.isGated('s2')).toBe(false);
gate.ungate('s1', ['zigbee'], makeMatchResult(['p1']));
expect(gate.isGated('s1')).toBe(false);
expect(gate.getSession('s2')!.tags).toEqual([]); // s2 untouched
});
});

View File

@@ -0,0 +1,165 @@
import { describe, it, expect } from 'vitest';
import { TagMatcher, extractKeywordsFromToolCall, type PromptIndexEntry } from '../src/gate/tag-matcher.js';
function makePrompt(overrides: Partial<PromptIndexEntry> = {}): PromptIndexEntry {
return {
name: 'test-prompt',
priority: 5,
summary: 'A test prompt for testing',
chapters: ['Chapter One', 'Chapter Two'],
content: 'Full content of the test prompt.',
...overrides,
};
}
describe('TagMatcher', () => {
it('returns priority 10 prompts regardless of tags', () => {
const matcher = new TagMatcher();
const critical = makePrompt({ name: 'common-mistakes', priority: 10, summary: 'Unrelated stuff' });
const normal = makePrompt({ name: 'normal', priority: 5, summary: 'Something else' });
const result = matcher.match([], [critical, normal]);
expect(result.fullContent.map((p) => p.name)).toEqual(['common-mistakes']);
expect(result.remaining.map((p) => p.name)).toEqual(['normal']);
});
it('scores by matching_tags * priority', () => {
const matcher = new TagMatcher();
const high = makePrompt({ name: 'important', priority: 8, summary: 'zigbee mqtt pairing' });
const low = makePrompt({ name: 'basic', priority: 3, summary: 'zigbee basics' });
// Both match "zigbee": high scores 1*8=8, low scores 1*3=3
const result = matcher.match(['zigbee'], [low, high]);
expect(result.fullContent[0]!.name).toBe('important');
expect(result.fullContent[1]!.name).toBe('basic');
});
it('matches more tags = higher score', () => {
const matcher = new TagMatcher();
const twoMatch = makePrompt({ name: 'two-match', priority: 5, summary: 'zigbee mqtt' });
const oneMatch = makePrompt({ name: 'one-match', priority: 5, summary: 'zigbee only' });
// two-match: 2*5=10, one-match: 1*5=5
const result = matcher.match(['zigbee', 'mqtt'], [oneMatch, twoMatch]);
expect(result.fullContent[0]!.name).toBe('two-match');
});
it('performs case-insensitive matching', () => {
const matcher = new TagMatcher();
const prompt = makePrompt({ name: 'test', summary: 'ZIGBEE Protocol Setup' });
const result = matcher.match(['zigbee'], [prompt]);
expect(result.fullContent).toHaveLength(1);
});
it('matches against name, summary, and chapters', () => {
const matcher = new TagMatcher();
const byName = makePrompt({ name: 'zigbee-config', summary: 'unrelated', chapters: [] });
const bySummary = makePrompt({ name: 'setup', summary: 'zigbee setup guide', chapters: [] });
const byChapter = makePrompt({ name: 'guide', summary: 'unrelated', chapters: ['Zigbee Pairing'] });
const result = matcher.match(['zigbee'], [byName, bySummary, byChapter]);
expect(result.fullContent).toHaveLength(3);
});
it('respects byte budget', () => {
const matcher = new TagMatcher(100); // Very small budget
const small = makePrompt({ name: 'small', summary: 'zigbee', content: 'Short.' }); // ~6 bytes
const big = makePrompt({ name: 'big', summary: 'zigbee', content: 'x'.repeat(200) }); // 200 bytes
const result = matcher.match(['zigbee'], [small, big]);
expect(result.fullContent.map((p) => p.name)).toEqual(['small']);
expect(result.indexOnly.map((p) => p.name)).toEqual(['big']);
});
it('puts non-matched prompts in remaining', () => {
const matcher = new TagMatcher();
const matched = makePrompt({ name: 'matched', summary: 'zigbee stuff' });
const unmatched = makePrompt({ name: 'unmatched', summary: 'completely different topic' });
const result = matcher.match(['zigbee'], [matched, unmatched]);
expect(result.fullContent.map((p) => p.name)).toEqual(['matched']);
expect(result.remaining.map((p) => p.name)).toEqual(['unmatched']);
});
it('handles empty tags — only priority 10 matched', () => {
const matcher = new TagMatcher();
const critical = makePrompt({ name: 'critical', priority: 10 });
const normal = makePrompt({ name: 'normal', priority: 5 });
const result = matcher.match([], [critical, normal]);
expect(result.fullContent.map((p) => p.name)).toEqual(['critical']);
expect(result.remaining.map((p) => p.name)).toEqual(['normal']);
});
it('handles empty prompts array', () => {
const matcher = new TagMatcher();
const result = matcher.match(['zigbee'], []);
expect(result.fullContent).toEqual([]);
expect(result.indexOnly).toEqual([]);
expect(result.remaining).toEqual([]);
});
it('all priority 10 prompts are included even beyond budget', () => {
const matcher = new TagMatcher(50); // Tiny budget
const c1 = makePrompt({ name: 'c1', priority: 10, content: 'x'.repeat(40) });
const c2 = makePrompt({ name: 'c2', priority: 10, content: 'y'.repeat(40) });
const result = matcher.match([], [c1, c2]);
// Both should be in fullContent — priority 10 has Infinity score
// First one fits budget, second overflows but still priority 10
expect(result.fullContent.length + result.indexOnly.length).toBe(2);
// At minimum the first one is in fullContent
expect(result.fullContent[0]!.name).toBe('c1');
});
it('sorts matched by score descending', () => {
const matcher = new TagMatcher();
const p1 = makePrompt({ name: 'p1', priority: 3, summary: 'mqtt zigbee lights' }); // 3 matches * 3 = 9
const p2 = makePrompt({ name: 'p2', priority: 8, summary: 'mqtt' }); // 1 match * 8 = 8
const p3 = makePrompt({ name: 'p3', priority: 2, summary: 'mqtt zigbee lights pairing automation' }); // 5 * 2 = 10
const result = matcher.match(['mqtt', 'zigbee', 'lights', 'pairing', 'automation'], [p1, p2, p3]);
expect(result.fullContent.map((p) => p.name)).toEqual(['p3', 'p1', 'p2']);
});
});
describe('extractKeywordsFromToolCall', () => {
it('extracts from tool name', () => {
const keywords = extractKeywordsFromToolCall('home-assistant/get_entities', {});
expect(keywords).toContain('home');
expect(keywords).toContain('assistant');
expect(keywords).toContain('get_entities');
});
it('extracts from string arguments', () => {
const keywords = extractKeywordsFromToolCall('tool', { domain: 'light', area: 'kitchen' });
expect(keywords).toContain('light');
expect(keywords).toContain('kitchen');
});
it('ignores short words (<=2 chars)', () => {
const keywords = extractKeywordsFromToolCall('ab', { x: 'hi' });
expect(keywords).not.toContain('ab');
expect(keywords).not.toContain('hi');
});
it('ignores long string values (>200 chars)', () => {
const keywords = extractKeywordsFromToolCall('tool', { data: 'x'.repeat(201) });
// Only 'tool' from the name
expect(keywords).toEqual(['tool']);
});
it('caps at 10 keywords', () => {
const args: Record<string, string> = {};
for (let i = 0; i < 20; i++) args[`key${i}`] = `keyword${i}value`;
const keywords = extractKeywordsFromToolCall('tool', args);
expect(keywords.length).toBeLessThanOrEqual(10);
});
it('lowercases all keywords', () => {
const keywords = extractKeywordsFromToolCall('MyTool', { name: 'MQTT' });
expect(keywords).toContain('mytool');
expect(keywords).toContain('mqtt');
});
});