From 4d8ee23d0e20ff08f2c9ed54afd639c6087fa128 Mon Sep 17 00:00:00 2001 From: Michal Date: Sun, 19 Apr 2026 13:05:43 +0100 Subject: [PATCH] feat(mcplocal): RBAC-bounded vllm-managed failover + name-based llm lookup MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Why: when mcpd's inference proxy is unreachable, clients with a local vllm-managed provider should be able to substitute — but only if they still have view permission on the centralized Llm. Otherwise revoking an Llm wouldn't actually stop a misbehaving client. Infrastructure (the agent + mcplocal HTTP-mode wire-up will land separately when those clients pivot to mcpd's proxy): - LlmProviderFileEntry gains optional `failoverFor: `. The entry is otherwise the same local provider it always was; the new field just declares which central Llm it can substitute for. - ProviderRegistry tracks a failover map (registerFailover / getFailoverFor / listFailovers). Unregister removes any failover entry pointing at the removed provider so we don't end up with dangling references. - New FailoverRouter wraps a primary inference call. On primary failure: if a local provider is registered for the Llm, HEAD-probe `mcpd /api/v1/llms/ :name` with the caller's bearer to verify view permission, then either invoke the local provider (allowed) or re-throw the primary error (403, 401, network unreachable, anything else — all fail-closed). - Server: GET /api/v1/llms/:idOrName accepts both CUID and human name. Lets FailoverRouter probe by name without a separate id-resolution call. HEAD derives automatically from GET in Fastify, which runs the same RBAC hook and drops the body — exactly what the probe needs. Tests: 11 failover unit tests (registry map, decision flow, fail-closed for forbidden + unreachable, checkAuth status mapping) + 4 new route tests (name lookup, HEAD existing/missing). Full suite 1844/1844 (+14 from Phase 2's 1830). TypeScript clean across mcpd + mcplocal. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/mcpd/src/routes/llms.ts | 23 ++- src/mcpd/tests/llm-routes.test.ts | 19 ++ src/mcplocal/src/http/config.ts | 8 + src/mcplocal/src/llm-config.ts | 3 + src/mcplocal/src/providers/failover-router.ts | 107 +++++++++++ src/mcplocal/src/providers/registry.ts | 26 +++ src/mcplocal/tests/failover-router.test.ts | 170 ++++++++++++++++++ 7 files changed, 355 insertions(+), 1 deletion(-) create mode 100644 src/mcplocal/src/providers/failover-router.ts create mode 100644 src/mcplocal/tests/failover-router.test.ts diff --git a/src/mcpd/src/routes/llms.ts b/src/mcpd/src/routes/llms.ts index 58c26c9..3e0bf79 100644 --- a/src/mcpd/src/routes/llms.ts +++ b/src/mcpd/src/routes/llms.ts @@ -10,9 +10,12 @@ export function registerLlmRoutes( return service.list(); }); + // Accepts either CUID or human name. Used both by the CLI (which usually + // resolves to CUID first) and by FailoverRouter's RBAC pre-check (which + // hands over the user-facing name to avoid an extra round-trip). app.get<{ Params: { id: string } }>('/api/v1/llms/:id', async (request, reply) => { try { - return await service.getById(request.params.id); + return await getByIdOrName(service, request.params.id); } catch (err) { if (err instanceof NotFoundError) { reply.code(404); @@ -22,6 +25,10 @@ export function registerLlmRoutes( } }); + // No explicit HEAD handler: Fastify auto-derives HEAD from GET, which runs + // the same RBAC hook + lookup and drops the body. That's exactly what + // FailoverRouter wants for its "can the caller still view this Llm?" probe. + app.post('/api/v1/llms', async (request, reply) => { try { const row = await service.create(request.body); @@ -62,3 +69,17 @@ export function registerLlmRoutes( } }); } + +const CUID_RE = /^c[a-z0-9]{24}/i; + +/** + * Look up by CUID first; if the input doesn't look like one, fall back to + * findByName. Lets the same URL serve both `mcpctl describe llm ` and + * the FailoverRouter's name-based RBAC check. + */ +async function getByIdOrName(service: LlmService, idOrName: string) { + if (CUID_RE.test(idOrName)) { + return service.getById(idOrName); + } + return service.getByName(idOrName); +} diff --git a/src/mcpd/tests/llm-routes.test.ts b/src/mcpd/tests/llm-routes.test.ts index 2d06fd7..0a7ef6c 100644 --- a/src/mcpd/tests/llm-routes.test.ts +++ b/src/mcpd/tests/llm-routes.test.ts @@ -104,6 +104,25 @@ describe('Llm Routes', () => { expect(res.statusCode).toBe(404); }); + it('GET /api/v1/llms/:nameOrId resolves by human name when not a CUID', async () => { + await createApp(mockRepo([makeLlm({ id: 'llm-1', name: 'claude' })])); + const res = await app.inject({ method: 'GET', url: '/api/v1/llms/claude' }); + expect(res.statusCode).toBe(200); + expect(res.json<{ name: string; id: string }>().name).toBe('claude'); + }); + + it('HEAD /api/v1/llms/:name returns 200 for an existing Llm (failover RBAC pre-check)', async () => { + await createApp(mockRepo([makeLlm({ name: 'claude' })])); + const res = await app.inject({ method: 'HEAD', url: '/api/v1/llms/claude' }); + expect(res.statusCode).toBe(200); + }); + + it('HEAD /api/v1/llms/:name returns 404 for a missing Llm', async () => { + await createApp(mockRepo()); + const res = await app.inject({ method: 'HEAD', url: '/api/v1/llms/missing' }); + expect(res.statusCode).toBe(404); + }); + it('POST /api/v1/llms creates and returns 201', async () => { await createApp(mockRepo()); const res = await app.inject({ diff --git a/src/mcplocal/src/http/config.ts b/src/mcplocal/src/http/config.ts index 6d201d0..9be7fb4 100644 --- a/src/mcplocal/src/http/config.ts +++ b/src/mcplocal/src/http/config.ts @@ -64,6 +64,14 @@ export interface LlmProviderFileEntry { idleTimeoutMinutes?: number; /** vllm-managed: extra args for `vllm serve` */ extraArgs?: string[]; + /** + * If set, this local provider is allowed to substitute for the centralized + * Llm of this name when the mcpd inference proxy is unreachable. + * RBAC is still enforced — the caller must have view permission on the + * named Llm via mcpd before failover is permitted (fail-closed if mcpd + * itself can't be reached). + */ + failoverFor?: string; } export interface ProjectLlmOverride { diff --git a/src/mcplocal/src/llm-config.ts b/src/mcplocal/src/llm-config.ts index 3cb04c9..5777249 100644 --- a/src/mcplocal/src/llm-config.ts +++ b/src/mcplocal/src/llm-config.ts @@ -173,6 +173,9 @@ export async function createProvidersFromConfig( if (entry.tier) { registry.assignTier(provider.name, entry.tier); } + if (entry.failoverFor) { + registry.registerFailover(entry.failoverFor, provider.name); + } } return registry; diff --git a/src/mcplocal/src/providers/failover-router.ts b/src/mcplocal/src/providers/failover-router.ts new file mode 100644 index 0000000..4358b7a --- /dev/null +++ b/src/mcplocal/src/providers/failover-router.ts @@ -0,0 +1,107 @@ +/** + * FailoverRouter — orchestrates "try mcpd's centralized Llm, fall back to a + * local provider when authorized" for clients that consume the inference + * proxy. + * + * Decision flow on a centralized inference call: + * + * 1. Call the primary (the supplied `primary` callback, typically an HTTP + * POST to mcpd /api/v1/llms/:name/infer). + * 2. If that succeeds → done. + * 3. If it fails AND a local provider is registered as failover for this + * Llm name → call mcpd /api/v1/llms/:name (RBAC-gated) to verify the + * caller still has permission to view this Llm. mcpd unreachable → + * fail-closed (re-throw the original error). 403 → fail-closed. + * 4. 200 → invoke the local provider's `complete()` and tag the result + * as `failover: true` for client-side audit. + * + * The check call uses HEAD to avoid pulling the Llm body (and any + * description / extraConfig) over the wire — mcpd treats both methods the + * same in the RBAC hook because the URL maps to the same permission. + */ +import type { LlmProvider } from './types.js'; +import type { ProviderRegistry } from './registry.js'; + +export interface FailoverDecision { + result: T; + failover: boolean; + /** Name of the local provider used (only set when failover === true). */ + via?: string; +} + +export interface FailoverRouterDeps { + /** Injected fetch for the RBAC pre-check. Tests mock this. */ + fetch?: typeof globalThis.fetch; + /** mcpd base URL (no trailing slash). */ + mcpdUrl: string; + /** Bearer token to attach to the RBAC pre-check call. */ + bearerToken?: string; +} + +/** Outcome of the RBAC pre-check. Used internally + exposed for tests. */ +export type AuthCheckOutcome = 'allowed' | 'forbidden' | 'unreachable'; + +export class FailoverRouter { + private readonly fetchImpl: typeof globalThis.fetch; + private readonly mcpdUrl: string; + private readonly bearer: string | undefined; + + constructor( + private readonly registry: ProviderRegistry, + deps: FailoverRouterDeps, + ) { + this.fetchImpl = deps.fetch ?? globalThis.fetch; + this.mcpdUrl = deps.mcpdUrl.replace(/\/+$/, ''); + if (deps.bearerToken !== undefined) this.bearer = deps.bearerToken; + } + + /** + * Run a primary inference attempt; on failure, fall back to the local + * provider if one is registered for this Llm AND the caller still has + * `view:llms:` on mcpd. + * + * `primary` should reject (throw) when mcpd's proxy is unreachable or + * returns a 5xx — that's the signal to consider failover. 4xx errors that + * indicate a bad request are surfaced as-is; the router only retries on + * primary failure shapes that look like an upstream/network issue. + */ + async run( + llmName: string, + primary: () => Promise, + localCall: (provider: LlmProvider) => Promise, + ): Promise> { + try { + const result = await primary(); + return { result, failover: false }; + } catch (primaryErr) { + const local = this.registry.getFailoverFor(llmName); + if (local === null) throw primaryErr; + + const auth = await this.checkAuth(llmName); + if (auth !== 'allowed') { + // Fail-closed for forbidden AND unreachable. + throw primaryErr; + } + + const result = await localCall(local); + return { result, failover: true, via: local.name }; + } + } + + /** RBAC pre-check exposed for tests / status-display callers. */ + async checkAuth(llmName: string): Promise { + const url = `${this.mcpdUrl}/api/v1/llms/${encodeURIComponent(llmName)}`; + const headers: Record = {}; + if (this.bearer !== undefined) headers['Authorization'] = `Bearer ${this.bearer}`; + let res: Response; + try { + res = await this.fetchImpl(url, { method: 'HEAD', headers }); + } catch { + return 'unreachable'; + } + if (res.status === 200 || res.status === 204) return 'allowed'; + if (res.status === 403 || res.status === 401) return 'forbidden'; + // Anything else (404, 500…) — treat as unreachable for the failover flow. + return 'unreachable'; + } +} diff --git a/src/mcplocal/src/providers/registry.ts b/src/mcplocal/src/providers/registry.ts index 03cb52a..472ee5e 100644 --- a/src/mcplocal/src/providers/registry.ts +++ b/src/mcplocal/src/providers/registry.ts @@ -8,6 +8,8 @@ export class ProviderRegistry { private providers = new Map(); private activeProvider: string | null = null; private tierProviders = new Map(); + /** Maps a centralized Llm name → local provider name that can substitute when mcpd is unreachable. */ + private failoverMap = new Map(); register(provider: LlmProvider): void { this.providers.set(provider.name, provider); @@ -31,6 +33,30 @@ export class ProviderRegistry { this.tierProviders.set(tier, filtered); } } + // Remove from failover map (any entry whose local-provider value points at this name) + for (const [centralName, localName] of this.failoverMap) { + if (localName === name) this.failoverMap.delete(centralName); + } + } + + /** Mark `localProviderName` as the failover for the centralized Llm named `centralLlmName`. */ + registerFailover(centralLlmName: string, localProviderName: string): void { + if (!this.providers.has(localProviderName)) { + throw new Error(`Provider '${localProviderName}' is not registered`); + } + this.failoverMap.set(centralLlmName, localProviderName); + } + + /** Look up the local provider that can substitute for a centralized Llm, if any. */ + getFailoverFor(centralLlmName: string): LlmProvider | null { + const localName = this.failoverMap.get(centralLlmName); + if (localName === undefined) return null; + return this.providers.get(localName) ?? null; + } + + /** Names of central Llms that have a local failover registered. Used in status output. */ + listFailovers(): Array<{ centralLlmName: string; localProviderName: string }> { + return [...this.failoverMap.entries()].map(([centralLlmName, localProviderName]) => ({ centralLlmName, localProviderName })); } setActive(name: string): void { diff --git a/src/mcplocal/tests/failover-router.test.ts b/src/mcplocal/tests/failover-router.test.ts new file mode 100644 index 0000000..98a2631 --- /dev/null +++ b/src/mcplocal/tests/failover-router.test.ts @@ -0,0 +1,170 @@ +import { describe, it, expect, vi } from 'vitest'; +import { ProviderRegistry } from '../src/providers/registry.js'; +import { FailoverRouter } from '../src/providers/failover-router.js'; +import type { LlmProvider, CompleteResponse } from '../src/providers/types.js'; + +function fakeProvider(name: string): LlmProvider { + const completeFn = vi.fn(async (): Promise => ({ + content: 'local response', + finishReason: 'stop', + })); + return { + name, + complete: completeFn, + listModels: vi.fn(async () => [name]), + isAvailable: vi.fn(async () => true), + }; +} + +function makeFetch(behaviour: { method: string; status?: number; throw?: boolean }): ReturnType { + return vi.fn(async (url: string | URL, init?: RequestInit) => { + if (behaviour.throw === true) throw new Error('connection refused'); + expect(init?.method).toBe(behaviour.method); + expect(String(url)).toMatch(/\/api\/v1\/llms\//); + return new Response(null, { status: behaviour.status ?? 200 }); + }); +} + +describe('ProviderRegistry — failover map', () => { + it('registerFailover maps a central name → local provider name', () => { + const reg = new ProviderRegistry(); + const local = fakeProvider('vllm-local'); + reg.register(local); + reg.registerFailover('claude', 'vllm-local'); + + const found = reg.getFailoverFor('claude'); + expect(found?.name).toBe('vllm-local'); + }); + + it('getFailoverFor returns null when no map entry exists', () => { + const reg = new ProviderRegistry(); + reg.register(fakeProvider('vllm-local')); + expect(reg.getFailoverFor('claude')).toBeNull(); + }); + + it('registerFailover throws when local provider is not registered', () => { + const reg = new ProviderRegistry(); + expect(() => reg.registerFailover('claude', 'missing')).toThrow(/not registered/); + }); + + it('unregister removes failover entries that pointed at the removed provider', () => { + const reg = new ProviderRegistry(); + reg.register(fakeProvider('vllm-local')); + reg.registerFailover('claude', 'vllm-local'); + reg.unregister('vllm-local'); + expect(reg.getFailoverFor('claude')).toBeNull(); + expect(reg.listFailovers()).toEqual([]); + }); + + it('listFailovers reports the current map', () => { + const reg = new ProviderRegistry(); + reg.register(fakeProvider('vllm-local')); + reg.registerFailover('claude', 'vllm-local'); + reg.registerFailover('opus', 'vllm-local'); + expect(reg.listFailovers()).toEqual([ + { centralLlmName: 'claude', localProviderName: 'vllm-local' }, + { centralLlmName: 'opus', localProviderName: 'vllm-local' }, + ]); + }); +}); + +describe('FailoverRouter', () => { + it('returns primary result when primary succeeds', async () => { + const reg = new ProviderRegistry(); + reg.register(fakeProvider('vllm-local')); + reg.registerFailover('claude', 'vllm-local'); + + const router = new FailoverRouter(reg, { + mcpdUrl: 'http://mcpd', + fetch: vi.fn() as unknown as typeof fetch, + }); + const out = await router.run('claude', async () => 'central', async () => 'local'); + expect(out.failover).toBe(false); + expect(out.result).toBe('central'); + }); + + it('falls back to local when primary fails AND mcpd auth-checks 200', async () => { + const reg = new ProviderRegistry(); + reg.register(fakeProvider('vllm-local')); + reg.registerFailover('claude', 'vllm-local'); + + const fetchFn = makeFetch({ method: 'HEAD', status: 200 }); + const router = new FailoverRouter(reg, { + mcpdUrl: 'http://mcpd', + fetch: fetchFn as unknown as typeof fetch, + bearerToken: 'bearer-x', + }); + const out = await router.run( + 'claude', + async () => { throw new Error('upstream down'); }, + async (provider) => `via:${provider.name}`, + ); + expect(out.failover).toBe(true); + expect(out.via).toBe('vllm-local'); + expect(out.result).toBe('via:vllm-local'); + + // Bearer was attached + const [, init] = fetchFn.mock.calls[0] as [string, RequestInit]; + expect((init.headers as Record)['Authorization']).toBe('Bearer bearer-x'); + }); + + it('re-throws primary error when no local failover is registered', async () => { + const reg = new ProviderRegistry(); + const router = new FailoverRouter(reg, { + mcpdUrl: 'http://mcpd', + fetch: vi.fn() as unknown as typeof fetch, + }); + await expect(router.run( + 'claude', + async () => { throw new Error('boom'); }, + async () => 'never', + )).rejects.toThrow('boom'); + }); + + it('re-throws (fail-closed) when mcpd returns 403 to the auth check', async () => { + const reg = new ProviderRegistry(); + reg.register(fakeProvider('vllm-local')); + reg.registerFailover('claude', 'vllm-local'); + + const router = new FailoverRouter(reg, { + mcpdUrl: 'http://mcpd', + fetch: makeFetch({ method: 'HEAD', status: 403 }) as unknown as typeof fetch, + }); + await expect(router.run( + 'claude', + async () => { throw new Error('upstream down'); }, + async () => 'never', + )).rejects.toThrow('upstream down'); + }); + + it('re-throws (fail-closed) when mcpd itself is unreachable for the auth check', async () => { + const reg = new ProviderRegistry(); + reg.register(fakeProvider('vllm-local')); + reg.registerFailover('claude', 'vllm-local'); + + const router = new FailoverRouter(reg, { + mcpdUrl: 'http://mcpd', + fetch: makeFetch({ method: 'HEAD', throw: true }) as unknown as typeof fetch, + }); + await expect(router.run( + 'claude', + async () => { throw new Error('upstream down'); }, + async () => 'never', + )).rejects.toThrow('upstream down'); + }); + + it('checkAuth maps responses correctly', async () => { + const reg = new ProviderRegistry(); + const make = (status: number) => new FailoverRouter(reg, { + mcpdUrl: 'http://mcpd', + fetch: (async () => new Response(null, { status })) as unknown as typeof fetch, + }); + + expect(await make(200).checkAuth('claude')).toBe('allowed'); + expect(await make(204).checkAuth('claude')).toBe('allowed'); + expect(await make(401).checkAuth('claude')).toBe('forbidden'); + expect(await make(403).checkAuth('claude')).toBe('forbidden'); + expect(await make(404).checkAuth('claude')).toBe('unreachable'); + expect(await make(500).checkAuth('claude')).toBe('unreachable'); + }); +});