171 lines
6.2 KiB
TypeScript
171 lines
6.2 KiB
TypeScript
|
|
import { describe, it, expect, vi } from 'vitest';
|
||
|
|
import { ProviderRegistry } from '../src/providers/registry.js';
|
||
|
|
import { FailoverRouter } from '../src/providers/failover-router.js';
|
||
|
|
import type { LlmProvider, CompleteResponse } from '../src/providers/types.js';
|
||
|
|
|
||
|
|
function fakeProvider(name: string): LlmProvider {
|
||
|
|
const completeFn = vi.fn(async (): Promise<CompleteResponse> => ({
|
||
|
|
content: 'local response',
|
||
|
|
finishReason: 'stop',
|
||
|
|
}));
|
||
|
|
return {
|
||
|
|
name,
|
||
|
|
complete: completeFn,
|
||
|
|
listModels: vi.fn(async () => [name]),
|
||
|
|
isAvailable: vi.fn(async () => true),
|
||
|
|
};
|
||
|
|
}
|
||
|
|
|
||
|
|
function makeFetch(behaviour: { method: string; status?: number; throw?: boolean }): ReturnType<typeof vi.fn> {
|
||
|
|
return vi.fn(async (url: string | URL, init?: RequestInit) => {
|
||
|
|
if (behaviour.throw === true) throw new Error('connection refused');
|
||
|
|
expect(init?.method).toBe(behaviour.method);
|
||
|
|
expect(String(url)).toMatch(/\/api\/v1\/llms\//);
|
||
|
|
return new Response(null, { status: behaviour.status ?? 200 });
|
||
|
|
});
|
||
|
|
}
|
||
|
|
|
||
|
|
describe('ProviderRegistry — failover map', () => {
|
||
|
|
it('registerFailover maps a central name → local provider name', () => {
|
||
|
|
const reg = new ProviderRegistry();
|
||
|
|
const local = fakeProvider('vllm-local');
|
||
|
|
reg.register(local);
|
||
|
|
reg.registerFailover('claude', 'vllm-local');
|
||
|
|
|
||
|
|
const found = reg.getFailoverFor('claude');
|
||
|
|
expect(found?.name).toBe('vllm-local');
|
||
|
|
});
|
||
|
|
|
||
|
|
it('getFailoverFor returns null when no map entry exists', () => {
|
||
|
|
const reg = new ProviderRegistry();
|
||
|
|
reg.register(fakeProvider('vllm-local'));
|
||
|
|
expect(reg.getFailoverFor('claude')).toBeNull();
|
||
|
|
});
|
||
|
|
|
||
|
|
it('registerFailover throws when local provider is not registered', () => {
|
||
|
|
const reg = new ProviderRegistry();
|
||
|
|
expect(() => reg.registerFailover('claude', 'missing')).toThrow(/not registered/);
|
||
|
|
});
|
||
|
|
|
||
|
|
it('unregister removes failover entries that pointed at the removed provider', () => {
|
||
|
|
const reg = new ProviderRegistry();
|
||
|
|
reg.register(fakeProvider('vllm-local'));
|
||
|
|
reg.registerFailover('claude', 'vllm-local');
|
||
|
|
reg.unregister('vllm-local');
|
||
|
|
expect(reg.getFailoverFor('claude')).toBeNull();
|
||
|
|
expect(reg.listFailovers()).toEqual([]);
|
||
|
|
});
|
||
|
|
|
||
|
|
it('listFailovers reports the current map', () => {
|
||
|
|
const reg = new ProviderRegistry();
|
||
|
|
reg.register(fakeProvider('vllm-local'));
|
||
|
|
reg.registerFailover('claude', 'vllm-local');
|
||
|
|
reg.registerFailover('opus', 'vllm-local');
|
||
|
|
expect(reg.listFailovers()).toEqual([
|
||
|
|
{ centralLlmName: 'claude', localProviderName: 'vllm-local' },
|
||
|
|
{ centralLlmName: 'opus', localProviderName: 'vllm-local' },
|
||
|
|
]);
|
||
|
|
});
|
||
|
|
});
|
||
|
|
|
||
|
|
describe('FailoverRouter', () => {
|
||
|
|
it('returns primary result when primary succeeds', async () => {
|
||
|
|
const reg = new ProviderRegistry();
|
||
|
|
reg.register(fakeProvider('vllm-local'));
|
||
|
|
reg.registerFailover('claude', 'vllm-local');
|
||
|
|
|
||
|
|
const router = new FailoverRouter(reg, {
|
||
|
|
mcpdUrl: 'http://mcpd',
|
||
|
|
fetch: vi.fn() as unknown as typeof fetch,
|
||
|
|
});
|
||
|
|
const out = await router.run('claude', async () => 'central', async () => 'local');
|
||
|
|
expect(out.failover).toBe(false);
|
||
|
|
expect(out.result).toBe('central');
|
||
|
|
});
|
||
|
|
|
||
|
|
it('falls back to local when primary fails AND mcpd auth-checks 200', async () => {
|
||
|
|
const reg = new ProviderRegistry();
|
||
|
|
reg.register(fakeProvider('vllm-local'));
|
||
|
|
reg.registerFailover('claude', 'vllm-local');
|
||
|
|
|
||
|
|
const fetchFn = makeFetch({ method: 'HEAD', status: 200 });
|
||
|
|
const router = new FailoverRouter(reg, {
|
||
|
|
mcpdUrl: 'http://mcpd',
|
||
|
|
fetch: fetchFn as unknown as typeof fetch,
|
||
|
|
bearerToken: 'bearer-x',
|
||
|
|
});
|
||
|
|
const out = await router.run(
|
||
|
|
'claude',
|
||
|
|
async () => { throw new Error('upstream down'); },
|
||
|
|
async (provider) => `via:${provider.name}`,
|
||
|
|
);
|
||
|
|
expect(out.failover).toBe(true);
|
||
|
|
expect(out.via).toBe('vllm-local');
|
||
|
|
expect(out.result).toBe('via:vllm-local');
|
||
|
|
|
||
|
|
// Bearer was attached
|
||
|
|
const [, init] = fetchFn.mock.calls[0] as [string, RequestInit];
|
||
|
|
expect((init.headers as Record<string, string>)['Authorization']).toBe('Bearer bearer-x');
|
||
|
|
});
|
||
|
|
|
||
|
|
it('re-throws primary error when no local failover is registered', async () => {
|
||
|
|
const reg = new ProviderRegistry();
|
||
|
|
const router = new FailoverRouter(reg, {
|
||
|
|
mcpdUrl: 'http://mcpd',
|
||
|
|
fetch: vi.fn() as unknown as typeof fetch,
|
||
|
|
});
|
||
|
|
await expect(router.run(
|
||
|
|
'claude',
|
||
|
|
async () => { throw new Error('boom'); },
|
||
|
|
async () => 'never',
|
||
|
|
)).rejects.toThrow('boom');
|
||
|
|
});
|
||
|
|
|
||
|
|
it('re-throws (fail-closed) when mcpd returns 403 to the auth check', async () => {
|
||
|
|
const reg = new ProviderRegistry();
|
||
|
|
reg.register(fakeProvider('vllm-local'));
|
||
|
|
reg.registerFailover('claude', 'vllm-local');
|
||
|
|
|
||
|
|
const router = new FailoverRouter(reg, {
|
||
|
|
mcpdUrl: 'http://mcpd',
|
||
|
|
fetch: makeFetch({ method: 'HEAD', status: 403 }) as unknown as typeof fetch,
|
||
|
|
});
|
||
|
|
await expect(router.run(
|
||
|
|
'claude',
|
||
|
|
async () => { throw new Error('upstream down'); },
|
||
|
|
async () => 'never',
|
||
|
|
)).rejects.toThrow('upstream down');
|
||
|
|
});
|
||
|
|
|
||
|
|
it('re-throws (fail-closed) when mcpd itself is unreachable for the auth check', async () => {
|
||
|
|
const reg = new ProviderRegistry();
|
||
|
|
reg.register(fakeProvider('vllm-local'));
|
||
|
|
reg.registerFailover('claude', 'vllm-local');
|
||
|
|
|
||
|
|
const router = new FailoverRouter(reg, {
|
||
|
|
mcpdUrl: 'http://mcpd',
|
||
|
|
fetch: makeFetch({ method: 'HEAD', throw: true }) as unknown as typeof fetch,
|
||
|
|
});
|
||
|
|
await expect(router.run(
|
||
|
|
'claude',
|
||
|
|
async () => { throw new Error('upstream down'); },
|
||
|
|
async () => 'never',
|
||
|
|
)).rejects.toThrow('upstream down');
|
||
|
|
});
|
||
|
|
|
||
|
|
it('checkAuth maps responses correctly', async () => {
|
||
|
|
const reg = new ProviderRegistry();
|
||
|
|
const make = (status: number) => new FailoverRouter(reg, {
|
||
|
|
mcpdUrl: 'http://mcpd',
|
||
|
|
fetch: (async () => new Response(null, { status })) as unknown as typeof fetch,
|
||
|
|
});
|
||
|
|
|
||
|
|
expect(await make(200).checkAuth('claude')).toBe('allowed');
|
||
|
|
expect(await make(204).checkAuth('claude')).toBe('allowed');
|
||
|
|
expect(await make(401).checkAuth('claude')).toBe('forbidden');
|
||
|
|
expect(await make(403).checkAuth('claude')).toBe('forbidden');
|
||
|
|
expect(await make(404).checkAuth('claude')).toBe('unreachable');
|
||
|
|
expect(await make(500).checkAuth('claude')).toBe('unreachable');
|
||
|
|
});
|
||
|
|
});
|