Compare commits
3 Commits
feat/llm-f
...
d217eadd13
| Author | SHA1 | Date | |
|---|---|---|---|
| d217eadd13 | |||
| 9e3507752f | |||
| 97ac1e75ef |
@@ -10,12 +10,9 @@ export function registerLlmRoutes(
|
|||||||
return service.list();
|
return service.list();
|
||||||
});
|
});
|
||||||
|
|
||||||
// Accepts either CUID or human name. Used both by the CLI (which usually
|
|
||||||
// resolves to CUID first) and by FailoverRouter's RBAC pre-check (which
|
|
||||||
// hands over the user-facing name to avoid an extra round-trip).
|
|
||||||
app.get<{ Params: { id: string } }>('/api/v1/llms/:id', async (request, reply) => {
|
app.get<{ Params: { id: string } }>('/api/v1/llms/:id', async (request, reply) => {
|
||||||
try {
|
try {
|
||||||
return await getByIdOrName(service, request.params.id);
|
return await service.getById(request.params.id);
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
if (err instanceof NotFoundError) {
|
if (err instanceof NotFoundError) {
|
||||||
reply.code(404);
|
reply.code(404);
|
||||||
@@ -25,10 +22,6 @@ export function registerLlmRoutes(
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
// No explicit HEAD handler: Fastify auto-derives HEAD from GET, which runs
|
|
||||||
// the same RBAC hook + lookup and drops the body. That's exactly what
|
|
||||||
// FailoverRouter wants for its "can the caller still view this Llm?" probe.
|
|
||||||
|
|
||||||
app.post('/api/v1/llms', async (request, reply) => {
|
app.post('/api/v1/llms', async (request, reply) => {
|
||||||
try {
|
try {
|
||||||
const row = await service.create(request.body);
|
const row = await service.create(request.body);
|
||||||
@@ -69,17 +62,3 @@ export function registerLlmRoutes(
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
const CUID_RE = /^c[a-z0-9]{24}/i;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Look up by CUID first; if the input doesn't look like one, fall back to
|
|
||||||
* findByName. Lets the same URL serve both `mcpctl describe llm <name>` and
|
|
||||||
* the FailoverRouter's name-based RBAC check.
|
|
||||||
*/
|
|
||||||
async function getByIdOrName(service: LlmService, idOrName: string) {
|
|
||||||
if (CUID_RE.test(idOrName)) {
|
|
||||||
return service.getById(idOrName);
|
|
||||||
}
|
|
||||||
return service.getByName(idOrName);
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -104,25 +104,6 @@ describe('Llm Routes', () => {
|
|||||||
expect(res.statusCode).toBe(404);
|
expect(res.statusCode).toBe(404);
|
||||||
});
|
});
|
||||||
|
|
||||||
it('GET /api/v1/llms/:nameOrId resolves by human name when not a CUID', async () => {
|
|
||||||
await createApp(mockRepo([makeLlm({ id: 'llm-1', name: 'claude' })]));
|
|
||||||
const res = await app.inject({ method: 'GET', url: '/api/v1/llms/claude' });
|
|
||||||
expect(res.statusCode).toBe(200);
|
|
||||||
expect(res.json<{ name: string; id: string }>().name).toBe('claude');
|
|
||||||
});
|
|
||||||
|
|
||||||
it('HEAD /api/v1/llms/:name returns 200 for an existing Llm (failover RBAC pre-check)', async () => {
|
|
||||||
await createApp(mockRepo([makeLlm({ name: 'claude' })]));
|
|
||||||
const res = await app.inject({ method: 'HEAD', url: '/api/v1/llms/claude' });
|
|
||||||
expect(res.statusCode).toBe(200);
|
|
||||||
});
|
|
||||||
|
|
||||||
it('HEAD /api/v1/llms/:name returns 404 for a missing Llm', async () => {
|
|
||||||
await createApp(mockRepo());
|
|
||||||
const res = await app.inject({ method: 'HEAD', url: '/api/v1/llms/missing' });
|
|
||||||
expect(res.statusCode).toBe(404);
|
|
||||||
});
|
|
||||||
|
|
||||||
it('POST /api/v1/llms creates and returns 201', async () => {
|
it('POST /api/v1/llms creates and returns 201', async () => {
|
||||||
await createApp(mockRepo());
|
await createApp(mockRepo());
|
||||||
const res = await app.inject({
|
const res = await app.inject({
|
||||||
|
|||||||
@@ -64,14 +64,6 @@ export interface LlmProviderFileEntry {
|
|||||||
idleTimeoutMinutes?: number;
|
idleTimeoutMinutes?: number;
|
||||||
/** vllm-managed: extra args for `vllm serve` */
|
/** vllm-managed: extra args for `vllm serve` */
|
||||||
extraArgs?: string[];
|
extraArgs?: string[];
|
||||||
/**
|
|
||||||
* If set, this local provider is allowed to substitute for the centralized
|
|
||||||
* Llm of this name when the mcpd inference proxy is unreachable.
|
|
||||||
* RBAC is still enforced — the caller must have view permission on the
|
|
||||||
* named Llm via mcpd before failover is permitted (fail-closed if mcpd
|
|
||||||
* itself can't be reached).
|
|
||||||
*/
|
|
||||||
failoverFor?: string;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface ProjectLlmOverride {
|
export interface ProjectLlmOverride {
|
||||||
|
|||||||
@@ -173,9 +173,6 @@ export async function createProvidersFromConfig(
|
|||||||
if (entry.tier) {
|
if (entry.tier) {
|
||||||
registry.assignTier(provider.name, entry.tier);
|
registry.assignTier(provider.name, entry.tier);
|
||||||
}
|
}
|
||||||
if (entry.failoverFor) {
|
|
||||||
registry.registerFailover(entry.failoverFor, provider.name);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return registry;
|
return registry;
|
||||||
|
|||||||
@@ -1,107 +0,0 @@
|
|||||||
/**
|
|
||||||
* FailoverRouter — orchestrates "try mcpd's centralized Llm, fall back to a
|
|
||||||
* local provider when authorized" for clients that consume the inference
|
|
||||||
* proxy.
|
|
||||||
*
|
|
||||||
* Decision flow on a centralized inference call:
|
|
||||||
*
|
|
||||||
* 1. Call the primary (the supplied `primary` callback, typically an HTTP
|
|
||||||
* POST to mcpd /api/v1/llms/:name/infer).
|
|
||||||
* 2. If that succeeds → done.
|
|
||||||
* 3. If it fails AND a local provider is registered as failover for this
|
|
||||||
* Llm name → call mcpd /api/v1/llms/:name (RBAC-gated) to verify the
|
|
||||||
* caller still has permission to view this Llm. mcpd unreachable →
|
|
||||||
* fail-closed (re-throw the original error). 403 → fail-closed.
|
|
||||||
* 4. 200 → invoke the local provider's `complete()` and tag the result
|
|
||||||
* as `failover: true` for client-side audit.
|
|
||||||
*
|
|
||||||
* The check call uses HEAD to avoid pulling the Llm body (and any
|
|
||||||
* description / extraConfig) over the wire — mcpd treats both methods the
|
|
||||||
* same in the RBAC hook because the URL maps to the same permission.
|
|
||||||
*/
|
|
||||||
import type { LlmProvider } from './types.js';
|
|
||||||
import type { ProviderRegistry } from './registry.js';
|
|
||||||
|
|
||||||
export interface FailoverDecision<T> {
|
|
||||||
result: T;
|
|
||||||
failover: boolean;
|
|
||||||
/** Name of the local provider used (only set when failover === true). */
|
|
||||||
via?: string;
|
|
||||||
}
|
|
||||||
|
|
||||||
export interface FailoverRouterDeps {
|
|
||||||
/** Injected fetch for the RBAC pre-check. Tests mock this. */
|
|
||||||
fetch?: typeof globalThis.fetch;
|
|
||||||
/** mcpd base URL (no trailing slash). */
|
|
||||||
mcpdUrl: string;
|
|
||||||
/** Bearer token to attach to the RBAC pre-check call. */
|
|
||||||
bearerToken?: string;
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Outcome of the RBAC pre-check. Used internally + exposed for tests. */
|
|
||||||
export type AuthCheckOutcome = 'allowed' | 'forbidden' | 'unreachable';
|
|
||||||
|
|
||||||
export class FailoverRouter {
|
|
||||||
private readonly fetchImpl: typeof globalThis.fetch;
|
|
||||||
private readonly mcpdUrl: string;
|
|
||||||
private readonly bearer: string | undefined;
|
|
||||||
|
|
||||||
constructor(
|
|
||||||
private readonly registry: ProviderRegistry,
|
|
||||||
deps: FailoverRouterDeps,
|
|
||||||
) {
|
|
||||||
this.fetchImpl = deps.fetch ?? globalThis.fetch;
|
|
||||||
this.mcpdUrl = deps.mcpdUrl.replace(/\/+$/, '');
|
|
||||||
if (deps.bearerToken !== undefined) this.bearer = deps.bearerToken;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Run a primary inference attempt; on failure, fall back to the local
|
|
||||||
* provider if one is registered for this Llm AND the caller still has
|
|
||||||
* `view:llms:<llmName>` on mcpd.
|
|
||||||
*
|
|
||||||
* `primary` should reject (throw) when mcpd's proxy is unreachable or
|
|
||||||
* returns a 5xx — that's the signal to consider failover. 4xx errors that
|
|
||||||
* indicate a bad request are surfaced as-is; the router only retries on
|
|
||||||
* primary failure shapes that look like an upstream/network issue.
|
|
||||||
*/
|
|
||||||
async run<T>(
|
|
||||||
llmName: string,
|
|
||||||
primary: () => Promise<T>,
|
|
||||||
localCall: (provider: LlmProvider) => Promise<T>,
|
|
||||||
): Promise<FailoverDecision<T>> {
|
|
||||||
try {
|
|
||||||
const result = await primary();
|
|
||||||
return { result, failover: false };
|
|
||||||
} catch (primaryErr) {
|
|
||||||
const local = this.registry.getFailoverFor(llmName);
|
|
||||||
if (local === null) throw primaryErr;
|
|
||||||
|
|
||||||
const auth = await this.checkAuth(llmName);
|
|
||||||
if (auth !== 'allowed') {
|
|
||||||
// Fail-closed for forbidden AND unreachable.
|
|
||||||
throw primaryErr;
|
|
||||||
}
|
|
||||||
|
|
||||||
const result = await localCall(local);
|
|
||||||
return { result, failover: true, via: local.name };
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/** RBAC pre-check exposed for tests / status-display callers. */
|
|
||||||
async checkAuth(llmName: string): Promise<AuthCheckOutcome> {
|
|
||||||
const url = `${this.mcpdUrl}/api/v1/llms/${encodeURIComponent(llmName)}`;
|
|
||||||
const headers: Record<string, string> = {};
|
|
||||||
if (this.bearer !== undefined) headers['Authorization'] = `Bearer ${this.bearer}`;
|
|
||||||
let res: Response;
|
|
||||||
try {
|
|
||||||
res = await this.fetchImpl(url, { method: 'HEAD', headers });
|
|
||||||
} catch {
|
|
||||||
return 'unreachable';
|
|
||||||
}
|
|
||||||
if (res.status === 200 || res.status === 204) return 'allowed';
|
|
||||||
if (res.status === 403 || res.status === 401) return 'forbidden';
|
|
||||||
// Anything else (404, 500…) — treat as unreachable for the failover flow.
|
|
||||||
return 'unreachable';
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -8,8 +8,6 @@ export class ProviderRegistry {
|
|||||||
private providers = new Map<string, LlmProvider>();
|
private providers = new Map<string, LlmProvider>();
|
||||||
private activeProvider: string | null = null;
|
private activeProvider: string | null = null;
|
||||||
private tierProviders = new Map<Tier, string[]>();
|
private tierProviders = new Map<Tier, string[]>();
|
||||||
/** Maps a centralized Llm name → local provider name that can substitute when mcpd is unreachable. */
|
|
||||||
private failoverMap = new Map<string, string>();
|
|
||||||
|
|
||||||
register(provider: LlmProvider): void {
|
register(provider: LlmProvider): void {
|
||||||
this.providers.set(provider.name, provider);
|
this.providers.set(provider.name, provider);
|
||||||
@@ -33,30 +31,6 @@ export class ProviderRegistry {
|
|||||||
this.tierProviders.set(tier, filtered);
|
this.tierProviders.set(tier, filtered);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Remove from failover map (any entry whose local-provider value points at this name)
|
|
||||||
for (const [centralName, localName] of this.failoverMap) {
|
|
||||||
if (localName === name) this.failoverMap.delete(centralName);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Mark `localProviderName` as the failover for the centralized Llm named `centralLlmName`. */
|
|
||||||
registerFailover(centralLlmName: string, localProviderName: string): void {
|
|
||||||
if (!this.providers.has(localProviderName)) {
|
|
||||||
throw new Error(`Provider '${localProviderName}' is not registered`);
|
|
||||||
}
|
|
||||||
this.failoverMap.set(centralLlmName, localProviderName);
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Look up the local provider that can substitute for a centralized Llm, if any. */
|
|
||||||
getFailoverFor(centralLlmName: string): LlmProvider | null {
|
|
||||||
const localName = this.failoverMap.get(centralLlmName);
|
|
||||||
if (localName === undefined) return null;
|
|
||||||
return this.providers.get(localName) ?? null;
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Names of central Llms that have a local failover registered. Used in status output. */
|
|
||||||
listFailovers(): Array<{ centralLlmName: string; localProviderName: string }> {
|
|
||||||
return [...this.failoverMap.entries()].map(([centralLlmName, localProviderName]) => ({ centralLlmName, localProviderName }));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
setActive(name: string): void {
|
setActive(name: string): void {
|
||||||
|
|||||||
@@ -1,170 +0,0 @@
|
|||||||
import { describe, it, expect, vi } from 'vitest';
|
|
||||||
import { ProviderRegistry } from '../src/providers/registry.js';
|
|
||||||
import { FailoverRouter } from '../src/providers/failover-router.js';
|
|
||||||
import type { LlmProvider, CompleteResponse } from '../src/providers/types.js';
|
|
||||||
|
|
||||||
function fakeProvider(name: string): LlmProvider {
|
|
||||||
const completeFn = vi.fn(async (): Promise<CompleteResponse> => ({
|
|
||||||
content: 'local response',
|
|
||||||
finishReason: 'stop',
|
|
||||||
}));
|
|
||||||
return {
|
|
||||||
name,
|
|
||||||
complete: completeFn,
|
|
||||||
listModels: vi.fn(async () => [name]),
|
|
||||||
isAvailable: vi.fn(async () => true),
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
function makeFetch(behaviour: { method: string; status?: number; throw?: boolean }): ReturnType<typeof vi.fn> {
|
|
||||||
return vi.fn(async (url: string | URL, init?: RequestInit) => {
|
|
||||||
if (behaviour.throw === true) throw new Error('connection refused');
|
|
||||||
expect(init?.method).toBe(behaviour.method);
|
|
||||||
expect(String(url)).toMatch(/\/api\/v1\/llms\//);
|
|
||||||
return new Response(null, { status: behaviour.status ?? 200 });
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
describe('ProviderRegistry — failover map', () => {
|
|
||||||
it('registerFailover maps a central name → local provider name', () => {
|
|
||||||
const reg = new ProviderRegistry();
|
|
||||||
const local = fakeProvider('vllm-local');
|
|
||||||
reg.register(local);
|
|
||||||
reg.registerFailover('claude', 'vllm-local');
|
|
||||||
|
|
||||||
const found = reg.getFailoverFor('claude');
|
|
||||||
expect(found?.name).toBe('vllm-local');
|
|
||||||
});
|
|
||||||
|
|
||||||
it('getFailoverFor returns null when no map entry exists', () => {
|
|
||||||
const reg = new ProviderRegistry();
|
|
||||||
reg.register(fakeProvider('vllm-local'));
|
|
||||||
expect(reg.getFailoverFor('claude')).toBeNull();
|
|
||||||
});
|
|
||||||
|
|
||||||
it('registerFailover throws when local provider is not registered', () => {
|
|
||||||
const reg = new ProviderRegistry();
|
|
||||||
expect(() => reg.registerFailover('claude', 'missing')).toThrow(/not registered/);
|
|
||||||
});
|
|
||||||
|
|
||||||
it('unregister removes failover entries that pointed at the removed provider', () => {
|
|
||||||
const reg = new ProviderRegistry();
|
|
||||||
reg.register(fakeProvider('vllm-local'));
|
|
||||||
reg.registerFailover('claude', 'vllm-local');
|
|
||||||
reg.unregister('vllm-local');
|
|
||||||
expect(reg.getFailoverFor('claude')).toBeNull();
|
|
||||||
expect(reg.listFailovers()).toEqual([]);
|
|
||||||
});
|
|
||||||
|
|
||||||
it('listFailovers reports the current map', () => {
|
|
||||||
const reg = new ProviderRegistry();
|
|
||||||
reg.register(fakeProvider('vllm-local'));
|
|
||||||
reg.registerFailover('claude', 'vllm-local');
|
|
||||||
reg.registerFailover('opus', 'vllm-local');
|
|
||||||
expect(reg.listFailovers()).toEqual([
|
|
||||||
{ centralLlmName: 'claude', localProviderName: 'vllm-local' },
|
|
||||||
{ centralLlmName: 'opus', localProviderName: 'vllm-local' },
|
|
||||||
]);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
describe('FailoverRouter', () => {
|
|
||||||
it('returns primary result when primary succeeds', async () => {
|
|
||||||
const reg = new ProviderRegistry();
|
|
||||||
reg.register(fakeProvider('vllm-local'));
|
|
||||||
reg.registerFailover('claude', 'vllm-local');
|
|
||||||
|
|
||||||
const router = new FailoverRouter(reg, {
|
|
||||||
mcpdUrl: 'http://mcpd',
|
|
||||||
fetch: vi.fn() as unknown as typeof fetch,
|
|
||||||
});
|
|
||||||
const out = await router.run('claude', async () => 'central', async () => 'local');
|
|
||||||
expect(out.failover).toBe(false);
|
|
||||||
expect(out.result).toBe('central');
|
|
||||||
});
|
|
||||||
|
|
||||||
it('falls back to local when primary fails AND mcpd auth-checks 200', async () => {
|
|
||||||
const reg = new ProviderRegistry();
|
|
||||||
reg.register(fakeProvider('vllm-local'));
|
|
||||||
reg.registerFailover('claude', 'vllm-local');
|
|
||||||
|
|
||||||
const fetchFn = makeFetch({ method: 'HEAD', status: 200 });
|
|
||||||
const router = new FailoverRouter(reg, {
|
|
||||||
mcpdUrl: 'http://mcpd',
|
|
||||||
fetch: fetchFn as unknown as typeof fetch,
|
|
||||||
bearerToken: 'bearer-x',
|
|
||||||
});
|
|
||||||
const out = await router.run(
|
|
||||||
'claude',
|
|
||||||
async () => { throw new Error('upstream down'); },
|
|
||||||
async (provider) => `via:${provider.name}`,
|
|
||||||
);
|
|
||||||
expect(out.failover).toBe(true);
|
|
||||||
expect(out.via).toBe('vllm-local');
|
|
||||||
expect(out.result).toBe('via:vllm-local');
|
|
||||||
|
|
||||||
// Bearer was attached
|
|
||||||
const [, init] = fetchFn.mock.calls[0] as [string, RequestInit];
|
|
||||||
expect((init.headers as Record<string, string>)['Authorization']).toBe('Bearer bearer-x');
|
|
||||||
});
|
|
||||||
|
|
||||||
it('re-throws primary error when no local failover is registered', async () => {
|
|
||||||
const reg = new ProviderRegistry();
|
|
||||||
const router = new FailoverRouter(reg, {
|
|
||||||
mcpdUrl: 'http://mcpd',
|
|
||||||
fetch: vi.fn() as unknown as typeof fetch,
|
|
||||||
});
|
|
||||||
await expect(router.run(
|
|
||||||
'claude',
|
|
||||||
async () => { throw new Error('boom'); },
|
|
||||||
async () => 'never',
|
|
||||||
)).rejects.toThrow('boom');
|
|
||||||
});
|
|
||||||
|
|
||||||
it('re-throws (fail-closed) when mcpd returns 403 to the auth check', async () => {
|
|
||||||
const reg = new ProviderRegistry();
|
|
||||||
reg.register(fakeProvider('vllm-local'));
|
|
||||||
reg.registerFailover('claude', 'vllm-local');
|
|
||||||
|
|
||||||
const router = new FailoverRouter(reg, {
|
|
||||||
mcpdUrl: 'http://mcpd',
|
|
||||||
fetch: makeFetch({ method: 'HEAD', status: 403 }) as unknown as typeof fetch,
|
|
||||||
});
|
|
||||||
await expect(router.run(
|
|
||||||
'claude',
|
|
||||||
async () => { throw new Error('upstream down'); },
|
|
||||||
async () => 'never',
|
|
||||||
)).rejects.toThrow('upstream down');
|
|
||||||
});
|
|
||||||
|
|
||||||
it('re-throws (fail-closed) when mcpd itself is unreachable for the auth check', async () => {
|
|
||||||
const reg = new ProviderRegistry();
|
|
||||||
reg.register(fakeProvider('vllm-local'));
|
|
||||||
reg.registerFailover('claude', 'vllm-local');
|
|
||||||
|
|
||||||
const router = new FailoverRouter(reg, {
|
|
||||||
mcpdUrl: 'http://mcpd',
|
|
||||||
fetch: makeFetch({ method: 'HEAD', throw: true }) as unknown as typeof fetch,
|
|
||||||
});
|
|
||||||
await expect(router.run(
|
|
||||||
'claude',
|
|
||||||
async () => { throw new Error('upstream down'); },
|
|
||||||
async () => 'never',
|
|
||||||
)).rejects.toThrow('upstream down');
|
|
||||||
});
|
|
||||||
|
|
||||||
it('checkAuth maps responses correctly', async () => {
|
|
||||||
const reg = new ProviderRegistry();
|
|
||||||
const make = (status: number) => new FailoverRouter(reg, {
|
|
||||||
mcpdUrl: 'http://mcpd',
|
|
||||||
fetch: (async () => new Response(null, { status })) as unknown as typeof fetch,
|
|
||||||
});
|
|
||||||
|
|
||||||
expect(await make(200).checkAuth('claude')).toBe('allowed');
|
|
||||||
expect(await make(204).checkAuth('claude')).toBe('allowed');
|
|
||||||
expect(await make(401).checkAuth('claude')).toBe('forbidden');
|
|
||||||
expect(await make(403).checkAuth('claude')).toBe('forbidden');
|
|
||||||
expect(await make(404).checkAuth('claude')).toBe('unreachable');
|
|
||||||
expect(await make(500).checkAuth('claude')).toBe('unreachable');
|
|
||||||
});
|
|
||||||
});
|
|
||||||
Reference in New Issue
Block a user