import { pipeline, env, RawImage, Pipeline } from 'sillytavern-transformers'; import { getConfigValue } from './util.js'; import path from 'path'; import fs from 'fs'; configureTransformers(); function configureTransformers() { // Limit the number of threads to 1 to avoid issues on Android env.backends.onnx.wasm.numThreads = 1; // Use WASM from a local folder to avoid CDN connections env.backends.onnx.wasm.wasmPaths = path.join(process.cwd(), 'dist') + path.sep; } const tasks = { 'text-classification': { defaultModel: 'Cohee/distilbert-base-uncased-go-emotions-onnx', pipeline: null, configField: 'extras.classificationModel', quantized: true, }, 'image-to-text': { defaultModel: 'Xenova/vit-gpt2-image-captioning', pipeline: null, configField: 'extras.captioningModel', quantized: true, }, 'feature-extraction': { defaultModel: 'Xenova/all-mpnet-base-v2', pipeline: null, configField: 'extras.embeddingModel', quantized: true, }, 'text-generation': { defaultModel: 'Cohee/fooocus_expansion-onnx', pipeline: null, configField: 'extras.promptExpansionModel', quantized: false, }, 'automatic-speech-recognition': { defaultModel: 'Xenova/whisper-small', pipeline: null, configField: 'extras.speechToTextModel', quantized: true, }, 'text-to-speech': { defaultModel: 'Xenova/speecht5_tts', pipeline: null, configField: 'extras.textToSpeechModel', quantized: false, }, }; /** * Gets a RawImage object from a base64-encoded image. * @param {string} image Base64-encoded image * @returns {Promise} Object representing the image */ async function getRawImage(image) { try { const buffer = Buffer.from(image, 'base64'); const byteArray = new Uint8Array(buffer); const blob = new Blob([byteArray]); const rawImage = await RawImage.fromBlob(blob); return rawImage; } catch { return null; } } /** * Gets the model to use for a given transformers.js task. * @param {string} task The task to get the model for * @returns {string} The model to use for the given task */ function getModelForTask(task) { const defaultModel = tasks[task].defaultModel; try { const model = getConfigValue(tasks[task].configField, null); return model || defaultModel; } catch (error) { console.warn('Failed to read config.yaml, using default classification model.'); return defaultModel; } } async function migrateCacheToDataDir() { const oldCacheDir = path.join(process.cwd(), 'cache'); const newCacheDir = path.join(global.DATA_ROOT, '_cache'); if (!fs.existsSync(newCacheDir)) { fs.mkdirSync(newCacheDir, { recursive: true }); } if (fs.existsSync(oldCacheDir) && fs.statSync(oldCacheDir).isDirectory()) { const files = fs.readdirSync(oldCacheDir); if (files.length === 0) { return; } console.log('Migrating model cache files to data directory. Please wait...'); for (const file of files) { try { const oldPath = path.join(oldCacheDir, file); const newPath = path.join(newCacheDir, file); fs.cpSync(oldPath, newPath, { recursive: true, force: true }); fs.rmSync(oldPath, { recursive: true, force: true }); } catch (error) { console.warn('Failed to migrate cache file. The model will be re-downloaded.', error); } } } } /** * Gets the transformers.js pipeline for a given task. * @param {import('sillytavern-transformers').PipelineType} task The task to get the pipeline for * @param {string} forceModel The model to use for the pipeline, if any * @returns {Promise} Pipeline for the task */ async function getPipeline(task, forceModel = '') { await migrateCacheToDataDir(); if (tasks[task].pipeline) { if (forceModel === '' || tasks[task].currentModel === forceModel) { return tasks[task].pipeline; } console.log('Disposing transformers.js pipeline for for task', task, 'with model', tasks[task].currentModel); await tasks[task].pipeline.dispose(); } const cacheDir = path.join(global.DATA_ROOT, '_cache'); const model = forceModel || getModelForTask(task); const localOnly = getConfigValue('extras.disableAutoDownload', false); console.log('Initializing transformers.js pipeline for task', task, 'with model', model); const instance = await pipeline(task, model, { cache_dir: cacheDir, quantized: tasks[task].quantized ?? true, local_files_only: localOnly }); tasks[task].pipeline = instance; tasks[task].currentModel = model; return instance; } export default { getPipeline, getRawImage, };