386 lines
13 KiB
TypeScript
386 lines
13 KiB
TypeScript
import { BadRequestException, HttpException, Injectable } from '@nestjs/common';
|
||
import type { SafeUser } from '../auth/auth.types.js';
|
||
import { SettingsService } from '../settings/settings.service.js';
|
||
import { aiChatSchema } from './ai.schemas.js';
|
||
|
||
interface AiProvider {
|
||
endpoint: string;
|
||
apiKey: string;
|
||
modelName: string;
|
||
fallbackModelName?: string;
|
||
}
|
||
|
||
const RETRYABLE_PROVIDER_STATUSES = new Set([429, 500, 502, 503, 504]);
|
||
const DEFAULT_RETRY_DELAYS_MS = [600, 1200];
|
||
const DEFAULT_PROVIDER_TIMEOUT_MS = 45_000;
|
||
const DEFAULT_KIMI_TEXT_MODEL = 'kimi-k2.6';
|
||
const DEFAULT_KIMI_VISION_MODEL = 'kimi-k2.6';
|
||
|
||
@Injectable()
|
||
export class AiService {
|
||
constructor(private readonly settingsService: SettingsService) {}
|
||
|
||
async listModels(actor: SafeUser) {
|
||
const provider = await this.getActiveProvider(actor);
|
||
const response = await this.fetchProvider(`${provider.endpoint}/models`, {
|
||
method: 'GET',
|
||
headers: this.headers(provider),
|
||
});
|
||
|
||
const payload = await this.parseProviderResponse(response);
|
||
if (!response.ok) {
|
||
throw this.createProviderException(response.status, payload);
|
||
}
|
||
|
||
const models = Array.isArray((payload as { data?: unknown[] }).data)
|
||
? ((payload as { data: Array<{ id?: string }> }).data)
|
||
.map((model) => model.id)
|
||
.filter((id): id is string => Boolean(id))
|
||
: [];
|
||
|
||
return {
|
||
models,
|
||
provider: {
|
||
endpoint: provider.endpoint,
|
||
modelName: provider.modelName,
|
||
fallbackModelName: provider.fallbackModelName || '',
|
||
},
|
||
raw: payload,
|
||
};
|
||
}
|
||
|
||
async chat(actor: SafeUser, rawInput: unknown) {
|
||
const result = aiChatSchema.safeParse(rawInput);
|
||
if (!result.success) {
|
||
throw new BadRequestException(result.error.issues.map((issue) => issue.message).join(';'));
|
||
}
|
||
|
||
const provider = await this.getActiveProvider(actor);
|
||
const input = result.data;
|
||
const modelCandidates = this.selectModelCandidates(provider, input);
|
||
const { response, responsePayload } = await this.fetchChatWithModelFallback(
|
||
provider,
|
||
input,
|
||
modelCandidates,
|
||
);
|
||
|
||
if (!response.ok) {
|
||
throw this.createProviderException(response.status, responsePayload);
|
||
}
|
||
|
||
return responsePayload;
|
||
}
|
||
|
||
private async getActiveProvider(actor: SafeUser): Promise<AiProvider> {
|
||
const settings = await this.settingsService.getSystemSettings(actor, { includeSecrets: true });
|
||
const activeProvider = settings.activeAiProvider || 'kimi';
|
||
const provider = settings.aiProviders?.[activeProvider];
|
||
const endpoint = provider?.endpoint?.replace(/\/+$/, '') || '';
|
||
const apiKey = provider?.apiKey || '';
|
||
const modelName = provider?.modelName || '';
|
||
const fallbackModelName = provider?.fallbackModelName || '';
|
||
|
||
if (!endpoint) {
|
||
throw new BadRequestException('尚未配置 AI 接口地址');
|
||
}
|
||
if (!apiKey) {
|
||
throw new BadRequestException('尚未配置 AI API Key');
|
||
}
|
||
if (!modelName) {
|
||
throw new BadRequestException('尚未配置 AI 模型名称');
|
||
}
|
||
|
||
return { endpoint, apiKey, modelName, fallbackModelName };
|
||
}
|
||
|
||
private headers(provider: AiProvider) {
|
||
return {
|
||
'Content-Type': 'application/json',
|
||
Authorization: `Bearer ${provider.apiKey}`,
|
||
};
|
||
}
|
||
|
||
private normalizeProviderPayload(payload: Record<string, unknown>) {
|
||
const model = typeof payload.model === 'string' ? payload.model : '';
|
||
if (!this.isKimiK2Model(model)) return payload;
|
||
|
||
const normalized = { ...payload };
|
||
delete normalized.temperature;
|
||
delete normalized.top_p;
|
||
delete normalized.n;
|
||
delete normalized.presence_penalty;
|
||
delete normalized.frequency_penalty;
|
||
|
||
if (this.supportsKimiThinkingToggle(model) && !('thinking' in normalized)) {
|
||
normalized.thinking = { type: 'disabled' };
|
||
}
|
||
|
||
return normalized;
|
||
}
|
||
|
||
private selectModelCandidates(provider: AiProvider, input: Record<string, unknown>) {
|
||
const primaryModel = this.selectModel(provider, input, provider.modelName);
|
||
const fallbackModel = provider.fallbackModelName
|
||
? this.selectModel(provider, input, provider.fallbackModelName)
|
||
: '';
|
||
|
||
return [primaryModel, fallbackModel].filter((model, index, models): model is string => (
|
||
Boolean(model) && models.indexOf(model) === index
|
||
));
|
||
}
|
||
|
||
private selectModel(provider: AiProvider, input: Record<string, unknown>, configuredModel: string) {
|
||
const model = configuredModel || (typeof input.model === 'string' ? input.model : '');
|
||
if (!this.isMoonshotProvider(provider)) return model;
|
||
|
||
const hasMedia = this.hasMediaInput(input.messages);
|
||
if (hasMedia && !this.supportsMediaInput(model)) {
|
||
return process.env.AI_KIMI_VISION_MODEL || DEFAULT_KIMI_VISION_MODEL;
|
||
}
|
||
if (this.isDeprecatedKimiK2Model(model)) {
|
||
return process.env.AI_KIMI_TEXT_MODEL || DEFAULT_KIMI_TEXT_MODEL;
|
||
}
|
||
return model;
|
||
}
|
||
|
||
private isMoonshotProvider(provider: AiProvider) {
|
||
return /moonshot\.cn/i.test(provider.endpoint);
|
||
}
|
||
|
||
private supportsMediaInput(model: string) {
|
||
return /vision/i.test(model) || this.isKimiMultimodalModel(model);
|
||
}
|
||
|
||
private isKimiMultimodalModel(model: string) {
|
||
return /^kimi-k2\.(?:5|6)$/i.test(model);
|
||
}
|
||
|
||
private supportsKimiThinkingToggle(model: string) {
|
||
return this.isKimiMultimodalModel(model);
|
||
}
|
||
|
||
private isKimiK2Model(model: string) {
|
||
return /^kimi-k2(?:[.-]|$)/i.test(model);
|
||
}
|
||
|
||
private isDeprecatedKimiK2Model(model: string) {
|
||
return /^kimi-k2(?:-|$)/i.test(model);
|
||
}
|
||
|
||
private hasMediaInput(messages: unknown) {
|
||
if (!Array.isArray(messages)) return false;
|
||
return messages.some((message) => {
|
||
if (typeof message !== 'object' || message === null || !('content' in message)) return false;
|
||
return this.hasMediaContent((message as { content?: unknown }).content);
|
||
});
|
||
}
|
||
|
||
private hasMediaContent(content: unknown): boolean {
|
||
if (!Array.isArray(content)) return false;
|
||
return content.some((part) => (
|
||
typeof part === 'object' &&
|
||
part !== null &&
|
||
(
|
||
'image_url' in part ||
|
||
'video_url' in part ||
|
||
(part as { type?: unknown }).type === 'image_url' ||
|
||
(part as { type?: unknown }).type === 'video_url'
|
||
)
|
||
));
|
||
}
|
||
|
||
private async parseProviderResponse(response: Response) {
|
||
const text = await response.text();
|
||
if (!text) return null;
|
||
try {
|
||
return JSON.parse(text) as unknown;
|
||
} catch {
|
||
return { message: text };
|
||
}
|
||
}
|
||
|
||
private async fetchProvider(url: string, init: RequestInit) {
|
||
const timeoutMs = this.providerTimeoutMs();
|
||
const controller = new AbortController();
|
||
const timeout = setTimeout(() => controller.abort(), timeoutMs);
|
||
try {
|
||
return await fetch(url, { ...init, signal: controller.signal });
|
||
} catch (error) {
|
||
if (error instanceof Error && error.name === 'AbortError') {
|
||
throw new HttpException(
|
||
{
|
||
code: 'AI_PROVIDER_TIMEOUT',
|
||
message: `AI 服务响应超时(${Math.round(timeoutMs / 1000)}秒),请稍后重试或缩短报告上下文。`,
|
||
},
|
||
504,
|
||
);
|
||
}
|
||
throw new BadRequestException(`AI 服务连接失败:${error instanceof Error ? error.message : String(error)}`);
|
||
} finally {
|
||
clearTimeout(timeout);
|
||
}
|
||
}
|
||
|
||
private async fetchProviderWithRetry(url: string, init: RequestInit) {
|
||
const retryDelays = this.retryDelays();
|
||
let response = await this.fetchProvider(url, init);
|
||
for (const delayMs of retryDelays) {
|
||
if (!(await this.shouldRetryResponse(response))) break;
|
||
await this.sleep(delayMs);
|
||
response = await this.fetchProvider(url, init);
|
||
}
|
||
return response;
|
||
}
|
||
|
||
private async fetchChatWithModelFallback(
|
||
provider: AiProvider,
|
||
input: Record<string, unknown>,
|
||
models: string[],
|
||
) {
|
||
let lastError: unknown;
|
||
|
||
for (let index = 0; index < models.length; index += 1) {
|
||
const model = models[index];
|
||
try {
|
||
const payload = this.normalizeProviderPayload({
|
||
...input,
|
||
model,
|
||
});
|
||
const response = await this.fetchProviderWithRetry(`${provider.endpoint}/chat/completions`, {
|
||
method: 'POST',
|
||
headers: this.headers(provider),
|
||
body: JSON.stringify(payload),
|
||
});
|
||
const responsePayload = await this.parseProviderResponse(response);
|
||
|
||
if (
|
||
response.ok ||
|
||
index === models.length - 1 ||
|
||
!this.shouldFallbackFromResponse(response.status, responsePayload)
|
||
) {
|
||
return { response, responsePayload };
|
||
}
|
||
|
||
lastError = this.createProviderException(response.status, responsePayload);
|
||
} catch (error) {
|
||
if (index === models.length - 1 || !this.shouldFallbackFromError(error)) {
|
||
throw error;
|
||
}
|
||
lastError = error;
|
||
}
|
||
}
|
||
|
||
throw lastError instanceof Error ? lastError : new BadRequestException('AI 服务请求失败');
|
||
}
|
||
|
||
private async shouldRetryResponse(response: Response) {
|
||
if (!RETRYABLE_PROVIDER_STATUSES.has(response.status)) return false;
|
||
const payload = await this.parseProviderResponse(response.clone());
|
||
return this.isRetryableProviderPayload(response.status, payload);
|
||
}
|
||
|
||
private shouldFallbackFromResponse(status: number, payload: unknown) {
|
||
return RETRYABLE_PROVIDER_STATUSES.has(status) && this.isRetryableProviderPayload(status, payload);
|
||
}
|
||
|
||
private shouldFallbackFromError(error: unknown) {
|
||
return error instanceof HttpException && this.shouldFallbackFromResponse(error.getStatus(), error.getResponse());
|
||
}
|
||
|
||
private isRetryableProviderPayload(status: number, payload: unknown) {
|
||
if (this.isQuotaOrBalanceError(status, payload)) return false;
|
||
return RETRYABLE_PROVIDER_STATUSES.has(status);
|
||
}
|
||
|
||
private formatProviderError(status: number, payload: unknown) {
|
||
const message =
|
||
typeof payload === 'object' && payload !== null && 'error' in payload
|
||
? JSON.stringify((payload as { error: unknown }).error)
|
||
: typeof payload === 'object' && payload !== null && 'message' in payload
|
||
? String((payload as { message: unknown }).message)
|
||
: JSON.stringify(payload);
|
||
return `AI 服务请求失败:${status}${message ? ` - ${message}` : ''}`;
|
||
}
|
||
|
||
private createProviderException(status: number, payload: unknown) {
|
||
return new HttpException(
|
||
{
|
||
code: this.providerErrorCode(status, payload),
|
||
message: this.formatProviderError(status, payload),
|
||
},
|
||
status,
|
||
);
|
||
}
|
||
|
||
private providerErrorCode(status: number, payload: unknown) {
|
||
if (this.isQuotaOrBalanceError(status, payload)) return 'AI_PROVIDER_QUOTA_EXCEEDED';
|
||
|
||
const providerType =
|
||
typeof payload === 'object' && payload !== null && 'type' in payload
|
||
? String((payload as { type: unknown }).type)
|
||
: typeof payload === 'object' && payload !== null && 'error' in payload
|
||
? this.extractProviderErrorType((payload as { error: unknown }).error)
|
||
: '';
|
||
|
||
if (status === 429 && /overloaded/i.test(providerType)) return 'AI_PROVIDER_OVERLOADED';
|
||
if (status === 429) return 'AI_PROVIDER_RATE_LIMITED';
|
||
if (status === 504) return 'AI_PROVIDER_TIMEOUT';
|
||
if (status >= 500) return 'AI_PROVIDER_UNAVAILABLE';
|
||
return 'AI_PROVIDER_ERROR';
|
||
}
|
||
|
||
private isQuotaOrBalanceError(status: number, payload: unknown) {
|
||
if (status === 402) return true;
|
||
const text = this.providerErrorText(payload);
|
||
return /quota|balance|billing|insufficient|suspended/i.test(text);
|
||
}
|
||
|
||
private providerErrorText(payload: unknown): string {
|
||
if (typeof payload === 'string') return payload;
|
||
if (typeof payload !== 'object' || payload === null) return '';
|
||
|
||
const error = 'error' in payload ? (payload as { error: unknown }).error : null;
|
||
const topLevel = [
|
||
'type' in payload ? (payload as { type: unknown }).type : '',
|
||
'message' in payload ? (payload as { message: unknown }).message : '',
|
||
'code' in payload ? (payload as { code: unknown }).code : '',
|
||
];
|
||
|
||
if (typeof error === 'object' && error !== null) {
|
||
topLevel.push(
|
||
'type' in error ? (error as { type: unknown }).type : '',
|
||
'message' in error ? (error as { message: unknown }).message : '',
|
||
'code' in error ? (error as { code: unknown }).code : '',
|
||
);
|
||
} else if (typeof error === 'string') {
|
||
topLevel.push(error);
|
||
}
|
||
|
||
return topLevel.filter(Boolean).join(' ');
|
||
}
|
||
|
||
private extractProviderErrorType(error: unknown) {
|
||
return typeof error === 'object' && error !== null && 'type' in error
|
||
? String((error as { type: unknown }).type)
|
||
: '';
|
||
}
|
||
|
||
private retryDelays() {
|
||
const raw = process.env.AI_PROVIDER_RETRY_DELAYS_MS;
|
||
if (!raw) return DEFAULT_RETRY_DELAYS_MS;
|
||
return raw
|
||
.split(',')
|
||
.map((value) => Number(value.trim()))
|
||
.filter((value) => Number.isFinite(value) && value >= 0);
|
||
}
|
||
|
||
private providerTimeoutMs() {
|
||
const value = Number(process.env.AI_PROVIDER_TIMEOUT_MS);
|
||
return Number.isFinite(value) && value > 0 ? value : DEFAULT_PROVIDER_TIMEOUT_MS;
|
||
}
|
||
|
||
private sleep(ms: number) {
|
||
return new Promise((resolve) => setTimeout(resolve, ms));
|
||
}
|
||
}
|