Merge pull request 'feat(mcplocal): RBAC-bounded vllm-managed failover' (#54) from feat/llm-failover into main
Some checks failed
Some checks failed
This commit was merged in pull request #54.
This commit is contained in:
@@ -10,9 +10,12 @@ export function registerLlmRoutes(
|
||||
return service.list();
|
||||
});
|
||||
|
||||
// Accepts either CUID or human name. Used both by the CLI (which usually
|
||||
// resolves to CUID first) and by FailoverRouter's RBAC pre-check (which
|
||||
// hands over the user-facing name to avoid an extra round-trip).
|
||||
app.get<{ Params: { id: string } }>('/api/v1/llms/:id', async (request, reply) => {
|
||||
try {
|
||||
return await service.getById(request.params.id);
|
||||
return await getByIdOrName(service, request.params.id);
|
||||
} catch (err) {
|
||||
if (err instanceof NotFoundError) {
|
||||
reply.code(404);
|
||||
@@ -22,6 +25,10 @@ export function registerLlmRoutes(
|
||||
}
|
||||
});
|
||||
|
||||
// No explicit HEAD handler: Fastify auto-derives HEAD from GET, which runs
|
||||
// the same RBAC hook + lookup and drops the body. That's exactly what
|
||||
// FailoverRouter wants for its "can the caller still view this Llm?" probe.
|
||||
|
||||
app.post('/api/v1/llms', async (request, reply) => {
|
||||
try {
|
||||
const row = await service.create(request.body);
|
||||
@@ -62,3 +69,17 @@ export function registerLlmRoutes(
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
const CUID_RE = /^c[a-z0-9]{24}/i;
|
||||
|
||||
/**
|
||||
* Look up by CUID first; if the input doesn't look like one, fall back to
|
||||
* findByName. Lets the same URL serve both `mcpctl describe llm <name>` and
|
||||
* the FailoverRouter's name-based RBAC check.
|
||||
*/
|
||||
async function getByIdOrName(service: LlmService, idOrName: string) {
|
||||
if (CUID_RE.test(idOrName)) {
|
||||
return service.getById(idOrName);
|
||||
}
|
||||
return service.getByName(idOrName);
|
||||
}
|
||||
|
||||
@@ -104,6 +104,25 @@ describe('Llm Routes', () => {
|
||||
expect(res.statusCode).toBe(404);
|
||||
});
|
||||
|
||||
it('GET /api/v1/llms/:nameOrId resolves by human name when not a CUID', async () => {
|
||||
await createApp(mockRepo([makeLlm({ id: 'llm-1', name: 'claude' })]));
|
||||
const res = await app.inject({ method: 'GET', url: '/api/v1/llms/claude' });
|
||||
expect(res.statusCode).toBe(200);
|
||||
expect(res.json<{ name: string; id: string }>().name).toBe('claude');
|
||||
});
|
||||
|
||||
it('HEAD /api/v1/llms/:name returns 200 for an existing Llm (failover RBAC pre-check)', async () => {
|
||||
await createApp(mockRepo([makeLlm({ name: 'claude' })]));
|
||||
const res = await app.inject({ method: 'HEAD', url: '/api/v1/llms/claude' });
|
||||
expect(res.statusCode).toBe(200);
|
||||
});
|
||||
|
||||
it('HEAD /api/v1/llms/:name returns 404 for a missing Llm', async () => {
|
||||
await createApp(mockRepo());
|
||||
const res = await app.inject({ method: 'HEAD', url: '/api/v1/llms/missing' });
|
||||
expect(res.statusCode).toBe(404);
|
||||
});
|
||||
|
||||
it('POST /api/v1/llms creates and returns 201', async () => {
|
||||
await createApp(mockRepo());
|
||||
const res = await app.inject({
|
||||
|
||||
@@ -64,6 +64,14 @@ export interface LlmProviderFileEntry {
|
||||
idleTimeoutMinutes?: number;
|
||||
/** vllm-managed: extra args for `vllm serve` */
|
||||
extraArgs?: string[];
|
||||
/**
|
||||
* If set, this local provider is allowed to substitute for the centralized
|
||||
* Llm of this name when the mcpd inference proxy is unreachable.
|
||||
* RBAC is still enforced — the caller must have view permission on the
|
||||
* named Llm via mcpd before failover is permitted (fail-closed if mcpd
|
||||
* itself can't be reached).
|
||||
*/
|
||||
failoverFor?: string;
|
||||
}
|
||||
|
||||
export interface ProjectLlmOverride {
|
||||
|
||||
@@ -173,6 +173,9 @@ export async function createProvidersFromConfig(
|
||||
if (entry.tier) {
|
||||
registry.assignTier(provider.name, entry.tier);
|
||||
}
|
||||
if (entry.failoverFor) {
|
||||
registry.registerFailover(entry.failoverFor, provider.name);
|
||||
}
|
||||
}
|
||||
|
||||
return registry;
|
||||
|
||||
107
src/mcplocal/src/providers/failover-router.ts
Normal file
107
src/mcplocal/src/providers/failover-router.ts
Normal file
@@ -0,0 +1,107 @@
|
||||
/**
|
||||
* FailoverRouter — orchestrates "try mcpd's centralized Llm, fall back to a
|
||||
* local provider when authorized" for clients that consume the inference
|
||||
* proxy.
|
||||
*
|
||||
* Decision flow on a centralized inference call:
|
||||
*
|
||||
* 1. Call the primary (the supplied `primary` callback, typically an HTTP
|
||||
* POST to mcpd /api/v1/llms/:name/infer).
|
||||
* 2. If that succeeds → done.
|
||||
* 3. If it fails AND a local provider is registered as failover for this
|
||||
* Llm name → call mcpd /api/v1/llms/:name (RBAC-gated) to verify the
|
||||
* caller still has permission to view this Llm. mcpd unreachable →
|
||||
* fail-closed (re-throw the original error). 403 → fail-closed.
|
||||
* 4. 200 → invoke the local provider's `complete()` and tag the result
|
||||
* as `failover: true` for client-side audit.
|
||||
*
|
||||
* The check call uses HEAD to avoid pulling the Llm body (and any
|
||||
* description / extraConfig) over the wire — mcpd treats both methods the
|
||||
* same in the RBAC hook because the URL maps to the same permission.
|
||||
*/
|
||||
import type { LlmProvider } from './types.js';
|
||||
import type { ProviderRegistry } from './registry.js';
|
||||
|
||||
export interface FailoverDecision<T> {
|
||||
result: T;
|
||||
failover: boolean;
|
||||
/** Name of the local provider used (only set when failover === true). */
|
||||
via?: string;
|
||||
}
|
||||
|
||||
export interface FailoverRouterDeps {
|
||||
/** Injected fetch for the RBAC pre-check. Tests mock this. */
|
||||
fetch?: typeof globalThis.fetch;
|
||||
/** mcpd base URL (no trailing slash). */
|
||||
mcpdUrl: string;
|
||||
/** Bearer token to attach to the RBAC pre-check call. */
|
||||
bearerToken?: string;
|
||||
}
|
||||
|
||||
/** Outcome of the RBAC pre-check. Used internally + exposed for tests. */
|
||||
export type AuthCheckOutcome = 'allowed' | 'forbidden' | 'unreachable';
|
||||
|
||||
export class FailoverRouter {
|
||||
private readonly fetchImpl: typeof globalThis.fetch;
|
||||
private readonly mcpdUrl: string;
|
||||
private readonly bearer: string | undefined;
|
||||
|
||||
constructor(
|
||||
private readonly registry: ProviderRegistry,
|
||||
deps: FailoverRouterDeps,
|
||||
) {
|
||||
this.fetchImpl = deps.fetch ?? globalThis.fetch;
|
||||
this.mcpdUrl = deps.mcpdUrl.replace(/\/+$/, '');
|
||||
if (deps.bearerToken !== undefined) this.bearer = deps.bearerToken;
|
||||
}
|
||||
|
||||
/**
|
||||
* Run a primary inference attempt; on failure, fall back to the local
|
||||
* provider if one is registered for this Llm AND the caller still has
|
||||
* `view:llms:<llmName>` on mcpd.
|
||||
*
|
||||
* `primary` should reject (throw) when mcpd's proxy is unreachable or
|
||||
* returns a 5xx — that's the signal to consider failover. 4xx errors that
|
||||
* indicate a bad request are surfaced as-is; the router only retries on
|
||||
* primary failure shapes that look like an upstream/network issue.
|
||||
*/
|
||||
async run<T>(
|
||||
llmName: string,
|
||||
primary: () => Promise<T>,
|
||||
localCall: (provider: LlmProvider) => Promise<T>,
|
||||
): Promise<FailoverDecision<T>> {
|
||||
try {
|
||||
const result = await primary();
|
||||
return { result, failover: false };
|
||||
} catch (primaryErr) {
|
||||
const local = this.registry.getFailoverFor(llmName);
|
||||
if (local === null) throw primaryErr;
|
||||
|
||||
const auth = await this.checkAuth(llmName);
|
||||
if (auth !== 'allowed') {
|
||||
// Fail-closed for forbidden AND unreachable.
|
||||
throw primaryErr;
|
||||
}
|
||||
|
||||
const result = await localCall(local);
|
||||
return { result, failover: true, via: local.name };
|
||||
}
|
||||
}
|
||||
|
||||
/** RBAC pre-check exposed for tests / status-display callers. */
|
||||
async checkAuth(llmName: string): Promise<AuthCheckOutcome> {
|
||||
const url = `${this.mcpdUrl}/api/v1/llms/${encodeURIComponent(llmName)}`;
|
||||
const headers: Record<string, string> = {};
|
||||
if (this.bearer !== undefined) headers['Authorization'] = `Bearer ${this.bearer}`;
|
||||
let res: Response;
|
||||
try {
|
||||
res = await this.fetchImpl(url, { method: 'HEAD', headers });
|
||||
} catch {
|
||||
return 'unreachable';
|
||||
}
|
||||
if (res.status === 200 || res.status === 204) return 'allowed';
|
||||
if (res.status === 403 || res.status === 401) return 'forbidden';
|
||||
// Anything else (404, 500…) — treat as unreachable for the failover flow.
|
||||
return 'unreachable';
|
||||
}
|
||||
}
|
||||
@@ -8,6 +8,8 @@ export class ProviderRegistry {
|
||||
private providers = new Map<string, LlmProvider>();
|
||||
private activeProvider: string | null = null;
|
||||
private tierProviders = new Map<Tier, string[]>();
|
||||
/** Maps a centralized Llm name → local provider name that can substitute when mcpd is unreachable. */
|
||||
private failoverMap = new Map<string, string>();
|
||||
|
||||
register(provider: LlmProvider): void {
|
||||
this.providers.set(provider.name, provider);
|
||||
@@ -31,6 +33,30 @@ export class ProviderRegistry {
|
||||
this.tierProviders.set(tier, filtered);
|
||||
}
|
||||
}
|
||||
// Remove from failover map (any entry whose local-provider value points at this name)
|
||||
for (const [centralName, localName] of this.failoverMap) {
|
||||
if (localName === name) this.failoverMap.delete(centralName);
|
||||
}
|
||||
}
|
||||
|
||||
/** Mark `localProviderName` as the failover for the centralized Llm named `centralLlmName`. */
|
||||
registerFailover(centralLlmName: string, localProviderName: string): void {
|
||||
if (!this.providers.has(localProviderName)) {
|
||||
throw new Error(`Provider '${localProviderName}' is not registered`);
|
||||
}
|
||||
this.failoverMap.set(centralLlmName, localProviderName);
|
||||
}
|
||||
|
||||
/** Look up the local provider that can substitute for a centralized Llm, if any. */
|
||||
getFailoverFor(centralLlmName: string): LlmProvider | null {
|
||||
const localName = this.failoverMap.get(centralLlmName);
|
||||
if (localName === undefined) return null;
|
||||
return this.providers.get(localName) ?? null;
|
||||
}
|
||||
|
||||
/** Names of central Llms that have a local failover registered. Used in status output. */
|
||||
listFailovers(): Array<{ centralLlmName: string; localProviderName: string }> {
|
||||
return [...this.failoverMap.entries()].map(([centralLlmName, localProviderName]) => ({ centralLlmName, localProviderName }));
|
||||
}
|
||||
|
||||
setActive(name: string): void {
|
||||
|
||||
170
src/mcplocal/tests/failover-router.test.ts
Normal file
170
src/mcplocal/tests/failover-router.test.ts
Normal file
@@ -0,0 +1,170 @@
|
||||
import { describe, it, expect, vi } from 'vitest';
|
||||
import { ProviderRegistry } from '../src/providers/registry.js';
|
||||
import { FailoverRouter } from '../src/providers/failover-router.js';
|
||||
import type { LlmProvider, CompleteResponse } from '../src/providers/types.js';
|
||||
|
||||
function fakeProvider(name: string): LlmProvider {
|
||||
const completeFn = vi.fn(async (): Promise<CompleteResponse> => ({
|
||||
content: 'local response',
|
||||
finishReason: 'stop',
|
||||
}));
|
||||
return {
|
||||
name,
|
||||
complete: completeFn,
|
||||
listModels: vi.fn(async () => [name]),
|
||||
isAvailable: vi.fn(async () => true),
|
||||
};
|
||||
}
|
||||
|
||||
function makeFetch(behaviour: { method: string; status?: number; throw?: boolean }): ReturnType<typeof vi.fn> {
|
||||
return vi.fn(async (url: string | URL, init?: RequestInit) => {
|
||||
if (behaviour.throw === true) throw new Error('connection refused');
|
||||
expect(init?.method).toBe(behaviour.method);
|
||||
expect(String(url)).toMatch(/\/api\/v1\/llms\//);
|
||||
return new Response(null, { status: behaviour.status ?? 200 });
|
||||
});
|
||||
}
|
||||
|
||||
describe('ProviderRegistry — failover map', () => {
|
||||
it('registerFailover maps a central name → local provider name', () => {
|
||||
const reg = new ProviderRegistry();
|
||||
const local = fakeProvider('vllm-local');
|
||||
reg.register(local);
|
||||
reg.registerFailover('claude', 'vllm-local');
|
||||
|
||||
const found = reg.getFailoverFor('claude');
|
||||
expect(found?.name).toBe('vllm-local');
|
||||
});
|
||||
|
||||
it('getFailoverFor returns null when no map entry exists', () => {
|
||||
const reg = new ProviderRegistry();
|
||||
reg.register(fakeProvider('vllm-local'));
|
||||
expect(reg.getFailoverFor('claude')).toBeNull();
|
||||
});
|
||||
|
||||
it('registerFailover throws when local provider is not registered', () => {
|
||||
const reg = new ProviderRegistry();
|
||||
expect(() => reg.registerFailover('claude', 'missing')).toThrow(/not registered/);
|
||||
});
|
||||
|
||||
it('unregister removes failover entries that pointed at the removed provider', () => {
|
||||
const reg = new ProviderRegistry();
|
||||
reg.register(fakeProvider('vllm-local'));
|
||||
reg.registerFailover('claude', 'vllm-local');
|
||||
reg.unregister('vllm-local');
|
||||
expect(reg.getFailoverFor('claude')).toBeNull();
|
||||
expect(reg.listFailovers()).toEqual([]);
|
||||
});
|
||||
|
||||
it('listFailovers reports the current map', () => {
|
||||
const reg = new ProviderRegistry();
|
||||
reg.register(fakeProvider('vllm-local'));
|
||||
reg.registerFailover('claude', 'vllm-local');
|
||||
reg.registerFailover('opus', 'vllm-local');
|
||||
expect(reg.listFailovers()).toEqual([
|
||||
{ centralLlmName: 'claude', localProviderName: 'vllm-local' },
|
||||
{ centralLlmName: 'opus', localProviderName: 'vllm-local' },
|
||||
]);
|
||||
});
|
||||
});
|
||||
|
||||
describe('FailoverRouter', () => {
|
||||
it('returns primary result when primary succeeds', async () => {
|
||||
const reg = new ProviderRegistry();
|
||||
reg.register(fakeProvider('vllm-local'));
|
||||
reg.registerFailover('claude', 'vllm-local');
|
||||
|
||||
const router = new FailoverRouter(reg, {
|
||||
mcpdUrl: 'http://mcpd',
|
||||
fetch: vi.fn() as unknown as typeof fetch,
|
||||
});
|
||||
const out = await router.run('claude', async () => 'central', async () => 'local');
|
||||
expect(out.failover).toBe(false);
|
||||
expect(out.result).toBe('central');
|
||||
});
|
||||
|
||||
it('falls back to local when primary fails AND mcpd auth-checks 200', async () => {
|
||||
const reg = new ProviderRegistry();
|
||||
reg.register(fakeProvider('vllm-local'));
|
||||
reg.registerFailover('claude', 'vllm-local');
|
||||
|
||||
const fetchFn = makeFetch({ method: 'HEAD', status: 200 });
|
||||
const router = new FailoverRouter(reg, {
|
||||
mcpdUrl: 'http://mcpd',
|
||||
fetch: fetchFn as unknown as typeof fetch,
|
||||
bearerToken: 'bearer-x',
|
||||
});
|
||||
const out = await router.run(
|
||||
'claude',
|
||||
async () => { throw new Error('upstream down'); },
|
||||
async (provider) => `via:${provider.name}`,
|
||||
);
|
||||
expect(out.failover).toBe(true);
|
||||
expect(out.via).toBe('vllm-local');
|
||||
expect(out.result).toBe('via:vllm-local');
|
||||
|
||||
// Bearer was attached
|
||||
const [, init] = fetchFn.mock.calls[0] as [string, RequestInit];
|
||||
expect((init.headers as Record<string, string>)['Authorization']).toBe('Bearer bearer-x');
|
||||
});
|
||||
|
||||
it('re-throws primary error when no local failover is registered', async () => {
|
||||
const reg = new ProviderRegistry();
|
||||
const router = new FailoverRouter(reg, {
|
||||
mcpdUrl: 'http://mcpd',
|
||||
fetch: vi.fn() as unknown as typeof fetch,
|
||||
});
|
||||
await expect(router.run(
|
||||
'claude',
|
||||
async () => { throw new Error('boom'); },
|
||||
async () => 'never',
|
||||
)).rejects.toThrow('boom');
|
||||
});
|
||||
|
||||
it('re-throws (fail-closed) when mcpd returns 403 to the auth check', async () => {
|
||||
const reg = new ProviderRegistry();
|
||||
reg.register(fakeProvider('vllm-local'));
|
||||
reg.registerFailover('claude', 'vllm-local');
|
||||
|
||||
const router = new FailoverRouter(reg, {
|
||||
mcpdUrl: 'http://mcpd',
|
||||
fetch: makeFetch({ method: 'HEAD', status: 403 }) as unknown as typeof fetch,
|
||||
});
|
||||
await expect(router.run(
|
||||
'claude',
|
||||
async () => { throw new Error('upstream down'); },
|
||||
async () => 'never',
|
||||
)).rejects.toThrow('upstream down');
|
||||
});
|
||||
|
||||
it('re-throws (fail-closed) when mcpd itself is unreachable for the auth check', async () => {
|
||||
const reg = new ProviderRegistry();
|
||||
reg.register(fakeProvider('vllm-local'));
|
||||
reg.registerFailover('claude', 'vllm-local');
|
||||
|
||||
const router = new FailoverRouter(reg, {
|
||||
mcpdUrl: 'http://mcpd',
|
||||
fetch: makeFetch({ method: 'HEAD', throw: true }) as unknown as typeof fetch,
|
||||
});
|
||||
await expect(router.run(
|
||||
'claude',
|
||||
async () => { throw new Error('upstream down'); },
|
||||
async () => 'never',
|
||||
)).rejects.toThrow('upstream down');
|
||||
});
|
||||
|
||||
it('checkAuth maps responses correctly', async () => {
|
||||
const reg = new ProviderRegistry();
|
||||
const make = (status: number) => new FailoverRouter(reg, {
|
||||
mcpdUrl: 'http://mcpd',
|
||||
fetch: (async () => new Response(null, { status })) as unknown as typeof fetch,
|
||||
});
|
||||
|
||||
expect(await make(200).checkAuth('claude')).toBe('allowed');
|
||||
expect(await make(204).checkAuth('claude')).toBe('allowed');
|
||||
expect(await make(401).checkAuth('claude')).toBe('forbidden');
|
||||
expect(await make(403).checkAuth('claude')).toBe('forbidden');
|
||||
expect(await make(404).checkAuth('claude')).toBe('unreachable');
|
||||
expect(await make(500).checkAuth('claude')).toBe('unreachable');
|
||||
});
|
||||
});
|
||||
Reference in New Issue
Block a user