Merge pull request 'feat(mcplocal): RBAC-bounded vllm-managed failover' (#54) from feat/llm-failover into main
Some checks failed
Some checks failed
This commit was merged in pull request #54.
This commit is contained in:
@@ -10,9 +10,12 @@ export function registerLlmRoutes(
|
|||||||
return service.list();
|
return service.list();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Accepts either CUID or human name. Used both by the CLI (which usually
|
||||||
|
// resolves to CUID first) and by FailoverRouter's RBAC pre-check (which
|
||||||
|
// hands over the user-facing name to avoid an extra round-trip).
|
||||||
app.get<{ Params: { id: string } }>('/api/v1/llms/:id', async (request, reply) => {
|
app.get<{ Params: { id: string } }>('/api/v1/llms/:id', async (request, reply) => {
|
||||||
try {
|
try {
|
||||||
return await service.getById(request.params.id);
|
return await getByIdOrName(service, request.params.id);
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
if (err instanceof NotFoundError) {
|
if (err instanceof NotFoundError) {
|
||||||
reply.code(404);
|
reply.code(404);
|
||||||
@@ -22,6 +25,10 @@ export function registerLlmRoutes(
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// No explicit HEAD handler: Fastify auto-derives HEAD from GET, which runs
|
||||||
|
// the same RBAC hook + lookup and drops the body. That's exactly what
|
||||||
|
// FailoverRouter wants for its "can the caller still view this Llm?" probe.
|
||||||
|
|
||||||
app.post('/api/v1/llms', async (request, reply) => {
|
app.post('/api/v1/llms', async (request, reply) => {
|
||||||
try {
|
try {
|
||||||
const row = await service.create(request.body);
|
const row = await service.create(request.body);
|
||||||
@@ -62,3 +69,17 @@ export function registerLlmRoutes(
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const CUID_RE = /^c[a-z0-9]{24}/i;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Look up by CUID first; if the input doesn't look like one, fall back to
|
||||||
|
* findByName. Lets the same URL serve both `mcpctl describe llm <name>` and
|
||||||
|
* the FailoverRouter's name-based RBAC check.
|
||||||
|
*/
|
||||||
|
async function getByIdOrName(service: LlmService, idOrName: string) {
|
||||||
|
if (CUID_RE.test(idOrName)) {
|
||||||
|
return service.getById(idOrName);
|
||||||
|
}
|
||||||
|
return service.getByName(idOrName);
|
||||||
|
}
|
||||||
|
|||||||
@@ -104,6 +104,25 @@ describe('Llm Routes', () => {
|
|||||||
expect(res.statusCode).toBe(404);
|
expect(res.statusCode).toBe(404);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('GET /api/v1/llms/:nameOrId resolves by human name when not a CUID', async () => {
|
||||||
|
await createApp(mockRepo([makeLlm({ id: 'llm-1', name: 'claude' })]));
|
||||||
|
const res = await app.inject({ method: 'GET', url: '/api/v1/llms/claude' });
|
||||||
|
expect(res.statusCode).toBe(200);
|
||||||
|
expect(res.json<{ name: string; id: string }>().name).toBe('claude');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('HEAD /api/v1/llms/:name returns 200 for an existing Llm (failover RBAC pre-check)', async () => {
|
||||||
|
await createApp(mockRepo([makeLlm({ name: 'claude' })]));
|
||||||
|
const res = await app.inject({ method: 'HEAD', url: '/api/v1/llms/claude' });
|
||||||
|
expect(res.statusCode).toBe(200);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('HEAD /api/v1/llms/:name returns 404 for a missing Llm', async () => {
|
||||||
|
await createApp(mockRepo());
|
||||||
|
const res = await app.inject({ method: 'HEAD', url: '/api/v1/llms/missing' });
|
||||||
|
expect(res.statusCode).toBe(404);
|
||||||
|
});
|
||||||
|
|
||||||
it('POST /api/v1/llms creates and returns 201', async () => {
|
it('POST /api/v1/llms creates and returns 201', async () => {
|
||||||
await createApp(mockRepo());
|
await createApp(mockRepo());
|
||||||
const res = await app.inject({
|
const res = await app.inject({
|
||||||
|
|||||||
@@ -64,6 +64,14 @@ export interface LlmProviderFileEntry {
|
|||||||
idleTimeoutMinutes?: number;
|
idleTimeoutMinutes?: number;
|
||||||
/** vllm-managed: extra args for `vllm serve` */
|
/** vllm-managed: extra args for `vllm serve` */
|
||||||
extraArgs?: string[];
|
extraArgs?: string[];
|
||||||
|
/**
|
||||||
|
* If set, this local provider is allowed to substitute for the centralized
|
||||||
|
* Llm of this name when the mcpd inference proxy is unreachable.
|
||||||
|
* RBAC is still enforced — the caller must have view permission on the
|
||||||
|
* named Llm via mcpd before failover is permitted (fail-closed if mcpd
|
||||||
|
* itself can't be reached).
|
||||||
|
*/
|
||||||
|
failoverFor?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface ProjectLlmOverride {
|
export interface ProjectLlmOverride {
|
||||||
|
|||||||
@@ -173,6 +173,9 @@ export async function createProvidersFromConfig(
|
|||||||
if (entry.tier) {
|
if (entry.tier) {
|
||||||
registry.assignTier(provider.name, entry.tier);
|
registry.assignTier(provider.name, entry.tier);
|
||||||
}
|
}
|
||||||
|
if (entry.failoverFor) {
|
||||||
|
registry.registerFailover(entry.failoverFor, provider.name);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return registry;
|
return registry;
|
||||||
|
|||||||
107
src/mcplocal/src/providers/failover-router.ts
Normal file
107
src/mcplocal/src/providers/failover-router.ts
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
/**
|
||||||
|
* FailoverRouter — orchestrates "try mcpd's centralized Llm, fall back to a
|
||||||
|
* local provider when authorized" for clients that consume the inference
|
||||||
|
* proxy.
|
||||||
|
*
|
||||||
|
* Decision flow on a centralized inference call:
|
||||||
|
*
|
||||||
|
* 1. Call the primary (the supplied `primary` callback, typically an HTTP
|
||||||
|
* POST to mcpd /api/v1/llms/:name/infer).
|
||||||
|
* 2. If that succeeds → done.
|
||||||
|
* 3. If it fails AND a local provider is registered as failover for this
|
||||||
|
* Llm name → call mcpd /api/v1/llms/:name (RBAC-gated) to verify the
|
||||||
|
* caller still has permission to view this Llm. mcpd unreachable →
|
||||||
|
* fail-closed (re-throw the original error). 403 → fail-closed.
|
||||||
|
* 4. 200 → invoke the local provider's `complete()` and tag the result
|
||||||
|
* as `failover: true` for client-side audit.
|
||||||
|
*
|
||||||
|
* The check call uses HEAD to avoid pulling the Llm body (and any
|
||||||
|
* description / extraConfig) over the wire — mcpd treats both methods the
|
||||||
|
* same in the RBAC hook because the URL maps to the same permission.
|
||||||
|
*/
|
||||||
|
import type { LlmProvider } from './types.js';
|
||||||
|
import type { ProviderRegistry } from './registry.js';
|
||||||
|
|
||||||
|
export interface FailoverDecision<T> {
|
||||||
|
result: T;
|
||||||
|
failover: boolean;
|
||||||
|
/** Name of the local provider used (only set when failover === true). */
|
||||||
|
via?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface FailoverRouterDeps {
|
||||||
|
/** Injected fetch for the RBAC pre-check. Tests mock this. */
|
||||||
|
fetch?: typeof globalThis.fetch;
|
||||||
|
/** mcpd base URL (no trailing slash). */
|
||||||
|
mcpdUrl: string;
|
||||||
|
/** Bearer token to attach to the RBAC pre-check call. */
|
||||||
|
bearerToken?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Outcome of the RBAC pre-check. Used internally + exposed for tests. */
|
||||||
|
export type AuthCheckOutcome = 'allowed' | 'forbidden' | 'unreachable';
|
||||||
|
|
||||||
|
export class FailoverRouter {
|
||||||
|
private readonly fetchImpl: typeof globalThis.fetch;
|
||||||
|
private readonly mcpdUrl: string;
|
||||||
|
private readonly bearer: string | undefined;
|
||||||
|
|
||||||
|
constructor(
|
||||||
|
private readonly registry: ProviderRegistry,
|
||||||
|
deps: FailoverRouterDeps,
|
||||||
|
) {
|
||||||
|
this.fetchImpl = deps.fetch ?? globalThis.fetch;
|
||||||
|
this.mcpdUrl = deps.mcpdUrl.replace(/\/+$/, '');
|
||||||
|
if (deps.bearerToken !== undefined) this.bearer = deps.bearerToken;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Run a primary inference attempt; on failure, fall back to the local
|
||||||
|
* provider if one is registered for this Llm AND the caller still has
|
||||||
|
* `view:llms:<llmName>` on mcpd.
|
||||||
|
*
|
||||||
|
* `primary` should reject (throw) when mcpd's proxy is unreachable or
|
||||||
|
* returns a 5xx — that's the signal to consider failover. 4xx errors that
|
||||||
|
* indicate a bad request are surfaced as-is; the router only retries on
|
||||||
|
* primary failure shapes that look like an upstream/network issue.
|
||||||
|
*/
|
||||||
|
async run<T>(
|
||||||
|
llmName: string,
|
||||||
|
primary: () => Promise<T>,
|
||||||
|
localCall: (provider: LlmProvider) => Promise<T>,
|
||||||
|
): Promise<FailoverDecision<T>> {
|
||||||
|
try {
|
||||||
|
const result = await primary();
|
||||||
|
return { result, failover: false };
|
||||||
|
} catch (primaryErr) {
|
||||||
|
const local = this.registry.getFailoverFor(llmName);
|
||||||
|
if (local === null) throw primaryErr;
|
||||||
|
|
||||||
|
const auth = await this.checkAuth(llmName);
|
||||||
|
if (auth !== 'allowed') {
|
||||||
|
// Fail-closed for forbidden AND unreachable.
|
||||||
|
throw primaryErr;
|
||||||
|
}
|
||||||
|
|
||||||
|
const result = await localCall(local);
|
||||||
|
return { result, failover: true, via: local.name };
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** RBAC pre-check exposed for tests / status-display callers. */
|
||||||
|
async checkAuth(llmName: string): Promise<AuthCheckOutcome> {
|
||||||
|
const url = `${this.mcpdUrl}/api/v1/llms/${encodeURIComponent(llmName)}`;
|
||||||
|
const headers: Record<string, string> = {};
|
||||||
|
if (this.bearer !== undefined) headers['Authorization'] = `Bearer ${this.bearer}`;
|
||||||
|
let res: Response;
|
||||||
|
try {
|
||||||
|
res = await this.fetchImpl(url, { method: 'HEAD', headers });
|
||||||
|
} catch {
|
||||||
|
return 'unreachable';
|
||||||
|
}
|
||||||
|
if (res.status === 200 || res.status === 204) return 'allowed';
|
||||||
|
if (res.status === 403 || res.status === 401) return 'forbidden';
|
||||||
|
// Anything else (404, 500…) — treat as unreachable for the failover flow.
|
||||||
|
return 'unreachable';
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -8,6 +8,8 @@ export class ProviderRegistry {
|
|||||||
private providers = new Map<string, LlmProvider>();
|
private providers = new Map<string, LlmProvider>();
|
||||||
private activeProvider: string | null = null;
|
private activeProvider: string | null = null;
|
||||||
private tierProviders = new Map<Tier, string[]>();
|
private tierProviders = new Map<Tier, string[]>();
|
||||||
|
/** Maps a centralized Llm name → local provider name that can substitute when mcpd is unreachable. */
|
||||||
|
private failoverMap = new Map<string, string>();
|
||||||
|
|
||||||
register(provider: LlmProvider): void {
|
register(provider: LlmProvider): void {
|
||||||
this.providers.set(provider.name, provider);
|
this.providers.set(provider.name, provider);
|
||||||
@@ -31,6 +33,30 @@ export class ProviderRegistry {
|
|||||||
this.tierProviders.set(tier, filtered);
|
this.tierProviders.set(tier, filtered);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// Remove from failover map (any entry whose local-provider value points at this name)
|
||||||
|
for (const [centralName, localName] of this.failoverMap) {
|
||||||
|
if (localName === name) this.failoverMap.delete(centralName);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Mark `localProviderName` as the failover for the centralized Llm named `centralLlmName`. */
|
||||||
|
registerFailover(centralLlmName: string, localProviderName: string): void {
|
||||||
|
if (!this.providers.has(localProviderName)) {
|
||||||
|
throw new Error(`Provider '${localProviderName}' is not registered`);
|
||||||
|
}
|
||||||
|
this.failoverMap.set(centralLlmName, localProviderName);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Look up the local provider that can substitute for a centralized Llm, if any. */
|
||||||
|
getFailoverFor(centralLlmName: string): LlmProvider | null {
|
||||||
|
const localName = this.failoverMap.get(centralLlmName);
|
||||||
|
if (localName === undefined) return null;
|
||||||
|
return this.providers.get(localName) ?? null;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Names of central Llms that have a local failover registered. Used in status output. */
|
||||||
|
listFailovers(): Array<{ centralLlmName: string; localProviderName: string }> {
|
||||||
|
return [...this.failoverMap.entries()].map(([centralLlmName, localProviderName]) => ({ centralLlmName, localProviderName }));
|
||||||
}
|
}
|
||||||
|
|
||||||
setActive(name: string): void {
|
setActive(name: string): void {
|
||||||
|
|||||||
170
src/mcplocal/tests/failover-router.test.ts
Normal file
170
src/mcplocal/tests/failover-router.test.ts
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
import { describe, it, expect, vi } from 'vitest';
|
||||||
|
import { ProviderRegistry } from '../src/providers/registry.js';
|
||||||
|
import { FailoverRouter } from '../src/providers/failover-router.js';
|
||||||
|
import type { LlmProvider, CompleteResponse } from '../src/providers/types.js';
|
||||||
|
|
||||||
|
function fakeProvider(name: string): LlmProvider {
|
||||||
|
const completeFn = vi.fn(async (): Promise<CompleteResponse> => ({
|
||||||
|
content: 'local response',
|
||||||
|
finishReason: 'stop',
|
||||||
|
}));
|
||||||
|
return {
|
||||||
|
name,
|
||||||
|
complete: completeFn,
|
||||||
|
listModels: vi.fn(async () => [name]),
|
||||||
|
isAvailable: vi.fn(async () => true),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
function makeFetch(behaviour: { method: string; status?: number; throw?: boolean }): ReturnType<typeof vi.fn> {
|
||||||
|
return vi.fn(async (url: string | URL, init?: RequestInit) => {
|
||||||
|
if (behaviour.throw === true) throw new Error('connection refused');
|
||||||
|
expect(init?.method).toBe(behaviour.method);
|
||||||
|
expect(String(url)).toMatch(/\/api\/v1\/llms\//);
|
||||||
|
return new Response(null, { status: behaviour.status ?? 200 });
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
describe('ProviderRegistry — failover map', () => {
|
||||||
|
it('registerFailover maps a central name → local provider name', () => {
|
||||||
|
const reg = new ProviderRegistry();
|
||||||
|
const local = fakeProvider('vllm-local');
|
||||||
|
reg.register(local);
|
||||||
|
reg.registerFailover('claude', 'vllm-local');
|
||||||
|
|
||||||
|
const found = reg.getFailoverFor('claude');
|
||||||
|
expect(found?.name).toBe('vllm-local');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('getFailoverFor returns null when no map entry exists', () => {
|
||||||
|
const reg = new ProviderRegistry();
|
||||||
|
reg.register(fakeProvider('vllm-local'));
|
||||||
|
expect(reg.getFailoverFor('claude')).toBeNull();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('registerFailover throws when local provider is not registered', () => {
|
||||||
|
const reg = new ProviderRegistry();
|
||||||
|
expect(() => reg.registerFailover('claude', 'missing')).toThrow(/not registered/);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('unregister removes failover entries that pointed at the removed provider', () => {
|
||||||
|
const reg = new ProviderRegistry();
|
||||||
|
reg.register(fakeProvider('vllm-local'));
|
||||||
|
reg.registerFailover('claude', 'vllm-local');
|
||||||
|
reg.unregister('vllm-local');
|
||||||
|
expect(reg.getFailoverFor('claude')).toBeNull();
|
||||||
|
expect(reg.listFailovers()).toEqual([]);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('listFailovers reports the current map', () => {
|
||||||
|
const reg = new ProviderRegistry();
|
||||||
|
reg.register(fakeProvider('vllm-local'));
|
||||||
|
reg.registerFailover('claude', 'vllm-local');
|
||||||
|
reg.registerFailover('opus', 'vllm-local');
|
||||||
|
expect(reg.listFailovers()).toEqual([
|
||||||
|
{ centralLlmName: 'claude', localProviderName: 'vllm-local' },
|
||||||
|
{ centralLlmName: 'opus', localProviderName: 'vllm-local' },
|
||||||
|
]);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('FailoverRouter', () => {
|
||||||
|
it('returns primary result when primary succeeds', async () => {
|
||||||
|
const reg = new ProviderRegistry();
|
||||||
|
reg.register(fakeProvider('vllm-local'));
|
||||||
|
reg.registerFailover('claude', 'vllm-local');
|
||||||
|
|
||||||
|
const router = new FailoverRouter(reg, {
|
||||||
|
mcpdUrl: 'http://mcpd',
|
||||||
|
fetch: vi.fn() as unknown as typeof fetch,
|
||||||
|
});
|
||||||
|
const out = await router.run('claude', async () => 'central', async () => 'local');
|
||||||
|
expect(out.failover).toBe(false);
|
||||||
|
expect(out.result).toBe('central');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('falls back to local when primary fails AND mcpd auth-checks 200', async () => {
|
||||||
|
const reg = new ProviderRegistry();
|
||||||
|
reg.register(fakeProvider('vllm-local'));
|
||||||
|
reg.registerFailover('claude', 'vllm-local');
|
||||||
|
|
||||||
|
const fetchFn = makeFetch({ method: 'HEAD', status: 200 });
|
||||||
|
const router = new FailoverRouter(reg, {
|
||||||
|
mcpdUrl: 'http://mcpd',
|
||||||
|
fetch: fetchFn as unknown as typeof fetch,
|
||||||
|
bearerToken: 'bearer-x',
|
||||||
|
});
|
||||||
|
const out = await router.run(
|
||||||
|
'claude',
|
||||||
|
async () => { throw new Error('upstream down'); },
|
||||||
|
async (provider) => `via:${provider.name}`,
|
||||||
|
);
|
||||||
|
expect(out.failover).toBe(true);
|
||||||
|
expect(out.via).toBe('vllm-local');
|
||||||
|
expect(out.result).toBe('via:vllm-local');
|
||||||
|
|
||||||
|
// Bearer was attached
|
||||||
|
const [, init] = fetchFn.mock.calls[0] as [string, RequestInit];
|
||||||
|
expect((init.headers as Record<string, string>)['Authorization']).toBe('Bearer bearer-x');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('re-throws primary error when no local failover is registered', async () => {
|
||||||
|
const reg = new ProviderRegistry();
|
||||||
|
const router = new FailoverRouter(reg, {
|
||||||
|
mcpdUrl: 'http://mcpd',
|
||||||
|
fetch: vi.fn() as unknown as typeof fetch,
|
||||||
|
});
|
||||||
|
await expect(router.run(
|
||||||
|
'claude',
|
||||||
|
async () => { throw new Error('boom'); },
|
||||||
|
async () => 'never',
|
||||||
|
)).rejects.toThrow('boom');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('re-throws (fail-closed) when mcpd returns 403 to the auth check', async () => {
|
||||||
|
const reg = new ProviderRegistry();
|
||||||
|
reg.register(fakeProvider('vllm-local'));
|
||||||
|
reg.registerFailover('claude', 'vllm-local');
|
||||||
|
|
||||||
|
const router = new FailoverRouter(reg, {
|
||||||
|
mcpdUrl: 'http://mcpd',
|
||||||
|
fetch: makeFetch({ method: 'HEAD', status: 403 }) as unknown as typeof fetch,
|
||||||
|
});
|
||||||
|
await expect(router.run(
|
||||||
|
'claude',
|
||||||
|
async () => { throw new Error('upstream down'); },
|
||||||
|
async () => 'never',
|
||||||
|
)).rejects.toThrow('upstream down');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('re-throws (fail-closed) when mcpd itself is unreachable for the auth check', async () => {
|
||||||
|
const reg = new ProviderRegistry();
|
||||||
|
reg.register(fakeProvider('vllm-local'));
|
||||||
|
reg.registerFailover('claude', 'vllm-local');
|
||||||
|
|
||||||
|
const router = new FailoverRouter(reg, {
|
||||||
|
mcpdUrl: 'http://mcpd',
|
||||||
|
fetch: makeFetch({ method: 'HEAD', throw: true }) as unknown as typeof fetch,
|
||||||
|
});
|
||||||
|
await expect(router.run(
|
||||||
|
'claude',
|
||||||
|
async () => { throw new Error('upstream down'); },
|
||||||
|
async () => 'never',
|
||||||
|
)).rejects.toThrow('upstream down');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('checkAuth maps responses correctly', async () => {
|
||||||
|
const reg = new ProviderRegistry();
|
||||||
|
const make = (status: number) => new FailoverRouter(reg, {
|
||||||
|
mcpdUrl: 'http://mcpd',
|
||||||
|
fetch: (async () => new Response(null, { status })) as unknown as typeof fetch,
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(await make(200).checkAuth('claude')).toBe('allowed');
|
||||||
|
expect(await make(204).checkAuth('claude')).toBe('allowed');
|
||||||
|
expect(await make(401).checkAuth('claude')).toBe('forbidden');
|
||||||
|
expect(await make(403).checkAuth('claude')).toBe('forbidden');
|
||||||
|
expect(await make(404).checkAuth('claude')).toBe('unreachable');
|
||||||
|
expect(await make(500).checkAuth('claude')).toBe('unreachable');
|
||||||
|
});
|
||||||
|
});
|
||||||
Reference in New Issue
Block a user