Files
Mdeical_Sur_Report/server/src/ai/ai.service.ts

386 lines
13 KiB
TypeScript
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import { BadRequestException, HttpException, Injectable } from '@nestjs/common';
import type { SafeUser } from '../auth/auth.types.js';
import { SettingsService } from '../settings/settings.service.js';
import { aiChatSchema } from './ai.schemas.js';
interface AiProvider {
endpoint: string;
apiKey: string;
modelName: string;
fallbackModelName?: string;
}
const RETRYABLE_PROVIDER_STATUSES = new Set([429, 500, 502, 503, 504]);
const DEFAULT_RETRY_DELAYS_MS = [600, 1200];
const DEFAULT_PROVIDER_TIMEOUT_MS = 45_000;
const DEFAULT_KIMI_TEXT_MODEL = 'kimi-k2.6';
const DEFAULT_KIMI_VISION_MODEL = 'kimi-k2.6';
@Injectable()
export class AiService {
constructor(private readonly settingsService: SettingsService) {}
async listModels(actor: SafeUser) {
const provider = await this.getActiveProvider(actor);
const response = await this.fetchProvider(`${provider.endpoint}/models`, {
method: 'GET',
headers: this.headers(provider),
});
const payload = await this.parseProviderResponse(response);
if (!response.ok) {
throw this.createProviderException(response.status, payload);
}
const models = Array.isArray((payload as { data?: unknown[] }).data)
? ((payload as { data: Array<{ id?: string }> }).data)
.map((model) => model.id)
.filter((id): id is string => Boolean(id))
: [];
return {
models,
provider: {
endpoint: provider.endpoint,
modelName: provider.modelName,
fallbackModelName: provider.fallbackModelName || '',
},
raw: payload,
};
}
async chat(actor: SafeUser, rawInput: unknown) {
const result = aiChatSchema.safeParse(rawInput);
if (!result.success) {
throw new BadRequestException(result.error.issues.map((issue) => issue.message).join(''));
}
const provider = await this.getActiveProvider(actor);
const input = result.data;
const modelCandidates = this.selectModelCandidates(provider, input);
const { response, responsePayload } = await this.fetchChatWithModelFallback(
provider,
input,
modelCandidates,
);
if (!response.ok) {
throw this.createProviderException(response.status, responsePayload);
}
return responsePayload;
}
private async getActiveProvider(actor: SafeUser): Promise<AiProvider> {
const settings = await this.settingsService.getSystemSettings(actor, { includeSecrets: true });
const activeProvider = settings.activeAiProvider || 'kimi';
const provider = settings.aiProviders?.[activeProvider];
const endpoint = provider?.endpoint?.replace(/\/+$/, '') || '';
const apiKey = provider?.apiKey || '';
const modelName = provider?.modelName || '';
const fallbackModelName = provider?.fallbackModelName || '';
if (!endpoint) {
throw new BadRequestException('尚未配置 AI 接口地址');
}
if (!apiKey) {
throw new BadRequestException('尚未配置 AI API Key');
}
if (!modelName) {
throw new BadRequestException('尚未配置 AI 模型名称');
}
return { endpoint, apiKey, modelName, fallbackModelName };
}
private headers(provider: AiProvider) {
return {
'Content-Type': 'application/json',
Authorization: `Bearer ${provider.apiKey}`,
};
}
private normalizeProviderPayload(payload: Record<string, unknown>) {
const model = typeof payload.model === 'string' ? payload.model : '';
if (!this.isKimiK2Model(model)) return payload;
const normalized = { ...payload };
delete normalized.temperature;
delete normalized.top_p;
delete normalized.n;
delete normalized.presence_penalty;
delete normalized.frequency_penalty;
if (this.supportsKimiThinkingToggle(model) && !('thinking' in normalized)) {
normalized.thinking = { type: 'disabled' };
}
return normalized;
}
private selectModelCandidates(provider: AiProvider, input: Record<string, unknown>) {
const primaryModel = this.selectModel(provider, input, provider.modelName);
const fallbackModel = provider.fallbackModelName
? this.selectModel(provider, input, provider.fallbackModelName)
: '';
return [primaryModel, fallbackModel].filter((model, index, models): model is string => (
Boolean(model) && models.indexOf(model) === index
));
}
private selectModel(provider: AiProvider, input: Record<string, unknown>, configuredModel: string) {
const model = configuredModel || (typeof input.model === 'string' ? input.model : '');
if (!this.isMoonshotProvider(provider)) return model;
const hasMedia = this.hasMediaInput(input.messages);
if (hasMedia && !this.supportsMediaInput(model)) {
return process.env.AI_KIMI_VISION_MODEL || DEFAULT_KIMI_VISION_MODEL;
}
if (this.isDeprecatedKimiK2Model(model)) {
return process.env.AI_KIMI_TEXT_MODEL || DEFAULT_KIMI_TEXT_MODEL;
}
return model;
}
private isMoonshotProvider(provider: AiProvider) {
return /moonshot\.cn/i.test(provider.endpoint);
}
private supportsMediaInput(model: string) {
return /vision/i.test(model) || this.isKimiMultimodalModel(model);
}
private isKimiMultimodalModel(model: string) {
return /^kimi-k2\.(?:5|6)$/i.test(model);
}
private supportsKimiThinkingToggle(model: string) {
return this.isKimiMultimodalModel(model);
}
private isKimiK2Model(model: string) {
return /^kimi-k2(?:[.-]|$)/i.test(model);
}
private isDeprecatedKimiK2Model(model: string) {
return /^kimi-k2(?:-|$)/i.test(model);
}
private hasMediaInput(messages: unknown) {
if (!Array.isArray(messages)) return false;
return messages.some((message) => {
if (typeof message !== 'object' || message === null || !('content' in message)) return false;
return this.hasMediaContent((message as { content?: unknown }).content);
});
}
private hasMediaContent(content: unknown): boolean {
if (!Array.isArray(content)) return false;
return content.some((part) => (
typeof part === 'object' &&
part !== null &&
(
'image_url' in part ||
'video_url' in part ||
(part as { type?: unknown }).type === 'image_url' ||
(part as { type?: unknown }).type === 'video_url'
)
));
}
private async parseProviderResponse(response: Response) {
const text = await response.text();
if (!text) return null;
try {
return JSON.parse(text) as unknown;
} catch {
return { message: text };
}
}
private async fetchProvider(url: string, init: RequestInit) {
const timeoutMs = this.providerTimeoutMs();
const controller = new AbortController();
const timeout = setTimeout(() => controller.abort(), timeoutMs);
try {
return await fetch(url, { ...init, signal: controller.signal });
} catch (error) {
if (error instanceof Error && error.name === 'AbortError') {
throw new HttpException(
{
code: 'AI_PROVIDER_TIMEOUT',
message: `AI 服务响应超时(${Math.round(timeoutMs / 1000)}秒),请稍后重试或缩短报告上下文。`,
},
504,
);
}
throw new BadRequestException(`AI 服务连接失败:${error instanceof Error ? error.message : String(error)}`);
} finally {
clearTimeout(timeout);
}
}
private async fetchProviderWithRetry(url: string, init: RequestInit) {
const retryDelays = this.retryDelays();
let response = await this.fetchProvider(url, init);
for (const delayMs of retryDelays) {
if (!(await this.shouldRetryResponse(response))) break;
await this.sleep(delayMs);
response = await this.fetchProvider(url, init);
}
return response;
}
private async fetchChatWithModelFallback(
provider: AiProvider,
input: Record<string, unknown>,
models: string[],
) {
let lastError: unknown;
for (let index = 0; index < models.length; index += 1) {
const model = models[index];
try {
const payload = this.normalizeProviderPayload({
...input,
model,
});
const response = await this.fetchProviderWithRetry(`${provider.endpoint}/chat/completions`, {
method: 'POST',
headers: this.headers(provider),
body: JSON.stringify(payload),
});
const responsePayload = await this.parseProviderResponse(response);
if (
response.ok ||
index === models.length - 1 ||
!this.shouldFallbackFromResponse(response.status, responsePayload)
) {
return { response, responsePayload };
}
lastError = this.createProviderException(response.status, responsePayload);
} catch (error) {
if (index === models.length - 1 || !this.shouldFallbackFromError(error)) {
throw error;
}
lastError = error;
}
}
throw lastError instanceof Error ? lastError : new BadRequestException('AI 服务请求失败');
}
private async shouldRetryResponse(response: Response) {
if (!RETRYABLE_PROVIDER_STATUSES.has(response.status)) return false;
const payload = await this.parseProviderResponse(response.clone());
return this.isRetryableProviderPayload(response.status, payload);
}
private shouldFallbackFromResponse(status: number, payload: unknown) {
return RETRYABLE_PROVIDER_STATUSES.has(status) && this.isRetryableProviderPayload(status, payload);
}
private shouldFallbackFromError(error: unknown) {
return error instanceof HttpException && this.shouldFallbackFromResponse(error.getStatus(), error.getResponse());
}
private isRetryableProviderPayload(status: number, payload: unknown) {
if (this.isQuotaOrBalanceError(status, payload)) return false;
return RETRYABLE_PROVIDER_STATUSES.has(status);
}
private formatProviderError(status: number, payload: unknown) {
const message =
typeof payload === 'object' && payload !== null && 'error' in payload
? JSON.stringify((payload as { error: unknown }).error)
: typeof payload === 'object' && payload !== null && 'message' in payload
? String((payload as { message: unknown }).message)
: JSON.stringify(payload);
return `AI 服务请求失败:${status}${message ? ` - ${message}` : ''}`;
}
private createProviderException(status: number, payload: unknown) {
return new HttpException(
{
code: this.providerErrorCode(status, payload),
message: this.formatProviderError(status, payload),
},
status,
);
}
private providerErrorCode(status: number, payload: unknown) {
if (this.isQuotaOrBalanceError(status, payload)) return 'AI_PROVIDER_QUOTA_EXCEEDED';
const providerType =
typeof payload === 'object' && payload !== null && 'type' in payload
? String((payload as { type: unknown }).type)
: typeof payload === 'object' && payload !== null && 'error' in payload
? this.extractProviderErrorType((payload as { error: unknown }).error)
: '';
if (status === 429 && /overloaded/i.test(providerType)) return 'AI_PROVIDER_OVERLOADED';
if (status === 429) return 'AI_PROVIDER_RATE_LIMITED';
if (status === 504) return 'AI_PROVIDER_TIMEOUT';
if (status >= 500) return 'AI_PROVIDER_UNAVAILABLE';
return 'AI_PROVIDER_ERROR';
}
private isQuotaOrBalanceError(status: number, payload: unknown) {
if (status === 402) return true;
const text = this.providerErrorText(payload);
return /quota|balance|billing|insufficient|suspended/i.test(text);
}
private providerErrorText(payload: unknown): string {
if (typeof payload === 'string') return payload;
if (typeof payload !== 'object' || payload === null) return '';
const error = 'error' in payload ? (payload as { error: unknown }).error : null;
const topLevel = [
'type' in payload ? (payload as { type: unknown }).type : '',
'message' in payload ? (payload as { message: unknown }).message : '',
'code' in payload ? (payload as { code: unknown }).code : '',
];
if (typeof error === 'object' && error !== null) {
topLevel.push(
'type' in error ? (error as { type: unknown }).type : '',
'message' in error ? (error as { message: unknown }).message : '',
'code' in error ? (error as { code: unknown }).code : '',
);
} else if (typeof error === 'string') {
topLevel.push(error);
}
return topLevel.filter(Boolean).join(' ');
}
private extractProviderErrorType(error: unknown) {
return typeof error === 'object' && error !== null && 'type' in error
? String((error as { type: unknown }).type)
: '';
}
private retryDelays() {
const raw = process.env.AI_PROVIDER_RETRY_DELAYS_MS;
if (!raw) return DEFAULT_RETRY_DELAYS_MS;
return raw
.split(',')
.map((value) => Number(value.trim()))
.filter((value) => Number.isFinite(value) && value >= 0);
}
private providerTimeoutMs() {
const value = Number(process.env.AI_PROVIDER_TIMEOUT_MS);
return Number.isFinite(value) && value > 0 ? value : DEFAULT_PROVIDER_TIMEOUT_MS;
}
private sleep(ms: number) {
return new Promise((resolve) => setTimeout(resolve, ms));
}
}