diff --git a/src/mcpd/tests/mcp-server-flow.test.ts b/src/mcpd/tests/mcp-server-flow.test.ts new file mode 100644 index 0000000..0ae6d11 --- /dev/null +++ b/src/mcpd/tests/mcp-server-flow.test.ts @@ -0,0 +1,753 @@ +import { describe, it, expect, vi, beforeAll, afterAll, beforeEach } from 'vitest'; +import Fastify from 'fastify'; +import type { FastifyInstance } from 'fastify'; +import http from 'node:http'; +import { McpServerService } from '../src/services/mcp-server.service.js'; +import { InstanceService } from '../src/services/instance.service.js'; +import { McpProxyService } from '../src/services/mcp-proxy-service.js'; +import { AuditLogService } from '../src/services/audit-log.service.js'; +import { errorHandler } from '../src/middleware/error-handler.js'; +import { registerMcpServerRoutes } from '../src/routes/mcp-servers.js'; +import { registerInstanceRoutes } from '../src/routes/instances.js'; +import { registerMcpProxyRoutes } from '../src/routes/mcp-proxy.js'; +import type { + IMcpServerRepository, + IMcpInstanceRepository, + IAuditLogRepository, +} from '../src/repositories/interfaces.js'; +import type { McpOrchestrator } from '../src/services/orchestrator.js'; +import type { McpServer, McpInstance, InstanceStatus } from '@prisma/client'; + +// --------------------------------------------------------------------------- +// In-memory repository implementations (stateful mocks) +// --------------------------------------------------------------------------- + +function createInMemoryServerRepo(): IMcpServerRepository { + const servers = new Map(); + let nextId = 1; + + return { + findAll: vi.fn(async () => [...servers.values()]), + findById: vi.fn(async (id: string) => servers.get(id) ?? null), + findByName: vi.fn(async (name: string) => [...servers.values()].find((s) => s.name === name) ?? null), + create: vi.fn(async (data) => { + const id = `srv-${nextId++}`; + const server = { + id, + name: data.name, + description: data.description ?? '', + packageName: data.packageName ?? null, + dockerImage: data.dockerImage ?? null, + transport: data.transport ?? 'STDIO', + repositoryUrl: data.repositoryUrl ?? null, + externalUrl: data.externalUrl ?? null, + command: data.command ?? null, + containerPort: data.containerPort ?? null, + envTemplate: data.envTemplate ?? [], + version: 1, + createdAt: new Date(), + updatedAt: new Date(), + } as McpServer; + servers.set(id, server); + return server; + }), + update: vi.fn(async (id: string, data) => { + const existing = servers.get(id); + if (!existing) throw new Error(`Server ${id} not found`); + const updated = { ...existing, ...data, updatedAt: new Date() } as McpServer; + servers.set(id, updated); + return updated; + }), + delete: vi.fn(async (id: string) => { + servers.delete(id); + }), + }; +} + +function createInMemoryInstanceRepo(): IMcpInstanceRepository { + const instances = new Map(); + let nextId = 1; + + return { + findAll: vi.fn(async (serverId?: string) => { + const all = [...instances.values()]; + return serverId ? all.filter((i) => i.serverId === serverId) : all; + }), + findById: vi.fn(async (id: string) => instances.get(id) ?? null), + findByContainerId: vi.fn(async (containerId: string) => + [...instances.values()].find((i) => i.containerId === containerId) ?? null, + ), + create: vi.fn(async (data) => { + const id = `inst-${nextId++}`; + const instance = { + id, + serverId: data.serverId, + containerId: data.containerId ?? null, + status: (data.status ?? 'STOPPED') as InstanceStatus, + port: data.port ?? null, + metadata: data.metadata ?? {}, + version: 1, + createdAt: new Date(), + updatedAt: new Date(), + } as McpInstance; + instances.set(id, instance); + return instance; + }), + updateStatus: vi.fn(async (id: string, status: InstanceStatus, fields?) => { + const existing = instances.get(id); + if (!existing) throw new Error(`Instance ${id} not found`); + const updated = { + ...existing, + status, + ...(fields?.containerId !== undefined ? { containerId: fields.containerId } : {}), + ...(fields?.port !== undefined ? { port: fields.port } : {}), + ...(fields?.metadata !== undefined ? { metadata: fields.metadata } : {}), + version: existing.version + 1, + updatedAt: new Date(), + } as McpInstance; + instances.set(id, updated); + return updated; + }), + delete: vi.fn(async (id: string) => { + instances.delete(id); + }), + }; +} + +function createInMemoryAuditLogRepo(): IAuditLogRepository { + const logs: Array<{ id: string; userId: string; action: string; resource: string; resourceId: string | null; details: Record; createdAt: Date }> = []; + let nextId = 1; + + return { + findAll: vi.fn(async () => logs as never[]), + findById: vi.fn(async (id: string) => (logs.find((l) => l.id === id) as never) ?? null), + create: vi.fn(async (data) => { + const log = { + id: `log-${nextId++}`, + userId: data.userId, + action: data.action, + resource: data.resource, + resourceId: data.resourceId ?? null, + details: data.details ?? {}, + createdAt: new Date(), + }; + logs.push(log); + return log as never; + }), + count: vi.fn(async () => logs.length), + deleteOlderThan: vi.fn(async () => 0), + }; +} + +function createMockOrchestrator(): McpOrchestrator { + let containerPort = 40000; + return { + ping: vi.fn(async () => true), + pullImage: vi.fn(async () => {}), + createContainer: vi.fn(async (spec) => ({ + containerId: `ctr-${spec.name}`, + name: spec.name, + state: 'running' as const, + port: spec.containerPort ?? ++containerPort, + createdAt: new Date(), + })), + stopContainer: vi.fn(async () => {}), + removeContainer: vi.fn(async () => {}), + inspectContainer: vi.fn(async (id) => ({ + containerId: id, + name: 'test', + state: 'running' as const, + createdAt: new Date(), + })), + getContainerLogs: vi.fn(async () => ({ stdout: '', stderr: '' })), + }; +} + +// --------------------------------------------------------------------------- +// Fake MCP server (streamable-http) +// --------------------------------------------------------------------------- + +function createFakeMcpServer(): { server: http.Server; getPort: () => number; requests: Array<{ method: string; body: unknown }> } { + const requests: Array<{ method: string; body: unknown }> = []; + let sessionCounter = 0; + + const server = http.createServer((req, res) => { + let body = ''; + req.on('data', (chunk) => (body += chunk)); + req.on('end', () => { + let parsed: { method?: string; id?: number; params?: unknown } = {}; + try { + parsed = JSON.parse(body); + } catch { + // notifications may not have id + } + + requests.push({ method: parsed.method ?? 'unknown', body: parsed }); + + if (parsed.method === 'initialize') { + const sessionId = `session-${++sessionCounter}`; + const response = { + jsonrpc: '2.0', + id: parsed.id, + result: { + protocolVersion: '2025-03-26', + capabilities: { tools: {} }, + serverInfo: { name: 'fake-mcp', version: '1.0.0' }, + }, + }; + res.writeHead(200, { + 'Content-Type': 'text/event-stream', + 'Mcp-Session-Id': sessionId, + }); + res.end(`event: message\ndata: ${JSON.stringify(response)}\n\n`); + return; + } + + if (parsed.method === 'notifications/initialized') { + res.writeHead(200, { 'Content-Type': 'application/json' }); + res.end(''); + return; + } + + if (parsed.method === 'tools/list') { + const response = { + jsonrpc: '2.0', + id: parsed.id, + result: { + tools: [ + { name: 'ha_get_overview', description: 'Get Home Assistant overview', inputSchema: { type: 'object', properties: {} } }, + { name: 'ha_search_entities', description: 'Search HA entities', inputSchema: { type: 'object', properties: { query: { type: 'string' } } } }, + ], + }, + }; + res.writeHead(200, { 'Content-Type': 'text/event-stream' }); + res.end(`event: message\ndata: ${JSON.stringify(response)}\n\n`); + return; + } + + if (parsed.method === 'tools/call') { + const toolName = (parsed.params as { name?: string })?.name; + const response = { + jsonrpc: '2.0', + id: parsed.id, + result: { + content: [{ type: 'text', text: `Result from ${toolName}` }], + }, + }; + res.writeHead(200, { 'Content-Type': 'text/event-stream' }); + res.end(`event: message\ndata: ${JSON.stringify(response)}\n\n`); + return; + } + + // Default: echo back + res.writeHead(200, { 'Content-Type': 'application/json' }); + res.end(JSON.stringify({ jsonrpc: '2.0', id: parsed.id, result: {} })); + }); + }); + + let port = 0; + return { + server, + getPort: () => port, + requests, + ...{ + listen: () => + new Promise((resolve) => { + server.listen(0, () => { + const addr = server.address(); + if (addr && typeof addr === 'object') port = addr.port; + resolve(); + }); + }), + close: () => new Promise((resolve) => server.close(() => resolve())), + }, + } as ReturnType & { listen: () => Promise; close: () => Promise }; +} + +// --------------------------------------------------------------------------- +// Test app builder +// --------------------------------------------------------------------------- + +async function buildTestApp(deps: { + serverRepo: IMcpServerRepository; + instanceRepo: IMcpInstanceRepository; + auditLogRepo: IAuditLogRepository; + orchestrator: McpOrchestrator; +}): Promise { + const app = Fastify({ logger: false }); + app.setErrorHandler(errorHandler); + + const serverService = new McpServerService(deps.serverRepo); + const instanceService = new InstanceService(deps.instanceRepo, deps.serverRepo, deps.orchestrator); + const proxyService = new McpProxyService(deps.instanceRepo, deps.serverRepo); + const auditLogService = new AuditLogService(deps.auditLogRepo); + + registerMcpServerRoutes(app, serverService); + registerInstanceRoutes(app, instanceService); + registerMcpProxyRoutes(app, { + mcpProxyService: proxyService, + auditLogService, + authDeps: { + findSession: async () => ({ userId: 'test-user', expiresAt: new Date(Date.now() + 3600_000) }), + }, + }); + + await app.ready(); + return app; +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +describe('MCP server full flow', () => { + let fakeMcp: ReturnType & { listen: () => Promise; close: () => Promise }; + let fakeMcpPort: number; + + beforeAll(async () => { + fakeMcp = createFakeMcpServer() as typeof fakeMcp; + await fakeMcp.listen(); + fakeMcpPort = fakeMcp.getPort(); + }); + + afterAll(async () => { + await fakeMcp.close(); + }); + + describe('external server flow (externalUrl)', () => { + let app: FastifyInstance; + let serverRepo: IMcpServerRepository; + let instanceRepo: IMcpInstanceRepository; + + beforeEach(async () => { + serverRepo = createInMemoryServerRepo(); + instanceRepo = createInMemoryInstanceRepo(); + app = await buildTestApp({ + serverRepo, + instanceRepo, + auditLogRepo: createInMemoryAuditLogRepo(), + orchestrator: createMockOrchestrator(), + }); + }); + + afterAll(async () => { + if (app) await app.close(); + }); + + it('registers server, starts virtual instance, and proxies tools/list', async () => { + // 1. Register external MCP server + const createRes = await app.inject({ + method: 'POST', + url: '/api/v1/servers', + payload: { + name: 'ha-mcp', + description: 'Home Assistant MCP', + transport: 'STREAMABLE_HTTP', + externalUrl: `http://localhost:${fakeMcpPort}`, + containerPort: 3000, + envTemplate: [ + { name: 'HOMEASSISTANT_TOKEN', description: 'HA token', isSecret: true }, + ], + }, + }); + + expect(createRes.statusCode).toBe(201); + const server = createRes.json<{ id: string; name: string; externalUrl: string }>(); + expect(server.name).toBe('ha-mcp'); + expect(server.externalUrl).toBe(`http://localhost:${fakeMcpPort}`); + + // 2. Verify server is listed + const listRes = await app.inject({ method: 'GET', url: '/api/v1/servers' }); + expect(listRes.statusCode).toBe(200); + const servers = listRes.json>(); + expect(servers).toHaveLength(1); + expect(servers[0]!.name).toBe('ha-mcp'); + + // 3. Start a virtual instance (external server — no Docker) + const startRes = await app.inject({ + method: 'POST', + url: '/api/v1/instances', + payload: { serverId: server.id }, + }); + + expect(startRes.statusCode).toBe(201); + const instance = startRes.json<{ id: string; status: string; containerId: string | null }>(); + expect(instance.status).toBe('RUNNING'); + expect(instance.containerId).toBeNull(); + + // 4. Proxy tools/list to the fake MCP server + const proxyRes = await app.inject({ + method: 'POST', + url: '/api/v1/mcp/proxy', + headers: { authorization: 'Bearer test-token' }, + payload: { + serverId: server.id, + method: 'tools/list', + }, + }); + + expect(proxyRes.statusCode).toBe(200); + const proxyBody = proxyRes.json<{ jsonrpc: string; result: { tools: Array<{ name: string }> } }>(); + expect(proxyBody.jsonrpc).toBe('2.0'); + expect(proxyBody.result.tools).toHaveLength(2); + expect(proxyBody.result.tools.map((t) => t.name)).toContain('ha_get_overview'); + expect(proxyBody.result.tools.map((t) => t.name)).toContain('ha_search_entities'); + + // 5. Verify the fake server received the protocol handshake + tools/list + const methods = fakeMcp.requests.map((r) => r.method); + expect(methods).toContain('initialize'); + expect(methods).toContain('notifications/initialized'); + expect(methods).toContain('tools/list'); + }); + + it('proxies tools/call with parameters', async () => { + // Register + start + const createRes = await app.inject({ + method: 'POST', + url: '/api/v1/servers', + payload: { + name: 'ha-mcp-call', + description: 'HA MCP for call test', + transport: 'STREAMABLE_HTTP', + externalUrl: `http://localhost:${fakeMcpPort}`, + }, + }); + const server = createRes.json<{ id: string }>(); + + await app.inject({ + method: 'POST', + url: '/api/v1/instances', + payload: { serverId: server.id }, + }); + + // Proxy tools/call + const proxyRes = await app.inject({ + method: 'POST', + url: '/api/v1/mcp/proxy', + headers: { authorization: 'Bearer test-token' }, + payload: { + serverId: server.id, + method: 'tools/call', + params: { name: 'ha_get_overview' }, + }, + }); + + expect(proxyRes.statusCode).toBe(200); + const body = proxyRes.json<{ result: { content: Array<{ text: string }> } }>(); + expect(body.result.content[0]!.text).toBe('Result from ha_get_overview'); + }); + }); + + describe('managed server flow (Docker)', () => { + let app: FastifyInstance; + let orchestrator: ReturnType; + + beforeEach(async () => { + orchestrator = createMockOrchestrator(); + app = await buildTestApp({ + serverRepo: createInMemoryServerRepo(), + instanceRepo: createInMemoryInstanceRepo(), + auditLogRepo: createInMemoryAuditLogRepo(), + orchestrator, + }); + }); + + afterAll(async () => { + if (app) await app.close(); + }); + + it('registers server with dockerImage, starts container, and creates instance', async () => { + // 1. Register managed server + const createRes = await app.inject({ + method: 'POST', + url: '/api/v1/servers', + payload: { + name: 'ha-mcp-docker', + description: 'HA MCP managed by Docker', + dockerImage: 'ghcr.io/homeassistant-ai/ha-mcp:2.4', + transport: 'STREAMABLE_HTTP', + containerPort: 3000, + command: ['python', '-c', 'print("hello")'], + envTemplate: [ + { name: 'HOMEASSISTANT_URL', description: 'HA URL' }, + { name: 'HOMEASSISTANT_TOKEN', description: 'HA token', isSecret: true }, + ], + }, + }); + + expect(createRes.statusCode).toBe(201); + const server = createRes.json<{ id: string; name: string; dockerImage: string; command: string[] }>(); + expect(server.name).toBe('ha-mcp-docker'); + expect(server.dockerImage).toBe('ghcr.io/homeassistant-ai/ha-mcp:2.4'); + expect(server.command).toEqual(['python', '-c', 'print("hello")']); + + // 2. Start container instance with env + const startRes = await app.inject({ + method: 'POST', + url: '/api/v1/instances', + payload: { + serverId: server.id, + env: { HOMEASSISTANT_URL: 'https://ha.example.com', HOMEASSISTANT_TOKEN: 'secret' }, + }, + }); + + expect(startRes.statusCode).toBe(201); + const instance = startRes.json<{ id: string; status: string; containerId: string }>(); + expect(instance.status).toBe('RUNNING'); + expect(instance.containerId).toBeTruthy(); + + // 3. Verify orchestrator was called with correct spec + expect(orchestrator.createContainer).toHaveBeenCalledTimes(1); + const spec = vi.mocked(orchestrator.createContainer).mock.calls[0]![0]; + expect(spec.image).toBe('ghcr.io/homeassistant-ai/ha-mcp:2.4'); + expect(spec.containerPort).toBe(3000); + expect(spec.command).toEqual(['python', '-c', 'print("hello")']); + expect(spec.env).toEqual({ + HOMEASSISTANT_URL: 'https://ha.example.com', + HOMEASSISTANT_TOKEN: 'secret', + }); + }); + + it('marks instance as ERROR when Docker fails', async () => { + vi.mocked(orchestrator.createContainer).mockRejectedValueOnce(new Error('Docker socket unavailable')); + + const createRes = await app.inject({ + method: 'POST', + url: '/api/v1/servers', + payload: { + name: 'failing-server', + description: 'Will fail to start', + dockerImage: 'some-image:latest', + transport: 'STDIO', + }, + }); + const server = createRes.json<{ id: string }>(); + + const startRes = await app.inject({ + method: 'POST', + url: '/api/v1/instances', + payload: { serverId: server.id }, + }); + + expect(startRes.statusCode).toBe(201); + const instance = startRes.json<{ id: string; status: string }>(); + expect(instance.status).toBe('ERROR'); + }); + }); + + describe('full lifecycle', () => { + let app: FastifyInstance; + let orchestrator: ReturnType; + + beforeEach(async () => { + orchestrator = createMockOrchestrator(); + app = await buildTestApp({ + serverRepo: createInMemoryServerRepo(), + instanceRepo: createInMemoryInstanceRepo(), + auditLogRepo: createInMemoryAuditLogRepo(), + orchestrator, + }); + }); + + afterAll(async () => { + if (app) await app.close(); + }); + + it('register → start → list → stop → remove', async () => { + // Register + const createRes = await app.inject({ + method: 'POST', + url: '/api/v1/servers', + payload: { + name: 'lifecycle-test', + description: 'Full lifecycle', + dockerImage: 'test:latest', + transport: 'SSE', + containerPort: 8080, + }, + }); + expect(createRes.statusCode).toBe(201); + const server = createRes.json<{ id: string }>(); + + // Start + const startRes = await app.inject({ + method: 'POST', + url: '/api/v1/instances', + payload: { serverId: server.id }, + }); + expect(startRes.statusCode).toBe(201); + const instance = startRes.json<{ id: string; status: string }>(); + expect(instance.status).toBe('RUNNING'); + + // List instances + const listRes = await app.inject({ + method: 'GET', + url: `/api/v1/instances?serverId=${server.id}`, + }); + expect(listRes.statusCode).toBe(200); + const instances = listRes.json>(); + expect(instances).toHaveLength(1); + + // Stop + const stopRes = await app.inject({ + method: 'POST', + url: `/api/v1/instances/${instance.id}/stop`, + }); + expect(stopRes.statusCode).toBe(200); + expect(stopRes.json<{ status: string }>().status).toBe('STOPPED'); + + // Remove + const removeRes = await app.inject({ + method: 'DELETE', + url: `/api/v1/instances/${instance.id}`, + }); + expect(removeRes.statusCode).toBe(204); + + // Verify instance is gone + const listAfter = await app.inject({ + method: 'GET', + url: `/api/v1/instances?serverId=${server.id}`, + }); + expect(listAfter.json()).toHaveLength(0); + + // Delete server + const deleteRes = await app.inject({ + method: 'DELETE', + url: `/api/v1/servers/${server.id}`, + }); + expect(deleteRes.statusCode).toBe(204); + + // Verify server is gone + const serversAfter = await app.inject({ method: 'GET', url: '/api/v1/servers' }); + expect(serversAfter.json()).toHaveLength(0); + }); + + it('external server lifecycle: register → start → proxy → stop → cleanup', async () => { + // Register external + const createRes = await app.inject({ + method: 'POST', + url: '/api/v1/servers', + payload: { + name: 'external-lifecycle', + transport: 'STREAMABLE_HTTP', + externalUrl: `http://localhost:${fakeMcpPort}`, + }, + }); + const server = createRes.json<{ id: string }>(); + + // Start (virtual instance) + const startRes = await app.inject({ + method: 'POST', + url: '/api/v1/instances', + payload: { serverId: server.id }, + }); + const instance = startRes.json<{ id: string; status: string; containerId: string | null }>(); + expect(instance.status).toBe('RUNNING'); + expect(instance.containerId).toBeNull(); + + // Proxy tools/list + const proxyRes = await app.inject({ + method: 'POST', + url: '/api/v1/mcp/proxy', + headers: { authorization: 'Bearer test-token' }, + payload: { serverId: server.id, method: 'tools/list' }, + }); + expect(proxyRes.statusCode).toBe(200); + expect(proxyRes.json<{ result: { tools: unknown[] } }>().result.tools.length).toBeGreaterThan(0); + + // Stop (no container to stop) + const stopRes = await app.inject({ + method: 'POST', + url: `/api/v1/instances/${instance.id}/stop`, + }); + expect(stopRes.statusCode).toBe(200); + expect(stopRes.json<{ status: string }>().status).toBe('STOPPED'); + + // Docker orchestrator should NOT have been called + expect(orchestrator.createContainer).not.toHaveBeenCalled(); + expect(orchestrator.stopContainer).not.toHaveBeenCalled(); + }); + }); + + describe('proxy authentication', () => { + let app: FastifyInstance; + + beforeEach(async () => { + app = await buildTestApp({ + serverRepo: createInMemoryServerRepo(), + instanceRepo: createInMemoryInstanceRepo(), + auditLogRepo: createInMemoryAuditLogRepo(), + orchestrator: createMockOrchestrator(), + }); + }); + + afterAll(async () => { + if (app) await app.close(); + }); + + it('rejects proxy calls without auth header', async () => { + const res = await app.inject({ + method: 'POST', + url: '/api/v1/mcp/proxy', + payload: { serverId: 'srv-1', method: 'tools/list' }, + }); + // Auth middleware rejects with 401 (no Bearer token) + expect(res.statusCode).toBe(401); + }); + }); + + describe('server update flow', () => { + let app: FastifyInstance; + + beforeEach(async () => { + app = await buildTestApp({ + serverRepo: createInMemoryServerRepo(), + instanceRepo: createInMemoryInstanceRepo(), + auditLogRepo: createInMemoryAuditLogRepo(), + orchestrator: createMockOrchestrator(), + }); + }); + + afterAll(async () => { + if (app) await app.close(); + }); + + it('creates and updates server fields', async () => { + // Create + const createRes = await app.inject({ + method: 'POST', + url: '/api/v1/servers', + payload: { + name: 'updatable', + description: 'Original desc', + transport: 'STDIO', + }, + }); + const server = createRes.json<{ id: string; description: string }>(); + expect(server.description).toBe('Original desc'); + + // Update + const updateRes = await app.inject({ + method: 'PUT', + url: `/api/v1/servers/${server.id}`, + payload: { + description: 'Updated desc', + externalUrl: `http://localhost:${fakeMcpPort}`, + transport: 'STREAMABLE_HTTP', + }, + }); + expect(updateRes.statusCode).toBe(200); + const updated = updateRes.json<{ description: string; externalUrl: string; transport: string }>(); + expect(updated.description).toBe('Updated desc'); + expect(updated.externalUrl).toBe(`http://localhost:${fakeMcpPort}`); + expect(updated.transport).toBe('STREAMABLE_HTTP'); + + // Fetch to verify persistence + const getRes = await app.inject({ + method: 'GET', + url: `/api/v1/servers/${server.id}`, + }); + expect(getRes.json<{ description: string }>().description).toBe('Updated desc'); + }); + }); +});