.
├─.changeset
│ └─config.json
├─0.8-migration-guide.md
├─CODE_OF_CONDUCT.md
├─CONTRIBUTING.md
├─LICENSE
├─NOTICE
├─README.md
├─examples
│ ├─.env.example
│ ├─Dockerfile-example
│ ├─browser
│ │ ├─browser_track.py
│ │ └─standalone_app.py
│ ├─conversation_persistor.py
│ ├─echo-agent.py
│ ├─hive-moderation-agent
│ │ ├─README.md
│ │ ├─agent.py
│ │ ├─hive_data_classes.py
│ │ └─requirements.txt
│ ├─minimal_worker.py
│ ├─multimodal-agent
│ │ ├─gemini_agent.py
│ │ └─openai_agent.py
│ ├─participant-entrypoint
│ │ ├─README.md
│ │ ├─participant_entrypoint.py
│ │ └─requirements.txt
│ ├─simple-color
│ │ ├─README.md
│ │ ├─agent.py
│ │ └─requirements.txt
│ ├─speech-to-text
│ │ ├─README.md
│ │ ├─requirements.txt
│ │ └─transcriber.py
│ ├─text-to-speech
│ │ ├─README.md
│ │ ├─cartesia_tts.py
│ │ ├─elevenlabs_tts.py
│ │ ├─neuphonic_tts.py
│ │ ├─openai_tts.py
│ │ ├─requirements.txt
│ │ └─sync_tts_transcription.py
│ └─voice-pipeline-agent
│ ├─README.md
│ ├─cost_metrics.py
│ ├─custom_pronunciation.py
│ ├─fallback_adapter.py
│ ├─function_calling_weather.py
│ ├─gemini_voice_agent.py
│ ├─llamaindex-rag
│ │ ├─README.md
│ │ ├─chat_engine.py
│ │ ├─data
│ │ │ └─raw_data.txt
│ │ ├─query_engine.py
│ │ └─retrieval.py
│ ├─minimal_assistant.py
│ ├─openai_assistant.py
│ ├─requirements.txt
│ ├─save_chatctx.py
│ ├─simple-rag
│ │ ├─assistant.py
│ │ ├─build_data.py
│ │ └─raw_data.txt
│ └─turn_detector.py
├─livekit-agents
│ ├─CHANGELOG.md
│ ├─README.md
│ ├─livekit
│ │ └─agents
│ │ ├─__init__.py
│ │ ├─_exceptions.py
│ │ ├─cli
│ │ │ ├─__init__.py
│ │ │ ├─cli.py
│ │ │ ├─log.py
│ │ │ ├─proto.py
│ │ │ └─watcher.py
│ │ ├─http_server.py
│ │ ├─inference_runner.py
│ │ ├─ipc
│ │ │ ├─__init__.py
│ │ │ ├─channel.py
│ │ │ ├─inference_executor.py
│ │ │ ├─inference_proc_executor.py
│ │ │ ├─inference_proc_lazy_main.py
│ │ │ ├─job_executor.py
│ │ │ ├─job_proc_executor.py
│ │ │ ├─job_proc_lazy_main.py
│ │ │ ├─job_thread_executor.py
│ │ │ ├─log_queue.py
│ │ │ ├─proc_client.py
│ │ │ ├─proc_pool.py
│ │ │ ├─proto.py
│ │ │ └─supervised_proc.py
│ │ ├─job.py
│ │ ├─llm
│ │ │ ├─__init__.py
│ │ │ ├─chat_context.py
│ │ │ ├─fallback_adapter.py
│ │ │ ├─function_context.py
│ │ │ └─llm.py
│ │ ├─log.py
│ │ ├─metrics
│ │ │ ├─__init__.py
│ │ │ ├─base.py
│ │ │ ├─usage_collector.py
│ │ │ └─utils.py
│ │ ├─multimodal
│ │ │ ├─__init__.py
│ │ │ ├─agent_playout.py
│ │ │ └─multimodal_agent.py
│ │ ├─pipeline
│ │ │ ├─__init__.py
│ │ │ ├─agent_output.py
│ │ │ ├─agent_playout.py
│ │ │ ├─human_input.py
│ │ │ ├─log.py
│ │ │ ├─pipeline_agent.py
│ │ │ ├─plotter.py
│ │ │ └─speech_handle.py
│ │ ├─plugin.py
│ │ ├─py.typed
│ │ ├─stt
│ │ │ ├─__init__.py
│ │ │ ├─fallback_adapter.py
│ │ │ ├─stream_adapter.py
│ │ │ └─stt.py
│ │ ├─tokenize
│ │ │ ├─__init__.py
│ │ │ ├─_basic_hyphenator.py
│ │ │ ├─_basic_paragraph.py
│ │ │ ├─_basic_sent.py
│ │ │ ├─_basic_word.py
│ │ │ ├─basic.py
│ │ │ ├─token_stream.py
│ │ │ ├─tokenizer.py
│ │ │ └─utils.py
│ │ ├─transcription
│ │ │ ├─__init__.py
│ │ │ ├─_utils.py
│ │ │ ├─stt_forwarder.py
│ │ │ └─tts_forwarder.py
│ │ ├─tts
│ │ │ ├─__init__.py
│ │ │ ├─fallback_adapter.py
│ │ │ ├─stream_adapter.py
│ │ │ └─tts.py
│ │ ├─types.py
│ │ ├─utils
│ │ │ ├─__init__.py
│ │ │ ├─_message_change.py
│ │ │ ├─aio
│ │ │ │ ├─__init__.py
│ │ │ │ ├─channel.py
│ │ │ │ ├─debug.py
│ │ │ │ ├─duplex_unix.py
│ │ │ │ ├─interval.py
│ │ │ │ ├─itertools.py
│ │ │ │ ├─sleep.py
│ │ │ │ └─task_set.py
│ │ │ ├─audio.py
│ │ │ ├─codecs
│ │ │ │ ├─__init__.py
│ │ │ │ └─decoder.py
│ │ │ ├─connection_pool.py
│ │ │ ├─exp_filter.py
│ │ │ ├─http_context.py
│ │ │ ├─hw
│ │ │ │ ├─__init__.py
│ │ │ │ └─cpu.py
│ │ │ ├─images
│ │ │ │ ├─__init__.py
│ │ │ │ └─image.py
│ │ │ ├─log.py
│ │ │ ├─misc.py
│ │ │ └─moving_average.py
│ │ ├─vad.py
│ │ ├─version.py
│ │ ├─voice_assistant
│ │ │ └─__init__.py
│ │ └─worker.py
│ ├─package.json
│ ├─pyproject.toml
│ └─setup.py
├─livekit-plugins
│ ├─install_local.sh
│ ├─install_plugins_editable.sh
│ ├─livekit-plugins-anthropic
│ │ ├─CHANGELOG.md
│ │ ├─README.md
│ │ ├─livekit
│ │ │ └─plugins
│ │ │ └─anthropic
│ │ │ ├─__init__.py
│ │ │ ├─llm.py
│ │ │ ├─log.py
│ │ │ ├─models.py
│ │ │ ├─py.typed
│ │ │ └─version.py
│ │ ├─package.json
│ │ ├─pyproject.toml
│ │ └─setup.py
│ ├─livekit-plugins-assemblyai
│ │ ├─CHANGELOG.md
│ │ ├─README.md
│ │ ├─livekit
│ │ │ └─plugins
│ │ │ └─assemblyai
│ │ │ ├─__init__.py
│ │ │ ├─log.py
│ │ │ ├─py.typed
│ │ │ ├─stt.py
│ │ │ └─version.py
│ │ ├─package.json
│ │ ├─pyproject.toml
│ │ └─setup.py
│ ├─livekit-plugins-aws
│ │ ├─CHANGELOG.md
│ │ ├─README.md
│ │ ├─livekit
│ │ │ └─plugins
│ │ │ └─aws
│ │ │ ├─__init__.py
│ │ │ ├─_utils.py
│ │ │ ├─llm.py
│ │ │ ├─log.py
│ │ │ ├─models.py
│ │ │ ├─py.typed
│ │ │ ├─stt.py
│ │ │ ├─tts.py
│ │ │ └─version.py
│ │ ├─package.json
│ │ ├─pyproject.toml
│ │ └─setup.py
│ ├─livekit-plugins-azure
│ │ ├─CHANGELOG.md
│ │ ├─README.md
│ │ ├─livekit
│ │ │ └─plugins
│ │ │ └─azure
│ │ │ ├─__init__.py
│ │ │ ├─log.py
│ │ │ ├─py.typed
│ │ │ ├─stt.py
│ │ │ ├─tts.py
│ │ │ └─version.py
│ │ ├─package.json
│ │ ├─pyproject.toml
│ │ └─setup.py
│ ├─livekit-plugins-browser
│ │ ├─.clang-format
│ │ ├─CHANGELOG.md
│ │ ├─CMakeLists.txt
│ │ ├─LICENSE.txt
│ │ ├─README.md
│ │ ├─cmake
│ │ │ └─DownloadCEF.cmake
│ │ ├─livekit
│ │ │ └─plugins
│ │ │ └─browser
│ │ │ ├─__init__.py
│ │ │ ├─log.py
│ │ │ ├─proc.py
│ │ │ ├─proc_main.py
│ │ │ ├─proto.py
│ │ │ ├─py.typed
│ │ │ ├─resources
│ │ │ │ └─__init__.py
│ │ │ └─version.py
│ │ ├─package.json
│ │ ├─pyproject.toml
│ │ ├─setup.py
│ │ └─src
│ │ ├─CMakeLists.txt
│ │ ├─agents_python.cpp
│ │ ├─agents_python.hpp
│ │ ├─app.cpp
│ │ ├─app.hpp
│ │ ├─app_mac.mm
│ │ ├─browser_handle.cpp
│ │ ├─browser_handle.hpp
│ │ ├─dev_renderer.cpp
│ │ ├─dev_renderer.hpp
│ │ ├─dummy.cpp
│ │ ├─gleq.h
│ │ ├─handler.cpp
│ │ ├─handler.hpp
│ │ ├─helper_main_linux.cpp
│ │ ├─helper_main_mac.mm
│ │ ├─helper_main_win.cpp
│ │ ├─keyboard_codes.h
│ │ ├─resources
│ │ │ ├─lkcefapp-Info.plist
│ │ │ └─lkcefhelper-Info.plist
│ │ └─run_browser.py
│ ├─livekit-plugins-cartesia
│ │ ├─CHANGELOG.md
│ │ ├─README.md
│ │ ├─livekit
│ │ │ └─plugins
│ │ │ └─cartesia
│ │ │ ├─__init__.py
│ │ │ ├─log.py
│ │ │ ├─models.py
│ │ │ ├─py.typed
│ │ │ ├─tts.py
│ │ │ └─version.py
│ │ ├─package.json
│ │ ├─pyproject.toml
│ │ └─setup.py
│ ├─livekit-plugins-clova
│ │ ├─README.md
│ │ ├─livekit
│ │ │ └─plugins
│ │ │ └─clova
│ │ │ ├─__init__.py
│ │ │ ├─common.py
│ │ │ ├─constants.py
│ │ │ ├─log.py
│ │ │ ├─models.py
│ │ │ ├─stt.py
│ │ │ └─version.py
│ │ ├─pyproject.toml
│ │ └─setup.py
│ ├─livekit-plugins-deepgram
│ │ ├─CHANGELOG.md
│ │ ├─README.md
│ │ ├─livekit
│ │ │ └─plugins
│ │ │ └─deepgram
│ │ │ ├─__init__.py
│ │ │ ├─_utils.py
│ │ │ ├─log.py
│ │ │ ├─models.py
│ │ │ ├─py.typed
│ │ │ ├─stt.py
│ │ │ ├─tts.py
│ │ │ └─version.py
│ │ ├─package.json
│ │ ├─pyproject.toml
│ │ └─setup.py
│ ├─livekit-plugins-elevenlabs
│ │ ├─CHANGELOG.md
│ │ ├─README.md
│ │ ├─livekit
│ │ │ └─plugins
│ │ │ └─elevenlabs
│ │ │ ├─__init__.py
│ │ │ ├─log.py
│ │ │ ├─models.py
│ │ │ ├─py.typed
│ │ │ ├─tts.py
│ │ │ └─version.py
│ │ ├─package.json
│ │ ├─pyproject.toml
│ │ └─setup.py
│ ├─livekit-plugins-fal
│ │ ├─CHANGELOG.md
│ │ ├─README.md
│ │ ├─livekit
│ │ │ └─plugins
│ │ │ └─fal
│ │ │ ├─__init__.py
│ │ │ ├─log.py
│ │ │ ├─py.typed
│ │ │ ├─stt.py
│ │ │ └─version.py
│ │ ├─package.json
│ │ ├─pyproject.toml
│ │ └─setup.py
│ ├─livekit-plugins-google
│ │ ├─CHANGELOG.md
│ │ ├─README.md
│ │ ├─livekit
│ │ │ └─plugins
│ │ │ └─google
│ │ │ ├─__init__.py
│ │ │ ├─_utils.py
│ │ │ ├─beta
│ │ │ │ ├─__init__.py
│ │ │ │ └─realtime
│ │ │ │ ├─__init__.py
│ │ │ │ ├─api_proto.py
│ │ │ │ ├─realtime_api.py
│ │ │ │ └─transcriber.py
│ │ │ ├─llm.py
│ │ │ ├─log.py
│ │ │ ├─models.py
│ │ │ ├─py.typed
│ │ │ ├─stt.py
│ │ │ ├─tts.py
│ │ │ └─version.py
│ │ ├─package.json
│ │ ├─pyproject.toml
│ │ └─setup.py
│ ├─livekit-plugins-groq
│ │ ├─CHANGELOG.md
│ │ ├─README.md
│ │ ├─livekit
│ │ │ └─plugins
│ │ │ └─groq
│ │ │ ├─__init__.py
│ │ │ ├─log.py
│ │ │ ├─models.py
│ │ │ ├─services.py
│ │ │ ├─tts.py
│ │ │ └─version.py
│ │ ├─package.json
│ │ ├─pyproject.toml
│ │ └─setup.py
│ ├─livekit-plugins-llama-index
│ │ ├─CHANGELOG.md
│ │ ├─README.md
│ │ ├─livekit
│ │ │ └─plugins
│ │ │ └─llama_index
│ │ │ ├─__init__.py
│ │ │ ├─llm.py
│ │ │ ├─log.py
│ │ │ └─version.py
│ │ ├─package.json
│ │ ├─pyproject.toml
│ │ └─setup.py
│ ├─livekit-plugins-minimal
│ │ ├─CHANGELOG.md
│ │ ├─README.md
│ │ ├─livekit
│ │ │ └─plugins
│ │ │ └─minimal
│ │ │ ├─__init__.py
│ │ │ ├─log.py
│ │ │ └─version.py
│ │ ├─package.json
│ │ ├─pyproject.toml
│ │ └─setup.py
│ ├─livekit-plugins-neuphonic
│ │ ├─CHANGELOG.md
│ │ ├─README.md
│ │ ├─livekit
│ │ │ └─plugins
│ │ │ └─neuphonic
│ │ │ ├─__init__.py
│ │ │ ├─log.py
│ │ │ ├─models.py
│ │ │ ├─py.typed
│ │ │ ├─tts.py
│ │ │ └─version.py
│ │ ├─package.json
│ │ ├─pyproject.toml
│ │ └─setup.py
│ ├─livekit-plugins-nltk
│ │ ├─CHANGELOG.md
│ │ ├─README.md
│ │ ├─livekit
│ │ │ └─plugins
│ │ │ └─nltk
│ │ │ ├─__init__.py
│ │ │ ├─log.py
│ │ │ ├─py.typed
│ │ │ ├─sentence_tokenizer.py
│ │ │ └─version.py
│ │ ├─package.json
│ │ ├─pyproject.toml
│ │ └─setup.py
│ ├─livekit-plugins-openai
│ │ ├─CHANGELOG.md
│ │ ├─README.md
│ │ ├─livekit
│ │ │ └─plugins
│ │ │ └─openai
│ │ │ ├─__init__.py
│ │ │ ├─_oai_api.py
│ │ │ ├─beta
│ │ │ │ ├─__init__.py
│ │ │ │ └─assistant_llm.py
│ │ │ ├─embeddings.py
│ │ │ ├─llm.py
│ │ │ ├─log.py
│ │ │ ├─models.py
│ │ │ ├─py.typed
│ │ │ ├─realtime
│ │ │ │ ├─__init__.py
│ │ │ │ ├─api_proto.py
│ │ │ │ ├─log.py
│ │ │ │ ├─realtime_model.py
│ │ │ │ └─remote_items.py
│ │ │ ├─stt.py
│ │ │ ├─tts.py
│ │ │ ├─utils.py
│ │ │ └─version.py
│ │ ├─package.json
│ │ ├─pyproject.toml
│ │ └─setup.py
│ ├─livekit-plugins-playai
│ │ ├─CHANGELOG.md
│ │ ├─README.md
│ │ ├─livekit
│ │ │ └─plugins
│ │ │ └─playai
│ │ │ ├─__init__.py
│ │ │ ├─log.py
│ │ │ ├─models.py
│ │ │ ├─py.typed
│ │ │ ├─tts.py
│ │ │ └─version.py
│ │ ├─package.json
│ │ ├─pyproject.toml
│ │ └─setup.py
│ ├─livekit-plugins-rag
│ │ ├─CHANGELOG.md
│ │ ├─README.md
│ │ ├─livekit
│ │ │ └─plugins
│ │ │ └─rag
│ │ │ ├─__init__.py
│ │ │ ├─annoy.py
│ │ │ ├─chunking.py
│ │ │ ├─log.py
│ │ │ ├─py.typed
│ │ │ └─version.py
│ │ ├─package.json
│ │ ├─pyproject.toml
│ │ └─setup.py
│ ├─livekit-plugins-resemble
│ │ ├─CHANGELOG.md
│ │ ├─README.md
│ │ ├─livekit
│ │ │ └─plugins
│ │ │ └─resemble
│ │ │ ├─__init__.py
│ │ │ ├─log.py
│ │ │ ├─models.py
│ │ │ ├─py.typed
│ │ │ ├─tts.py
│ │ │ └─version.py
│ │ ├─package.json
│ │ ├─pyproject.toml
│ │ └─setup.py
│ ├─livekit-plugins-rime
│ │ ├─CHANGELOG.md
│ │ ├─README.md
│ │ ├─livekit
│ │ │ └─plugins
│ │ │ └─rime
│ │ │ ├─__init__.py
│ │ │ ├─log.py
│ │ │ ├─models.py
│ │ │ ├─py.typed
│ │ │ ├─tts.py
│ │ │ └─version.py
│ │ ├─package.json
│ │ ├─pyproject.toml
│ │ └─setup.py
│ ├─livekit-plugins-silero
│ │ ├─CHANGELOG.md
│ │ ├─README.md
│ │ ├─livekit
│ │ │ └─plugins
│ │ │ └─silero
│ │ │ ├─__init__.py
│ │ │ ├─log.py
│ │ │ ├─onnx_model.py
│ │ │ ├─py.typed
│ │ │ ├─resources
│ │ │ │ ├─__init__.py
│ │ │ │ └─silero_vad.onnx
│ │ │ ├─vad.py
│ │ │ └─version.py
│ │ ├─package.json
│ │ ├─pyproject.toml
│ │ └─setup.py
│ ├─livekit-plugins-speechmatics
│ │ ├─CHANGELOG.md
│ │ ├─README.md
│ │ ├─livekit
│ │ │ └─plugins
│ │ │ └─speechmatics
│ │ │ ├─__init__.py
│ │ │ ├─log.py
│ │ │ ├─py.typed
│ │ │ ├─stt.py
│ │ │ ├─types.py
│ │ │ ├─utils.py
│ │ │ └─version.py
│ │ ├─package.json
│ │ ├─project.toml
│ │ └─setup.py
│ └─livekit-plugins-turn-detector
│ ├─CHANGELOG.md
│ ├─README.md
│ ├─livekit
│ │ └─plugins
│ │ └─turn_detector
│ │ ├─__init__.py
│ │ ├─base.py
│ │ ├─english.py
│ │ ├─log.py
│ │ ├─models.py
│ │ ├─multilingual.py
│ │ └─version.py
│ ├─package.json
│ ├─pyproject.toml
│ └─setup.py
├─mypy.ini
├─package.json
├─pnpm-lock.yaml
├─pnpm-workspace.yaml
├─ruff.toml
└─tests
├─__init__.py
├─conftest.py
├─fake_stt.py
├─fake_tts.py
├─hearts.rgba
├─long.mp3
├─long_synthesize.txt
├─long_transcript.txt
├─pytest.ini
├─test-requirements.txt
├─test_aio.py
├─test_build_func_desc.py
├─test_connection_pool.py
├─test_create_func.py
├─test_decoder.py
├─test_ipc.py
├─test_llm.py
├─test_message_change.py
├─test_stt.py
├─test_stt_fallback.py
├─test_tokenizer.py
├─test_tts.py
├─test_tts_fallback.py
├─test_vad.py
└─utils.py
{
"$schema": "https://unpkg.com/@changesets/config@2.2.0/schema.json",
"changelog": [
"@livekit/changesets-changelog-github",
{
"repo": "livekit/agents"
}
],
"commit": false,
"fixed": [],
"linked": [],
"access": "public",
"baseBranch": "main",
"updateInternalDependencies": "patch",
"privatePackages": { "version": true, "tag": true }
}
# Migrating to 0.8.x
v0.8 is a major release of the framework, featuring significant reliability improvements to VoiceAssistant. This update includes a few breaking API changes that will impact the way you build your agents. We strive to minimize breaking changes, and will stabilize the API as we approach version 1.0.
## Job and Worker API
### Specifying your entrypoint function
`entrypoint_fnc` is now a parameter in WorkerOptions. Previously, you were required to explicitly accept the job.
### Namespace has been removed
We've removed the namespace option in order to simplify the registration process. In future versions, it'll be possible to provide an explicit `agent_name` to launch multiple kinds of agents for each room.
### Connecting to room is explicit
You now need to call `await ctx.connect()` to initiate the connection to the room. This allows for pre-connect setup (such as callback registrations) to avoid race conditions.
### Example
The above changes are reflected in the following minimal example:
```python
from livekit.agents import JobContext, JobRequest, WorkerOptions, cli
async def job_entrypoint(ctx: JobContext):
await ctx.connect()
# your logic here
...
if __name__ == "__main__":
cli.run_app(
WorkerOptions(entrypoint_fnc=job_entrypoint)
)
VoiceAssistant API remains mostly unchanged, despite significant improvements to functionality and internals. However, there have been changes to the configuration.
transcription
paramclass VoiceAssistant(utils.EventEmitter[EventTypes]):
def __init__(
self,
*,
vad: vad.VAD,
stt: stt.STT,
llm: LLM,
tts: tts.TTS,
chat_ctx: ChatContext | None = None,
fnc_ctx: FunctionContext | None = None,
allow_interruptions: bool = True,
interrupt_speech_duration: float = 0.6,
interrupt_min_words: int = 0,
preemptive_synthesis: bool = True,
transcription: AssistantTranscriptionOptions = AssistantTranscriptionOptions(),
will_synthesize_assistant_reply: WillSynthesizeAssistantReply = _default_will_synthesize_assistant_reply,
plotting: bool = False,
loop: asyncio.AbstractEventLoop | None = None,
) -> None:
...
The LLM class has been restructured to enhance ergonomics and improve the function calling support.
Function calling has gotten a complete overhaul in v0.8.0. The primary breaking change is that function calls are now NOT automatically invoked when iterating the LLM stream. LLMStream.execute_functions
needs to be called instead. (VoiceAssistant handles this automatically)
Previously, LLM.chat() was an async method that returned an LLMStream (which itself was an AsyncIterable).
We found it easier and less-confusing for LLM.chat() to be synchronous, while still returning the same AsyncIterable LLMStream.
history
has been renamed to chat_ctx
In order to improve consistency and reduce confusion.
chat_ctx = llm.ChatContext()
chat_ctx.append(role="user", text="user message")
stream = llm_plugin.chat(chat_ctx=chat_ctx)
Previously, to communicate to a STT provider that you have sent enough input to generate a response - you could push_frame(None) to coax the TTS into synthesizing a response.
In v0.8.0 that API has been removed and replaced with flush()
end_input
signals to the STT provider that the input is complete and no additional input will follow. Previously, this was done using aclose(wait=True).
The wait
arg of aclose has been removed in favor of SpeechStream.end_input (see above). Now, if you call TTS.aclose()
without first calling STT.end_input, the behavior will be that the request is cancelled.
stt_stream = my_stt_instance.stream()
async for ev in audio_stream:
stt_stream.push_frame(ev.frame)
# optionally flush when enough frames have been pushed
stt_stream.flush()
stt_stream.end_input()
await stt_stream.aclose()
SynthesizedAudio dataclass has gone through a major change
# New SynthesizedAudio dataclass
@dataclass
class SynthesizedAudio:
request_id: str
"""Request ID (one segment could be made up of multiple requests)"""
segment_id: str
"""Segment ID, each segment is separated by a flush"""
frame: rtc.AudioFrame
"""Synthesized audio frame"""
delta_text: str = ""
"""Current segment of the synthesized audio"""
#Old SynthesizedAudio dataclass
@dataclass
class SynthesizedAudio:
text: str
data: rtc.AudioFrame
The SynthesisEvent has been removed entirely. All occurrences of it have been replaced with SynthesizedAudio
Similar to the STT changes, this coaxes the TTS provider into generating a response. The SynthesizedAudio response will have a new segment_id after calls to flush().
Similar to the STT changes, aclose(wait=True) has been replaced.
Similar to the STT changes, the wait arg has been removed.
tts_stream = my_tts_instance.stream()
tts_stream.push_text("This is the first sentence")
tts_stream.flush()
tts_stream.push_text("This is the second sentence")
tts_stream.end_input()
await tts_stream.aclose()
The same changes made to STT and TTS have also been made to VAD
vad_stream = my_vad_instance.stream()
async for ev in audio_stream:
vad_stream.push_frame(ev.frame)
# optionally flush when enough frames have been pushed
vad_stream.flush()
vad_stream.end_input()
await vad_stream.aclose()
## CODE_OF_CONDUCT.md
```md
# Code of Conduct
## Our Pledge
We are committed to providing a welcoming, respectful, and harassment-free
environment for everyone, regardless of background, experience, or identity. We
strive to foster a positive and inclusive community where all participants feel
valued and empowered to contribute.
## Our Standards
### Expected behavior
* Demonstrating empathy and kindness toward other people
* Being respectful of differing opinions, viewpoints, and experiences
* Giving and gracefully accepting constructive feedback
* Accepting responsibility and apologizing to those affected by our mistakes,
and learning from the experience
* Focusing on what is best not just for us as individuals, but for the overall
community
### Unacceptable behavior
* Harassment, discrimination, or offensive comments regarding identity,
appearance, or background
* Publishing others' private information, such as a physical or email address,
without their explicit permission
* Personal attacks, insults, or disruptive behavior that undermines the
community
* Posting content or engaging in activities that are inappropriate, unlawful, or
harmful
## Enforcement Responsibilities
Community leaders are responsible for clarifying and enforcing our standards of
acceptable behavior and will take appropriate and fair corrective action in
response to any behavior that they deem inappropriate, threatening, offensive,
or harmful.
Community leaders have the right and responsibility to remove, edit, or reject
comments, commits, code, wiki edits, issues, and other contributions that are
not aligned to this Code of Conduct, and will communicate reasons for moderation
decisions when appropriate.
## Scope
This Code of Conduct applies within all community spaces, and also applies when
an individual is officially representing the community in public spaces.
Examples of representing our community include using an official email address,
posting via an official social media account, or acting as an appointed
representative at an online or offline event.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported to the community leaders responsible for enforcement at
<conduct@livekit.io>.
All complaints will be reviewed and investigated promptly and fairly.
All community leaders are obligated to respect the privacy and security of the
reporter of any incident.
## Enforcement Guidelines
Violations of this Code of Conduct may result in removal from the community,
project, or repository. Severe violations may result in a permanent ban.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
version 2.1, available at
[https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1].
It has been subtly adapted for formatting and brevity, as well as changing the
actions taken after a violation.
Community Impact Guidelines were inspired by
[Mozilla's code of conduct enforcement ladder][Mozilla CoC].
For answers to common questions about this code of conduct, see the FAQ at
[https://www.contributor-covenant.org/faq][FAQ]. Translations are available at
[https://www.contributor-covenant.org/translations][translations].
[homepage]: https://www.contributor-covenant.org
[v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html
[Mozilla CoC]: https://github.com/mozilla/diversity
[FAQ]: https://www.contributor-covenant.org/faq
[translations]: https://www.contributor-covenant.org/translations
# Contributing to livekit/agents
The LiveKit Agents framework is an open-source project, and we welcome any contribution from anyone
willing to work in good faith with the community. No contribution is too small!
## Code of Conduct
The LiveKit Agents project has a [Code of Conduct](/CODE_OF_CONDUCT.md) to which all contributors
must adhere.
## Contribute code
There are many ways you can contribute code to the project:
- **Write a plugin**: if there is a TTS/STT/LLM provider you use that isn't on our plugins list,
feel free to write a plugin for it! Refer to the source code of similar plugins to see how they're
built.
- **Fix bugs**: we strive to make this framework as reliable as possible, and we'd appreciate your
help with squashing bugs and improving stability. Follow the guidelines below for information
about authoring pull requests.
- **Add new features**: we're open to adding new features to the framework, though we ask that you
open an issue first to discuss the viability and scope of the new functionality before starting
work.
Our continuous integration requires a few additional code quality steps for your pull request to
be approved:
- Run `ruff check --fix` and `ruff format` before committing your changes to ensure consistent file
formatting and best practices.
- If writing new methods/enums/classes, document them. This project uses
[pdoc3](https://pdoc3.github.io/pdoc/) for automatic API documentation generation, and every new
addition has to be properly documented.
- On your first pull request, the CLA Assistant bot will give you a link to sign this project's
Contributor License Agreement, required to add your code to the repository.
- There's no need to mess around with `CHANGELOG.md` or package manifests — we have a bot handle
that for us. A maintainer will add the necessary notes before merging.
## Assist others in the community
If you can't contribute code, you can still help us greatly by helping out community members who
may have questions about the framework and how to use it. Join the `#agents` channel on
[our Slack](https://livekit.io/join-slack).
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Copyright 2023 LiveKit, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
<!--BEGIN_BANNER_IMAGE-->
<picture>
<source media="(prefers-color-scheme: dark)" srcset="/.github/banner_dark.png">
<source media="(prefers-color-scheme: light)" srcset="/.github/banner_light.png">
<img style="width:100%;" alt="The LiveKit icon, the name of the repository and some sample code in the background." src="https://raw.githubusercontent.com/livekit/agents/main/.github/banner_light.png">
</picture>
<!--END_BANNER_IMAGE-->
<br /><br />
Looking for the JS/TS library? Check out [AgentsJS](https://github.com/livekit/agents-js)
## ✨ NEW ✨
### In-house phrase endpointing model
We’ve trained a new, open weights phrase endpointing model that significantly improves end-of-turn detection and conversational flow between voice agents and users by reducing agent interruptions. Optimized to run on CPUs, it’s available via [livekit-plugins-turn-detector](https://pypi.org/project/livekit-plugins-turn-detector/) package.
## What is Agents?
<!--BEGIN_DESCRIPTION-->
The **Agents framework** enables you to build AI-driven server programs that can see, hear, and speak in realtime. It offers a fully open-source platform for creating realtime, agentic applications.
<!--END_DESCRIPTION-->
## Features
- **Flexible integrations**: A comprehensive ecosystem to mix and match the right models for each use case.
- **AI voice agents**: `VoicePipelineAgent` and `MultimodalAgent` help orchestrate the conversation flow using LLMs and other AI models.
- **Integrated job scheduling**: Built-in task scheduling and distribution with [dispatch APIs](https://docs.livekit.io/agents/build/dispatch/) to connect end users to agents.
- **Realtime media transport**: Stream audio, video, and data over WebRTC and SIP with client SDKs for most platforms.
- **Telephony integration**: Works seamlessly with LiveKit's [telephony stack](https://docs.livekit.io/sip/), allowing your agent to make calls to or receive calls from phones.
- **Exchange data with clients**: Use [RPCs](https://docs.livekit.io/home/client/data/rpc/) and other [Data APIs](https://docs.livekit.io/home/client/data/) to seamlessly exchange data with clients.
- **Open-source**: Fully open-source, allowing you to run the entire stack on your own servers, including [LiveKit server](https://github.com/livekit/livekit), one of the most widely used WebRTC media servers.
## Installation
To install the core Agents library:
```bash
pip install livekit-agents
The framework includes a variety of plugins that make it easy to process streaming input or generate output. For example, there are plugins for converting text-to-speech or running inference with popular LLMs. Here’s how you can install a plugin:
pip install livekit-plugins-openai
We’ve partnered with OpenAI on a new MultimodalAgent
API in the Agents framework. This class completely wraps OpenAI’s Realtime API, abstracts away the raw wire protocol, and provide an ultra-low latency WebRTC transport between GPT-4o and your users’ devices. This same stack powers Advanced Voice in the ChatGPT app.
Provider | Package | Usage |
---|---|---|
OpenAI | livekit-plugins-openai | openai.LLM() |
Azure OpenAI | livekit-plugins-openai | openai.LLM.with_azure() |
Anthropic | livekit-plugins-anthropic | anthropic.LLM() |
Google (Gemini) | livekit-plugins-google | google.LLM() |
AWS Bedrock | livekit-plugins-aws | aws.LLM() |
Cerebras | livekit-plugins-openai | openai.LLM.with_cerebras() |
DeepSeek | livekit-plugins-openai | openai.LLM.with_deepseek() |
Groq | livekit-plugins-openai | openai.LLM.with_groq() |
Ollama | livekit-plugins-openai | openai.LLM.with_ollama() |
Perplexity | livekit-plugins-openai | openai.LLM.with_perplexity() |
Together.ai | livekit-plugins-openai | openai.LLM.with_together() |
X.ai (Grok) | livekit-plugins-openai | openai.LLM.with_x_ai() |
Provider | Package | Streaming | Usage |
---|---|---|---|
Azure | livekit-plugins-azure | ✅ | azure.STT() |
Deepgram | livekit-plugins-deepgram | ✅ | deepgram.STT() |
OpenAI (Whisper) | livekit-plugins-openai | openai.STT() | |
livekit-plugins-google | ✅ | google.STT() | |
AssemblyAI | livekit-plugins-assemblyai | ✅ | assemblyai.STT() |
Groq (Whisper) | livekit-plugins-openai | openai.STT.with_groq() | |
FAL (Whizper) | livekit-plugins-fal | fal.STT() | |
Speechmatics | livekit-plugins-speechmatics | ✅ | speechmatics.STT() |
AWS Transcribe | livekit-plugins-aws | ✅ | aws.STT() |
Provider | Package | Streaming | Voice Cloning | Usage |
---|---|---|---|---|
Cartesia | livekit-plugins-cartesia | ✅ | ✅ | cartesia.TTS() |
ElevenLabs | livekit-plugins-elevenlabs | ✅ | ✅ | elevenlabs.TTS() |
OpenAI | livekit-plugins-openai | openai.TTS() | ||
Azure OpenAI | livekit-plugins-openai | openai.TTS.with_azure() | ||
livekit-plugins-google | ✅ | ✅ | google.TTS() | |
Deepgram | livekit-plugins-deepgram | ✅ | deepgram.TTS() | |
Play.ai | livekit-plugins-playai | ✅ | ✅ | playai.TTS() |
Rime | livekit-plugins-rime | ✅ | rime.TTS() | |
Neuphonic | livekit-plugins-neuphonic | ✅ | ✅ | neuphonic.TTS() |
AWS Polly | livekit-plugins-aws | ✅ | aws.TTS() |
Plugin | Description |
---|---|
livekit-plugins-rag | Annoy based simple RAG |
livekit-plugins-llama-index | RAG with LlamaIndex |
livekit-plugins-nltk | Utilities for working with text |
livekit-plugins-silero | Voice activity detection |
livekit-plugins-turn-detector | Conversational turn detection model |
Documentation on the framework and how to use it can be found here
Description | Demo Link | Code Link |
---|---|---|
A basic voice agent using a pipeline of STT, LLM, and TTS | demo | code |
Voice agent using the new OpenAI Realtime API | demo | code |
Super fast voice agent using Cerebras hosted Llama 3.1 | demo | code |
Voice agent using Cartesia’s Sonic model | demo | code |
Agent that looks up the current weather via function call | N/A | code |
Voice Agent using Gemini 2.0 Flash | N/A | code |
Voice agent with custom turn-detection model | N/A | code |
Voice agent that performs a RAG-based lookup | N/A | code |
Simple agent that echos back the last utterance | N/A | code |
Video agent that publishes a stream of RGB frames | N/A | code |
Transcription agent that generates text captions from a user’s speech | N/A | code |
A chat agent you can text who will respond back with generated speech | N/A | code |
Localhost multi-agent conference call | N/A | code |
Moderation agent that uses Hive to detect spam/abusive video | N/A | code |
The Agents framework is under active development in a rapidly evolving field. We welcome and appreciate contributions of any kind, be it feedback, bugfixes, features, new plugins and tools, or better documentation. You can file issues under this repo, open a PR, or chat with us in LiveKit’s Slack community.
<table>
</table>
## examples/.env.example
```example
LIVEKIT_API_SECRET="<your livekit api secret>"
LIVEKIT_API_KEY="<your livekit api key>"
LIVEKIT_URL="<your livekit ws url>"
# This is an example Dockerfile that builds a minimal container for running LK Agents
# syntax=docker/dockerfile:1
ARG PYTHON_VERSION=3.11.6
FROM python:${PYTHON_VERSION}-slim
# Prevents Python from writing pyc files.
ENV PYTHONDONTWRITEBYTECODE=1
# Keeps Python from buffering stdout and stderr to avoid situations where
# the application crashes without emitting any logs due to buffering.
ENV PYTHONUNBUFFERED=1
# Create a non-privileged user that the app will run under.
# See https://docs.docker.com/develop/develop-images/dockerfile_best-practices/#user
ARG UID=10001
RUN adduser \
--disabled-password \
--gecos "" \
--home "/home/appuser" \
--shell "/sbin/nologin" \
--uid "${UID}" \
appuser
# Install gcc, g++ and other build dependencies.
RUN apt-get update && \
apt-get install -y \
gcc \
g++ \
python3-dev \
&& rm -rf /var/lib/apt/lists/*
USER appuser
RUN mkdir -p /home/appuser/.cache
RUN chown -R appuser /home/appuser/.cache
WORKDIR /home/appuser
COPY requirements.txt .
RUN python -m pip install --user --no-cache-dir -r requirements.txt
COPY . .
# ensure that any dependent models are downloaded at build-time
RUN python myagent.py download-files
# Run the application.
ENTRYPOINT ["python", "myagent.py"]
CMD ["start"]
import asyncio
import logging
from dotenv import load_dotenv
from livekit import rtc
from livekit.agents import JobContext, WorkerOptions, cli
from livekit.plugins import browser
WIDTH = 1920
HEIGHT = 1080
load_dotenv()
async def entrypoint(job: JobContext):
await job.connect()
ctx = browser.BrowserContext(dev_mode=True)
await ctx.initialize()
page = await ctx.new_page(url="www.livekit.io")
source = rtc.VideoSource(WIDTH, HEIGHT)
track = rtc.LocalVideoTrack.create_video_track("single-color", source)
options = rtc.TrackPublishOptions(source=rtc.TrackSource.SOURCE_CAMERA)
publication = await job.room.local_participant.publish_track(track, options)
logging.info("published track", extra={"track_sid": publication.sid})
@page.on("paint")
def on_paint(paint_data):
source.capture_frame(paint_data.frame)
async def _test_cycle():
urls = [
"https://www.livekit.io",
"https://www.google.com",
]
i = 0
async with ctx.playwright() as browser:
while True:
i += 1
await asyncio.sleep(5)
defaultContext = browser.contexts[0]
defaultPage = defaultContext.pages[0]
try:
await defaultPage.goto(urls[i % len(urls)])
except Exception:
logging.exception(f"failed to navigate to {urls[i % len(urls)]}")
await _test_cycle()
if __name__ == "__main__":
cli.run_app(WorkerOptions(entrypoint_fnc=entrypoint))
from livekit.plugins import browser
ctx = browser.BrowserContext(dev_mode=True)
import asyncio
import logging
from dataclasses import dataclass
from datetime import datetime
from typing import Union
import aiofiles
from dotenv import load_dotenv
from livekit.agents import (
AutoSubscribe,
JobContext,
WorkerOptions,
cli,
multimodal,
utils,
)
from livekit.agents.llm import ChatMessage
from livekit.agents.multimodal.multimodal_agent import EventTypes
from livekit.plugins import openai
@dataclass
class EventLog:
eventname: str | None
"""name of recorded event"""
time: str = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
"""time the event is recorded"""
@dataclass
class TranscriptionLog:
role: str | None
"""role of the speaker"""
transcription: str | None
"""transcription of speech"""
time: str = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
"""time the event is recorded"""
class ConversationPersistor(utils.EventEmitter[EventTypes]):
def __init__(
self,
*,
model: multimodal.MultimodalAgent | None,
log: str | None,
transcriptions_only: bool = False,
):
"""
Initializes a ConversationPersistor instance which records the events and transcriptions of a MultimodalAgent.
Args:
model (multimodal.MultimodalAgent): an instance of a MultiModalAgent
log (str): name of the external file to record events in
transcriptions_only (bool): a boolean variable to determine if only transcriptions will be recorded, False by default
user_transcriptions (arr): list of user transcriptions
agent_transcriptions (arr): list of agent transcriptions
events (arr): list of all events
log_q (asyncio.Queue): a queue of EventLog and TranscriptionLog
"""
super().__init__()
self._model = model
self._log = log
self._transcriptions_only = transcriptions_only
self._user_transcriptions = []
self._agent_transcriptions = []
self._events = []
self._log_q = asyncio.Queue[Union[EventLog, TranscriptionLog, None]]()
@property
def log(self) -> str | None:
return self._log
@property
def model(self) -> multimodal.MultimodalAgent | None:
return self._model
@property
def user_transcriptions(self) -> dict:
return self._user_transcriptions
@property
def agent_transcriptions(self) -> dict:
return self._agent_transcriptions
@property
def events(self) -> dict:
return self._events
@log.setter
def log(self, newlog: str | None) -> None:
self._log = newlog
async def _main_atask(self) -> None:
# Writes to file asynchronously
while True:
log = await self._log_q.get()
if log is None:
break
async with aiofiles.open(self._log, "a") as file:
if type(log) is EventLog and not self._transcriptions_only:
self._events.append(log)
await file.write("\n" + log.time + " " + log.eventname)
if type(log) is TranscriptionLog:
if log.role == "user":
self._user_transcriptions.append(log)
else:
self._agent_transcriptions.append(log)
await file.write(
"\n" + log.time + " " + log.role + " " + log.transcription
)
async def aclose(self) -> None:
# Exits
self._log_q.put_nowait(None)
await self._main_task
def start(self) -> None:
# Listens for emitted MultimodalAgent events
self._main_task = asyncio.create_task(self._main_atask())
@self._model.on("user_started_speaking")
def _user_started_speaking():
event = EventLog(eventname="user_started_speaking")
self._log_q.put_nowait(event)
@self._model.on("user_stopped_speaking")
def _user_stopped_speaking():
event = EventLog(eventname="user_stopped_speaking")
self._log_q.put_nowait(event)
@self._model.on("agent_started_speaking")
def _agent_started_speaking():
event = EventLog(eventname="agent_started_speaking")
self._log_q.put_nowait(event)
@self._model.on("agent_stopped_speaking")
def _agent_stopped_speaking():
transcription = TranscriptionLog(
role="agent",
transcription=(self._model._playing_handle._tr_fwd.played_text)[1:],
)
self._log_q.put_nowait(transcription)
event = EventLog(eventname="agent_stopped_speaking")
self._log_q.put_nowait(event)
@self._model.on("user_speech_committed")
def _user_speech_committed(user_msg: ChatMessage):
transcription = TranscriptionLog(
role="user", transcription=user_msg.content
)
self._log_q.put_nowait(transcription)
event = EventLog(eventname="user_speech_committed")
self._log_q.put_nowait(event)
@self._model.on("agent_speech_committed")
def _agent_speech_committed():
event = EventLog(eventname="agent_speech_committed")
self._log_q.put_nowait(event)
@self._model.on("agent_speech_interrupted")
def _agent_speech_interrupted():
event = EventLog(eventname="agent_speech_interrupted")
self._log_q.put_nowait(event)
@self._model.on("function_calls_collected")
def _function_calls_collected():
event = EventLog(eventname="function_calls_collected")
self._log_q.put_nowait(event)
@self._model.on("function_calls_finished")
def _function_calls_finished():
event = EventLog(eventname="function_calls_finished")
self._log_q.put_nowait(event)
load_dotenv()
logger = logging.getLogger("my-worker")
logger.setLevel(logging.INFO)
async def entrypoint(ctx: JobContext):
agent = multimodal.MultimodalAgent(
model=openai.realtime.RealtimeModel(
voice="alloy",
temperature=0.8,
instructions="You are a helpful assistant.",
turn_detection=openai.realtime.ServerVadOptions(
threshold=0.6, prefix_padding_ms=200, silence_duration_ms=500
),
),
)
cp = ConversationPersistor(model=agent, log="log.txt")
cp.start()
await ctx.connect(auto_subscribe=AutoSubscribe.AUDIO_ONLY)
participant = await ctx.wait_for_participant()
agent.start(ctx.room, participant)
if __name__ == "__main__":
cli.run_app(WorkerOptions(entrypoint_fnc=entrypoint))
import asyncio
import logging
from dotenv import load_dotenv
from livekit import rtc
from livekit.agents import (
ATTRIBUTE_AGENT_STATE,
AgentState,
AutoSubscribe,
JobContext,
WorkerOptions,
cli,
)
from livekit.agents.vad import VADEventType
from livekit.plugins import silero
load_dotenv()
logger = logging.getLogger("echo-agent")
# An example agent that echos each utterance from the user back to them
# the example uses a queue to buffer incoming streams, and uses VAD to detect
# when the user is done speaking.
async def entrypoint(ctx: JobContext):
logger.info(f"connecting to room {ctx.room.name}")
await ctx.connect(auto_subscribe=AutoSubscribe.AUDIO_ONLY)
# wait for the first participant to connect
participant: rtc.Participant = await ctx.wait_for_participant()
stream = rtc.AudioStream.from_participant(
participant=participant,
track_source=rtc.TrackSource.SOURCE_MICROPHONE,
)
vad = silero.VAD.load(
min_speech_duration=0.2,
min_silence_duration=0.6,
)
vad_stream = vad.stream()
source = rtc.AudioSource(sample_rate=48000, num_channels=1)
track = rtc.LocalAudioTrack.create_audio_track("echo", source)
await ctx.room.local_participant.publish_track(
track,
rtc.TrackPublishOptions(source=rtc.TrackSource.SOURCE_MICROPHONE),
)
# speech queue holds AudioFrames
queue = asyncio.Queue(maxsize=1000) # 10 seconds of audio (1000 frames * 10ms)
is_speaking = False
is_echoing = False
async def _set_state(state: AgentState):
await ctx.room.local_participant.set_attributes({ATTRIBUTE_AGENT_STATE: state})
await _set_state("listening")
async def _process_input():
async for audio_event in stream:
if is_echoing: # Skip processing while echoing
continue
vad_stream.push_frame(audio_event.frame)
try:
queue.put_nowait(audio_event.frame)
except asyncio.QueueFull:
# Remove oldest frame when queue is full
queue.get_nowait()
queue.put_nowait(audio_event.frame)
async def _process_vad():
nonlocal is_speaking, is_echoing
async for vad_event in vad_stream:
if is_echoing: # Skip VAD processing while echoing
continue
if vad_event.type == VADEventType.START_OF_SPEECH:
is_speaking = True
frames_to_keep = 100
frames = []
while not queue.empty():
frames.append(queue.get_nowait())
for frame in frames[-frames_to_keep:]:
queue.put_nowait(frame)
elif vad_event.type == VADEventType.END_OF_SPEECH:
is_speaking = False
is_echoing = True
logger.info("end of speech, playing back")
await _set_state("speaking")
try:
while not queue.empty():
frame = queue.get_nowait()
await source.capture_frame(frame)
except asyncio.QueueEmpty:
pass
finally:
is_echoing = False # Reset echoing flag after playback
await _set_state("listening")
await asyncio.gather(
_process_input(),
_process_vad(),
)
if __name__ == "__main__":
cli.run_app(
WorkerOptions(
entrypoint_fnc=entrypoint,
),
)
# LiveKit realtime moderation agent using Hive
This is an agent that performs visual moderation of every participant's video in a room. It does this moderation using the Visual Content Moderation model from [Hive](https://thehive.ai) [[docs](https://docs.thehive.ai/docs/visual-content-moderation#visual-content-moderation)].
## Prerequisites
Before running this agent, you'll need:
1. A LiveKit Cloud project (or a self-hosted LiveKit server).
2. An API key from Hive to access the above mentioned model.
## Configuration
Currently, this agent is configured entirely from the `agent.py` source code and the environment.
### Environment Variables
| configuration | description | example value |
|---------------|-------------|---------------|
| `LIVEKIT_URL` | Your LiveKit URL | `wss://test-abc123de.livekit.cloud` |
| `LIVEKIT_API_KEY` | Your LiveKit API key | |
| `LIVEKIT_API_SECRET` | Your LiveKit API secret | |
| `HIVE_API_KEY` | The API key from Hive to access the `Visual Content Moderation` model | `abc1deFgHIjK23KLMNOp45QrsTuv6wx8` |
### Code
| configuration | description | example value |
|---------------|-------------|---------------|
| `MOD_FRAME_INTERVAL` | Minimum number of seconds to wait between frames | 5.0 |
| `HIVE_HEADERS` | The headers to send with every request to the Hive API | `{}` |
| `CONFIDENCE_THRESHOLD` | The minimum score Hive's moderation class must meet before it is considered a problem | 0.9 |
## Running
Run this code like you would any other [LiveKit agent](https://docs.livekit.io/agents/build/anatomy/#starting-the-worker):
python3 agent.py start
Once running, the agent will join all new LiveKit rooms by default and begin moderation.
"""
LiveKit agent that connects to a room and performs visual moderation on the video
of all participants using the Visual Content Moderation model from Hive
(https://docs.thehive.ai/docs/visual-content-moderation#visual-content-moderation).
The agent periodically sends a frame from the participant's video to Hive's API
for a moderation check. If the results of that check show a confidence score
of 0.9 or higher for any of the positive classes, it logs the result and adds a
message to the room's chat. This can easily be extended to take additional
actions like removing a participant or ending a livestream, etc.
"""
import asyncio
import logging
import os
import time
from io import BytesIO
import aiohttp
from dotenv import load_dotenv
from hive_data_classes import HiveResponse, from_dict
from livekit import agents, rtc
from PIL import Image
load_dotenv()
MOD_FRAME_INTERVAL = 5.0 # check 1 frame every 5 seconds
"""
How often to check a frame (in seconds)
"""
HIVE_HEADERS = {
"Authorization": f"Token {os.getenv('HIVE_API_KEY')}",
"accept": "application/json",
}
"""
The default headers included with every request to thehive.ai
"""
CONFIDENCE_THRESHOLD = 0.9
"""
THe threshold level for scores returned by thehive.ai. See details in this doc:
https://docs.thehive.ai/docs/visual-content-moderation#choosing-thresholds-for-visual-moderation
"""
logger = logging.getLogger("hive-moderation-agent")
logger.setLevel(logging.INFO)
async def request_fnc(req: agents.JobRequest):
"""
The request handler for the agent. We use this to set the name of the
agent that is displayed to users
"""
# accept the job request and name the agent participant so users know what this is
await req.accept(
name="Moderator",
identity="hive-moderator",
)
async def entrypoint(ctx: agents.JobContext):
"""
The entrypoint of the agent. This is called every time the moderator
agent joins a room.
"""
# connect to the room and automatically subscribe to all participants' video
await ctx.connect(auto_subscribe=agents.AutoSubscribe.VIDEO_ONLY)
chat = rtc.ChatManager(ctx.room)
@ctx.room.on("track_subscribed")
def on_track_subscribed(
track: rtc.Track,
_publication: rtc.TrackPublication,
participant: rtc.RemoteParticipant,
):
"""
Event handler for video tracks. We automatically subscribe to all video
tracks when a participant joins the room. This event is triggered
once we have completed subscription to that video track.
This creates a backgrond task to process frames from each track
"""
asyncio.create_task(process_track(participant, track))
async def process_track(participant: rtc.RemoteParticipant, track: rtc.VideoTrack):
"""
This function is running in a background task once for each video track
(i.e., once for each participant). It handles processing a frame
from the video once every MOD_FRAME INTERVAL seconds.
"""
video_stream = rtc.VideoStream(track)
last_processed_time = 0
async for frame in video_stream:
current_time = time.time()
if (current_time - last_processed_time) >= MOD_FRAME_INTERVAL:
last_processed_time = current_time
await check_frame(participant, frame)
async def check_frame(participant: rtc.RemoteParticipant, frame: rtc.VideoFrame):
"""
Uses thehive.ai API to check the frame for any classifications we care about
"""
# get the current frame and convert to png format
argb_frame = frame.frame.convert(rtc.VideoBufferType.RGBA)
image = Image.frombytes(
"RGBA", (argb_frame.width, argb_frame.height), argb_frame.data
)
buffer = BytesIO()
image.save(buffer, format="PNG")
buffer.seek(0) # reset buffer position to beginning after writing
data = aiohttp.FormData()
data.add_field("image", buffer, filename="image.png", content_type="image/png")
# submit the image to Hive
logger.info("submitting image to hive")
async with aiohttp.ClientSession() as session:
async with session.post(
"https://api.thehive.ai/api/v2/task/sync",
headers=HIVE_HEADERS,
data=data,
) as response:
response.raise_for_status()
response_dict = await response.json()
hive_response: HiveResponse = from_dict(HiveResponse, response_dict)
if (
hive_response.code == 200
and len(hive_response.status) > 0
and len(hive_response.status[0].response.output) > 0
):
results = hive_response.status[0].response.output[0].classes
# filter to anything with a confidence score > threshold
for mod_class in results:
if mod_class.class_[0:4] == "yes_":
# TODO: should also include "general_nsfw" class
if mod_class.score >= CONFIDENCE_THRESHOLD:
class_name = mod_class.class_[4:]
message = (
'FOUND %s for participant "%s" (confidence score: %0.3f)'
% (
class_name,
participant.identity,
mod_class.score,
)
)
logger.info(message)
await chat.send_message(message)
await ctx.wait_for_participant()
await chat.send_message(
"I'm a moderation agent,"
"I will detect and notify you of all inappropriate material in your video stream"
)
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
agents.cli.run_app(agents.WorkerOptions(entrypoint, request_fnc=request_fnc))
from dataclasses import dataclass, is_dataclass
from typing import List, get_type_hints
def from_dict(cls, data):
if is_dataclass(cls) and isinstance(data, dict):
# Get type hints for all fields in the dataclass
field_types = get_type_hints(cls)
# Special handling for reserved words like 'class'
reserved_word_mappings = {"class": "class_"} # Map 'class' to 'class_'
processed_data = {}
for key, value in data.items():
# Check if the key is a reserved word and map it accordingly
field_name = reserved_word_mappings.get(key, key)
# Only include keys that have corresponding fields in the dataclass
if field_name in field_types:
field_type = field_types[field_name]
# Determine if the field_type is itself a dataclass
if is_dataclass(field_type):
processed_value = from_dict(field_type, value)
elif hasattr(field_type, "__origin__") and issubclass(
field_type.__origin__, List
):
# Handle List fields, assuming all elements are of the same type
item_type = field_type.__args__[0]
processed_value = [from_dict(item_type, item) for item in value]
else:
processed_value = value
processed_data[field_name] = processed_value
return cls(**processed_data)
elif isinstance(data, list):
# This assumes that the function was called with a list type as `cls`,
# which might not work as expected without context on the list's element type.
# A better approach might be needed for handling lists of dataclasses.
return [
from_dict(cls.__args__[0], item) if hasattr(cls, "__args__") else item
for item in data
]
else:
return data
@dataclass
class Status:
code: str
message: str
@dataclass
class ModInput:
id: str
charge: float
config_tag: SyntaxWarning
config_version: float
created_on: str
model: str
model_type: str
model_version: float
project_id: int
user_id: int
@dataclass
class ModClass:
class_: str
score: float
@dataclass
class ModOutput:
time: int
classes: List[ModClass]
@dataclass
class Response:
input: ModInput
output: List[ModOutput]
@dataclass
class ModResponse:
status: Status
response: Response
@dataclass
class HiveResponse:
id: str
code: int
project_id: int
user_id: int
created_on: str
status: List[ModResponse]
from_cache: bool
livekit
livekit-agents<1.0.0
python-dotenv
Pillow
aiohttp
import logging
from livekit.agents import AutoSubscribe, JobContext, WorkerOptions, WorkerType, cli
logger = logging.getLogger("my-worker")
logger.setLevel(logging.INFO)
async def entrypoint(ctx: JobContext):
logger.info("starting entrypoint")
await ctx.connect(auto_subscribe=AutoSubscribe.SUBSCRIBE_ALL)
logger.info("connected to the room")
# add your agent logic here!
if __name__ == "__main__":
# WorkerType.ROOM is the default worker type which will create an agent for every room.
# You can also use WorkerType.PUBLISHER to create a single agent for all participants that publish a track.
cli.run_app(WorkerOptions(entrypoint_fnc=entrypoint, worker_type=WorkerType.ROOM))
from __future__ import annotations
import logging
from typing import Annotated
import aiohttp
from dotenv import load_dotenv
from livekit.agents import (
AutoSubscribe,
JobContext,
WorkerOptions,
WorkerType,
cli,
llm,
multimodal,
)
from livekit.plugins import google
load_dotenv()
logger = logging.getLogger("my-worker")
logger.setLevel(logging.INFO)
async def entrypoint(ctx: JobContext):
logger.info("starting entrypoint")
fnc_ctx = llm.FunctionContext()
@fnc_ctx.ai_callable()
async def get_weather(
location: Annotated[
str, llm.TypeInfo(description="The location to get the weather for")
],
):
"""Called when the user asks about the weather. This function will return the weather for the given location."""
logger.info(f"getting weather for {location}")
url = f"https://wttr.in/{location}?format=%C+%t"
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
if response.status == 200:
weather_data = await response.text()
# # response from the function call is returned to the LLM
return f"The weather in {location} is {weather_data}."
else:
raise Exception(
f"Failed to get weather data, status code: {response.status}"
)
await ctx.connect(auto_subscribe=AutoSubscribe.AUDIO_ONLY)
participant = await ctx.wait_for_participant()
# create a chat context with chat history, these will be synchronized with the server
# upon calling `agent.generate_reply()`
chat_ctx = llm.ChatContext()
# chat_ctx.append(text="I'm planning a trip to Paris next month.", role="user")
# chat_ctx.append(
# text="How exciting! Paris is a beautiful city. I'd be happy to suggest some must-visit places and help you plan your trip.",
# role="assistant",
# )
agent = multimodal.MultimodalAgent(
model=google.beta.realtime.RealtimeModel(
voice="Puck",
temperature=0.8,
instructions="You are a helpful assistant, greet the user and help them with their trip planning",
),
fnc_ctx=fnc_ctx,
chat_ctx=chat_ctx,
)
agent.start(ctx.room, participant)
agent.generate_reply()
if __name__ == "__main__":
cli.run_app(WorkerOptions(entrypoint_fnc=entrypoint, worker_type=WorkerType.ROOM))
from __future__ import annotations
import asyncio
import logging
from typing import Annotated
import aiohttp
from dotenv import load_dotenv
from livekit.agents import (
AutoSubscribe,
JobContext,
WorkerOptions,
WorkerType,
cli,
llm,
multimodal,
)
from livekit.plugins import openai
load_dotenv()
logger = logging.getLogger("my-worker")
logger.setLevel(logging.INFO)
async def entrypoint(ctx: JobContext):
logger.info("starting entrypoint")
fnc_ctx = llm.FunctionContext()
@fnc_ctx.ai_callable()
async def get_weather(
location: Annotated[
str, llm.TypeInfo(description="The location to get the weather for")
],
):
"""Called when the user asks about the weather. This function will return the weather for the given location."""
logger.info(f"getting weather for {location}")
url = f"https://wttr.in/{location}?format=%C+%t"
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
if response.status == 200:
weather_data = await response.text()
# response from the function call is returned to the LLM
return f"The weather in {location} is {weather_data}."
else:
raise Exception(
f"Failed to get weather data, status code: {response.status}"
)
await ctx.connect(auto_subscribe=AutoSubscribe.AUDIO_ONLY)
participant = await ctx.wait_for_participant()
# to use Microsoft Azure, uncomment the following lines
# agent = multimodal.MultimodalAgent(
# model=openai.realtime.RealtimeModel.with_azure(
# azure_deployment="<model-deployment>",
# azure_endpoint="wss://<endpoint>.openai.azure.com/", # or AZURE_OPENAI_ENDPOINT
# api_key="<api-key>", # or AZURE_OPENAI_API_KEY
# api_version="2024-10-01-preview", # or OPENAI_API_VERSION
# voice="alloy",
# temperature=0.8,
# instructions="You are a helpful assistant",
# turn_detection=openai.realtime.ServerVadOptions(
# threshold=0.6, prefix_padding_ms=200, silence_duration_ms=500
# ),
# ),
# fnc_ctx=fnc_ctx,
# )
# create a chat context with chat history, these will be synchronized with the server
# upon session establishment
chat_ctx = llm.ChatContext()
# chat_ctx.append(text="I'm planning a trip to Paris next month.", role="user")
# chat_ctx.append(
# text="How exciting! Paris is a beautiful city. I'd be happy to suggest some must-visit places and help you plan your trip.",
# role="assistant",
# )
agent = multimodal.MultimodalAgent(
model=openai.realtime.RealtimeModel(
voice="alloy",
temperature=0.8,
instructions=(
"You are a helpful assistant, greet the user and help them with their trip planning. "
"When performing function calls, let user know that you are checking the weather."
),
turn_detection=openai.realtime.ServerVadOptions(
threshold=0.6, prefix_padding_ms=200, silence_duration_ms=500
),
),
fnc_ctx=fnc_ctx,
chat_ctx=chat_ctx,
)
agent.start(ctx.room, participant)
agent.generate_reply()
@agent.on("agent_speech_committed")
@agent.on("agent_speech_interrupted")
def _on_agent_speech_created(msg: llm.ChatMessage):
# example of truncating the chat context
max_ctx_len = 10
chat_ctx = agent.chat_ctx_copy()
if len(chat_ctx.messages) > max_ctx_len:
chat_ctx.messages = chat_ctx.messages[-max_ctx_len:]
# NOTE: The `set_chat_ctx` function will attempt to synchronize changes made
# to the local chat context with the server instead of completely replacing it,
# provided that the message IDs are consistent.
asyncio.create_task(agent.set_chat_ctx(chat_ctx))
if __name__ == "__main__":
cli.run_app(WorkerOptions(entrypoint_fnc=entrypoint, worker_type=WorkerType.ROOM))
# Participant Entrypoint Example
This example shows how to do things when participants join. For example, a common use case is to fetch some external data based on the participant's attributes.
## Run
### Setup and activate a virtual env:
`python -m venv venv`
`source venv/bin/activate`
### Set environment variables:
```bash
export LIVEKIT_URL=<your LiveKit server URL>
export LIVEKIT_API_KEY=<your API Key>
export LIVEKIT_API_SECRET=<your API Secret>
pip install -r requirements.txt
python participant_entrypoint.py dev
We’ve built Agents Playground so you don’t have to build your own frontend while you iterate on your agent.
## examples/participant-entrypoint/participant_entrypoint.py
```py
import asyncio
import logging
from dotenv import load_dotenv
from livekit import rtc
from livekit.agents import AutoSubscribe, JobContext, WorkerOptions, cli
load_dotenv()
logger = logging.getLogger("my-worker")
logger.setLevel(logging.INFO)
async def entrypoint(ctx: JobContext):
logger.info("starting entrypoint")
async def participant_task_1(ctx: JobContext, p: rtc.RemoteParticipant):
# You can filter out participants you are not interested in
# if p.identity != "some_identity_of_interest":
# return
logger.info(f"participant task 1 starting for {p.identity}")
# Do something with p.attributes, p.identity, p.metadata, etc.
# my_stuff = await fetch_stuff_from_my_db(p)
# Do something
await asyncio.sleep(60)
logger.info(f"participant task done for {p.identity}")
async def participant_task_2(ctx: JobContext, p: rtc.RemoteParticipant):
# multiple tasks can be run concurrently for each participant
logger.info(f"participant task 2 starting for {p.identity}")
await asyncio.sleep(10)
# Add participant entrypoints before calling ctx.connect
ctx.add_participant_entrypoint(entrypoint_fnc=participant_task_1)
ctx.add_participant_entrypoint(entrypoint_fnc=participant_task_2)
await ctx.connect(auto_subscribe=AutoSubscribe.SUBSCRIBE_ALL)
logger.info("connected to the room")
if __name__ == "__main__":
cli.run_app(WorkerOptions(entrypoint_fnc=entrypoint))
livekit-agents>=0.12.20
python-dotenv~=1.0
# Simple-color
This small exmple publishes a solid color video frame.
import asyncio
import logging
import random
from dotenv import load_dotenv
from livekit import rtc
from livekit.agents import JobContext, WorkerOptions, cli
# Load environment variables
load_dotenv()
WIDTH = 640
HEIGHT = 480
async def entrypoint(job: JobContext):
await job.connect()
room = job.room
source = rtc.VideoSource(WIDTH, HEIGHT)
track = rtc.LocalVideoTrack.create_video_track("single-color", source)
options = rtc.TrackPublishOptions(source=rtc.TrackSource.SOURCE_CAMERA)
publication = await room.local_participant.publish_track(track, options)
logging.info("published track", extra={"track_sid": publication.sid})
async def _draw_color():
argb_frame = bytearray(WIDTH * HEIGHT * 4)
while True:
await asyncio.sleep(0.1) # 100ms
# Create a new random color
r, g, b = [random.randint(0, 255) for _ in range(3)]
color = bytes([r, g, b, 255])
# Fill the frame with the new random color
argb_frame[:] = color * WIDTH * HEIGHT
frame = rtc.VideoFrame(WIDTH, HEIGHT, rtc.VideoBufferType.RGBA, argb_frame)
source.capture_frame(frame)
await _draw_color()
if __name__ == "__main__":
cli.run_app(WorkerOptions(entrypoint_fnc=entrypoint))
livekit-agents>=0.12.20
python-dotenv~=1.0
# Speech-to-text
This example shows realtime transcription from voice to text.
It uses OpenAI's Whisper STT API, but supports other STT plugins by changing this line:
```python
stt = openai.STT()
To render the transcriptions into your client application, refer to the full documentation.
export LIVEKIT_URL=wss://yourhost.livekit.cloud
export LIVEKIT_API_KEY=livekit-api-key
export LIVEKIT_API_SECRET=your-api-secret
export OPENAI_API_KEY=your-api-key
python3 transcriber.py start
Then connect to any room. For an example frontend, you can use LiveKit’s Agents Playground.
## examples/speech-to-text/requirements.txt
```txt
livekit-agents>=0.12.20
livekit-plugins-deepgram>=0.7.3
python-dotenv~=1.0
import asyncio
import logging
from dotenv import load_dotenv
from livekit import rtc
from livekit.agents import (
AutoSubscribe,
JobContext,
WorkerOptions,
cli,
stt,
transcription,
)
from livekit.plugins import openai, silero
load_dotenv()
logger = logging.getLogger("transcriber")
async def _forward_transcription(
stt_stream: stt.SpeechStream, stt_forwarder: transcription.STTSegmentsForwarder
):
"""Forward the transcription to the client and log the transcript in the console"""
async for ev in stt_stream:
if ev.type == stt.SpeechEventType.INTERIM_TRANSCRIPT:
# you may not want to log interim transcripts, they are not final and may be incorrect
pass
elif ev.type == stt.SpeechEventType.FINAL_TRANSCRIPT:
print(" -> ", ev.alternatives[0].text)
elif ev.type == stt.SpeechEventType.RECOGNITION_USAGE:
logger.debug(f"metrics: {ev.recognition_usage}")
stt_forwarder.update(ev)
async def entrypoint(ctx: JobContext):
logger.info(f"starting transcriber (speech to text) example, room: {ctx.room.name}")
# this example uses OpenAI Whisper, but you can use assemblyai, deepgram, google, azure, etc.
stt_impl = openai.STT()
if not stt_impl.capabilities.streaming:
# wrap with a stream adapter to use streaming semantics
stt_impl = stt.StreamAdapter(
stt=stt_impl,
vad=silero.VAD.load(
min_silence_duration=0.2,
),
)
async def transcribe_track(participant: rtc.RemoteParticipant, track: rtc.Track):
audio_stream = rtc.AudioStream(track)
stt_forwarder = transcription.STTSegmentsForwarder(
room=ctx.room, participant=participant, track=track
)
stt_stream = stt_impl.stream()
asyncio.create_task(_forward_transcription(stt_stream, stt_forwarder))
async for ev in audio_stream:
stt_stream.push_frame(ev.frame)
@ctx.room.on("track_subscribed")
def on_track_subscribed(
track: rtc.Track,
publication: rtc.TrackPublication,
participant: rtc.RemoteParticipant,
):
# spin up a task to transcribe each track
if track.kind == rtc.TrackKind.KIND_AUDIO:
asyncio.create_task(transcribe_track(participant, track))
await ctx.connect(auto_subscribe=AutoSubscribe.AUDIO_ONLY)
if __name__ == "__main__":
cli.run_app(WorkerOptions(entrypoint_fnc=entrypoint))
# Text-to-speech
This small example shows how you can generate real-time audio data from text.
import asyncio
import logging
from dotenv import load_dotenv
from livekit import rtc
from livekit.agents import AutoSubscribe, JobContext, WorkerOptions, cli
from livekit.plugins import cartesia
load_dotenv()
logger = logging.getLogger("cartesia-tts-demo")
logger.setLevel(logging.INFO)
async def entrypoint(job: JobContext):
logger.info("starting tts example agent")
tts = cartesia.TTS(
# speed="fastest",
# emotion=["surprise:highest"],
)
source = rtc.AudioSource(tts.sample_rate, tts.num_channels)
track = rtc.LocalAudioTrack.create_audio_track("agent-mic", source)
options = rtc.TrackPublishOptions()
options.source = rtc.TrackSource.SOURCE_MICROPHONE
await job.connect(auto_subscribe=AutoSubscribe.SUBSCRIBE_NONE)
publication = await job.room.local_participant.publish_track(track, options)
await publication.wait_for_subscription()
stream = tts.stream()
async def _playback_task():
async for audio in stream:
await source.capture_frame(audio.frame)
task = asyncio.create_task(_playback_task())
text = "hello from Cartesia. I hope you are having a great day."
# split into two word chunks to simulate LLM streaming
words = text.split()
for i in range(0, len(words), 2):
chunk = " ".join(words[i : i + 2])
if chunk:
logger.info(f'pushing chunk: "{chunk} "')
stream.push_text(chunk + " ")
# Mark end of input segment
stream.flush()
stream.end_input()
await asyncio.gather(task)
if __name__ == "__main__":
cli.run_app(WorkerOptions(entrypoint_fnc=entrypoint))
import asyncio
import logging
from typing import Optional
from dotenv import load_dotenv
from livekit import rtc
from livekit.agents import JobContext, WorkerOptions, cli
from livekit.plugins import elevenlabs
logger = logging.getLogger("elevenlabs-tts-demo")
logger.setLevel(logging.INFO)
load_dotenv()
def _text_to_chunks(text: str) -> list[str]:
"""Split the text into chunks of 2, 3, and 4 words"""
sizes = [2, 3, 4]
chunks, i = [], 0
for size in sizes:
while i + size <= len(text):
chunks.append(text[i : i + size])
i += size
chunks.append(text[i:]) # remaining
return chunks
async def _playout_task(
playout_q: asyncio.Queue, audio_source: rtc.AudioSource
) -> None:
"""Playout audio frames from the queue to the audio source"""
while True:
frame = await playout_q.get()
if frame is None:
break
await audio_source.capture_frame(frame)
async def entrypoint(job: JobContext):
# use another voice for this demo
# you can get a list of the voices using 'await tts_11labs.list_voices()'
voice = elevenlabs.Voice(
id="ODq5zmih8GrVes37Dizd", name="Patrick", category="premade"
)
tts_11labs = elevenlabs.TTS(model_id="eleven_multilingual_v2", voice=voice)
source = rtc.AudioSource(tts_11labs.sample_rate, tts_11labs.num_channels)
track = rtc.LocalAudioTrack.create_audio_track("agent-mic", source)
options = rtc.TrackPublishOptions()
options.source = rtc.TrackSource.SOURCE_MICROPHONE
await job.connect()
publication = await job.room.local_participant.publish_track(track, options)
await publication.wait_for_subscription()
logger.info('Saying "Bonjour, comment allez-vous?"')
async for output in tts_11labs.synthesize("Bonjour, comment allez-vous?"):
await source.capture_frame(output.frame)
await asyncio.sleep(1)
logger.info('Saying "Au revoir."')
async for output in tts_11labs.synthesize("Au revoir."):
await source.capture_frame(output.frame)
await asyncio.sleep(1)
streamed_text = (
"Bonjour, ceci est un autre example avec la méthode utilisant un websocket."
)
logger.info('Streaming text "%s"', streamed_text)
stream = tts_11labs.stream()
for chunk in _text_to_chunks(
streamed_text
): # split into chunk just for the demonstration
stream.push_text(chunk)
stream.flush()
stream.end_input()
playout_q = asyncio.Queue[Optional[rtc.AudioFrame]]()
async def _synth_task():
async for ev in stream:
playout_q.put_nowait(ev.frame)
playout_q.put_nowait(None)
synth_task = asyncio.create_task(_synth_task())
playout_task = asyncio.create_task(_playout_task(playout_q, source))
await asyncio.gather(synth_task, playout_task)
await stream.aclose()
if __name__ == "__main__":
cli.run_app(WorkerOptions(entrypoint_fnc=entrypoint))
import asyncio
import logging
from dotenv import load_dotenv
from livekit import rtc
from livekit.agents import AutoSubscribe, JobContext, WorkerOptions, cli
from livekit.plugins import neuphonic
load_dotenv()
logger = logging.getLogger("neuphonic-tts-demo")
logger.setLevel(logging.INFO)
async def entrypoint(job: JobContext):
logger.info("starting tts example agent")
SAMPLE_RATE = 22050
NUM_CHANNELS = 1
tts = neuphonic.TTS(
# voice_id=<uuid>,
sample_rate=SAMPLE_RATE # defaults to 22050
)
source = rtc.AudioSource(SAMPLE_RATE, NUM_CHANNELS)
track = rtc.LocalAudioTrack.create_audio_track("agent-mic", source)
options = rtc.TrackPublishOptions()
options.source = rtc.TrackSource.SOURCE_MICROPHONE
await job.connect(auto_subscribe=AutoSubscribe.SUBSCRIBE_NONE)
publication = await job.room.local_participant.publish_track(track, options)
await publication.wait_for_subscription()
stream = tts.stream()
async def _playback_task():
async for audio in stream:
await source.capture_frame(audio.frame)
task = asyncio.create_task(_playback_task())
text = "Hello from Neuphonic. You have just successfully run the example!"
# split into two word chunks to simulate LLM streaming
words = text.split()
for i in range(0, len(words), 2):
chunk = " ".join(words[i : i + 2])
if chunk:
logger.info(f'pushing chunk: "{chunk} "')
stream.push_text(chunk + " ")
# Mark end of input segment
stream.flush()
stream.end_input()
await asyncio.gather(task)
if __name__ == "__main__":
cli.run_app(WorkerOptions(entrypoint_fnc=entrypoint))
import asyncio
import logging
from dotenv import load_dotenv
from livekit import rtc
from livekit.agents import AutoSubscribe, JobContext, WorkerOptions, cli
from livekit.plugins import openai
load_dotenv()
logger = logging.getLogger("openai-tts-demo")
logger.setLevel(logging.INFO)
async def entrypoint(job: JobContext):
logger.info("starting tts example agent")
tts = openai.TTS(model="tts-1", voice="nova")
source = rtc.AudioSource(tts.sample_rate, tts.num_channels)
track = rtc.LocalAudioTrack.create_audio_track("agent-mic", source)
options = rtc.TrackPublishOptions()
options.source = rtc.TrackSource.SOURCE_MICROPHONE
await job.connect(auto_subscribe=AutoSubscribe.SUBSCRIBE_NONE)
publication = await job.room.local_participant.publish_track(track, options)
await publication.wait_for_subscription()
logger.info('Saying "Hello!"')
async for output in tts.synthesize("Hello!"):
await source.capture_frame(output.frame)
await asyncio.sleep(1)
logger.info('Saying "Goodbye."')
async for output in tts.synthesize("Goodbye."):
await source.capture_frame(output.frame)
if __name__ == "__main__":
cli.run_app(WorkerOptions(entrypoint_fnc=entrypoint))
livekit-agents>=0.12.20
livekit-plugins-openai>=0.12.3
livekit-plugins-cartesia>=0.4.11
livekit-plugins-elevenlabs>=0.8.2
python-dotenv~=1.0
import asyncio
import logging
from typing import Optional
from dotenv import load_dotenv
from livekit import rtc
from livekit.agents import (
AutoSubscribe,
JobContext,
WorkerOptions,
cli,
transcription,
tts,
)
from livekit.plugins import elevenlabs
load_dotenv()
logger = logging.getLogger("transcription-forwarding-demo")
logger.setLevel(logging.INFO)
async def entrypoint(ctx: JobContext):
logger.info("starting transcription protocol example")
tts_11labs = elevenlabs.TTS()
# publish an audio track
source = rtc.AudioSource(tts_11labs.sample_rate, tts_11labs.num_channels)
track = rtc.LocalAudioTrack.create_audio_track("agent-mic", source)
options = rtc.TrackPublishOptions(source=rtc.TrackSource.SOURCE_MICROPHONE)
await ctx.connect(auto_subscribe=AutoSubscribe.SUBSCRIBE_NONE)
publication = await ctx.room.local_participant.publish_track(track, options)
await publication.wait_for_subscription()
# start the transcription examples
tts_forwarder = transcription.TTSSegmentsForwarder(
room=ctx.room, participant=ctx.room.local_participant
)
await _eg_single_segment(tts_forwarder, tts_11labs, source)
await asyncio.sleep(2)
await _eg_streamed_tts_stream(tts_forwarder, tts_11labs, source)
async def _eg_single_segment(
tts_forwarder: transcription.TTSSegmentsForwarder,
tts_11labs: tts.TTS,
source: rtc.AudioSource,
):
"""Transcription example without streaming (single string)"""
text = "Hello world, this is a single segment"
logger.info("pushing text %s", text)
tts_forwarder.push_text(text)
tts_forwarder.mark_text_segment_end()
playout_q = asyncio.Queue[Optional[rtc.AudioFrame]]()
playout_task = asyncio.create_task(_playout_task(tts_forwarder, playout_q, source))
async for output in tts_11labs.synthesize(text):
tts_forwarder.push_audio(output.frame)
playout_q.put_nowait(output.frame)
tts_forwarder.mark_audio_segment_end()
playout_q.put_nowait(None)
await playout_task
async def _eg_streamed_tts_stream(
tts_forwarder: transcription.TTSSegmentsForwarder,
tts_11labs: tts.TTS,
source: rtc.AudioSource,
):
"""Transcription example using a tts stream (we split text into chunks just for the example)"""
# this tts_forwarder will forward the transcription to the client and sync with the audio
tts_stream = tts_11labs.stream()
streamed_text = "Hello world, this text is going to be splitted into small chunks"
logger.info("pushing text %s", streamed_text)
for chunk in _text_to_chunks(streamed_text):
tts_stream.push_text(chunk)
tts_forwarder.push_text(chunk)
tts_stream.flush()
tts_stream.end_input()
tts_forwarder.mark_text_segment_end()
playout_q = asyncio.Queue[Optional[rtc.AudioFrame]]()
async def _synth_task() -> None:
async for ev in tts_stream:
playout_q.put_nowait(ev.frame)
tts_forwarder.push_audio(ev.frame)
tts_forwarder.mark_audio_segment_end()
playout_q.put_nowait(None)
await tts_stream.aclose()
playout_task = asyncio.create_task(_playout_task(tts_forwarder, playout_q, source))
synth_task = asyncio.create_task(_synth_task())
await asyncio.gather(synth_task, playout_task)
await tts_forwarder.aclose()
async def _playout_task(
tts_forwarder: transcription.TTSSegmentsForwarder,
playout_q: asyncio.Queue,
audio_source: rtc.AudioSource,
) -> None:
"""Playout audio frames from the queue to the audio source"""
tts_forwarder.segment_playout_started()
while True:
frame = await playout_q.get()
if frame is None:
break
await audio_source.capture_frame(frame)
tts_forwarder.segment_playout_finished()
def _text_to_chunks(text: str) -> list[str]:
"""Split the text into chunks of 2, 3, and 4 words"""
sizes = [2, 3, 4]
chunks, i = [], 0
for size in sizes:
while i + size <= len(text):
chunks.append(text[i : i + size])
i += size
chunks.append(text[i:]) # remaining
return chunks
if __name__ == "__main__":
cli.run_app(WorkerOptions(entrypoint_fnc=entrypoint))
# Voice Assistant Examples
We have a few examples that shows the various ways of using using the VoiceAssistant class:
- `minimal_assistant.py`: a basic conversational assistant
- `function_calling_weather.py`: a weather assistant that calls an API endpoint to retrieve the weather
- `custom_pronunciation.py`: using the `before_tts_cb` hook to customize how TTS pronounces words
- `simple_rag`: a simple RAG assistant that answers questions by querying a embeddings index
The demo assistants use:
- Deepgram for Speech-to-text
- OpenAI for LLM and Text-to-speech
## Run
Instructions for running the two agents are identical, the following steps will assume you are running `minimal_assistant.py`
### Setup and activate a virtual env:
`python -m venv venv`
`source venv/bin/activate`
### Set environment variables:
```bash
export LIVEKIT_URL=<your LiveKit server URL>
export LIVEKIT_API_KEY=<your API Key>
export LIVEKIT_API_SECRET=<your API Secret>
export DEEPGRAM_API_KEY=<your Deepgram API key>
export OPENAI_API_KEY=<your OpenAI API key>
pip install -r requirements.txt
python minimal_assistant.py download-files
python minimal_assistant.py dev
We’ve built Agents Playground so you don’t have to build your own frontend while you iterate on your agent.
## examples/voice-pipeline-agent/cost_metrics.py
```py
import logging
from dotenv import load_dotenv
from livekit.agents import (
AutoSubscribe,
JobContext,
JobProcess,
WorkerOptions,
cli,
llm,
metrics,
)
from livekit.agents.pipeline import VoicePipelineAgent
from livekit.plugins import deepgram, openai, silero
load_dotenv()
logger = logging.getLogger("metrics-example")
# This example logs pipeline metrics and computes cost of the session
OPENAI_LLM_INPUT_PRICE = 2.50 / (10**6) # $2.50 per million tokens
OPENAI_LLM_OUTPUT_PRICE = 10 / (10**6) # $10 per million tokens
OPENAI_TTS_PRICE = 15 / (10**6) # $15 per million characters
DEEPGRAM_STT_PRICE = 0.0043 # $0.0043 per minute
def prewarm(proc: JobProcess):
proc.userdata["vad"] = silero.VAD.load()
async def entrypoint(ctx: JobContext):
initial_ctx = llm.ChatContext().append(
role="system",
text=(
"You are a voice assistant created by LiveKit. Your interface with users will be voice. "
"You should use short and concise responses, and avoiding usage of unpronouncable punctuation."
),
)
await ctx.connect(auto_subscribe=AutoSubscribe.AUDIO_ONLY)
participant = await ctx.wait_for_participant()
agent = VoicePipelineAgent(
vad=ctx.proc.userdata["vad"],
stt=deepgram.STT(),
llm=openai.LLM(),
tts=openai.TTS(),
chat_ctx=initial_ctx,
)
usage_collector = metrics.UsageCollector()
@agent.on("metrics_collected")
def _on_metrics_collected(mtrcs: metrics.AgentMetrics):
metrics.log_metrics(mtrcs)
usage_collector.collect(mtrcs)
async def log_session_cost():
summary = usage_collector.get_summary()
llm_cost = (
summary.llm_prompt_tokens * OPENAI_LLM_INPUT_PRICE
+ summary.llm_completion_tokens * OPENAI_LLM_OUTPUT_PRICE
)
tts_cost = summary.tts_characters_count * OPENAI_TTS_PRICE
stt_cost = summary.stt_audio_duration * DEEPGRAM_STT_PRICE / 60
total_cost = llm_cost + tts_cost + stt_cost
logger.info(
f"Total cost: ${total_cost:.4f} (LLM: ${llm_cost:.4f}, TTS: ${tts_cost:.4f}, STT: ${stt_cost:.4f})"
)
ctx.add_shutdown_callback(log_session_cost)
agent.start(ctx.room, participant)
await agent.say("Hey, how can I help you today?", allow_interruptions=True)
if __name__ == "__main__":
cli.run_app(WorkerOptions(entrypoint_fnc=entrypoint, prewarm_fnc=prewarm))
from __future__ import annotations
from typing import AsyncIterable
from dotenv import load_dotenv
from livekit.agents import AutoSubscribe, JobContext, WorkerOptions, cli, llm, tokenize
from livekit.agents.pipeline import VoicePipelineAgent
from livekit.plugins import cartesia, deepgram, openai, silero
load_dotenv()
async def entrypoint(ctx: JobContext):
initial_ctx = llm.ChatContext().append(
role="system",
text=(
"You are a voice assistant created by LiveKit. Your interface with users will be voice. "
"You should use short and concise responses, and avoiding usage of unpronouncable punctuation."
),
)
await ctx.connect(auto_subscribe=AutoSubscribe.AUDIO_ONLY)
def _before_tts_cb(agent: VoicePipelineAgent, text: str | AsyncIterable[str]):
# The TTS is incorrectly pronouncing "LiveKit", so we'll replace it with a phonetic
# spelling
return tokenize.utils.replace_words(
text=text, replacements={"livekit": r"<<l|aɪ|v|k|ɪ|t|>>"}
)
# also for this example, we also intensify the keyword "LiveKit" to make it more likely to be
# recognized with the STT
deepgram_stt = deepgram.STT(keywords=[("LiveKit", 3.5)])
agent = VoicePipelineAgent(
vad=silero.VAD.load(),
stt=deepgram_stt,
llm=openai.LLM(),
tts=cartesia.TTS(),
chat_ctx=initial_ctx,
before_tts_cb=_before_tts_cb,
)
agent.start(ctx.room)
await agent.say("Hey, LiveKit is awesome!", allow_interruptions=True)
if __name__ == "__main__":
cli.run_app(WorkerOptions(entrypoint_fnc=entrypoint))
import logging
from datetime import datetime
from dotenv import load_dotenv
from livekit.agents import (
AutoSubscribe,
JobContext,
JobProcess,
WorkerOptions,
cli,
llm,
stt,
tts,
)
from livekit.agents.pipeline import VoicePipelineAgent
from livekit.plugins import cartesia, deepgram, openai, playai, silero
load_dotenv()
logger = logging.getLogger("fallback-adapter-example")
def prewarm(proc: JobProcess):
proc.userdata["vad"] = silero.VAD.load()
async def entrypoint(ctx: JobContext):
initial_ctx = llm.ChatContext().append(
role="system",
text=(
"You are a voice assistant created by LiveKit. Your interface with users will be voice. "
"You should use short and concise responses, and avoiding usage of unpronouncable punctuation."
),
)
fnc_ctx = llm.FunctionContext()
@fnc_ctx.ai_callable()
def get_time():
"""called to retrieve the current local time"""
return datetime.now().strftime("%H:%M:%S")
await ctx.connect(auto_subscribe=AutoSubscribe.AUDIO_ONLY)
# wait for the first participant to connect
participant = await ctx.wait_for_participant()
logger.info(f"starting voice assistant for participant {participant.identity}")
vad: silero.VAD = ctx.proc.userdata["vad"]
# fallback to OpenAI if Deepgram goes down
fallback_stt = stt.FallbackAdapter(
[
deepgram.STT(),
stt.StreamAdapter(stt=openai.STT(), vad=vad),
]
)
# fallback to Azure if OpenAI goes down
fallback_llm = llm.FallbackAdapter(
[
openai.LLM(),
openai.LLM.with_azure(),
]
)
# fallback to 11labs if Cartesia goes down
# you can keep the same voice by using their voice cloning feature
fallback_tts = tts.FallbackAdapter(
[
cartesia.TTS(),
playai.TTS(),
]
)
agent = VoicePipelineAgent(
vad=vad,
stt=fallback_stt,
llm=fallback_llm,
tts=fallback_tts,
chat_ctx=initial_ctx,
fnc_ctx=fnc_ctx,
)
agent.start(ctx.room, participant)
await agent.say("Hey, how can I help you today?", allow_interruptions=True)
if __name__ == "__main__":
cli.run_app(
WorkerOptions(
entrypoint_fnc=entrypoint,
prewarm_fnc=prewarm,
),
)
import asyncio
import logging
from typing import Annotated
import aiohttp
from dotenv import load_dotenv
from livekit.agents import (
AutoSubscribe,
JobContext,
JobProcess,
WorkerOptions,
cli,
llm,
metrics,
)
from livekit.agents.pipeline import VoicePipelineAgent
from livekit.plugins import deepgram, openai, silero
load_dotenv()
logger = logging.getLogger("weather-demo")
logger.setLevel(logging.INFO)
class AssistantFnc(llm.FunctionContext):
"""
The class defines a set of LLM functions that the assistant can execute.
"""
@llm.ai_callable()
async def get_weather(
self,
location: Annotated[
str, llm.TypeInfo(description="The location to get the weather for")
],
latitude: Annotated[
str,
llm.TypeInfo(description="The latitude of location to get the weather for"),
],
longitude: Annotated[
str,
llm.TypeInfo(
description="The longitude of location to get the weather for"
),
],
):
"""Called when the user asks about the weather. This function will return the weather for the given location.
When given a location, please estimate the latitude and longitude of the location and do not ask the user for them."""
# When a function call is running, there are a couple of options to inform the user
# that it might take awhile:
# Option 1: you can use .say filler message immediately after the call is triggered
# Option 2: you can prompt the agent to return a text response when it's making a function call
# uncomment for option 1
# agent = AgentCallContext.get_current().agent
# filler_messages = [
# "Let me check the weather in {location} for you.",
# "Let me see what the weather is like in {location} right now.",
# # LLM will complete this sentence if it is added to the end of the chat context
# "The current weather in {location} is ",
# ]
# message = random.choice(filler_messages).format(location=location)
# logger.info(f"saying filler message: {message}")
# NOTE: set add_to_chat_ctx=True will add the message to the end
# of the chat context of the function call for answer synthesis
# speech_handle = await agent.say(message, add_to_chat_ctx=True) # noqa: F841
logger.info(f"getting weather for {latitude}, {longitude}")
url = f"https://api.open-meteo.com/v1/forecast?latitude={latitude}&longitude={longitude}¤t=temperature_2m"
weather_data = {}
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
if response.status == 200:
data = await response.json()
# response from the function call is returned to the LLM
weather_data = {
"temperature": data["current"]["temperature_2m"],
"temperature_unit": "Celsius",
}
else:
raise Exception(
f"Failed to get weather data, status code: {response.status}"
)
# artificially delay the function call for testing
await asyncio.sleep(2)
logger.info(f"weather data: {weather_data}")
# (optional) To wait for the speech to finish before giving results of the function call
# without waiting, the new speech result will be queued and played after current speech is finished
# await speech_handle.join()
return weather_data
def prewarm_process(proc: JobProcess):
# preload silero VAD in memory to speed up session start
proc.userdata["vad"] = silero.VAD.load()
async def entrypoint(ctx: JobContext):
await ctx.connect(auto_subscribe=AutoSubscribe.AUDIO_ONLY)
fnc_ctx = AssistantFnc() # create our fnc ctx instance
initial_chat_ctx = llm.ChatContext().append(
text=(
"You are a weather assistant created by LiveKit. Your interface with users will be voice. "
"You will provide weather information for a given location. "
# when using option 1, you can suppress from the agent with prompt
# "do not return any text while calling the function."
# option 2 - using LLM to generate text for the function call
"when performing function calls, let user know that you are checking the weather."
),
role="system",
)
participant = await ctx.wait_for_participant()
agent = VoicePipelineAgent(
vad=ctx.proc.userdata["vad"],
stt=deepgram.STT(),
llm=openai.LLM(model="gpt-4o"),
tts=openai.TTS(),
fnc_ctx=fnc_ctx,
chat_ctx=initial_chat_ctx,
)
usage_collector = metrics.UsageCollector()
@agent.on("metrics_collected")
def _on_metrics_collected(mtrcs: metrics.AgentMetrics):
metrics.log_metrics(mtrcs)
usage_collector.collect(mtrcs)
async def log_usage():
summary = usage_collector.get_summary()
logger.info(f"Usage: ${summary}")
# Start the assistant. This will automatically publish a microphone track and listen to the participant.
agent.start(ctx.room, participant)
await agent.say(
"Hello from the weather station. Tell me your location to check the weather."
)
if __name__ == "__main__":
cli.run_app(
WorkerOptions(
entrypoint_fnc=entrypoint,
prewarm_fnc=prewarm_process,
),
)
import logging
from typing import Annotated
from dotenv import load_dotenv
from livekit.agents import (
AutoSubscribe,
JobContext,
JobProcess,
WorkerOptions,
cli,
llm,
metrics,
)
from livekit.agents.pipeline import VoicePipelineAgent
from livekit.plugins import google, silero
load_dotenv()
logger = logging.getLogger("voice-assistant")
def prewarm(proc: JobProcess):
proc.userdata["vad"] = silero.VAD.load()
# An example Voice Agent using Google STT, Gemini 2.0 Flash, and Google TTS.
# Prerequisites:
# 1. livekit-plugins-openai[vertex] package installed
# 2. save your service account credentials and set the following environments:
# * GOOGLE_APPLICATION_CREDENTIALS to the path of the service account key file
# * GOOGLE_CLOUD_PROJECT to your Google Cloud project ID
# 3. the following services are enabled on your Google Cloud project:
# * Vertex AI
# * Cloud Speech-to-Text API
# * Cloud Text-to-Speech API
# Read more about authentication with Google: https://cloud.google.com/docs/authentication/application-default-credentials
async def entrypoint(ctx: JobContext):
initial_ctx = llm.ChatContext().append(
role="system",
text=(
"You are a voice assistant created by LiveKit. Your interface with users will be voice. "
"You should use short and concise responses, and avoiding usage of unpronouncable punctuation."
),
)
logger.info(f"connecting to room {ctx.room.name}")
await ctx.connect(auto_subscribe=AutoSubscribe.AUDIO_ONLY)
# wait for the first participant to connect
participant = await ctx.wait_for_participant()
logger.info(f"starting voice assistant for participant {participant.identity}")
fnc_ctx = llm.FunctionContext()
@fnc_ctx.ai_callable()
async def get_weather(
location: Annotated[
str, llm.TypeInfo(description="The location to get the weather for")
],
):
"""Called when the user asks about the weather. This function will return the weather for the given location."""
return f"The weather in {location} is sunny."
agent = VoicePipelineAgent(
vad=ctx.proc.userdata["vad"],
stt=google.STT(),
llm=google.LLM(),
tts=google.TTS(
voice_name="en-US-Journey-D",
),
chat_ctx=initial_ctx,
fnc_ctx=fnc_ctx,
)
agent.start(ctx.room, participant)
usage_collector = metrics.UsageCollector()
@agent.on("metrics_collected")
def _on_metrics_collected(mtrcs: metrics.AgentMetrics):
metrics.log_metrics(mtrcs)
usage_collector.collect(mtrcs)
async def log_usage():
summary = usage_collector.get_summary()
logger.info(f"Usage: ${summary}")
ctx.add_shutdown_callback(log_usage)
await agent.say(
"Hi there, this is Gemini, how can I help you today?", allow_interruptions=False
)
if __name__ == "__main__":
cli.run_app(
WorkerOptions(
entrypoint_fnc=entrypoint,
prewarm_fnc=prewarm,
),
)
# RAG Example using LlamaIndex
This repository showcases three ways to build a voice assistant with Retrieval-Augmented Generation (RAG) using LlamaIndex:
1. **`chat_engine.py`**: Utilizes LlamaIndex's `as_chat_engine` for a straightforward, integrated solution. **Trade-off**: Lacks function calling support, limiting advanced interactions.
2. **`query_engine.py`**: Uses an LLM that supports function calling (e.g., OpenAI's models) to define custom functions like `query_info` for retrieval. **Trade-off**: Requires additional setup but offers greater flexibility.
3. **`retrieval.py`**: Manually injects retrieved context into the system prompt using LlamaIndex's retriever. **Trade-off**: Provides fine-grained control but involves complex prompt engineering.
**Current recommended way**: Use **`query_engine.py`** for its balance of flexibility and control, enabling function calling and custom behaviors without excessive complexity.
import os
from dotenv import load_dotenv
from livekit.agents import AutoSubscribe, JobContext, WorkerOptions, cli, llm
from livekit.agents.pipeline import VoicePipelineAgent
from livekit.plugins import deepgram, llama_index, openai, silero
from llama_index.core import (
SimpleDirectoryReader,
StorageContext,
VectorStoreIndex,
load_index_from_storage,
)
from llama_index.core.chat_engine.types import ChatMode
load_dotenv()
# check if storage already exists
PERSIST_DIR = "./chat-engine-storage"
if not os.path.exists(PERSIST_DIR):
# load the documents and create the index
documents = SimpleDirectoryReader("data").load_data()
index = VectorStoreIndex.from_documents(documents)
# store it for later
index.storage_context.persist(persist_dir=PERSIST_DIR)
else:
# load the existing index
storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR)
index = load_index_from_storage(storage_context)
async def entrypoint(ctx: JobContext):
initial_ctx = llm.ChatContext().append(
role="system",
text=(
"You are a voice assistant created by LiveKit. Your interface with users will be voice. "
"You should use short and concise responses, and avoiding usage of unpronouncable punctuation."
),
)
chat_engine = index.as_chat_engine(chat_mode=ChatMode.CONTEXT)
await ctx.connect(auto_subscribe=AutoSubscribe.AUDIO_ONLY)
assistant = VoicePipelineAgent(
vad=silero.VAD.load(),
stt=deepgram.STT(),
llm=llama_index.LLM(chat_engine=chat_engine),
tts=openai.TTS(),
chat_ctx=initial_ctx,
)
assistant.start(ctx.room)
await assistant.say("Hey, how can I help you today?", allow_interruptions=True)
if __name__ == "__main__":
cli.run_app(WorkerOptions(entrypoint_fnc=entrypoint))
Cloud Architecture
LiveKit Cloud gives you the flexibility of LiveKit's WebRTC stack, combined with global, CDN-scale infrastructure offering 99.99% uptime.
Built with LiveKit SFU
LiveKit Cloud builds on our open-source SFU. This means it supports the exact same client and server APIs as the open-source stack.
Maintaining compatibility with LiveKit's Open Source stack (OSS) is important to us. We didn't want any developer locked into using Cloud, or needing to integrate a different set of features, APIs or SDKs for their applications to work with it. Our design goal: a developer should be able to switch between Cloud or self-hosted without changing a line of code.
Distributed Mesh Architecture
In contrast to traditional WebRTC architectures, LiveKit Cloud runs multiple SFU instances in a mesh formation. We've developed capabilities for media servers to discover and connect to one another, in order to relay media between servers. This key capability allows us to bypass the single-server limitation that exists in traditional SFU and MCU architectures.
Multi-home
Cloud multi-home architecture
With a multi-home architecture, participants no longer need to connect to the same server. When participants from different regions join the same meeting, they'll each connect to the SFU closest to them, minimizing latency and transmission loss between the participant and SFU.
Each SFU instance establishes connections to other instances over optimized inter-data center networks. Inter-data center networks often run close to internet backbones, delivering high throughput with a minimal number of network hops.
No SPOF
Anything that can fail, will. LiveKit Cloud is designed to anticipate (and recover from) failures in every software and hardware component.
Layers of redundancy are built into the system. A media server failure is recovered from by moving impacted participants to another instance. We isolate shared infrastructure, like our message bus, to individual data centers.
When an entire data center fails, customer traffic is automatically migrated to the next closest data center. LiveKit's client SDKs will perform a "session migration": moving existing WebRTC sessions to a different media server without service interruption for your users.
Globally distributed
To serve end users around the world, our infrastructure runs across multiple Cloud vendors and data centers. Today we have data centers in North America, South America, Southeast Asia, East Asia, and Europe, delivering under 100ms of latency for users in those regions.
Designed to scale
When you need to place many viewers on a media track, like in a livestream, LiveKit Cloud handles that capacity dynamically by forming a distribution mesh, similar to a CDN. It's important to note that this process happens automatically as your sessions scales up. There are no special configurations necessary. Every LiveKit Cloud project scales automatically.
The theoretical limits of this architecture is on the order of millions per room/session. For practical purposes, we've placed a limit of 100k simulteneous participants in the same session. If you have a realtime application operating at a scale larger than this, you can request a limit increase in your Cloud dashboard or get in touch with us.
import os
from dotenv import load_dotenv
from livekit.agents import AutoSubscribe, JobContext, WorkerOptions, cli, llm
from livekit.agents.pipeline import VoicePipelineAgent
from livekit.plugins import deepgram, openai, silero
from llama_index.core import (
SimpleDirectoryReader,
StorageContext,
VectorStoreIndex,
load_index_from_storage,
)
load_dotenv()
# check if storage already exists
PERSIST_DIR = "./query-engine-storage"
if not os.path.exists(PERSIST_DIR):
# load the documents and create the index
documents = SimpleDirectoryReader("data").load_data()
index = VectorStoreIndex.from_documents(documents)
# store it for later
index.storage_context.persist(persist_dir=PERSIST_DIR)
else:
# load the existing index
storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR)
index = load_index_from_storage(storage_context)
async def entrypoint(ctx: JobContext):
initial_ctx = llm.ChatContext().append(
role="system",
text=(
"You are a voice assistant created by LiveKit. Your interface with users will be voice. "
"You should use short and concise responses, and avoiding usage of unpronouncable punctuation."
),
)
await ctx.connect(auto_subscribe=AutoSubscribe.AUDIO_ONLY)
fnc_ctx = llm.FunctionContext()
@fnc_ctx.ai_callable(description="Get more information about a specific topic")
async def query_info(query: str) -> str:
query_engine = index.as_query_engine(use_async=True)
res = await query_engine.aquery(query)
print("Query result:", res)
return str(res)
assistant = VoicePipelineAgent(
vad=silero.VAD.load(),
stt=deepgram.STT(),
llm=openai.LLM(),
tts=openai.TTS(),
chat_ctx=initial_ctx,
fnc_ctx=fnc_ctx,
)
assistant.start(ctx.room)
await assistant.say("Hey, how can I help you today?", allow_interruptions=True)
if __name__ == "__main__":
cli.run_app(WorkerOptions(entrypoint_fnc=entrypoint))
import os
from dotenv import load_dotenv
from livekit.agents import AutoSubscribe, JobContext, WorkerOptions, cli, llm
from livekit.agents.pipeline import VoicePipelineAgent
from livekit.plugins import deepgram, openai, silero
from llama_index.core import (
SimpleDirectoryReader,
StorageContext,
VectorStoreIndex,
load_index_from_storage,
)
from llama_index.core.schema import MetadataMode
load_dotenv()
# check if storage already exists
PERSIST_DIR = "./retrieval-engine-storage"
if not os.path.exists(PERSIST_DIR):
# load the documents and create the index
documents = SimpleDirectoryReader("data").load_data()
index = VectorStoreIndex.from_documents(documents)
# store it for later
index.storage_context.persist(persist_dir=PERSIST_DIR)
else:
# load the existing index
storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR)
index = load_index_from_storage(storage_context)
async def entrypoint(ctx: JobContext):
system_msg = llm.ChatMessage(
role="system",
content=(
"You are a voice assistant created by LiveKit. Your interface with users will be voice. "
"You should use short and concise responses, and avoiding usage of unpronouncable punctuation."
),
)
initial_ctx = llm.ChatContext()
initial_ctx.messages.append(system_msg)
async def _will_synthesize_assistant_reply(
assistant: VoicePipelineAgent, chat_ctx: llm.ChatContext
):
ctx_msg = system_msg.copy()
user_msg = chat_ctx.messages[-1]
retriever = index.as_retriever()
nodes = await retriever.aretrieve(user_msg.content)
ctx_msg.content = "Context that might help answer the user's question:"
for node in nodes:
node_content = node.get_content(metadata_mode=MetadataMode.LLM)
ctx_msg.content += f"\n\n{node_content}"
chat_ctx.messages[0] = ctx_msg # the first message is the system message
return assistant.llm.chat(chat_ctx=chat_ctx)
await ctx.connect(auto_subscribe=AutoSubscribe.AUDIO_ONLY)
assistant = VoicePipelineAgent(
vad=silero.VAD.load(),
stt=deepgram.STT(),
llm=openai.LLM(),
tts=openai.TTS(),
chat_ctx=initial_ctx,
will_synthesize_assistant_reply=_will_synthesize_assistant_reply,
)
assistant.start(ctx.room)
await assistant.say("Hey, how can I help you today?", allow_interruptions=True)
if __name__ == "__main__":
cli.run_app(WorkerOptions(entrypoint_fnc=entrypoint))
import logging
from dotenv import load_dotenv
from livekit import rtc
from livekit.agents import (
AutoSubscribe,
JobContext,
JobProcess,
WorkerOptions,
cli,
llm,
metrics,
)
from livekit.agents.pipeline import VoicePipelineAgent
from livekit.plugins import deepgram, openai, silero
load_dotenv()
logger = logging.getLogger("voice-assistant")
def prewarm(proc: JobProcess):
proc.userdata["vad"] = silero.VAD.load()
async def entrypoint(ctx: JobContext):
initial_ctx = llm.ChatContext().append(
role="system",
text=(
"You are a voice assistant created by LiveKit. Your interface with users will be voice. "
"You should use short and concise responses, and avoiding usage of unpronouncable punctuation."
),
)
logger.info(f"connecting to room {ctx.room.name}")
await ctx.connect(auto_subscribe=AutoSubscribe.AUDIO_ONLY)
# wait for the first participant to connect
participant = await ctx.wait_for_participant()
logger.info(f"starting voice assistant for participant {participant.identity}")
dg_model = "nova-3-general"
if participant.kind == rtc.ParticipantKind.PARTICIPANT_KIND_SIP:
# use a model optimized for telephony
dg_model = "nova-2-phonecall"
agent = VoicePipelineAgent(
vad=ctx.proc.userdata["vad"],
stt=deepgram.STT(model=dg_model),
llm=openai.LLM(),
tts=openai.TTS(),
chat_ctx=initial_ctx,
)
agent.start(ctx.room, participant)
usage_collector = metrics.UsageCollector()
@agent.on("metrics_collected")
def _on_metrics_collected(mtrcs: metrics.AgentMetrics):
metrics.log_metrics(mtrcs)
usage_collector.collect(mtrcs)
async def log_usage():
summary = usage_collector.get_summary()
logger.info(f"Usage: ${summary}")
ctx.add_shutdown_callback(log_usage)
await agent.say("Hello there! How can I help you today?", allow_interruptions=False)
if __name__ == "__main__":
cli.run_app(
WorkerOptions(
entrypoint_fnc=entrypoint,
prewarm_fnc=prewarm,
),
)
import logging
from dotenv import load_dotenv
from livekit.agents import AutoSubscribe, JobContext, WorkerOptions, cli, llm
from livekit.agents.pipeline import VoicePipelineAgent
from livekit.plugins import deepgram, openai, silero
from livekit.plugins.openai.beta import (
AssistantCreateOptions,
AssistantLLM,
AssistantOptions,
OnFileUploadedInfo,
)
load_dotenv()
logger = logging.getLogger("openai_assistant")
async def entrypoint(ctx: JobContext):
"""This example demonstrates a VoicePipelineAgent that uses OpenAI's Assistant API as the LLM"""
initial_ctx = llm.ChatContext()
await ctx.connect(auto_subscribe=AutoSubscribe.AUDIO_ONLY)
participant = await ctx.wait_for_participant()
# When you add a ChatMessage that contain images, AssistantLLM will upload them
# to OpenAI's Assistant API.
# It's up to you to remove them if desired or otherwise manage them going forward.
def on_file_uploaded(info: OnFileUploadedInfo):
logger.info(f"{info.type} uploaded: {info.openai_file_object}")
agent = VoicePipelineAgent(
vad=silero.VAD.load(),
stt=deepgram.STT(),
llm=AssistantLLM(
assistant_opts=AssistantOptions(
create_options=AssistantCreateOptions(
model="gpt-4o",
instructions="You are a voice assistant created by LiveKit. Your interface with users will be voice.",
name="KITT",
)
),
on_file_uploaded=on_file_uploaded,
),
tts=openai.TTS(),
chat_ctx=initial_ctx,
)
agent.start(ctx.room, participant)
await agent.say("Hey, how can I help you today?", allow_interruptions=False)
if __name__ == "__main__":
cli.run_app(WorkerOptions(entrypoint_fnc=entrypoint))
livekit-agents>=0.12.20
livekit-plugins-deepgram>=0.7.3
livekit-plugins-google>=0.11.3
livekit-plugins-openai[vertex]>=0.10.10,<1.0.0
livekit-plugins-silero>=0.7.5
livekit-plugins-rag>=0.2.4
python-dotenv~=1.0
aiofile~=3.8.8
import asyncio
from datetime import datetime
from dotenv import load_dotenv
from livekit.agents import AutoSubscribe, JobContext, WorkerOptions, cli, llm
from livekit.agents.pipeline import VoicePipelineAgent
from livekit.plugins import deepgram, openai, silero
load_dotenv()
async def entrypoint(ctx: JobContext):
initial_ctx = llm.ChatContext().append(
role="system",
text=(
"You are a voice assistant created by LiveKit. Your interface with users will be voice. "
"You should use short and concise responses, and avoiding usage of unpronouncable punctuation."
),
)
await ctx.connect(auto_subscribe=AutoSubscribe.AUDIO_ONLY)
agent = VoicePipelineAgent(
vad=silero.VAD.load(),
stt=deepgram.STT(),
llm=openai.LLM(),
tts=openai.TTS(),
chat_ctx=initial_ctx,
)
agent.start(ctx.room)
log_queue = asyncio.Queue()
@agent.on("user_speech_committed")
def on_user_speech_committed(msg: llm.ChatMessage):
# convert string lists to strings, drop images
if isinstance(msg.content, list):
msg.content = "\n".join(
"[image]" if isinstance(x, llm.ChatImage) else x for x in msg
)
log_queue.put_nowait(f"[{datetime.now()}] USER:\n{msg.content}\n\n")
@agent.on("agent_speech_committed")
def on_agent_speech_committed(msg: llm.ChatMessage):
log_queue.put_nowait(f"[{datetime.now()}] AGENT:\n{msg.content}\n\n")
async def write_transcription():
async with open("transcriptions.log", "w") as f:
while True:
msg = await log_queue.get()
if msg is None:
break
await f.write(msg)
write_task = asyncio.create_task(write_transcription())
async def finish_queue():
log_queue.put_nowait(None)
await write_task
ctx.add_shutdown_callback(finish_queue)
await agent.say("Hey, how can I help you today?", allow_interruptions=True)
if __name__ == "__main__":
cli.run_app(WorkerOptions(entrypoint_fnc=entrypoint))
import logging
import pickle
from livekit.agents import AutoSubscribe, JobContext, WorkerOptions, cli, llm
from livekit.agents.pipeline import VoicePipelineAgent
from livekit.plugins import deepgram, openai, rag, silero
logger = logging.getLogger("rag-assistant")
annoy_index = rag.annoy.AnnoyIndex.load("vdb_data") # see build_data.py
embeddings_dimension = 1536
with open("my_data.pkl", "rb") as f:
paragraphs_by_uuid = pickle.load(f)
async def entrypoint(ctx: JobContext):
async def _enrich_with_rag(agent: VoicePipelineAgent, chat_ctx: llm.ChatContext):
# locate the last user message and use it to query the RAG model
# to get the most relevant paragraph
# then provide that as additional context to the LLM
user_msg = chat_ctx.messages[-1]
user_embedding = await openai.create_embeddings(
input=[user_msg.content],
model="text-embedding-3-small",
dimensions=embeddings_dimension,
)
result = annoy_index.query(user_embedding[0].embedding, n=1)[0]
paragraph = paragraphs_by_uuid[result.userdata]
if paragraph:
logger.info(f"enriching with RAG: {paragraph}")
rag_msg = llm.ChatMessage.create(
text="Context:\n" + paragraph,
role="assistant",
)
# replace last message with RAG, and append user message at the end
chat_ctx.messages[-1] = rag_msg
chat_ctx.messages.append(user_msg)
initial_ctx = llm.ChatContext().append(
role="system",
text=(
"You are a voice assistant created by LiveKit. Your interface with users will be voice. "
"You should use short and concise responses, and avoiding usage of unpronouncable punctuation."
"Use the provided context to answer the user's question if needed."
),
)
await ctx.connect(auto_subscribe=AutoSubscribe.AUDIO_ONLY)
agent = VoicePipelineAgent(
chat_ctx=initial_ctx,
vad=silero.VAD.load(),
stt=deepgram.STT(),
llm=openai.LLM(),
tts=openai.TTS(),
before_llm_cb=_enrich_with_rag,
)
agent.start(ctx.room)
await agent.say("Hey, how can I help you today?", allow_interruptions=True)
if __name__ == "__main__":
cli.run_app(WorkerOptions(entrypoint_fnc=entrypoint))
import asyncio
import pickle
import uuid
import aiohttp
from livekit.agents import tokenize
from livekit.plugins import openai, rag
from tqdm import tqdm
# from this blog https://openai.com/index/new-embedding-models-and-api-updates/
# 512 seems to provide good MTEB score with text-embedding-3-small
embeddings_dimension = 1536
raw_data = open("raw_data.txt", "r").read()
async def _create_embeddings(
input: str, http_session: aiohttp.ClientSession
) -> openai.EmbeddingData:
results = await openai.create_embeddings(
input=[input],
model="text-embedding-3-small",
dimensions=embeddings_dimension,
http_session=http_session,
)
return results[0]
async def main() -> None:
async with aiohttp.ClientSession() as http_session:
idx_builder = rag.annoy.IndexBuilder(f=embeddings_dimension, metric="angular")
paragraphs_by_uuid = {}
for p in tokenize.basic.tokenize_paragraphs(raw_data):
p_uuid = uuid.uuid4()
paragraphs_by_uuid[p_uuid] = p
for p_uuid, paragraph in tqdm(paragraphs_by_uuid.items()):
resp = await _create_embeddings(paragraph, http_session)
idx_builder.add_item(resp.embedding, p_uuid)
idx_builder.build()
idx_builder.save("vdb_data")
# save data with pickle
with open("my_data.pkl", "wb") as f:
pickle.dump(paragraphs_by_uuid, f)
if __name__ == "__main__":
asyncio.run(main())
Cloud Architecture
LiveKit Cloud gives you the flexibility of LiveKit's WebRTC stack, combined with global, CDN-scale infrastructure offering 99.99% uptime.
Built with LiveKit SFU
LiveKit Cloud builds on our open-source SFU. This means it supports the exact same client and server APIs as the open-source stack.
Maintaining compatibility with LiveKit's Open Source stack (OSS) is important to us. We didn't want any developer locked into using Cloud, or needing to integrate a different set of features, APIs or SDKs for their applications to work with it. Our design goal: a developer should be able to switch between Cloud or self-hosted without changing a line of code.
Distributed Mesh Architecture
In contrast to traditional WebRTC architectures, LiveKit Cloud runs multiple SFU instances in a mesh formation. We've developed capabilities for media servers to discover and connect to one another, in order to relay media between servers. This key capability allows us to bypass the single-server limitation that exists in traditional SFU and MCU architectures.
Multi-home
Cloud multi-home architecture
With a multi-home architecture, participants no longer need to connect to the same server. When participants from different regions join the same meeting, they'll each connect to the SFU closest to them, minimizing latency and transmission loss between the participant and SFU.
Each SFU instance establishes connections to other instances over optimized inter-data center networks. Inter-data center networks often run close to internet backbones, delivering high throughput with a minimal number of network hops.
No SPOF
Anything that can fail, will. LiveKit Cloud is designed to anticipate (and recover from) failures in every software and hardware component.
Layers of redundancy are built into the system. A media server failure is recovered from by moving impacted participants to another instance. We isolate shared infrastructure, like our message bus, to individual data centers.
When an entire data center fails, customer traffic is automatically migrated to the next closest data center. LiveKit's client SDKs will perform a "session migration": moving existing WebRTC sessions to a different media server without service interruption for your users.
Globally distributed
To serve end users around the world, our infrastructure runs across multiple Cloud vendors and data centers. Today we have data centers in North America, South America, Southeast Asia, East Asia, and Europe, delivering under 100ms of latency for users in those regions.
Designed to scale
When you need to place many viewers on a media track, like in a livestream, LiveKit Cloud handles that capacity dynamically by forming a distribution mesh, similar to a CDN. It's important to note that this process happens automatically as your sessions scales up. There are no special configurations necessary. Every LiveKit Cloud project scales automatically.
The theoretical limits of this architecture is on the order of millions per room/session. For practical purposes, we've placed a limit of 100k simulteneous participants in the same session. If you have a realtime application operating at a scale larger than this, you can request a limit increase in your Cloud dashboard or get in touch with us.
import logging
from dotenv import load_dotenv
from livekit.agents import (
AutoSubscribe,
JobContext,
JobProcess,
WorkerOptions,
cli,
llm,
metrics,
)
from livekit.agents.pipeline import VoicePipelineAgent
from livekit.plugins import deepgram, openai, silero
from livekit.plugins.turn_detector.multilingual import MultilingualModel
load_dotenv()
logger = logging.getLogger("voice-assistant")
def prewarm(proc: JobProcess):
proc.userdata["vad"] = silero.VAD.load()
# This example uses our open-weight turn detection model to detect when the user is
# done speaking. This approach is more accurate than the default VAD model, reducing
# false positive interruptions by the agent.
async def entrypoint(ctx: JobContext):
initial_ctx = llm.ChatContext().append(
role="system",
text=(
"You are a voice assistant created by LiveKit. Your interface with users will be voice. "
"You should use short and concise responses, and avoiding usage of unpronouncable punctuation."
),
)
logger.info(f"connecting to room {ctx.room.name}")
await ctx.connect(auto_subscribe=AutoSubscribe.AUDIO_ONLY)
# wait for the first participant to connect
participant = await ctx.wait_for_participant()
logger.info(f"starting voice assistant for participant {participant.identity}")
agent = VoicePipelineAgent(
vad=ctx.proc.userdata["vad"],
stt=deepgram.STT(model="nova-3", language="multi"),
llm=openai.LLM(model="gpt-4o-mini"),
tts=openai.TTS(),
chat_ctx=initial_ctx,
turn_detector=MultilingualModel(),
)
agent.start(ctx.room, participant)
usage_collector = metrics.UsageCollector()
@agent.on("metrics_collected")
def _on_metrics_collected(mtrcs: metrics.AgentMetrics):
metrics.log_metrics(mtrcs)
usage_collector.collect(mtrcs)
async def log_usage():
summary = usage_collector.get_summary()
logger.info(f"Usage: ${summary}")
ctx.add_shutdown_callback(log_usage)
await agent.say("Hey, how can I help you today?", allow_interruptions=True)
if __name__ == "__main__":
cli.run_app(
WorkerOptions(
entrypoint_fnc=entrypoint,
prewarm_fnc=prewarm,
),
)
# livekit-agents
## 0.12.20
### Patch Changes
- fix decoder: if no data was pushed, close the output channel - [#1881](https://github.com/livekit/agents/pull/1881) ([@jayeshp19](https://github.com/jayeshp19))
## 0.12.19
### Patch Changes
- fixed thread safety in AudioStreamDecoder - [#1736](https://github.com/livekit/agents/pull/1736) ([@jeradf](https://github.com/jeradf))
- cleanup AudioStreamDecoder resources - [#1736](https://github.com/livekit/agents/pull/1736) ([@jeradf](https://github.com/jeradf))
## 0.12.18
### Patch Changes
- Remove unnecessary version pins - [#1682](https://github.com/livekit/agents/pull/1682) ([@hauntsaninja](https://github.com/hauntsaninja))
- reduced retry interval to 2s - [#1701](https://github.com/livekit/agents/pull/1701) ([@davidzhao](https://github.com/davidzhao))
## 0.12.17
### Patch Changes
- use streaming AudioDecoder to handle compressed encoding - [#1584](https://github.com/livekit/agents/pull/1584) ([@davidzhao](https://github.com/davidzhao))
- added a tts.prewarm method to start the connection pool early. - [#1587](https://github.com/livekit/agents/pull/1587) ([@davidzhao](https://github.com/davidzhao))
- Raise ValueError in FallbackAdapter when streaming is not supported - [#1609](https://github.com/livekit/agents/pull/1609) ([@jayeshp19](https://github.com/jayeshp19))
- fixed a bug in AudioStreamDecoder where it could fail on close - [#1587](https://github.com/livekit/agents/pull/1587) ([@davidzhao](https://github.com/davidzhao))
- support for livekit noise cancellation plugin in VoicePipelineAgent and MultimodalAgent - [#1551](https://github.com/livekit/agents/pull/1551) ([@bcherry](https://github.com/bcherry))
- fix: \_play_speech get stuck due to orphan speech handle - [#1555](https://github.com/livekit/agents/pull/1555) ([@SiyuanQi](https://github.com/SiyuanQi))
## 0.12.16
### Patch Changes
- feat: connection pooling. speeds up generation with STT/TTS providers - [#1538](https://github.com/livekit/agents/pull/1538) ([@davidzhao](https://github.com/davidzhao))
- handle process initialization failure - [#1556](https://github.com/livekit/agents/pull/1556) ([@theomonnom](https://github.com/theomonnom))
## 0.12.15
### Patch Changes
- Revert "fix(cli): update main_file path to use current directory" - [#1509](https://github.com/livekit/agents/pull/1509) ([@theomonnom](https://github.com/theomonnom))
## 0.12.14
### Patch Changes
- openai tts: switch to using Opus encoding - [#1494](https://github.com/livekit/agents/pull/1494) ([@davidzhao](https://github.com/davidzhao))
- improve exception logging - [#1490](https://github.com/livekit/agents/pull/1490) ([@jayeshp19](https://github.com/jayeshp19))
- fix interrupting nested speech from before_llm_cb - [#1504](https://github.com/livekit/agents/pull/1504) ([@longcw](https://github.com/longcw))
- add cache tokens in `CompletionUsage` dataclass - [#1478](https://github.com/livekit/agents/pull/1478) ([@jayeshp19](https://github.com/jayeshp19))
## 0.12.13
### Patch Changes
- Allow shutdown callbacks to take reason - [#1475](https://github.com/livekit/agents/pull/1475) ([@martin-purplefish](https://github.com/martin-purplefish))
## 0.12.12
### Patch Changes
- fix agent transcription could not be disabled - [#1448](https://github.com/livekit/agents/pull/1448) ([@davidzhao](https://github.com/davidzhao))
- Added an additional field in LLM capabilities class to check if model providers support function call history within chat context without needing function definitions. - [#1441](https://github.com/livekit/agents/pull/1441) ([@jayeshp19](https://github.com/jayeshp19))
- support agent.say inside the before_llm_cb - [#1460](https://github.com/livekit/agents/pull/1460) ([@longcw](https://github.com/longcw))
## 0.12.11
### Patch Changes
- gemini-realtime: fix input audio sample rate - [#1411](https://github.com/livekit/agents/pull/1411) ([@jayeshp19](https://github.com/jayeshp19))
- fix(pipeline_agent): clear user transcript when before_llm_cb returns false - [#1423](https://github.com/livekit/agents/pull/1423) ([@s-hamdananwar](https://github.com/s-hamdananwar))
- fix: fallbackadapter to correctly handle function calls - [#1429](https://github.com/livekit/agents/pull/1429) ([@davidzhao](https://github.com/davidzhao))
- improved TTFB metrics for streaming TTS - [#1431](https://github.com/livekit/agents/pull/1431) ([@davidzhao](https://github.com/davidzhao))
## 0.12.10
### Patch Changes
- fix false positive interruption tripping up certain LLMs - [#1410](https://github.com/livekit/agents/pull/1410) ([@davidzhao](https://github.com/davidzhao))
- fix: ensure llm.FallbackAdapter executes function calls - [#1409](https://github.com/livekit/agents/pull/1409) ([@davidzhao](https://github.com/davidzhao))
## 0.12.9
### Patch Changes
- add generate_reply api for multimodal agent - [#1359](https://github.com/livekit/agents/pull/1359) ([@longcw](https://github.com/longcw))
- remove aiodns from livekit-agents - [#1368](https://github.com/livekit/agents/pull/1368) ([@theomonnom](https://github.com/theomonnom))
## 0.12.8
### Patch Changes
- Fix not awaiting forward task in TTS forwarder, leading to warnings. - [#1339](https://github.com/livekit/agents/pull/1339) ([@martin-purplefish](https://github.com/martin-purplefish))
- reduces initial delay before model retries - [#1337](https://github.com/livekit/agents/pull/1337) ([@davidzhao](https://github.com/davidzhao))
- fix the function calls without a text response are not added to chat ctx - [#1349](https://github.com/livekit/agents/pull/1349) ([@longcw](https://github.com/longcw))
- add timeout for EOU inference requests made to the inference process - [#1315](https://github.com/livekit/agents/pull/1315) ([@theomonnom](https://github.com/theomonnom))
- support disabling server VAD for OpenAI realtime model - [#1347](https://github.com/livekit/agents/pull/1347) ([@longcw](https://github.com/longcw))
## 0.12.7
### Patch Changes
- ensure job status updates contain the correct status - [#1319](https://github.com/livekit/agents/pull/1319) ([@davidzhao](https://github.com/davidzhao))
## 0.12.6
### Patch Changes
- expose worker_id in jobcontext - [#1307](https://github.com/livekit/agents/pull/1307) ([@s-hamdananwar](https://github.com/s-hamdananwar))
- improved handling of LLM errors, do not retry if already began - [#1298](https://github.com/livekit/agents/pull/1298) ([@davidzhao](https://github.com/davidzhao))
- Do not pass function context if at max depth - [#1306](https://github.com/livekit/agents/pull/1306) ([@martin-purplefish](https://github.com/martin-purplefish))
- avoid warnings when function depth matches limit - [#1316](https://github.com/livekit/agents/pull/1316) ([@davidzhao](https://github.com/davidzhao))
- improve interruption handling, avoid agent from getting stuck - [#1290](https://github.com/livekit/agents/pull/1290) ([@davidzhao](https://github.com/davidzhao))
- add manual interrupt method for pipeline agent - [#1294](https://github.com/livekit/agents/pull/1294) ([@longcw](https://github.com/longcw))
- make multimodal class generic and support gemini live api - [#1240](https://github.com/livekit/agents/pull/1240) ([@jayeshp19](https://github.com/jayeshp19))
## 0.12.5
### Patch Changes
- make max_endpoint_delay configurable - [#1277](https://github.com/livekit/agents/pull/1277) ([@davidzhao](https://github.com/davidzhao))
- set USE_DOCSTRING as default for ai_callable - [#1266](https://github.com/livekit/agents/pull/1266) ([@longcw](https://github.com/longcw))
- fix: do not log process warning when process not found - [#1281](https://github.com/livekit/agents/pull/1281) ([@davidzhao](https://github.com/davidzhao))
- fix context when functions have been called - [#1279](https://github.com/livekit/agents/pull/1279) ([@jmugicagonz](https://github.com/jmugicagonz))
## 0.12.4
### Patch Changes
- avoid duplicated chat ctx for function calls with messages - [#1254](https://github.com/livekit/agents/pull/1254) ([@longcw](https://github.com/longcw))
## 0.12.3
### Patch Changes
- Moved create_ai_function_info to function_context.py for better reusability and reduce repetation - [#1260](https://github.com/livekit/agents/pull/1260) ([@jayeshp19](https://github.com/jayeshp19))
- added streaming audio decoder for compressed audio. - [#1236](https://github.com/livekit/agents/pull/1236) ([@davidzhao](https://github.com/davidzhao))
- Add JPEG quality param to image encoder - [#1249](https://github.com/livekit/agents/pull/1249) ([@bcherry](https://github.com/bcherry))
- Add support for OpenAI's "detail" parameter to ChatImage - [#1213](https://github.com/livekit/agents/pull/1213) ([@bcherry](https://github.com/bcherry))
Add support for data URLs on ChatImage in the Anthropic plugin.
- fix: correctly parse function argument types - [#1221](https://github.com/livekit/agents/pull/1221) ([@jayeshp19](https://github.com/jayeshp19))
- Fix center_aspect_fit bug, add scale_aspect_fit and scale_aspect_fill resizing options. - [#1222](https://github.com/livekit/agents/pull/1222) ([@bcherry](https://github.com/bcherry))
Make scale_aspect_fit the new default resizing option for video frames.
## 0.12.2
### Patch Changes
- improvements to endpointing latency - [#1212](https://github.com/livekit/agents/pull/1212) ([@davidzhao](https://github.com/davidzhao))
- Improvements to end of turn plugin, ensure STT language settings. - [#1195](https://github.com/livekit/agents/pull/1195) ([@davidzhao](https://github.com/davidzhao))
- fix duplicated agent speech commit for message with function call - [#1192](https://github.com/livekit/agents/pull/1192) ([@longcw](https://github.com/longcw))
- fix: Handle optional func args in tool calls when set to `None` - [#1211](https://github.com/livekit/agents/pull/1211) ([@jayeshp19](https://github.com/jayeshp19))
## 0.12.1
### Patch Changes
- fix release - [#1176](https://github.com/livekit/agents/pull/1176) ([@theomonnom](https://github.com/theomonnom))
## 0.12.0
### Minor Changes
- add nested speech handles, now agent.say works during a function call - [#1130](https://github.com/livekit/agents/pull/1130) ([@longcw](https://github.com/longcw))
### Patch Changes
- feat: stt retry & stt.FallbackAdapter - [#1114](https://github.com/livekit/agents/pull/1114) ([@theomonnom](https://github.com/theomonnom))
- expose LiveKitAPI from the a JobContext - [#1159](https://github.com/livekit/agents/pull/1159) ([@theomonnom](https://github.com/theomonnom))
- add extra chat messages to the end of the function call outputs - [#1165](https://github.com/livekit/agents/pull/1165) ([@longcw](https://github.com/longcw))
- Add retries to recover from text mode to audio model for realtime API - [#1121](https://github.com/livekit/agents/pull/1121) ([@longcw](https://github.com/longcw))
- prepare for release - [#1160](https://github.com/livekit/agents/pull/1160) ([@theomonnom](https://github.com/theomonnom))
- add max_job_memory_usage and will kill the job if it exceeds the limit - [#1136](https://github.com/livekit/agents/pull/1136) ([@longcw](https://github.com/longcw))
- support for custom tool use in LLMs - [#1102](https://github.com/livekit/agents/pull/1102) ([@jayeshp19](https://github.com/jayeshp19))
- feat: tts retry & tts.FallbackAdapter - [#1074](https://github.com/livekit/agents/pull/1074) ([@theomonnom](https://github.com/theomonnom))
- Expose multimodal agent metrics - [#1080](https://github.com/livekit/agents/pull/1080) ([@longcw](https://github.com/longcw))
- preload mp3 decoder for TTS plugins - [#1129](https://github.com/livekit/agents/pull/1129) ([@jayeshp19](https://github.com/jayeshp19))
- feat: llm retry & llm.FallbackAdapter - [#1132](https://github.com/livekit/agents/pull/1132) ([@theomonnom](https://github.com/theomonnom))
- feat: inference process & end of utterance plugin - [#1133](https://github.com/livekit/agents/pull/1133) ([@theomonnom](https://github.com/theomonnom))
- vertex ai support with openai library - [#1084](https://github.com/livekit/agents/pull/1084) ([@jayeshp19](https://github.com/jayeshp19))
## 0.11.3
### Patch Changes
- add PeriodicCollector utility for metrics - [#1094](https://github.com/livekit/agents/pull/1094) ([@davidzhao](https://github.com/davidzhao))
## 0.11.2
### Patch Changes
- Fix interrupt_min_words handling - [#1062](https://github.com/livekit/agents/pull/1062) ([@davidzhao](https://github.com/davidzhao))
- pipelineagent: fix speech_committed never called - [#1078](https://github.com/livekit/agents/pull/1078) ([@theomonnom](https://github.com/theomonnom))
- Allow setting agent attributes when accepting job - [#1076](https://github.com/livekit/agents/pull/1076) ([@davidzhao](https://github.com/davidzhao))
- handles error in function calls - [#1057](https://github.com/livekit/agents/pull/1057) ([@jayeshp19](https://github.com/jayeshp19))
- Include job count in WorkerStatus and pass in worker for load_fnc - [#1046](https://github.com/livekit/agents/pull/1046) ([@keepingitneil](https://github.com/keepingitneil))
- Fix delay calculation - [#1081](https://github.com/livekit/agents/pull/1081) ([@martin-purplefish](https://github.com/martin-purplefish))
- sync the Realtime API converstation items and add set_chat_ctx - [#1015](https://github.com/livekit/agents/pull/1015) ([@longcw](https://github.com/longcw))
- added metrics for idle time - [#1064](https://github.com/livekit/agents/pull/1064) ([@jayeshp19](https://github.com/jayeshp19))
## 0.11.1
### Patch Changes
- Fix stack dump on closed stream - [#1023](https://github.com/livekit/agents/pull/1023) ([@martin-purplefish](https://github.com/martin-purplefish))
- fix: invalid request on anthropic - [#1018](https://github.com/livekit/agents/pull/1018) ([@theomonnom](https://github.com/theomonnom))
- fix: IndexError on tts metrics - [#1028](https://github.com/livekit/agents/pull/1028) ([@theomonnom](https://github.com/theomonnom))
## 0.11.0
### Minor Changes
- prepare for release - [#1007](https://github.com/livekit/agents/pull/1007) ([@theomonnom](https://github.com/theomonnom))
### Patch Changes
- Fix race in load calc initialization - [#969](https://github.com/livekit/agents/pull/969) ([@martin-purplefish](https://github.com/martin-purplefish))
- Fix incorrect load computation on docker instances - [#972](https://github.com/livekit/agents/pull/972) ([@martin-purplefish](https://github.com/martin-purplefish))
- stt: reduce bandwidth usage by reducing sample_rate to 16khz - [#920](https://github.com/livekit/agents/pull/920) ([@theomonnom](https://github.com/theomonnom))
- Reorganized metrics, added create_metrics_logger - [#1009](https://github.com/livekit/agents/pull/1009) ([@davidzhao](https://github.com/davidzhao))
- pipelineagent: expose timing metrics & api errors wip - [#957](https://github.com/livekit/agents/pull/957) ([@theomonnom](https://github.com/theomonnom))
- Allow kind to be list or single value - [#1006](https://github.com/livekit/agents/pull/1006) ([@keepingitneil](https://github.com/keepingitneil))
- fix before_llm_cb not handling coroutines returning False - [#961](https://github.com/livekit/agents/pull/961) ([@Tanesan](https://github.com/Tanesan))
- expose transcriptions for multimodal agents - [#1001](https://github.com/livekit/agents/pull/1001) ([@longcw](https://github.com/longcw))
- Fix stack dump on room shutdown - [#989](https://github.com/livekit/agents/pull/989) ([@martin-purplefish](https://github.com/martin-purplefish))
- Add exception logging for tool calls - [#923](https://github.com/livekit/agents/pull/923) ([@martin-purplefish](https://github.com/martin-purplefish))
- Skip egress by default in participant-related utilities on JobContext - [#1005](https://github.com/livekit/agents/pull/1005) ([@keepingitneil](https://github.com/keepingitneil))
- pipeline-agent: avoid nested function calls - [#935](https://github.com/livekit/agents/pull/935) ([@theomonnom](https://github.com/theomonnom))
- expose usage metrics - [#984](https://github.com/livekit/agents/pull/984) ([@theomonnom](https://github.com/theomonnom))
- fix jobs never reloading - [#934](https://github.com/livekit/agents/pull/934) ([@theomonnom](https://github.com/theomonnom))
- voicepipeline: support recursive/chained function calls - [#970](https://github.com/livekit/agents/pull/970) ([@theomonnom](https://github.com/theomonnom))
## 0.10.2
### Patch Changes
- Fix split_paragraphs and simple-rag example - [#896](https://github.com/livekit/agents/pull/896) ([@davidzhao](https://github.com/davidzhao))
- Fix bug where if the tts_source was a string but before_tts_cb returned AsyncIterable[str], the transcript would not be synthesized. - [#906](https://github.com/livekit/agents/pull/906) ([@martin-purplefish](https://github.com/martin-purplefish))
- Allow forcing interruptions of incomplete audio - [#891](https://github.com/livekit/agents/pull/891) ([@martin-purplefish](https://github.com/martin-purplefish))
- Include chat context on collected tool calls - [#897](https://github.com/livekit/agents/pull/897) ([@martin-purplefish](https://github.com/martin-purplefish))
## 0.10.1
### Patch Changes
- use rtc.combine_audio_frames - [#841](https://github.com/livekit/agents/pull/841) ([@theomonnom](https://github.com/theomonnom))
- Fix agent state to not change to listening when user speaks - [#857](https://github.com/livekit/agents/pull/857) ([@martin-purplefish](https://github.com/martin-purplefish))
Fixed canceling uncancelable speech
Fixed bug where agent would get stuck with uninterruptable speech.
- Fix bug where empty audio would cause agent to get stuck. - [#836](https://github.com/livekit/agents/pull/836) ([@martin-purplefish](https://github.com/martin-purplefish))
- fix: handle when STT does not return any speech - [#854](https://github.com/livekit/agents/pull/854) ([@davidzhao](https://github.com/davidzhao))
- Fix watcher reloaded processes double connecting to rooms - [#822](https://github.com/livekit/agents/pull/822) ([@keepingitneil](https://github.com/keepingitneil))
- voice-pipeline: avoid stacked replies when interruptions is disallowed - [#869](https://github.com/livekit/agents/pull/869) ([@theomonnom](https://github.com/theomonnom))
- disable preemptive_synthesis by default - [#867](https://github.com/livekit/agents/pull/867) ([@theomonnom](https://github.com/theomonnom))
- Fixed bug where agent would get stuck on non-interruptable speech - [#850](https://github.com/livekit/agents/pull/850) ([@martin-purplefish](https://github.com/martin-purplefish))
- use EventEmitter from rtc - [#879](https://github.com/livekit/agents/pull/879) ([@theomonnom](https://github.com/theomonnom))
- AudioByteStream: avoid empty frames on flush - [#840](https://github.com/livekit/agents/pull/840) ([@theomonnom](https://github.com/theomonnom))
- improve worker logs - [#878](https://github.com/livekit/agents/pull/878) ([@theomonnom](https://github.com/theomonnom))
- voice-pipeline: fix tts_forwarder not always being closed - [#871](https://github.com/livekit/agents/pull/871) ([@theomonnom](https://github.com/theomonnom))
- bump livekit-rtc to v0.17.5 - [#880](https://github.com/livekit/agents/pull/880) ([@theomonnom](https://github.com/theomonnom))
- Fixed bug where agent would freeze if before_llm_cb returned false - [#865](https://github.com/livekit/agents/pull/865) ([@martin-purplefish](https://github.com/martin-purplefish))
## 0.10.0
### Minor Changes
- OpenAI Realtime API support - [#814](https://github.com/livekit/agents/pull/814) ([@theomonnom](https://github.com/theomonnom))
### Patch Changes
- bump livekit to v0.17.2 - [#815](https://github.com/livekit/agents/pull/815) ([@theomonnom](https://github.com/theomonnom))
- silero: support any sample rate - [#805](https://github.com/livekit/agents/pull/805) ([@theomonnom](https://github.com/theomonnom))
## 0.9.1
### Patch Changes
- fix VoiceAssisstant being stuck when interrupting before user speech is committed - [#790](https://github.com/livekit/agents/pull/790) ([@coderlxn](https://github.com/coderlxn))
- Fix function for OpenAI Assistants - [#784](https://github.com/livekit/agents/pull/784) ([@keepingitneil](https://github.com/keepingitneil))
## 0.9.0
### Minor Changes
- rename voice_assistant.state to lk.agent.state - [#772](https://github.com/livekit/agents/pull/772) ([@bcherry](https://github.com/bcherry))
### Patch Changes
- bump rtc - [#782](https://github.com/livekit/agents/pull/782) ([@nbsp](https://github.com/nbsp))
- improve graceful shutdown - [#756](https://github.com/livekit/agents/pull/756) ([@theomonnom](https://github.com/theomonnom))
- avoid returning tiny frames from TTS - [#747](https://github.com/livekit/agents/pull/747) ([@theomonnom](https://github.com/theomonnom))
- windows: default to threaded executor & fix dev mode - [#755](https://github.com/livekit/agents/pull/755) ([@theomonnom](https://github.com/theomonnom))
- 11labs: send phoneme in one entire xml chunk - [#766](https://github.com/livekit/agents/pull/766) ([@theomonnom](https://github.com/theomonnom))
- fix: process not starting if num_idle_processes is zero - [#763](https://github.com/livekit/agents/pull/763) ([@theomonnom](https://github.com/theomonnom))
- voiceassistant: avoid tiny frames on playout - [#750](https://github.com/livekit/agents/pull/750) ([@theomonnom](https://github.com/theomonnom))
- voiceassistant: expose turn_completion_delay - [#752](https://github.com/livekit/agents/pull/752) ([@theomonnom](https://github.com/theomonnom))
- limit concurrent process init to 1 - [#751](https://github.com/livekit/agents/pull/751) ([@theomonnom](https://github.com/theomonnom))
- Add typing-extensions as a dependency - [#778](https://github.com/livekit/agents/pull/778) ([@keepingitneil](https://github.com/keepingitneil))
- Allow setting LLM temperature with VoiceAssistant - [#741](https://github.com/livekit/agents/pull/741) ([@davidzhao](https://github.com/davidzhao))
- better dev defaults - [#762](https://github.com/livekit/agents/pull/762) ([@theomonnom](https://github.com/theomonnom))
- voiceassistant: allow to cancel llm generation inside before_llm_cb - [#753](https://github.com/livekit/agents/pull/753) ([@theomonnom](https://github.com/theomonnom))
- use os.exit to exit forcefully - [#770](https://github.com/livekit/agents/pull/770) ([@theomonnom](https://github.com/theomonnom))
## 0.8.12
### Patch Changes
- tts*forwarder: don't raise inside mark*{audio,text}\_segment_end when nothing was pushed - [#730](https://github.com/livekit/agents/pull/730) ([@theomonnom](https://github.com/theomonnom))
## 0.8.11
### Patch Changes
- improve gracefully_cancel logic - [#720](https://github.com/livekit/agents/pull/720) ([@theomonnom](https://github.com/theomonnom))
- Make ctx.room.name available prior to connection - [#716](https://github.com/livekit/agents/pull/716) ([@davidzhao](https://github.com/davidzhao))
- ipc: add threaded job runner - [#684](https://github.com/livekit/agents/pull/684) ([@theomonnom](https://github.com/theomonnom))
- voiceassistant: add VoiceAssistantState - [#654](https://github.com/livekit/agents/pull/654) ([@lukasIO](https://github.com/lukasIO))
- add JobContext.wait_for_participant - [#712](https://github.com/livekit/agents/pull/712) ([@theomonnom](https://github.com/theomonnom))
- fix non pickleable log - [#691](https://github.com/livekit/agents/pull/691) ([@theomonnom](https://github.com/theomonnom))
- voiceassistant: skip speech initialization if interrupted - [#715](https://github.com/livekit/agents/pull/715) ([@theomonnom](https://github.com/theomonnom))
- bump required livekit version to 0.15.2 - [#722](https://github.com/livekit/agents/pull/722) ([@theomonnom](https://github.com/theomonnom))
- voiceassistant: add will_synthesize_assistant_speech - [#706](https://github.com/livekit/agents/pull/706) ([@theomonnom](https://github.com/theomonnom))
- voiceassistant: fix mark_audio_segment_end with no audio data - [#719](https://github.com/livekit/agents/pull/719) ([@theomonnom](https://github.com/theomonnom))
## 0.8.10
### Patch Changes
- Pass JobContext to participant entrypoint function - [#694](https://github.com/livekit/agents/pull/694) ([@davidzhao](https://github.com/davidzhao))
- voiceassistant: keep punctuations when sending agent transcription - [#648](https://github.com/livekit/agents/pull/648) ([@theomonnom](https://github.com/theomonnom))
## 0.8.9
### Patch Changes
- Introduce easy api for starting tasks for remote participants - [#679](https://github.com/livekit/agents/pull/679) ([@keepingitneil](https://github.com/keepingitneil))
- update livekit to 0.14.0 and await tracksubscribed - [#678](https://github.com/livekit/agents/pull/678) ([@nbsp](https://github.com/nbsp))
## 0.8.8
### Patch Changes
- fix uninitialized SpeechHandle error on interruption - [#665](https://github.com/livekit/agents/pull/665) ([@theomonnom](https://github.com/theomonnom))
- voiceassistant: avoid stacking assistant replies when allow_interruptions=False - [#667](https://github.com/livekit/agents/pull/667) ([@theomonnom](https://github.com/theomonnom))
- fix: disconnect event may now have a arguments - [#668](https://github.com/livekit/agents/pull/668) ([@theomonnom](https://github.com/theomonnom))
- Add ServerMessage.termination handler - [#635](https://github.com/livekit/agents/pull/635) ([@nbsp](https://github.com/nbsp))
## 0.8.7
### Patch Changes
- voiceassistant: fix llm not having the full chat context on bad interruption timing - [#659](https://github.com/livekit/agents/pull/659) ([@theomonnom](https://github.com/theomonnom))
## 0.8.6
### Patch Changes
- voiceassistant: fix will_synthesize_assistant_reply race - [#638](https://github.com/livekit/agents/pull/638) ([@theomonnom](https://github.com/theomonnom))
- Switch Cartesia to a sentence tokenizer and keep the same context id throughout. - [#608](https://github.com/livekit/agents/pull/608) ([@keepingitneil](https://github.com/keepingitneil))
Propagate segment_id through the basic sentence tokenizer
- silero: adjust vad activation threshold - [#639](https://github.com/livekit/agents/pull/639) ([@theomonnom](https://github.com/theomonnom))
- limit simultaneous process initialization - [#621](https://github.com/livekit/agents/pull/621) ([@theomonnom](https://github.com/theomonnom))
- voiceassistant: remove fade effect when interrupting #622 - [#623](https://github.com/livekit/agents/pull/623) ([@theomonnom](https://github.com/theomonnom))
- ipc improvements, fix slow shutdown & cleanup leaked resources - [#607](https://github.com/livekit/agents/pull/607) ([@theomonnom](https://github.com/theomonnom))
- ipc: use our own duplex instead of mp.Queue - [#634](https://github.com/livekit/agents/pull/634) ([@theomonnom](https://github.com/theomonnom))
- Support OpenAI Assistants API as a beta feature under `livekit.plugins.openai.beta` - [#601](https://github.com/livekit/agents/pull/601) ([@keepingitneil](https://github.com/keepingitneil))
Add \_metadata to ChatCtx and ChatMessage which can be used (in the case of OpenAI assistants) for bookeeping to sync local state with remote, OpenAI state
- llm: fix optional arguments & non-hashable list - [#637](https://github.com/livekit/agents/pull/637) ([@theomonnom](https://github.com/theomonnom))
- silero: fix vad padding & static audio - [#631](https://github.com/livekit/agents/pull/631) ([@theomonnom](https://github.com/theomonnom))
## 0.8.5
### Patch Changes
- add support for optional arguments on ai_callable functions - [#600](https://github.com/livekit/agents/pull/600) ([@theomonnom](https://github.com/theomonnom))
- voiceassistant: correctly export AssistantTranscriptionOptions - [#598](https://github.com/livekit/agents/pull/598) ([@theomonnom](https://github.com/theomonnom))
- fix: log levelname not present when using the start subcommand - [#602](https://github.com/livekit/agents/pull/602) ([@theomonnom](https://github.com/theomonnom))
- voiceassistant: fix incomplete committed agent transcript in the chat_ctx - [#595](https://github.com/livekit/agents/pull/595) ([@theomonnom](https://github.com/theomonnom))
- cartesia: correctly add spaces & fix tests - [#591](https://github.com/livekit/agents/pull/591) ([@theomonnom](https://github.com/theomonnom))
## 0.8.4
### Patch Changes
- voiceassistant: only commit the spoken words in the chat context. - [#589](https://github.com/livekit/agents/pull/589) ([@theomonnom](https://github.com/theomonnom))
- use aiodns by default - [#579](https://github.com/livekit/agents/pull/579) ([@theomonnom](https://github.com/theomonnom))
- voice_assistant: fix missing spaces between transcript chunks - [#566](https://github.com/livekit/agents/pull/566) ([@egoldschmidt](https://github.com/egoldschmidt))
- voiceassistant: fix transcription being fully sent even when interrupted - [#581](https://github.com/livekit/agents/pull/581) ([@theomonnom](https://github.com/theomonnom))
- voiceassistant: fix AssertionError when there is no user_question - [#582](https://github.com/livekit/agents/pull/582) ([@theomonnom](https://github.com/theomonnom))
- voiceassistant: fix speech validation cancellation - [#584](https://github.com/livekit/agents/pull/584) ([@theomonnom](https://github.com/theomonnom))
- voiceassistant: fix synthesis continuing after interruption - [#588](https://github.com/livekit/agents/pull/588) ([@theomonnom](https://github.com/theomonnom))
## 0.8.3
### Patch Changes
- voiceassistant: run function calls sequentially - [#554](https://github.com/livekit/agents/pull/554) ([@theomonnom](https://github.com/theomonnom))
- configure plugins loggers & more debug logs on the voiceassistant - [#555](https://github.com/livekit/agents/pull/555) ([@theomonnom](https://github.com/theomonnom))
- warn no room connection after job_entry was called after 10 seconds. - [#558](https://github.com/livekit/agents/pull/558) ([@theomonnom](https://github.com/theomonnom))
- deepgram: reduce chunks size to 100ms - [#561](https://github.com/livekit/agents/pull/561) ([@theomonnom](https://github.com/theomonnom))
- voiceassistant: cleanup validation behaviour #545 - [#553](https://github.com/livekit/agents/pull/553) ([@theomonnom](https://github.com/theomonnom))
- voiceassistant: commit user question directly when allow_interruptions=False - [#547](https://github.com/livekit/agents/pull/547) ([@theomonnom](https://github.com/theomonnom))
- ipc: increase high ping threshold - [#556](https://github.com/livekit/agents/pull/556) ([@theomonnom](https://github.com/theomonnom))
- voiceassistant: interrupt on final transcript - [#546](https://github.com/livekit/agents/pull/546) ([@theomonnom](https://github.com/theomonnom))
- voiceassistant: tweaks & fix speech being removed too soon from the queue - [#560](https://github.com/livekit/agents/pull/560) ([@theomonnom](https://github.com/theomonnom))
- voiceassistant: fix duplicate answers - [#548](https://github.com/livekit/agents/pull/548) ([@theomonnom](https://github.com/theomonnom))
- reduce the default load threshold to a more appropriate default - [#559](https://github.com/livekit/agents/pull/559) ([@theomonnom](https://github.com/theomonnom))
## 0.8.2
### Patch Changes
- fix: remove unnecessary async function - [#540](https://github.com/livekit/agents/pull/540) ([@Nabil372](https://github.com/Nabil372))
## 0.8.1
### Patch Changes
- update livekit-rtc to v0.12.0 - [#535](https://github.com/livekit/agents/pull/535) ([@theomonnom](https://github.com/theomonnom))
- automatically create stt.StreamAdapter when provided stt doesn't support streaming - [#536](https://github.com/livekit/agents/pull/536) ([@theomonnom](https://github.com/theomonnom))
- update examples to the latest API & export AutoSubscribe - [#534](https://github.com/livekit/agents/pull/534) ([@theomonnom](https://github.com/theomonnom))
- fix end_input not flushing & unhandled flush messages - [#528](https://github.com/livekit/agents/pull/528) ([@theomonnom](https://github.com/theomonnom))
## 0.8.0
### Minor Changes
- dev prerelease - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
### Patch Changes
- test release - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
- pull: '--rebase --autostash ...' - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
- Default loglevel to warn - [#472](https://github.com/livekit/agents/pull/472) ([@lukasIO](https://github.com/lukasIO))
- bump versions to update dependencies - [#510](https://github.com/livekit/agents/pull/510) ([@theomonnom](https://github.com/theomonnom))
- test release - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
- fix changesets release CI - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
- release v0.8.0 - [`6e74aa714c2dfaa8212db4528d7b59d095b6c660`](https://github.com/livekit/agents/commit/6e74aa714c2dfaa8212db4528d7b59d095b6c660) ([@theomonnom](https://github.com/theomonnom))
- dev fixes - multiprocessing & voiceassistant - [#493](https://github.com/livekit/agents/pull/493) ([@theomonnom](https://github.com/theomonnom))
## 0.8.0-dev.8
### Patch Changes
- pull: '--rebase --autostash ...' - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
## 0.8.0-dev.7
### Patch Changes
- test release - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
## 0.8.0-dev.6
### Patch Changes
- test release - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
## 0.8.0-dev.5
### Patch Changes
- fix changesets release CI - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
## 0.8.0-dev.4
### Patch Changes
- bump versions to update dependencies - [#510](https://github.com/livekit/agents/pull/510) ([@theomonnom](https://github.com/theomonnom))
## 0.8.0-dev.3
### Patch Changes
- dev fixes - multiprocessing & voiceassistant - [#493](https://github.com/livekit/agents/pull/493) ([@theomonnom](https://github.com/theomonnom))
## 0.8.0-dev.2
### Minor Changes
- dev prerelease - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
## 0.7.3-dev.1
### Patch Changes
- Default loglevel to warn - [#472](https://github.com/livekit/agents/pull/472) ([@lukasIO](https://github.com/lukasIO))
# LiveKit Agents
The core LiveKit Agents Framework. See top-level README for more information.
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import (
cli,
ipc,
llm,
metrics,
multimodal,
pipeline,
stt,
tokenize,
transcription,
tts,
utils,
vad,
voice_assistant,
)
from ._exceptions import (
APIConnectionError,
APIError,
APIStatusError,
APITimeoutError,
AssignmentTimeoutError,
)
from .job import AutoSubscribe, JobContext, JobExecutorType, JobProcess, JobRequest
from .plugin import Plugin
from .types import (
ATTRIBUTE_AGENT_STATE,
DEFAULT_API_CONNECT_OPTIONS,
NOT_GIVEN,
AgentState,
APIConnectOptions,
NotGiven,
NotGivenOr,
)
from .version import __version__
from .worker import Worker, WorkerOptions, WorkerPermissions, WorkerType
__all__ = [
"__version__",
"Worker",
"WorkerOptions",
"WorkerType",
"WorkerPermissions",
"JobProcess",
"JobContext",
"JobRequest",
"JobExecutorType",
"AutoSubscribe",
"AgentState",
"Plugin",
"ipc",
"stt",
"vad",
"utils",
"tts",
"tokenize",
"llm",
"metrics",
"transcription",
"pipeline",
"multimodal",
"voice_assistant",
"cli",
"AssignmentTimeoutError",
"APIConnectionError",
"APIError",
"APIStatusError",
"APITimeoutError",
"ATTRIBUTE_AGENT_STATE",
"APIConnectOptions",
"DEFAULT_API_CONNECT_OPTIONS",
"AgentState",
"NotGiven",
"NOT_GIVEN",
"NotGivenOr",
]
# Cleanup docs of unexported modules
_module = dir()
NOT_IN_ALL = [m for m in _module if m not in __all__]
__pdoc__ = {}
for n in NOT_IN_ALL:
__pdoc__[n] = False
from __future__ import annotations
class AssignmentTimeoutError(Exception):
"""Raised when accepting a job but not receiving an assignment within the specified timeout.
The server may have chosen another worker to handle this job."""
pass
# errors used by our plugins
class APIError(Exception):
"""Raised when an API request failed.
This is used on our TTS/STT/LLM plugins."""
message: str
"""
The error message returned by the API.
"""
body: object | None
"""The API response body, if available.
If the API returned a valid json, the body will contains
the decodede result.
"""
retryable: bool = False
"""Whether the error can be retried."""
def __init__(
self, message: str, *, body: object | None, retryable: bool = True
) -> None:
super().__init__(message)
self.message = message
self.body = body
self.retryable = retryable
class APIStatusError(APIError):
"""Raised when an API response has a status code of 4xx or 5xx."""
status_code: int
"""The status code of the API response."""
request_id: str | None
"""The request ID of the API response, if available."""
def __init__(
self,
message: str,
*,
status_code: int = -1,
request_id: str | None = None,
body: object | None = None,
retryable: bool | None = None,
) -> None:
if retryable is None:
retryable = True
# 4xx errors are not retryable
if status_code >= 400 and status_code < 500:
retryable = False
super().__init__(message, body=body, retryable=retryable)
self.status_code = status_code
self.request_id = request_id
def __str__(self):
return (
f"{self.message} "
f"(status_code={self.status_code}, request_id={self.request_id}, body={self.body})"
)
class APIConnectionError(APIError):
"""Raised when an API request failed due to a connection error."""
def __init__(
self, message: str = "Connection error.", *, retryable: bool = True
) -> None:
super().__init__(message, body=None, retryable=retryable)
class APITimeoutError(APIConnectionError):
"""Raised when an API request timed out."""
def __init__(
self, message: str = "Request timed out.", *, retryable: bool = True
) -> None:
super().__init__(message, retryable=retryable)
from .cli import run_app
__all__ = ["run_app"]
import asyncio
import pathlib
import signal
import sys
import click
from livekit.protocol import models
from .. import utils
from ..log import logger
from ..plugin import Plugin
from ..worker import Worker, WorkerOptions
from . import proto
from .log import setup_logging
def run_app(opts: WorkerOptions) -> None:
"""Run the CLI to interact with the worker"""
cli = click.Group()
@cli.command(help="Start the worker in production mode.")
@click.option(
"--log-level",
default="INFO",
type=click.Choice(
["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], case_sensitive=False
),
help="Set the logging level",
)
@click.option(
"--url",
envvar="LIVEKIT_URL",
help="LiveKit server or Cloud project's websocket URL",
)
@click.option(
"--api-key",
envvar="LIVEKIT_API_KEY",
help="LiveKit server or Cloud project's API key",
)
@click.option(
"--api-secret",
envvar="LIVEKIT_API_SECRET",
help="LiveKit server or Cloud project's API secret",
)
@click.option(
"--drain-timeout",
default=60,
help="Time in seconds to wait for jobs to finish before shutting down",
)
def start(
log_level: str, url: str, api_key: str, api_secret: str, drain_timeout: int
) -> None:
opts.ws_url = url or opts.ws_url
opts.api_key = api_key or opts.api_key
opts.api_secret = api_secret or opts.api_secret
args = proto.CliArgs(
opts=opts,
log_level=log_level,
devmode=False,
asyncio_debug=False,
watch=False,
drain_timeout=drain_timeout,
)
run_worker(args)
@cli.command(help="Start the worker in development mode")
@click.option(
"--log-level",
default="DEBUG",
type=click.Choice(
["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], case_sensitive=False
),
help="Set the logging level",
)
@click.option(
"--url",
envvar="LIVEKIT_URL",
help="LiveKit server or Cloud project's websocket URL",
)
@click.option(
"--api-key",
envvar="LIVEKIT_API_KEY",
help="LiveKit server or Cloud project's API key",
)
@click.option(
"--api-secret",
envvar="LIVEKIT_API_SECRET",
help="LiveKit server or Cloud project's API secret",
)
@click.option(
"--asyncio-debug/--no-asyncio-debug",
default=False,
help="Enable debugging feature of asyncio",
)
@click.option(
"--watch/--no-watch",
default=True,
help="Watch for changes in the current directory and plugins in editable mode",
)
def dev(
log_level: str,
url: str,
api_key: str,
api_secret: str,
asyncio_debug: bool,
watch: bool,
) -> None:
_run_dev(opts, log_level, url, api_key, api_secret, asyncio_debug, watch)
@cli.command(help="Connect to a specific room")
@click.option(
"--log-level",
default="DEBUG",
type=click.Choice(
["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], case_sensitive=False
),
help="Set the logging level",
)
@click.option(
"--url",
envvar="LIVEKIT_URL",
help="LiveKit server or Cloud project's websocket URL",
)
@click.option(
"--api-key",
envvar="LIVEKIT_API_KEY",
help="LiveKit server or Cloud project's API key",
)
@click.option(
"--api-secret",
envvar="LIVEKIT_API_SECRET",
help="LiveKit server or Cloud project's API secret",
)
@click.option(
"--asyncio-debug/--no-asyncio-debug",
default=False,
help="Enable debugging feature of asyncio",
)
@click.option(
"--watch/--no-watch",
default=True,
help="Watch for changes in the current directory and plugins in editable mode",
)
@click.option("--room", help="Room name to connect to", required=True)
@click.option(
"--participant-identity", help="Participant identity (JobType.JT_PUBLISHER)"
)
def connect(
log_level: str,
url: str,
api_key: str,
api_secret: str,
asyncio_debug: bool,
watch: bool,
room: str,
participant_identity: str,
) -> None:
_run_dev(
opts,
log_level,
url,
api_key,
api_secret,
asyncio_debug,
watch,
room,
participant_identity,
)
@cli.command(help="Download plugin dependency files")
@click.option(
"--log-level",
default="DEBUG",
type=click.Choice(
["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], case_sensitive=False
),
help="Set the logging level",
)
def download_files(log_level: str) -> None:
setup_logging(log_level, True)
for plugin in Plugin.registered_plugins:
logger.info(f"Downloading files for {plugin}")
plugin.download_files()
logger.info(f"Finished downloading files for {plugin}")
cli()
def _run_dev(
opts: WorkerOptions,
log_level: str,
url: str,
api_key: str,
api_secret: str,
asyncio_debug: bool,
watch: bool,
room: str = "",
participant_identity: str = "",
):
opts.ws_url = url or opts.ws_url
opts.api_key = api_key or opts.api_key
opts.api_secret = api_secret or opts.api_secret
args = proto.CliArgs(
opts=opts,
log_level=log_level,
devmode=True,
asyncio_debug=asyncio_debug,
watch=watch,
drain_timeout=0,
room=room,
participant_identity=participant_identity,
)
if watch:
from .watcher import WatchServer
setup_logging(log_level, args.devmode)
main_file = pathlib.Path(sys.argv[0]).parent
async def _run_loop():
server = WatchServer(
run_worker, main_file, args, loop=asyncio.get_event_loop()
)
await server.run()
try:
asyncio.run(_run_loop())
except KeyboardInterrupt:
pass
else:
run_worker(args)
def run_worker(args: proto.CliArgs) -> None:
setup_logging(args.log_level, args.devmode)
args.opts.validate_config(args.devmode)
loop = asyncio.get_event_loop()
worker = Worker(args.opts, devmode=args.devmode, loop=loop)
loop.set_debug(args.asyncio_debug)
loop.slow_callback_duration = 0.1 # 100ms
utils.aio.debug.hook_slow_callbacks(2)
if args.room and args.reload_count == 0:
# directly connect to a specific room
@worker.once("worker_registered")
def _connect_on_register(worker_id: str, server_info: models.ServerInfo):
logger.info("connecting to room %s", args.room)
loop.create_task(worker.simulate_job(args.room, args.participant_identity))
try:
def _signal_handler():
raise KeyboardInterrupt
for sig in (signal.SIGINT, signal.SIGTERM):
loop.add_signal_handler(sig, _signal_handler)
except NotImplementedError:
# TODO(theomonnom): add_signal_handler is not implemented on win
pass
async def _worker_run(worker: Worker) -> None:
try:
await worker.run()
except Exception:
logger.exception("worker failed")
watch_client = None
if args.watch:
from .watcher import WatchClient
watch_client = WatchClient(worker, args, loop=loop)
watch_client.start()
try:
main_task = loop.create_task(_worker_run(worker), name="agent_runner")
try:
loop.run_until_complete(main_task)
except KeyboardInterrupt:
pass
try:
if not args.devmode:
loop.run_until_complete(worker.drain(timeout=args.drain_timeout))
loop.run_until_complete(worker.aclose())
if watch_client:
loop.run_until_complete(watch_client.aclose())
except KeyboardInterrupt:
logger.warning("exiting forcefully")
import os
os._exit(1) # TODO(theomonnom): add aclose(force=True) in worker
finally:
try:
tasks = asyncio.all_tasks(loop)
for task in tasks:
task.cancel()
loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True))
loop.run_until_complete(loop.shutdown_asyncgens())
loop.run_until_complete(loop.shutdown_default_executor())
finally:
loop.close()
from __future__ import annotations
import json
import logging
import re
import traceback
from collections import OrderedDict
from datetime import date, datetime, time, timezone
from inspect import istraceback
from typing import Any, Dict, Tuple
from ..plugin import Plugin
# noisy loggers are set to warn by default
NOISY_LOGGERS = [
"httpx",
"httpcore",
"openai",
"watchfiles",
"anthropic",
"websockets.client",
"botocore",
"aiobotocore",
]
def _silence_noisy_loggers() -> None:
for noisy_logger in NOISY_LOGGERS:
logger = logging.getLogger(noisy_logger)
if logger.level == logging.NOTSET:
logger.setLevel(logging.WARN)
# skip default LogRecord attributes
# http://docs.python.org/library/logging.html#logrecord-attributes
_RESERVED_ATTRS: Tuple[str, ...] = (
"args",
"asctime",
"created",
"exc_info",
"exc_text",
"filename",
"funcName",
"levelname",
"levelno",
"lineno",
"module",
"msecs",
"message",
"msg",
"name",
"pathname",
"process",
"processName",
"relativeCreated",
"stack_info",
"thread",
"threadName",
"taskName",
)
def _merge_record_extra(record: logging.LogRecord, target: Dict[Any, Any]):
for key, value in record.__dict__.items():
if key not in _RESERVED_ATTRS and not (
hasattr(key, "startswith") and key.startswith("_")
):
target[key] = value
def _parse_style(formatter: logging.Formatter) -> list[str]:
"""parse the list of fields required by the style"""
if isinstance(formatter._style, logging.StringTemplateStyle):
formatter_style_pattern = re.compile(r"\$\{(.+?)\}", re.IGNORECASE)
elif isinstance(formatter._style, logging.StrFormatStyle):
formatter_style_pattern = re.compile(r"\{(.+?)\}", re.IGNORECASE)
elif isinstance(formatter._style, logging.PercentStyle):
formatter_style_pattern = re.compile(r"%\((.+?)\)", re.IGNORECASE)
else:
raise ValueError("Invalid format: %s" % formatter._fmt)
if formatter._fmt:
return formatter_style_pattern.findall(formatter._fmt)
else:
return []
class JsonFormatter(logging.Formatter):
class JsonEncoder(json.JSONEncoder):
def default(self, o: Any):
if isinstance(o, (date, datetime, time)):
return o.isoformat()
elif istraceback(o):
return "".join(traceback.format_tb(o)).strip()
elif type(o) is Exception or isinstance(o, Exception) or type(o) is type:
return str(o)
# extra values are formatted as str() if the encoder raises TypeError
try:
return super().default(o)
except TypeError:
try:
return str(o)
except Exception:
return None
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._required_fields = _parse_style(self)
def format(self, record: logging.LogRecord) -> str:
"""Formats a log record and serializes to json"""
message_dict: Dict[str, Any] = {}
message_dict["level"] = record.levelname
message_dict["name"] = record.name
if isinstance(record.msg, dict):
message_dict = record.msg
record.message = ""
else:
record.message = record.getMessage()
if "asctime" in self._required_fields:
record.asctime = self.formatTime(record, self.datefmt)
if record.exc_info and not message_dict.get("exc_info"):
message_dict["exc_info"] = self.formatException(record.exc_info)
if not message_dict.get("exc_info") and record.exc_text:
message_dict["exc_info"] = record.exc_text
if record.stack_info and not message_dict.get("stack_info"):
message_dict["stack_info"] = self.formatStack(record.stack_info)
log_record: Dict[str, Any] = OrderedDict()
for field in self._required_fields:
log_record[field] = record.__dict__.get(field)
log_record.update(message_dict)
_merge_record_extra(record, log_record)
log_record["timestamp"] = datetime.fromtimestamp(
record.created, tz=timezone.utc
)
return json.dumps(log_record, cls=JsonFormatter.JsonEncoder, ensure_ascii=True)
class ColoredFormatter(logging.Formatter):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._esc_codes = {
"esc_reset": self._esc(0),
"esc_red": self._esc(31),
"esc_green": self._esc(32),
"esc_yellow": self._esc(33),
"esc_blue": self._esc(34),
"esc_purple": self._esc(35),
"esc_cyan": self._esc(36),
"esc_gray": self._esc(90),
"esc_bold_red": self._esc(1, 31),
}
self._level_colors = {
"DEBUG": self._esc_codes["esc_cyan"],
"INFO": self._esc_codes["esc_green"],
"WARNING": self._esc_codes["esc_yellow"],
"ERROR": self._esc_codes["esc_red"],
"CRITICAL": self._esc_codes["esc_bold_red"],
"DEV": self._esc_codes["esc_purple"],
}
self._required_fields = _parse_style(self)
@classmethod
def _esc(cls, *codes: int) -> str:
return "\033[" + ";".join(str(code) for code in codes) + "m"
def formatMessage(self, record: logging.LogRecord) -> str:
"""Formats a log record with colors"""
extra: Dict[Any, Any] = {}
_merge_record_extra(record, extra)
args = {}
for field in self._required_fields:
args[field] = record.__dict__.get(field)
args["esc_levelcolor"] = self._level_colors.get(record.levelname, "")
args["extra"] = ""
args.update(self._esc_codes)
if extra:
args["extra"] = json.dumps(
extra, cls=JsonFormatter.JsonEncoder, ensure_ascii=True
)
for field in self._required_fields:
if field in extra:
del extra[field]
msg = self._style._fmt % args
return msg + self._esc_codes["esc_reset"]
def setup_logging(log_level: str, devmode: bool) -> None:
handler = logging.StreamHandler()
if devmode:
# colorful logs for dev (improves readability)
colored_formatter = ColoredFormatter(
"%(asctime)s - %(esc_levelcolor)s%(levelname)-4s%(esc_reset)s %(name)s - %(message)s %(esc_gray)s%(extra)s"
)
handler.setFormatter(colored_formatter)
else:
# production logs (serialized of json)
json_formatter = JsonFormatter()
handler.setFormatter(json_formatter)
root = logging.getLogger()
root.addHandler(handler)
root.setLevel(log_level)
_silence_noisy_loggers()
from ..log import logger
if logger.level == logging.NOTSET:
logger.setLevel(log_level)
from ..pipeline.log import logger
if logger.level == logging.NOTSET:
logger.setLevel(log_level)
def _configure_plugin_logger(plugin: Plugin) -> None:
if plugin.logger is not None and plugin.logger.level == logging.NOTSET:
plugin.logger.setLevel(log_level)
for plugin in Plugin.registered_plugins:
_configure_plugin_logger(plugin)
Plugin.emitter.on("plugin_registered", _configure_plugin_logger)
from __future__ import annotations
import io
import socket
from dataclasses import dataclass, field
from typing import ClassVar
from livekit.protocol import agent
from ..ipc import channel
from ..job import JobAcceptArguments, RunningJobInfo
from ..worker import WorkerOptions
@dataclass
class CliArgs:
opts: WorkerOptions
log_level: str
devmode: bool
asyncio_debug: bool
watch: bool
drain_timeout: int
room: str = ""
participant_identity: str = ""
# amount of time this worker has been reloaded
reload_count: int = 0
# pipe used for the communication between the watch server and the watch client
# when reload/dev mode is enabled
mp_cch: socket.socket | None = None
@dataclass
class ActiveJobsRequest:
MSG_ID: ClassVar[int] = 1
@dataclass
class ActiveJobsResponse:
MSG_ID: ClassVar[int] = 2
jobs: list[RunningJobInfo] = field(default_factory=list)
reload_count: int = 0
def write(self, b: io.BytesIO) -> None:
channel.write_int(b, len(self.jobs))
for running_job in self.jobs:
accept_args = running_job.accept_arguments
channel.write_bytes(b, running_job.job.SerializeToString())
channel.write_string(b, accept_args.name)
channel.write_string(b, accept_args.identity)
channel.write_string(b, accept_args.metadata)
channel.write_string(b, running_job.url)
channel.write_string(b, running_job.token)
channel.write_string(b, running_job.worker_id)
channel.write_int(b, self.reload_count)
def read(self, b: io.BytesIO) -> None:
for _ in range(channel.read_int(b)):
job = agent.Job()
job.ParseFromString(channel.read_bytes(b))
self.jobs.append(
RunningJobInfo(
accept_arguments=JobAcceptArguments(
name=channel.read_string(b),
identity=channel.read_string(b),
metadata=channel.read_string(b),
),
job=job,
url=channel.read_string(b),
token=channel.read_string(b),
worker_id=channel.read_string(b),
)
)
self.reload_count = channel.read_int(b)
@dataclass
class ReloadJobsRequest:
MSG_ID: ClassVar[int] = 3
@dataclass
class ReloadJobsResponse(ActiveJobsResponse):
MSG_ID: ClassVar[int] = 4
@dataclass
class Reloaded:
MSG_ID: ClassVar[int] = 5
IPC_MESSAGES = {
ActiveJobsRequest.MSG_ID: ActiveJobsRequest,
ActiveJobsResponse.MSG_ID: ActiveJobsResponse,
ReloadJobsRequest.MSG_ID: ReloadJobsRequest,
ReloadJobsResponse.MSG_ID: ReloadJobsResponse,
Reloaded.MSG_ID: Reloaded,
}
from __future__ import annotations
import asyncio
import contextlib
import json
import pathlib
import socket
import urllib.parse
import urllib.request
from importlib.metadata import Distribution, PackageNotFoundError
from typing import Any, Callable, Set
import watchfiles
from .. import utils
from ..ipc import channel
from ..log import DEV_LEVEL, logger
from ..plugin import Plugin
from ..worker import Worker
from . import proto
def _find_watchable_paths(main_file: pathlib.Path) -> list[pathlib.Path]:
packages: list[Distribution] = []
# also watch agents plugins in editable mode
def _try_add(name: str) -> bool:
nonlocal packages
try:
dist = Distribution.from_name(name)
packages.append(dist)
return True
except PackageNotFoundError:
return False
if not _try_add("livekit.agents"):
_try_add("livekit-agents")
for plugin in Plugin.registered_plugins:
if not _try_add(plugin.package):
_try_add(plugin.package.replace(".", "-"))
paths: list[pathlib.Path] = [main_file.absolute()]
for pkg in packages:
# https://packaging.python.org/en/latest/specifications/direct-url/
durl = pkg.read_text("direct_url.json")
if not durl:
continue
durl_json: dict[str, Any] = json.loads(durl)
dir_info = durl_json.get("dir_info", {})
if dir_info.get("editable", False):
path: str | None = durl_json.get("url")
if path and path.startswith("file://"):
parsed_url = urllib.parse.urlparse(path)
file_url_path = urllib.parse.unquote(parsed_url.path)
local_path = urllib.request.url2pathname(file_url_path)
file_path = pathlib.Path(local_path)
paths.append(file_path)
return paths
class WatchServer:
def __init__(
self,
worker_runner: Callable[[proto.CliArgs], Any],
main_file: pathlib.Path,
cli_args: proto.CliArgs,
loop: asyncio.AbstractEventLoop,
) -> None:
self._mp_pch, cli_args.mp_cch = socket.socketpair()
self._cli_args = cli_args
self._worker_runner = worker_runner
self._main_file = main_file
self._loop = loop
self._recv_jobs_fut = asyncio.Future[None]()
self._worker_reloading = False
async def run(self) -> None:
watch_paths = _find_watchable_paths(self._main_file)
for pth in watch_paths:
logger.log(DEV_LEVEL, f"Watching {pth}")
self._pch = await utils.aio.duplex_unix._AsyncDuplex.open(self._mp_pch)
read_ipc_task = self._loop.create_task(self._read_ipc_task())
try:
await watchfiles.arun_process(
*watch_paths,
target=self._worker_runner,
args=(self._cli_args,),
watch_filter=watchfiles.filters.PythonFilter(),
callback=self._on_reload,
)
finally:
await utils.aio.gracefully_cancel(read_ipc_task)
await self._pch.aclose()
async def _on_reload(self, _: Set[watchfiles.main.FileChange]) -> None:
if self._worker_reloading:
return
self._worker_reloading = True
try:
await channel.asend_message(self._pch, proto.ActiveJobsRequest())
self._recv_jobs_fut = asyncio.Future()
with contextlib.suppress(asyncio.TimeoutError):
# wait max 1.5s to get the active jobs
await asyncio.wait_for(self._recv_jobs_fut, timeout=1.5)
finally:
self._cli_args.reload_count += 1
@utils.log_exceptions(logger=logger)
async def _read_ipc_task(self) -> None:
active_jobs = []
while True:
msg = await channel.arecv_message(self._pch, proto.IPC_MESSAGES)
if isinstance(msg, proto.ActiveJobsResponse):
if msg.reload_count != self._cli_args.reload_count:
continue
active_jobs = msg.jobs
with contextlib.suppress(asyncio.InvalidStateError):
self._recv_jobs_fut.set_result(None)
if isinstance(msg, proto.ReloadJobsRequest):
await channel.asend_message(
self._pch, proto.ReloadJobsResponse(jobs=active_jobs)
)
if isinstance(msg, proto.Reloaded):
self._worker_reloading = False
class WatchClient:
def __init__(
self,
worker: Worker,
cli_args: proto.CliArgs,
loop: asyncio.AbstractEventLoop | None = None,
) -> None:
self._loop = loop or asyncio.get_event_loop()
self._worker = worker
self._cli_args = cli_args
def start(self) -> None:
self._main_task = self._loop.create_task(self._run())
@utils.log_exceptions(logger=logger)
async def _run(self) -> None:
assert self._cli_args.mp_cch
try:
self._cch = await utils.aio.duplex_unix._AsyncDuplex.open(
self._cli_args.mp_cch
)
await channel.asend_message(self._cch, proto.ReloadJobsRequest())
while True:
try:
msg = await channel.arecv_message(self._cch, proto.IPC_MESSAGES)
except utils.aio.duplex_unix.DuplexClosed:
break
if isinstance(msg, proto.ActiveJobsRequest):
jobs = self._worker.active_jobs
await channel.asend_message(
self._cch,
proto.ActiveJobsResponse(
jobs=jobs, reload_count=self._cli_args.reload_count
),
)
elif isinstance(msg, proto.ReloadJobsResponse):
# TODO(theomonnom): wait for the worker to be fully initialized/connected
await self._worker._reload_jobs(msg.jobs)
await channel.asend_message(self._cch, proto.Reloaded())
except utils.aio.duplex_unix.DuplexClosed:
pass
async def aclose(self) -> None:
if not self._main_task:
return
self._main_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._main_task
await self._cch.aclose()
from __future__ import annotations
import asyncio
from typing import Any
from aiohttp import web
async def health_check(_: Any):
return web.Response(text="OK")
class HttpServer:
def __init__(
self, host: str, port: int, loop: asyncio.AbstractEventLoop | None = None
) -> None:
self._loop = loop or asyncio.get_event_loop()
self._host = host
self._port = port
self._app = web.Application(loop=self._loop)
self._app.add_routes([web.get("/", health_check)])
self._close_future = asyncio.Future[None](loop=self._loop)
async def run(self) -> None:
self._runner = web.AppRunner(self._app)
await self._runner.setup()
site = web.TCPSite(self._runner, self._host, self._port)
await site.start()
try:
await self._close_future
finally:
await self._runner.cleanup()
async def aclose(self) -> None:
if not self._close_future.done():
self._close_future.set_result(None)
from __future__ import annotations
import threading
from abc import ABC, abstractmethod
from typing import ClassVar, Protocol, Type
class _RunnerMeta(Protocol):
INFERENCE_METHOD: ClassVar[str]
_RunnersDict = dict[str, Type["_InferenceRunner"]]
# kept private until we stabilize the API (only used for EOU today)
class _InferenceRunner(ABC, _RunnerMeta):
registered_runners: _RunnersDict = {}
@classmethod
def register_runner(cls, runner_class: Type["_InferenceRunner"]) -> None:
if threading.current_thread() != threading.main_thread():
raise RuntimeError("InferenceRunner must be registered on the main thread")
if runner_class.INFERENCE_METHOD in cls.registered_runners:
raise ValueError(
f"InferenceRunner {runner_class.INFERENCE_METHOD} already registered"
)
cls.registered_runners[runner_class.INFERENCE_METHOD] = runner_class
@abstractmethod
def initialize(self) -> None:
"""Initialize the runner. This is used to load models, etc."""
...
@abstractmethod
def run(self, data: bytes) -> bytes | None:
"""Run inference on the given data."""
...
from . import (
channel,
inference_proc_executor,
job_executor,
job_proc_executor,
job_thread_executor,
proc_pool,
proto,
)
__all__ = [
"proto",
"channel",
"proc_pool",
"job_proc_executor",
"job_thread_executor",
"inference_proc_executor",
"job_executor",
]
from __future__ import annotations
import io
import struct
from typing import ClassVar, Protocol, runtime_checkable
from .. import utils
class Message(Protocol):
MSG_ID: ClassVar[int]
@runtime_checkable
class DataMessage(Message, Protocol):
def write(self, b: io.BytesIO) -> None: ...
def read(self, b: io.BytesIO) -> None: ...
MessagesDict = dict[int, type[Message]]
def _read_message(data: bytes, messages: MessagesDict) -> Message:
bio = io.BytesIO(data)
msg_id = read_int(bio)
msg = messages[msg_id]()
if isinstance(msg, DataMessage):
msg.read(bio)
return msg
def _write_message(msg: Message) -> bytes:
bio = io.BytesIO()
write_int(bio, msg.MSG_ID)
if isinstance(msg, DataMessage):
msg.write(bio)
return bio.getvalue()
async def arecv_message(
dplx: utils.aio.duplex_unix._AsyncDuplex, messages: MessagesDict
) -> Message:
return _read_message(await dplx.recv_bytes(), messages)
async def asend_message(dplx: utils.aio.duplex_unix._AsyncDuplex, msg: Message) -> None:
await dplx.send_bytes(_write_message(msg))
def recv_message(
dplx: utils.aio.duplex_unix._Duplex, messages: MessagesDict
) -> Message:
return _read_message(dplx.recv_bytes(), messages)
def send_message(dplx: utils.aio.duplex_unix._Duplex, msg: Message) -> None:
dplx.send_bytes(_write_message(msg))
def write_bytes(b: io.BytesIO, buf: bytes) -> None:
b.write(len(buf).to_bytes(4, "big"))
b.write(buf)
def read_bytes(b: io.BytesIO) -> bytes:
length = int.from_bytes(b.read(4), "big")
return b.read(length)
def write_string(b: io.BytesIO, s: str) -> None:
encoded = s.encode("utf-8")
b.write(len(encoded).to_bytes(4, "big"))
b.write(encoded)
def read_string(b: io.BytesIO) -> str:
length = int.from_bytes(b.read(4), "big")
return b.read(length).decode("utf-8")
def write_int(b: io.BytesIO, i: int) -> None:
b.write(i.to_bytes(4, "big"))
def read_int(b: io.BytesIO) -> int:
return int.from_bytes(b.read(4), "big")
def write_bool(b: io.BytesIO, bi: bool) -> None:
b.write(bi.to_bytes(1, "big"))
def read_bool(b: io.BytesIO) -> bool:
return bool.from_bytes(b.read(1), "big")
def write_float(b: io.BytesIO, f: float) -> None:
b.write(struct.pack("f", f))
def read_float(b: io.BytesIO) -> float:
return struct.unpack("f", b.read(4))[0]
def write_double(b: io.BytesIO, d: float) -> None:
b.write(struct.pack("d", d))
def read_double(b: io.BytesIO) -> float:
return struct.unpack("d", b.read(8))[0]
def write_long(b: io.BytesIO, long: int) -> None:
b.write(long.to_bytes(8, "big"))
def read_long(b: io.BytesIO) -> int:
return int.from_bytes(b.read(8), "big")
from __future__ import annotations
from typing import Protocol
class InferenceExecutor(Protocol):
async def do_inference(self, method: str, data: bytes) -> bytes | None: ...
from __future__ import annotations
import asyncio
import contextlib
import multiprocessing as mp
import socket
from multiprocessing.context import BaseContext
from ..inference_runner import _RunnersDict
from ..log import logger
from ..utils import aio, log_exceptions, shortuuid
from . import channel, proto
from .inference_proc_lazy_main import ProcStartArgs, proc_main
from .supervised_proc import SupervisedProc
class InferenceProcExecutor(SupervisedProc):
def __init__(
self,
*,
runners: _RunnersDict,
initialize_timeout: float,
close_timeout: float,
memory_warn_mb: float,
memory_limit_mb: float,
ping_interval: float,
ping_timeout: float,
high_ping_threshold: float,
mp_ctx: BaseContext,
loop: asyncio.AbstractEventLoop,
) -> None:
super().__init__(
initialize_timeout=initialize_timeout,
close_timeout=close_timeout,
memory_warn_mb=memory_warn_mb,
memory_limit_mb=memory_limit_mb,
ping_interval=ping_interval,
ping_timeout=ping_timeout,
high_ping_threshold=high_ping_threshold,
mp_ctx=mp_ctx,
loop=loop,
)
self._runners = runners
self._active_requests: dict[str, asyncio.Future[proto.InferenceResponse]] = {}
def _create_process(self, cch: socket.socket, log_cch: socket.socket) -> mp.Process:
proc_args = ProcStartArgs(
log_cch=log_cch,
mp_cch=cch,
runners=self._runners,
)
return self._mp_ctx.Process( # type: ignore
target=proc_main,
args=(proc_args,),
name="inference_proc",
)
@log_exceptions(logger=logger)
async def _main_task(self, ipc_ch: aio.ChanReceiver[channel.Message]) -> None:
async for msg in ipc_ch:
if isinstance(msg, proto.InferenceResponse):
fut = self._active_requests.pop(msg.request_id, None)
if fut is None:
logger.warning(
"received unexpected inference response",
extra={"request_id": msg.request_id},
)
return
with contextlib.suppress(asyncio.InvalidStateError):
fut.set_result(msg)
async def do_inference(self, method: str, data: bytes) -> bytes | None:
if not self.started:
raise RuntimeError("process not started")
request_id = shortuuid("inference_req_")
fut = asyncio.Future[proto.InferenceResponse]()
await channel.asend_message(
self._pch,
proto.InferenceRequest(request_id=request_id, method=method, data=data),
)
self._active_requests[request_id] = fut
inf_resp = await fut
if inf_resp.error:
raise RuntimeError(f"inference of {method} failed: {inf_resp.error}")
return inf_resp.data
def logging_extra(self):
extra = super().logging_extra()
extra["inference"] = True
return extra
from multiprocessing import current_process
if current_process().name == "inference_proc":
import signal
import sys
# ignore signals in the inference process (the parent process will handle them)
signal.signal(signal.SIGINT, signal.SIG_IGN)
signal.signal(signal.SIGTERM, signal.SIG_IGN)
def _no_traceback_excepthook(exc_type, exc_val, traceback):
if isinstance(exc_val, KeyboardInterrupt):
return
sys.__excepthook__(exc_type, exc_val, traceback)
sys.excepthook = _no_traceback_excepthook
import asyncio
import socket
from dataclasses import dataclass
from ..inference_runner import _RunnersDict
from ..log import logger
from ..utils import aio, log_exceptions
from . import proto
from .channel import Message
from .proc_client import _ProcClient
@dataclass
class ProcStartArgs:
log_cch: socket.socket
mp_cch: socket.socket
runners: _RunnersDict
def proc_main(args: ProcStartArgs) -> None:
from .proc_client import _ProcClient
inf_proc = _InferenceProc(args.runners)
client = _ProcClient(
args.mp_cch,
args.log_cch,
inf_proc.initialize,
inf_proc.entrypoint,
)
client.initialize_logger()
pid = current_process().pid
logger.info("initializing inference process", extra={"pid": pid})
client.initialize()
logger.info("inference process initialized", extra={"pid": pid})
client.run()
class _InferenceProc:
def __init__(self, runners: _RunnersDict) -> None:
# create an instance of each runner (the ctor must not requires any argument)
self._runners = {name: runner() for name, runner in runners.items()}
def initialize(
self, init_req: proto.InitializeRequest, client: _ProcClient
) -> None:
self._client = client
for runner in self._runners.values():
logger.debug(
"initializing inference runner",
extra={"runner": runner.__class__.INFERENCE_METHOD},
)
runner.initialize()
@log_exceptions(logger=logger)
async def entrypoint(self, cch: aio.ChanReceiver[Message]) -> None:
async for msg in cch:
if isinstance(msg, proto.InferenceRequest):
await self._handle_inference_request(msg)
if isinstance(msg, proto.ShutdownRequest):
await self._client.send(proto.Exiting(reason=msg.reason))
break
async def _handle_inference_request(self, msg: proto.InferenceRequest) -> None:
loop = asyncio.get_running_loop()
if msg.method not in self._runners:
logger.warning("unknown inference method", extra={"method": msg.method})
try:
data = await loop.run_in_executor(
None, self._runners[msg.method].run, msg.data
)
await self._client.send(
proto.InferenceResponse(
request_id=msg.request_id,
data=data,
)
)
except Exception as e:
logger.exception("error running inference")
await self._client.send(
proto.InferenceResponse(request_id=msg.request_id, error=str(e))
)
from __future__ import annotations
from enum import Enum
from typing import Any, Protocol
from ..job import RunningJobInfo
class JobExecutor(Protocol):
@property
def started(self) -> bool: ...
@property
def user_arguments(self) -> Any | None: ...
@user_arguments.setter
def user_arguments(self, value: Any | None) -> None: ...
@property
def running_job(self) -> RunningJobInfo | None: ...
@property
def status(self) -> JobStatus: ...
async def start(self) -> None: ...
async def join(self) -> None: ...
async def initialize(self) -> None: ...
async def aclose(self) -> None: ...
async def launch_job(self, info: RunningJobInfo) -> None: ...
class JobStatus(Enum):
RUNNING = "running"
FAILED = "failed"
SUCCESS = "success"
from __future__ import annotations
import asyncio
import multiprocessing as mp
import socket
from multiprocessing.context import BaseContext
from typing import Any, Awaitable, Callable
from ..job import JobContext, JobProcess, RunningJobInfo
from ..log import logger
from ..utils import aio, log_exceptions
from . import channel, proto
from .inference_executor import InferenceExecutor
from .job_executor import JobStatus
from .job_proc_lazy_main import ProcStartArgs, proc_main
from .supervised_proc import SupervisedProc
class ProcJobExecutor(SupervisedProc):
def __init__(
self,
*,
initialize_process_fnc: Callable[[JobProcess], Any],
job_entrypoint_fnc: Callable[[JobContext], Awaitable[None]],
inference_executor: InferenceExecutor | None,
initialize_timeout: float,
close_timeout: float,
memory_warn_mb: float,
memory_limit_mb: float,
ping_interval: float,
ping_timeout: float,
high_ping_threshold: float,
mp_ctx: BaseContext,
loop: asyncio.AbstractEventLoop,
) -> None:
super().__init__(
initialize_timeout=initialize_timeout,
close_timeout=close_timeout,
memory_warn_mb=memory_warn_mb,
memory_limit_mb=memory_limit_mb,
ping_interval=ping_interval,
ping_timeout=ping_timeout,
high_ping_threshold=high_ping_threshold,
mp_ctx=mp_ctx,
loop=loop,
)
self._user_args: Any | None = None
self._job_status: JobStatus | None = None
self._running_job: RunningJobInfo | None = None
self._initialize_process_fnc = initialize_process_fnc
self._job_entrypoint_fnc = job_entrypoint_fnc
self._inference_executor = inference_executor
self._inference_tasks: list[asyncio.Task[None]] = []
@property
def status(self) -> JobStatus:
if self._job_status is None:
raise RuntimeError("job status not available")
return self._job_status
@property
def user_arguments(self) -> Any | None:
return self._user_args
@user_arguments.setter
def user_arguments(self, value: Any | None) -> None:
self._user_args = value
@property
def running_job(self) -> RunningJobInfo | None:
return self._running_job
def _create_process(self, cch: socket.socket, log_cch: socket.socket) -> mp.Process:
proc_args = ProcStartArgs(
initialize_process_fnc=self._initialize_process_fnc,
job_entrypoint_fnc=self._job_entrypoint_fnc,
log_cch=log_cch,
mp_cch=cch,
user_arguments=self._user_args,
)
return self._mp_ctx.Process( # type: ignore
target=proc_main,
args=(proc_args,),
name="job_proc",
)
@log_exceptions(logger=logger)
async def _main_task(self, ipc_ch: aio.ChanReceiver[channel.Message]) -> None:
try:
async for msg in ipc_ch:
if isinstance(msg, proto.InferenceRequest):
self._inference_tasks.append(
asyncio.create_task(self._do_inference_task(msg))
)
finally:
await aio.gracefully_cancel(*self._inference_tasks)
@log_exceptions(logger=logger)
async def _supervise_task(self) -> None:
try:
await super()._supervise_task()
finally:
self._job_status = (
JobStatus.SUCCESS if self.exitcode == 0 else JobStatus.FAILED
)
async def _do_inference_task(self, inf_req: proto.InferenceRequest) -> None:
if self._inference_executor is None:
logger.warning("inference request received but no inference executor")
await channel.asend_message(
self._pch,
proto.InferenceResponse(
request_id=inf_req.request_id, error="no inference executor"
),
)
return
try:
inf_res = await self._inference_executor.do_inference(
inf_req.method, inf_req.data
)
await channel.asend_message(
self._pch,
proto.InferenceResponse(request_id=inf_req.request_id, data=inf_res),
)
except Exception as e:
await channel.asend_message(
self._pch,
proto.InferenceResponse(request_id=inf_req.request_id, error=str(e)),
)
async def launch_job(self, info: RunningJobInfo) -> None:
"""start/assign a job to the process"""
if self._running_job is not None:
raise RuntimeError("process already has a running job")
if not self._initialize_fut.done():
raise RuntimeError("process not initialized")
self._job_status = JobStatus.RUNNING
self._running_job = info
start_req = proto.StartJobRequest()
start_req.running_job = info
await channel.asend_message(self._pch, start_req)
def logging_extra(self):
extra = super().logging_extra()
if self._running_job:
extra["job_id"] = self._running_job.job.id
return extra
from __future__ import annotations
from multiprocessing import current_process
if current_process().name == "job_proc":
import signal
import sys
# ignore signals in the jobs process (the parent process will handle them)
signal.signal(signal.SIGINT, signal.SIG_IGN)
signal.signal(signal.SIGTERM, signal.SIG_IGN)
def _no_traceback_excepthook(exc_type, exc_val, traceback):
if isinstance(exc_val, KeyboardInterrupt):
return
sys.__excepthook__(exc_type, exc_val, traceback)
sys.excepthook = _no_traceback_excepthook
import asyncio
import contextlib
import socket
import threading
from dataclasses import dataclass
from typing import Any, Callable
from livekit import rtc
from ..job import JobContext, JobProcess, _JobContextVar
from ..log import logger
from ..utils import aio, http_context, log_exceptions, shortuuid
from .channel import Message
from .inference_executor import InferenceExecutor
from .proc_client import _ProcClient
from .proto import (
Exiting,
InferenceRequest,
InferenceResponse,
InitializeRequest,
ShutdownRequest,
StartJobRequest,
)
@dataclass
class ProcStartArgs:
initialize_process_fnc: Callable[[JobProcess], Any]
job_entrypoint_fnc: Callable[[JobContext], Any]
mp_cch: socket.socket
log_cch: socket.socket
user_arguments: Any | None = None
def proc_main(args: ProcStartArgs) -> None:
from .proc_client import _ProcClient
job_proc = _JobProc(
args.initialize_process_fnc, args.job_entrypoint_fnc, args.user_arguments
)
client = _ProcClient(
args.mp_cch,
args.log_cch,
job_proc.initialize,
job_proc.entrypoint,
)
client.initialize_logger()
pid = current_process().pid
logger.info("initializing job process", extra={"pid": pid})
try:
client.initialize()
except Exception:
return # initialization failed, exit
logger.info("job process initialized", extra={"pid": pid})
client.run()
class _InfClient(InferenceExecutor):
def __init__(self, proc_client: _ProcClient) -> None:
self._client = proc_client
self._active_requests: dict[str, asyncio.Future[InferenceResponse]] = {}
async def do_inference(self, method: str, data: bytes) -> bytes | None:
request_id = shortuuid("inference_job_")
fut = asyncio.Future[InferenceResponse]()
await self._client.send(
InferenceRequest(request_id=request_id, method=method, data=data),
)
self._active_requests[request_id] = fut
inf_resp = await fut
if inf_resp.error:
raise RuntimeError(f"inference of {method} failed: {inf_resp.error}")
return inf_resp.data
def _on_inference_response(self, resp: InferenceResponse) -> None:
fut = self._active_requests.pop(resp.request_id, None)
if fut is None:
logger.warning(
"received unexpected inference response", extra={"resp": resp}
)
return
with contextlib.suppress(asyncio.InvalidStateError):
fut.set_result(resp)
@dataclass
class _ShutdownInfo:
user_initiated: bool
reason: str
class _JobProc:
def __init__(
self,
initialize_process_fnc: Callable[[JobProcess], Any],
job_entrypoint_fnc: Callable[[JobContext], Any],
user_arguments: Any | None = None,
) -> None:
self._initialize_process_fnc = initialize_process_fnc
self._job_entrypoint_fnc = job_entrypoint_fnc
self._job_proc = JobProcess(user_arguments=user_arguments)
self._job_task: asyncio.Task | None = None
# used to warn users if both connect and shutdown are not called inside the job_entry
self._ctx_connect_called = False
self._ctx_shutdown_called = False
@property
def has_running_job(self) -> bool:
return self._job_task is not None
def initialize(self, init_req: InitializeRequest, client: _ProcClient) -> None:
self._client = client
self._inf_client = _InfClient(client)
self._initialize_process_fnc(self._job_proc)
@log_exceptions(logger=logger)
async def entrypoint(self, cch: aio.ChanReceiver[Message]) -> None:
self._exit_proc_flag = asyncio.Event()
self._shutdown_fut: asyncio.Future[_ShutdownInfo] = asyncio.Future()
@log_exceptions(logger=logger)
async def _read_ipc_task():
async for msg in cch:
if isinstance(msg, StartJobRequest):
if self.has_running_job:
logger.warning(
"trying to start a new job while one is already running"
)
continue
self._start_job(msg)
if isinstance(msg, ShutdownRequest):
if not self.has_running_job:
self._exit_proc_flag.set()
break # exit immediately
with contextlib.suppress(asyncio.InvalidStateError):
self._shutdown_fut.set_result(
_ShutdownInfo(reason=msg.reason, user_initiated=False)
)
if isinstance(msg, InferenceResponse):
self._inf_client._on_inference_response(msg)
read_task = asyncio.create_task(_read_ipc_task(), name="job_ipc_read")
await self._exit_proc_flag.wait()
await aio.gracefully_cancel(read_task)
def _start_job(self, msg: StartJobRequest) -> None:
self._room = rtc.Room()
@self._room.on("disconnected")
def _on_room_disconnected(*args):
with contextlib.suppress(asyncio.InvalidStateError):
self._shutdown_fut.set_result(
_ShutdownInfo(user_initiated=False, reason="room disconnected")
)
def _on_ctx_connect() -> None:
self._ctx_connect_called = True
def _on_ctx_shutdown(reason: str) -> None:
self._ctx_shutdown_called = True
with contextlib.suppress(asyncio.InvalidStateError):
self._shutdown_fut.set_result(
_ShutdownInfo(user_initiated=True, reason=reason)
)
self._room._info.name = msg.running_job.job.room.name
self._job_ctx = JobContext(
proc=self._job_proc,
info=msg.running_job,
room=self._room,
on_connect=_on_ctx_connect,
on_shutdown=_on_ctx_shutdown,
inference_executor=self._inf_client,
)
self._job_task = asyncio.create_task(self._run_job_task(), name="job_task")
def _exit_proc_cb(_: asyncio.Task) -> None:
self._exit_proc_flag.set()
self._job_task.add_done_callback(_exit_proc_cb)
async def _run_job_task(self) -> None:
http_context._new_session_ctx()
job_ctx_token = _JobContextVar.set(self._job_ctx)
job_entry_task = asyncio.create_task(
self._job_entrypoint_fnc(self._job_ctx), name="job_user_entrypoint"
)
async def _warn_not_connected_task():
await asyncio.sleep(10)
if not self._ctx_connect_called and not self._ctx_shutdown_called:
logger.warning(
(
"The room connection was not established within 10 seconds after calling job_entry. "
"This may indicate that job_ctx.connect() was not called. "
)
)
warn_unconnected_task = asyncio.create_task(_warn_not_connected_task())
job_entry_task.add_done_callback(lambda _: warn_unconnected_task.cancel())
def log_exception(t: asyncio.Task) -> None:
if not t.cancelled() and t.exception():
logger.error(
"unhandled exception while running the job task",
exc_info=t.exception(),
)
elif not self._ctx_connect_called and not self._ctx_shutdown_called:
logger.warning(
(
"The job task completed without establishing a connection or performing a proper shutdown. "
"Ensure that job_ctx.connect()/job_ctx.shutdown() is called and the job is correctly finalized."
)
)
job_entry_task.add_done_callback(log_exception)
shutdown_info = await self._shutdown_fut
logger.debug(
"shutting down job task",
extra={
"reason": shutdown_info.reason,
"user_initiated": shutdown_info.user_initiated,
},
)
await self._client.send(Exiting(reason=shutdown_info.reason))
await self._room.disconnect()
try:
shutdown_tasks = []
for callback in self._job_ctx._shutdown_callbacks:
shutdown_tasks.append(
asyncio.create_task(
callback(shutdown_info.reason), name="job_shutdown_callback"
)
)
await asyncio.gather(*shutdown_tasks)
except Exception:
logger.exception("error while shutting down the job")
await http_context._close_http_ctx()
_JobContextVar.reset(job_ctx_token)
@dataclass
class ThreadStartArgs:
initialize_process_fnc: Callable[[JobProcess], Any]
job_entrypoint_fnc: Callable[[JobContext], Any]
join_fnc: Callable[[], None]
mp_cch: socket.socket
user_arguments: Any | None
def thread_main(
args: ThreadStartArgs,
) -> None:
"""main function for the job process when using the ThreadedJobRunner"""
tid = threading.get_native_id()
try:
from .proc_client import _ProcClient
job_proc = _JobProc(
args.initialize_process_fnc, args.job_entrypoint_fnc, args.user_arguments
)
client = _ProcClient(
args.mp_cch,
None,
job_proc.initialize,
job_proc.entrypoint,
)
logger.info("initializing job runner", extra={"tid": tid})
client.initialize()
logger.info("job runner initialized", extra={"tid": tid})
client.run()
finally:
args.join_fnc()
from __future__ import annotations
import asyncio
import contextlib
import socket
import threading
from dataclasses import dataclass
from typing import Any, Awaitable, Callable
from .. import utils
from ..job import JobContext, JobProcess, RunningJobInfo
from ..log import logger
from ..utils.aio import duplex_unix
from . import channel, job_proc_lazy_main, proto
from .inference_executor import InferenceExecutor
from .job_executor import JobStatus
@dataclass
class _ProcOpts:
initialize_process_fnc: Callable[[JobProcess], Any]
job_entrypoint_fnc: Callable[[JobContext], Awaitable[None]]
initialize_timeout: float
close_timeout: float
ping_interval: float
high_ping_threshold: float
class ThreadJobExecutor:
def __init__(
self,
*,
initialize_process_fnc: Callable[[JobProcess], Any],
job_entrypoint_fnc: Callable[[JobContext], Awaitable[None]],
inference_executor: InferenceExecutor | None,
initialize_timeout: float,
close_timeout: float,
ping_interval: float,
high_ping_threshold: float,
loop: asyncio.AbstractEventLoop,
) -> None:
self._loop = loop
self._opts = _ProcOpts(
initialize_process_fnc=initialize_process_fnc,
job_entrypoint_fnc=job_entrypoint_fnc,
initialize_timeout=initialize_timeout,
close_timeout=close_timeout,
ping_interval=ping_interval,
high_ping_threshold=high_ping_threshold,
)
self._user_args: Any | None = None
self._job_status: JobStatus | None = None
self._running_job: RunningJobInfo | None = None
self._main_atask: asyncio.Task[None] | None = None
self._initialize_fut = asyncio.Future[None]()
self._closing = False
self._lock = asyncio.Lock()
self._inference_executor = inference_executor
self._inference_tasks: list[asyncio.Task[None]] = []
@property
def status(self) -> JobStatus:
if self._job_status is None:
raise RuntimeError("job status not available")
return self._job_status
@property
def started(self) -> bool:
return self._main_atask is not None
@property
def user_arguments(self) -> Any | None:
return self._user_args
@user_arguments.setter
def user_arguments(self, value: Any | None) -> None:
self._user_args = value
@property
def running_job(self) -> RunningJobInfo | None:
return self._running_job
async def start(self) -> None:
if self.started:
raise RuntimeError("runner already started")
if self._closing:
raise RuntimeError("runner is closed")
await asyncio.shield(self._start())
async def _start(self) -> None:
async with self._lock:
# to simplify the runners implementation, we also use a duplex in the threaded executor
# (ThreadedRunners), so we can use the same protocol
mp_pch, mp_cch = socket.socketpair()
self._pch = await duplex_unix._AsyncDuplex.open(mp_pch)
self._join_fut = asyncio.Future[None]()
def _on_join() -> None:
with contextlib.suppress(RuntimeError):
self._loop.call_soon_threadsafe(self._join_fut.set_result, None)
targs = job_proc_lazy_main.ThreadStartArgs(
mp_cch=mp_cch,
initialize_process_fnc=self._opts.initialize_process_fnc,
job_entrypoint_fnc=self._opts.job_entrypoint_fnc,
user_arguments=self._user_args,
join_fnc=_on_join,
)
self._thread = t = threading.Thread(
target=job_proc_lazy_main.thread_main,
args=(targs,),
name="job_thread_runner",
)
t.start()
self._main_atask = asyncio.create_task(self._main_task())
async def join(self) -> None:
"""wait for the thread to finish"""
if not self.started:
raise RuntimeError("runner not started")
async with self._lock:
if self._main_atask:
await asyncio.shield(self._main_atask)
async def initialize(self) -> None:
await channel.asend_message(self._pch, proto.InitializeRequest())
try:
init_res = await asyncio.wait_for(
channel.arecv_message(self._pch, proto.IPC_MESSAGES),
timeout=self._opts.initialize_timeout,
)
assert isinstance(init_res, proto.InitializeResponse), (
"first message must be InitializeResponse"
)
except asyncio.TimeoutError:
self._initialize_fut.set_exception(
asyncio.TimeoutError("runner initialization timed out")
)
logger.error(
"job initialization is taking too much time..",
extra=self.logging_extra(),
)
raise
except Exception as e: # should be channel.ChannelClosed most of the time
self._initialize_fut.set_exception(e)
raise
else:
self._initialize_fut.set_result(None)
async def aclose(self) -> None:
"""
attempt to gracefully close the job. warn if it takes too long to close
(in the threaded executor, the job can't be "killed")
"""
if not self.started:
return
self._closing = True
with contextlib.suppress(utils.aio.duplex_unix.DuplexClosed):
await channel.asend_message(self._pch, proto.ShutdownRequest())
try:
if self._main_atask:
await asyncio.wait_for(
asyncio.shield(self._main_atask), timeout=self._opts.close_timeout
)
except asyncio.TimeoutError:
logger.error(
"job shutdown is taking too much time..", extra=self.logging_extra()
)
async with self._lock:
if self._main_atask:
await asyncio.shield(self._main_atask)
async def _do_inference_task(self, inf_req: proto.InferenceRequest) -> None:
if self._inference_executor is None:
logger.warning("inference request received but no inference executor")
await channel.asend_message(
self._pch,
proto.InferenceResponse(
request_id=inf_req.request_id, error="no inference executor"
),
)
return
try:
inf_res = await self._inference_executor.do_inference(
inf_req.method, inf_req.data
)
await channel.asend_message(
self._pch,
proto.InferenceResponse(request_id=inf_req.request_id, data=inf_res),
)
except Exception as e:
await channel.asend_message(
self._pch,
proto.InferenceResponse(request_id=inf_req.request_id, error=str(e)),
)
async def launch_job(self, info: RunningJobInfo) -> None:
"""start/assign a job to the executor"""
if self._running_job is not None:
raise RuntimeError("executor already has a running job")
if not self._initialize_fut.done():
raise RuntimeError("executor not initialized")
self._running_job = info
self._job_status = JobStatus.RUNNING
start_req = proto.StartJobRequest()
start_req.running_job = info
await channel.asend_message(self._pch, start_req)
@utils.log_exceptions(logger=logger)
async def _main_task(self) -> None:
try:
await self._initialize_fut
except asyncio.TimeoutError:
pass # this happens when the initialization takes longer than self._initialize_timeout
except Exception:
pass # initialization failed
ping_task = asyncio.create_task(self._ping_task())
monitor_task = asyncio.create_task(self._monitor_task())
await self._join_fut
await utils.aio.gracefully_cancel(ping_task, monitor_task)
await utils.aio.gracefully_cancel(*self._inference_tasks)
with contextlib.suppress(duplex_unix.DuplexClosed):
await self._pch.aclose()
self._job_status = JobStatus.SUCCESS
@utils.log_exceptions(logger=logger)
async def _monitor_task(self) -> None:
while True:
try:
msg = await channel.arecv_message(self._pch, proto.IPC_MESSAGES)
except utils.aio.duplex_unix.DuplexClosed:
break
if isinstance(msg, proto.PongResponse):
delay = utils.time_ms() - msg.timestamp
if delay > self._opts.high_ping_threshold * 1000:
logger.warning(
"job executor is unresponsive",
extra={"delay": delay, **self.logging_extra()},
)
if isinstance(msg, proto.Exiting):
logger.debug(
"job exiting", extra={"reason": msg.reason, **self.logging_extra()}
)
if isinstance(msg, proto.InferenceRequest):
self._inference_tasks.append(
asyncio.create_task(self._do_inference_task(msg))
)
@utils.log_exceptions(logger=logger)
async def _ping_task(self) -> None:
ping_interval = utils.aio.interval(self._opts.ping_interval)
while True:
await ping_interval.tick()
try:
await channel.asend_message(
self._pch, proto.PingRequest(timestamp=utils.time_ms())
)
except utils.aio.duplex_unix.DuplexClosed:
break
def logging_extra(self):
extra: dict[str, Any] = {
"tid": self._thread.native_id,
}
if self._running_job:
extra["job_id"] = self._running_job.job.id
return extra
from __future__ import annotations
import copy
import logging
import pickle
import queue
import sys
import threading
from typing import Callable, Optional
from .. import utils
from ..utils.aio import duplex_unix
class LogQueueListener:
def __init__(
self,
duplex: utils.aio.duplex_unix._Duplex,
prepare_fnc: Callable[[logging.LogRecord], None],
):
self._thread: threading.Thread | None = None
self._duplex = duplex
self._prepare_fnc = prepare_fnc
def start(self) -> None:
self._thread = threading.Thread(target=self._monitor, name="ipc_log_listener")
self._thread.start()
def stop(self) -> None:
if self._thread is None:
return
self._duplex.close()
self._thread.join()
self._thread = None
def handle(self, record: logging.LogRecord) -> None:
self._prepare_fnc(record)
lger = logging.getLogger(record.name)
if not lger.isEnabledFor(record.levelno):
return
lger.callHandlers(record)
def _monitor(self):
while True:
try:
data = self._duplex.recv_bytes()
except utils.aio.duplex_unix.DuplexClosed:
break
record = pickle.loads(data)
self.handle(record)
class LogQueueHandler(logging.Handler):
_sentinal = None
def __init__(self, duplex: utils.aio.duplex_unix._Duplex) -> None:
super().__init__()
self._duplex = duplex
self._send_q = queue.SimpleQueue[Optional[bytes]]()
self._send_thread = threading.Thread(
target=self._forward_logs, name="ipc_log_forwarder"
)
self._send_thread.start()
def _forward_logs(self):
while True:
serialized_record = self._send_q.get()
if serialized_record is None:
break
try:
self._duplex.send_bytes(serialized_record)
except duplex_unix.DuplexClosed:
break
self._duplex.close()
def emit(self, record: logging.LogRecord) -> None:
try:
# Check if Python is shutting down
if sys.is_finalizing():
return
# from https://github.com/python/cpython/blob/91b7f2e7f6593acefda4fa860250dd87d6f849bf/Lib/logging/handlers.py#L1453
msg = self.format(record)
record = copy.copy(record)
record.message = msg
record.msg = msg
record.args = None
record.exc_info = None
record.exc_text = None
record.stack_info = None
# https://websockets.readthedocs.io/en/stable/topics/logging.html#logging-to-json
# webosckets library add "websocket" attribute to log records, which is not pickleable
if hasattr(record, "websocket"):
record.websocket = None
self._send_q.put_nowait(pickle.dumps(record))
except Exception:
self.handleError(record)
def close(self) -> None:
super().close()
self._send_q.put_nowait(self._sentinal)
from __future__ import annotations
import asyncio
import contextlib
import logging
import socket
import sys
from typing import Callable, Coroutine
from ..log import logger
from ..utils import aio, log_exceptions, time_ms
from .channel import Message, arecv_message, asend_message, recv_message, send_message
from .log_queue import LogQueueHandler
from .proto import (
IPC_MESSAGES,
InitializeRequest,
InitializeResponse,
PingRequest,
PongResponse,
)
class _ProcClient:
def __init__(
self,
mp_cch: socket.socket,
log_cch: socket.socket | None,
initialize_fnc: Callable[[InitializeRequest, "_ProcClient"], None],
main_task_fnc: Callable[
[aio.ChanReceiver[Message]], Coroutine[None, None, None]
],
) -> None:
self._mp_cch = mp_cch
self._log_cch = log_cch
self._initialize_fnc = initialize_fnc
self._main_task_fnc = main_task_fnc
self._initialized = False
self._log_handler: LogQueueHandler | None = None
def initialize_logger(self) -> None:
if self._log_cch is None:
raise RuntimeError("cannot initialize logger without log channel")
root_logger = logging.getLogger()
root_logger.setLevel(logging.NOTSET)
log_cch = aio.duplex_unix._Duplex.open(self._log_cch)
self._log_handler = LogQueueHandler(log_cch)
root_logger.addHandler(self._log_handler)
def initialize(self) -> None:
try:
cch = aio.duplex_unix._Duplex.open(self._mp_cch)
first_req = recv_message(cch, IPC_MESSAGES)
assert isinstance(first_req, InitializeRequest), (
"first message must be proto.InitializeRequest"
)
self._init_req = first_req
try:
self._initialize_fnc(self._init_req, self)
send_message(cch, InitializeResponse())
except Exception as e:
send_message(cch, InitializeResponse(error=str(e)))
raise
self._initialized = True
cch.detach()
except aio.duplex_unix.DuplexClosed as e:
raise RuntimeError("failed to initialize proc_client") from e
def run(self) -> None:
if not self._initialized:
raise RuntimeError("proc_client not initialized")
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.set_debug(self._init_req.asyncio_debug)
loop.slow_callback_duration = 0.1 # 100ms
aio.debug.hook_slow_callbacks(2.0)
try:
self._task = loop.create_task(self._monitor_task(), name="proc_client_main")
while not self._task.done():
try:
loop.run_until_complete(self._task)
except KeyboardInterrupt:
# ignore the keyboard interrupt, we handle the process shutdown ourselves on the worker process
# (See proto.ShutdownRequest)
pass
except KeyboardInterrupt:
pass
finally:
if self._log_handler is not None:
self._log_handler.close()
loop.run_until_complete(loop.shutdown_default_executor())
async def send(self, msg: Message) -> None:
await asend_message(self._acch, msg)
async def _monitor_task(self) -> None:
self._acch = await aio.duplex_unix._AsyncDuplex.open(self._mp_cch)
try:
exit_flag = asyncio.Event()
ping_timeout = aio.sleep(self._init_req.ping_timeout)
ipc_ch = aio.Chan[Message]()
@log_exceptions(logger=logger)
async def _read_ipc_task():
while True:
try:
msg = await arecv_message(self._acch, IPC_MESSAGES)
except aio.duplex_unix.DuplexClosed:
break
with contextlib.suppress(aio.SleepFinished):
ping_timeout.reset()
if isinstance(msg, PingRequest):
await asend_message(
self._acch,
PongResponse(
last_timestamp=msg.timestamp, timestamp=time_ms()
),
)
ipc_ch.send_nowait(msg)
@log_exceptions(logger=logger)
async def _self_health_check():
await ping_timeout
print(
"worker process is not responding.. worker crashed?",
file=sys.stderr,
)
read_task = asyncio.create_task(_read_ipc_task(), name="ipc_read")
health_check_task: asyncio.Task | None = None
if self._init_req.ping_interval > 0:
health_check_task = asyncio.create_task(
_self_health_check(), name="health_check"
)
main_task = asyncio.create_task(
self._main_task_fnc(ipc_ch), name="main_task_entrypoint"
)
def _done_cb(_: asyncio.Task) -> None:
with contextlib.suppress(asyncio.InvalidStateError):
exit_flag.set()
ipc_ch.close()
read_task.add_done_callback(_done_cb)
if health_check_task is not None:
health_check_task.add_done_callback(_done_cb)
main_task.add_done_callback(_done_cb)
await exit_flag.wait()
await aio.gracefully_cancel(read_task, main_task)
if health_check_task is not None:
await aio.gracefully_cancel(health_check_task)
finally:
await self._acch.aclose()
from __future__ import annotations
import asyncio
from multiprocessing.context import BaseContext
from typing import Any, Awaitable, Callable, Literal
from .. import utils
from ..job import JobContext, JobExecutorType, JobProcess, RunningJobInfo
from ..log import logger
from ..utils import aio
from . import inference_executor, job_proc_executor, job_thread_executor
from .job_executor import JobExecutor
EventTypes = Literal[
"process_created",
"process_started",
"process_ready",
"process_closed",
"process_job_launched",
]
MAX_CONCURRENT_INITIALIZATIONS = 5
class ProcPool(utils.EventEmitter[EventTypes]):
def __init__(
self,
*,
initialize_process_fnc: Callable[[JobProcess], Any],
job_entrypoint_fnc: Callable[[JobContext], Awaitable[None]],
num_idle_processes: int,
initialize_timeout: float,
close_timeout: float,
inference_executor: inference_executor.InferenceExecutor | None,
job_executor_type: JobExecutorType,
mp_ctx: BaseContext,
memory_warn_mb: float,
memory_limit_mb: float,
loop: asyncio.AbstractEventLoop,
) -> None:
super().__init__()
self._job_executor_type = job_executor_type
self._mp_ctx = mp_ctx
self._initialize_process_fnc = initialize_process_fnc
self._job_entrypoint_fnc = job_entrypoint_fnc
self._close_timeout = close_timeout
self._inf_executor = inference_executor
self._initialize_timeout = initialize_timeout
self._loop = loop
self._memory_limit_mb = memory_limit_mb
self._memory_warn_mb = memory_warn_mb
self._num_idle_processes = num_idle_processes
self._init_sem = asyncio.Semaphore(MAX_CONCURRENT_INITIALIZATIONS)
self._proc_needed_sem = asyncio.Semaphore(num_idle_processes)
self._warmed_proc_queue = asyncio.Queue[JobExecutor]()
self._executors: list[JobExecutor] = []
self._started = False
self._closed = False
@property
def processes(self) -> list[JobExecutor]:
return self._executors
def get_by_job_id(self, job_id: str) -> JobExecutor | None:
return next(
(
x
for x in self._executors
if x.running_job and x.running_job.job.id == job_id
),
None,
)
def start(self) -> None:
if self._started:
return
self._started = True
self._main_atask = asyncio.create_task(self._main_task())
async def aclose(self) -> None:
if not self._started:
return
self._closed = True
await aio.gracefully_cancel(self._main_atask)
async def launch_job(self, info: RunningJobInfo) -> None:
if self._num_idle_processes == 0:
self._proc_needed_sem.release() # ask for a process if prewarmed processes are not disabled
proc = await self._warmed_proc_queue.get()
else:
proc = await self._warmed_proc_queue.get()
self._proc_needed_sem.release() # notify that a new process can be warmed/started
await proc.launch_job(info)
self.emit("process_job_launched", proc)
@utils.log_exceptions(logger=logger)
async def _proc_watch_task(self) -> None:
proc: JobExecutor
if self._job_executor_type == JobExecutorType.THREAD:
proc = job_thread_executor.ThreadJobExecutor(
initialize_process_fnc=self._initialize_process_fnc,
job_entrypoint_fnc=self._job_entrypoint_fnc,
initialize_timeout=self._initialize_timeout,
close_timeout=self._close_timeout,
inference_executor=self._inf_executor,
ping_interval=2.5,
high_ping_threshold=0.5,
loop=self._loop,
)
elif self._job_executor_type == JobExecutorType.PROCESS:
proc = job_proc_executor.ProcJobExecutor(
initialize_process_fnc=self._initialize_process_fnc,
job_entrypoint_fnc=self._job_entrypoint_fnc,
initialize_timeout=self._initialize_timeout,
close_timeout=self._close_timeout,
inference_executor=self._inf_executor,
mp_ctx=self._mp_ctx,
loop=self._loop,
ping_interval=2.5,
ping_timeout=60,
high_ping_threshold=0.5,
memory_warn_mb=self._memory_warn_mb,
memory_limit_mb=self._memory_limit_mb,
)
else:
raise ValueError(f"unsupported job executor: {self._job_executor_type}")
try:
self._executors.append(proc)
async with self._init_sem:
if self._closed:
return
self.emit("process_created", proc)
await proc.start()
self.emit("process_started", proc)
try:
await proc.initialize()
# process where initialization times out will never fire "process_ready"
# neither be used to launch jobs
self.emit("process_ready", proc)
self._warmed_proc_queue.put_nowait(proc)
except Exception:
self._proc_needed_sem.release() # notify to warm a new process after initialization failure
await proc.join()
self.emit("process_closed", proc)
finally:
self._executors.remove(proc)
@utils.log_exceptions(logger=logger)
async def _main_task(self) -> None:
watch_tasks: list[asyncio.Task[None]] = []
try:
while True:
await self._proc_needed_sem.acquire()
task = asyncio.create_task(self._proc_watch_task())
watch_tasks.append(task)
task.add_done_callback(watch_tasks.remove)
except asyncio.CancelledError:
await asyncio.gather(*[proc.aclose() for proc in self._executors])
await asyncio.gather(*watch_tasks)
from __future__ import annotations
import io
from dataclasses import dataclass, field
from typing import ClassVar
from livekit.protocol import agent
from ..job import JobAcceptArguments, RunningJobInfo
from . import channel
@dataclass
class InitializeRequest:
"""sent by the main process to the subprocess to initialize it. this is going to call initialize_process_fnc"""
MSG_ID: ClassVar[int] = 0
asyncio_debug: bool = False
ping_interval: float = 0
ping_timeout: float = 0 # if no response, process is considered dead
high_ping_threshold: float = (
0 # if ping is higher than this, process is considered unresponsive
)
def write(self, b: io.BytesIO) -> None:
channel.write_bool(b, self.asyncio_debug)
channel.write_float(b, self.ping_interval)
channel.write_float(b, self.ping_timeout)
channel.write_float(b, self.high_ping_threshold)
def read(self, b: io.BytesIO) -> None:
self.asyncio_debug = channel.read_bool(b)
self.ping_interval = channel.read_float(b)
self.ping_timeout = channel.read_float(b)
self.high_ping_threshold = channel.read_float(b)
@dataclass
class InitializeResponse:
"""mark the process as initialized"""
MSG_ID: ClassVar[int] = 1
error: str = ""
def write(self, b: io.BytesIO) -> None:
channel.write_string(b, self.error)
def read(self, b: io.BytesIO) -> None:
self.error = channel.read_string(b)
@dataclass
class PingRequest:
"""sent by the main process to the subprocess to check if it is still alive"""
MSG_ID: ClassVar[int] = 2
timestamp: int = 0
def write(self, b: io.BytesIO) -> None:
channel.write_long(b, self.timestamp)
def read(self, b: io.BytesIO) -> None:
self.timestamp = channel.read_long(b)
@dataclass
class PongResponse:
"""response to a PingRequest"""
MSG_ID: ClassVar[int] = 3
last_timestamp: int = 0
timestamp: int = 0
def write(self, b: io.BytesIO) -> None:
channel.write_long(b, self.last_timestamp)
channel.write_long(b, self.timestamp)
def read(self, b: io.BytesIO) -> None:
self.last_timestamp = channel.read_long(b)
self.timestamp = channel.read_long(b)
@dataclass
class StartJobRequest:
"""sent by the main process to the subprocess to start a job, the subprocess will only
receive this message if the process is fully initialized (after sending a InitializeResponse)."""
MSG_ID: ClassVar[int] = 4
running_job: RunningJobInfo = field(init=False)
def write(self, b: io.BytesIO) -> None:
accept_args = self.running_job.accept_arguments
channel.write_bytes(b, self.running_job.job.SerializeToString())
channel.write_string(b, accept_args.name)
channel.write_string(b, accept_args.identity)
channel.write_string(b, accept_args.metadata)
channel.write_string(b, self.running_job.url)
channel.write_string(b, self.running_job.token)
channel.write_string(b, self.running_job.worker_id)
def read(self, b: io.BytesIO) -> None:
job = agent.Job()
job.ParseFromString(channel.read_bytes(b))
self.running_job = RunningJobInfo(
accept_arguments=JobAcceptArguments(
name=channel.read_string(b),
identity=channel.read_string(b),
metadata=channel.read_string(b),
),
job=job,
url=channel.read_string(b),
token=channel.read_string(b),
worker_id=channel.read_string(b),
)
@dataclass
class ShutdownRequest:
"""sent by the main process to the subprocess to indicate that it should shut down
gracefully. the subprocess will follow with a ExitInfo message"""
MSG_ID: ClassVar[int] = 5
reason: str = ""
def write(self, b: io.BytesIO) -> None:
channel.write_string(b, self.reason)
def read(self, b: io.BytesIO) -> None:
self.reason = channel.read_string(b)
@dataclass
class Exiting:
"""sent by the subprocess to the main process to indicate that it is exiting"""
MSG_ID: ClassVar[int] = 6
reason: str = ""
def write(self, b: io.BytesIO) -> None:
channel.write_string(b, self.reason)
def read(self, b: io.BytesIO) -> None:
self.reason = channel.read_string(b)
@dataclass
class InferenceRequest:
"""sent by a subprocess to the main process to request inference"""
MSG_ID: ClassVar[int] = 7
method: str = ""
request_id: str = ""
data: bytes = b""
def write(self, b: io.BytesIO) -> None:
channel.write_string(b, self.method)
channel.write_string(b, self.request_id)
channel.write_bytes(b, self.data)
def read(self, b: io.BytesIO) -> None:
self.method = channel.read_string(b)
self.request_id = channel.read_string(b)
self.data = channel.read_bytes(b)
@dataclass
class InferenceResponse:
"""response to an InferenceRequest"""
MSG_ID: ClassVar[int] = 8
request_id: str = ""
data: bytes | None = None
error: str = ""
def write(self, b: io.BytesIO) -> None:
channel.write_string(b, self.request_id)
channel.write_bool(b, self.data is not None)
if self.data is not None:
channel.write_bytes(b, self.data)
channel.write_string(b, self.error)
def read(self, b: io.BytesIO) -> None:
self.request_id = channel.read_string(b)
has_data = channel.read_bool(b)
if has_data:
self.data = channel.read_bytes(b)
self.error = channel.read_string(b)
IPC_MESSAGES = {
InitializeRequest.MSG_ID: InitializeRequest,
InitializeResponse.MSG_ID: InitializeResponse,
PingRequest.MSG_ID: PingRequest,
PongResponse.MSG_ID: PongResponse,
StartJobRequest.MSG_ID: StartJobRequest,
ShutdownRequest.MSG_ID: ShutdownRequest,
Exiting.MSG_ID: Exiting,
InferenceRequest.MSG_ID: InferenceRequest,
InferenceResponse.MSG_ID: InferenceResponse,
}
from __future__ import annotations
import asyncio
import contextlib
import logging
import multiprocessing as mp
import socket
import sys
import threading
from abc import ABC, abstractmethod
from dataclasses import dataclass
from multiprocessing.context import BaseContext
from typing import Any
import psutil
from ..log import logger
from ..utils import aio, log_exceptions, time_ms
from ..utils.aio import duplex_unix
from . import channel, proto
from .log_queue import LogQueueListener
@dataclass
class _ProcOpts:
initialize_timeout: float
close_timeout: float
memory_warn_mb: float
memory_limit_mb: float
ping_interval: float
ping_timeout: float
high_ping_threshold: float
class SupervisedProc(ABC):
def __init__(
self,
*,
initialize_timeout: float,
close_timeout: float,
memory_warn_mb: float,
memory_limit_mb: float,
ping_interval: float,
ping_timeout: float,
high_ping_threshold: float,
mp_ctx: BaseContext,
loop: asyncio.AbstractEventLoop,
) -> None:
self._loop = loop
self._mp_ctx = mp_ctx
self._opts = _ProcOpts(
initialize_timeout=initialize_timeout,
close_timeout=close_timeout,
memory_warn_mb=memory_warn_mb,
memory_limit_mb=memory_limit_mb,
ping_interval=ping_interval,
ping_timeout=ping_timeout,
high_ping_threshold=high_ping_threshold,
)
self._exitcode: int | None = None
self._pid: int | None = None
self._supervise_atask: asyncio.Task[None] | None = None
self._closing = False
self._kill_sent = False
self._initialize_fut = asyncio.Future[None]()
self._lock = asyncio.Lock()
@abstractmethod
def _create_process(
self, cch: socket.socket, log_cch: socket.socket
) -> mp.Process: ...
@abstractmethod
async def _main_task(self, ipc_ch: aio.ChanReceiver[channel.Message]) -> None: ...
@property
def exitcode(self) -> int | None:
return self._exitcode
@property
def killed(self) -> bool:
return self._kill_sent
@property
def pid(self) -> int | None:
return self._pid
@property
def started(self) -> bool:
return self._supervise_atask is not None
async def start(self) -> None:
"""start the supervised process"""
if self.started:
raise RuntimeError("process already started")
if self._closing:
raise RuntimeError("process is closed")
await asyncio.shield(self._start())
async def _start(self) -> None:
def _add_proc_ctx_log(record: logging.LogRecord) -> None:
extra = self.logging_extra()
for key, value in extra.items():
setattr(record, key, value)
async with self._lock:
mp_pch, mp_cch = socket.socketpair()
mp_log_pch, mp_log_cch = socket.socketpair()
self._pch = await duplex_unix._AsyncDuplex.open(mp_pch)
log_pch = duplex_unix._Duplex.open(mp_log_pch)
log_listener = LogQueueListener(log_pch, _add_proc_ctx_log)
log_listener.start()
self._proc = self._create_process(mp_cch, mp_log_cch)
await self._loop.run_in_executor(None, self._proc.start)
mp_log_cch.close()
mp_cch.close()
self._pid = self._proc.pid
self._join_fut = asyncio.Future[None]()
def _sync_run():
self._proc.join()
log_listener.stop()
try:
self._loop.call_soon_threadsafe(self._join_fut.set_result, None)
except RuntimeError:
pass
thread = threading.Thread(target=_sync_run, name="proc_join_thread")
thread.start()
self._supervise_atask = asyncio.create_task(self._supervise_task())
async def join(self) -> None:
"""wait for the process to finish"""
if not self.started:
raise RuntimeError("process not started")
async with self._lock:
if self._supervise_atask:
await asyncio.shield(self._supervise_atask)
async def initialize(self) -> None:
"""initialize the process, this is sending a InitializeRequest message and waiting for a
InitializeResponse with a timeout"""
await channel.asend_message(
self._pch,
proto.InitializeRequest(
asyncio_debug=self._loop.get_debug(),
ping_interval=self._opts.ping_interval,
ping_timeout=self._opts.ping_timeout,
high_ping_threshold=self._opts.high_ping_threshold,
),
)
# wait for the process to become ready
try:
init_res = await asyncio.wait_for(
channel.arecv_message(self._pch, proto.IPC_MESSAGES),
timeout=self._opts.initialize_timeout,
)
assert isinstance(init_res, proto.InitializeResponse), (
"first message must be InitializeResponse"
)
if init_res.error:
self._initialize_fut.set_exception(
RuntimeError(f"process initialization failed: {init_res.error}")
)
logger.error(
f"process initialization failed: {init_res.error}",
extra=self.logging_extra(),
)
raise RuntimeError(f"process initialization failed: {init_res.error}")
else:
self._initialize_fut.set_result(None)
except asyncio.TimeoutError:
self._initialize_fut.set_exception(
asyncio.TimeoutError("process initialization timed out")
)
logger.error(
"initialization timed out, killing process", extra=self.logging_extra()
)
self._send_kill_signal()
raise
except Exception as e: # should be channel.ChannelClosed most of the time
self._initialize_fut.set_exception(e)
raise
async def aclose(self) -> None:
"""attempt to gracefully close the supervised process"""
if not self.started:
return
self._closing = True
with contextlib.suppress(duplex_unix.DuplexClosed):
await channel.asend_message(self._pch, proto.ShutdownRequest())
try:
if self._supervise_atask:
await asyncio.wait_for(
asyncio.shield(self._supervise_atask),
timeout=self._opts.close_timeout,
)
except asyncio.TimeoutError:
logger.error(
"process did not exit in time, killing process",
extra=self.logging_extra(),
)
self._send_kill_signal()
async with self._lock:
if self._supervise_atask:
await asyncio.shield(self._supervise_atask)
async def kill(self) -> None:
"""forcefully kill the supervised process"""
if not self.started:
raise RuntimeError("process not started")
self._closing = True
self._send_kill_signal()
async with self._lock:
if self._supervise_atask:
await asyncio.shield(self._supervise_atask)
def _send_kill_signal(self) -> None:
"""forcefully kill the process"""
try:
if not self._proc.is_alive():
return
except ValueError:
return
logger.info("killing process", extra=self.logging_extra())
if sys.platform == "win32":
self._proc.terminate()
else:
self._proc.kill()
self._kill_sent = True
@log_exceptions(logger=logger)
async def _supervise_task(self) -> None:
try:
await self._initialize_fut
except asyncio.TimeoutError:
pass # this happens when the initialization takes longer than self._initialize_timeout
except Exception:
pass # initialization failed
# the process is killed if it doesn't respond to ping requests
pong_timeout = aio.sleep(self._opts.ping_timeout)
ipc_ch = aio.Chan[channel.Message]()
main_task = asyncio.create_task(self._main_task(ipc_ch))
read_ipc_task = asyncio.create_task(self._read_ipc_task(ipc_ch, pong_timeout))
ping_task = asyncio.create_task(self._ping_pong_task(pong_timeout))
read_ipc_task.add_done_callback(lambda _: ipc_ch.close())
memory_monitor_task: asyncio.Task[None] | None = None
if self._opts.memory_limit_mb > 0 or self._opts.memory_warn_mb > 0:
memory_monitor_task = asyncio.create_task(self._memory_monitor_task())
await self._join_fut
self._exitcode = self._proc.exitcode
self._proc.close()
await aio.gracefully_cancel(ping_task, read_ipc_task, main_task)
if memory_monitor_task is not None:
await aio.gracefully_cancel(memory_monitor_task)
with contextlib.suppress(duplex_unix.DuplexClosed):
await self._pch.aclose()
if self._exitcode != 0 and not self._kill_sent:
logger.error(
f"process exited with non-zero exit code {self.exitcode}",
extra=self.logging_extra(),
)
@log_exceptions(logger=logger)
async def _read_ipc_task(
self, ipc_ch: aio.Chan[channel.Message], pong_timeout: aio.Sleep
) -> None:
while True:
try:
msg = await channel.arecv_message(self._pch, proto.IPC_MESSAGES)
except duplex_unix.DuplexClosed:
break
if isinstance(msg, proto.PongResponse):
delay = time_ms() - msg.timestamp
if delay > self._opts.high_ping_threshold * 1000:
logger.warning(
"process is unresponsive",
extra={"delay": delay, **self.logging_extra()},
)
with contextlib.suppress(aio.SleepFinished):
pong_timeout.reset()
if isinstance(msg, proto.Exiting):
logger.info(
"process exiting",
extra={"reason": msg.reason, **self.logging_extra()},
)
ipc_ch.send_nowait(msg)
@log_exceptions(logger=logger)
async def _ping_pong_task(self, pong_timeout: aio.Sleep) -> None:
ping_interval = aio.interval(self._opts.ping_interval)
async def _send_ping_co():
while True:
await ping_interval.tick()
try:
await channel.asend_message(
self._pch, proto.PingRequest(timestamp=time_ms())
)
except duplex_unix.DuplexClosed:
break
async def _pong_timeout_co():
await pong_timeout
logger.error(
"process is unresponsive, killing process", extra=self.logging_extra()
)
self._send_kill_signal()
tasks = [
asyncio.create_task(_send_ping_co()),
asyncio.create_task(_pong_timeout_co()),
]
try:
await asyncio.gather(*tasks)
finally:
await aio.gracefully_cancel(*tasks)
@log_exceptions(logger=logger)
async def _memory_monitor_task(self) -> None:
"""Monitor memory usage and kill the process if it exceeds the limit."""
while not self._closing and not self._kill_sent:
try:
if not self._pid:
await asyncio.sleep(5)
continue
# get process memory info
process = psutil.Process(self._pid)
memory_info = process.memory_info()
memory_mb = memory_info.rss / (1024 * 1024) # Convert to MB
if (
self._opts.memory_limit_mb > 0
and memory_mb > self._opts.memory_limit_mb
):
logger.error(
"process exceeded memory limit, killing process",
extra={
"memory_usage_mb": memory_mb,
"memory_limit_mb": self._opts.memory_limit_mb,
**self.logging_extra(),
},
)
self._send_kill_signal()
elif (
self._opts.memory_warn_mb > 0
and memory_mb > self._opts.memory_warn_mb
):
logger.warning(
"process memory usage is high",
extra={
"memory_usage_mb": memory_mb,
"memory_warn_mb": self._opts.memory_warn_mb,
"memory_limit_mb": self._opts.memory_limit_mb,
**self.logging_extra(),
},
)
except (psutil.NoSuchProcess, psutil.AccessDenied) as e:
if self._closing or self._kill_sent:
return
logger.warning(
"Failed to get memory info for process",
extra=self.logging_extra(),
exc_info=e,
)
# don't bother rechecking if we cannot get process info
return
except Exception:
if self._closing or self._kill_sent:
return
logger.exception(
"Error in memory monitoring task",
extra=self.logging_extra(),
)
await asyncio.sleep(5) # check every 5 seconds
def logging_extra(self):
extra: dict[str, Any] = {
"pid": self.pid,
}
return extra
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import asyncio
import contextvars
import functools
import multiprocessing as mp
from dataclasses import dataclass
from enum import Enum, unique
from typing import Any, Callable, Coroutine, Tuple, Union
from livekit import api, rtc
from livekit.protocol import agent, models
from .ipc.inference_executor import InferenceExecutor
from .log import logger
_JobContextVar = contextvars.ContextVar["JobContext"]("agents_job_context")
def get_current_job_context() -> JobContext:
ctx = _JobContextVar.get(None)
if ctx is None:
raise RuntimeError(
"no job context found, are you running this code inside a job entrypoint?"
)
return ctx
@unique
class JobExecutorType(Enum):
PROCESS = "process"
THREAD = "thread"
class AutoSubscribe(str, Enum):
SUBSCRIBE_ALL = "subscribe_all"
SUBSCRIBE_NONE = "subscribe_none"
AUDIO_ONLY = "audio_only"
VIDEO_ONLY = "video_only"
@dataclass
class JobAcceptArguments:
name: str
identity: str
metadata: str
attributes: dict[str, str] | None = None
@dataclass
class RunningJobInfo:
accept_arguments: JobAcceptArguments
job: agent.Job
url: str
token: str
worker_id: str
DEFAULT_PARTICIPANT_KINDS: list[rtc.ParticipantKind.ValueType] = [
rtc.ParticipantKind.PARTICIPANT_KIND_SIP,
rtc.ParticipantKind.PARTICIPANT_KIND_STANDARD,
]
class JobContext:
def __init__(
self,
*,
proc: JobProcess,
info: RunningJobInfo,
room: rtc.Room,
on_connect: Callable[[], None],
on_shutdown: Callable[[str], None],
inference_executor: InferenceExecutor,
) -> None:
self._proc = proc
self._info = info
self._room = room
self._on_connect = on_connect
self._on_shutdown = on_shutdown
self._shutdown_callbacks: list[
Callable[[str], Coroutine[None, None, None]],
] = []
self._participant_entrypoints: list[
Tuple[
Callable[
[JobContext, rtc.RemoteParticipant], Coroutine[None, None, None]
],
list[rtc.ParticipantKind.ValueType] | rtc.ParticipantKind.ValueType,
]
] = []
self._participant_tasks = dict[Tuple[str, Callable], asyncio.Task[None]]()
self._room.on("participant_connected", self._participant_available)
self._inf_executor = inference_executor
@property
def inference_executor(self) -> InferenceExecutor:
return self._inf_executor
@functools.cached_property
def api(self) -> api.LiveKitAPI:
return api.LiveKitAPI()
@property
def proc(self) -> JobProcess:
"""Returns the process running the job. Useful for storing process-specific state."""
return self._proc
@property
def job(self) -> agent.Job:
"""Returns the current job that the worker is executing."""
return self._info.job
@property
def worker_id(self) -> str:
"""Returns the id of the worker."""
return self._info.worker_id
@property
def room(self) -> rtc.Room:
"""The Room object is the main interface that the worker should interact with.
When the entrypoint is called, the worker has not connected to the Room yet.
Certain properties of Room would not be available before calling JobContext.connect()
"""
return self._room
@property
def agent(self) -> rtc.LocalParticipant:
return self._room.local_participant
def add_shutdown_callback(
self,
callback: Union[
Callable[[], Coroutine[None, None, None]],
Callable[[str], Coroutine[None, None, None]],
],
) -> None:
"""
Add a callback to be called when the job is shutting down.
Optionally the callback can take a single argument, the shutdown reason.
"""
if callback.__code__.co_argcount > 0:
self._shutdown_callbacks.append(callback) # type: ignore
else:
async def wrapper(_: str) -> None:
await callback() # type: ignore
self._shutdown_callbacks.append(wrapper)
async def wait_for_participant(
self,
*,
identity: str | None = None,
kind: list[rtc.ParticipantKind.ValueType]
| rtc.ParticipantKind.ValueType = DEFAULT_PARTICIPANT_KINDS,
) -> rtc.RemoteParticipant:
"""
Returns a participant that matches the given identity. If identity is None, the first
participant that joins the room will be returned.
If the participant has already joined, the function will return immediately.
"""
if not self._room.isconnected():
raise RuntimeError("room is not connected")
fut = asyncio.Future[rtc.RemoteParticipant]()
def kind_match(p: rtc.RemoteParticipant) -> bool:
if isinstance(kind, list):
return p.kind in kind
return p.kind == kind
for p in self._room.remote_participants.values():
if (identity is None or p.identity == identity) and kind_match(p):
fut.set_result(p)
break
def _on_participant_connected(p: rtc.RemoteParticipant):
if (identity is None or p.identity == identity) and kind_match(p):
self._room.off("participant_connected", _on_participant_connected)
if not fut.done():
fut.set_result(p)
if not fut.done():
self._room.on("participant_connected", _on_participant_connected)
return await fut
async def connect(
self,
*,
e2ee: rtc.E2EEOptions | None = None,
auto_subscribe: AutoSubscribe = AutoSubscribe.SUBSCRIBE_ALL,
rtc_config: rtc.RtcConfiguration | None = None,
) -> None:
"""Connect to the room. This method should be called only once.
Args:
e2ee: End-to-end encryption options. If provided, the Agent will utilize end-to-end encryption. Note: clients will also need to handle E2EE.
auto_subscribe: Whether to automatically subscribe to tracks. Default is AutoSubscribe.SUBSCRIBE_ALL.
rtc_config: Custom RTC configuration to use when connecting to the room.
"""
room_options = rtc.RoomOptions(
e2ee=e2ee,
auto_subscribe=auto_subscribe == AutoSubscribe.SUBSCRIBE_ALL,
rtc_config=rtc_config,
)
await self._room.connect(self._info.url, self._info.token, options=room_options)
self._on_connect()
for p in self._room.remote_participants.values():
self._participant_available(p)
_apply_auto_subscribe_opts(self._room, auto_subscribe)
def shutdown(self, reason: str = "") -> None:
self._on_shutdown(reason)
def add_participant_entrypoint(
self,
entrypoint_fnc: Callable[
[JobContext, rtc.RemoteParticipant], Coroutine[None, None, None]
],
*_,
kind: list[rtc.ParticipantKind.ValueType]
| rtc.ParticipantKind.ValueType = DEFAULT_PARTICIPANT_KINDS,
):
"""Adds an entrypoint function to be run when a participant joins the room. In cases where
the participant has already joined, the entrypoint will be run immediately. Multiple unique entrypoints can be
added and they will each be run in parallel for each participant.
"""
if entrypoint_fnc in [e for (e, _) in self._participant_entrypoints]:
raise ValueError("entrypoints cannot be added more than once")
self._participant_entrypoints.append((entrypoint_fnc, kind))
def _participant_available(self, p: rtc.RemoteParticipant) -> None:
for coro, kind in self._participant_entrypoints:
if isinstance(kind, list):
if p.kind not in kind:
continue
else:
if p.kind != kind:
continue
if (p.identity, coro) in self._participant_tasks:
logger.warning(
f"a participant has joined before a prior participant task matching the same identity has finished: '{p.identity}'"
)
task_name = f"part-entry-{p.identity}-{coro.__name__}"
task = asyncio.create_task(coro(self, p), name=task_name)
self._participant_tasks[(p.identity, coro)] = task
task.add_done_callback(
lambda _: self._participant_tasks.pop((p.identity, coro))
)
def _apply_auto_subscribe_opts(room: rtc.Room, auto_subscribe: AutoSubscribe) -> None:
if auto_subscribe not in (AutoSubscribe.AUDIO_ONLY, AutoSubscribe.VIDEO_ONLY):
return
def _subscribe_if_needed(pub: rtc.RemoteTrackPublication):
if (
auto_subscribe == AutoSubscribe.AUDIO_ONLY
and pub.kind == rtc.TrackKind.KIND_AUDIO
) or (
auto_subscribe == AutoSubscribe.VIDEO_ONLY
and pub.kind == rtc.TrackKind.KIND_VIDEO
):
pub.set_subscribed(True)
for p in room.remote_participants.values():
for pub in p.track_publications.values():
_subscribe_if_needed(pub)
@room.on("track_published")
def on_track_published(pub: rtc.RemoteTrackPublication, _: rtc.RemoteParticipant):
_subscribe_if_needed(pub)
class JobProcess:
def __init__(
self,
*,
user_arguments: Any | None = None,
) -> None:
self._mp_proc = mp.current_process()
self._userdata: dict[str, Any] = {}
self._user_arguments = user_arguments
@property
def pid(self) -> int | None:
return self._mp_proc.pid
@property
def userdata(self) -> dict:
return self._userdata
@property
def user_arguments(self) -> Any | None:
return self._user_arguments
class JobRequest:
def __init__(
self,
*,
job: agent.Job,
on_reject: Callable[[], Coroutine[None, None, None]],
on_accept: Callable[[JobAcceptArguments], Coroutine[None, None, None]],
) -> None:
self._job = job
self._lock = asyncio.Lock()
self._on_reject = on_reject
self._on_accept = on_accept
@property
def id(self) -> str:
return self._job.id
@property
def job(self) -> agent.Job:
return self._job
@property
def room(self) -> models.Room:
return self._job.room
@property
def publisher(self) -> models.ParticipantInfo | None:
return self._job.participant
@property
def agent_name(self) -> str:
return self._job.agent_name
async def reject(self) -> None:
"""Reject the job request. The job may be assigned to another worker"""
await self._on_reject()
async def accept(
self,
*,
name: str = "",
identity: str = "",
metadata: str = "",
attributes: dict[str, str] | None = None,
) -> None:
"""Accept the job request, and start the job if the LiveKit SFU assigns the job to our worker."""
if not identity:
identity = "agent-" + self.id
accept_arguments = JobAcceptArguments(
name=name,
identity=identity,
metadata=metadata,
attributes=attributes,
)
await self._on_accept(accept_arguments)
from .chat_context import (
ChatAudio,
ChatContent,
ChatContext,
ChatImage,
ChatMessage,
ChatRole,
)
from .fallback_adapter import AvailabilityChangedEvent, FallbackAdapter
from .function_context import (
USE_DOCSTRING,
CalledFunction,
FunctionArgInfo,
FunctionCallInfo,
FunctionContext,
FunctionInfo,
TypeInfo,
_create_ai_function_info,
ai_callable,
)
from .llm import (
LLM,
ChatChunk,
Choice,
ChoiceDelta,
CompletionUsage,
LLMCapabilities,
LLMStream,
ToolChoice,
)
__all__ = [
"LLM",
"LLMStream",
"ChatContext",
"ChatRole",
"ChatMessage",
"ChatAudio",
"ChatImage",
"ChatContent",
"ChatContext",
"ChoiceDelta",
"Choice",
"ChatChunk",
"CompletionUsage",
"FunctionContext",
"ai_callable",
"TypeInfo",
"FunctionArgInfo",
"FunctionInfo",
"FunctionCallInfo",
"CalledFunction",
"USE_DOCSTRING",
"LLMCapabilities",
"FallbackAdapter",
"AvailabilityChangedEvent",
"ToolChoice",
"_create_ai_function_info",
]
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Literal, Union
from livekit import rtc
from livekit.agents import utils
from . import function_context
ChatRole = Literal["system", "user", "assistant", "tool"]
@dataclass
class ChatImage:
"""
ChatImage is used to input images into the ChatContext on supported LLM providers / plugins.
You may need to consult your LLM provider's documentation on supported URL types.
```python
# Pass a VideoFrame directly, which will be automatically converted to a JPEG data URL internally
async for event in rtc.VideoStream(video_track):
chat_image = ChatImage(image=event.frame)
# this instance is now available for your ChatContext
# Encode your VideoFrame yourself for more control, and pass the result as a data URL (see EncodeOptions for more details)
from livekit.agents.utils.images import encode, EncodeOptions, ResizeOptions
image_bytes = encode(
event.frame,
EncodeOptions(
format="PNG",
resize_options=ResizeOptions(
width=512, height=512, strategy="scale_aspect_fit"
),
),
)
chat_image = ChatImage(
image=f"data:image/png;base64,{base64.b64encode(image_bytes).decode('utf-8')}"
)
# With an external URL
chat_image = ChatImage(image="https://example.com/image.jpg")
```
"""
image: str | rtc.VideoFrame
"""
Either a string URL or a VideoFrame object
"""
inference_width: int | None = None
"""
Resizing parameter for rtc.VideoFrame inputs (ignored for URL images)
"""
inference_height: int | None = None
"""
Resizing parameter for rtc.VideoFrame inputs (ignored for URL images)
"""
inference_detail: Literal["auto", "high", "low"] = "auto"
"""
Detail parameter for LLM provider, if supported.
Currently only supported by OpenAI (see https://platform.openai.com/docs/guides/vision?lang=node#low-or-high-fidelity-image-understanding)
"""
_cache: dict[Any, Any] = field(default_factory=dict, repr=False, init=False)
"""
_cache is used internally by LLM implementations to store a processed version of the image
for later use.
"""
@dataclass
class ChatAudio:
frame: rtc.AudioFrame | list[rtc.AudioFrame]
ChatContent = Union[str, ChatImage, ChatAudio]
@dataclass
class ChatMessage:
role: ChatRole
id: str = field(
default_factory=lambda: utils.shortuuid("item_")
) # used by the OAI realtime API
name: str | None = None
content: ChatContent | list[ChatContent] | None = None
tool_calls: list[function_context.FunctionCallInfo] | None = None
tool_call_id: str | None = None
tool_exception: Exception | None = None
_metadata: dict[str, Any] = field(default_factory=dict, repr=False, init=False)
@staticmethod
def create_tool_from_called_function(
called_function: function_context.CalledFunction,
) -> "ChatMessage":
if not called_function.task.done():
raise ValueError("cannot create a tool result from a running ai function")
tool_exception: Exception | None = None
try:
content = called_function.task.result()
except BaseException as e:
if isinstance(e, Exception):
tool_exception = e
content = f"Error: {e}"
return ChatMessage(
role="tool",
name=called_function.call_info.function_info.name,
content=content,
tool_call_id=called_function.call_info.tool_call_id,
tool_exception=tool_exception,
)
@staticmethod
def create_tool_calls(
called_functions: list[function_context.FunctionCallInfo],
*,
text: str = "",
) -> "ChatMessage":
return ChatMessage(role="assistant", tool_calls=called_functions, content=text)
@staticmethod
def create(
*,
text: str = "",
images: list[ChatImage] = [],
role: ChatRole = "system",
id: str | None = None,
) -> "ChatMessage":
id = id or utils.shortuuid("item_")
if len(images) == 0:
return ChatMessage(role=role, content=text, id=id)
else:
content: list[ChatContent] = []
if text:
content.append(text)
if len(images) > 0:
content.extend(images)
return ChatMessage(role=role, content=content, id=id)
def copy(self):
content = self.content
if isinstance(content, list):
content = content.copy()
tool_calls = self.tool_calls
if tool_calls is not None:
tool_calls = tool_calls.copy()
copied_msg = ChatMessage(
role=self.role,
id=self.id,
name=self.name,
content=content,
tool_calls=tool_calls,
tool_call_id=self.tool_call_id,
)
copied_msg._metadata = self._metadata
return copied_msg
@dataclass
class ChatContext:
messages: list[ChatMessage] = field(default_factory=list)
_metadata: dict[str, Any] = field(default_factory=dict, repr=False, init=False)
def append(
self, *, text: str = "", images: list[ChatImage] = [], role: ChatRole = "system"
) -> ChatContext:
self.messages.append(ChatMessage.create(text=text, images=images, role=role))
return self
def copy(self) -> ChatContext:
copied_chat_ctx = ChatContext(messages=[m.copy() for m in self.messages])
copied_chat_ctx._metadata = self._metadata
return copied_chat_ctx
from __future__ import annotations
import asyncio
import dataclasses
import time
from dataclasses import dataclass
from typing import AsyncIterable, Literal, Optional, Union
from livekit.agents._exceptions import APIConnectionError, APIError
from ..log import logger
from ..types import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions
from .chat_context import ChatContext
from .function_context import CalledFunction, FunctionCallInfo, FunctionContext
from .llm import LLM, ChatChunk, LLMCapabilities, LLMStream, ToolChoice
DEFAULT_FALLBACK_API_CONNECT_OPTIONS = APIConnectOptions(
max_retry=0, timeout=DEFAULT_API_CONNECT_OPTIONS.timeout
)
@dataclass
class _LLMStatus:
available: bool
recovering_task: asyncio.Task | None
@dataclass
class AvailabilityChangedEvent:
llm: LLM
available: bool
class FallbackAdapter(
LLM[Literal["llm_availability_changed"]],
):
def __init__(
self,
llm: list[LLM],
*,
attempt_timeout: float = 10.0,
max_retry_per_llm: int = 1,
retry_interval: float = 5,
) -> None:
if len(llm) < 1:
raise ValueError("at least one LLM instance must be provided.")
super().__init__(
capabilities=LLMCapabilities(
supports_choices_on_int=all(
t.capabilities.supports_choices_on_int for t in llm
),
requires_persistent_functions=all(
t.capabilities.requires_persistent_functions for t in llm
),
)
)
self._llm_instances = llm
self._attempt_timeout = attempt_timeout
self._max_retry_per_llm = max_retry_per_llm
self._retry_interval = retry_interval
self._status = [
_LLMStatus(available=True, recovering_task=None)
for _ in self._llm_instances
]
def chat(
self,
*,
chat_ctx: ChatContext,
conn_options: APIConnectOptions = DEFAULT_FALLBACK_API_CONNECT_OPTIONS,
fnc_ctx: FunctionContext | None = None,
temperature: float | None = None,
n: int | None = 1,
parallel_tool_calls: bool | None = None,
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]]
| None = None,
) -> "LLMStream":
return FallbackLLMStream(
llm=self,
conn_options=conn_options,
chat_ctx=chat_ctx,
fnc_ctx=fnc_ctx,
temperature=temperature,
n=n,
parallel_tool_calls=parallel_tool_calls,
tool_choice=tool_choice,
)
class FallbackLLMStream(LLMStream):
def __init__(
self,
*,
llm: FallbackAdapter,
conn_options: APIConnectOptions,
chat_ctx: ChatContext,
fnc_ctx: FunctionContext | None,
temperature: float | None,
n: int | None,
parallel_tool_calls: bool | None,
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]]
| None = None,
) -> None:
super().__init__(
llm, chat_ctx=chat_ctx, fnc_ctx=fnc_ctx, conn_options=conn_options
)
self._fallback_adapter = llm
self._temperature = temperature
self._n = n
self._parallel_tool_calls = parallel_tool_calls
self._tool_choice = tool_choice
self._current_stream: Optional[LLMStream] = None
@property
def function_calls(self) -> list[FunctionCallInfo]:
if self._current_stream is None:
return []
return self._current_stream.function_calls
@property
def chat_ctx(self) -> ChatContext:
if self._current_stream is None:
return self._chat_ctx
return self._current_stream.chat_ctx
@property
def fnc_ctx(self) -> FunctionContext | None:
if self._current_stream is None:
return self._fnc_ctx
return self._current_stream.fnc_ctx
def execute_functions(self) -> list[CalledFunction]:
# this function is unused, but putting it in place for completeness
if self._current_stream is None:
return []
return self._current_stream.execute_functions()
async def _try_generate(
self, *, llm: LLM, check_recovery: bool = False
) -> AsyncIterable[ChatChunk]:
"""
Try to generate with the given LLM.
Args:
llm: The LLM instance to generate with
check_recovery: When True, indicates this is a background recovery check and the
result will not be used. Recovery checks verify if a previously
failed LLM has become available again.
"""
try:
async with llm.chat(
chat_ctx=self._chat_ctx,
fnc_ctx=self._fnc_ctx,
temperature=self._temperature,
n=self._n,
parallel_tool_calls=self._parallel_tool_calls,
tool_choice=self._tool_choice,
conn_options=dataclasses.replace(
self._conn_options,
max_retry=self._fallback_adapter._max_retry_per_llm,
timeout=self._fallback_adapter._attempt_timeout,
retry_interval=self._fallback_adapter._retry_interval,
),
) as stream:
should_set_current = not check_recovery
async for chunk in stream:
if should_set_current:
should_set_current = False
self._current_stream = stream
yield chunk
except asyncio.TimeoutError:
if check_recovery:
logger.warning(f"{llm.label} recovery timed out")
raise
logger.warning(
f"{llm.label} timed out, switching to next LLM",
)
raise
except APIError as e:
if check_recovery:
logger.warning(
f"{llm.label} recovery failed",
exc_info=e,
)
raise
logger.warning(
f"{llm.label} failed, switching to next LLM",
exc_info=e,
)
raise
except Exception:
if check_recovery:
logger.exception(
f"{llm.label} recovery unexpected error",
)
raise
logger.exception(
f"{llm.label} unexpected error, switching to next LLM",
)
raise
def _try_recovery(self, llm: LLM) -> None:
llm_status = self._fallback_adapter._status[
self._fallback_adapter._llm_instances.index(llm)
]
if llm_status.recovering_task is None or llm_status.recovering_task.done():
async def _recover_llm_task(llm: LLM) -> None:
try:
async for _ in self._try_generate(llm=llm, check_recovery=True):
pass
llm_status.available = True
logger.info(f"llm.FallbackAdapter, {llm.label} recovered")
self._fallback_adapter.emit(
"llm_availability_changed",
AvailabilityChangedEvent(llm=llm, available=True),
)
except Exception:
return
llm_status.recovering_task = asyncio.create_task(_recover_llm_task(llm))
async def _run(self) -> None:
start_time = time.time()
all_failed = all(
not llm_status.available for llm_status in self._fallback_adapter._status
)
if all_failed:
logger.error("all LLMs are unavailable, retrying..")
for i, llm in enumerate(self._fallback_adapter._llm_instances):
llm_status = self._fallback_adapter._status[i]
if llm_status.available or all_failed:
chunk_sent = False
try:
async for result in self._try_generate(
llm=llm, check_recovery=False
):
chunk_sent = True
self._event_ch.send_nowait(result)
return
except Exception: # exceptions already logged inside _try_synthesize
if llm_status.available:
llm_status.available = False
self._fallback_adapter.emit(
"llm_availability_changed",
AvailabilityChangedEvent(llm=llm, available=False),
)
if chunk_sent:
raise
self._try_recovery(llm)
raise APIConnectionError(
"all LLMs failed (%s) after %s seconds"
% (
[llm.label for llm in self._fallback_adapter._llm_instances],
time.time() - start_time,
)
)
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import asyncio
import enum
import functools
import inspect
import json
import types
import typing
from dataclasses import dataclass
from typing import Any, Callable, Optional, Tuple
from ..log import logger
class _UseDocMarker:
pass
METADATA_ATTR = "__livekit_ai_metadata__"
USE_DOCSTRING = _UseDocMarker()
@dataclass(frozen=True, init=False)
class TypeInfo:
description: str
choices: tuple
def __init__(self, description: str, choices: tuple | list[Any] = tuple()) -> None:
object.__setattr__(self, "description", description)
if isinstance(choices, list):
choices = tuple(choices)
object.__setattr__(self, "choices", choices)
@dataclass(frozen=True)
class FunctionArgInfo:
name: str
description: str
type: type
default: Any
choices: tuple | None
@dataclass(frozen=True)
class FunctionInfo:
name: str
description: str
auto_retry: bool
callable: Callable
arguments: dict[str, FunctionArgInfo]
@dataclass(frozen=True)
class FunctionCallInfo:
tool_call_id: str
function_info: FunctionInfo
raw_arguments: str
arguments: dict[str, Any]
def execute(self) -> CalledFunction:
function_info = self.function_info
func = functools.partial(function_info.callable, **self.arguments)
if asyncio.iscoroutinefunction(function_info.callable):
task = asyncio.create_task(func())
else:
task = asyncio.create_task(asyncio.to_thread(func))
called_fnc = CalledFunction(call_info=self, task=task)
def _on_done(fut):
try:
called_fnc.result = fut.result()
except BaseException as e:
called_fnc.exception = e
task.add_done_callback(_on_done)
return called_fnc
@dataclass
class CalledFunction:
call_info: FunctionCallInfo
task: asyncio.Task[Any]
result: Any | None = None
exception: BaseException | None = None
def ai_callable(
*,
name: str | None = None,
description: str | _UseDocMarker = USE_DOCSTRING,
auto_retry: bool = False,
) -> Callable:
def deco(f):
_set_metadata(f, name=name, desc=description, auto_retry=auto_retry)
return f
return deco
class FunctionContext:
def __init__(self) -> None:
self._fncs = dict[str, FunctionInfo]()
for _, member in inspect.getmembers(self, predicate=inspect.ismethod):
if hasattr(member, METADATA_ATTR):
self._register_ai_function(member)
def ai_callable(
self,
*,
name: str | None = None,
description: str | _UseDocMarker = USE_DOCSTRING,
auto_retry: bool = True,
) -> Callable:
def deco(f):
_set_metadata(f, name=name, desc=description, auto_retry=auto_retry)
self._register_ai_function(f)
return deco
def _register_ai_function(self, fnc: Callable) -> None:
if not hasattr(fnc, METADATA_ATTR):
logger.warning(f"function {fnc.__name__} does not have ai metadata")
return
metadata: _AIFncMetadata = getattr(fnc, METADATA_ATTR)
fnc_name = metadata.name
if fnc_name in self._fncs:
raise ValueError(f"duplicate ai_callable name: {fnc_name}")
sig = inspect.signature(fnc)
# get_type_hints with include_extra=True is needed when using Annotated
# using typing.get_args with param.Annotated is returning an empty tuple for some reason
type_hints = typing.get_type_hints(
fnc, include_extras=True
) # Annotated[T, ...] -> T
args = dict[str, FunctionArgInfo]()
for name, param in sig.parameters.items():
if param.kind not in (
inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.KEYWORD_ONLY,
):
raise ValueError(f"{fnc_name}: unsupported parameter kind {param.kind}")
inner_th, type_info = _extract_types(type_hints[name])
if not is_type_supported(inner_th):
raise ValueError(
f"{fnc_name}: unsupported type {inner_th} for parameter {name}"
)
desc = type_info.description if type_info else ""
choices = type_info.choices if type_info else ()
if (
isinstance(inner_th, type)
and issubclass(inner_th, enum.Enum)
and not choices
):
# the enum must be a str or int (and at least one value)
# this is verified by is_type_supported
choices = tuple([item.value for item in inner_th])
inner_th = type(choices[0])
args[name] = FunctionArgInfo(
name=name,
description=desc,
type=inner_th,
default=param.default,
choices=choices,
)
self._fncs[metadata.name] = FunctionInfo(
name=metadata.name,
description=metadata.description,
auto_retry=metadata.auto_retry,
callable=fnc,
arguments=args,
)
@property
def ai_functions(self) -> dict[str, FunctionInfo]:
return self._fncs
@dataclass(frozen=True)
class _AIFncMetadata:
name: str
description: str
auto_retry: bool
def _extract_types(annotation: type) -> tuple[type, TypeInfo | None]:
"""Return inner_type, TypeInfo"""
if typing.get_origin(annotation) is not typing.Annotated:
# email: Annotated[
# Optional[str], TypeInfo(description="The user address email")
# ] = None,
#
# An argument like the above will return us:
# `typing.Optional[typing.Annotated[typing.Optional[str], TypeInfo(description='The user address email', choices=())]]`
# So we ignore the first typing.Optional
is_optional, optional_inner = _is_optional_type(annotation)
if is_optional:
inner_type, info = _extract_types(optional_inner)
return Optional[inner_type], info # type: ignore
return annotation, None
# assume the first argument is always the inner type the LLM will use
args = typing.get_args(annotation)
if len(args) < 2:
return args[0], None
for a in args:
if isinstance(a, TypeInfo):
return args[0], a
return args[0], None
def _set_metadata(
f: Callable,
name: str | None = None,
desc: str | _UseDocMarker = USE_DOCSTRING,
auto_retry: bool = False,
) -> None:
if isinstance(desc, _UseDocMarker):
docstring = inspect.getdoc(f)
if docstring is None:
raise ValueError(
f"missing docstring for function {f.__name__}, "
"use explicit description or provide docstring"
)
desc = docstring
metadata = _AIFncMetadata(
name=name or f.__name__, description=desc, auto_retry=auto_retry
)
setattr(f, METADATA_ATTR, metadata)
def is_type_supported(t: type) -> bool:
if t in (str, int, float, bool):
return True
if typing.get_origin(t) is list:
in_type = typing.get_args(t)[0]
return is_type_supported(in_type)
is_optional, ty = _is_optional_type(t)
if is_optional:
return is_type_supported(ty)
if issubclass(t, enum.Enum):
initial_type = None
for e in t:
if initial_type is None:
initial_type = type(e.value)
if type(e.value) is not initial_type:
return False
return initial_type in (str, int)
return False
def _is_optional_type(typ) -> Tuple[bool, Any]:
"""return is_optional, inner_type"""
origin = typing.get_origin(typ)
if origin is None or origin is list:
return False, typ
if origin in {typing.Union, getattr(types, "UnionType", typing.Union)}:
args = typing.get_args(typ)
is_optional = type(None) in args
non_none_args = [a for a in args if a is not type(None)]
if is_optional and len(non_none_args) == 1:
# Exactly one non-None type + None means optional
return True, non_none_args[0]
return False, None
def _create_ai_function_info(
fnc_ctx: FunctionContext,
tool_call_id: str,
fnc_name: str,
raw_arguments: str, # JSON string
) -> FunctionCallInfo:
if fnc_name not in fnc_ctx.ai_functions:
raise ValueError(f"AI function {fnc_name} not found")
parsed_arguments: dict[str, Any] = {}
try:
if raw_arguments: # ignore empty string
parsed_arguments = json.loads(raw_arguments)
except json.JSONDecodeError:
raise ValueError(
f"AI function {fnc_name} received invalid JSON arguments - {raw_arguments}"
)
fnc_info = fnc_ctx.ai_functions[fnc_name]
# Ensure all necessary arguments are present and of the correct type.
sanitized_arguments: dict[str, Any] = {}
for arg_info in fnc_info.arguments.values():
if arg_info.name not in parsed_arguments:
if arg_info.default is inspect.Parameter.empty:
raise ValueError(
f"AI function {fnc_name} missing required argument {arg_info.name}"
)
continue
arg_value = parsed_arguments[arg_info.name]
is_optional, inner_th = _is_optional_type(arg_info.type)
if typing.get_origin(inner_th) is not None:
if not isinstance(arg_value, list):
raise ValueError(
f"AI function {fnc_name} argument {arg_info.name} should be a list"
)
inner_type = typing.get_args(inner_th)[0]
sanitized_value = [
_sanitize_primitive(
value=v,
expected_type=inner_type,
choices=arg_info.choices,
)
for v in arg_value
]
else:
sanitized_value = _sanitize_primitive(
value=arg_value,
expected_type=inner_th,
choices=arg_info.choices,
)
sanitized_arguments[arg_info.name] = sanitized_value
return FunctionCallInfo(
tool_call_id=tool_call_id,
raw_arguments=raw_arguments,
function_info=fnc_info,
arguments=sanitized_arguments,
)
def _sanitize_primitive(
*, value: Any, expected_type: type, choices: tuple | None
) -> Any:
if expected_type is str:
if not isinstance(value, str):
raise ValueError(f"expected str, got {type(value)}")
elif expected_type in (int, float):
if not isinstance(value, (int, float)):
raise ValueError(f"expected number, got {type(value)}")
if expected_type is int:
if value % 1 != 0:
raise ValueError("expected int, got float")
value = int(value)
elif expected_type is float:
value = float(value)
elif expected_type is bool:
if not isinstance(value, bool):
raise ValueError(f"expected bool, got {type(value)}")
if choices and value not in choices:
raise ValueError(f"invalid value {value}, not in {choices}")
return value
from __future__ import annotations
import asyncio
import time
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from types import TracebackType
from typing import (
Any,
AsyncIterable,
AsyncIterator,
Generic,
Literal,
TypeVar,
Union,
)
from livekit import rtc
from livekit.agents._exceptions import APIConnectionError, APIError
from .. import utils
from ..log import logger
from ..metrics import LLMMetrics
from ..types import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions
from ..utils import aio
from . import function_context
from .chat_context import ChatContext, ChatRole
@dataclass
class ChoiceDelta:
role: ChatRole
content: str | None = None
tool_calls: list[function_context.FunctionCallInfo] | None = None
@dataclass
class CompletionUsage:
completion_tokens: int
prompt_tokens: int
total_tokens: int
cache_creation_input_tokens: int = 0
cache_read_input_tokens: int = 0
@dataclass
class Choice:
delta: ChoiceDelta
index: int = 0
@dataclass
class LLMCapabilities:
supports_choices_on_int: bool = True
"""check whether the LLM supports integer enums choices as function arguments"""
requires_persistent_functions: bool = False
"""if the LLM requires function definition when previous function calls exist in chat context"""
@dataclass
class ChatChunk:
request_id: str
choices: list[Choice] = field(default_factory=list)
usage: CompletionUsage | None = None
@dataclass
class ToolChoice:
type: Literal["function"]
name: str
TEvent = TypeVar("TEvent")
class LLM(
ABC,
rtc.EventEmitter[Union[Literal["metrics_collected"], TEvent]],
Generic[TEvent],
):
def __init__(self, *, capabilities: LLMCapabilities | None = None) -> None:
super().__init__()
if capabilities is None:
capabilities = LLMCapabilities()
self._capabilities = capabilities
self._label = f"{type(self).__module__}.{type(self).__name__}"
@property
def label(self) -> str:
return self._label
@abstractmethod
def chat(
self,
*,
chat_ctx: ChatContext,
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
fnc_ctx: function_context.FunctionContext | None = None,
temperature: float | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]]
| None = None,
) -> "LLMStream": ...
@property
def capabilities(self) -> LLMCapabilities:
return self._capabilities
async def aclose(self) -> None: ...
async def __aenter__(self) -> LLM:
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
await self.aclose()
class LLMStream(ABC):
def __init__(
self,
llm: LLM,
*,
chat_ctx: ChatContext,
fnc_ctx: function_context.FunctionContext | None,
conn_options: APIConnectOptions,
) -> None:
self._llm = llm
self._chat_ctx = chat_ctx
self._fnc_ctx = fnc_ctx
self._conn_options = conn_options
self._event_ch = aio.Chan[ChatChunk]()
self._event_aiter, monitor_aiter = aio.itertools.tee(self._event_ch, 2)
self._metrics_task = asyncio.create_task(
self._metrics_monitor_task(monitor_aiter), name="LLM._metrics_task"
)
self._task = asyncio.create_task(self._main_task())
self._task.add_done_callback(lambda _: self._event_ch.close())
self._function_calls_info: list[function_context.FunctionCallInfo] = []
self._function_tasks = set[asyncio.Task[Any]]()
@abstractmethod
async def _run(self) -> None: ...
async def _main_task(self) -> None:
for i in range(self._conn_options.max_retry + 1):
try:
return await self._run()
except APIError as e:
retry_interval = self._conn_options._interval_for_retry(i)
if self._conn_options.max_retry == 0 or not e.retryable:
raise
elif i == self._conn_options.max_retry:
raise APIConnectionError(
f"failed to generate LLM completion after {self._conn_options.max_retry + 1} attempts",
) from e
else:
logger.warning(
f"failed to generate LLM completion, retrying in {retry_interval}s",
exc_info=e,
extra={
"llm": self._llm._label,
"attempt": i + 1,
},
)
await asyncio.sleep(retry_interval)
@utils.log_exceptions(logger=logger)
async def _metrics_monitor_task(
self, event_aiter: AsyncIterable[ChatChunk]
) -> None:
start_time = time.perf_counter()
ttft = -1.0
request_id = ""
usage: CompletionUsage | None = None
async for ev in event_aiter:
request_id = ev.request_id
if ttft == -1.0:
ttft = time.perf_counter() - start_time
if ev.usage is not None:
usage = ev.usage
duration = time.perf_counter() - start_time
metrics = LLMMetrics(
timestamp=time.time(),
request_id=request_id,
ttft=ttft,
duration=duration,
cancelled=self._task.cancelled(),
label=self._llm._label,
completion_tokens=usage.completion_tokens if usage else 0,
prompt_tokens=usage.prompt_tokens if usage else 0,
total_tokens=usage.total_tokens if usage else 0,
tokens_per_second=usage.completion_tokens / duration if usage else 0.0,
error=None,
)
self._llm.emit("metrics_collected", metrics)
@property
def function_calls(self) -> list[function_context.FunctionCallInfo]:
"""List of called functions from this stream."""
return self._function_calls_info
@property
def chat_ctx(self) -> ChatContext:
"""The initial chat context of this stream."""
return self._chat_ctx
@property
def fnc_ctx(self) -> function_context.FunctionContext | None:
"""The function context of this stream."""
return self._fnc_ctx
def execute_functions(self) -> list[function_context.CalledFunction]:
"""Execute all functions concurrently of this stream."""
called_functions: list[function_context.CalledFunction] = []
for fnc_info in self._function_calls_info:
called_fnc = fnc_info.execute()
self._function_tasks.add(called_fnc.task)
called_fnc.task.add_done_callback(self._function_tasks.remove)
called_functions.append(called_fnc)
return called_functions
async def aclose(self) -> None:
await aio.gracefully_cancel(self._task)
await utils.aio.gracefully_cancel(*self._function_tasks)
await self._metrics_task
async def __anext__(self) -> ChatChunk:
try:
val = await self._event_aiter.__anext__()
except StopAsyncIteration:
if not self._task.cancelled() and (exc := self._task.exception()):
raise exc from None
raise StopAsyncIteration
return val
def __aiter__(self) -> AsyncIterator[ChatChunk]:
return self
async def __aenter__(self) -> LLMStream:
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
await self.aclose()
import logging
DEV_LEVEL = 23
logging.addLevelName(DEV_LEVEL, "DEV")
logger = logging.getLogger("livekit.agents")
from .base import (
AgentMetrics,
LLMMetrics,
MultimodalLLMError,
MultimodalLLMMetrics,
PipelineEOUMetrics,
PipelineLLMMetrics,
PipelineSTTMetrics,
PipelineTTSMetrics,
PipelineVADMetrics,
STTMetrics,
TTSMetrics,
VADMetrics,
)
from .usage_collector import UsageCollector, UsageSummary
from .utils import log_metrics
__all__ = [
"LLMMetrics",
"MultimodalLLMError",
"MultimodalLLMMetrics",
"AgentMetrics",
"PipelineEOUMetrics",
"PipelineSTTMetrics",
"PipelineTTSMetrics",
"PipelineVADMetrics",
"PipelineLLMMetrics",
"VADMetrics",
"STTMetrics",
"TTSMetrics",
"UsageSummary",
"UsageCollector",
"log_metrics",
]
from __future__ import annotations
from dataclasses import dataclass
from typing import Union
@dataclass
class Error:
pass
@dataclass
class LLMMetrics:
request_id: str
timestamp: float
ttft: float
duration: float
label: str
cancelled: bool
completion_tokens: int
prompt_tokens: int
total_tokens: int
tokens_per_second: float
error: Error | None
@dataclass
class STTMetrics:
request_id: str
timestamp: float
duration: float
label: str
audio_duration: float
streamed: bool
error: Error | None
@dataclass
class TTSMetrics:
request_id: str
timestamp: float
ttfb: float
duration: float
audio_duration: float
cancelled: bool
characters_count: int
label: str
streamed: bool
error: Error | None
@dataclass
class VADMetrics:
timestamp: float
idle_time: float
inference_duration_total: float
inference_count: int
label: str
@dataclass
class PipelineSTTMetrics(STTMetrics):
pass
@dataclass
class PipelineEOUMetrics:
sequence_id: str
"""Unique identifier shared across different metrics to combine related STT, LLM, and TTS metrics."""
timestamp: float
"""Timestamp of when the event was recorded."""
end_of_utterance_delay: float
"""Amount of time between the end of speech from VAD and the decision to end the user's turn."""
transcription_delay: float
"""Time taken to obtain the transcript after the end of the user's speech.
May be 0 if the transcript was already available.
"""
@dataclass
class PipelineLLMMetrics(LLMMetrics):
sequence_id: str
"""Unique identifier shared across different metrics to combine related STT, LLM, and TTS metrics."""
@dataclass
class PipelineTTSMetrics(TTSMetrics):
sequence_id: str
"""Unique identifier shared across different metrics to combine related STT, LLM, and TTS metrics."""
@dataclass
class PipelineVADMetrics(VADMetrics):
pass
@dataclass
class MultimodalLLMError(Error):
type: str | None
reason: str | None = None
code: str | None = None
message: str | None = None
@dataclass
class MultimodalLLMMetrics(LLMMetrics):
@dataclass
class CachedTokenDetails:
text_tokens: int
audio_tokens: int
@dataclass
class InputTokenDetails:
cached_tokens: int
text_tokens: int
audio_tokens: int
cached_tokens_details: MultimodalLLMMetrics.CachedTokenDetails
@dataclass
class OutputTokenDetails:
text_tokens: int
audio_tokens: int
input_token_details: InputTokenDetails
output_token_details: OutputTokenDetails
AgentMetrics = Union[
STTMetrics,
LLMMetrics,
TTSMetrics,
VADMetrics,
PipelineSTTMetrics,
PipelineEOUMetrics,
PipelineLLMMetrics,
PipelineTTSMetrics,
PipelineVADMetrics,
MultimodalLLMMetrics,
]
from copy import deepcopy
from dataclasses import dataclass
from .base import AgentMetrics, LLMMetrics, STTMetrics, TTSMetrics
@dataclass
class UsageSummary:
llm_prompt_tokens: int
llm_completion_tokens: int
tts_characters_count: int
stt_audio_duration: float
class UsageCollector:
def __init__(self) -> None:
self._summary = UsageSummary(0, 0, 0, 0.0)
def __call__(self, metrics: AgentMetrics) -> None:
self.collect(metrics)
def collect(self, metrics: AgentMetrics) -> None:
if isinstance(metrics, LLMMetrics):
self._summary.llm_prompt_tokens += metrics.prompt_tokens
self._summary.llm_completion_tokens += metrics.completion_tokens
elif isinstance(metrics, TTSMetrics):
self._summary.tts_characters_count += metrics.characters_count
elif isinstance(metrics, STTMetrics):
self._summary.stt_audio_duration += metrics.audio_duration
def get_summary(self) -> UsageSummary:
return deepcopy(self._summary)
from __future__ import annotations
import logging
from ..log import logger as default_logger
from .base import (
AgentMetrics,
LLMMetrics,
PipelineEOUMetrics,
PipelineLLMMetrics,
PipelineSTTMetrics,
PipelineTTSMetrics,
STTMetrics,
TTSMetrics,
)
def log_metrics(metrics: AgentMetrics, *, logger: logging.Logger | None = None):
if logger is None:
logger = default_logger
if isinstance(metrics, PipelineLLMMetrics):
logger.info(
f"Pipeline LLM metrics: sequence_id={metrics.sequence_id}, ttft={metrics.ttft:.2f}, input_tokens={metrics.prompt_tokens}, output_tokens={metrics.completion_tokens}, tokens_per_second={metrics.tokens_per_second:.2f}"
)
elif isinstance(metrics, LLMMetrics):
logger.info(
f"LLM metrics: ttft={metrics.ttft:.2f}, input_tokens={metrics.prompt_tokens}, output_tokens={metrics.completion_tokens}, tokens_per_second={metrics.tokens_per_second:.2f}"
)
elif isinstance(metrics, PipelineTTSMetrics):
logger.info(
f"Pipeline TTS metrics: sequence_id={metrics.sequence_id}, ttfb={metrics.ttfb}, audio_duration={metrics.audio_duration:.2f}"
)
elif isinstance(metrics, TTSMetrics):
logger.info(
f"TTS metrics: ttfb={metrics.ttfb}, audio_duration={metrics.audio_duration:.2f}"
)
elif isinstance(metrics, PipelineEOUMetrics):
logger.info(
f"Pipeline EOU metrics: sequence_id={metrics.sequence_id}, end_of_utterance_delay={metrics.end_of_utterance_delay:.2f}, transcription_delay={metrics.transcription_delay:.2f}"
)
elif isinstance(metrics, PipelineSTTMetrics):
logger.info(
f"Pipeline STT metrics: duration={metrics.duration:.2f}, audio_duration={metrics.audio_duration:.2f}"
)
elif isinstance(metrics, STTMetrics):
logger.info(f"STT metrics: audio_duration={metrics.audio_duration:.2f}")
from .multimodal_agent import (
AgentTranscriptionOptions,
MultimodalAgent,
_RealtimeAPI,
_RealtimeAPISession,
)
__all__ = [
"MultimodalAgent",
"AgentTranscriptionOptions",
"_RealtimeAPI",
"_RealtimeAPISession",
]
from __future__ import annotations
import asyncio
from typing import AsyncIterable, Literal
from livekit import rtc
from livekit.agents import transcription, utils
from ..log import logger
EventTypes = Literal["playout_started", "playout_stopped"]
class PlayoutHandle:
def __init__(
self,
*,
audio_source: rtc.AudioSource,
item_id: str,
content_index: int,
transcription_fwd: transcription.TTSSegmentsForwarder,
) -> None:
self._audio_source = audio_source
self._tr_fwd = transcription_fwd
self._item_id = item_id
self._content_index = content_index
self._int_fut = asyncio.Future[None]()
self._done_fut = asyncio.Future[None]()
self._interrupted = False
self._pushed_duration = 0.0
self._total_played_time: float | None = None # set when the playout is done
@property
def item_id(self) -> str:
return self._item_id
@property
def audio_samples(self) -> int:
if self._total_played_time is not None:
return int(self._total_played_time * 24000)
return int((self._pushed_duration - self._audio_source.queued_duration) * 24000)
@property
def text_chars(self) -> int:
return len(self._tr_fwd.played_text)
@property
def content_index(self) -> int:
return self._content_index
@property
def interrupted(self) -> bool:
return self._interrupted
def done(self) -> bool:
return self._done_fut.done() or self._interrupted
def interrupt(self) -> None:
if self.done():
return
self._int_fut.set_result(None)
self._interrupted = True
class AgentPlayout(utils.EventEmitter[EventTypes]):
def __init__(self, *, audio_source: rtc.AudioSource) -> None:
super().__init__()
self._source = audio_source
self._playout_atask: asyncio.Task[None] | None = None
def play(
self,
*,
item_id: str,
content_index: int,
transcription_fwd: transcription.TTSSegmentsForwarder,
text_stream: AsyncIterable[str],
audio_stream: AsyncIterable[rtc.AudioFrame],
) -> PlayoutHandle:
handle = PlayoutHandle(
audio_source=self._source,
item_id=item_id,
content_index=content_index,
transcription_fwd=transcription_fwd,
)
self._playout_atask = asyncio.create_task(
self._playout_task(self._playout_atask, handle, text_stream, audio_stream)
)
return handle
@utils.log_exceptions(logger=logger)
async def _playout_task(
self,
old_task: asyncio.Task[None],
handle: PlayoutHandle,
text_stream: AsyncIterable[str],
audio_stream: AsyncIterable[rtc.AudioFrame],
) -> None:
if old_task is not None:
await utils.aio.gracefully_cancel(old_task)
first_frame = True
@utils.log_exceptions(logger=logger)
async def _play_text_stream():
async for text in text_stream:
handle._tr_fwd.push_text(text)
handle._tr_fwd.mark_text_segment_end()
@utils.log_exceptions(logger=logger)
async def _capture_task():
nonlocal first_frame
samples_per_channel = 1200
bstream = utils.audio.AudioByteStream(
24000,
1,
samples_per_channel=samples_per_channel,
)
async for frame in audio_stream:
if first_frame:
handle._tr_fwd.segment_playout_started()
self.emit("playout_started")
first_frame = False
handle._tr_fwd.push_audio(frame)
for f in bstream.write(frame.data.tobytes()):
handle._pushed_duration += f.samples_per_channel / f.sample_rate
await self._source.capture_frame(f)
for f in bstream.flush():
handle._pushed_duration += f.samples_per_channel / f.sample_rate
await self._source.capture_frame(f)
handle._tr_fwd.mark_audio_segment_end()
await self._source.wait_for_playout()
read_text_task = asyncio.create_task(_play_text_stream())
capture_task = asyncio.create_task(_capture_task())
try:
await asyncio.wait(
[capture_task, handle._int_fut],
return_when=asyncio.FIRST_COMPLETED,
)
finally:
await utils.aio.gracefully_cancel(capture_task)
handle._total_played_time = (
handle._pushed_duration - self._source.queued_duration
)
if handle.interrupted or capture_task.exception():
self._source.clear_queue() # make sure to remove any queued frames
await utils.aio.gracefully_cancel(read_text_task)
# make sure the text_data.sentence_stream is closed
handle._tr_fwd.mark_text_segment_end()
if not first_frame and not handle.interrupted:
handle._tr_fwd.segment_playout_finished()
await handle._tr_fwd.aclose()
handle._done_fut.set_result(None)
# emit playout_stopped after the transcription forwarder has been closed
if not first_frame:
self.emit("playout_stopped", handle.interrupted)
from __future__ import annotations
import asyncio
from dataclasses import dataclass
from typing import (
Any,
AsyncIterable,
Callable,
Literal,
Optional,
Protocol,
TypeVar,
Union,
overload,
)
import aiohttp
from livekit import rtc
from livekit.agents import llm, stt, tokenize, transcription, utils, vad
from livekit.agents.llm import ChatMessage
from livekit.agents.metrics import MultimodalLLMMetrics
from ..log import logger
from ..types import ATTRIBUTE_AGENT_STATE, AgentState
from . import agent_playout
EventTypes = Literal[
"user_started_speaking",
"user_stopped_speaking",
"agent_started_speaking",
"agent_stopped_speaking",
"user_speech_committed",
"agent_speech_committed",
"agent_speech_interrupted",
"function_calls_collected",
"function_calls_finished",
"metrics_collected",
]
class _InputTranscriptionProto(Protocol):
item_id: str
"""id of the item"""
transcript: str
"""transcript of the input audio"""
class _ContentProto(Protocol):
response_id: str
item_id: str
output_index: int
content_index: int
text: str
audio: list[rtc.AudioFrame]
text_stream: AsyncIterable[str]
audio_stream: AsyncIterable[rtc.AudioFrame]
content_type: Literal["text", "audio"]
class _CapabilitiesProto(Protocol):
supports_truncate: bool
input_audio_sample_rate: int | None
class _RealtimeAPI(Protocol):
"""Realtime API protocol"""
@property
def capabilities(self) -> _CapabilitiesProto: ...
def session(
self,
*,
chat_ctx: llm.ChatContext | None = None,
fnc_ctx: llm.FunctionContext | None = None,
) -> _RealtimeAPISession:
"""
Create a new realtime session with the given chat and function contexts.
"""
pass
T = TypeVar("T", bound=Callable[..., Any])
class _RealtimeAPISession(Protocol):
async def set_chat_ctx(self, ctx: llm.ChatContext) -> None: ...
@overload
def on(self, event: str, callback: None = None) -> Callable[[T], T]: ...
@overload
def on(self, event: str, callback: T) -> T: ...
def on(
self, event: str, callback: Optional[T] = None
) -> Union[T, Callable[[T], T]]: ...
def _push_audio(self, frame: rtc.AudioFrame) -> None: ...
@property
def fnc_ctx(self) -> llm.FunctionContext | None: ...
@fnc_ctx.setter
def fnc_ctx(self, value: llm.FunctionContext | None) -> None: ...
def chat_ctx_copy(self) -> llm.ChatContext: ...
def cancel_response(self) -> None: ...
def create_response(
self,
on_duplicate: Literal[
"cancel_existing", "cancel_new", "keep_both"
] = "keep_both",
) -> None: ...
def commit_audio_buffer(self) -> None: ...
@property
def server_vad_enabled(self) -> bool: ...
def _recover_from_text_response(self, item_id: str) -> None: ...
def _update_conversation_item_content(
self,
item_id: str,
content: llm.ChatContent | list[llm.ChatContent] | None = None,
) -> None: ...
def _truncate_conversation_item(
self, item_id: str, content_index: int, audio_end_ms: int
) -> None: ...
@property
def playout_complete(self) -> asyncio.Event | None:
"""Event that is set when the playout is done"""
pass
@dataclass(frozen=True)
class AgentTranscriptionOptions:
user_transcription: bool = True
"""Whether to forward the user transcription to the client"""
agent_transcription: bool = True
"""Whether to forward the agent transcription to the client"""
agent_transcription_speed: float = 1.0
"""The speed at which the agent's speech transcription is forwarded to the client.
We try to mimic the agent's speech speed by adjusting the transcription speed."""
sentence_tokenizer: tokenize.SentenceTokenizer = tokenize.basic.SentenceTokenizer()
"""The tokenizer used to split the speech into sentences.
This is used to decide when to mark a transcript as final for the agent transcription."""
word_tokenizer: tokenize.WordTokenizer = tokenize.basic.WordTokenizer(
ignore_punctuation=False
)
"""The tokenizer used to split the speech into words.
This is used to simulate the "interim results" of the agent transcription."""
hyphenate_word: Callable[[str], list[str]] = tokenize.basic.hyphenate_word
"""A function that takes a string (word) as input and returns a list of strings,
representing the hyphenated parts of the word."""
@dataclass(frozen=True)
class _ImplOptions:
transcription: AgentTranscriptionOptions
class MultimodalAgent(utils.EventEmitter[EventTypes]):
def __init__(
self,
*,
model: _RealtimeAPI,
vad: vad.VAD | None = None,
chat_ctx: llm.ChatContext | None = None,
fnc_ctx: llm.FunctionContext | None = None,
transcription: AgentTranscriptionOptions = AgentTranscriptionOptions(),
max_text_response_retries: int = 5,
loop: asyncio.AbstractEventLoop | None = None,
noise_cancellation: rtc.NoiseCancellationOptions | None = None,
):
"""Create a new MultimodalAgent.
Args:
model: RealtimeAPI instance.
vad: Voice Activity Detection (VAD) instance.
chat_ctx: Chat context for the assistant.
fnc_ctx: Function context for the assistant.
transcription: Options for assistant transcription.
max_text_response_retries: Maximum number of retries to recover
from text responses to audio mode. OpenAI's realtime API has a
chance to return text responses instead of audio if the chat
context includes text system or assistant messages. The agent will
attempt to recover to audio mode by deleting the text response
and appending an empty audio message to the conversation.
loop: Event loop to use. Default to asyncio.get_event_loop().
"""
super().__init__()
self._loop = loop or asyncio.get_event_loop()
self._model = model
self._vad = vad
self._chat_ctx = chat_ctx
self._fnc_ctx = fnc_ctx
self._opts = _ImplOptions(
transcription=transcription,
)
# audio input
self._read_micro_atask: asyncio.Task | None = None
self._subscribed_track: rtc.RemoteAudioTrack | None = None
self._input_audio_ch = utils.aio.Chan[rtc.AudioFrame]()
# audio output
self._playing_handle: agent_playout.PlayoutHandle | None = None
self._linked_participant: rtc.RemoteParticipant | None = None
self._started, self._closed = False, False
self._update_state_task: asyncio.Task | None = None
self._http_session: aiohttp.ClientSession | None = None
self._text_response_retries = 0
self._max_text_response_retries = max_text_response_retries
self._noise_cancellation = noise_cancellation
@property
def vad(self) -> vad.VAD | None:
return self._vad
@property
def fnc_ctx(self) -> llm.FunctionContext | None:
return self._session.fnc_ctx
@fnc_ctx.setter
def fnc_ctx(self, value: llm.FunctionContext | None) -> None:
self._session.fnc_ctx = value
def chat_ctx_copy(self) -> llm.ChatContext:
return self._session.chat_ctx_copy()
async def set_chat_ctx(self, ctx: llm.ChatContext) -> None:
await self._session.set_chat_ctx(ctx)
def start(
self, room: rtc.Room, participant: rtc.RemoteParticipant | str | None = None
) -> None:
if self._started:
raise RuntimeError("voice assistant already started")
room.on("participant_connected", self._on_participant_connected)
room.on("track_published", self._subscribe_to_microphone)
room.on("track_subscribed", self._subscribe_to_microphone)
self._room, self._participant = room, participant
if participant is not None:
if isinstance(participant, rtc.RemoteParticipant):
self._link_participant(participant.identity)
else:
self._link_participant(participant)
else:
# no participant provided, try to find the first participant in the room
for participant in self._room.remote_participants.values():
self._link_participant(participant.identity)
break
self._session = self._model.session(
chat_ctx=self._chat_ctx, fnc_ctx=self._fnc_ctx
)
# Create a task to wait for initialization and start the main task
async def _init_and_start():
try:
await self._session._init_sync_task
logger.info("Session initialized with chat context")
self._main_atask = asyncio.create_task(self._main_task())
except Exception as e:
logger.exception("Failed to initialize session")
raise e
# Schedule the initialization and start task
asyncio.create_task(_init_and_start())
@self._session.on("response_content_added")
def _on_content_added(message: _ContentProto):
tr_fwd = transcription.TTSSegmentsForwarder(
room=self._room,
participant=self._room.local_participant,
speed=self._opts.transcription.agent_transcription_speed,
sentence_tokenizer=self._opts.transcription.sentence_tokenizer,
word_tokenizer=self._opts.transcription.word_tokenizer,
hyphenate_word=self._opts.transcription.hyphenate_word,
)
self._playing_handle = self._agent_playout.play(
item_id=message.item_id,
content_index=message.content_index,
transcription_fwd=tr_fwd,
text_stream=message.text_stream,
audio_stream=message.audio_stream,
)
@self._session.on("response_content_done")
def _response_content_done(message: _ContentProto):
if message.content_type == "text":
if self._text_response_retries >= self._max_text_response_retries:
raise RuntimeError(
f"The OpenAI Realtime API returned a text response "
f"after {self._max_text_response_retries} retries. "
f"Please try to reduce the number of text system or "
f"assistant messages in the chat context."
)
self._text_response_retries += 1
logger.warning(
"The OpenAI Realtime API returned a text response instead of audio. "
"Attempting to recover to audio mode...",
extra={
"item_id": message.item_id,
"text": message.text,
"retries": self._text_response_retries,
},
)
self._session._recover_from_text_response(message.item_id)
else:
self._text_response_retries = 0
@self._session.on("input_speech_committed")
def _input_speech_committed():
self._stt_forwarder.update(
stt.SpeechEvent(
type=stt.SpeechEventType.INTERIM_TRANSCRIPT,
alternatives=[stt.SpeechData(language="", text="")],
)
)
@self._session.on("input_speech_transcription_completed")
def _input_speech_transcription_completed(ev: _InputTranscriptionProto):
self._stt_forwarder.update(
stt.SpeechEvent(
type=stt.SpeechEventType.FINAL_TRANSCRIPT,
alternatives=[stt.SpeechData(language="", text=ev.transcript)],
)
)
if self._model.capabilities.supports_truncate:
user_msg = ChatMessage.create(
text=ev.transcript, role="user", id=ev.item_id
)
self._session._update_conversation_item_content(
ev.item_id, user_msg.content
)
self._emit_speech_committed("user", ev.transcript)
@self._session.on("agent_speech_transcription_completed")
def _agent_speech_transcription_completed(ev: _InputTranscriptionProto):
self._agent_stt_forwarder.update(
stt.SpeechEvent(
type=stt.SpeechEventType.FINAL_TRANSCRIPT,
alternatives=[stt.SpeechData(language="", text=ev.transcript)],
)
)
self._emit_speech_committed("agent", ev.transcript)
# Similar to _input_speech_started, this handles updating the state to "listening" when the agent's speech is complete.
# However, since Gemini doesn't support VAD events, we are not emitting the `user_started_speaking` event here.
@self._session.on("agent_speech_stopped")
def _agent_speech_stopped():
self.interrupt()
@self._session.on("input_speech_started")
def _input_speech_started():
self.emit("user_started_speaking")
self.interrupt()
@self._session.on("input_speech_stopped")
def _input_speech_stopped():
self.emit("user_stopped_speaking")
@self._session.on("function_calls_collected")
def _function_calls_collected(fnc_call_infos: list[llm.FunctionCallInfo]):
self.emit("function_calls_collected", fnc_call_infos)
@self._session.on("function_calls_finished")
def _function_calls_finished(called_fncs: list[llm.CalledFunction]):
self.emit("function_calls_finished", called_fncs)
@self._session.on("metrics_collected")
def _metrics_collected(metrics: MultimodalLLMMetrics):
self.emit("metrics_collected", metrics)
def interrupt(self) -> None:
if self._playing_handle is not None and not self._playing_handle.done():
self._playing_handle.interrupt()
if self._model.capabilities.supports_truncate:
self._session.cancel_response() # Only supported by OpenAI
self._session._truncate_conversation_item(
item_id=self._playing_handle.item_id,
content_index=self._playing_handle.content_index,
audio_end_ms=int(self._playing_handle.audio_samples / 24000 * 1000),
)
self._update_state("listening")
def generate_reply(
self,
on_duplicate: Literal[
"cancel_existing", "cancel_new", "keep_both"
] = "cancel_existing",
) -> None:
"""Generate a reply from the agent"""
if not self._session.server_vad_enabled:
self._session.commit_audio_buffer()
self._session.create_response(on_duplicate=on_duplicate)
def _update_state(self, state: AgentState, delay: float = 0.0):
"""Set the current state of the agent"""
@utils.log_exceptions(logger=logger)
async def _run_task(delay: float) -> None:
await asyncio.sleep(delay)
if self._room.isconnected():
await self._room.local_participant.set_attributes(
{ATTRIBUTE_AGENT_STATE: state}
)
if self._update_state_task is not None:
self._update_state_task.cancel()
self._update_state_task = asyncio.create_task(_run_task(delay))
@utils.log_exceptions(logger=logger)
async def _main_task(self) -> None:
self._update_state("initializing")
self._audio_source = rtc.AudioSource(24000, 1)
track = rtc.LocalAudioTrack.create_audio_track(
"assistant_voice", self._audio_source
)
self._agent_publication = await self._room.local_participant.publish_track(
track, rtc.TrackPublishOptions(source=rtc.TrackSource.SOURCE_MICROPHONE)
)
self._agent_stt_forwarder = transcription.STTSegmentsForwarder(
room=self._room,
participant=self._room.local_participant,
track=track,
)
self._agent_playout = agent_playout.AgentPlayout(
audio_source=self._audio_source
)
def _on_playout_started() -> None:
if self._session.playout_complete is not None:
self._session.playout_complete.clear()
self.emit("agent_started_speaking")
self._update_state("speaking")
def _on_playout_stopped(interrupted: bool) -> None:
if self._session.playout_complete is not None:
self._session.playout_complete.set()
self.emit("agent_stopped_speaking")
self._update_state("listening")
if self._playing_handle is not None:
collected_text = self._playing_handle._tr_fwd.played_text
if interrupted:
collected_text += "..."
if self._model.capabilities.supports_truncate and collected_text:
msg = ChatMessage.create(
text=collected_text,
role="assistant",
id=self._playing_handle.item_id,
)
self._session._update_conversation_item_content(
self._playing_handle.item_id, msg.content
)
self._emit_speech_committed("agent", collected_text, interrupted)
self._agent_playout.on("playout_started", _on_playout_started)
self._agent_playout.on("playout_stopped", _on_playout_stopped)
await self._agent_publication.wait_for_subscription()
bstream = utils.audio.AudioByteStream(
24000,
1,
samples_per_channel=2400,
)
async for frame in self._input_audio_ch:
for f in bstream.write(frame.data.tobytes()):
self._session._push_audio(f)
def _on_participant_connected(self, participant: rtc.RemoteParticipant):
if self._linked_participant is None:
return
self._link_participant(participant.identity)
def _link_participant(self, participant_identity: str) -> None:
self._linked_participant = self._room.remote_participants.get(
participant_identity
)
if self._linked_participant is None:
logger.error("_link_participant must be called with a valid identity")
return
self._subscribe_to_microphone()
async def _micro_task(self, track: rtc.LocalAudioTrack) -> None:
sample_rate = self._model.capabilities.input_audio_sample_rate
if sample_rate is None:
sample_rate = 24000
input_stream = rtc.AudioStream(
track,
sample_rate=sample_rate,
num_channels=1,
noise_cancellation=self._noise_cancellation,
)
async for ev in input_stream:
self._input_audio_ch.send_nowait(ev.frame)
def _subscribe_to_microphone(self, *args, **kwargs) -> None:
"""Subscribe to the participant microphone if found"""
if self._linked_participant is None:
return
for publication in self._linked_participant.track_publications.values():
if publication.source != rtc.TrackSource.SOURCE_MICROPHONE:
continue
if not publication.subscribed:
publication.set_subscribed(True)
if (
publication.track is not None
and publication.track != self._subscribed_track
):
self._subscribed_track = publication.track # type: ignore
self._stt_forwarder = transcription.STTSegmentsForwarder(
room=self._room,
participant=self._linked_participant,
track=self._subscribed_track,
)
if self._read_micro_atask is not None:
self._read_micro_atask.cancel()
self._read_micro_atask = asyncio.create_task(
self._micro_task(self._subscribed_track) # type: ignore
)
break
def _ensure_session(self) -> aiohttp.ClientSession:
if not self._http_session:
self._http_session = utils.http_context.http_session()
return self._http_session
def _emit_speech_committed(
self, speaker: Literal["user", "agent"], msg: str, interrupted: bool = False
):
if speaker == "user":
self.emit("user_speech_committed", msg)
else:
if interrupted:
self.emit("agent_speech_interrupted", msg)
else:
self.emit("agent_speech_committed", msg)
logger.debug(
f"committed {speaker} speech",
extra={
f"{speaker}_transcript": msg,
"interrupted": interrupted,
},
)
from .pipeline_agent import (
AgentCallContext,
AgentTranscriptionOptions,
VoicePipelineAgent,
)
__all__ = [
"VoicePipelineAgent",
"AgentCallContext",
"AgentTranscriptionOptions",
]
from __future__ import annotations
import asyncio
import inspect
from typing import Any, AsyncIterable, Awaitable, Callable, Union
from livekit import rtc
from .. import llm, tokenize, utils
from .. import transcription as agent_transcription
from .. import tts as text_to_speech
from .agent_playout import AgentPlayout, PlayoutHandle
from .log import logger
SpeechSource = Union[AsyncIterable[str], str, Awaitable[str]]
class SynthesisHandle:
def __init__(
self,
*,
speech_id: str,
tts_source: SpeechSource,
transcript_source: SpeechSource,
agent_playout: AgentPlayout,
tts: text_to_speech.TTS,
transcription_fwd: agent_transcription.TTSSegmentsForwarder,
) -> None:
(
self._tts_source,
self._transcript_source,
self._agent_playout,
self._tts,
self._tr_fwd,
) = (
tts_source,
transcript_source,
agent_playout,
tts,
transcription_fwd,
)
self._buf_ch = utils.aio.Chan[rtc.AudioFrame]()
self._play_handle: PlayoutHandle | None = None
self._interrupt_fut = asyncio.Future[None]()
self._speech_id = speech_id
@property
def speech_id(self) -> str:
return self._speech_id
@property
def tts_forwarder(self) -> agent_transcription.TTSSegmentsForwarder:
return self._tr_fwd
@property
def validated(self) -> bool:
return self._play_handle is not None
@property
def interrupted(self) -> bool:
return self._interrupt_fut.done()
@property
def play_handle(self) -> PlayoutHandle | None:
return self._play_handle
def play(self) -> PlayoutHandle:
"""Validate the speech for playout"""
if self.interrupted:
raise RuntimeError("synthesis was interrupted")
self._play_handle = self._agent_playout.play(
self._speech_id, self._buf_ch, transcription_fwd=self._tr_fwd
)
return self._play_handle
def interrupt(self) -> None:
"""Interrupt the speech"""
if self.interrupted:
return
logger.debug(
"agent interrupted",
extra={"speech_id": self.speech_id},
)
if self._play_handle is not None:
self._play_handle.interrupt()
self._interrupt_fut.set_result(None)
class AgentOutput:
def __init__(
self,
*,
room: rtc.Room,
agent_playout: AgentPlayout,
llm: llm.LLM,
tts: text_to_speech.TTS,
) -> None:
self._room, self._agent_playout, self._llm, self._tts = (
room,
agent_playout,
llm,
tts,
)
self._tasks = set[asyncio.Task[Any]]()
@property
def playout(self) -> AgentPlayout:
return self._agent_playout
async def aclose(self) -> None:
for task in self._tasks:
task.cancel()
await asyncio.gather(*self._tasks, return_exceptions=True)
def synthesize(
self,
*,
speech_id: str,
tts_source: SpeechSource,
transcript_source: SpeechSource,
transcription: bool,
transcription_speed: float,
sentence_tokenizer: tokenize.SentenceTokenizer,
word_tokenizer: tokenize.WordTokenizer,
hyphenate_word: Callable[[str], list[str]],
) -> SynthesisHandle:
def _before_forward(
fwd: agent_transcription.TTSSegmentsForwarder,
rtc_transcription: rtc.Transcription,
):
if not transcription:
rtc_transcription.segments = []
return rtc_transcription
transcription_fwd = agent_transcription.TTSSegmentsForwarder(
room=self._room,
participant=self._room.local_participant,
speed=transcription_speed,
sentence_tokenizer=sentence_tokenizer,
word_tokenizer=word_tokenizer,
hyphenate_word=hyphenate_word,
before_forward_cb=_before_forward,
)
handle = SynthesisHandle(
tts_source=tts_source,
transcript_source=transcript_source,
agent_playout=self._agent_playout,
tts=self._tts,
transcription_fwd=transcription_fwd,
speech_id=speech_id,
)
task = asyncio.create_task(self._synthesize_task(handle))
self._tasks.add(task)
task.add_done_callback(self._tasks.remove)
return handle
@utils.log_exceptions(logger=logger)
async def _synthesize_task(self, handle: SynthesisHandle) -> None:
"""Synthesize speech from the source"""
tts_source = handle._tts_source
transcript_source = handle._transcript_source
if isinstance(tts_source, Awaitable):
tts_source = await tts_source
if isinstance(transcript_source, Awaitable):
transcript_source = await transcript_source
tts_stream: AsyncIterable[str] | None = None
if isinstance(tts_source, str):
# wrap in async iterator
async def string_to_stream(text: str):
yield text
tts_stream = string_to_stream(tts_source)
else:
tts_stream = tts_source
co = self._stream_synthesis_task(tts_stream, transcript_source, handle)
synth = asyncio.create_task(co)
synth.add_done_callback(lambda _: handle._buf_ch.close())
try:
_ = await asyncio.wait(
[synth, handle._interrupt_fut], return_when=asyncio.FIRST_COMPLETED
)
finally:
await utils.aio.gracefully_cancel(synth)
@utils.log_exceptions(logger=logger)
async def _read_transcript_task(
self, transcript_source: AsyncIterable[str] | str, handle: SynthesisHandle
) -> None:
try:
if isinstance(transcript_source, str):
handle._tr_fwd.push_text(transcript_source)
else:
async for seg in transcript_source:
if not handle._tr_fwd.closed:
handle._tr_fwd.push_text(seg)
if not handle.tts_forwarder.closed:
handle.tts_forwarder.mark_text_segment_end()
finally:
if inspect.isasyncgen(transcript_source):
await transcript_source.aclose()
@utils.log_exceptions(logger=logger)
async def _stream_synthesis_task(
self,
tts_source: AsyncIterable[str],
transcript_source: AsyncIterable[str] | str,
handle: SynthesisHandle,
) -> None:
"""synthesize speech from streamed text"""
@utils.log_exceptions(logger=logger)
async def _read_generated_audio_task(
tts_stream: text_to_speech.SynthesizeStream,
) -> None:
try:
async for audio in tts_stream:
if not handle._tr_fwd.closed:
handle._tr_fwd.push_audio(audio.frame)
handle._buf_ch.send_nowait(audio.frame)
finally:
if handle._tr_fwd and not handle._tr_fwd.closed:
handle._tr_fwd.mark_audio_segment_end()
await tts_stream.aclose()
tts_stream: text_to_speech.SynthesizeStream | None = None
read_tts_atask: asyncio.Task | None = None
read_transcript_atask: asyncio.Task | None = None
try:
async for seg in tts_source:
if tts_stream is None:
tts_stream = handle._tts.stream()
read_tts_atask = asyncio.create_task(
_read_generated_audio_task(tts_stream)
)
read_transcript_atask = asyncio.create_task(
self._read_transcript_task(transcript_source, handle)
)
tts_stream.push_text(seg)
if tts_stream is not None:
tts_stream.end_input()
assert read_transcript_atask and read_tts_atask
await read_tts_atask
await read_transcript_atask
finally:
if read_tts_atask is not None:
assert read_transcript_atask is not None
await utils.aio.gracefully_cancel(read_tts_atask, read_transcript_atask)
if inspect.isasyncgen(tts_source):
await tts_source.aclose()
from __future__ import annotations
import asyncio
from typing import AsyncIterable, Literal
from livekit import rtc
from .. import transcription, utils
from .log import logger
EventTypes = Literal["playout_started", "playout_stopped"]
class PlayoutHandle:
def __init__(
self,
speech_id: str,
audio_source: rtc.AudioSource,
playout_source: AsyncIterable[rtc.AudioFrame],
transcription_fwd: transcription.TTSSegmentsForwarder,
) -> None:
self._playout_source = playout_source
self._audio_source = audio_source
self._tr_fwd = transcription_fwd
self._interrupted = False
self._int_fut = asyncio.Future[None]()
self._done_fut = asyncio.Future[None]()
self._speech_id = speech_id
self._pushed_duration = 0.0
self._total_played_time: float | None = None # set whem the playout is done
@property
def speech_id(self) -> str:
return self._speech_id
@property
def interrupted(self) -> bool:
return self._interrupted
@property
def time_played(self) -> float:
if self._total_played_time is not None:
return self._total_played_time
return self._pushed_duration - self._audio_source.queued_duration
def done(self) -> bool:
return self._done_fut.done() or self._interrupted
def interrupt(self) -> None:
if self.done():
return
self._int_fut.set_result(None)
self._interrupted = True
def join(self) -> asyncio.Future:
return self._done_fut
class AgentPlayout(utils.EventEmitter[EventTypes]):
def __init__(self, *, audio_source: rtc.AudioSource) -> None:
super().__init__()
self._audio_source = audio_source
self._target_volume = 1.0
self._playout_atask: asyncio.Task[None] | None = None
self._closed = False
@property
def target_volume(self) -> float:
return self._target_volume
@target_volume.setter
def target_volume(self, value: float) -> None:
self._target_volume = value
@property
def smoothed_volume(self) -> float:
return self._target_volume
async def aclose(self) -> None:
if self._closed:
return
self._closed = True
if self._playout_atask is not None:
await self._playout_atask
def play(
self,
speech_id: str,
playout_source: AsyncIterable[rtc.AudioFrame],
transcription_fwd: transcription.TTSSegmentsForwarder,
) -> PlayoutHandle:
if self._closed:
raise ValueError("cancellable source is closed")
handle = PlayoutHandle(
speech_id=speech_id,
audio_source=self._audio_source,
playout_source=playout_source,
transcription_fwd=transcription_fwd,
)
self._playout_atask = asyncio.create_task(
self._playout_task(self._playout_atask, handle)
)
return handle
@utils.log_exceptions(logger=logger)
async def _playout_task(
self, old_task: asyncio.Task[None] | None, handle: PlayoutHandle
) -> None:
if old_task is not None:
await utils.aio.gracefully_cancel(old_task)
if self._audio_source.queued_duration > 0:
# this should not happen, but log it just in case
logger.warning(
"new playout while the source is still playing",
extra={
"speech_id": handle.speech_id,
"queued_duration": self._audio_source.queued_duration,
},
)
first_frame = True
@utils.log_exceptions(logger=logger)
async def _capture_task():
nonlocal first_frame
async for frame in handle._playout_source:
if first_frame:
handle._tr_fwd.segment_playout_started()
logger.debug(
"speech playout started",
extra={"speech_id": handle.speech_id},
)
self.emit("playout_started")
first_frame = False
handle._pushed_duration += frame.samples_per_channel / frame.sample_rate
await self._audio_source.capture_frame(frame)
if self._audio_source.queued_duration > 0:
await self._audio_source.wait_for_playout()
capture_task = asyncio.create_task(_capture_task())
try:
await asyncio.wait(
[capture_task, handle._int_fut],
return_when=asyncio.FIRST_COMPLETED,
)
finally:
await utils.aio.gracefully_cancel(capture_task)
handle._total_played_time = (
handle._pushed_duration - self._audio_source.queued_duration
)
if handle.interrupted or capture_task.exception():
self._audio_source.clear_queue() # make sure to remove any queued frames
if not first_frame:
if not handle.interrupted:
handle._tr_fwd.segment_playout_finished()
self.emit("playout_stopped", handle.interrupted)
await handle._tr_fwd.aclose()
handle._done_fut.set_result(None)
logger.debug(
"speech playout finished",
extra={
"speech_id": handle.speech_id,
"interrupted": handle.interrupted,
},
)
from __future__ import annotations
import asyncio
from typing import Literal
from livekit import rtc
from .. import stt as speech_to_text
from .. import transcription, utils
from .. import vad as voice_activity_detection
from .log import logger
EventTypes = Literal[
"start_of_speech",
"vad_inference_done",
"end_of_speech",
"final_transcript",
"interim_transcript",
]
class HumanInput(utils.EventEmitter[EventTypes]):
def __init__(
self,
*,
room: rtc.Room,
vad: voice_activity_detection.VAD,
stt: speech_to_text.STT,
participant: rtc.RemoteParticipant,
transcription: bool,
noise_cancellation: rtc.NoiseCancellationOptions | None = None,
) -> None:
super().__init__()
self._room, self._vad, self._stt, self._participant, self._transcription = (
room,
vad,
stt,
participant,
transcription,
)
self._noise_cancellation = noise_cancellation
self._subscribed_track: rtc.RemoteAudioTrack | None = None
self._recognize_atask: asyncio.Task[None] | None = None
self._closed = False
self._speaking = False
self._speech_probability = 0.0
self._room.on("track_published", self._subscribe_to_microphone)
self._room.on("track_subscribed", self._subscribe_to_microphone)
self._subscribe_to_microphone()
async def aclose(self) -> None:
if self._closed:
raise RuntimeError("HumanInput already closed")
self._closed = True
self._room.off("track_published", self._subscribe_to_microphone)
self._room.off("track_subscribed", self._subscribe_to_microphone)
self._speaking = False
if self._recognize_atask is not None:
await utils.aio.gracefully_cancel(self._recognize_atask)
@property
def speaking(self) -> bool:
return self._speaking
@property
def speaking_probability(self) -> float:
return self._speech_probability
def _subscribe_to_microphone(self, *args, **kwargs) -> None:
"""
Subscribe to the participant microphone if found and not already subscribed.
Do nothing if no track is found.
"""
for publication in self._participant.track_publications.values():
if publication.source != rtc.TrackSource.SOURCE_MICROPHONE:
continue
if not publication.subscribed:
publication.set_subscribed(True)
track: rtc.RemoteAudioTrack | None = publication.track # type: ignore
if track is not None and track != self._subscribed_track:
self._subscribed_track = track
if self._recognize_atask is not None:
self._recognize_atask.cancel()
self._recognize_atask = asyncio.create_task(
self._recognize_task(
rtc.AudioStream(
track,
sample_rate=16000,
noise_cancellation=self._noise_cancellation,
)
)
)
break
@utils.log_exceptions(logger=logger)
async def _recognize_task(self, audio_stream: rtc.AudioStream) -> None:
"""
Receive the frames from the user audio stream and detect voice activity.
"""
vad_stream = self._vad.stream()
stt_stream = self._stt.stream()
def _before_forward(
fwd: transcription.STTSegmentsForwarder, transcription: rtc.Transcription
):
if not self._transcription:
transcription.segments = []
return transcription
stt_forwarder = transcription.STTSegmentsForwarder(
room=self._room,
participant=self._participant,
track=self._subscribed_track,
before_forward_cb=_before_forward,
)
async def _audio_stream_co() -> None:
# forward the audio stream to the VAD and STT streams
async for ev in audio_stream:
stt_stream.push_frame(ev.frame)
vad_stream.push_frame(ev.frame)
async def _vad_stream_co() -> None:
async for ev in vad_stream:
if ev.type == voice_activity_detection.VADEventType.START_OF_SPEECH:
self._speaking = True
self.emit("start_of_speech", ev)
elif ev.type == voice_activity_detection.VADEventType.INFERENCE_DONE:
self._speech_probability = ev.probability
self.emit("vad_inference_done", ev)
elif ev.type == voice_activity_detection.VADEventType.END_OF_SPEECH:
self._speaking = False
self.emit("end_of_speech", ev)
async def _stt_stream_co() -> None:
async for ev in stt_stream:
stt_forwarder.update(ev)
if ev.type == speech_to_text.SpeechEventType.FINAL_TRANSCRIPT:
self.emit("final_transcript", ev)
elif ev.type == speech_to_text.SpeechEventType.INTERIM_TRANSCRIPT:
self.emit("interim_transcript", ev)
tasks = [
asyncio.create_task(_audio_stream_co()),
asyncio.create_task(_vad_stream_co()),
asyncio.create_task(_stt_stream_co()),
]
try:
await asyncio.gather(*tasks)
finally:
await utils.aio.gracefully_cancel(*tasks)
await stt_forwarder.aclose()
await stt_stream.aclose()
await vad_stream.aclose()
import logging
logger = logging.getLogger("livekit.agents.pipeline")
from __future__ import annotations
import asyncio
import contextvars
import time
from dataclasses import dataclass
from typing import (
Any,
AsyncGenerator,
AsyncIterable,
Awaitable,
Callable,
Literal,
Optional,
Protocol,
Union,
)
from livekit import rtc
from .. import metrics, stt, tokenize, tts, utils, vad
from ..llm import LLM, ChatContext, ChatMessage, FunctionContext, LLMStream
from ..types import ATTRIBUTE_AGENT_STATE, AgentState
from .agent_output import AgentOutput, SpeechSource, SynthesisHandle
from .agent_playout import AgentPlayout
from .human_input import HumanInput
from .log import logger
from .plotter import AssistantPlotter
from .speech_handle import SpeechHandle
BeforeLLMCallback = Callable[
["VoicePipelineAgent", ChatContext],
Union[
Optional[LLMStream],
Awaitable[Optional[LLMStream]],
Literal[False],
Awaitable[Literal[False]],
],
]
WillSynthesizeAssistantReply = BeforeLLMCallback
BeforeTTSCallback = Callable[
["VoicePipelineAgent", Union[str, AsyncIterable[str]]],
SpeechSource,
]
EventTypes = Literal[
"user_started_speaking",
"user_stopped_speaking",
"agent_started_speaking",
"agent_stopped_speaking",
"user_speech_committed",
"agent_speech_committed",
"agent_speech_interrupted",
"function_calls_collected",
"function_calls_finished",
"metrics_collected",
]
_CallContextVar = contextvars.ContextVar["AgentCallContext"](
"voice_assistant_contextvar"
)
class AgentCallContext:
def __init__(self, assistant: "VoicePipelineAgent", llm_stream: LLMStream) -> None:
self._assistant = assistant
self._metadata = dict[str, Any]()
self._llm_stream = llm_stream
self._extra_chat_messages: list[ChatMessage] = []
@staticmethod
def get_current() -> "AgentCallContext":
return _CallContextVar.get()
@property
def agent(self) -> "VoicePipelineAgent":
return self._assistant
@property
def chat_ctx(self) -> ChatContext:
return self._llm_stream.chat_ctx
def store_metadata(self, key: str, value: Any) -> None:
self._metadata[key] = value
def get_metadata(self, key: str, default: Any = None) -> Any:
return self._metadata.get(key, default)
def llm_stream(self) -> LLMStream:
return self._llm_stream
def add_extra_chat_message(self, message: ChatMessage) -> None:
"""Append chat message to the end of function outputs for the answer LLM call"""
self._extra_chat_messages.append(message)
@property
def extra_chat_messages(self) -> list[ChatMessage]:
return self._extra_chat_messages
def _default_before_llm_cb(
agent: VoicePipelineAgent, chat_ctx: ChatContext
) -> LLMStream:
return agent.llm.chat(
chat_ctx=chat_ctx,
fnc_ctx=agent.fnc_ctx,
)
@dataclass
class SpeechData:
sequence_id: str
SpeechDataContextVar = contextvars.ContextVar[SpeechData]("voice_assistant_speech_data")
def _default_before_tts_cb(
agent: VoicePipelineAgent, text: str | AsyncIterable[str]
) -> str | AsyncIterable[str]:
return text
@dataclass(frozen=True)
class _ImplOptions:
allow_interruptions: bool
int_speech_duration: float
int_min_words: int
min_endpointing_delay: float
max_endpointing_delay: float
max_nested_fnc_calls: int
preemptive_synthesis: bool
before_llm_cb: BeforeLLMCallback
before_tts_cb: BeforeTTSCallback
plotting: bool
transcription: AgentTranscriptionOptions
@dataclass(frozen=True)
class AgentTranscriptionOptions:
user_transcription: bool = True
"""Whether to forward the user transcription to the client"""
agent_transcription: bool = True
"""Whether to forward the agent transcription to the client"""
agent_transcription_speed: float = 1.0
"""The speed at which the agent's speech transcription is forwarded to the client.
We try to mimic the agent's speech speed by adjusting the transcription speed."""
sentence_tokenizer: tokenize.SentenceTokenizer = tokenize.basic.SentenceTokenizer()
"""The tokenizer used to split the speech into sentences.
This is used to decide when to mark a transcript as final for the agent transcription."""
word_tokenizer: tokenize.WordTokenizer = tokenize.basic.WordTokenizer(
ignore_punctuation=False
)
"""The tokenizer used to split the speech into words.
This is used to simulate the "interim results" of the agent transcription."""
hyphenate_word: Callable[[str], list[str]] = tokenize.basic.hyphenate_word
"""A function that takes a string (word) as input and returns a list of strings,
representing the hyphenated parts of the word."""
class _TurnDetector(Protocol):
# When endpoint probability is below this threshold we think the user is not finished speaking
# so we will use a long delay
def unlikely_threshold(self, language: str | None) -> float: ...
def supports_language(self, language: str | None) -> bool: ...
async def predict_end_of_turn(self, chat_ctx: ChatContext) -> float: ...
class VoicePipelineAgent(utils.EventEmitter[EventTypes]):
"""
A pipeline agent (VAD + STT + LLM + TTS) implementation.
"""
MIN_TIME_PLAYED_FOR_COMMIT = 1.5
"""Minimum time played for the user speech to be committed to the chat context"""
def __init__(
self,
*,
vad: vad.VAD,
stt: stt.STT,
llm: LLM,
tts: tts.TTS,
noise_cancellation: rtc.NoiseCancellationOptions | None = None,
turn_detector: _TurnDetector | None = None,
chat_ctx: ChatContext | None = None,
fnc_ctx: FunctionContext | None = None,
allow_interruptions: bool = True,
interrupt_speech_duration: float = 0.5,
interrupt_min_words: int = 0,
min_endpointing_delay: float = 0.5,
max_endpointing_delay: float = 6.0,
max_nested_fnc_calls: int = 1,
preemptive_synthesis: bool = False,
transcription: AgentTranscriptionOptions = AgentTranscriptionOptions(),
before_llm_cb: BeforeLLMCallback = _default_before_llm_cb,
before_tts_cb: BeforeTTSCallback = _default_before_tts_cb,
plotting: bool = False,
loop: asyncio.AbstractEventLoop | None = None,
# backward compatibility
will_synthesize_assistant_reply: WillSynthesizeAssistantReply | None = None,
) -> None:
"""
Create a new VoicePipelineAgent.
Args:
vad: Voice Activity Detection (VAD) instance.
stt: Speech-to-Text (STT) instance.
llm: Large Language Model (LLM) instance.
tts: Text-to-Speech (TTS) instance.
chat_ctx: Chat context for the assistant.
fnc_ctx: Function context for the assistant.
allow_interruptions: Whether to allow the user to interrupt the assistant.
interrupt_speech_duration: Minimum duration of speech to consider for interruption.
interrupt_min_words: Minimum number of words to consider for interruption.
Defaults to 0 as this may increase the latency depending on the STT.
min_endpointing_delay: Delay to wait before considering the user finished speaking.
max_nested_fnc_calls: Maximum number of nested function calls allowed for chaining
function calls (e.g functions that depend on each other).
preemptive_synthesis: Whether to preemptively synthesize responses.
transcription: Options for assistant transcription.
before_llm_cb: Callback called when the assistant is about to synthesize a reply.
This can be used to customize the reply (e.g: inject context/RAG).
Returning None will create a default LLM stream. You can also return your own llm
stream by calling the llm.chat() method.
Returning False will cancel the synthesis of the reply.
before_tts_cb: Callback called when the assistant is about to
synthesize a speech. This can be used to customize text before the speech synthesis.
(e.g: editing the pronunciation of a word).
plotting: Whether to enable plotting for debugging. matplotlib must be installed.
loop: Event loop to use. Default to asyncio.get_event_loop().
"""
super().__init__()
self._loop = loop or asyncio.get_event_loop()
if will_synthesize_assistant_reply is not None:
logger.warning(
"will_synthesize_assistant_reply is deprecated and will be removed in 1.5.0, use before_llm_cb instead",
)
before_llm_cb = will_synthesize_assistant_reply
self._opts = _ImplOptions(
plotting=plotting,
allow_interruptions=allow_interruptions,
int_speech_duration=interrupt_speech_duration,
int_min_words=interrupt_min_words,
min_endpointing_delay=min_endpointing_delay,
max_endpointing_delay=max_endpointing_delay,
max_nested_fnc_calls=max_nested_fnc_calls,
preemptive_synthesis=preemptive_synthesis,
transcription=transcription,
before_llm_cb=before_llm_cb,
before_tts_cb=before_tts_cb,
)
self._plotter = AssistantPlotter(self._loop)
# wrap with StreamAdapter automatically when streaming is not supported on a specific TTS/STT.
# To override StreamAdapter options, create the adapter manually.
if not tts.capabilities.streaming:
from .. import tts as text_to_speech
tts = text_to_speech.StreamAdapter(
tts=tts, sentence_tokenizer=tokenize.basic.SentenceTokenizer()
)
if not stt.capabilities.streaming:
from .. import stt as speech_to_text
stt = speech_to_text.StreamAdapter(
stt=stt,
vad=vad,
)
self._stt, self._vad, self._llm, self._tts = stt, vad, llm, tts
self._turn_detector = turn_detector
self._chat_ctx = chat_ctx or ChatContext()
self._fnc_ctx = fnc_ctx
self._started, self._closed = False, False
self._human_input: HumanInput | None = None
self._agent_output: AgentOutput | None = None
# done when the agent output track is published
self._track_published_fut = asyncio.Future[None]()
self._pending_agent_reply: SpeechHandle | None = None
self._agent_reply_task: asyncio.Task[None] | None = None
self._playing_speech: SpeechHandle | None = None
self._transcribed_text, self._transcribed_interim_text = "", ""
self._deferred_validation = _DeferredReplyValidation(
self._validate_reply_if_possible,
min_endpointing_delay=self._opts.min_endpointing_delay,
max_endpointing_delay=self._opts.max_endpointing_delay,
turn_detector=self._turn_detector,
agent=self,
)
self._speech_q: list[SpeechHandle] = []
self._speech_q_changed = asyncio.Event()
self._update_state_task: asyncio.Task | None = None
self._last_final_transcript_time: float | None = None
self._last_speech_time: float | None = None
self._noise_cancellation = noise_cancellation
@property
def fnc_ctx(self) -> FunctionContext | None:
return self._fnc_ctx
@fnc_ctx.setter
def fnc_ctx(self, fnc_ctx: FunctionContext | None) -> None:
self._fnc_ctx = fnc_ctx
@property
def chat_ctx(self) -> ChatContext:
return self._chat_ctx
@property
def llm(self) -> LLM:
return self._llm
@property
def tts(self) -> tts.TTS:
return self._tts
@property
def stt(self) -> stt.STT:
return self._stt
@property
def vad(self) -> vad.VAD:
return self._vad
def start(
self, room: rtc.Room, participant: rtc.RemoteParticipant | str | None = None
) -> None:
"""Start the voice assistant
Args:
room: the room to use
participant: the participant to listen to, can either be a participant or a participant identity
If None, the first participant found in the room will be selected
"""
if self._started:
raise RuntimeError("voice assistant already started")
@self._stt.on("metrics_collected")
def _on_stt_metrics(stt_metrics: metrics.STTMetrics) -> None:
self.emit(
"metrics_collected",
metrics.PipelineSTTMetrics(
**stt_metrics.__dict__,
),
)
@self._tts.on("metrics_collected")
def _on_tts_metrics(tts_metrics: metrics.TTSMetrics) -> None:
speech_data = SpeechDataContextVar.get(None)
if speech_data is None:
return
self.emit(
"metrics_collected",
metrics.PipelineTTSMetrics(
**tts_metrics.__dict__,
sequence_id=speech_data.sequence_id,
),
)
@self._llm.on("metrics_collected")
def _on_llm_metrics(llm_metrics: metrics.LLMMetrics) -> None:
speech_data = SpeechDataContextVar.get(None)
if speech_data is None:
return
self.emit(
"metrics_collected",
metrics.PipelineLLMMetrics(
**llm_metrics.__dict__,
sequence_id=speech_data.sequence_id,
),
)
@self._vad.on("metrics_collected")
def _on_vad_metrics(vad_metrics: vad.VADMetrics) -> None:
self.emit(
"metrics_collected", metrics.PipelineVADMetrics(**vad_metrics.__dict__)
)
room.on("participant_connected", self._on_participant_connected)
self._room, self._participant = room, participant
if participant is not None:
if isinstance(participant, rtc.RemoteParticipant):
self._link_participant(participant.identity)
else:
self._link_participant(participant)
else:
# no participant provided, try to find the first participant in the room
for participant in self._room.remote_participants.values():
self._link_participant(participant.identity)
break
self._main_atask = asyncio.create_task(self._main_task())
def on(self, event: EventTypes, callback: Callable[[Any], None] | None = None):
"""Register a callback for an event
Args:
event: the event to listen to (see EventTypes)
- user_started_speaking: the user started speaking
- user_stopped_speaking: the user stopped speaking
- agent_started_speaking: the agent started speaking
- agent_stopped_speaking: the agent stopped speaking
- user_speech_committed: the user speech was committed to the chat context
- agent_speech_committed: the agent speech was committed to the chat context
- agent_speech_interrupted: the agent speech was interrupted
- function_calls_collected: received the complete set of functions to be executed
- function_calls_finished: all function calls have been completed
callback: the callback to call when the event is emitted
"""
return super().on(event, callback)
async def say(
self,
source: str | LLMStream | AsyncIterable[str],
*,
allow_interruptions: bool = True,
add_to_chat_ctx: bool = True,
) -> SpeechHandle:
"""
Play a speech source through the voice assistant.
Args:
source: The source of the speech to play.
It can be a string, an LLMStream, or an asynchronous iterable of strings.
allow_interruptions: Whether to allow interruptions during the speech playback.
add_to_chat_ctx: Whether to add the speech to the chat context.
Returns:
The speech handle for the speech that was played, can be used to
wait for the speech to finish.
"""
await self._track_published_fut
call_ctx = None
fnc_source: str | AsyncIterable[str] | None = None
if add_to_chat_ctx:
try:
call_ctx = AgentCallContext.get_current()
except LookupError:
# no active call context, ignore
pass
else:
if isinstance(source, LLMStream):
logger.warning(
"LLMStream will be ignored for function call chat context"
)
elif isinstance(source, AsyncIterable):
source, fnc_source = utils.aio.itertools.tee(source, 2) # type: ignore
else:
fnc_source = source
new_handle = SpeechHandle.create_assistant_speech(
allow_interruptions=allow_interruptions, add_to_chat_ctx=add_to_chat_ctx
)
synthesis_handle = self._synthesize_agent_speech(new_handle.id, source)
new_handle.initialize(source=source, synthesis_handle=synthesis_handle)
if self._playing_speech and not self._playing_speech.nested_speech_done:
self._playing_speech.add_nested_speech(new_handle)
elif self._speech_q:
self._speech_q[0].add_nested_speech(new_handle)
else:
self._add_speech_for_playout(new_handle)
# add the speech to the function call context if needed
if call_ctx is not None and fnc_source is not None:
if isinstance(fnc_source, AsyncIterable):
text = ""
async for chunk in fnc_source:
text += chunk
else:
text = fnc_source
call_ctx.add_extra_chat_message(
ChatMessage.create(text=text, role="assistant")
)
logger.debug(
"added speech to function call chat context",
extra={"text": text},
)
return new_handle
def interrupt(self, interrupt_all: bool = True) -> None:
"""Interrupt the current speech
Args:
interrupt_all: Whether to interrupt all pending speech
"""
if interrupt_all:
# interrupt all pending speech
if self._pending_agent_reply is not None:
self._pending_agent_reply.cancel(cancel_nested=True)
for speech in self._speech_q:
speech.cancel(cancel_nested=True)
# interrupt the playing speech
if self._playing_speech is not None:
self._playing_speech.cancel()
def _update_state(self, state: AgentState, delay: float = 0.0):
"""Set the current state of the agent"""
@utils.log_exceptions(logger=logger)
async def _run_task(delay: float) -> None:
await asyncio.sleep(delay)
if self._room.isconnected():
await self._room.local_participant.set_attributes(
{ATTRIBUTE_AGENT_STATE: state}
)
if self._update_state_task is not None:
self._update_state_task.cancel()
self._update_state_task = asyncio.create_task(_run_task(delay))
async def aclose(self) -> None:
"""Close the voice assistant"""
if not self._started:
return
self._room.off("participant_connected", self._on_participant_connected)
await self._deferred_validation.aclose()
def _on_participant_connected(self, participant: rtc.RemoteParticipant):
if self._human_input is not None:
return
self._link_participant(participant.identity)
def _link_participant(self, identity: str) -> None:
participant = self._room.remote_participants.get(identity)
if participant is None:
logger.error("_link_participant must be called with a valid identity")
return
self._human_input = HumanInput(
room=self._room,
vad=self._vad,
stt=self._stt,
participant=participant,
transcription=self._opts.transcription.user_transcription,
noise_cancellation=self._noise_cancellation,
)
def _on_start_of_speech(ev: vad.VADEvent) -> None:
self._plotter.plot_event("user_started_speaking")
self.emit("user_started_speaking")
self._deferred_validation.on_human_start_of_speech(ev)
def _on_vad_inference_done(ev: vad.VADEvent) -> None:
if not self._track_published_fut.done():
return
assert self._agent_output is not None
tv = 1.0
if self._opts.allow_interruptions:
tv = max(0.0, 1.0 - ev.probability)
self._agent_output.playout.target_volume = tv
smoothed_tv = self._agent_output.playout.smoothed_volume
self._plotter.plot_value("raw_vol", tv)
self._plotter.plot_value("smoothed_vol", smoothed_tv)
self._plotter.plot_value("vad_probability", ev.probability)
if ev.speech_duration >= self._opts.int_speech_duration:
self._interrupt_if_possible()
if ev.raw_accumulated_speech > 0.0:
self._last_speech_time = (
time.perf_counter() - ev.raw_accumulated_silence
)
def _on_end_of_speech(ev: vad.VADEvent) -> None:
self._plotter.plot_event("user_stopped_speaking")
self.emit("user_stopped_speaking")
self._deferred_validation.on_human_end_of_speech(ev)
def _on_interim_transcript(ev: stt.SpeechEvent) -> None:
self._transcribed_interim_text = ev.alternatives[0].text
def _on_final_transcript(ev: stt.SpeechEvent) -> None:
new_transcript = ev.alternatives[0].text
if not new_transcript:
return
logger.debug(
"received user transcript",
extra={"user_transcript": new_transcript},
)
self._last_final_transcript_time = time.perf_counter()
self._transcribed_text += (
" " if self._transcribed_text else ""
) + new_transcript
if self._opts.preemptive_synthesis:
if (
self._playing_speech is None
or self._playing_speech.allow_interruptions
):
self._synthesize_agent_reply()
self._deferred_validation.on_human_final_transcript(
new_transcript, ev.alternatives[0].language
)
words = self._opts.transcription.word_tokenizer.tokenize(
text=new_transcript
)
if len(words) >= 3:
# VAD can sometimes not detect that the human is speaking
# to make the interruption more reliable, we also interrupt on the final transcript.
self._interrupt_if_possible()
self._human_input.on("start_of_speech", _on_start_of_speech)
self._human_input.on("vad_inference_done", _on_vad_inference_done)
self._human_input.on("end_of_speech", _on_end_of_speech)
self._human_input.on("interim_transcript", _on_interim_transcript)
self._human_input.on("final_transcript", _on_final_transcript)
@utils.log_exceptions(logger=logger)
async def _main_task(self) -> None:
if self._opts.plotting:
await self._plotter.start()
self._update_state("initializing")
audio_source = rtc.AudioSource(self._tts.sample_rate, self._tts.num_channels)
track = rtc.LocalAudioTrack.create_audio_track("assistant_voice", audio_source)
self._agent_publication = await self._room.local_participant.publish_track(
track, rtc.TrackPublishOptions(source=rtc.TrackSource.SOURCE_MICROPHONE)
)
agent_playout = AgentPlayout(audio_source=audio_source)
self._agent_output = AgentOutput(
room=self._room,
agent_playout=agent_playout,
llm=self._llm,
tts=self._tts,
)
def _on_playout_started() -> None:
self._plotter.plot_event("agent_started_speaking")
self.emit("agent_started_speaking")
self._update_state("speaking")
def _on_playout_stopped(interrupted: bool) -> None:
self._plotter.plot_event("agent_stopped_speaking")
self.emit("agent_stopped_speaking")
self._update_state("listening")
agent_playout.on("playout_started", _on_playout_started)
agent_playout.on("playout_stopped", _on_playout_stopped)
self._track_published_fut.set_result(None)
while True:
await self._speech_q_changed.wait()
while self._speech_q:
speech = self._speech_q[0]
self._playing_speech = speech
await self._play_speech(speech)
self._speech_q.pop(0) # Remove the element only after playing
self._playing_speech = None
self._speech_q_changed.clear()
def _synthesize_agent_reply(self):
"""Synthesize the agent reply to the user question, also make sure only one reply
is synthesized/played at a time"""
if self._pending_agent_reply is not None:
self._pending_agent_reply.cancel()
if self._human_input is not None and not self._human_input.speaking:
self._update_state("thinking", 0.2)
self._pending_agent_reply = new_handle = SpeechHandle.create_assistant_reply(
allow_interruptions=self._opts.allow_interruptions,
add_to_chat_ctx=True,
user_question=self._transcribed_text,
)
self._agent_reply_task = asyncio.create_task(
self._synthesize_answer_task(self._agent_reply_task, new_handle)
)
self._agent_reply_task.add_done_callback(
lambda t: new_handle.cancel() if t.cancelled() else None
)
@utils.log_exceptions(logger=logger)
async def _synthesize_answer_task(
self, old_task: asyncio.Task[None], handle: SpeechHandle
) -> None:
if old_task is not None:
await utils.aio.gracefully_cancel(old_task)
copied_ctx = self._chat_ctx.copy()
playing_speech = self._playing_speech
if playing_speech is not None and playing_speech.initialized:
if (
not playing_speech.user_question or playing_speech.user_committed
) and not playing_speech.speech_committed:
# the speech is playing but not committed yet, add it to the chat context for this new reply synthesis
# First add the previous function call message if any
if playing_speech.extra_tools_messages:
if playing_speech.fnc_text_message_id is not None:
# there is a message alongside the function calls
msgs = copied_ctx.messages
if msgs and msgs[-1].id == playing_speech.fnc_text_message_id:
# replace it with the tool call message if it's the last in the ctx
msgs.pop()
copied_ctx.messages.extend(playing_speech.extra_tools_messages)
# Then add the previous assistant message
copied_ctx.messages.append(
ChatMessage.create(
text=playing_speech.synthesis_handle.tts_forwarder.played_text,
role="assistant",
)
)
# when user_question is empty, it's due to a false positive interruption
# when this happens, we'd want to add a continue marker to the chat context.
# while some LLMs could deal with empty content during an inference request
# others would fail.
user_input = handle.user_question
if not user_input.strip():
user_input = "<continue>"
copied_ctx.messages.append(ChatMessage.create(text=user_input, role="user"))
tk = SpeechDataContextVar.set(SpeechData(sequence_id=handle.id))
try:
llm_stream = self._opts.before_llm_cb(self, copied_ctx)
if asyncio.iscoroutine(llm_stream):
llm_stream = await llm_stream
if llm_stream is False:
# user chose not to synthesize an answer, so we do not want to
# leave the same question in chat context. otherwise it would be
# unintentionally committed when the next set of speech comes in.
if len(self._transcribed_text) >= len(handle.user_question):
self._transcribed_text = self._transcribed_text[
len(handle.user_question) :
]
handle.cancel()
return
# fallback to default impl if no custom/user stream is returned
if not isinstance(llm_stream, LLMStream):
llm_stream = _default_before_llm_cb(self, chat_ctx=copied_ctx)
if handle.interrupted:
return
synthesis_handle = self._synthesize_agent_speech(handle.id, llm_stream)
handle.initialize(source=llm_stream, synthesis_handle=synthesis_handle)
finally:
SpeechDataContextVar.reset(tk)
async def _play_speech(self, speech_handle: SpeechHandle) -> None:
await self._agent_publication.wait_for_subscription()
fnc_done_fut = asyncio.Future[None]()
playing_lock = asyncio.Lock()
nested_speech_played = asyncio.Event()
async def _play_nested_speech():
speech_handle._nested_speech_done_fut = asyncio.Future[None]()
while not speech_handle.nested_speech_done:
nesting_changed = asyncio.create_task(
speech_handle.nested_speech_changed.wait()
)
nesting_done_fut: asyncio.Future = speech_handle._nested_speech_done_fut
await asyncio.wait(
[nesting_changed, nesting_done_fut, fnc_done_fut],
return_when=asyncio.FIRST_COMPLETED,
)
if not nesting_changed.done():
nesting_changed.cancel()
while speech_handle.nested_speech_handles:
nested_speech_played.clear()
speech = speech_handle.nested_speech_handles[0]
if speech_handle.nested_speech_done:
# in case tool speech is added after nested speech done
speech.cancel(cancel_nested=True)
speech_handle.nested_speech_handles.pop(0)
continue
async with playing_lock:
self._playing_speech = speech
await self._play_speech(speech)
speech_handle.nested_speech_handles.pop(0)
self._playing_speech = speech_handle
nested_speech_played.set()
speech_handle.nested_speech_changed.clear()
# break if the function calls task is done
if fnc_done_fut.done():
speech_handle.mark_nested_speech_done()
nested_speech_task = asyncio.create_task(_play_nested_speech())
async def _stop_nesting_speech():
fnc_done_fut.set_result(None)
await nested_speech_task
try:
await speech_handle.wait_for_initialization()
except asyncio.CancelledError:
await _stop_nesting_speech()
return
# wait for all pre-added nested speech to be played
while speech_handle.nested_speech_handles:
await nested_speech_played.wait()
await playing_lock.acquire()
synthesis_handle = speech_handle.synthesis_handle
if synthesis_handle.interrupted:
playing_lock.release()
await _stop_nesting_speech()
return
user_question = speech_handle.user_question
play_handle = synthesis_handle.play()
join_fut = play_handle.join()
def _commit_user_question_if_needed() -> None:
if (
not user_question
or synthesis_handle.interrupted
or speech_handle.user_committed
):
return
is_using_tools = isinstance(speech_handle.source, LLMStream) and len(
speech_handle.source.function_calls
)
# make sure at least some speech was played before committing the user message
# since we try to validate as fast as possible it is possible the agent gets interrupted
# really quickly (barely audible), we don't want to mark this question as "answered".
if (
speech_handle.allow_interruptions
and not is_using_tools
and (
play_handle.time_played < self.MIN_TIME_PLAYED_FOR_COMMIT
and not join_fut.done()
)
):
return
user_msg = ChatMessage.create(text=user_question, role="user")
self._chat_ctx.messages.append(user_msg)
self.emit("user_speech_committed", user_msg)
self._transcribed_text = self._transcribed_text[len(user_question) :]
speech_handle.mark_user_committed()
# wait for the play_handle to finish and check every 1s if the user question should be committed
_commit_user_question_if_needed()
while not join_fut.done():
await asyncio.wait(
[join_fut], return_when=asyncio.FIRST_COMPLETED, timeout=0.2
)
_commit_user_question_if_needed()
if speech_handle.interrupted:
break
_commit_user_question_if_needed()
collected_text = speech_handle.synthesis_handle.tts_forwarder.played_text
interrupted = speech_handle.interrupted
is_using_tools = isinstance(speech_handle.source, LLMStream) and len(
speech_handle.source.function_calls
)
# add tool calls and text message to the chat context
message_id_committed: str | None = None
if speech_handle.add_to_chat_ctx and (
not user_question or speech_handle.user_committed
):
if speech_handle.extra_tools_messages:
if speech_handle.fnc_text_message_id is not None:
# there is a message alongside the function calls
msgs = self._chat_ctx.messages
if msgs and msgs[-1].id == speech_handle.fnc_text_message_id:
# replace it with the tool call message if it's the last in the ctx
msgs.pop()
elif speech_handle.extra_tools_messages[0].tool_calls:
# remove the content of the tool call message
speech_handle.extra_tools_messages[0].content = ""
self._chat_ctx.messages.extend(speech_handle.extra_tools_messages)
if collected_text:
if interrupted:
collected_text += "..."
msg = ChatMessage.create(text=collected_text, role="assistant")
self._chat_ctx.messages.append(msg)
message_id_committed = msg.id
speech_handle.mark_speech_committed()
if interrupted:
self.emit("agent_speech_interrupted", msg)
else:
self.emit("agent_speech_committed", msg)
logger.debug(
"committed agent speech",
extra={
"agent_transcript": collected_text,
"interrupted": interrupted,
"speech_id": speech_handle.id,
},
)
playing_lock.release()
@utils.log_exceptions(logger=logger)
async def _execute_function_calls() -> None:
nonlocal interrupted, collected_text
# if the answer is using tools, execute the functions and automatically generate
# a response to the user question from the returned values
if not is_using_tools or interrupted:
return
if speech_handle.fnc_nested_depth >= self._opts.max_nested_fnc_calls:
logger.warning(
"max function calls nested depth reached",
extra={
"speech_id": speech_handle.id,
"fnc_nested_depth": speech_handle.fnc_nested_depth,
},
)
return
assert isinstance(speech_handle.source, LLMStream)
assert not user_question or speech_handle.user_committed, (
"user speech should have been committed before using tools"
)
llm_stream = speech_handle.source
# execute functions
call_ctx = AgentCallContext(self, llm_stream)
tk = _CallContextVar.set(call_ctx)
new_function_calls = llm_stream.function_calls
self.emit("function_calls_collected", new_function_calls)
called_fncs = []
for fnc in new_function_calls:
called_fnc = fnc.execute()
called_fncs.append(called_fnc)
logger.debug(
"executing ai function",
extra={
"function": fnc.function_info.name,
"speech_id": speech_handle.id,
},
)
try:
await called_fnc.task
except Exception as e:
logger.exception(
"error executing ai function",
extra={
"function": fnc.function_info.name,
"speech_id": speech_handle.id,
},
exc_info=e,
)
tool_calls_info = []
tool_calls_results = []
for called_fnc in called_fncs:
# ignore the function calls that returns None
if called_fnc.result is None and called_fnc.exception is None:
continue
tool_calls_info.append(called_fnc.call_info)
tool_calls_results.append(
ChatMessage.create_tool_from_called_function(called_fnc)
)
if not tool_calls_info:
return
# create a nested speech handle
extra_tools_messages = [
ChatMessage.create_tool_calls(tool_calls_info, text=collected_text)
]
extra_tools_messages.extend(tool_calls_results)
new_speech_handle = SpeechHandle.create_tool_speech(
allow_interruptions=speech_handle.allow_interruptions,
add_to_chat_ctx=speech_handle.add_to_chat_ctx,
extra_tools_messages=extra_tools_messages,
fnc_nested_depth=speech_handle.fnc_nested_depth + 1,
fnc_text_message_id=message_id_committed,
)
# synthesize the tool speech with the chat ctx from llm_stream
chat_ctx = call_ctx.chat_ctx.copy()
chat_ctx.messages.extend(extra_tools_messages)
chat_ctx.messages.extend(call_ctx.extra_chat_messages)
fnc_ctx = self.fnc_ctx
if (
fnc_ctx
and new_speech_handle.fnc_nested_depth
>= self._opts.max_nested_fnc_calls
and not self._llm.capabilities.requires_persistent_functions
):
if len(fnc_ctx.ai_functions) > 1:
logger.info(
"max function calls nested depth reached, dropping function context. increase max_nested_fnc_calls to enable additional nesting.",
extra={
"speech_id": speech_handle.id,
"fnc_nested_depth": speech_handle.fnc_nested_depth,
},
)
fnc_ctx = None
answer_llm_stream = self._llm.chat(
chat_ctx=chat_ctx,
fnc_ctx=fnc_ctx,
)
synthesis_handle = self._synthesize_agent_speech(
new_speech_handle.id, answer_llm_stream
)
new_speech_handle.initialize(
source=answer_llm_stream, synthesis_handle=synthesis_handle
)
speech_handle.add_nested_speech(new_speech_handle)
self.emit("function_calls_finished", called_fncs)
_CallContextVar.reset(tk)
if not is_using_tools:
# skip the function calls execution
await _stop_nesting_speech()
speech_handle._set_done()
return
fnc_task = asyncio.create_task(_execute_function_calls())
fnc_task.add_done_callback(lambda _: fnc_done_fut.set_result(None))
await nested_speech_task
if not fnc_task.done():
logger.debug(
"cancelling function calls task", extra={"speech_id": speech_handle.id}
)
fnc_task.cancel()
# mark the speech as done
speech_handle._set_done()
def _synthesize_agent_speech(
self,
speech_id: str,
source: str | LLMStream | AsyncIterable[str],
) -> SynthesisHandle:
assert self._agent_output is not None, (
"agent output should be initialized when ready"
)
tk = SpeechDataContextVar.set(SpeechData(speech_id))
async def _llm_stream_to_str_generator(
stream: LLMStream,
) -> AsyncGenerator[str]:
try:
async for chunk in stream:
if not chunk.choices:
continue
content = chunk.choices[0].delta.content
if content is None:
continue
yield content
finally:
await stream.aclose()
if isinstance(source, LLMStream):
source = _llm_stream_to_str_generator(source)
og_source = source
transcript_source = source
if isinstance(og_source, AsyncIterable):
og_source, transcript_source = utils.aio.itertools.tee(og_source, 2)
tts_source = self._opts.before_tts_cb(self, og_source)
if tts_source is None:
raise ValueError("before_tts_cb must return str or AsyncIterable[str]")
try:
return self._agent_output.synthesize(
speech_id=speech_id,
tts_source=tts_source,
transcript_source=transcript_source,
transcription=self._opts.transcription.agent_transcription,
transcription_speed=self._opts.transcription.agent_transcription_speed,
sentence_tokenizer=self._opts.transcription.sentence_tokenizer,
word_tokenizer=self._opts.transcription.word_tokenizer,
hyphenate_word=self._opts.transcription.hyphenate_word,
)
finally:
SpeechDataContextVar.reset(tk)
def _validate_reply_if_possible(self) -> None:
"""Check if the new agent speech should be played"""
if self._playing_speech and not self._playing_speech.interrupted:
should_ignore_input = False
if not self._playing_speech.allow_interruptions:
should_ignore_input = True
logger.debug(
"skipping validation, agent is speaking and does not allow interruptions",
extra={"speech_id": self._playing_speech.id},
)
elif not self._should_interrupt():
should_ignore_input = True
logger.debug(
"interrupt threshold is not met",
extra={"speech_id": self._playing_speech.id},
)
if should_ignore_input:
self._transcribed_text = ""
return
if self._pending_agent_reply is None:
if self._opts.preemptive_synthesis:
return
# as long as we don't have a pending reply, we need to synthesize it
# in order to keep the conversation flowing.
# transcript could be empty at this moment, if the user interrupted the agent
# but did not generate any transcribed text.
self._synthesize_agent_reply()
assert self._pending_agent_reply is not None
# due to timing, we could end up with two pushed agent replies inside the speech queue.
# so make sure we directly interrupt every reply when validating a new one
for speech in self._speech_q:
if not speech.is_reply:
continue
if speech.allow_interruptions:
speech.interrupt()
logger.debug(
"validated agent reply",
extra={
"speech_id": self._pending_agent_reply.id,
"text": self._transcribed_text,
},
)
if self._last_speech_time is not None:
time_since_last_speech = time.perf_counter() - self._last_speech_time
transcription_delay = max(
(self._last_final_transcript_time or 0) - self._last_speech_time, 0
)
eou_metrics = metrics.PipelineEOUMetrics(
timestamp=time.time(),
sequence_id=self._pending_agent_reply.id,
end_of_utterance_delay=time_since_last_speech,
transcription_delay=transcription_delay,
)
self.emit("metrics_collected", eou_metrics)
self._add_speech_for_playout(self._pending_agent_reply)
self._pending_agent_reply = None
self._transcribed_interim_text = ""
# self._transcribed_text is reset after MIN_TIME_PLAYED_FOR_COMMIT, see self._play_speech
def _interrupt_if_possible(self) -> None:
"""Check whether the current assistant speech should be interrupted"""
if self._playing_speech and self._should_interrupt():
self._playing_speech.interrupt()
def _should_interrupt(self) -> bool:
if self._playing_speech is None:
return False
if (
not self._playing_speech.allow_interruptions
or self._playing_speech.interrupted
):
return False
if self._opts.int_min_words != 0:
text = self._transcribed_interim_text or self._transcribed_text
interim_words = self._opts.transcription.word_tokenizer.tokenize(text=text)
if len(interim_words) < self._opts.int_min_words:
return False
return True
def _add_speech_for_playout(self, speech_handle: SpeechHandle) -> None:
self._speech_q.append(speech_handle)
self._speech_q_changed.set()
class _DeferredReplyValidation:
"""This class is used to try to find the best time to validate the agent reply."""
# if the STT gives us punctuation, we can try validate the reply faster.
PUNCTUATION = ".!?"
PUNCTUATION_REDUCE_FACTOR = 0.75
FINAL_TRANSCRIPT_TIMEOUT = 5
def __init__(
self,
validate_fnc: Callable[[], None],
min_endpointing_delay: float,
max_endpointing_delay: float,
turn_detector: _TurnDetector | None,
agent: VoicePipelineAgent,
) -> None:
self._turn_detector = turn_detector
self._validate_fnc = validate_fnc
self._validating_task: asyncio.Task | None = None
self._last_final_transcript: str = ""
self._last_language: str | None = None
self._last_recv_start_of_speech_time: float = 0.0
self._last_recv_end_of_speech_time: float = 0.0
self._last_recv_transcript_time: float = 0.0
self._speaking = False
self._agent = agent
self._end_of_speech_delay = min_endpointing_delay
self._max_endpointing_delay = max_endpointing_delay
@property
def validating(self) -> bool:
return self._validating_task is not None and not self._validating_task.done()
def _compute_delay(self) -> float | None:
"""Computes the amount of time to wait before validating the agent reply.
This function should be called after the agent has received final transcript, or after VAD
"""
# never interrupt the user while they are speaking
if self._speaking:
return None
# if STT doesn't give us the final transcript after end of speech, we'll still validate the reply
# to prevent the agent from getting "stuck"
# in this case, the agent will not have final transcript, so it'll trigger the user input with empty
if not self._last_final_transcript:
return self.FINAL_TRANSCRIPT_TIMEOUT
delay = self._end_of_speech_delay
if self._end_with_punctuation():
delay = delay * self.PUNCTUATION_REDUCE_FACTOR
# the delay should be computed from end of earlier timestamp, that's the true end of user speech
end_of_speech_time = self._last_recv_end_of_speech_time
if (
self._last_recv_transcript_time > 0
and self._last_recv_transcript_time > self._last_recv_start_of_speech_time
and self._last_recv_transcript_time < end_of_speech_time
):
end_of_speech_time = self._last_recv_transcript_time
elapsed_time = time.perf_counter() - end_of_speech_time
if elapsed_time < delay:
delay -= elapsed_time
else:
delay = 0
return delay
def on_human_final_transcript(self, transcript: str, language: str | None) -> None:
self._last_final_transcript += " " + transcript.strip() # type: ignore
logger.debug(
"last language updated",
extra={"from": self._last_language, "to": language},
)
self._last_language = language
self._last_recv_transcript_time = time.perf_counter()
delay = self._compute_delay()
if delay is not None:
self._run(delay)
def on_human_start_of_speech(self, ev: vad.VADEvent) -> None:
self._speaking = True
self._last_recv_start_of_speech_time = time.perf_counter()
if self.validating:
assert self._validating_task is not None
self._validating_task.cancel()
def on_human_end_of_speech(self, ev: vad.VADEvent) -> None:
self._speaking = False
self._last_recv_end_of_speech_time = time.perf_counter()
delay = self._compute_delay()
if delay is not None:
self._run(delay)
async def aclose(self) -> None:
if self._validating_task is not None:
await utils.aio.gracefully_cancel(self._validating_task)
def _end_with_punctuation(self) -> bool:
return (
len(self._last_final_transcript) > 0
and self._last_final_transcript[-1] in self.PUNCTUATION
)
def _reset_states(self) -> None:
self._last_final_transcript = ""
self._last_recv_end_of_speech_time = 0.0
self._last_recv_transcript_time = 0.0
def _run(self, delay: float) -> None:
@utils.log_exceptions(logger=logger)
async def _run_task(chat_ctx: ChatContext, delay: float) -> None:
use_turn_detector = self._last_final_transcript and not self._speaking
if use_turn_detector and self._turn_detector is not None:
if not self._turn_detector.supports_language(self._last_language):
logger.debug(
"turn detector does not support language",
extra={"language": self._last_language},
)
else:
start_time = time.perf_counter()
try:
eot_prob = await self._turn_detector.predict_end_of_turn(
chat_ctx
)
unlikely_threshold = self._turn_detector.unlikely_threshold(
self._last_language
)
elasped = time.perf_counter() - start_time
if eot_prob < unlikely_threshold:
delay = self._max_endpointing_delay
delay = max(0, delay - elasped)
except (TimeoutError, AssertionError):
pass # inference process is unresponsive
await asyncio.sleep(delay)
self._reset_states()
self._validate_fnc()
if self._validating_task is not None:
self._validating_task.cancel()
detect_ctx = self._agent._chat_ctx.copy()
detect_ctx.messages.append(
ChatMessage.create(text=self._agent._transcribed_text, role="user")
)
self._validating_task = asyncio.create_task(_run_task(detect_ctx, delay))
import asyncio
import contextlib
import io
import multiprocessing as mp
import selectors
import socket
import time
from dataclasses import dataclass
from typing import ClassVar, Literal, Tuple
from .. import utils
from ..ipc import channel
PlotType = Literal["vad_probability", "raw_vol", "smoothed_vol"]
EventType = Literal[
"user_started_speaking",
"user_stopped_speaking",
"agent_started_speaking",
"agent_stopped_speaking",
]
@dataclass
class PlotMessage:
MSG_ID: ClassVar[int] = 1
which: PlotType = "vad_probability"
x: float = 0.0
y: float = 0.0
def write(self, b: io.BytesIO) -> None:
channel.write_string(b, self.which)
channel.write_float(b, self.x)
channel.write_float(b, self.y)
def read(self, b: io.BytesIO) -> None:
self.which = channel.read_string(b) # type: ignore
self.x = channel.read_float(b)
self.y = channel.read_float(b)
@dataclass
class PlotEventMessage:
MSG_ID: ClassVar[int] = 2
which: EventType = "user_started_speaking"
x: float = 0.0
def write(self, b: io.BytesIO) -> None:
channel.write_string(b, self.which)
channel.write_float(b, self.x)
def read(self, b: io.BytesIO) -> None:
self.which = channel.read_string(b) # type: ignore
self.x = channel.read_float(b)
PLT_MESSAGES: dict = {
PlotMessage.MSG_ID: PlotMessage,
PlotEventMessage.MSG_ID: PlotEventMessage,
}
def _draw_plot(mp_cch):
try:
import matplotlib as mpl # type: ignore
import matplotlib.pyplot as plt # type: ignore
except ImportError:
raise ImportError(
"matplotlib is required to run use the VoiceAssistant plotter"
)
plt.style.use("ggplot")
mpl.rcParams["toolbar"] = "None"
plot_data: dict[str, Tuple[list[float], list[float]]] = {}
plot_events: dict[str, list[float]] = {}
fig, (pv, sp) = plt.subplots(2, sharex="all")
fig.canvas.manager.set_window_title("Voice Assistant") # type: ignore
max_points = 250
duplex = utils.aio.duplex_unix._Duplex.open(mp_cch)
selector = selectors.DefaultSelector()
selector.register(mp_cch, selectors.EVENT_READ)
def _draw_cb(sp, pv):
while True:
events = selector.select(timeout=0.01)
if not events:
break
msg = channel.recv_message(duplex, PLT_MESSAGES)
if isinstance(msg, PlotMessage):
data = plot_data.setdefault(msg.which, ([], []))
data[0].append(msg.x)
data[1].append(msg.y)
data[0][:] = data[0][-max_points:]
data[1][:] = data[1][-max_points:]
# remove old events older than 7.5s
for events in plot_events.values():
while events and events[0] < msg.x - 7.5:
events.pop(0)
elif isinstance(msg, PlotEventMessage):
events = plot_events.setdefault(msg.which, [])
events.append(msg.x)
vad_raw = plot_data.setdefault("vad_probability", ([], []))
raw_vol = plot_data.get("raw_vol", ([], []))
vol = plot_data.get("smoothed_vol", ([], []))
pv.clear()
pv.set_ylim(0, 1)
pv.set(ylabel="assistant volume")
pv.plot(vol[0], vol[1], label="volume")
pv.plot(raw_vol[0], raw_vol[1], label="target_volume")
pv.legend()
sp.clear()
sp.set_ylim(0, 1)
sp.set(xlabel="time (s)", ylabel="speech probability")
sp.plot(vad_raw[0], vad_raw[1], label="raw")
sp.legend()
for start in plot_events.get("agent_started_speaking", []):
pv.axvline(x=start, color="r", linestyle="--")
for stop in plot_events.get("agent_stopped_speaking", []):
pv.axvline(x=stop, color="r", linestyle="--")
for start in plot_events.get("user_started_speaking", []):
sp.axvline(x=start, color="r", linestyle="--")
for stop in plot_events.get("user_stopped_speaking", []):
sp.axvline(x=stop, color="r", linestyle="--")
fig.canvas.draw()
timer = fig.canvas.new_timer(interval=33)
timer.add_callback(_draw_cb, sp, pv)
timer.start()
plt.show()
class AssistantPlotter:
def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
self._loop = loop
self._started = False
async def start(self):
if self._started:
return
mp_pch, mp_cch = socket.socketpair()
self._duplex = await utils.aio.duplex_unix._AsyncDuplex.open(mp_pch)
self._plot_proc = mp.Process(target=_draw_plot, args=(mp_cch,), daemon=True)
self._plot_proc.start()
mp_cch.close()
self._started = True
self._closed = False
self._start_time = time.time()
def plot_value(self, which: PlotType, y: float):
if not self._started:
return
ts = time.time() - self._start_time
self._send_message(PlotMessage(which=which, x=ts, y=y))
def plot_event(self, which: EventType):
if not self._started:
return
ts = time.time() - self._start_time
self._send_message(PlotEventMessage(which=which, x=ts))
def _send_message(self, msg: channel.Message) -> None:
if self._closed:
return
async def _asend_message():
try:
await channel.asend_message(self._duplex, msg)
except Exception:
self._closed = True
asyncio.ensure_future(_asend_message())
async def terminate(self):
if not self._started:
return
self._plot_proc.terminate()
with contextlib.suppress(utils.aio.duplex_unix.DuplexClosed):
await self._duplex.aclose()
from __future__ import annotations
import asyncio
from typing import AsyncIterable
from .. import utils
from ..llm import ChatMessage, LLMStream
from .agent_output import SynthesisHandle
class SpeechHandle:
def __init__(
self,
*,
id: str,
allow_interruptions: bool,
add_to_chat_ctx: bool,
is_reply: bool,
user_question: str,
fnc_nested_depth: int = 0,
extra_tools_messages: list[ChatMessage] | None = None,
fnc_text_message_id: str | None = None,
) -> None:
self._id = id
self._allow_interruptions = allow_interruptions
self._add_to_chat_ctx = add_to_chat_ctx
# is_reply is True when the speech is answering to a user question
self._is_reply = is_reply
self._user_question = user_question
self._user_committed = False
self._init_fut = asyncio.Future[None]()
self._done_fut = asyncio.Future[None]()
self._initialized = False
self._speech_committed = False # speech committed (interrupted or not)
# source and synthesis_handle are None until the speech is initialized
self._source: str | LLMStream | AsyncIterable[str] | None = None
self._synthesis_handle: SynthesisHandle | None = None
# nested speech handle and function calls
self._fnc_nested_depth = fnc_nested_depth
self._fnc_extra_tools_messages: list[ChatMessage] | None = extra_tools_messages
self._fnc_text_message_id: str | None = fnc_text_message_id
self._nested_speech_handles: list[SpeechHandle] = []
self._nested_speech_changed = asyncio.Event()
self._nested_speech_done_fut: asyncio.Future[None] | None = None
@staticmethod
def create_assistant_reply(
*,
allow_interruptions: bool,
add_to_chat_ctx: bool,
user_question: str,
) -> SpeechHandle:
return SpeechHandle(
id=utils.shortuuid(),
allow_interruptions=allow_interruptions,
add_to_chat_ctx=add_to_chat_ctx,
is_reply=True,
user_question=user_question,
)
@staticmethod
def create_assistant_speech(
*,
allow_interruptions: bool,
add_to_chat_ctx: bool,
) -> SpeechHandle:
return SpeechHandle(
id=utils.shortuuid(),
allow_interruptions=allow_interruptions,
add_to_chat_ctx=add_to_chat_ctx,
is_reply=False,
user_question="",
)
@staticmethod
def create_tool_speech(
*,
allow_interruptions: bool,
add_to_chat_ctx: bool,
fnc_nested_depth: int,
extra_tools_messages: list[ChatMessage],
fnc_text_message_id: str | None = None,
) -> SpeechHandle:
return SpeechHandle(
id=utils.shortuuid(),
allow_interruptions=allow_interruptions,
add_to_chat_ctx=add_to_chat_ctx,
is_reply=False,
user_question="",
fnc_nested_depth=fnc_nested_depth,
extra_tools_messages=extra_tools_messages,
fnc_text_message_id=fnc_text_message_id,
)
async def wait_for_initialization(self) -> None:
await asyncio.shield(self._init_fut)
def initialize(
self,
*,
source: str | LLMStream | AsyncIterable[str],
synthesis_handle: SynthesisHandle,
) -> None:
if self.interrupted:
raise RuntimeError("speech is interrupted")
self._source = source
self._synthesis_handle = synthesis_handle
self._initialized = True
self._init_fut.set_result(None)
def mark_user_committed(self) -> None:
self._user_committed = True
def mark_speech_committed(self) -> None:
self._speech_committed = True
@property
def user_committed(self) -> bool:
return self._user_committed
@property
def speech_committed(self) -> bool:
return self._speech_committed
@property
def id(self) -> str:
return self._id
@property
def allow_interruptions(self) -> bool:
return self._allow_interruptions
@property
def add_to_chat_ctx(self) -> bool:
return self._add_to_chat_ctx
@property
def source(self) -> str | LLMStream | AsyncIterable[str]:
if self._source is None:
raise RuntimeError("speech not initialized")
return self._source
@property
def synthesis_handle(self) -> SynthesisHandle:
if self._synthesis_handle is None:
raise RuntimeError("speech not initialized")
return self._synthesis_handle
@synthesis_handle.setter
def synthesis_handle(self, synthesis_handle: SynthesisHandle) -> None:
"""synthesis handle can be replaced for the same speech.
This is useful when we need to do a new generation. (e.g for automatic function call answers)"""
if self._synthesis_handle is None:
raise RuntimeError("speech not initialized")
self._synthesis_handle = synthesis_handle
@property
def initialized(self) -> bool:
return self._initialized
@property
def is_reply(self) -> bool:
return self._is_reply
@property
def user_question(self) -> str:
return self._user_question
@property
def interrupted(self) -> bool:
return self._init_fut.cancelled() or (
self._synthesis_handle is not None and self._synthesis_handle.interrupted
)
def join(self) -> asyncio.Future:
return self._done_fut
def _set_done(self) -> None:
self._done_fut.set_result(None)
def interrupt(self) -> None:
if not self.allow_interruptions:
raise RuntimeError("interruptions are not allowed")
self.cancel()
def cancel(self, cancel_nested: bool = False) -> None:
self._init_fut.cancel()
if self._synthesis_handle is not None:
self._synthesis_handle.interrupt()
if cancel_nested:
for speech in self._nested_speech_handles:
speech.cancel(cancel_nested=True)
self.mark_nested_speech_done()
@property
def fnc_nested_depth(self) -> int:
return self._fnc_nested_depth
@property
def extra_tools_messages(self) -> list[ChatMessage] | None:
return self._fnc_extra_tools_messages
@property
def fnc_text_message_id(self) -> str | None:
return self._fnc_text_message_id
def add_nested_speech(self, speech_handle: SpeechHandle) -> None:
self._nested_speech_handles.append(speech_handle)
self._nested_speech_changed.set()
@property
def nested_speech_handles(self) -> list[SpeechHandle]:
return self._nested_speech_handles
@property
def nested_speech_changed(self) -> asyncio.Event:
return self._nested_speech_changed
@property
def nested_speech_done(self) -> bool:
# True if not started or done
return (
self._nested_speech_done_fut is None or self._nested_speech_done_fut.done()
)
def mark_nested_speech_done(self) -> None:
if self._nested_speech_done_fut is None or self._nested_speech_done_fut.done():
return
self._nested_speech_done_fut.set_result(None)
from __future__ import annotations
import logging
import threading
from abc import ABC
from typing import List, Literal
from . import utils
EventTypes = Literal["plugin_registered",]
class Plugin(ABC):
registered_plugins: List["Plugin"] = []
emitter: utils.EventEmitter[EventTypes] = utils.EventEmitter()
# TODO(theomonnom): make logger mandatory once all plugins have been updated
def __init__(
self,
title: str,
version: str,
package: str,
logger: logging.Logger | None = None,
) -> None:
self._title = title
self._version = version
self._package = package
self._logger = logger
@classmethod
def register_plugin(cls, plugin: "Plugin") -> None:
if threading.current_thread() != threading.main_thread():
raise RuntimeError("Plugins must be registered on the main thread")
cls.registered_plugins.append(plugin)
cls.emitter.emit("plugin_registered", plugin)
# plugin can implement an optional download_files method
def download_files(self) -> None: ...
@property
def package(self) -> str:
return self._package
@property
def title(self) -> str:
return self._title
@property
def version(self) -> str:
return self._version
@property
def logger(self) -> logging.Logger | None:
return self._logger
from .fallback_adapter import AvailabilityChangedEvent, FallbackAdapter
from .stream_adapter import StreamAdapter, StreamAdapterWrapper
from .stt import (
STT,
RecognitionUsage,
RecognizeStream,
SpeechData,
SpeechEvent,
SpeechEventType,
SpeechStream,
STTCapabilities,
)
__all__ = [
"SpeechEventType",
"SpeechEvent",
"SpeechData",
"RecognizeStream",
"SpeechStream",
"STT",
"STTCapabilities",
"StreamAdapter",
"StreamAdapterWrapper",
"RecognitionUsage",
"FallbackAdapter",
"AvailabilityChangedEvent",
]
from __future__ import annotations
import asyncio
import contextlib
import dataclasses
import time
from dataclasses import dataclass
from typing import Literal
from livekit import rtc
from livekit.agents.utils.audio import AudioBuffer
from .. import utils
from .._exceptions import APIConnectionError, APIError
from ..log import logger
from ..types import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions
from ..utils import aio
from .stt import STT, RecognizeStream, SpeechEvent, SpeechEventType, STTCapabilities
# don't retry when using the fallback adapter
DEFAULT_FALLBACK_API_CONNECT_OPTIONS = APIConnectOptions(
max_retry=0, timeout=DEFAULT_API_CONNECT_OPTIONS.timeout
)
@dataclass
class AvailabilityChangedEvent:
stt: STT
available: bool
@dataclass
class _STTStatus:
available: bool
recovering_synthesize_task: asyncio.Task | None
recovering_stream_task: asyncio.Task | None
class FallbackAdapter(
STT[Literal["stt_availability_changed"]],
):
def __init__(
self,
stt: list[STT],
*,
attempt_timeout: float = 10.0,
max_retry_per_stt: int = 1,
retry_interval: float = 5,
) -> None:
if len(stt) < 1:
raise ValueError("At least one STT instance must be provided.")
non_streaming_stt = [t for t in stt if not t.capabilities.streaming]
if non_streaming_stt:
labels = ", ".join(t.label for t in non_streaming_stt)
raise ValueError(
f"STTs do not support streaming: {labels}. "
"Wrap them with stt.StreamAdapter to enable streaming."
)
super().__init__(
capabilities=STTCapabilities(
streaming=True,
interim_results=all(t.capabilities.interim_results for t in stt),
)
)
self._stt_instances = stt
self._attempt_timeout = attempt_timeout
self._max_retry_per_stt = max_retry_per_stt
self._retry_interval = retry_interval
self._status: list[_STTStatus] = [
_STTStatus(
available=True,
recovering_synthesize_task=None,
recovering_stream_task=None,
)
for _ in self._stt_instances
]
async def _try_recognize(
self,
*,
stt: STT,
buffer: utils.AudioBuffer,
language: str | None = None,
conn_options: APIConnectOptions,
recovering: bool = False,
) -> SpeechEvent:
try:
return await stt.recognize(
buffer,
language=language,
conn_options=dataclasses.replace(
conn_options,
max_retry=self._max_retry_per_stt,
timeout=self._attempt_timeout,
retry_interval=self._retry_interval,
),
)
except asyncio.TimeoutError:
if recovering:
logger.warning(
f"{stt.label} recovery timed out", extra={"streamed": False}
)
raise
logger.warning(
f"{stt.label} timed out, switching to next STT",
extra={"streamed": False},
)
raise
except APIError as e:
if recovering:
logger.warning(
f"{stt.label} recovery failed",
exc_info=e,
extra={"streamed": False},
)
raise
logger.warning(
f"{stt.label} failed, switching to next STT",
exc_info=e,
extra={"streamed": False},
)
raise
except Exception:
if recovering:
logger.exception(
f"{stt.label} recovery unexpected error", extra={"streamed": False}
)
raise
logger.exception(
f"{stt.label} unexpected error, switching to next STT",
extra={"streamed": False},
)
raise
def _try_recovery(
self,
*,
stt: STT,
buffer: utils.AudioBuffer,
language: str | None,
conn_options: APIConnectOptions,
) -> None:
stt_status = self._status[self._stt_instances.index(stt)]
if (
stt_status.recovering_synthesize_task is None
or stt_status.recovering_synthesize_task.done()
):
async def _recover_stt_task(stt: STT) -> None:
try:
await self._try_recognize(
stt=stt,
buffer=buffer,
language=language,
conn_options=conn_options,
recovering=True,
)
stt_status.available = True
logger.info(f"{stt.label} recovered")
self.emit(
"stt_availability_changed",
AvailabilityChangedEvent(stt=stt, available=True),
)
except Exception:
return
stt_status.recovering_synthesize_task = asyncio.create_task(
_recover_stt_task(stt)
)
async def _recognize_impl(
self,
buffer: utils.AudioBuffer,
*,
language: str | None,
conn_options: APIConnectOptions,
):
start_time = time.time()
all_failed = all(not stt_status.available for stt_status in self._status)
if all_failed:
logger.error("all STTs are unavailable, retrying..")
for i, stt in enumerate(self._stt_instances):
stt_status = self._status[i]
if stt_status.available or all_failed:
try:
return await self._try_recognize(
stt=stt,
buffer=buffer,
language=language,
conn_options=conn_options,
recovering=False,
)
except Exception: # exceptions already logged inside _try_recognize
if stt_status.available:
stt_status.available = False
self.emit(
"stt_availability_changed",
AvailabilityChangedEvent(stt=stt, available=False),
)
self._try_recovery(
stt=stt, buffer=buffer, language=language, conn_options=conn_options
)
raise APIConnectionError(
"all STTs failed (%s) after %s seconds"
% (
[stt.label for stt in self._stt_instances],
time.time() - start_time,
)
)
async def recognize(
self,
buffer: AudioBuffer,
*,
language: str | None = None,
conn_options: APIConnectOptions = DEFAULT_FALLBACK_API_CONNECT_OPTIONS,
) -> SpeechEvent:
return await super().recognize(
buffer, language=language, conn_options=conn_options
)
def stream(
self,
*,
language: str | None = None,
conn_options: APIConnectOptions = DEFAULT_FALLBACK_API_CONNECT_OPTIONS,
) -> RecognizeStream:
return FallbackRecognizeStream(
stt=self, language=language, conn_options=conn_options
)
async def aclose(self) -> None:
for stt_status in self._status:
if stt_status.recovering_synthesize_task is not None:
await aio.gracefully_cancel(stt_status.recovering_synthesize_task)
if stt_status.recovering_stream_task is not None:
await aio.gracefully_cancel(stt_status.recovering_stream_task)
class FallbackRecognizeStream(RecognizeStream):
def __init__(
self,
*,
stt: FallbackAdapter,
language: str | None,
conn_options: APIConnectOptions,
):
super().__init__(stt=stt, conn_options=conn_options, sample_rate=None)
self._language = language
self._fallback_adapter = stt
self._recovering_streams: list[RecognizeStream] = []
async def _run(self) -> None:
start_time = time.time()
all_failed = all(
not stt_status.available for stt_status in self._fallback_adapter._status
)
if all_failed:
logger.error("all STTs are unavailable, retrying..")
main_stream: RecognizeStream | None = None
forward_input_task: asyncio.Task | None = None
async def _forward_input_task() -> None:
with contextlib.suppress(RuntimeError): # stream might be closed
async for data in self._input_ch:
for stream in self._recovering_streams:
if isinstance(data, rtc.AudioFrame):
stream.push_frame(data)
elif isinstance(data, self._FlushSentinel):
stream.flush()
if main_stream is not None:
if isinstance(data, rtc.AudioFrame):
main_stream.push_frame(data)
elif isinstance(data, self._FlushSentinel):
main_stream.flush()
if main_stream is not None:
main_stream.end_input()
for i, stt in enumerate(self._fallback_adapter._stt_instances):
stt_status = self._fallback_adapter._status[i]
if stt_status.available or all_failed:
try:
main_stream = stt.stream(
language=self._language,
conn_options=dataclasses.replace(
self._conn_options,
max_retry=self._fallback_adapter._max_retry_per_stt,
timeout=self._fallback_adapter._attempt_timeout,
retry_interval=self._fallback_adapter._retry_interval,
),
)
if forward_input_task is None or forward_input_task.done():
forward_input_task = asyncio.create_task(_forward_input_task())
try:
async with main_stream:
async for ev in main_stream:
self._event_ch.send_nowait(ev)
except asyncio.TimeoutError:
logger.warning(
f"{stt.label} timed out, switching to next STT",
extra={"streamed": True},
)
raise
except APIError as e:
logger.warning(
f"{stt.label} failed, switching to next STT",
exc_info=e,
extra={"streamed": True},
)
raise
except Exception:
logger.exception(
f"{stt.label} unexpected error, switching to next STT",
extra={"streamed": True},
)
raise
return
except Exception:
if stt_status.available:
stt_status.available = False
self._stt.emit(
"stt_availability_changed",
AvailabilityChangedEvent(stt=stt, available=False),
)
self._try_recovery(stt)
if forward_input_task is not None:
await aio.gracefully_cancel(forward_input_task)
await asyncio.gather(*[stream.aclose() for stream in self._recovering_streams])
raise APIConnectionError(
"all STTs failed (%s) after %s seconds"
% (
[stt.label for stt in self._fallback_adapter._stt_instances],
time.time() - start_time,
)
)
def _try_recovery(self, stt: STT) -> None:
stt_status = self._fallback_adapter._status[
self._fallback_adapter._stt_instances.index(stt)
]
if (
stt_status.recovering_stream_task is None
or stt_status.recovering_stream_task.done()
):
stream = stt.stream(
language=self._language,
conn_options=dataclasses.replace(
self._conn_options,
max_retry=0,
timeout=self._fallback_adapter._attempt_timeout,
),
)
self._recovering_streams.append(stream)
async def _recover_stt_task() -> None:
try:
nb_transcript = 0
async with stream:
async for ev in stream:
if ev.type in SpeechEventType.FINAL_TRANSCRIPT:
if not ev.alternatives or not ev.alternatives[0].text:
continue
nb_transcript += 1
break
if nb_transcript == 0:
return
stt_status.available = True
logger.info(f"tts.FallbackAdapter, {stt.label} recovered")
self._fallback_adapter.emit(
"stt_availability_changed",
AvailabilityChangedEvent(stt=stt, available=True),
)
except asyncio.TimeoutError:
logger.warning(
f"{stream._stt.label} recovery timed out",
extra={"streamed": True},
)
except APIError as e:
logger.warning(
f"{stream._stt.label} recovery failed",
exc_info=e,
extra={"streamed": True},
)
except Exception:
logger.exception(
f"{stream._stt.label} recovery unexpected error",
extra={"streamed": True},
)
raise
stt_status.recovering_stream_task = task = asyncio.create_task(
_recover_stt_task()
)
task.add_done_callback(lambda _: self._recovering_streams.remove(stream))
from __future__ import annotations
import asyncio
from typing import AsyncIterable
from .. import utils
from ..types import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions
from ..vad import VAD, VADEventType
from .stt import STT, RecognizeStream, SpeechEvent, SpeechEventType, STTCapabilities
class StreamAdapter(STT):
def __init__(self, *, stt: STT, vad: VAD) -> None:
super().__init__(
capabilities=STTCapabilities(streaming=True, interim_results=False)
)
self._vad = vad
self._stt = stt
@self._stt.on("metrics_collected")
def _forward_metrics(*args, **kwargs):
self.emit("metrics_collected", *args, **kwargs)
@property
def wrapped_stt(self) -> STT:
return self._stt
async def _recognize_impl(
self,
buffer: utils.AudioBuffer,
*,
language: str | None,
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
):
return await self._stt.recognize(
buffer=buffer, language=language, conn_options=conn_options
)
def stream(
self,
*,
language: str | None = None,
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
) -> RecognizeStream:
return StreamAdapterWrapper(
self,
vad=self._vad,
wrapped_stt=self._stt,
language=language,
conn_options=conn_options,
)
class StreamAdapterWrapper(RecognizeStream):
def __init__(
self,
stt: STT,
*,
vad: VAD,
wrapped_stt: STT,
language: str | None,
conn_options: APIConnectOptions,
) -> None:
super().__init__(stt=stt, conn_options=conn_options)
self._vad = vad
self._wrapped_stt = wrapped_stt
self._vad_stream = self._vad.stream()
self._language = language
async def _metrics_monitor_task(
self, event_aiter: AsyncIterable[SpeechEvent]
) -> None:
pass # do nothing
async def _run(self) -> None:
async def _forward_input():
"""forward input to vad"""
async for input in self._input_ch:
if isinstance(input, self._FlushSentinel):
self._vad_stream.flush()
continue
self._vad_stream.push_frame(input)
self._vad_stream.end_input()
async def _recognize():
"""recognize speech from vad"""
async for event in self._vad_stream:
if event.type == VADEventType.START_OF_SPEECH:
self._event_ch.send_nowait(
SpeechEvent(SpeechEventType.START_OF_SPEECH)
)
elif event.type == VADEventType.END_OF_SPEECH:
self._event_ch.send_nowait(
SpeechEvent(
type=SpeechEventType.END_OF_SPEECH,
)
)
merged_frames = utils.merge_frames(event.frames)
t_event = await self._wrapped_stt.recognize(
buffer=merged_frames,
language=self._language,
conn_options=self._conn_options,
)
if len(t_event.alternatives) == 0:
continue
elif not t_event.alternatives[0].text:
continue
self._event_ch.send_nowait(
SpeechEvent(
type=SpeechEventType.FINAL_TRANSCRIPT,
alternatives=[t_event.alternatives[0]],
)
)
tasks = [
asyncio.create_task(_forward_input(), name="forward_input"),
asyncio.create_task(_recognize(), name="recognize"),
]
try:
await asyncio.gather(*tasks)
finally:
await utils.aio.gracefully_cancel(*tasks)
from __future__ import annotations
import asyncio
import time
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import Enum, unique
from types import TracebackType
from typing import AsyncIterable, AsyncIterator, Generic, List, Literal, TypeVar, Union
from livekit import rtc
from .._exceptions import APIConnectionError, APIError
from ..log import logger
from ..metrics import STTMetrics
from ..types import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions
from ..utils import AudioBuffer, aio
from ..utils.audio import calculate_audio_duration
@unique
class SpeechEventType(str, Enum):
START_OF_SPEECH = "start_of_speech"
"""indicate the start of speech
if the STT doesn't support this event, this will be emitted as the same time as the first INTERIM_TRANSCRIPT"""
INTERIM_TRANSCRIPT = "interim_transcript"
"""interim transcript, useful for real-time transcription"""
FINAL_TRANSCRIPT = "final_transcript"
"""final transcript, emitted when the STT is confident enough that a certain
portion of speech will not change"""
RECOGNITION_USAGE = "recognition_usage"
"""usage event, emitted periodically to indicate usage metrics"""
END_OF_SPEECH = "end_of_speech"
"""indicate the end of speech, emitted when the user stops speaking"""
@dataclass
class SpeechData:
language: str
text: str
start_time: float = 0.0
end_time: float = 0.0
confidence: float = 0.0 # [0, 1]
@dataclass
class RecognitionUsage:
audio_duration: float
@dataclass
class SpeechEvent:
type: SpeechEventType
request_id: str = ""
alternatives: List[SpeechData] = field(default_factory=list)
recognition_usage: RecognitionUsage | None = None
@dataclass
class STTCapabilities:
streaming: bool
interim_results: bool
TEvent = TypeVar("TEvent")
class STT(
ABC,
rtc.EventEmitter[Union[Literal["metrics_collected"], TEvent]],
Generic[TEvent],
):
def __init__(self, *, capabilities: STTCapabilities) -> None:
super().__init__()
self._capabilities = capabilities
self._label = f"{type(self).__module__}.{type(self).__name__}"
@property
def label(self) -> str:
return self._label
@property
def capabilities(self) -> STTCapabilities:
return self._capabilities
@abstractmethod
async def _recognize_impl(
self,
buffer: AudioBuffer,
*,
language: str | None,
conn_options: APIConnectOptions,
) -> SpeechEvent: ...
async def recognize(
self,
buffer: AudioBuffer,
*,
language: str | None = None,
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
) -> SpeechEvent:
for i in range(conn_options.max_retry + 1):
try:
start_time = time.perf_counter()
event = await self._recognize_impl(
buffer, language=language, conn_options=conn_options
)
duration = time.perf_counter() - start_time
stt_metrics = STTMetrics(
request_id=event.request_id,
timestamp=time.time(),
duration=duration,
label=self._label,
audio_duration=calculate_audio_duration(buffer),
streamed=False,
error=None,
)
self.emit("metrics_collected", stt_metrics)
return event
except APIError as e:
retry_interval = conn_options._interval_for_retry(i)
if conn_options.max_retry == 0:
raise
elif i == conn_options.max_retry:
raise APIConnectionError(
f"failed to recognize speech after {conn_options.max_retry + 1} attempts",
) from e
else:
logger.warning(
f"failed to recognize speech, retrying in {retry_interval}s",
exc_info=e,
extra={
"tts": self._label,
"attempt": i + 1,
"streamed": False,
},
)
await asyncio.sleep(retry_interval)
raise RuntimeError("unreachable")
def stream(
self,
*,
language: str | None = None,
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
) -> "RecognizeStream":
raise NotImplementedError(
"streaming is not supported by this STT, please use a different STT or use a StreamAdapter"
)
async def aclose(self) -> None:
"""Close the STT, and every stream/requests associated with it"""
...
async def __aenter__(self) -> STT:
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
await self.aclose()
class RecognizeStream(ABC):
class _FlushSentinel:
"""Sentinel to mark when it was flushed"""
pass
def __init__(
self,
*,
stt: STT,
conn_options: APIConnectOptions,
sample_rate: int | None = None,
):
"""
Args:
sample_rate : int or None, optional
The desired sample rate for the audio input.
If specified, the audio input will be automatically resampled to match
the given sample rate before being processed for Speech-to-Text.
If not provided (None), the input will retain its original sample rate.
"""
self._stt = stt
self._conn_options = conn_options
self._input_ch = aio.Chan[
Union[rtc.AudioFrame, RecognizeStream._FlushSentinel]
]()
self._event_ch = aio.Chan[SpeechEvent]()
self._event_aiter, monitor_aiter = aio.itertools.tee(self._event_ch, 2)
self._metrics_task = asyncio.create_task(
self._metrics_monitor_task(monitor_aiter), name="STT._metrics_task"
)
self._task = asyncio.create_task(self._main_task())
self._task.add_done_callback(lambda _: self._event_ch.close())
self._needed_sr = sample_rate
self._pushed_sr = 0
self._resampler: rtc.AudioResampler | None = None
@abstractmethod
async def _run(self) -> None: ...
async def _main_task(self) -> None:
max_retries = self._conn_options.max_retry
num_retries = 0
while num_retries <= max_retries:
try:
return await self._run()
except APIError as e:
if max_retries == 0:
raise
elif num_retries == max_retries:
raise APIConnectionError(
f"failed to recognize speech after {num_retries} attempts",
) from e
else:
retry_interval = self._conn_options._interval_for_retry(num_retries)
logger.warning(
f"failed to recognize speech, retrying in {retry_interval}s",
exc_info=e,
extra={
"tts": self._stt._label,
"attempt": num_retries,
"streamed": True,
},
)
await asyncio.sleep(retry_interval)
num_retries += 1
async def _metrics_monitor_task(
self, event_aiter: AsyncIterable[SpeechEvent]
) -> None:
"""Task used to collect metrics"""
start_time = time.perf_counter()
async for ev in event_aiter:
if ev.type == SpeechEventType.RECOGNITION_USAGE:
assert ev.recognition_usage is not None, (
"recognition_usage must be provided for RECOGNITION_USAGE event"
)
duration = time.perf_counter() - start_time
stt_metrics = STTMetrics(
request_id=ev.request_id,
timestamp=time.time(),
duration=duration,
label=self._stt._label,
audio_duration=ev.recognition_usage.audio_duration,
streamed=True,
error=None,
)
self._stt.emit("metrics_collected", stt_metrics)
def push_frame(self, frame: rtc.AudioFrame) -> None:
"""Push audio to be recognized"""
self._check_input_not_ended()
self._check_not_closed()
if self._pushed_sr and self._pushed_sr != frame.sample_rate:
raise ValueError("the sample rate of the input frames must be consistent")
self._pushed_sr = frame.sample_rate
if self._needed_sr and self._needed_sr != frame.sample_rate:
if not self._resampler:
self._resampler = rtc.AudioResampler(
frame.sample_rate,
self._needed_sr,
quality=rtc.AudioResamplerQuality.HIGH,
)
if self._resampler:
for frame in self._resampler.push(frame):
self._input_ch.send_nowait(frame)
else:
self._input_ch.send_nowait(frame)
def flush(self) -> None:
"""Mark the end of the current segment"""
self._check_input_not_ended()
self._check_not_closed()
if self._resampler:
for frame in self._resampler.flush():
self._input_ch.send_nowait(frame)
self._input_ch.send_nowait(self._FlushSentinel())
def end_input(self) -> None:
"""Mark the end of input, no more audio will be pushed"""
self.flush()
self._input_ch.close()
async def aclose(self) -> None:
"""Close ths stream immediately"""
self._input_ch.close()
await aio.gracefully_cancel(self._task)
if self._metrics_task is not None:
await self._metrics_task
async def __anext__(self) -> SpeechEvent:
try:
val = await self._event_aiter.__anext__()
except StopAsyncIteration:
if not self._task.cancelled() and (exc := self._task.exception()):
raise exc from None
raise StopAsyncIteration
return val
def __aiter__(self) -> AsyncIterator[SpeechEvent]:
return self
def _check_not_closed(self) -> None:
if self._event_ch.closed:
cls = type(self)
raise RuntimeError(f"{cls.__module__}.{cls.__name__} is closed")
def _check_input_not_ended(self) -> None:
if self._input_ch.closed:
cls = type(self)
raise RuntimeError(f"{cls.__module__}.{cls.__name__} input ended")
async def __aenter__(self) -> RecognizeStream:
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
await self.aclose()
SpeechStream = RecognizeStream # deprecated alias
from . import basic, utils
from .token_stream import (
BufferedSentenceStream,
BufferedWordStream,
)
from .tokenizer import (
SentenceStream,
SentenceTokenizer,
TokenData,
WordStream,
WordTokenizer,
)
__all__ = [
"SentenceTokenizer",
"SentenceStream",
"WordTokenizer",
"WordStream",
"TokenData",
"BufferedSentenceStream",
"BufferedWordStream",
"basic",
"utils",
]
from __future__ import annotations
import re
# Frank Liang hyphenator. impl from https://github.com/jfinkels/hyphenate
# This is English only, it is a good default.
# Users that want different languages or more advanced hyphenation should use the livekit-plugins-*
class Hyphenator:
def __init__(self, patterns, exceptions=""):
self.tree = {}
for pattern in patterns.split():
self._insert_pattern(pattern)
self.exceptions = {}
for ex in exceptions.split():
# Convert the hyphenated pattern into a point array for use later.
points = [0] + [int(h == "-") for h in re.split(r"[a-z]", ex)]
self.exceptions[ex.replace("-", "")] = points
def _insert_pattern(self, pattern):
# Convert the a pattern like 'a1bc3d4' into a string of chars 'abcd'
# and a list of points [ 0, 1, 0, 3, 4 ].
chars = re.sub("[0-9]", "", pattern)
points = [int(d or 0) for d in re.split("[.a-z]", pattern)]
# Insert the pattern into the tree. Each character finds a dict
# another level down in the tree, and leaf nodes have the list of
# points.
t = self.tree
for c in chars:
if c not in t:
t[c] = {}
t = t[c]
t[None] = points
def hyphenate_word(self, word: str) -> list[str]:
"""Given a word, returns a list of pieces, broken at the possible
hyphenation points.
"""
# Short words aren't hyphenated.
if len(word) <= 4:
return [word]
# If the word is an exception, get the stored points.
if word.lower() in self.exceptions:
points = self.exceptions[word.lower()]
else:
work = "." + word.lower() + "."
points = [0] * (len(work) + 1)
for i in range(len(work)):
t = self.tree
for c in work[i:]:
if c in t:
t = t[c]
if None in t:
p = t[None]
for j, p_j in enumerate(p):
points[i + j] = max(points[i + j], p_j)
else:
break
# No hyphens in the first two chars or the last two.
points[1] = points[2] = points[-2] = points[-3] = 0
# Examine the points to build the pieces list.
pieces = [""]
for c, p in zip(word, points[2:]):
pieces[-1] += c
if p % 2:
pieces.append("")
return pieces
PATTERNS = (
# Knuth and Liang's original hyphenation patterns from classic TeX.
# In the public domain.
"""
.ach4 .ad4der .af1t .al3t .am5at .an5c .ang4 .ani5m .ant4 .an3te .anti5s
.ar5s .ar4tie .ar4ty .as3c .as1p .as1s .aster5 .atom5 .au1d .av4i .awn4
.ba4g .ba5na .bas4e .ber4 .be5ra .be3sm .be5sto .bri2 .but4ti .cam4pe
.can5c .capa5b .car5ol .ca4t .ce4la .ch4 .chill5i .ci2 .cit5r .co3e .co4r
.cor5ner .de4moi .de3o .de3ra .de3ri .des4c .dictio5 .do4t .du4c .dumb5
.earth5 .eas3i .eb4 .eer4 .eg2 .el5d .el3em .enam3 .en3g .en3s .eq5ui5t
.er4ri .es3 .eu3 .eye5 .fes3 .for5mer .ga2 .ge2 .gen3t4 .ge5og .gi5a .gi4b
.go4r .hand5i .han5k .he2 .hero5i .hes3 .het3 .hi3b .hi3er .hon5ey .hon3o
.hov5 .id4l .idol3 .im3m .im5pin .in1 .in3ci .ine2 .in2k .in3s .ir5r .is4i
.ju3r .la4cy .la4m .lat5er .lath5 .le2 .leg5e .len4 .lep5 .lev1 .li4g
.lig5a .li2n .li3o .li4t .mag5a5 .mal5o .man5a .mar5ti .me2 .mer3c .me5ter
.mis1 .mist5i .mon3e .mo3ro .mu5ta .muta5b .ni4c .od2 .odd5 .of5te .or5ato
.or3c .or1d .or3t .os3 .os4tl .oth3 .out3 .ped5al .pe5te .pe5tit .pi4e
.pio5n .pi2t .pre3m .ra4c .ran4t .ratio5na .ree2 .re5mit .res2 .re5stat
.ri4g .rit5u .ro4q .ros5t .row5d .ru4d .sci3e .self5 .sell5 .se2n .se5rie
.sh2 .si2 .sing4 .st4 .sta5bl .sy2 .ta4 .te4 .ten5an .th2 .ti2 .til4
.tim5o5 .ting4 .tin5k .ton4a .to4p .top5i .tou5s .trib5ut .un1a .un3ce
.under5 .un1e .un5k .un5o .un3u .up3 .ure3 .us5a .ven4de .ve5ra .wil5i .ye4
4ab. a5bal a5ban abe2 ab5erd abi5a ab5it5ab ab5lat ab5o5liz 4abr ab5rog
ab3ul a4car ac5ard ac5aro a5ceou ac1er a5chet 4a2ci a3cie ac1in a3cio
ac5rob act5if ac3ul ac4um a2d ad4din ad5er. 2adi a3dia ad3ica adi4er a3dio
a3dit a5diu ad4le ad3ow ad5ran ad4su 4adu a3duc ad5um ae4r aeri4e a2f aff4
a4gab aga4n ag5ell age4o 4ageu ag1i 4ag4l ag1n a2go 3agog ag3oni a5guer
ag5ul a4gy a3ha a3he ah4l a3ho ai2 a5ia a3ic. ai5ly a4i4n ain5in ain5o
ait5en a1j ak1en al5ab al3ad a4lar 4aldi 2ale al3end a4lenti a5le5o al1i
al4ia. ali4e al5lev 4allic 4alm a5log. a4ly. 4alys 5a5lyst 5alyt 3alyz 4ama
am5ab am3ag ama5ra am5asc a4matis a4m5ato am5era am3ic am5if am5ily am1in
ami4no a2mo a5mon amor5i amp5en a2n an3age 3analy a3nar an3arc anar4i
a3nati 4and ande4s an3dis an1dl an4dow a5nee a3nen an5est. a3neu 2ang
ang5ie an1gl a4n1ic a3nies an3i3f an4ime a5nimi a5nine an3io a3nip an3ish
an3it a3niu an4kli 5anniz ano4 an5ot anoth5 an2sa an4sco an4sn an2sp ans3po
an4st an4sur antal4 an4tie 4anto an2tr an4tw an3ua an3ul a5nur 4ao apar4
ap5at ap5ero a3pher 4aphi a4pilla ap5illar ap3in ap3ita a3pitu a2pl apoc5
ap5ola apor5i apos3t aps5es a3pu aque5 2a2r ar3act a5rade ar5adis ar3al
a5ramete aran4g ara3p ar4at a5ratio ar5ativ a5rau ar5av4 araw4 arbal4
ar4chan ar5dine ar4dr ar5eas a3ree ar3ent a5ress ar4fi ar4fl ar1i ar5ial
ar3ian a3riet ar4im ar5inat ar3io ar2iz ar2mi ar5o5d a5roni a3roo ar2p ar3q
arre4 ar4sa ar2sh 4as. as4ab as3ant ashi4 a5sia. a3sib a3sic 5a5si4t ask3i
as4l a4soc as5ph as4sh as3ten as1tr asur5a a2ta at3abl at5ac at3alo at5ap
ate5c at5ech at3ego at3en. at3era ater5n a5terna at3est at5ev 4ath ath5em
a5then at4ho ath5om 4ati. a5tia at5i5b at1ic at3if ation5ar at3itu a4tog
a2tom at5omiz a4top a4tos a1tr at5rop at4sk at4tag at5te at4th a2tu at5ua
at5ue at3ul at3ura a2ty au4b augh3 au3gu au4l2 aun5d au3r au5sib aut5en
au1th a2va av3ag a5van ave4no av3era av5ern av5ery av1i avi4er av3ig av5oc
a1vor 3away aw3i aw4ly aws4 ax4ic ax4id ay5al aye4 ays4 azi4er azz5i
5ba. bad5ger ba4ge bal1a ban5dag ban4e ban3i barbi5 bari4a bas4si 1bat ba4z
2b1b b2be b3ber bbi4na 4b1d 4be. beak4 beat3 4be2d be3da be3de be3di be3gi
be5gu 1bel be1li be3lo 4be5m be5nig be5nu 4bes4 be3sp be5str 3bet bet5iz
be5tr be3tw be3w be5yo 2bf 4b3h bi2b bi4d 3bie bi5en bi4er 2b3if 1bil
bi3liz bina5r4 bin4d bi5net bi3ogr bi5ou bi2t 3bi3tio bi3tr 3bit5ua b5itz
b1j bk4 b2l2 blath5 b4le. blen4 5blesp b3lis b4lo blun4t 4b1m 4b3n bne5g
3bod bod3i bo4e bol3ic bom4bi bon4a bon5at 3boo 5bor. 4b1ora bor5d 5bore
5bori 5bos4 b5ota both5 bo4to bound3 4bp 4brit broth3 2b5s2 bsor4 2bt bt4l
b4to b3tr buf4fer bu4ga bu3li bumi4 bu4n bunt4i bu3re bus5ie buss4e 5bust
4buta 3butio b5uto b1v 4b5w 5by. bys4 1ca cab3in ca1bl cach4 ca5den 4cag4
2c5ah ca3lat cal4la call5in 4calo can5d can4e can4ic can5is can3iz can4ty
cany4 ca5per car5om cast5er cas5tig 4casy ca4th 4cativ cav5al c3c ccha5
cci4a ccompa5 ccon4 ccou3t 2ce. 4ced. 4ceden 3cei 5cel. 3cell 1cen 3cenc
2cen4e 4ceni 3cent 3cep ce5ram 4cesa 3cessi ces5si5b ces5t cet4 c5e4ta cew4
2ch 4ch. 4ch3ab 5chanic ch5a5nis che2 cheap3 4ched che5lo 3chemi ch5ene
ch3er. ch3ers 4ch1in 5chine. ch5iness 5chini 5chio 3chit chi2z 3cho2 ch4ti
1ci 3cia ci2a5b cia5r ci5c 4cier 5cific. 4cii ci4la 3cili 2cim 2cin c4ina
3cinat cin3em c1ing c5ing. 5cino cion4 4cipe ci3ph 4cipic 4cista 4cisti
2c1it cit3iz 5ciz ck1 ck3i 1c4l4 4clar c5laratio 5clare cle4m 4clic clim4
cly4 c5n 1co co5ag coe2 2cog co4gr coi4 co3inc col5i 5colo col3or com5er
con4a c4one con3g con5t co3pa cop3ic co4pl 4corb coro3n cos4e cov1 cove4
cow5a coz5e co5zi c1q cras5t 5crat. 5cratic cre3at 5cred 4c3reta cre4v cri2
cri5f c4rin cris4 5criti cro4pl crop5o cros4e cru4d 4c3s2 2c1t cta4b ct5ang
c5tant c2te c3ter c4ticu ctim3i ctu4r c4tw cud5 c4uf c4ui cu5ity 5culi
cul4tis 3cultu cu2ma c3ume cu4mi 3cun cu3pi cu5py cur5a4b cu5ria 1cus
cuss4i 3c4ut cu4tie 4c5utiv 4cutr 1cy cze4 1d2a 5da. 2d3a4b dach4 4daf 2dag
da2m2 dan3g dard5 dark5 4dary 3dat 4dativ 4dato 5dav4 dav5e 5day d1b d5c
d1d4 2de. deaf5 deb5it de4bon decan4 de4cil de5com 2d1ed 4dee. de5if deli4e
del5i5q de5lo d4em 5dem. 3demic dem5ic. de5mil de4mons demor5 1den de4nar
de3no denti5f de3nu de1p de3pa depi4 de2pu d3eq d4erh 5derm dern5iz der5s
des2 d2es. de1sc de2s5o des3ti de3str de4su de1t de2to de1v dev3il 4dey
4d1f d4ga d3ge4t dg1i d2gy d1h2 5di. 1d4i3a dia5b di4cam d4ice 3dict 3did
5di3en d1if di3ge di4lato d1in 1dina 3dine. 5dini di5niz 1dio dio5g di4pl
dir2 di1re dirt5i dis1 5disi d4is3t d2iti 1di1v d1j d5k2 4d5la 3dle. 3dled
3dles. 4dless 2d3lo 4d5lu 2dly d1m 4d1n4 1do 3do. do5de 5doe 2d5of d4og
do4la doli4 do5lor dom5iz do3nat doni4 doo3d dop4p d4or 3dos 4d5out do4v
3dox d1p 1dr drag5on 4drai dre4 drea5r 5dren dri4b dril4 dro4p 4drow
5drupli 4dry 2d1s2 ds4p d4sw d4sy d2th 1du d1u1a du2c d1uca duc5er
4duct. 4ducts du5el du4g d3ule dum4be du4n 4dup du4pe d1v d1w d2y 5dyn
dy4se dys5p e1a4b e3act ead1 ead5ie ea4ge ea5ger ea4l eal5er eal3ou eam3er
e5and ear3a ear4c ear5es ear4ic ear4il ear5k ear2t eart3e ea5sp e3ass east3
ea2t eat5en eath3i e5atif e4a3tu ea2v eav3en eav5i eav5o 2e1b e4bel. e4bels
e4ben e4bit e3br e4cad ecan5c ecca5 e1ce ec5essa ec2i e4cib ec5ificat
ec5ifie ec5ify ec3im eci4t e5cite e4clam e4clus e2col e4comm e4compe e4conc
e2cor ec3ora eco5ro e1cr e4crem ec4tan ec4te e1cu e4cul ec3ula 2e2da 4ed3d
e4d1er ede4s 4edi e3dia ed3ib ed3ica ed3im ed1it edi5z 4edo e4dol edon2
e4dri e4dul ed5ulo ee2c eed3i ee2f eel3i ee4ly ee2m ee4na ee4p1 ee2s4 eest4
ee4ty e5ex e1f e4f3ere 1eff e4fic 5efici efil4 e3fine ef5i5nite 3efit
efor5es e4fuse. 4egal eger4 eg5ib eg4ic eg5ing e5git5 eg5n e4go. e4gos
eg1ul e5gur 5egy e1h4 eher4 ei2 e5ic ei5d eig2 ei5gl e3imb e3inf e1ing
e5inst eir4d eit3e ei3th e5ity e1j e4jud ej5udi eki4n ek4la e1la
e4la. e4lac elan4d el5ativ e4law elaxa4 e3lea el5ebra 5elec e4led el3ega
e5len e4l1er e1les el2f el2i e3libe e4l5ic. el3ica e3lier el5igib e5lim
e4l3ing e3lio e2lis el5ish e3liv3 4ella el4lab ello4 e5loc el5og
el3op. el2sh el4ta e5lud el5ug e4mac e4mag e5man em5ana em5b e1me e2mel
e4met em3ica emi4e em5igra em1in2 em5ine em3i3ni e4mis em5ish e5miss em3iz
5emniz emo4g emoni5o em3pi e4mul em5ula emu3n e3my en5amo e4nant ench4er
en3dic e5nea e5nee en3em en5ero en5esi en5est en3etr e3new en5ics e5nie
e5nil e3nio en3ish en3it e5niu 5eniz 4enn 4eno eno4g e4nos en3ov en4sw
ent5age 4enthes en3ua en5uf e3ny. 4en3z e5of eo2g e4oi4 e3ol eop3ar e1or
eo3re eo5rol eos4 e4ot eo4to e5out e5ow e2pa e3pai ep5anc e5pel e3pent
ep5etitio ephe4 e4pli e1po e4prec ep5reca e4pred ep3reh e3pro e4prob ep4sh
ep5ti5b e4put ep5uta e1q equi3l e4q3ui3s er1a era4b 4erand er3ar
4erati. 2erb er4bl er3ch er4che 2ere. e3real ere5co ere3in er5el. er3emo
er5ena er5ence 4erene er3ent ere4q er5ess er3est eret4 er1h er1i e1ria4
5erick e3rien eri4er er3ine e1rio 4erit er4iu eri4v e4riva er3m4 er4nis
4ernit 5erniz er3no 2ero er5ob e5roc ero4r er1ou er1s er3set ert3er 4ertl
er3tw 4eru eru4t 5erwau e1s4a e4sage. e4sages es2c e2sca es5can e3scr es5cu
e1s2e e2sec es5ecr es5enc e4sert. e4serts e4serva 4esh e3sha esh5en e1si
e2sic e2sid es5iden es5igna e2s5im es4i4n esis4te esi4u e5skin es4mi e2sol
es3olu e2son es5ona e1sp es3per es5pira es4pre 2ess es4si4b estan4 es3tig
es5tim 4es2to e3ston 2estr e5stro estruc5 e2sur es5urr es4w eta4b eten4d
e3teo ethod3 et1ic e5tide etin4 eti4no e5tir e5titio et5itiv 4etn et5ona
e3tra e3tre et3ric et5rif et3rog et5ros et3ua et5ym et5z 4eu e5un e3up
eu3ro eus4 eute4 euti5l eu5tr eva2p5 e2vas ev5ast e5vea ev3ell evel3o
e5veng even4i ev1er e5verb e1vi ev3id evi4l e4vin evi4v e5voc e5vu e1wa
e4wag e5wee e3wh ewil5 ew3ing e3wit 1exp 5eyc 5eye. eys4 1fa fa3bl fab3r
fa4ce 4fag fain4 fall5e 4fa4ma fam5is 5far far5th fa3ta fa3the 4fato fault5
4f5b 4fd 4fe. feas4 feath3 fe4b 4feca 5fect 2fed fe3li fe4mo fen2d fend5e
fer1 5ferr fev4 4f1f f4fes f4fie f5fin. f2f5is f4fly f2fy 4fh 1fi fi3a
2f3ic. 4f3ical f3ican 4ficate f3icen fi3cer fic4i 5ficia 5ficie 4fics fi3cu
fi5del fight5 fil5i fill5in 4fily 2fin 5fina fin2d5 fi2ne f1in3g fin4n
fis4ti f4l2 f5less flin4 flo3re f2ly5 4fm 4fn 1fo 5fon fon4de fon4t fo2r
fo5rat for5ay fore5t for4i fort5a fos5 4f5p fra4t f5rea fres5c fri2 fril4
frol5 2f3s 2ft f4to f2ty 3fu fu5el 4fug fu4min fu5ne fu3ri fusi4 fus4s
4futa 1fy 1ga gaf4 5gal. 3gali ga3lo 2gam ga5met g5amo gan5is ga3niz
gani5za 4gano gar5n4 gass4 gath3 4gativ 4gaz g3b gd4 2ge. 2ged geez4 gel4in
ge5lis ge5liz 4gely 1gen ge4nat ge5niz 4geno 4geny 1geo ge3om g4ery 5gesi
geth5 4geto ge4ty ge4v 4g1g2 g2ge g3ger gglu5 ggo4 gh3in gh5out gh4to
5gi. 1gi4a gia5r g1ic 5gicia g4ico gien5 5gies. gil4 g3imen 3g4in. gin5ge
5g4ins 5gio 3gir gir4l g3isl gi4u 5giv 3giz gl2 gla4 glad5i 5glas 1gle
gli4b g3lig 3glo glo3r g1m g4my gn4a g4na. gnet4t g1ni g2nin g4nio g1no
g4non 1go 3go. gob5 5goe 3g4o4g go3is gon2 4g3o3na gondo5 go3ni 5goo go5riz
gor5ou 5gos. gov1 g3p 1gr 4grada g4rai gran2 5graph. g5rapher 5graphic
4graphy 4gray gre4n 4gress. 4grit g4ro gruf4 gs2 g5ste gth3 gu4a 3guard
2gue 5gui5t 3gun 3gus 4gu4t g3w 1gy 2g5y3n gy5ra h3ab4l hach4 hae4m hae4t
h5agu ha3la hala3m ha4m han4ci han4cy 5hand. han4g hang5er hang5o h5a5niz
han4k han4te hap3l hap5t ha3ran ha5ras har2d hard3e har4le harp5en har5ter
has5s haun4 5haz haz3a h1b 1head 3hear he4can h5ecat h4ed he5do5 he3l4i
hel4lis hel4ly h5elo hem4p he2n hena4 hen5at heo5r hep5 h4era hera3p her4ba
here5a h3ern h5erou h3ery h1es he2s5p he4t het4ed heu4 h1f h1h hi5an hi4co
high5 h4il2 himer4 h4ina hion4e hi4p hir4l hi3ro hir4p hir4r his3el his4s
hith5er hi2v 4hk 4h1l4 hlan4 h2lo hlo3ri 4h1m hmet4 2h1n h5odiz h5ods ho4g
hoge4 hol5ar 3hol4e ho4ma home3 hon4a ho5ny 3hood hoon4 hor5at ho5ris
hort3e ho5ru hos4e ho5sen hos1p 1hous house3 hov5el 4h5p 4hr4 hree5 hro5niz
hro3po 4h1s2 h4sh h4tar ht1en ht5es h4ty hu4g hu4min hun5ke hun4t hus3t4
hu4t h1w h4wart hy3pe hy3ph hy2s 2i1a i2al iam4 iam5ete i2an 4ianc ian3i
4ian4t ia5pe iass4 i4ativ ia4tric i4atu ibe4 ib3era ib5ert ib5ia ib3in
ib5it. ib5ite i1bl ib3li i5bo i1br i2b5ri i5bun 4icam 5icap 4icar
i4car. i4cara icas5 i4cay iccu4 4iceo 4ich 2ici i5cid ic5ina i2cip ic3ipa
i4cly i2c5oc 4i1cr 5icra i4cry ic4te ictu2 ic4t3ua ic3ula ic4um ic5uo i3cur
2id i4dai id5anc id5d ide3al ide4s i2di id5ian idi4ar i5die id3io idi5ou
id1it id5iu i3dle i4dom id3ow i4dr i2du id5uo 2ie4 ied4e 5ie5ga ield3
ien5a4 ien4e i5enn i3enti i1er. i3esc i1est i3et 4if. if5ero iff5en if4fr
4ific. i3fie i3fl 4ift 2ig iga5b ig3era ight3i 4igi i3gib ig3il ig3in ig3it
i4g4l i2go ig3or ig5ot i5gre igu5i ig1ur i3h 4i5i4 i3j 4ik i1la il3a4b
i4lade i2l5am ila5ra i3leg il1er ilev4 il5f il1i il3ia il2ib il3io il4ist
2ilit il2iz ill5ab 4iln il3oq il4ty il5ur il3v i4mag im3age ima5ry imenta5r
4imet im1i im5ida imi5le i5mini 4imit im4ni i3mon i2mu im3ula 2in. i4n3au
4inav incel4 in3cer 4ind in5dling 2ine i3nee iner4ar i5ness 4inga 4inge
in5gen 4ingi in5gling 4ingo 4ingu 2ini i5ni. i4nia in3io in1is
i5nite. 5initio in3ity 4ink 4inl 2inn 2i1no i4no4c ino4s i4not 2ins in3se
insur5a 2int. 2in4th in1u i5nus 4iny 2io 4io. ioge4 io2gr i1ol io4m ion3at
ion4ery ion3i io5ph ior3i i4os io5th i5oti io4to i4our 2ip ipe4 iphras4
ip3i ip4ic ip4re4 ip3ul i3qua iq5uef iq3uid iq3ui3t 4ir i1ra ira4b i4rac
ird5e ire4de i4ref i4rel4 i4res ir5gi ir1i iri5de ir4is iri3tu 5i5r2iz
ir4min iro4g 5iron. ir5ul 2is. is5ag is3ar isas5 2is1c is3ch 4ise is3er
3isf is5han is3hon ish5op is3ib isi4d i5sis is5itiv 4is4k islan4 4isms i2so
iso5mer is1p is2pi is4py 4is1s is4sal issen4 is4ses is4ta. is1te is1ti
ist4ly 4istral i2su is5us 4ita. ita4bi i4tag 4ita5m i3tan i3tat 2ite it3era
i5teri it4es 2ith i1ti 4itia 4i2tic it3ica 5i5tick it3ig it5ill i2tim 2itio
4itis i4tism i2t5o5m 4iton i4tram it5ry 4itt it3uat i5tud it3ul 4itz. i1u
2iv iv3ell iv3en. i4v3er. i4vers. iv5il. iv5io iv1it i5vore iv3o3ro i4v3ot
4i5w ix4o 4iy 4izar izi4 5izont 5ja jac4q ja4p 1je jer5s 4jestie 4jesty
jew3 jo4p 5judg 3ka. k3ab k5ag kais4 kal4 k1b k2ed 1kee ke4g ke5li k3en4d
k1er kes4 k3est. ke4ty k3f kh4 k1i 5ki. 5k2ic k4ill kilo5 k4im k4in. kin4de
k5iness kin4g ki4p kis4 k5ish kk4 k1l 4kley 4kly k1m k5nes 1k2no ko5r kosh4
k3ou kro5n 4k1s2 k4sc ks4l k4sy k5t k1w lab3ic l4abo laci4 l4ade la3dy
lag4n lam3o 3land lan4dl lan5et lan4te lar4g lar3i las4e la5tan 4lateli
4lativ 4lav la4v4a 2l1b lbin4 4l1c2 lce4 l3ci 2ld l2de ld4ere ld4eri ldi4
ld5is l3dr l4dri le2a le4bi left5 5leg. 5legg le4mat lem5atic 4len. 3lenc
5lene. 1lent le3ph le4pr lera5b ler4e 3lerg 3l4eri l4ero les2 le5sco 5lesq
3less 5less. l3eva lev4er. lev4era lev4ers 3ley 4leye 2lf l5fr 4l1g4 l5ga
lgar3 l4ges lgo3 2l3h li4ag li2am liar5iz li4as li4ato li5bi 5licio li4cor
4lics 4lict. l4icu l3icy l3ida lid5er 3lidi lif3er l4iff li4fl 5ligate
3ligh li4gra 3lik 4l4i4l lim4bl lim3i li4mo l4im4p l4ina 1l4ine lin3ea
lin3i link5er li5og 4l4iq lis4p l1it l2it. 5litica l5i5tics liv3er l1iz 4lj
lka3 l3kal lka4t l1l l4law l2le l5lea l3lec l3leg l3lel l3le4n l3le4t ll2i
l2lin4 l5lina ll4o lloqui5 ll5out l5low 2lm l5met lm3ing l4mod lmon4 2l1n2
3lo. lob5al lo4ci 4lof 3logic l5ogo 3logu lom3er 5long lon4i l3o3niz lood5
5lope. lop3i l3opm lora4 lo4rato lo5rie lor5ou 5los. los5et 5losophiz
5losophy los4t lo4ta loun5d 2lout 4lov 2lp lpa5b l3pha l5phi lp5ing l3pit
l4pl l5pr 4l1r 2l1s2 l4sc l2se l4sie 4lt lt5ag ltane5 l1te lten4 ltera4
lth3i l5ties. ltis4 l1tr ltu2 ltur3a lu5a lu3br luch4 lu3ci lu3en luf4
lu5id lu4ma 5lumi l5umn. 5lumnia lu3o luo3r 4lup luss4 lus3te 1lut l5ven
l5vet4 2l1w 1ly 4lya 4lyb ly5me ly3no 2lys4 l5yse 1ma 2mab ma2ca ma5chine
ma4cl mag5in 5magn 2mah maid5 4mald ma3lig ma5lin mal4li mal4ty 5mania
man5is man3iz 4map ma5rine. ma5riz mar4ly mar3v ma5sce mas4e mas1t 5mate
math3 ma3tis 4matiza 4m1b mba4t5 m5bil m4b3ing mbi4v 4m5c 4me. 2med
4med. 5media me3die m5e5dy me2g mel5on mel4t me2m mem1o3 1men men4a men5ac
men4de 4mene men4i mens4 mensu5 3ment men4te me5on m5ersa 2mes 3mesti me4ta
met3al me1te me5thi m4etr 5metric me5trie me3try me4v 4m1f 2mh 5mi. mi3a
mid4a mid4g mig4 3milia m5i5lie m4ill min4a 3mind m5inee m4ingl min5gli
m5ingly min4t m4inu miot4 m2is mis4er. mis5l mis4ti m5istry 4mith m2iz 4mk
4m1l m1m mma5ry 4m1n mn4a m4nin mn4o 1mo 4mocr 5mocratiz mo2d1 mo4go mois2
moi5se 4mok mo5lest mo3me mon5et mon5ge moni3a mon4ism mon4ist mo3niz
monol4 mo3ny. mo2r 4mora. mos2 mo5sey mo3sp moth3 m5ouf 3mous mo2v 4m1p
mpara5 mpa5rab mpar5i m3pet mphas4 m2pi mpi4a mp5ies m4p1in m5pir mp5is
mpo3ri mpos5ite m4pous mpov5 mp4tr m2py 4m3r 4m1s2 m4sh m5si 4mt 1mu
mula5r4 5mult multi3 3mum mun2 4mup mu4u 4mw 1na 2n1a2b n4abu 4nac. na4ca
n5act nag5er. nak4 na4li na5lia 4nalt na5mit n2an nanci4 nan4it nank4 nar3c
4nare nar3i nar4l n5arm n4as nas4c nas5ti n2at na3tal nato5miz n2au nau3se
3naut nav4e 4n1b4 ncar5 n4ces. n3cha n5cheo n5chil n3chis nc1in nc4it
ncour5a n1cr n1cu n4dai n5dan n1de nd5est. ndi4b n5d2if n1dit n3diz n5duc
ndu4r nd2we 2ne. n3ear ne2b neb3u ne2c 5neck 2ned ne4gat neg5ativ 5nege
ne4la nel5iz ne5mi ne4mo 1nen 4nene 3neo ne4po ne2q n1er nera5b n4erar
n2ere n4er5i ner4r 1nes 2nes. 4nesp 2nest 4nesw 3netic ne4v n5eve ne4w n3f
n4gab n3gel nge4n4e n5gere n3geri ng5ha n3gib ng1in n5git n4gla ngov4 ng5sh
n1gu n4gum n2gy 4n1h4 nha4 nhab3 nhe4 3n4ia ni3an ni4ap ni3ba ni4bl ni4d
ni5di ni4er ni2fi ni5ficat n5igr nik4 n1im ni3miz n1in 5nine. nin4g ni4o
5nis. nis4ta n2it n4ith 3nitio n3itor ni3tr n1j 4nk2 n5kero n3ket nk3in
n1kl 4n1l n5m nme4 nmet4 4n1n2 nne4 nni3al nni4v nob4l no3ble n5ocl 4n3o2d
3noe 4nog noge4 nois5i no5l4i 5nologis 3nomic n5o5miz no4mo no3my no4n
non4ag non5i n5oniz 4nop 5nop5o5li nor5ab no4rary 4nosc nos4e nos5t no5ta
1nou 3noun nov3el3 nowl3 n1p4 npi4 npre4c n1q n1r nru4 2n1s2 ns5ab nsati4
ns4c n2se n4s3es nsid1 nsig4 n2sl ns3m n4soc ns4pe n5spi nsta5bl n1t nta4b
nter3s nt2i n5tib nti4er nti2f n3tine n4t3ing nti4p ntrol5li nt4s ntu3me
nu1a nu4d nu5en nuf4fe n3uin 3nu3it n4um nu1me n5umi 3nu4n n3uo nu3tr n1v2
n1w4 nym4 nyp4 4nz n3za 4oa oad3 o5a5les oard3 oas4e oast5e oat5i ob3a3b
o5bar obe4l o1bi o2bin ob5ing o3br ob3ul o1ce och4 o3chet ocif3 o4cil
o4clam o4cod oc3rac oc5ratiz ocre3 5ocrit octor5a oc3ula o5cure od5ded
od3ic odi3o o2do4 odor3 od5uct. od5ucts o4el o5eng o3er oe4ta o3ev o2fi
of5ite ofit4t o2g5a5r og5ativ o4gato o1ge o5gene o5geo o4ger o3gie 1o1gis
og3it o4gl o5g2ly 3ogniz o4gro ogu5i 1ogy 2ogyn o1h2 ohab5 oi2 oic3es
oi3der oiff4 oig4 oi5let o3ing oint5er o5ism oi5son oist5en oi3ter o5j 2ok
o3ken ok5ie o1la o4lan olass4 ol2d old1e ol3er o3lesc o3let ol4fi ol2i
o3lia o3lice ol5id. o3li4f o5lil ol3ing o5lio o5lis. ol3ish o5lite o5litio
o5liv olli4e ol5ogiz olo4r ol5pl ol2t ol3ub ol3ume ol3un o5lus ol2v o2ly
om5ah oma5l om5atiz om2be om4bl o2me om3ena om5erse o4met om5etry o3mia
om3ic. om3ica o5mid om1in o5mini 5ommend omo4ge o4mon om3pi ompro5 o2n on1a
on4ac o3nan on1c 3oncil 2ond on5do o3nen on5est on4gu on1ic o3nio on1is
o5niu on3key on4odi on3omy on3s onspi4 onspir5a onsu4 onten4 on3t4i ontif5
on5um onva5 oo2 ood5e ood5i oo4k oop3i o3ord oost5 o2pa ope5d op1er 3opera
4operag 2oph o5phan o5pher op3ing o3pit o5pon o4posi o1pr op1u opy5 o1q
o1ra o5ra. o4r3ag or5aliz or5ange ore5a o5real or3ei ore5sh or5est. orew4
or4gu 4o5ria or3ica o5ril or1in o1rio or3ity o3riu or2mi orn2e o5rof or3oug
or5pe 3orrh or4se ors5en orst4 or3thi or3thy or4ty o5rum o1ry os3al os2c
os4ce o3scop 4oscopi o5scr os4i4e os5itiv os3ito os3ity osi4u os4l o2so
os4pa os4po os2ta o5stati os5til os5tit o4tan otele4g ot3er. ot5ers o4tes
4oth oth5esi oth3i4 ot3ic. ot5ica o3tice o3tif o3tis oto5s ou2 ou3bl ouch5i
ou5et ou4l ounc5er oun2d ou5v ov4en over4ne over3s ov4ert o3vis oviti4
o5v4ol ow3der ow3el ow5est ow1i own5i o4wo oy1a 1pa pa4ca pa4ce pac4t p4ad
5pagan p3agat p4ai pain4 p4al pan4a pan3el pan4ty pa3ny pa1p pa4pu para5bl
par5age par5di 3pare par5el p4a4ri par4is pa2te pa5ter 5pathic pa5thy
pa4tric pav4 3pay 4p1b pd4 4pe. 3pe4a pear4l pe2c 2p2ed 3pede 3pedi pedia4
ped4ic p4ee pee4d pek4 pe4la peli4e pe4nan p4enc pen4th pe5on
p4era. pera5bl p4erag p4eri peri5st per4mal perme5 p4ern per3o per3ti pe5ru
per1v pe2t pe5ten pe5tiz 4pf 4pg 4ph. phar5i phe3no ph4er ph4es. ph1ic
5phie ph5ing 5phisti 3phiz ph2l 3phob 3phone 5phoni pho4r 4phs ph3t 5phu
1phy pi3a pian4 pi4cie pi4cy p4id p5ida pi3de 5pidi 3piec pi3en pi4grap
pi3lo pi2n p4in. pind4 p4ino 3pi1o pion4 p3ith pi5tha pi2tu 2p3k2 1p2l2
3plan plas5t pli3a pli5er 4plig pli4n ploi4 plu4m plum4b 4p1m 2p3n po4c
5pod. po5em po3et5 5po4g poin2 5point poly5t po4ni po4p 1p4or po4ry 1pos
pos1s p4ot po4ta 5poun 4p1p ppa5ra p2pe p4ped p5pel p3pen p3per p3pet
ppo5site pr2 pray4e 5preci pre5co pre3em pref5ac pre4la pre3r p3rese 3press
pre5ten pre3v 5pri4e prin4t3 pri4s pris3o p3roca prof5it pro3l pros3e pro1t
2p1s2 p2se ps4h p4sib 2p1t pt5a4b p2te p2th pti3m ptu4r p4tw pub3 pue4 puf4
pul3c pu4m pu2n pur4r 5pus pu2t 5pute put3er pu3tr put4ted put4tin p3w qu2
qua5v 2que. 3quer 3quet 2rab ra3bi rach4e r5acl raf5fi raf4t r2ai ra4lo
ram3et r2ami rane5o ran4ge r4ani ra5no rap3er 3raphy rar5c rare4 rar5ef
4raril r2as ration4 rau4t ra5vai rav3el ra5zie r1b r4bab r4bag rbi2 rbi4f
r2bin r5bine rb5ing. rb4o r1c r2ce rcen4 r3cha rch4er r4ci4b rc4it rcum3
r4dal rd2i rdi4a rdi4er rdin4 rd3ing 2re. re1al re3an re5arr 5reav re4aw
r5ebrat rec5oll rec5ompe re4cre 2r2ed re1de re3dis red5it re4fac re2fe
re5fer. re3fi re4fy reg3is re5it re1li re5lu r4en4ta ren4te re1o re5pin
re4posi re1pu r1er4 r4eri rero4 re5ru r4es. re4spi ress5ib res2t re5stal
re3str re4ter re4ti4z re3tri reu2 re5uti rev2 re4val rev3el
r5ev5er. re5vers re5vert re5vil rev5olu re4wh r1f rfu4 r4fy rg2 rg3er r3get
r3gic rgi4n rg3ing r5gis r5git r1gl rgo4n r3gu rh4 4rh. 4rhal ri3a ria4b
ri4ag r4ib rib3a ric5as r4ice 4rici 5ricid ri4cie r4ico rid5er ri3enc
ri3ent ri1er ri5et rig5an 5rigi ril3iz 5riman rim5i 3rimo rim4pe r2ina
5rina. rin4d rin4e rin4g ri1o 5riph riph5e ri2pl rip5lic r4iq r2is
r4is. ris4c r3ish ris4p ri3ta3b r5ited. rit5er. rit5ers rit3ic ri2tu rit5ur
riv5el riv3et riv3i r3j r3ket rk4le rk4lin r1l rle4 r2led r4lig r4lis
rl5ish r3lo4 r1m rma5c r2me r3men rm5ers rm3ing r4ming. r4mio r3mit r4my
r4nar r3nel r4ner r5net r3ney r5nic r1nis4 r3nit r3niv rno4 r4nou r3nu
rob3l r2oc ro3cr ro4e ro1fe ro5fil rok2 ro5ker 5role. rom5ete rom4i rom4p
ron4al ron4e ro5n4is ron4ta 1room 5root ro3pel rop3ic ror3i ro5ro ros5per
ros4s ro4the ro4ty ro4va rov5el rox5 r1p r4pea r5pent rp5er. r3pet rp4h4
rp3ing r3po r1r4 rre4c rre4f r4reo rre4st rri4o rri4v rron4 rros4 rrys4
4rs2 r1sa rsa5ti rs4c r2se r3sec rse4cr rs5er. rs3es rse5v2 r1sh r5sha r1si
r4si4b rson3 r1sp r5sw rtach4 r4tag r3teb rten4d rte5o r1ti rt5ib rti4d
r4tier r3tig rtil3i rtil4l r4tily r4tist r4tiv r3tri rtroph4 rt4sh ru3a
ru3e4l ru3en ru4gl ru3in rum3pl ru2n runk5 run4ty r5usc ruti5n rv4e rvel4i
r3ven rv5er. r5vest r3vey r3vic rvi4v r3vo r1w ry4c 5rynge ry3t sa2 2s1ab
5sack sac3ri s3act 5sai salar4 sal4m sa5lo sal4t 3sanc san4de s1ap sa5ta
5sa3tio sat3u sau4 sa5vor 5saw 4s5b scan4t5 sca4p scav5 s4ced 4scei s4ces
sch2 s4cho 3s4cie 5scin4d scle5 s4cli scof4 4scopy scour5a s1cu 4s5d
4se. se4a seas4 sea5w se2c3o 3sect 4s4ed se4d4e s5edl se2g seg3r 5sei se1le
5self 5selv 4seme se4mol sen5at 4senc sen4d s5ened sen5g s5enin 4sentd
4sentl sep3a3 4s1er. s4erl ser4o 4servo s1e4s se5sh ses5t 5se5um 5sev
sev3en sew4i 5sex 4s3f 2s3g s2h 2sh. sh1er 5shev sh1in sh3io 3ship shiv5
sho4 sh5old shon3 shor4 short5 4shw si1b s5icc 3side. 5sides 5sidi si5diz
4signa sil4e 4sily 2s1in s2ina 5sine. s3ing 1sio 5sion sion5a si2r sir5a
1sis 3sitio 5siu 1siv 5siz sk2 4ske s3ket sk5ine sk5ing s1l2 s3lat s2le
slith5 2s1m s3ma small3 sman3 smel4 s5men 5smith smol5d4 s1n4 1so so4ce
soft3 so4lab sol3d2 so3lic 5solv 3som 3s4on. sona4 son4g s4op 5sophic
s5ophiz s5ophy sor5c sor5d 4sov so5vi 2spa 5spai spa4n spen4d 2s5peo 2sper
s2phe 3spher spho5 spil4 sp5ing 4spio s4ply s4pon spor4 4spot squal4l s1r
2ss s1sa ssas3 s2s5c s3sel s5seng s4ses. s5set s1si s4sie ssi4er ss5ily
s4sl ss4li s4sn sspend4 ss2t ssur5a ss5w 2st. s2tag s2tal stam4i 5stand
s4ta4p 5stat. s4ted stern5i s5tero ste2w stew5a s3the st2i s4ti. s5tia
s1tic 5stick s4tie s3tif st3ing 5stir s1tle 5stock stom3a 5stone s4top
3store st4r s4trad 5stratu s4tray s4trid 4stry 4st3w s2ty 1su su1al su4b3
su2g3 su5is suit3 s4ul su2m sum3i su2n su2r 4sv sw2 4swo s4y 4syc 3syl
syn5o sy5rin 1ta 3ta. 2tab ta5bles 5taboliz 4taci ta5do 4taf4 tai5lo ta2l
ta5la tal5en tal3i 4talk tal4lis ta5log ta5mo tan4de tanta3 ta5per ta5pl
tar4a 4tarc 4tare ta3riz tas4e ta5sy 4tatic ta4tur taun4 tav4 2taw tax4is
2t1b 4tc t4ch tch5et 4t1d 4te. tead4i 4teat tece4 5tect 2t1ed te5di 1tee
teg4 te5ger te5gi 3tel. teli4 5tels te2ma2 tem3at 3tenan 3tenc 3tend 4tenes
1tent ten4tag 1teo te4p te5pe ter3c 5ter3d 1teri ter5ies ter3is teri5za
5ternit ter5v 4tes. 4tess t3ess. teth5e 3teu 3tex 4tey 2t1f 4t1g
2th. than4 th2e 4thea th3eas the5at the3is 3thet th5ic. th5ica 4thil 5think
4thl th5ode 5thodic 4thoo thor5it tho5riz 2ths 1tia ti4ab ti4ato 2ti2b
4tick t4ico t4ic1u 5tidi 3tien tif2 ti5fy 2tig 5tigu till5in 1tim 4timp
tim5ul 2t1in t2ina 3tine. 3tini 1tio ti5oc tion5ee 5tiq ti3sa 3tise tis4m
ti5so tis4p 5tistica ti3tl ti4u 1tiv tiv4a 1tiz ti3za ti3zen 2tl t5la tlan4
3tle. 3tled 3tles. t5let. t5lo 4t1m tme4 2t1n2 1to to3b to5crat 4todo 2tof
to2gr to5ic to2ma tom4b to3my ton4ali to3nat 4tono 4tony to2ra to3rie
tor5iz tos2 5tour 4tout to3war 4t1p 1tra tra3b tra5ch traci4 trac4it
trac4te tras4 tra5ven trav5es5 tre5f tre4m trem5i 5tria tri5ces 5tricia
4trics 2trim tri4v tro5mi tron5i 4trony tro5phe tro3sp tro3v tru5i trus4
4t1s2 t4sc tsh4 t4sw 4t3t2 t4tes t5to ttu4 1tu tu1a tu3ar tu4bi tud2 4tue
4tuf4 5tu3i 3tum tu4nis 2t3up. 3ture 5turi tur3is tur5o tu5ry 3tus 4tv tw4
4t1wa twis4 4two 1ty 4tya 2tyl type3 ty5ph 4tz tz4e 4uab uac4 ua5na uan4i
uar5ant uar2d uar3i uar3t u1at uav4 ub4e u4bel u3ber u4bero u1b4i u4b5ing
u3ble. u3ca uci4b uc4it ucle3 u3cr u3cu u4cy ud5d ud3er ud5est udev4 u1dic
ud3ied ud3ies ud5is u5dit u4don ud4si u4du u4ene uens4 uen4te uer4il 3ufa
u3fl ugh3en ug5in 2ui2 uil5iz ui4n u1ing uir4m uita4 uiv3 uiv4er. u5j 4uk
u1la ula5b u5lati ulch4 5ulche ul3der ul4e u1len ul4gi ul2i u5lia ul3ing
ul5ish ul4lar ul4li4b ul4lis 4ul3m u1l4o 4uls uls5es ul1ti ultra3 4ultu
u3lu ul5ul ul5v um5ab um4bi um4bly u1mi u4m3ing umor5o um2p unat4 u2ne
un4er u1ni un4im u2nin un5ish uni3v un3s4 un4sw unt3ab un4ter. un4tes unu4
un5y un5z u4ors u5os u1ou u1pe uper5s u5pia up3ing u3pl up3p upport5 upt5ib
uptu4 u1ra 4ura. u4rag u4ras ur4be urc4 ur1d ure5at ur4fer ur4fr u3rif
uri4fic ur1in u3rio u1rit ur3iz ur2l url5ing. ur4no uros4 ur4pe ur4pi
urs5er ur5tes ur3the urti4 ur4tie u3ru 2us u5sad u5san us4ap usc2 us3ci
use5a u5sia u3sic us4lin us1p us5sl us5tere us1tr u2su usur4 uta4b u3tat
4ute. 4utel 4uten uten4i 4u1t2i uti5liz u3tine ut3ing ution5a u4tis 5u5tiz
u4t1l ut5of uto5g uto5matic u5ton u4tou uts4 u3u uu4m u1v2 uxu3 uz4e 1va
5va. 2v1a4b vac5il vac3u vag4 va4ge va5lie val5o val1u va5mo va5niz va5pi
var5ied 3vat 4ve. 4ved veg3 v3el. vel3li ve4lo v4ely ven3om v5enue v4erd
5vere. v4erel v3eren ver5enc v4eres ver3ie vermi4n 3verse ver3th v4e2s
4ves. ves4te ve4te vet3er ve4ty vi5ali 5vian 5vide. 5vided 4v3iden 5vides
5vidi v3if vi5gn vik4 2vil 5vilit v3i3liz v1in 4vi4na v2inc vin5d 4ving
vio3l v3io4r vi1ou vi4p vi5ro vis3it vi3so vi3su 4viti vit3r 4vity 3viv
5vo. voi4 3vok vo4la v5ole 5volt 3volv vom5i vor5ab vori4 vo4ry vo4ta
4votee 4vv4 v4y w5abl 2wac wa5ger wag5o wait5 w5al. wam4 war4t was4t wa1te
wa5ver w1b wea5rie weath3 wed4n weet3 wee5v wel4l w1er west3 w3ev whi4 wi2
wil2 will5in win4de win4g wir4 3wise with3 wiz5 w4k wl4es wl3in w4no 1wo2
wom1 wo5ven w5p wra4 wri4 writa4 w3sh ws4l ws4pe w5s4t 4wt wy4 x1a xac5e
x4ago xam3 x4ap xas5 x3c2 x1e xe4cuto x2ed xer4i xe5ro x1h xhi2 xhil5 xhu4
x3i xi5a xi5c xi5di x4ime xi5miz x3o x4ob x3p xpan4d xpecto5 xpe3d x1t2
x3ti x1u xu3a xx4 y5ac 3yar4 y5at y1b y1c y2ce yc5er y3ch ych4e ycom4 ycot4
y1d y5ee y1er y4erf yes4 ye4t y5gi 4y3h y1i y3la ylla5bl y3lo y5lu ymbol5
yme4 ympa3 yn3chr yn5d yn5g yn5ic 5ynx y1o4 yo5d y4o5g yom4 yo5net y4ons
y4os y4ped yper5 yp3i y3po y4poc yp2ta y5pu yra5m yr5ia y3ro yr4r ys4c
y3s2e ys3ica ys3io 3ysis y4so yss4 ys1t ys3ta ysur4 y3thin yt3ic y1w za1
z5a2b zar2 4zb 2ze ze4n ze4p z1er ze3ro zet4 2z1i z4il z4is 5zl 4zm 1zo
zo4m zo5ol zte4 4z1z2 z4zy
"""
# Extra patterns, from ushyphmax.tex, dated 2005-05-30.
# Copyright (C) 1990, 2004, 2005 Gerard D.C. Kuiken.
# Copying and distribution of this file, with or without modification,
# are permitted in any medium without royalty provided the copyright
# notice and this notice are preserved.
#
# These patterns are based on the Hyphenation Exception Log
# published in TUGboat, Volume 10 (1989), No. 3, pp. 337-341,
# and a large number of incorrectly hyphenated words not yet published.
"""
.con5gr .de5riva .dri5v4 .eth1y6l1 .eu4ler .ev2 .ever5si5b .ga4s1om1
.ge4ome .ge5ot1 .he3mo1 .he3p6a .he3roe .in5u2t .kil2n3i .ko6r1te1 .le6ices
.me4ga1l .met4ala .mim5i2c1 .mi1s4ers .ne6o3f .noe1th .non1e2m .poly1s
.post1am .pre1am .rav5en1o .semi5 .sem4ic .semid6 .semip4 .semir4 .sem6is4
.semiv4 .sph6in1 .spin1o .ta5pes1tr .te3legr .to6pog .to2q .un3at5t
.un5err5 .vi2c3ar .we2b1l .re1e4c a5bolic a2cabl af6fish am1en3ta5b anal6ys
ano5a2c ans5gr ans3v anti1d an3ti1n2 anti1re a4pe5able ar3che5t ar2range
as5ymptot ath3er1o1s at6tes. augh4tl au5li5f av3iou back2er. ba6r1onie
ba1thy bbi4t be2vie bi5d2if bil2lab bio5m bi1orb bio1rh b1i3tive blan2d1
blin2d1 blon2d2 bor1no5 bo2t1u1l brus4q bus6i2er bus6i2es buss4ing
but2ed. but4ted cad5e1m cat1a1s2 4chs. chs3hu chie5vo cig3a3r cin2q cle4ar
co6ph1o3n cous2ti cri3tie croc1o1d cro5e2co c2tro3me6c 1cu2r1ance 2d3alone
data1b dd5a5b d2d5ib de4als. de5clar1 de2c5lina de3fin3iti de2mos des3ic
de2tic dic1aid dif5fra 3di1methy di2ren di2rer 2d1lead 2d1li2e 3do5word
dren1a5l drif2t1a d1ri3pleg5 drom3e5d d3tab du2al. du1op1o1l ea4n3ies
e3chas edg1l ed1uling eli2t1is e1loa en1dix eo3grap 1e6p3i3neph1 e2r3i4an.
e3spac6i eth1y6l1ene 5eu2clid1 feb1rua fermi1o 3fich fit5ted. fla1g6el
flow2er. 3fluor gen2cy. ge3o1d ght1we g1lead get2ic. 4g1lish 5glo5bin
1g2nac gnet1ism gno5mo g2n1or. g2noresp 2g1o4n3i1za graph5er. griev1 g1utan
hair1s ha2p3ar5r hatch1 hex2a3 hite3sid h3i5pel1a4 hnau3z ho6r1ic. h2t1eou
hypo1tha id4ios ifac1et ign4it ignit1er i4jk im3ped3a infra1s2
i5nitely. irre6v3oc i1tesima ith5i2l itin5er5ar janu3a japan1e2s je1re1m
1ke6ling 1ki5netic 1kovian k3sha la4c3i5e lai6n3ess lar5ce1n l3chai
l3chil6d1 lead6er. lea4s1a 1lec3ta6b le3g6en2dre 1le1noid lith1o5g ll1fl
l2l3ish l5mo3nell lo1bot1o1 lo2ges. load4ed. load6er. l3tea lth5i2ly lue1p
1lunk3er 1lum5bia. 3lyg1a1mi ly5styr ma1la1p m2an. man3u1sc mar1gin1
medi2c med3i3cin medio6c1 me3gran3 m2en. 3mi3da5b 3milita mil2l1ag
mil5li5li mi6n3is. mi1n2ut1er mi1n2ut1est m3ma1b 5maph1ro1 5moc1ra1t
mo5e2las mol1e5c mon4ey1l mono3ch mo4no1en moro6n5is mono1s6 moth4et2
m1ou3sin m5shack2 mu2dro mul2ti5u n3ar4chs. n3ch2es1t ne3back 2ne1ski
n1dieck nd3thr nfi6n3ites 4n5i4an. nge5nes ng1ho ng1spr nk3rup n5less
5noc3er1os nom1a6l nom5e1no n1o1mist non1eq non1i4so 5nop1oly. no1vemb
ns5ceiv ns4moo ntre1p obli2g1 o3chas odel3li odit1ic oerst2 oke1st
o3les3ter oli3gop1o1 o1lo3n4om o3mecha6 onom1ic o3norma o3no2t1o3n o3nou
op1ism. or4tho3ni4t orth1ri or5tively o4s3pher o5test1er o5tes3tor
oth3e1o1s ou3ba3do o6v3i4an. oxi6d1ic pal6mat parag6ra4 par4a1le param4
para3me pee2v1 phi2l3ant phi5lat1e3l pi2c1a3d pli2c1ab pli5nar poin3ca
1pole. poly1e po3lyph1ono 1prema3c pre1neu pres2pli pro2cess
proc3i3ty. pro2g1e 3pseu2d pseu3d6o3d2 pseu3d6o3f2 pto3mat4 p5trol3
pu5bes5c quain2t1e qu6a3si3 quasir6 quasis6 quin5tes5s qui3v4ar r1abolic
3rab1o1loi ra3chu r3a3dig radi1o6g r2amen 3ra4m5e1triz ra3mou ra5n2has
ra1or r3bin1ge re2c3i1pr rec5t6ang re4t1ribu r3ial. riv1o1l 6rk. rk1ho
r1krau 6rks. r5le5qu ro1bot1 ro5e2las ro5epide1 ro3mesh ro1tron r3pau5li
rse1rad1i r1thou r1treu r1veil rz1sc sales3c sales5w 5sa3par5il sca6p1er
sca2t1ol s4chitz schro1ding1 1sci2utt scrap4er. scy4th1 sem1a1ph se3mes1t
se1mi6t5ic sep3temb shoe1st sid2ed. side5st side5sw si5resid sky1sc
3slova1kia 3s2og1a1my so2lute 3s2pace 1s2pacin spe3cio spher1o spi2c1il
spokes5w sports3c sports3w s3qui3to s2s1a3chu1 ss3hat s2s3i4an. s5sign5a3b
1s2tamp s2t1ant5shi star3tli sta1ti st5b 1stor1ab strat1a1g strib5ut st5scr
stu1pi4d1 styl1is su2per1e6 1sync 1syth3i2 swimm6 5tab1o1lism
ta3gon. talk1a5 t1a1min t6ap6ath 5tar2rh tch1c tch3i1er t1cr
teach4er. tele2g tele1r6o 3ter1gei ter2ic. t3ess2es tha4l1am tho3don
th1o5gen1i tho1k2er thy4l1an thy3sc 2t3i4an. ti2n3o1m t1li2er tolo2gy
tot3ic trai3tor1 tra1vers travers3a3b treach1e tr4ial. 3tro1le1um
trof4ic. tro3fit tro1p2is 3trop1o5les 3trop1o5lis t1ro1pol3it tsch3ie
ttrib1ut1 turn3ar t1wh ty2p5al ua3drati uad1ratu u5do3ny uea1m
u2r1al. uri4al. us2er. v1ativ v1oir5du1 va6guer vaude3v 1verely. v1er1eig
ves1tite vi1vip3a3r voice1p waste3w6a2 wave1g4 w3c week1n wide5sp wo4k1en
wrap3aro writ6er. x1q xquis3 y5che3d ym5e5try y1stro yes5ter1y
z3ian. z3o1phr z2z3w
"""
)
EXCEPTIONS = """
as-so-ciate as-so-ciates dec-li-na-tion oblig-a-tory phil-an-thropic present
presents project projects reci-procity re-cog-ni-zance ref-or-ma-tion
ret-ri-bu-tion ta-ble
"""
hyphenator = Hyphenator(PATTERNS, EXCEPTIONS)
hyphenate_word = hyphenator.hyphenate_word
import re
def split_paragraphs(text: str) -> list[tuple[str, int, int]]:
"""
Split the text into paragraphs.
Returns a list of paragraphs with their start and end indices of the original text.
"""
# Use a regex pattern to split on one or more blank lines
pattern = r"\n\s*\n"
# Find all splits in the text
splits = list(re.finditer(pattern, text))
paragraphs: list[tuple[str, int, int]] = []
start = 0
# Handle the case where there are no splits (i.e., single paragraph)
if not splits:
stripped = text.strip()
# skip empty
if not stripped:
return paragraphs
start_index = text.index(stripped)
return [(stripped, start_index, start_index + len(stripped))]
# Process each split
for split in splits:
end = split.start()
paragraph = text[start:end].strip()
if paragraph: # Only add non-empty paragraphs
para_start = start + text[start:end].index(paragraph)
para_end = para_start + len(paragraph)
paragraphs.append((paragraph, para_start, para_end))
start = split.end()
# Add the last paragraph
last_paragraph = text[start:].strip()
if last_paragraph:
para_start = start + text[start:].index(last_paragraph)
para_end = para_start + len(last_paragraph)
paragraphs.append((last_paragraph, para_start, para_end))
return paragraphs
import re
# rule based segmentation based on https://stackoverflow.com/a/31505798, works surprisingly well
def split_sentences(
text: str, min_sentence_len: int = 20
) -> list[tuple[str, int, int]]:
"""
the text may not contain substrings "<prd>" or "<stop>"
"""
alphabets = r"([A-Za-z])"
prefixes = r"(Mr|St|Mrs|Ms|Dr)[.]"
suffixes = r"(Inc|Ltd|Jr|Sr|Co)"
starters = r"(Mr|Mrs|Ms|Dr|Prof|Capt|Cpt|Lt|He\s|She\s|It\s|They\s|Their\s|Our\s|We\s|But\s|However\s|That\s|This\s|Wherever)"
acronyms = r"([A-Z][.][A-Z][.](?:[A-Z][.])?)"
websites = r"[.](com|net|org|io|gov|edu|me)"
digits = r"([0-9])"
multiple_dots = r"\.{2,}"
# fmt: off
text = text.replace("\n"," ")
text = re.sub(prefixes,"\\1<prd>", text)
text = re.sub(websites,"<prd>\\1", text)
text = re.sub(digits + "[.]" + digits,"\\1<prd>\\2",text)
# text = re.sub(multiple_dots, lambda match: "<prd>" * len(match.group(0)) + "<stop>", text)
# TODO(theomonnom): need improvement for ""..." dots", check capital + next sentence should not be
# small
text = re.sub(multiple_dots, lambda match: "<prd>" * len(match.group(0)), text)
if "Ph.D" in text:
text = text.replace("Ph.D.","Ph<prd>D<prd>")
text = re.sub(r"\s" + alphabets + "[.] "," \\1<prd> ",text)
text = re.sub(acronyms+" "+starters,"\\1<stop> \\2",text)
text = re.sub(alphabets + "[.]" + alphabets + "[.]" + alphabets + "[.]","\\1<prd>\\2<prd>\\3<prd>",text)
text = re.sub(alphabets + "[.]" + alphabets + "[.]","\\1<prd>\\2<prd>",text)
text = re.sub(r" "+suffixes+"[.] "+starters," \\1<stop> \\2",text)
text = re.sub(r" "+suffixes+"[.]"," \\1<prd>",text)
text = re.sub(r" " + alphabets + "[.]"," \\1<prd>",text)
if "”" in text:
text = text.replace(".”","”.")
if "\"" in text:
text = text.replace(".\"","\".")
if "!" in text:
text = text.replace("!\"","\"!")
if "?" in text:
text = text.replace("?\"","\"?")
text = text.replace(".",".<stop>")
text = text.replace("?","?<stop>")
text = text.replace("!","!<stop>")
text = text.replace("<prd>",".")
# fmt: on
splitted_sentences = text.split("<stop>")
text = text.replace("<stop>", "")
sentences: list[tuple[str, int, int]] = []
buff = ""
start_pos = 0
end_pos = 0
for match in splitted_sentences:
sentence = match.strip()
if not sentence:
continue
buff += " " + sentence
end_pos += len(match)
if len(buff) > min_sentence_len:
sentences.append((buff[1:], start_pos, end_pos))
start_pos = end_pos
buff = ""
if buff:
sentences.append((buff[1:], start_pos, len(text) - 1))
return sentences
import re
from . import tokenizer
def split_words(
text: str, ignore_punctuation: bool = True
) -> list[tuple[str, int, int]]:
"""
Split the text into words.
Returns a list of words with their start and end indices of the original text.
"""
matches = re.finditer(r"\S+", text)
words: list[tuple[str, int, int]] = []
for match in matches:
word = match.group(0)
start_pos = match.start()
end_pos = match.end()
if ignore_punctuation:
# TODO(theomonnom): acronyms passthrough
translation_table = str.maketrans("", "", "".join(tokenizer.PUNCTUATIONS))
word = word.translate(translation_table)
if not word:
continue
words.append((word, start_pos, end_pos))
return words
from __future__ import annotations
import functools
from dataclasses import dataclass
from . import (
_basic_hyphenator,
_basic_paragraph,
_basic_sent,
_basic_word,
token_stream,
tokenizer,
)
# Really naive implementation of SentenceTokenizer, WordTokenizer + hyphenate_word
# The basic tokenizer is rule-based and only English is really tested
__all__ = [
"SentenceTokenizer",
"WordTokenizer",
"hyphenate_word",
"tokenize_paragraphs",
]
@dataclass
class _TokenizerOptions:
language: str
min_sentence_len: int
stream_context_len: int
class SentenceTokenizer(tokenizer.SentenceTokenizer):
def __init__(
self,
*,
language: str = "english",
min_sentence_len: int = 20,
stream_context_len: int = 10,
) -> None:
self._config = _TokenizerOptions(
language=language,
min_sentence_len=min_sentence_len,
stream_context_len=stream_context_len,
)
def tokenize(self, text: str, *, language: str | None = None) -> list[str]:
return [
tok[0]
for tok in _basic_sent.split_sentences(
text, min_sentence_len=self._config.min_sentence_len
)
]
def stream(self, *, language: str | None = None) -> tokenizer.SentenceStream:
return token_stream.BufferedSentenceStream(
tokenizer=functools.partial(
_basic_sent.split_sentences,
min_sentence_len=self._config.min_sentence_len,
),
min_token_len=self._config.min_sentence_len,
min_ctx_len=self._config.stream_context_len,
)
class WordTokenizer(tokenizer.WordTokenizer):
def __init__(self, *, ignore_punctuation: bool = True) -> None:
self._ignore_punctuation = ignore_punctuation
def tokenize(self, text: str, *, language: str | None = None) -> list[str]:
return [
tok[0]
for tok in _basic_word.split_words(
text, ignore_punctuation=self._ignore_punctuation
)
]
def stream(self, *, language: str | None = None) -> tokenizer.WordStream:
return token_stream.BufferedWordStream(
tokenizer=functools.partial(
_basic_word.split_words, ignore_punctuation=self._ignore_punctuation
),
min_token_len=1,
min_ctx_len=1, # ignore
)
def hyphenate_word(word: str) -> list[str]:
return _basic_hyphenator.hyphenate_word(word)
def tokenize_paragraphs(text: str) -> list[str]:
return [tok[0] for tok in _basic_paragraph.split_paragraphs(text)]
from __future__ import annotations
import typing
from typing import Callable, Union
from ..utils import aio, shortuuid
from .tokenizer import SentenceStream, TokenData, WordStream
# Tokenizers can either provide us with a list of tokens or a list of tokens along with their start and end indices.
# If the start and end indices are not available, we attempt to locate the token within the text using str.find.
TokenizeCallable = Callable[[str], Union[list[str], list[tuple[str, int, int]]]]
class BufferedTokenStream:
def __init__(
self,
*,
tokenize_fnc: TokenizeCallable,
min_token_len: int,
min_ctx_len: int,
) -> None:
self._event_ch = aio.Chan[TokenData]()
self._tokenize_fnc = tokenize_fnc
self._min_ctx_len = min_ctx_len
self._min_token_len = min_token_len
self._current_segment_id = shortuuid()
self._buf_tokens: list[str] = [] # <= min_token_len
self._in_buf = ""
self._out_buf = ""
@typing.no_type_check
def push_text(self, text: str) -> None:
self._check_not_closed()
self._in_buf += text
if len(self._in_buf) < self._min_ctx_len:
return
while True:
tokens = self._tokenize_fnc(self._in_buf)
if len(tokens) <= 1:
break
if self._out_buf:
self._out_buf += " "
tok = tokens.pop(0)
tok_text = tok
if isinstance(tok, tuple):
tok_text = tok[0]
self._out_buf += tok_text
if len(self._out_buf) >= self._min_token_len:
self._event_ch.send_nowait(
TokenData(token=self._out_buf, segment_id=self._current_segment_id)
)
self._out_buf = ""
if isinstance(tok, tuple):
self._in_buf = self._in_buf[tok[2] :]
else:
tok_i = max(self._in_buf.find(tok), 0)
self._in_buf = self._in_buf[tok_i + len(tok) :].lstrip()
@typing.no_type_check
def flush(self) -> None:
self._check_not_closed()
if self._in_buf or self._out_buf:
tokens = self._tokenize_fnc(self._in_buf)
if tokens:
if self._out_buf:
self._out_buf += " "
if isinstance(tokens[0], tuple):
self._out_buf += " ".join([tok[0] for tok in tokens])
else:
self._out_buf += " ".join(tokens)
if self._out_buf:
self._event_ch.send_nowait(
TokenData(token=self._out_buf, segment_id=self._current_segment_id)
)
self._current_segment_id = shortuuid()
self._in_buf = ""
self._out_buf = ""
def end_input(self) -> None:
self.flush()
self._event_ch.close()
async def aclose(self) -> None:
self._event_ch.close()
def _check_not_closed(self) -> None:
if self._event_ch.closed:
cls = type(self)
raise RuntimeError(f"{cls.__module__}.{cls.__name__} is closed")
def __aiter__(self) -> "BufferedTokenStream":
return self
async def __anext__(self) -> TokenData:
return await self._event_ch.__anext__()
class BufferedSentenceStream(BufferedTokenStream, SentenceStream):
def __init__(
self,
*,
tokenizer: TokenizeCallable,
min_token_len: int,
min_ctx_len: int,
) -> None:
super().__init__(
tokenize_fnc=tokenizer,
min_token_len=min_token_len,
min_ctx_len=min_ctx_len,
)
class BufferedWordStream(BufferedTokenStream, WordStream):
def __init__(
self,
*,
tokenizer: TokenizeCallable,
min_token_len: int,
min_ctx_len: int,
) -> None:
super().__init__(
tokenize_fnc=tokenizer,
min_token_len=min_token_len,
min_ctx_len=min_ctx_len,
)
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import AsyncIterator
from ..utils import aio
# fmt: off
PUNCTUATIONS = ['!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', ':', ';', '<', '=', '>',
'?', '@', '[', '\\', ']', '^', '_', '`', '{', '|', '}', '~', '±', '—', '‘', '’', '“', '”', '…']
# fmt: on
@dataclass
class TokenData:
segment_id: str = ""
token: str = ""
class SentenceTokenizer(ABC):
@abstractmethod
def tokenize(self, text: str, *, language: str | None = None) -> list[str]:
pass
@abstractmethod
def stream(self, *, language: str | None = None) -> "SentenceStream":
pass
class SentenceStream(ABC):
def __init__(self) -> None:
self._event_ch = aio.Chan[TokenData]()
@abstractmethod
def push_text(self, text: str) -> None: ...
@abstractmethod
def flush(self) -> None: ...
@abstractmethod
def end_input(self) -> None: ...
@abstractmethod
async def aclose(self) -> None: ...
async def __anext__(self) -> TokenData:
return await self._event_ch.__anext__()
def __aiter__(self) -> AsyncIterator[TokenData]:
return self
def _do_close(self) -> None:
self._event_ch.close()
def _check_not_closed(self) -> None:
if self._event_ch.closed:
cls = type(self)
raise RuntimeError(f"{cls.__module__}.{cls.__name__} is closed")
class WordTokenizer(ABC):
@abstractmethod
def tokenize(self, text: str, *, language: str | None = None) -> list[str]:
pass
@abstractmethod
def stream(self, *, language: str | None = None) -> "WordStream":
pass
def format_words(self, words: list[str]) -> str:
return " ".join(words)
class WordStream(ABC):
def __init__(self) -> None:
self._event_ch = aio.Chan[TokenData]()
@abstractmethod
def push_text(self, text: str) -> None: ...
@abstractmethod
def flush(self) -> None: ...
@abstractmethod
def end_input(self) -> None: ...
@abstractmethod
async def aclose(self) -> None: ...
async def __anext__(self) -> TokenData:
return await self._event_ch.__anext__()
def __aiter__(self) -> AsyncIterator[TokenData]:
return self
def _do_close(self) -> None:
self._event_ch.close()
def _check_not_closed(self) -> None:
if self._event_ch.closed:
cls = type(self)
raise RuntimeError(f"{cls.__module__}.{cls.__name__} is closed")
from __future__ import annotations
from typing import AsyncIterable, overload
from . import _basic_word, tokenizer
@overload
def replace_words(
*,
text: str,
replacements: dict[str, str],
) -> str: ...
@overload
def replace_words(
*,
text: AsyncIterable[str],
replacements: dict[str, str],
) -> AsyncIterable[str]: ...
def replace_words(
*,
text: str | AsyncIterable[str],
replacements: dict[str, str],
) -> str | AsyncIterable[str]:
"""
Replace words in the given (async) text. The replacements are case-insensitive and the
replacement will keep the case of the original word.
Args:
text: text to replace words in
words: dictionary of words to replace
"""
replacements = {k.lower(): v for k, v in replacements.items()}
def _process_words(text, words):
offset = 0
processed_index = 0
for word, start_index, end_index in words:
no_punctuation = word.rstrip("".join(tokenizer.PUNCTUATIONS))
punctuation_off = len(word) - len(no_punctuation)
replacement = replacements.get(no_punctuation.lower())
if replacement:
text = (
text[: start_index + offset]
+ replacement
+ text[end_index + offset - punctuation_off :]
)
offset += len(replacement) - len(word) + punctuation_off
processed_index = end_index + offset
return text, processed_index
if isinstance(text, str):
words = _basic_word.split_words(text, ignore_punctuation=False)
text, _ = _process_words(text, words)
return text
else:
async def _replace_words():
buffer = ""
async for chunk in text:
buffer += chunk
words = _basic_word.split_words(buffer, ignore_punctuation=False)
if len(words) <= 1:
continue
buffer, procesed_index = _process_words(buffer, words[:-1])
yield buffer[:procesed_index]
buffer = buffer[procesed_index:]
if buffer:
words = _basic_word.split_words(buffer, ignore_punctuation=False)
buffer, _ = _process_words(buffer, words)
yield buffer
return _replace_words()
from .stt_forwarder import STTSegmentsForwarder
from .tts_forwarder import TTSSegmentsForwarder
__all__ = [
"TTSSegmentsForwarder",
"STTSegmentsForwarder",
]
from __future__ import annotations
from livekit import rtc
from ..utils import shortuuid
def find_micro_track_id(room: rtc.Room, identity: str) -> str:
p: rtc.RemoteParticipant | rtc.LocalParticipant | None = (
room.remote_participants.get(identity)
)
if identity == room.local_participant.identity:
p = room.local_participant
if p is None:
raise ValueError(f"participant {identity} not found")
# find first micro track
track_id = None
for track in p.track_publications.values():
if track.source == rtc.TrackSource.SOURCE_MICROPHONE:
track_id = track.sid
break
if track_id is None:
raise ValueError(f"participant {identity} does not have a microphone track")
return track_id
def segment_uuid() -> str:
return shortuuid("SG_")
from __future__ import annotations
import asyncio
import contextlib
from typing import Awaitable, Callable, Optional, Union
from livekit import rtc
from .. import stt
from ..log import logger
from . import _utils
BeforeForwardCallback = Callable[
["STTSegmentsForwarder", rtc.Transcription],
Union[rtc.Transcription, Awaitable[Optional[rtc.Transcription]]],
]
WillForwardTranscription = BeforeForwardCallback
def _default_before_forward_cb(
fwd: STTSegmentsForwarder, transcription: rtc.Transcription
) -> rtc.Transcription:
return transcription
class STTSegmentsForwarder:
"""
Forward STT transcription to the users. (Useful for client-side rendering)
"""
def __init__(
self,
*,
room: rtc.Room,
participant: rtc.Participant | str,
track: rtc.Track | rtc.TrackPublication | str | None = None,
before_forward_cb: BeforeForwardCallback = _default_before_forward_cb,
# backward compatibility
will_forward_transcription: WillForwardTranscription | None = None,
):
identity = participant if isinstance(participant, str) else participant.identity
if track is None:
track = _utils.find_micro_track_id(room, identity)
elif isinstance(track, (rtc.TrackPublication, rtc.Track)):
track = track.sid
if will_forward_transcription is not None:
logger.warning(
"will_forward_transcription is deprecated and will be removed in 1.5.0, use before_forward_cb instead",
)
before_forward_cb = will_forward_transcription
self._room, self._participant_identity, self._track_id = room, identity, track
self._before_forward_cb = before_forward_cb
self._queue = asyncio.Queue[Optional[rtc.TranscriptionSegment]]()
self._main_task = asyncio.create_task(self._run())
self._current_id = _utils.segment_uuid()
async def _run(self):
try:
while True:
seg = await self._queue.get()
if seg is None:
break
base_transcription = rtc.Transcription(
participant_identity=self._participant_identity,
track_sid=self._track_id,
segments=[seg], # no history for now
)
transcription = self._before_forward_cb(self, base_transcription)
if asyncio.iscoroutine(transcription):
transcription = await transcription
if not isinstance(transcription, rtc.Transcription):
transcription = _default_before_forward_cb(self, base_transcription)
if transcription.segments and self._room.isconnected():
await self._room.local_participant.publish_transcription(
transcription
)
except Exception:
logger.exception("error in stt transcription")
def update(self, ev: stt.SpeechEvent):
if ev.type == stt.SpeechEventType.INTERIM_TRANSCRIPT:
# TODO(theomonnom): We always take the first alternative, we should mb expose opt to the
# user?
text = ev.alternatives[0].text
self._queue.put_nowait(
rtc.TranscriptionSegment(
id=self._current_id,
text=text,
start_time=0,
end_time=0,
final=False,
language="", # TODO
)
)
elif ev.type == stt.SpeechEventType.FINAL_TRANSCRIPT:
text = ev.alternatives[0].text
self._queue.put_nowait(
rtc.TranscriptionSegment(
id=self._current_id,
text=text,
start_time=0,
end_time=0,
final=True,
language="", # TODO
)
)
self._current_id = _utils.segment_uuid()
async def aclose(self, *, wait: bool = True) -> None:
self._queue.put_nowait(None)
if not wait:
self._main_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._main_task
from __future__ import annotations
import asyncio
import contextlib
import time
from dataclasses import dataclass
from typing import Awaitable, Callable, Optional, Union
from livekit import rtc
from livekit.rtc.participant import PublishTranscriptionError
from .. import tokenize, utils
from ..log import logger
from ..tokenize.tokenizer import PUNCTUATIONS
from . import _utils
# 3.83 is the "baseline", the number of hyphens per second TTS returns in avg.
STANDARD_SPEECH_RATE = 3.83
BeforeForwardCallback = Callable[
["TTSSegmentsForwarder", rtc.Transcription],
Union[rtc.Transcription, Awaitable[Optional[rtc.Transcription]]],
]
WillForwardTranscription = BeforeForwardCallback
def _default_before_forward_callback(
fwd: TTSSegmentsForwarder, transcription: rtc.Transcription
) -> rtc.Transcription:
return transcription
@dataclass
class _TTSOptions:
room: rtc.Room
participant_identity: str
track_id: str
language: str
speed: float
word_tokenizer: tokenize.WordTokenizer
sentence_tokenizer: tokenize.SentenceTokenizer
hyphenate_word: Callable[[str], list[str]]
new_sentence_delay: float
before_forward_cb: BeforeForwardCallback
@dataclass
class _AudioData:
pushed_duration: float = 0.0
done: bool = False
@dataclass
class _TextData:
sentence_stream: tokenize.SentenceStream
pushed_text: str = ""
done: bool = False
forwarded_hyphens: int = 0
forwarded_sentences: int = 0
class TTSSegmentsForwarder:
"""
Forward TTS transcription to the users. This class tries to imitate the right timing of
speech with the synthesized text. The first estimation is based on the speed argument. Once
we have received the full audio of a specific text segment, we recalculate the avg speech
speed using the length of the text & audio and catch up/ slow down the transcription if needed.
"""
def __init__(
self,
*,
room: rtc.Room,
participant: rtc.Participant | str,
track: rtc.Track | rtc.TrackPublication | str | None = None,
language: str = "",
speed: float = 1.0,
new_sentence_delay: float = 0.4,
word_tokenizer: tokenize.WordTokenizer = tokenize.basic.WordTokenizer(),
sentence_tokenizer: tokenize.SentenceTokenizer = tokenize.basic.SentenceTokenizer(),
hyphenate_word: Callable[[str], list[str]] = tokenize.basic.hyphenate_word,
before_forward_cb: BeforeForwardCallback = _default_before_forward_callback,
loop: asyncio.AbstractEventLoop | None = None,
# backward compatibility
will_forward_transcription: WillForwardTranscription | None = None,
):
"""
Args:
room: room where the transcription will be sent
participant: participant or identity that is pushing the TTS
track: track where the TTS audio is being sent
language: language of the text
speed: average speech speed in characters per second (used by default if the full audio is not received yet)
new_sentence_delay: delay in seconds between sentences
auto_playout: if True, the forwarder will automatically start the transcription once the
first audio frame is received. If False, you need to call segment_playout_started
to start the transcription.
word_tokenizer: word tokenizer used to split the text into words
sentence_tokenizer: sentence tokenizer used to split the text into sentences
hyphenate_word: function that returns a list of hyphens for a given word
"""
identity = participant if isinstance(participant, str) else participant.identity
if track is None:
track = _utils.find_micro_track_id(room, identity)
elif isinstance(track, (rtc.TrackPublication, rtc.Track)):
track = track.sid
if will_forward_transcription is not None:
logger.warning(
"will_forward_transcription is deprecated and will be removed in 1.5.0, use before_forward_cb instead",
)
before_forward_cb = will_forward_transcription
speed = speed * STANDARD_SPEECH_RATE
self._opts = _TTSOptions(
room=room,
participant_identity=identity,
track_id=track,
language=language,
speed=speed,
word_tokenizer=word_tokenizer,
sentence_tokenizer=sentence_tokenizer,
hyphenate_word=hyphenate_word,
new_sentence_delay=new_sentence_delay,
before_forward_cb=before_forward_cb,
)
self._closed = False
self._loop = loop or asyncio.get_event_loop()
self._close_future = asyncio.Future[None]()
self._playing_seg_index = -1
self._finshed_seg_index = -1
self._text_q_changed = asyncio.Event()
self._text_q = list[Union[_TextData, None]]()
self._audio_q_changed = asyncio.Event()
self._audio_q = list[Union[_AudioData, None]]()
self._text_data: _TextData | None = None
self._audio_data: _AudioData | None = None
self._played_text = ""
self._main_atask: asyncio.Task | None = None
self._task_set = utils.aio.TaskSet(loop)
def segment_playout_started(self) -> None:
"""
Notify that the playout of the audio segment has started.
This will start forwarding the transcription for the current segment.
"""
self._check_not_closed()
self._playing_seg_index += 1
if self._main_atask is None:
self._main_atask = asyncio.create_task(self._main_task())
def segment_playout_finished(self) -> None:
"""
Notify that the playout of the audio segment has finished.
This will catchup and directly send the final transcription in case the forwarder is too
late.
"""
self._check_not_closed()
self._finshed_seg_index += 1
def push_audio(self, frame: rtc.AudioFrame) -> None:
self._check_not_closed()
if self._audio_data is None:
self._audio_data = _AudioData()
self._audio_q.append(self._audio_data)
self._audio_q_changed.set()
frame_duration = frame.samples_per_channel / frame.sample_rate
self._audio_data.pushed_duration += frame_duration
def mark_audio_segment_end(self) -> None:
self._check_not_closed()
if self._audio_data is None:
self.push_audio(rtc.AudioFrame(bytes(), 24000, 1, 0))
assert self._audio_data is not None
self._audio_data.done = True
self._audio_data = None
def push_text(self, text: str) -> None:
self._check_not_closed()
if self._text_data is None:
self._text_data = _TextData(
sentence_stream=self._opts.sentence_tokenizer.stream()
)
self._text_q.append(self._text_data)
self._text_q_changed.set()
self._text_data.pushed_text += text
self._text_data.sentence_stream.push_text(text)
def mark_text_segment_end(self) -> None:
self._check_not_closed()
if self._text_data is None:
self.push_text("")
assert self._text_data is not None
self._text_data.done = True
self._text_data.sentence_stream.end_input()
self._text_data = None
@property
def closed(self) -> bool:
return self._closed
@property
def played_text(self) -> str:
return self._played_text
async def aclose(self) -> None:
if self._closed:
return
self._closed = True
self._close_future.set_result(None)
for text_data in self._text_q:
assert text_data is not None
await text_data.sentence_stream.aclose()
self._text_q.append(None)
self._audio_q.append(None)
self._text_q_changed.set()
self._audio_q_changed.set()
await self._task_set.aclose()
if self._main_atask is not None:
await self._main_atask
@utils.log_exceptions(logger=logger)
async def _main_task(self) -> None:
"""Main task that forwards the transcription to the room."""
rtc_seg_ch = utils.aio.Chan[rtc.TranscriptionSegment]()
forward_task = None
try:
@utils.log_exceptions(logger=logger)
async def _forward_task():
async for rtc_seg in rtc_seg_ch:
base_transcription = rtc.Transcription(
participant_identity=self._opts.participant_identity,
track_sid=self._opts.track_id,
segments=[rtc_seg], # no history for now
)
transcription = self._opts.before_forward_cb(
self, base_transcription
)
if asyncio.iscoroutine(transcription):
transcription = await transcription
# fallback to default impl if no custom/user stream is returned
if not isinstance(transcription, rtc.Transcription):
transcription = _default_before_forward_callback(
self, base_transcription
)
if transcription.segments and self._opts.room.isconnected():
try:
await (
self._opts.room.local_participant.publish_transcription(
transcription
)
)
except PublishTranscriptionError:
continue
forward_task = asyncio.create_task(_forward_task())
seg_index = 0
q_done = False
while not q_done:
await self._text_q_changed.wait()
await self._audio_q_changed.wait()
while self._text_q and self._audio_q:
text_data = self._text_q.pop(0)
audio_data = self._audio_q.pop(0)
if text_data is None or audio_data is None:
q_done = True
break
# wait until the segment is validated and has started playing
while not self._closed:
if self._playing_seg_index >= seg_index:
break
await self._sleep_if_not_closed(0.125)
sentence_stream = text_data.sentence_stream
forward_start_time = time.time()
async for ev in sentence_stream:
await self._sync_sentence_co(
seg_index,
forward_start_time,
text_data,
audio_data,
ev.token,
rtc_seg_ch,
)
seg_index += 1
self._text_q_changed.clear()
self._audio_q_changed.clear()
finally:
rtc_seg_ch.close()
if forward_task:
await forward_task
async def _sync_sentence_co(
self,
segment_index: int,
segment_start_time: float,
text_data: _TextData,
audio_data: _AudioData,
sentence: str,
rtc_seg_ch: utils.aio.Chan[rtc.TranscriptionSegment],
):
"""Synchronize the transcription with the audio playout for a given sentence."""
# put each sentence in a different transcription segment
real_speed = None
if audio_data.pushed_duration > 0 and audio_data.done:
real_speed = (
len(self._calc_hyphens(text_data.pushed_text))
/ audio_data.pushed_duration
)
seg_id = _utils.segment_uuid()
words = self._opts.word_tokenizer.tokenize(text=sentence)
processed_words: list[str] = []
og_text = self._played_text
for word in words:
if segment_index <= self._finshed_seg_index:
# playout of the audio segment already finished
# break the loop and send the final transcription
break
if self._closed:
# transcription closed, early
return
word_hyphens = len(self._opts.hyphenate_word(word))
processed_words.append(word)
# elapsed time since the start of the seg
elapsed_time = time.time() - segment_start_time
text = self._opts.word_tokenizer.format_words(processed_words)
# remove any punctuation at the end of a non-final transcript
text = text.rstrip("".join(PUNCTUATIONS))
speed = self._opts.speed
if real_speed is not None:
speed = real_speed
estimated_pauses_s = (
text_data.forwarded_sentences * self._opts.new_sentence_delay
)
hyph_pauses = estimated_pauses_s * speed
target_hyphens = round(speed * elapsed_time)
dt = target_hyphens - text_data.forwarded_hyphens - hyph_pauses
to_wait_hyphens = max(0.0, word_hyphens - dt)
delay = to_wait_hyphens / speed
else:
delay = word_hyphens / speed
first_delay = min(delay / 2, 2 / speed)
await self._sleep_if_not_closed(first_delay)
rtc_seg_ch.send_nowait(
rtc.TranscriptionSegment(
id=seg_id,
text=text,
start_time=0,
end_time=0,
final=False,
language=self._opts.language,
)
)
# add space if there is text before it
self._played_text = f"{og_text}{' ' if og_text else ''}{text}"
await self._sleep_if_not_closed(delay - first_delay)
text_data.forwarded_hyphens += word_hyphens
rtc_seg_ch.send_nowait(
rtc.TranscriptionSegment(
id=seg_id,
text=sentence,
start_time=0,
end_time=0,
final=True,
language=self._opts.language,
)
)
self._played_text = f"{og_text}{' ' if og_text else ''}{sentence}"
await self._sleep_if_not_closed(self._opts.new_sentence_delay)
text_data.forwarded_sentences += 1
async def _sleep_if_not_closed(self, delay: float) -> None:
with contextlib.suppress(asyncio.TimeoutError):
await asyncio.wait([self._close_future], timeout=delay)
def _calc_hyphens(self, text: str) -> list[str]:
hyphens: list[str] = []
words = self._opts.word_tokenizer.tokenize(text=text)
for word in words:
new = self._opts.hyphenate_word(word)
hyphens.extend(new)
return hyphens
def _check_not_closed(self) -> None:
if self._closed:
raise RuntimeError("TTSForwarder is closed")
from .fallback_adapter import (
AvailabilityChangedEvent,
FallbackAdapter,
FallbackChunkedStream,
FallbackSynthesizeStream,
)
from .stream_adapter import StreamAdapter, StreamAdapterWrapper
from .tts import (
TTS,
ChunkedStream,
SynthesizedAudio,
SynthesizedAudioEmitter,
SynthesizeStream,
TTSCapabilities,
)
__all__ = [
"TTS",
"SynthesizedAudio",
"SynthesizeStream",
"TTSCapabilities",
"StreamAdapterWrapper",
"StreamAdapter",
"ChunkedStream",
"AvailabilityChangedEvent",
"FallbackAdapter",
"FallbackChunkedStream",
"FallbackSynthesizeStream",
"SynthesizedAudioEmitter",
]
from __future__ import annotations
import asyncio
import contextlib
import dataclasses
import time
from dataclasses import dataclass
from typing import AsyncGenerator, Literal, Optional, Union
from livekit import rtc
from .. import utils
from .._exceptions import APIConnectionError, APIError
from ..log import logger
from ..utils import aio
from .tts import (
DEFAULT_API_CONNECT_OPTIONS,
TTS,
APIConnectOptions,
ChunkedStream,
SynthesizedAudio,
SynthesizeStream,
TTSCapabilities,
)
# don't retry when using the fallback adapter
DEFAULT_FALLBACK_API_CONNECT_OPTIONS = APIConnectOptions(
max_retry=0, timeout=DEFAULT_API_CONNECT_OPTIONS.timeout
)
@dataclass
class _TTSStatus:
available: bool
recovering_task: asyncio.Task | None
resampler: rtc.AudioResampler | None
@dataclass
class AvailabilityChangedEvent:
tts: TTS
available: bool
class FallbackAdapter(
TTS[Literal["tts_availability_changed"]],
):
"""
Manages multiple TTS instances, providing a fallback mechanism to ensure continuous TTS service.
"""
def __init__(
self,
tts: list[TTS],
*,
attempt_timeout: float = 10.0,
max_retry_per_tts: int = 1, # only retry once by default
retry_interval: float = 5,
no_fallback_after_audio_duration: float | None = 3.0,
sample_rate: int | None = None,
) -> None:
"""
Initialize a FallbackAdapter that manages multiple TTS instances.
Args:
tts (list[TTS]): A list of TTS instances to use for fallback.
attempt_timeout (float, optional): Timeout for each synthesis attempt in seconds. Defaults to 10.0.
max_retry_per_tts (int, optional): Maximum number of retries per TTS instance. Defaults to 1.
no_fallback_after_audio_duration (float | None, optional): Disables fallback after this duration of audio is synthesized. Defaults to 3.0.
This is used to prevent unnaturally resaying the same text when the first TTS
instance fails.
sample_rate (int | None, optional): Desired sample rate for the synthesized audio. If None, uses the maximum sample rate among the TTS instances.
Raises:
ValueError: If less than one TTS instance is provided.
ValueError: If TTS instances have different numbers of channels.
"""
if len(tts) < 1:
raise ValueError("at least one TTS instance must be provided.")
if len(set(t.num_channels for t in tts)) != 1:
raise ValueError("all TTS must have the same number of channels")
if sample_rate is None:
sample_rate = max(t.sample_rate for t in tts)
num_channels = tts[0].num_channels
super().__init__(
capabilities=TTSCapabilities(
streaming=all(t.capabilities.streaming for t in tts),
),
sample_rate=sample_rate,
num_channels=num_channels,
)
self._tts_instances = tts
self._attempt_timeout = attempt_timeout
self._max_retry_per_tts = max_retry_per_tts
self._retry_interval = retry_interval
self._no_fallback_after_audio_duration = no_fallback_after_audio_duration
self._status: list[_TTSStatus] = []
for t in tts:
resampler = None
if sample_rate != t.sample_rate:
logger.info(
f"resampling {t.label} from {t.sample_rate}Hz to {sample_rate}Hz"
)
resampler = rtc.AudioResampler(
input_rate=t.sample_rate, output_rate=sample_rate
)
self._status.append(
_TTSStatus(available=True, recovering_task=None, resampler=resampler)
)
def synthesize(
self,
text: str,
*,
conn_options: Optional[APIConnectOptions] = None,
) -> "FallbackChunkedStream":
return FallbackChunkedStream(
tts=self,
input_text=text,
conn_options=conn_options or DEFAULT_FALLBACK_API_CONNECT_OPTIONS,
)
def stream(
self,
*,
conn_options: Optional[APIConnectOptions] = None,
) -> "FallbackSynthesizeStream":
return FallbackSynthesizeStream(
tts=self,
conn_options=conn_options or DEFAULT_FALLBACK_API_CONNECT_OPTIONS,
)
def prewarm(self) -> None:
if self._tts_instances:
self._tts_instances[0].prewarm()
async def aclose(self) -> None:
for tts_status in self._status:
if tts_status.recovering_task is not None:
await aio.gracefully_cancel(tts_status.recovering_task)
class FallbackChunkedStream(ChunkedStream):
def __init__(
self,
*,
tts: FallbackAdapter,
input_text: str,
conn_options: Optional[APIConnectOptions],
) -> None:
super().__init__(tts=tts, input_text=input_text, conn_options=conn_options)
self._fallback_adapter = tts
async def _try_synthesize(
self, *, tts: TTS, recovering: bool = False
) -> AsyncGenerator[SynthesizedAudio, None]:
try:
audio_duration = 0.0
async with tts.synthesize(
self._input_text,
conn_options=dataclasses.replace(
self._conn_options,
max_retry=self._fallback_adapter._max_retry_per_tts,
timeout=self._fallback_adapter._attempt_timeout,
retry_interval=self._fallback_adapter._retry_interval,
),
) as stream:
while True:
try:
audio = await asyncio.wait_for(
stream.__anext__(),
self._fallback_adapter._attempt_timeout
if audio_duration == 0.0
else None,
)
audio_duration += audio.frame.duration
yield audio
except StopAsyncIteration:
break
if audio_duration == 0.0:
raise APIConnectionError("no audio received")
except asyncio.TimeoutError:
if recovering:
logger.warning(
f"{tts.label} recovery timed out", extra={"streamed": False}
)
raise
logger.warning(
f"{tts.label} timed out, switching to next TTS",
extra={"streamed": False},
)
raise
except APIError as e:
if recovering:
logger.warning(
f"{tts.label} recovery failed",
exc_info=e,
extra={"streamed": False},
)
raise
logger.warning(
f"{tts.label} failed, switching to next TTS",
exc_info=e,
extra={"streamed": False},
)
raise
except Exception:
if recovering:
logger.exception(
f"{tts.label} recovery unexpected error", extra={"streamed": False}
)
raise
logger.exception(
f"{tts.label} unexpected error, switching to next TTS",
extra={"streamed": False},
)
raise
def _try_recovery(self, tts: TTS) -> None:
assert isinstance(self._tts, FallbackAdapter)
tts_status = self._tts._status[self._tts._tts_instances.index(tts)]
if tts_status.recovering_task is None or tts_status.recovering_task.done():
async def _recover_tts_task(tts: TTS) -> None:
try:
async for _ in self._try_synthesize(tts=tts, recovering=True):
pass
tts_status.available = True
logger.info(f"tts.FallbackAdapter, {tts.label} recovered")
self._tts.emit(
"tts_availability_changed",
AvailabilityChangedEvent(tts=tts, available=True),
)
except Exception:
return
tts_status.recovering_task = asyncio.create_task(_recover_tts_task(tts))
async def _run(self) -> None:
assert isinstance(self._tts, FallbackAdapter)
start_time = time.time()
all_failed = all(not tts_status.available for tts_status in self._tts._status)
if all_failed:
logger.error("all TTSs are unavailable, retrying..")
for i, tts in enumerate(self._tts._tts_instances):
tts_status = self._tts._status[i]
if tts_status.available or all_failed:
audio_duration = 0.0
try:
request_id: str | None = None
resampler = tts_status.resampler
async for synthesized_audio in self._try_synthesize(
tts=tts, recovering=False
):
audio_duration += synthesized_audio.frame.duration
request_id = synthesized_audio.request_id
if resampler is not None:
for rf in resampler.push(synthesized_audio.frame):
self._event_ch.send_nowait(
SynthesizedAudio(
frame=rf,
request_id=synthesized_audio.request_id,
)
)
continue
self._event_ch.send_nowait(synthesized_audio)
if resampler is not None and request_id is not None:
for rf in resampler.flush():
self._event_ch.send_nowait(
SynthesizedAudio(
frame=rf,
request_id=request_id,
)
)
return
except Exception: # exceptions already logged inside _try_synthesize
if tts_status.available:
tts_status.available = False
self._tts.emit(
"tts_availability_changed",
AvailabilityChangedEvent(tts=tts, available=False),
)
if self._tts._no_fallback_after_audio_duration is not None:
if (
audio_duration
>= self._tts._no_fallback_after_audio_duration
):
logger.warning(
f"{tts.label} already synthesized {audio_duration}s of audio, ignoring fallback"
)
return
self._try_recovery(tts)
raise APIConnectionError(
"all TTSs failed (%s) after %s seconds"
% (
[tts.label for tts in self._tts._tts_instances],
time.time() - start_time,
)
)
class FallbackSynthesizeStream(SynthesizeStream):
def __init__(
self,
*,
tts: FallbackAdapter,
conn_options: Optional[APIConnectOptions] = None,
):
super().__init__(
tts=tts, conn_options=conn_options or DEFAULT_FALLBACK_API_CONNECT_OPTIONS
)
self._fallback_adapter = tts
self._total_segments: list[list[str]] = []
self._pending_segments_chunks: list[list[str]] = []
self._current_segment_text: list[str] = []
async def _try_synthesize(
self,
*,
tts: TTS,
input_ch: aio.ChanReceiver[str | SynthesizeStream._FlushSentinel],
conn_options: APIConnectOptions,
recovering: bool = False,
) -> AsyncGenerator[SynthesizedAudio, None]:
stream = tts.stream(conn_options=conn_options)
input_sent_fut = asyncio.Future() # type: ignore
@utils.log_exceptions(logger=logger)
async def _input_task() -> None:
try:
segment = ""
async for data in input_ch:
if isinstance(data, str):
segment += data
stream.push_text(data)
elif isinstance(data, self._FlushSentinel):
# start the timeout on flush
if segment:
segment = ""
with contextlib.suppress(asyncio.InvalidStateError):
input_sent_fut.set_result(True)
stream.flush()
finally:
with contextlib.suppress(RuntimeError):
stream.end_input()
with contextlib.suppress(asyncio.InvalidStateError):
input_sent_fut.set_result(False)
input_task = asyncio.create_task(_input_task())
next_audio_task: asyncio.Future[SynthesizedAudio] | None = None
try:
audio_duration = 0.0
async with stream:
while True:
if next_audio_task is None or next_audio_task.done():
next_audio_task = asyncio.ensure_future(stream.__anext__())
try:
if not input_sent_fut.done():
await asyncio.wait(
[input_sent_fut, next_audio_task],
return_when=asyncio.FIRST_COMPLETED,
)
if not next_audio_task.done():
continue
audio = next_audio_task.result()
else:
audio = await asyncio.wait_for(
next_audio_task, self._fallback_adapter._attempt_timeout
)
audio_duration += audio.frame.duration
if audio.is_final:
input_sent_fut = asyncio.Future()
audio_duration = 0.0
yield audio
except StopAsyncIteration:
break
if (
audio_duration == 0.0
and input_sent_fut.done()
and input_sent_fut.result()
):
raise APIConnectionError("no audio received")
except asyncio.TimeoutError:
if recovering:
logger.warning(
f"{tts.label} recovery timed out", extra={"streamed": True}
)
raise
logger.warning(
f"{tts.label} timed out, switching to next TTS",
extra={"streamed": True},
)
raise
except APIError as e:
if recovering:
logger.warning(
f"{tts.label} recovery failed", exc_info=e, extra={"streamed": True}
)
raise
logger.warning(
f"{tts.label} failed, switching to next TTS",
exc_info=e,
extra={"streamed": True},
)
raise
except Exception:
if recovering:
logger.exception(
f"{tts.label} recovery unexpected error",
extra={"streamed": True},
)
raise
logger.exception(
f"{tts.label} unexpected error, switching to next TTS",
extra={"streamed": True},
)
raise
finally:
if next_audio_task is not None:
await utils.aio.gracefully_cancel(next_audio_task)
await utils.aio.gracefully_cancel(input_task)
async def _run(self) -> None:
start_time = time.time()
all_failed = all(
not tts_status.available for tts_status in self._fallback_adapter._status
)
if all_failed:
logger.error("all TTSs are unavailable, retrying..")
new_input_ch: aio.Chan[str | SynthesizeStream._FlushSentinel] | None = None
async def _forward_input_task():
nonlocal new_input_ch
async for data in self._input_ch:
if new_input_ch:
new_input_ch.send_nowait(data)
if isinstance(data, str) and data:
self._current_segment_text.append(data)
elif (
isinstance(data, self._FlushSentinel) and self._current_segment_text
):
self._total_segments.append(self._current_segment_text)
self._pending_segments_chunks.append(self._current_segment_text)
self._current_segment_text = []
if new_input_ch:
new_input_ch.close()
input_task = asyncio.create_task(_forward_input_task())
try:
for i, tts in enumerate(self._fallback_adapter._tts_instances):
tts_status = self._fallback_adapter._status[i]
if tts_status.available or all_failed:
audio_duration = 0.0
try:
new_input_ch = aio.Chan[
Union[str, SynthesizeStream._FlushSentinel]
]()
for text in self._pending_segments_chunks:
for chunk in text:
new_input_ch.send_nowait(chunk)
new_input_ch.send_nowait(self._FlushSentinel())
for chunk in self._current_segment_text:
new_input_ch.send_nowait(chunk)
if input_task.done():
new_input_ch.close()
last_segment_id: str | None = None
resampler = tts_status.resampler
async for synthesized_audio in self._try_synthesize(
tts=tts,
input_ch=new_input_ch,
conn_options=dataclasses.replace(
self._conn_options,
max_retry=self._fallback_adapter._max_retry_per_tts,
timeout=self._fallback_adapter._attempt_timeout,
retry_interval=self._fallback_adapter._retry_interval,
),
recovering=False,
):
audio_duration += synthesized_audio.frame.duration
if resampler is not None:
for resampled_frame in resampler.push(
synthesized_audio.frame
):
self._event_ch.send_nowait(
dataclasses.replace(
synthesized_audio, frame=resampled_frame
)
)
if synthesized_audio.is_final:
for resampled_frame in resampler.flush():
self._event_ch.send_nowait(
dataclasses.replace(
synthesized_audio, frame=resampled_frame
)
)
else:
self._event_ch.send_nowait(synthesized_audio)
if (
synthesized_audio.is_final
or (
last_segment_id is not None
and synthesized_audio.segment_id != last_segment_id
)
) and self._pending_segments_chunks:
audio_duration = 0.0
self._pending_segments_chunks.pop(0)
last_segment_id = synthesized_audio.segment_id
return
except Exception:
if tts_status.available:
tts_status.available = False
self._tts.emit(
"tts_availability_changed",
AvailabilityChangedEvent(tts=tts, available=False),
)
if (
self._fallback_adapter._no_fallback_after_audio_duration
is not None
):
if (
audio_duration
>= self._fallback_adapter._no_fallback_after_audio_duration
and self._pending_segments_chunks
):
logger.warning(
f"{tts.label} already synthesized {audio_duration}s of audio, ignoring the current segment for the tts fallback"
)
return
self._try_recovery(tts)
raise APIConnectionError(
"all TTSs failed (%s) after %s seconds"
% (
[tts.label for tts in self._fallback_adapter._tts_instances],
time.time() - start_time,
)
)
finally:
await utils.aio.gracefully_cancel(input_task)
def _try_recovery(self, tts: TTS) -> None:
assert isinstance(self._tts, FallbackAdapter)
retry_segments = [self._current_segment_text.copy()]
if self._total_segments:
retry_segments.insert(0, self._total_segments[-1])
tts_status = self._tts._status[self._tts._tts_instances.index(tts)]
if tts_status.recovering_task is None or tts_status.recovering_task.done():
async def _recover_tts_task(tts: TTS) -> None:
try:
input_ch = aio.Chan[Union[str, SynthesizeStream._FlushSentinel]]()
for segment in retry_segments:
for t in segment:
input_ch.send_nowait(t)
input_ch.send_nowait(self._FlushSentinel())
input_ch.close()
async for _ in self._try_synthesize(
tts=tts,
input_ch=input_ch,
recovering=True,
conn_options=dataclasses.replace(
self._conn_options,
max_retry=0,
timeout=self._fallback_adapter._attempt_timeout,
retry_interval=self._fallback_adapter._retry_interval,
),
):
pass
tts_status.available = True
logger.info(f"tts.FallbackAdapter, {tts.label} recovered")
self._tts.emit(
"tts_availability_changed",
AvailabilityChangedEvent(tts=tts, available=True),
)
except Exception:
return
tts_status.recovering_task = asyncio.create_task(_recover_tts_task(tts))
from __future__ import annotations
import asyncio
from typing import AsyncIterable, Optional
from .. import tokenize, utils
from ..types import APIConnectOptions
from .tts import (
TTS,
ChunkedStream,
SynthesizedAudio,
SynthesizeStream,
TTSCapabilities,
)
class StreamAdapter(TTS):
def __init__(
self,
*,
tts: TTS,
sentence_tokenizer: tokenize.SentenceTokenizer,
) -> None:
super().__init__(
capabilities=TTSCapabilities(
streaming=True,
),
sample_rate=tts.sample_rate,
num_channels=tts.num_channels,
)
self._tts = tts
self._sentence_tokenizer = sentence_tokenizer
@self._tts.on("metrics_collected")
def _forward_metrics(*args, **kwargs):
self.emit("metrics_collected", *args, **kwargs)
def synthesize(
self,
text: str,
*,
conn_options: Optional[APIConnectOptions] = None,
) -> "ChunkedStream":
return self._tts.synthesize(text=text, conn_options=conn_options)
def stream(
self,
*,
conn_options: Optional[APIConnectOptions] = None,
) -> "StreamAdapterWrapper":
return StreamAdapterWrapper(
tts=self,
conn_options=conn_options,
wrapped_tts=self._tts,
sentence_tokenizer=self._sentence_tokenizer,
)
def prewarm(self) -> None:
self._tts.prewarm()
class StreamAdapterWrapper(SynthesizeStream):
def __init__(
self,
*,
tts: TTS,
wrapped_tts: TTS,
sentence_tokenizer: tokenize.SentenceTokenizer,
conn_options: Optional[APIConnectOptions],
) -> None:
super().__init__(tts=tts, conn_options=conn_options)
self._wrapped_tts = wrapped_tts
self._sent_stream = sentence_tokenizer.stream()
async def _metrics_monitor_task(
self, event_aiter: AsyncIterable[SynthesizedAudio]
) -> None:
pass # do nothing
async def _run(self) -> None:
async def _forward_input():
"""forward input to vad"""
async for data in self._input_ch:
if isinstance(data, self._FlushSentinel):
self._sent_stream.flush()
continue
self._sent_stream.push_text(data)
self._sent_stream.end_input()
async def _synthesize():
async for ev in self._sent_stream:
last_audio: SynthesizedAudio | None = None
async for audio in self._wrapped_tts.synthesize(ev.token):
if last_audio is not None:
self._event_ch.send_nowait(last_audio)
last_audio = audio
if last_audio is not None:
last_audio.is_final = True
self._event_ch.send_nowait(last_audio)
tasks = [
asyncio.create_task(_forward_input()),
asyncio.create_task(_synthesize()),
]
try:
await asyncio.gather(*tasks)
finally:
await utils.aio.gracefully_cancel(*tasks)
await self._wrapped_tts.aclose()
from __future__ import annotations
import asyncio
import time
from abc import ABC, abstractmethod
from dataclasses import dataclass
from types import TracebackType
from typing import (
AsyncIterable,
AsyncIterator,
Generic,
Literal,
Optional,
TypeVar,
Union,
)
from livekit import rtc
from .._exceptions import APIConnectionError, APIError
from ..log import logger
from ..metrics import TTSMetrics
from ..types import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions
from ..utils import aio
@dataclass
class SynthesizedAudio:
frame: rtc.AudioFrame
"""Synthesized audio frame"""
request_id: str
"""Request ID (one segment could be made up of multiple requests)"""
is_final: bool = False
"""Whether this is latest frame of the segment (streaming only)"""
segment_id: str = ""
"""Segment ID, each segment is separated by a flush (streaming only)"""
delta_text: str = ""
"""Current segment of the synthesized audio (streaming only)"""
@dataclass
class TTSCapabilities:
streaming: bool
"""Whether this TTS supports streaming (generally using websockets)"""
TEvent = TypeVar("TEvent")
class TTS(
ABC,
rtc.EventEmitter[Union[Literal["metrics_collected"], TEvent]],
Generic[TEvent],
):
def __init__(
self,
*,
capabilities: TTSCapabilities,
sample_rate: int,
num_channels: int,
conn_options: Optional[APIConnectOptions] = None,
) -> None:
super().__init__()
self._capabilities = capabilities
self._sample_rate = sample_rate
self._num_channels = num_channels
self._label = f"{type(self).__module__}.{type(self).__name__}"
self._conn_options = conn_options or DEFAULT_API_CONNECT_OPTIONS
@property
def label(self) -> str:
return self._label
@property
def capabilities(self) -> TTSCapabilities:
return self._capabilities
@property
def sample_rate(self) -> int:
return self._sample_rate
@property
def num_channels(self) -> int:
return self._num_channels
@abstractmethod
def synthesize(
self,
text: str,
*,
conn_options: Optional[APIConnectOptions] = None,
) -> ChunkedStream: ...
def stream(
self, *, conn_options: Optional[APIConnectOptions] = None
) -> SynthesizeStream:
raise NotImplementedError(
"streaming is not supported by this TTS, please use a different TTS or use a StreamAdapter"
)
def prewarm(self) -> None:
"""Pre-warm connection to the TTS service"""
pass
async def aclose(self) -> None: ...
async def __aenter__(self) -> TTS:
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
await self.aclose()
class ChunkedStream(ABC):
"""Used by the non-streamed synthesize API, some providers support chunked http responses"""
def __init__(
self,
*,
tts: TTS,
input_text: str,
conn_options: Optional[APIConnectOptions] = None,
) -> None:
self._input_text = input_text
self._tts = tts
self._conn_options = conn_options or DEFAULT_API_CONNECT_OPTIONS
self._event_ch = aio.Chan[SynthesizedAudio]()
self._event_aiter, monitor_aiter = aio.itertools.tee(self._event_ch, 2)
self._metrics_task = asyncio.create_task(
self._metrics_monitor_task(monitor_aiter), name="TTS._metrics_task"
)
self._synthesize_task = asyncio.create_task(
self._main_task(), name="TTS._synthesize_task"
)
self._synthesize_task.add_done_callback(lambda _: self._event_ch.close())
@property
def input_text(self) -> str:
return self._input_text
@property
def done(self) -> bool:
return self._synthesize_task.done()
@property
def exception(self) -> BaseException | None:
return self._synthesize_task.exception()
async def _metrics_monitor_task(
self, event_aiter: AsyncIterable[SynthesizedAudio]
) -> None:
"""Task used to collect metrics"""
start_time = time.perf_counter()
audio_duration = 0.0
ttfb = -1.0
request_id = ""
async for ev in event_aiter:
request_id = ev.request_id
if ttfb == -1.0:
ttfb = time.perf_counter() - start_time
audio_duration += ev.frame.duration
duration = time.perf_counter() - start_time
metrics = TTSMetrics(
timestamp=time.time(),
request_id=request_id,
ttfb=ttfb,
duration=duration,
characters_count=len(self._input_text),
audio_duration=audio_duration,
cancelled=self._synthesize_task.cancelled(),
label=self._tts._label,
streamed=False,
error=None,
)
self._tts.emit("metrics_collected", metrics)
async def collect(self) -> rtc.AudioFrame:
"""Utility method to collect every frame in a single call"""
frames = []
async for ev in self:
frames.append(ev.frame)
return rtc.combine_audio_frames(frames)
@abstractmethod
async def _run(self) -> None: ...
async def _main_task(self) -> None:
for i in range(self._conn_options.max_retry + 1):
try:
return await self._run()
except APIError as e:
retry_interval = self._conn_options._interval_for_retry(i)
if self._conn_options.max_retry == 0:
raise
elif i == self._conn_options.max_retry:
raise APIConnectionError(
f"failed to synthesize speech after {self._conn_options.max_retry + 1} attempts",
) from e
else:
logger.warning(
f"failed to synthesize speech, retrying in {retry_interval}s",
exc_info=e,
extra={
"tts": self._tts._label,
"attempt": i + 1,
"streamed": False,
},
)
await asyncio.sleep(retry_interval)
async def aclose(self) -> None:
"""Close is automatically called if the stream is completely collected"""
await aio.gracefully_cancel(self._synthesize_task)
self._event_ch.close()
await self._metrics_task
async def __anext__(self) -> SynthesizedAudio:
try:
val = await self._event_aiter.__anext__()
except StopAsyncIteration:
if not self._synthesize_task.cancelled() and (
exc := self._synthesize_task.exception()
):
raise exc from None
raise StopAsyncIteration
return val
def __aiter__(self) -> AsyncIterator[SynthesizedAudio]:
return self
async def __aenter__(self) -> ChunkedStream:
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
await self.aclose()
class SynthesizeStream(ABC):
class _FlushSentinel: ...
def __init__(
self, *, tts: TTS, conn_options: Optional[APIConnectOptions] = None
) -> None:
super().__init__()
self._tts = tts
self._conn_options = conn_options or DEFAULT_API_CONNECT_OPTIONS
self._input_ch = aio.Chan[Union[str, SynthesizeStream._FlushSentinel]]()
self._event_ch = aio.Chan[SynthesizedAudio]()
self._event_aiter, self._monitor_aiter = aio.itertools.tee(self._event_ch, 2)
self._task = asyncio.create_task(self._main_task(), name="TTS._main_task")
self._task.add_done_callback(lambda _: self._event_ch.close())
self._metrics_task: asyncio.Task | None = None # started on first push
self._started_time: float = 0
# used to track metrics
self._mtc_pending_texts: list[str] = []
self._mtc_text = ""
@abstractmethod
async def _run(self) -> None: ...
async def _main_task(self) -> None:
for i in range(self._conn_options.max_retry + 1):
try:
return await self._run()
except APIError as e:
retry_interval = self._conn_options._interval_for_retry(i)
if self._conn_options.max_retry == 0:
raise
elif i == self._conn_options.max_retry:
raise APIConnectionError(
f"failed to synthesize speech after {self._conn_options.max_retry + 1} attempts",
) from e
else:
logger.warning(
f"failed to synthesize speech, retrying in {retry_interval}s",
exc_info=e,
extra={
"tts": self._tts._label,
"attempt": i + 1,
"streamed": True,
},
)
await asyncio.sleep(retry_interval)
def _mark_started(self) -> None:
# only set the started time once, it'll get reset after we emit metrics
if self._started_time == 0:
self._started_time = time.perf_counter()
async def _metrics_monitor_task(
self, event_aiter: AsyncIterable[SynthesizedAudio]
) -> None:
"""Task used to collect metrics"""
audio_duration = 0.0
ttfb = -1.0
request_id = ""
def _emit_metrics():
nonlocal audio_duration, ttfb, request_id
if not self._started_time:
return
duration = time.perf_counter() - self._started_time
if not self._mtc_pending_texts:
return
text = self._mtc_pending_texts.pop(0)
if not text:
return
metrics = TTSMetrics(
timestamp=time.time(),
request_id=request_id,
ttfb=ttfb,
duration=duration,
characters_count=len(text),
audio_duration=audio_duration,
cancelled=self._task.cancelled(),
label=self._tts._label,
streamed=True,
error=None,
)
self._tts.emit("metrics_collected", metrics)
audio_duration = 0.0
ttfb = -1.0
request_id = ""
self._started_time = 0
async for ev in event_aiter:
if ttfb == -1.0:
ttfb = time.perf_counter() - self._started_time
audio_duration += ev.frame.duration
request_id = ev.request_id
if ev.is_final:
_emit_metrics()
if request_id:
_emit_metrics()
def push_text(self, token: str) -> None:
"""Push some text to be synthesized"""
if self._metrics_task is None:
self._metrics_task = asyncio.create_task(
self._metrics_monitor_task(self._monitor_aiter),
name="TTS._metrics_task",
)
self._mtc_text += token
self._check_input_not_ended()
self._check_not_closed()
self._input_ch.send_nowait(token)
def flush(self) -> None:
"""Mark the end of the current segment"""
if self._mtc_text:
self._mtc_pending_texts.append(self._mtc_text)
self._mtc_text = ""
self._check_input_not_ended()
self._check_not_closed()
self._input_ch.send_nowait(self._FlushSentinel())
def end_input(self) -> None:
"""Mark the end of input, no more text will be pushed"""
self.flush()
self._input_ch.close()
async def aclose(self) -> None:
"""Close ths stream immediately"""
self._input_ch.close()
await aio.gracefully_cancel(self._task)
if self._metrics_task is not None:
await self._metrics_task
def _check_not_closed(self) -> None:
if self._event_ch.closed:
cls = type(self)
raise RuntimeError(f"{cls.__module__}.{cls.__name__} is closed")
def _check_input_not_ended(self) -> None:
if self._input_ch.closed:
cls = type(self)
raise RuntimeError(f"{cls.__module__}.{cls.__name__} input ended")
async def __anext__(self) -> SynthesizedAudio:
try:
val = await self._event_aiter.__anext__()
except StopAsyncIteration:
if not self._task.cancelled() and (exc := self._task.exception()):
raise exc from None
raise StopAsyncIteration
return val
def __aiter__(self) -> AsyncIterator[SynthesizedAudio]:
return self
async def __aenter__(self) -> SynthesizeStream:
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
await self.aclose()
class SynthesizedAudioEmitter:
"""Utility for buffering and emitting audio frames with metadata to a channel.
This class helps TTS implementers to correctly handle is_final logic when streaming responses.
"""
def __init__(
self,
*,
event_ch: aio.Chan[SynthesizedAudio],
request_id: str,
segment_id: str = "",
) -> None:
self._event_ch = event_ch
self._frame: rtc.AudioFrame | None = None
self._request_id = request_id
self._segment_id = segment_id
def push(self, frame: Optional[rtc.AudioFrame]):
"""Emits any buffered frame and stores the new frame for later emission.
The buffered frame is emitted as not final.
"""
self._emit_frame(is_final=False)
self._frame = frame
def _emit_frame(self, is_final: bool = False):
"""Sends the buffered frame to the event channel if one exists."""
if self._frame is None:
return
self._event_ch.send_nowait(
SynthesizedAudio(
frame=self._frame,
request_id=self._request_id,
segment_id=self._segment_id,
is_final=is_final,
)
)
self._frame = None
def flush(self):
"""Emits any buffered frame as final."""
self._emit_frame(is_final=True)
from dataclasses import dataclass
from typing import Literal, TypeVar, Union
AgentState = Union[Literal["initializing", "listening", "thinking", "speaking"], str]
ATTRIBUTE_AGENT_STATE = "lk.agent.state"
"""
The state of the agent, stored in the agent's attributes.
This can be retrieved on the client side by using `RemoteParticipant.attributes`.
With components-js, this can be easily retrieved using:
```js
const { state, ... } = useVoiceAssistant();
”””
_T = TypeVar(“_T”)
class NotGiven: def bool(self) -> Literal[False]: return False
def __repr__(self) -> str:
return "NOT_GIVEN"
NotGivenOr = Union[_T, NotGiven] NOT_GIVEN = NotGiven()
@dataclass(frozen=True) class APIConnectOptions: max_retry: int = 3 “”” Maximum number of retries to connect to the API. “””
retry_interval: float = 2.0
"""
Interval between retries to connect to the API in seconds.
"""
timeout: float = 10.0
"""
Timeout for connecting to the API in seconds.
"""
def __post_init__(self):
if self.max_retry < 0:
raise ValueError("max_retry must be greater than or equal to 0")
if self.retry_interval < 0:
raise ValueError("retry_interval must be greater than or equal to 0")
if self.timeout < 0:
raise ValueError("timeout must be greater than or equal to 0")
def _interval_for_retry(self, num_retries: int) -> float:
"""
Return the interval for the given number of retries.
The first retry is immediate, and then uses specified retry_interval
"""
if num_retries == 0:
return 0.1
return self.retry_interval
DEFAULT_API_CONNECT_OPTIONS = APIConnectOptions()
## livekit-agents/livekit/agents/utils/__init__.py
```py
from livekit import rtc
from . import aio, audio, codecs, http_context, hw, images
from ._message_change import compute_changes as _compute_changes # keep internal
from .audio import AudioBuffer, combine_frames, merge_frames
from .connection_pool import ConnectionPool
from .exp_filter import ExpFilter
from .log import log_exceptions
from .misc import is_given, shortuuid, time_ms
from .moving_average import MovingAverage
EventEmitter = rtc.EventEmitter
__all__ = [
"AudioBuffer",
"merge_frames",
"combine_frames",
"time_ms",
"shortuuid",
"http_context",
"ExpFilter",
"MovingAverage",
"EventEmitter",
"log_exceptions",
"codecs",
"images",
"audio",
"aio",
"hw",
"is_given",
"_compute_changes",
"ConnectionPool",
]
from dataclasses import dataclass
from typing import Callable, Generic, TypeVar, Union
T = TypeVar("T")
@dataclass
class MessageChange(Generic[T]):
"""Represents changes needed to transform one list into another
The changes must be applied in order:
1. First apply all deletions
2. Then apply all insertions with their previous_item_id
"""
to_delete: list[T]
"""Items to delete from old list"""
to_add: list[tuple[Union[T, None], T]]
"""Items to add as (previous_item, new_item) pairs"""
def compute_changes(
old_list: list[T], new_list: list[T], key_fnc: Callable[[T], str]
) -> MessageChange[T]:
"""Compute minimum changes needed to transform old list into new list"""
# Convert to lists of ids
old_ids = [key_fnc(msg) for msg in old_list]
new_ids = [key_fnc(msg) for msg in new_list]
# Create lookup maps
old_msgs = {key_fnc(msg): msg for msg in old_list}
new_msgs = {key_fnc(msg): msg for msg in new_list}
# Compute changes using ids
changes = _compute_list_changes(old_ids, new_ids)
# Convert back to items
return MessageChange(
to_delete=[old_msgs[id] for id in changes.to_delete],
to_add=[
(
None if prev is None else old_msgs.get(prev) or new_msgs[prev],
new_msgs[new],
)
for prev, new in changes.to_add
],
)
def _compute_list_changes(old_list: list[T], new_list: list[T]) -> MessageChange[T]:
"""Compute minimum changes needed to transform old_list into new_list
Rules:
- Delete first, then insert
- Can't insert at start if list not empty (must delete all first)
- Each insert needs previous item except for first item in new list
- If an item changes position relative to others, it must be deleted and reinserted
- If first item in new list exists in old list, must delete all items before it
Examples:
old [a b c d] new [b c d e] -> delete a, insert (d,e)
old [a b c d] new [e a b c d] -> delete all, insert (None,e),(e,a),(a,b),(b,c),(c,d)
old [a b c d] new [a b d e c] -> delete d, insert (b,d),(d,e)
old [a b c d] new [a d c b] -> delete c,d, insert (a,d),(d,c)
"""
if not new_list:
return MessageChange(to_delete=old_list, to_add=[])
# Find first item's position in old list
try:
first_idx = old_list.index(new_list[0])
except ValueError:
# Special case: if first item is new, delete everything
prev_item: Union[T, None] = None
to_add: list[tuple[Union[T, None], T]] = []
for x in new_list:
to_add.append((prev_item, x))
prev_item = x
return MessageChange(to_delete=old_list, to_add=to_add)
# Delete all items before first_idx
to_delete = old_list[:first_idx]
remaining_old = old_list[first_idx:]
# Get positions of remaining items in new list
indices = []
items = []
new_positions = {x: i for i, x in enumerate(new_list)}
for x in remaining_old:
if x in new_positions:
indices.append(new_positions[x])
items.append(x)
# Try fast path first - check if remaining order is preserved
if _check_order_preserved(indices):
kept_indices = list(range(len(indices)))
else:
# Order changed, need to find kept items using LIS
# First item must be kept since we've already handled items before it
kept_indices = _find_longest_increasing_subsequence(indices)
# Convert kept indices back to items
kept_items = {items[i] for i in kept_indices}
# Add items that need to be deleted from remaining list
to_delete.extend(x for x in remaining_old if x not in kept_items)
# Compute items to add by following new list order
to_add = []
prev_item = None
for x in new_list:
if x not in kept_items:
to_add.append((prev_item, x))
prev_item = x
return MessageChange(to_delete=to_delete, to_add=to_add)
def _check_order_preserved(indices: list[int]) -> bool:
"""Check if indices form an increasing sequence"""
if not indices:
return True
# Check if indices form an increasing sequence
for i in range(1, len(indices)):
if indices[i] <= indices[i - 1]:
return False
return True
def _find_longest_increasing_subsequence(indices: list[int]) -> list[int]:
"""Find indices of the longest increasing subsequence
Args:
indices: List of indices to find LIS from
Returns:
List of indices into the input list that form the LIS
For example, indices = [0, 4, 1, 2] -> [0, 2, 3]
"""
if not indices:
return []
# Must include first index, find LIS starting from it
first_val = indices[0]
dp = [1] * len(indices)
prev = [-1] * len(indices)
best_len = 1 # At minimum we keep the first index
best_end = 0 # Start with first index
# Start from second element
for i in range(1, len(indices)):
# Only consider sequences starting from first index
if indices[i] > first_val:
dp[i] = 2
prev[i] = 0
if dp[i] > best_len:
best_len = dp[i]
best_end = i
# Try extending existing sequences
for j in range(1, i):
if indices[j] < indices[i] and prev[j] != -1 and dp[j] + 1 > dp[i]:
dp[i] = dp[j] + 1
prev[i] = j
if dp[i] > best_len:
best_len = dp[i]
best_end = i
# Reconstruct sequence
result = []
while best_end != -1:
result.append(best_end)
best_end = prev[best_end]
result.reverse()
return result
import asyncio
import functools
from . import debug, duplex_unix, itertools
from .channel import Chan, ChanClosed, ChanReceiver, ChanSender
from .interval import Interval, interval
from .sleep import Sleep, SleepFinished, sleep
from .task_set import TaskSet
async def gracefully_cancel(*futures: asyncio.Future):
loop = asyncio.get_running_loop()
waiters = []
for fut in futures:
waiter = loop.create_future()
cb = functools.partial(_release_waiter, waiter)
waiters.append((waiter, cb))
fut.add_done_callback(cb)
fut.cancel()
try:
for waiter, _ in waiters:
await waiter
finally:
for i, fut in enumerate(futures):
_, cb = waiters[i]
fut.remove_done_callback(cb)
def _release_waiter(waiter, *args):
if not waiter.done():
waiter.set_result(None)
__all__ = [
"ChanClosed",
"Chan",
"ChanSender",
"ChanReceiver",
"channel",
"Interval",
"interval",
"Sleep",
"SleepFinished",
"sleep",
"TaskSet",
"debug",
"gracefully_cancel",
"duplex_unix",
"itertools",
]
from __future__ import annotations
import asyncio
import contextlib
from collections import deque
from typing import AsyncIterator, Deque, Generic, Protocol, TypeVar
T = TypeVar("T")
T_co = TypeVar("T_co", covariant=True)
T_contra = TypeVar("T_contra", contravariant=True)
# Based on asyncio.Queue, see https://github.com/python/cpython/blob/main/Lib/asyncio/queues.py
class ChanClosed(Exception):
pass
class ChanFull(Exception):
pass
class ChanEmpty(Exception):
pass
class ChanSender(Protocol[T_contra]):
async def send(self, value: T_contra) -> None: ...
def send_nowait(self, value: T_contra) -> None: ...
def close(self) -> None: ...
class ChanReceiver(Protocol[T_co]):
async def recv(self) -> T_co: ...
def recv_nowait(self) -> T_co: ...
def close(self) -> None: ...
def __aiter__(self) -> AsyncIterator[T_co]: ...
async def __anext__(self) -> T_co: ...
class Chan(Generic[T]):
def __init__(
self, maxsize: int = 0, loop: asyncio.AbstractEventLoop | None = None
) -> None:
self._loop = loop or asyncio.get_event_loop()
self._maxsize = max(maxsize, 0)
# self._finished_ev = asyncio.Event()
self._close_ev = asyncio.Event()
self._closed = False
self._gets: Deque[asyncio.Future[T | None]] = deque()
self._puts: Deque[asyncio.Future[T | None]] = deque()
self._queue: Deque[T] = deque()
def _wakeup_next(self, waiters: deque[asyncio.Future[T | None]]):
while waiters:
waiter = waiters.popleft()
if not waiter.done():
waiter.set_result(None)
break
async def send(self, value: T) -> None:
while self.full() and not self._close_ev.is_set():
p = self._loop.create_future()
self._puts.append(p)
try:
await p
except ChanClosed:
raise
except:
p.cancel()
with contextlib.suppress(ValueError):
self._puts.remove(p)
if not self.full() and not p.cancelled():
self._wakeup_next(self._puts)
raise
self.send_nowait(value)
def send_nowait(self, value: T) -> None:
if self.full():
raise ChanFull
if self._close_ev.is_set():
raise ChanClosed
self._queue.append(value)
self._wakeup_next(self._gets)
async def recv(self) -> T:
while self.empty() and not self._close_ev.is_set():
g = self._loop.create_future()
self._gets.append(g)
try:
await g
except ChanClosed:
raise
except Exception:
g.cancel()
with contextlib.suppress(ValueError):
self._gets.remove(g)
if not self.empty() and not g.cancelled():
self._wakeup_next(self._gets)
raise
return self.recv_nowait()
def recv_nowait(self) -> T:
if self.empty():
if self._close_ev.is_set():
raise ChanClosed
else:
raise ChanEmpty
item = self._queue.popleft()
# if self.empty() and self._close_ev.is_set():
# self._finished_ev.set()
self._wakeup_next(self._puts)
return item
def close(self) -> None:
self._closed = True
self._close_ev.set()
for putter in self._puts:
if not putter.cancelled():
putter.set_exception(ChanClosed())
while len(self._gets) > self.qsize():
getter = self._gets.pop()
if not getter.cancelled():
getter.set_exception(ChanClosed())
while self._gets:
self._wakeup_next(self._gets)
# if self.empty():
# self._finished_ev.set()
@property
def closed(self) -> bool:
return self._closed
# async def join(self) -> None:
# await self._finished_ev.wait()
def qsize(self) -> int:
"""the number of elements queued (unread) in the channel buffer"""
return len(self._queue)
def full(self) -> bool:
if self._maxsize <= 0:
return False
else:
return self.qsize() >= self._maxsize
def empty(self) -> bool:
return not self._queue
def __aiter__(self) -> AsyncIterator[T]:
return self
async def __anext__(self) -> T:
try:
return await self.recv()
except ChanClosed:
raise StopAsyncIteration
from __future__ import annotations
import asyncio
import time
from asyncio.base_events import _format_handle # type: ignore
from typing import Any
from ...log import logger
def hook_slow_callbacks(slow_duration: float) -> None:
_run = asyncio.events.Handle._run
def instrumented(self: Any):
start = time.monotonic()
val = _run(self)
dt = time.monotonic() - start
if dt >= slow_duration:
logger.warning(
"Running %s took too long: %.2f seconds",
_format_handle(self), # type: ignore
dt,
)
return val
asyncio.events.Handle._run = instrumented # type: ignore
from __future__ import annotations
import asyncio
import socket
import struct
class DuplexClosed(Exception):
"""Exception raised when the duplex connection is closed."""
pass
class _AsyncDuplex:
def __init__(
self,
sock: socket.socket,
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter,
loop: asyncio.AbstractEventLoop | None = None,
) -> None:
self._loop = loop
self._sock = sock
self._reader = reader
self._writer = writer
@staticmethod
async def open(sock: socket.socket) -> _AsyncDuplex:
loop = asyncio.get_running_loop()
reader, writer = await asyncio.open_connection(sock=sock)
return _AsyncDuplex(sock, reader, writer, loop)
async def recv_bytes(self) -> bytes:
try:
len_bytes = await self._reader.readexactly(4)
len = struct.unpack("!I", len_bytes)[0]
return await self._reader.readexactly(len)
except (
OSError,
EOFError,
asyncio.IncompleteReadError,
):
raise DuplexClosed()
async def send_bytes(self, data: bytes) -> None:
try:
len_bytes = struct.pack("!I", len(data))
self._writer.write(len_bytes)
self._writer.write(data)
await self._writer.drain()
except OSError:
raise DuplexClosed()
async def aclose(self) -> None:
try:
self._writer.close()
await self._writer.wait_closed()
self._sock.close()
except OSError:
raise DuplexClosed()
def _read_exactly(sock: socket.socket, num_bytes: int) -> bytes:
data = bytearray()
while len(data) < num_bytes:
packet = sock.recv(num_bytes - len(data))
if not packet:
raise EOFError()
data.extend(packet)
return bytes(data)
class _Duplex:
def __init__(self, sock: socket.socket) -> None:
self._sock: socket.socket | None = sock
@staticmethod
def open(sock: socket.socket) -> _Duplex:
return _Duplex(sock)
def recv_bytes(self) -> bytes:
if self._sock is None:
raise DuplexClosed()
try:
len_bytes = _read_exactly(self._sock, 4)
len = struct.unpack("!I", len_bytes)[0]
return _read_exactly(self._sock, len)
except (OSError, EOFError):
raise DuplexClosed()
def send_bytes(self, data: bytes) -> None:
if self._sock is None:
raise DuplexClosed()
try:
len_bytes = struct.pack("!I", len(data))
self._sock.sendall(len_bytes)
self._sock.sendall(data)
except OSError:
raise DuplexClosed()
def detach(self) -> socket.socket:
if self._sock is None:
raise DuplexClosed()
sock = self._sock
self._sock = None
return sock
def close(self) -> None:
try:
if self._sock is not None:
self._sock.close()
self._sock = None
except OSError:
raise DuplexClosed()
from __future__ import annotations
import asyncio
from typing import Any
def _finish_fut(fut: asyncio.Future[Any]):
if fut.cancelled():
return
fut.set_result(None)
# MissedBehaviour is "Delay"
class Interval:
def __init__(self, interval: float) -> None:
self._interval = interval
self._last_sleep = 0.0
self._i = 0
self._handler: asyncio.TimerHandle | None = None
def reset(self) -> None:
if self._fut and self._handler and not self._handler.cancelled():
self._handler.cancel()
loop = asyncio.get_event_loop()
self._handler = loop.call_later(self._interval, _finish_fut, self._fut)
else:
self._last_sleep = 0
async def tick(self) -> int:
loop = asyncio.get_event_loop()
if self._last_sleep:
self._fut = loop.create_future()
delay = self._last_sleep - loop.time() + self._interval
self._handler = loop.call_later(delay, _finish_fut, self._fut)
try:
await self._fut
finally:
self._handler.cancel()
self._i += 1
self._last_sleep = loop.time()
return self._i
def __aiter__(self) -> "Interval":
return self
async def __anext__(self):
return await self.tick()
def interval(interval: float) -> Interval:
return Interval(interval)
import asyncio
from collections import deque
from typing import (
Any,
AsyncGenerator,
AsyncIterable,
AsyncIterator,
Deque,
Generic,
Iterator,
List,
Protocol,
Tuple,
TypeVar,
Union,
overload,
runtime_checkable,
)
from typing_extensions import AsyncContextManager
# based on https://github.com/maxfischer2781/asyncstdlib/blob/master/asyncstdlib/itertools.py
@runtime_checkable
class _ACloseable(Protocol):
async def aclose(self) -> None:
"""Asynchronously close this object"""
T = TypeVar("T")
async def tee_peer(
iterator: AsyncIterator[T],
buffer: Deque[T],
peers: List[Deque[T]],
lock: AsyncContextManager[Any],
) -> AsyncGenerator[T, None]:
try:
while True:
if not buffer:
async with lock:
if buffer:
continue
try:
item = await iterator.__anext__()
except StopAsyncIteration:
break
else:
for peer_buffer in peers:
peer_buffer.append(item)
yield buffer.popleft()
finally:
for idx, peer_buffer in enumerate(peers): # pragma: no branch
if peer_buffer is buffer:
peers.pop(idx)
break
if not peers and isinstance(iterator, _ACloseable):
await iterator.aclose()
class Tee(Generic[T]):
__slots__ = ("_iterator", "_buffers", "_children")
def __init__(
self,
iterator: AsyncIterable[T],
n: int = 2,
):
self._iterator = iterator.__aiter__()
self._buffers: List[Deque[T]] = [deque() for _ in range(n)]
lock = asyncio.Lock()
self._children = tuple(
tee_peer(
iterator=self._iterator,
buffer=buffer,
peers=self._buffers,
lock=lock,
)
for buffer in self._buffers
)
def __len__(self) -> int:
return len(self._children)
@overload
def __getitem__(self, item: int) -> AsyncIterator[T]: ...
@overload
def __getitem__(self, item: slice) -> Tuple[AsyncIterator[T], ...]: ...
def __getitem__(
self, item: Union[int, slice]
) -> Union[AsyncIterator[T], Tuple[AsyncIterator[T], ...]]:
return self._children[item]
def __iter__(self) -> Iterator[AsyncIterator[T]]:
yield from self._children
async def __aenter__(self) -> "Tee[T]":
return self
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
await self.aclose()
async def aclose(self) -> None:
for child in self._children:
await child.aclose()
tee = Tee
from __future__ import annotations
import asyncio
from typing import Any
def _finish_fut(fut: asyncio.Future[Any]):
if fut.cancelled():
return
fut.set_result(None)
class SleepFinished(Exception):
pass
class Sleep:
"""Same as asyncio.sleep except it is resettable"""
def __init__(self, delay: float) -> None:
self._delay = delay
self._handler: asyncio.TimerHandle | None = None
def reset(self, new_delay: float | None = None) -> None:
if new_delay is None:
new_delay = self._delay
self._delay = new_delay
if self._handler is None:
return
if self._handler.cancelled() or self._fut.done():
raise SleepFinished
self._handler.cancel()
loop = asyncio.get_event_loop()
self._handler = loop.call_later(new_delay, _finish_fut, self._fut)
def cancel(self) -> None:
if self._handler is None:
return
self._handler.cancel()
self._fut.cancel()
async def _sleep(self) -> None:
if self._delay <= 0:
self._fut = asyncio.Future[None]()
self._fut.set_result(None)
return
loop = asyncio.get_event_loop()
self._fut = loop.create_future()
self._handler = loop.call_later(self._delay, _finish_fut, self._fut)
try:
await self._fut
finally:
self._handler.cancel()
def __await__(self):
return self._sleep().__await__()
def sleep(delay: float) -> Sleep:
return Sleep(delay)
from __future__ import annotations
import asyncio
from typing import Any, Coroutine, TypeVar
_T = TypeVar("_T")
class TaskSet:
"""
Small utility to create task in a fire-and-forget fashion.
"""
def __init__(self, loop: asyncio.AbstractEventLoop | None = None) -> None:
self._loop = loop or asyncio.get_event_loop()
self._set = set[asyncio.Task[Any]]()
self._closed = False
def create_task(self, coro: Coroutine[Any, Any, _T]) -> asyncio.Task[_T]:
if self._closed:
raise RuntimeError("TaskSet is closed")
task = self._loop.create_task(coro)
self._set.add(task)
task.add_done_callback(self._set.remove)
return task
async def aclose(self) -> None:
self._closed = True
await asyncio.gather(*self._set, return_exceptions=True)
self._set.clear()
from __future__ import annotations
import ctypes
from typing import List, Union
from livekit import rtc
from ..log import logger
# deprecated aliases
AudioBuffer = Union[List[rtc.AudioFrame], rtc.AudioFrame]
combine_frames = rtc.combine_audio_frames
merge_frames = rtc.combine_audio_frames
def calculate_audio_duration(frames: AudioBuffer) -> float:
"""
Calculate the total duration of audio frames.
This function computes the total duration of audio frames in seconds.
It accepts either a list of `rtc.AudioFrame` objects or a single `rtc.AudioFrame` object.
Parameters:
- frames (AudioBuffer): A list of `rtc.AudioFrame` instances or a single `rtc.AudioFrame` instance.
Returns:
- float: The total duration in seconds of all frames provided.
"""
if isinstance(frames, list):
return sum(frame.duration for frame in frames)
else:
return frames.duration
class AudioByteStream:
"""
Buffer and chunk audio byte data into fixed-size frames.
This class is designed to handle incoming audio data in bytes,
buffering it and producing audio frames of a consistent size.
It is mainly used to easily chunk big or too small audio frames
into a fixed size, helping to avoid processing very small frames
(which can be inefficient) and very large frames (which can cause
latency or processing delays). By normalizing frame sizes, it
facilitates consistent and efficient audio data processing.
"""
def __init__(
self,
sample_rate: int,
num_channels: int,
samples_per_channel: int | None = None,
) -> None:
"""
Initialize an AudioByteStream instance.
Parameters:
sample_rate (int): The audio sample rate in Hz.
num_channels (int): The number of audio channels.
samples_per_channel (int, optional): The number of samples per channel in each frame.
If None, defaults to `sample_rate // 10` (i.e., 100ms of audio data).
The constructor sets up the internal buffer and calculates the size of each frame in bytes.
The frame size is determined by the number of channels, samples per channel, and the size
of each sample (assumed to be 16 bits or 2 bytes).
"""
self._sample_rate = sample_rate
self._num_channels = num_channels
if samples_per_channel is None:
samples_per_channel = sample_rate // 10 # 100ms by default
self._bytes_per_frame = (
num_channels * samples_per_channel * ctypes.sizeof(ctypes.c_int16)
)
self._buf = bytearray()
def push(self, data: bytes) -> list[rtc.AudioFrame]:
"""
Add audio data to the buffer and retrieve fixed-size frames.
Parameters:
data (bytes): The incoming audio data to buffer.
Returns:
list[rtc.AudioFrame]: A list of `AudioFrame` objects of fixed size.
The method appends the incoming data to the internal buffer.
While the buffer contains enough data to form complete frames,
it extracts the data for each frame, creates an `AudioFrame` object,
and appends it to the list of frames to return.
This allows you to feed in variable-sized chunks of audio data
(e.g., from a stream or file) and receive back a list of
fixed-size audio frames ready for processing or transmission.
"""
self._buf.extend(data)
frames = []
while len(self._buf) >= self._bytes_per_frame:
frame_data = self._buf[: self._bytes_per_frame]
self._buf = self._buf[self._bytes_per_frame :]
frames.append(
rtc.AudioFrame(
data=frame_data,
sample_rate=self._sample_rate,
num_channels=self._num_channels,
samples_per_channel=len(frame_data) // 2,
)
)
return frames
write = push # Alias for the push method.
def flush(self) -> list[rtc.AudioFrame]:
"""
Flush the buffer and retrieve any remaining audio data as a frame.
Returns:
list[rtc.AudioFrame]: A list containing any remaining `AudioFrame` objects.
This method processes any remaining data in the buffer that does not
fill a complete frame. If the remaining data forms a partial frame
(i.e., its size is not a multiple of the expected sample size), a warning is
logged and an empty list is returned. Otherwise, it returns the final
`AudioFrame` containing the remaining data.
Use this method when you have no more data to push and want to ensure
that all buffered audio data has been processed.
"""
if len(self._buf) == 0:
return []
if len(self._buf) % (2 * self._num_channels) != 0:
logger.warning("AudioByteStream: incomplete frame during flush, dropping")
return []
return [
rtc.AudioFrame(
data=self._buf,
sample_rate=self._sample_rate,
num_channels=self._num_channels,
samples_per_channel=len(self._buf) // 2,
)
]
# Copyright 2024 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .decoder import AudioStreamDecoder, StreamBuffer
__all__ = ["AudioStreamDecoder", "StreamBuffer"]
# Copyright 2024 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import contextlib
import io
from concurrent.futures import ThreadPoolExecutor
from typing import AsyncIterator, Optional
import av
import av.container
from livekit import rtc
from livekit.agents.log import logger
from livekit.agents.utils import aio
try:
# preload to ensure faster startup
import av # noqa
except ImportError:
pass
import threading
class StreamBuffer:
"""
A thread-safe buffer that behaves like an IO stream.
Allows writing from one thread and reading from another.
"""
def __init__(self):
self._buffer = io.BytesIO()
self._lock = threading.Lock()
self._data_available = threading.Condition(self._lock)
self._eof = False
def write(self, data: bytes):
"""Write data to the buffer from a writer thread."""
with self._data_available:
self._buffer.seek(0, io.SEEK_END)
self._buffer.write(data)
self._data_available.notify_all()
def read(self, size: int = -1) -> bytes:
"""Read data from the buffer in a reader thread."""
if self._buffer.closed:
return b""
with self._data_available:
while True:
if self._buffer.closed:
return b""
# always read from beginning
self._buffer.seek(0)
data = self._buffer.read(size)
if data:
# shrink the buffer to remove already-read data
remaining = self._buffer.read()
self._buffer = io.BytesIO(remaining)
return data
if self._eof:
return b""
self._data_available.wait()
def end_input(self):
"""Signal that no more data will be written."""
with self._data_available:
self._eof = True
self._data_available.notify_all()
def close(self):
self._buffer.close()
class AudioStreamDecoder:
"""A class that can be used to decode audio stream into PCM AudioFrames.
Decoders are stateful, and it should not be reused across multiple streams. Each decoder
is designed to decode a single stream.
"""
_max_workers: int = 10
_executor: Optional[ThreadPoolExecutor] = None
def __init__(self, *, sample_rate: int = 48000, num_channels: int = 1):
try:
import av # noqa
except ImportError:
raise ImportError(
"You haven't included the 'codecs' optional dependencies. Please install the 'codecs' extra by running `pip install livekit-agents[codecs]`"
)
self._sample_rate = sample_rate
self._layout = "mono"
if num_channels == 2:
self._layout = "stereo"
elif num_channels != 1:
raise ValueError(f"Invalid number of channels: {num_channels}")
self._output_ch = aio.Chan[rtc.AudioFrame]()
self._closed = False
self._started = False
self._input_buf = StreamBuffer()
self._loop = asyncio.get_event_loop()
if self.__class__._executor is None:
# each decoder instance will submit jobs to the shared pool
self.__class__._executor = ThreadPoolExecutor(
max_workers=self.__class__._max_workers
)
def push(self, chunk: bytes):
self._input_buf.write(chunk)
if not self._started:
self._started = True
self._loop.run_in_executor(self.__class__._executor, self._decode_loop)
def end_input(self):
self._input_buf.end_input()
if not self._started:
# if no data was pushed, close the output channel
self._output_ch.close()
def _decode_loop(self):
container: av.container.InputContainer | None = None
resampler: av.AudioResampler | None = None
try:
container = av.open(self._input_buf, mode="r")
if len(container.streams.audio) == 0:
raise ValueError("no audio stream found")
audio_stream = container.streams.audio[0]
resampler = av.AudioResampler(
format="s16", layout=self._layout, rate=self._sample_rate
)
for frame in container.decode(audio_stream):
if self._closed:
return
for resampled_frame in resampler.resample(frame):
nchannels = len(resampled_frame.layout.channels)
self._loop.call_soon_threadsafe(
self._output_ch.send_nowait,
rtc.AudioFrame(
data=resampled_frame.to_ndarray().tobytes(),
num_channels=nchannels,
sample_rate=int(resampled_frame.sample_rate),
samples_per_channel=int(
resampled_frame.samples / nchannels
),
),
)
except Exception:
logger.exception("error decoding audio")
finally:
self._loop.call_soon_threadsafe(self._output_ch.close)
if container:
container.close()
def __aiter__(self) -> AsyncIterator[rtc.AudioFrame]:
return self
async def __anext__(self) -> rtc.AudioFrame:
return await self._output_ch.__anext__()
async def aclose(self):
if self._closed:
return
self.end_input()
self._closed = True
self._input_buf.close()
# wait for decode loop to finish, only if anything's been pushed
with contextlib.suppress(aio.ChanClosed):
if self._started:
await self._output_ch.recv()
import asyncio
import time
import weakref
from contextlib import asynccontextmanager
from typing import (
AsyncGenerator,
Awaitable,
Callable,
Generic,
Optional,
Set,
TypeVar,
)
from . import aio
T = TypeVar("T")
class ConnectionPool(Generic[T]):
"""Helper class to manage persistent connections like websockets.
Handles connection pooling and reconnection after max duration.
Can be used as an async context manager to automatically return connections to the pool.
"""
def __init__(
self,
*,
max_session_duration: Optional[float] = None,
mark_refreshed_on_get: bool = False,
connect_cb: Optional[Callable[[], Awaitable[T]]] = None,
close_cb: Optional[Callable[[T], Awaitable[None]]] = None,
) -> None:
"""Initialize the connection wrapper.
Args:
max_session_duration: Maximum duration in seconds before forcing reconnection
mark_refreshed_on_get: If True, the session will be marked as fresh when get() is called. only used when max_session_duration is set.
connect_cb: Optional async callback to create new connections
close_cb: Optional async callback to close connections
"""
self._max_session_duration = max_session_duration
self._mark_refreshed_on_get = mark_refreshed_on_get
self._connect_cb = connect_cb
self._close_cb = close_cb
self._connections: dict[T, float] = {} # conn -> connected_at timestamp
self._available: Set[T] = set()
# store connections to be reaped (closed) later.
self._to_close: Set[T] = set()
self._prewarm_task: Optional[weakref.ref[asyncio.Task]] = None
async def _connect(self) -> T:
"""Create a new connection.
Returns:
The new connection object
Raises:
NotImplementedError: If no connect callback was provided
"""
if self._connect_cb is None:
raise NotImplementedError("Must provide connect_cb or implement connect()")
connection = await self._connect_cb()
self._connections[connection] = time.time()
return connection
async def _drain_to_close(self) -> None:
"""Drain and close all the connections queued for closing."""
for conn in list(self._to_close):
await self._maybe_close_connection(conn)
self._to_close.clear()
@asynccontextmanager
async def connection(self) -> AsyncGenerator[T, None]:
"""Get a connection from the pool and automatically return it when done.
Yields:
An active connection object
"""
conn = await self.get()
try:
yield conn
except Exception:
self.remove(conn)
raise
else:
self.put(conn)
async def get(self) -> T:
"""Get an available connection or create a new one if needed.
Returns:
An active connection object
"""
await self._drain_to_close()
now = time.time()
# try to reuse an available connection that hasn't expired
while self._available:
conn = self._available.pop()
if (
self._max_session_duration is None
or now - self._connections[conn] <= self._max_session_duration
):
if self._mark_refreshed_on_get:
self._connections[conn] = now
return conn
# connection expired; mark it for resetting.
self.remove(conn)
return await self._connect()
def put(self, conn: T) -> None:
"""Mark a connection as available for reuse.
If connection has been reset, it will not be added to the pool.
Args:
conn: The connection to make available
"""
if conn in self._connections:
self._available.add(conn)
async def _maybe_close_connection(self, conn: T) -> None:
"""Close a connection if close_cb is provided.
Args:
conn: The connection to close
"""
if self._close_cb is not None:
await self._close_cb(conn)
def remove(self, conn: T) -> None:
"""Remove a specific connection from the pool.
Marks the connection to be closed during the next drain cycle.
Args:
conn: The connection to reset
"""
self._available.discard(conn)
if conn in self._connections:
self._to_close.add(conn)
self._connections.pop(conn, None)
def invalidate(self) -> None:
"""Clear all existing connections.
Marks all current connections to be closed during the next drain cycle.
"""
for conn in list(self._connections.keys()):
self._to_close.add(conn)
self._connections.clear()
self._available.clear()
def prewarm(self) -> None:
"""Initiate prewarming of the connection pool without blocking.
This method starts a background task that creates a new connection if none exist.
The task automatically cleans itself up when the connection pool is closed.
"""
if self._prewarm_task is not None or self._connections:
return
async def _prewarm_impl():
if not self._connections:
conn = await self._connect()
self._available.add(conn)
task = asyncio.create_task(_prewarm_impl())
self._prewarm_task = weakref.ref(task)
async def aclose(self):
"""Close all connections, draining any pending connection closures."""
if self._prewarm_task is not None:
task = self._prewarm_task()
if task:
aio.gracefully_cancel(task)
self.invalidate()
await self._drain_to_close()
class ExpFilter:
def __init__(self, alpha: float, max_val: float = -1.0) -> None:
self._alpha = alpha
self._filtered = -1.0
self._max_val = max_val
def reset(self, alpha: float = -1.0) -> None:
if alpha != -1.0:
self._alpha = alpha
self._filtered = -1.0
def apply(self, exp: float, sample: float) -> float:
if self._filtered == -1.0:
self._filtered = sample
else:
a = self._alpha**exp
self._filtered = a * self._filtered + (1 - a) * sample
if self._max_val != -1.0 and self._filtered > self._max_val:
self._filtered = self._max_val
return self._filtered
def filtered(self) -> float:
return self._filtered
def update_base(self, alpha: float) -> None:
self._alpha = alpha
from __future__ import annotations
import contextvars
from typing import Callable
import aiohttp
from ..log import logger
_ClientFactory = Callable[[], aiohttp.ClientSession]
_ContextVar = contextvars.ContextVar("agent_http_session") # type: ignore
def _new_session_ctx() -> _ClientFactory:
g_session: aiohttp.ClientSession | None = None
def _new_session() -> aiohttp.ClientSession:
nonlocal g_session
if g_session is None:
logger.debug("http_session(): creating a new httpclient ctx")
g_session = aiohttp.ClientSession()
return g_session
_ContextVar.set(_new_session) # type: ignore
return _new_session
def http_session() -> aiohttp.ClientSession:
"""Optional utility function to avoid having to manually manage an aiohttp.ClientSession lifetime.
On job processes, this http session will be bound to the main event loop.
"""
val = _ContextVar.get(None) # type: ignore
if val is None:
raise RuntimeError(
"Attempted to use an http session outside of a job context. This is probably because you are trying to use a plugin without using the agent worker api. You may need to create your own aiohttp.ClientSession, pass it into the plugin constructor as a kwarg, and manage its lifecycle."
)
return val() # type: ignore
async def _close_http_ctx():
val = _ContextVar.get(None) # type: ignore
if val is not None:
logger.debug("http_session(): closing the httpclient ctx")
await val().close() # type: ignore
_ContextVar.set(None) # type: ignore
from .cpu import CGroupV2CPUMonitor, CPUMonitor, DefaultCPUMonitor, get_cpu_monitor
__all__ = ["get_cpu_monitor", "CPUMonitor", "CGroupV2CPUMonitor", "DefaultCPUMonitor"]
import os
import time
from abc import ABC, abstractmethod
import psutil
class CPUMonitor(ABC):
@abstractmethod
def cpu_count(self) -> float:
"""Number of logical CPUs.
Returns a float to allow for fractional CPUs (in the case of cgroups)."""
pass
@abstractmethod
def cpu_percent(self, interval: float = 0.5) -> float:
"""CPU usage percentage between 0 and 1"""
pass
class DefaultCPUMonitor(CPUMonitor):
def cpu_count(self) -> float:
return psutil.cpu_count() or 1.0
def cpu_percent(self, interval: float = 0.5) -> float:
return psutil.cpu_percent(interval) / 100.0
class CGroupV2CPUMonitor(CPUMonitor):
def cpu_count(self) -> float:
# quota: The maximum CPU time in microseconds that the cgroup can use within a given period.
# period: The period of time in microseconds over which the quota applies.
# If the quota is set to "max", it means the cgroup is allowed to use all available CPUs without restriction.
# Otherwise, the quota is a number that represents the maximum CPU time in microseconds that the cgroup can use within a given period.
quota, period = self._read_cpu_max()
if quota == "max":
return os.cpu_count() or 1
return 1.0 * int(quota) / period
def cpu_percent(self, interval: float = 0.5) -> float:
cpu_usage_start = self._read_cpu_usage()
time.sleep(interval)
cpu_usage_end = self._read_cpu_usage()
cpu_usage_diff = cpu_usage_end - cpu_usage_start
# Convert microseconds to seconds
cpu_usage_seconds = cpu_usage_diff / 1_000_000
# Get the number of CPUs available to the container
num_cpus = self.cpu_count()
# Calculate the percentage
cpu_usage_percent = cpu_usage_seconds / (interval * num_cpus)
return min(cpu_usage_percent, 1)
def _read_cpu_max(self) -> tuple[str, int]:
try:
with open("/sys/fs/cgroup/cpu.max", "r") as f:
data = f.read().strip().split()
quota = data[0]
period = int(data[1])
except FileNotFoundError:
quota = "max"
period = 100000
return quota, period
def _read_cpu_usage(self) -> int:
with open("/sys/fs/cgroup/cpu.stat", "r") as f:
for line in f:
if line.startswith("usage_usec"):
return int(line.split()[1])
raise RuntimeError("Failed to read CPU usage")
def get_cpu_monitor() -> CPUMonitor:
if _is_cgroup_v2():
return CGroupV2CPUMonitor()
return DefaultCPUMonitor()
def _is_cgroup_v2() -> bool:
return os.path.exists("/sys/fs/cgroup/cpu.stat")
# Copyright 2024 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .image import EncodeOptions, ResizeOptions, encode
__all__ = ["EncodeOptions", "ResizeOptions", "encode"]
# Copyright 2024 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
from dataclasses import dataclass
from importlib import import_module
from typing import TYPE_CHECKING, Any, Literal, Optional
from livekit import rtc
if TYPE_CHECKING:
from PIL import Image
@dataclass
class EncodeOptions:
"""Options for encoding rtc.VideoFrame to portable image formats."""
format: Literal["JPEG", "PNG"] = "JPEG"
"""The format to encode the image."""
resize_options: Optional["ResizeOptions"] = None
"""Options for resizing the image."""
quality: Optional[int] = 75
"""Image compression quality, 0-100. Only applies to JPEG."""
@dataclass
class ResizeOptions:
"""Options for resizing rtc.VideoFrame as part of encoding to a portable image format."""
width: int
"""The desired resize width (in)"""
height: int
"""The desired height to resize the image to."""
strategy: Literal[
"center_aspect_fit",
"center_aspect_cover",
"scale_aspect_fit",
"scale_aspect_cover",
"skew",
]
"""The strategy to use when resizing the image:
- center_aspect_fit: Fit the image into the provided dimensions, with letterboxing
- center_aspect_cover: Fill the provided dimensions, with cropping
- scale_aspect_fit: Fit the image into the provided dimensions, preserving its original aspect ratio
- scale_aspect_cover: Fill the provided dimensions, preserving its original aspect ratio (image will be larger than the provided dimensions)
- skew: Precisely resize the image to the provided dimensions
"""
def import_pil():
try:
if "Image" not in globals():
globals()["Image"] = import_module("PIL.Image")
except ImportError:
raise ImportError(
"You haven't included the 'images' optional dependencies. Please install the 'codecs' extra by running `pip install livekit-agents[images]`"
)
def encode(frame: rtc.VideoFrame, options: EncodeOptions) -> bytes:
"""Encode a rtc.VideoFrame to a portable image format (JPEG or PNG).
See EncodeOptions for more details.
"""
import_pil()
img = _image_from_frame(frame)
resized = _resize_image(img, options)
buffer = io.BytesIO()
kwargs = {}
if options.format == "JPEG" and options.quality is not None:
kwargs["quality"] = options.quality
resized.save(buffer, options.format, **kwargs)
buffer.seek(0)
return buffer.read()
def _image_from_frame(frame: rtc.VideoFrame):
converted = frame
if frame.type != rtc.VideoBufferType.RGBA:
converted = frame.convert(rtc.VideoBufferType.RGBA)
rgb_image = Image.frombytes( # type: ignore
"RGBA", (frame.width, frame.height), converted.data
).convert("RGB")
return rgb_image
def _resize_image(image: Any, options: EncodeOptions):
if options.resize_options is None:
return image
resize_opts = options.resize_options
if resize_opts.strategy == "skew":
return image.resize((resize_opts.width, resize_opts.height))
elif resize_opts.strategy == "center_aspect_fit":
result = Image.new("RGB", (resize_opts.width, resize_opts.height)) # noqa
# Start with assuming the new image is narrower than the original
new_width = resize_opts.width
new_height = int(image.height * (resize_opts.width / image.width))
# If the new image is wider than the original
if resize_opts.width / resize_opts.height > image.width / image.height:
new_height = resize_opts.height
new_width = int(image.width * (resize_opts.height / image.height))
resized = image.resize((new_width, new_height))
Image.Image.paste(
result,
resized,
(
(resize_opts.width - new_width) // 2,
(resize_opts.height - new_height) // 2,
),
)
return result
elif resize_opts.strategy == "center_aspect_cover":
result = Image.new("RGB", (resize_opts.width, resize_opts.height)) # noqa
# Start with assuming the new image is shorter than the original
new_height = int(image.height * (resize_opts.width / image.width))
new_width = resize_opts.width
# If the new image is taller than the original
if resize_opts.height / resize_opts.width > image.height / image.width:
new_width = int(image.width * (resize_opts.height / image.height))
new_height = resize_opts.height
resized = image.resize((new_width, new_height))
Image.Image.paste( # noqa
result,
resized,
(
(resize_opts.width - new_width) // 2,
(resize_opts.height - new_height) // 2,
),
)
return result
elif resize_opts.strategy == "scale_aspect_fill":
# Start with assuming width is the limiting dimension
new_width = resize_opts.width
new_height = int(image.height * (resize_opts.width / image.width))
# If height is under the limit, scale based on height instead
if new_height < resize_opts.height:
new_height = resize_opts.height
new_width = int(image.width * (resize_opts.height / image.height))
return image.resize((new_width, new_height))
elif resize_opts.strategy == "scale_aspect_fit":
# Start with assuming width is the limiting dimension
new_width = resize_opts.width
new_height = int(image.height * (resize_opts.width / image.width))
# If height would exceed the limit, scale based on height instead
if new_height > resize_opts.height:
new_height = resize_opts.height
new_width = int(image.width * (resize_opts.height / image.height))
return image.resize((new_width, new_height))
raise ValueError(f"Unknown resize strategy: {resize_opts.strategy}")
import asyncio
import functools
import logging
from typing import Any, Callable
def log_exceptions(
msg: str = "", logger: logging.Logger = logging.getLogger()
) -> Callable[[Any], Any]:
def deco(fn: Callable[[Any], Any]):
if asyncio.iscoroutinefunction(fn):
@functools.wraps(fn)
async def async_fn_logs(*args: Any, **kwargs: Any):
try:
return await fn(*args, **kwargs)
except Exception:
err = f"Error in {fn.__name__}"
if msg:
err += f" – {msg}"
logger.exception(err)
raise
return async_fn_logs
else:
@functools.wraps(fn)
def fn_logs(*args: Any, **kwargs: Any):
try:
return fn(*args, **kwargs)
except Exception:
err = f"Error in {fn.__name__}"
if msg:
err += f" – {msg}"
logger.exception(err)
raise
return fn_logs
return deco
from __future__ import annotations
import time
import uuid
from typing import TypeVar
from typing_extensions import TypeGuard
from ..types import NotGiven, NotGivenOr
_T = TypeVar("_T")
def time_ms() -> int:
return int(time.time() * 1000 + 0.5)
def shortuuid(prefix: str = "") -> str:
return prefix + str(uuid.uuid4().hex)[:12]
def is_given(obj: NotGivenOr[_T]) -> TypeGuard[_T]:
return not isinstance(obj, NotGiven)
from __future__ import annotations
class MovingAverage:
def __init__(self, window_size: int) -> None:
self._hist: list[float] = [0] * window_size
self._sum: float = 0
self._count: int = 0
def add_sample(self, sample: float) -> None:
self._count += 1
index = self._count % len(self._hist)
if self._count > len(self._hist):
self._sum -= self._hist[index]
self._sum += sample
self._hist[index] = sample
def get_avg(self) -> float:
if self._count == 0:
return 0
return self._sum / self.size()
def reset(self):
self._count = 0
self._sum = 0
def size(self) -> int:
return min(self._count, len(self._hist))
from __future__ import annotations
import asyncio
import time
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import Enum, unique
from typing import AsyncIterable, AsyncIterator, List, Literal, Union
from livekit import rtc
from .metrics import VADMetrics
from .utils import aio
@unique
class VADEventType(str, Enum):
START_OF_SPEECH = "start_of_speech"
INFERENCE_DONE = "inference_done"
END_OF_SPEECH = "end_of_speech"
@dataclass
class VADEvent:
"""
Represents an event detected by the Voice Activity Detector (VAD).
"""
type: VADEventType
"""Type of the VAD event (e.g., start of speech, end of speech, inference done)."""
samples_index: int
"""Index of the audio sample where the event occurred, relative to the inference sample rate."""
timestamp: float
"""Timestamp (in seconds) when the event was fired."""
speech_duration: float
"""Duration of the speech segment in seconds."""
silence_duration: float
"""Duration of the silence segment in seconds."""
frames: List[rtc.AudioFrame] = field(default_factory=list)
"""
List of audio frames associated with the speech.
- For `start_of_speech` events, this contains the audio chunks that triggered the detection.
- For `inference_done` events, this contains the audio chunks that were processed.
- For `end_of_speech` events, this contains the complete user speech.
"""
probability: float = 0.0
"""Probability that speech is present (only for `INFERENCE_DONE` events)."""
inference_duration: float = 0.0
"""Time taken to perform the inference, in seconds (only for `INFERENCE_DONE` events)."""
speaking: bool = False
"""Indicates whether speech was detected in the frames."""
raw_accumulated_silence: float = 0.0
"""Threshold used to detect silence."""
raw_accumulated_speech: float = 0.0
"""Threshold used to detect speech."""
@dataclass
class VADCapabilities:
update_interval: float
class VAD(ABC, rtc.EventEmitter[Literal["metrics_collected"]]):
def __init__(self, *, capabilities: VADCapabilities) -> None:
super().__init__()
self._capabilities = capabilities
self._label = f"{type(self).__module__}.{type(self).__name__}"
@property
def capabilities(self) -> VADCapabilities:
return self._capabilities
@abstractmethod
def stream(self) -> "VADStream": ...
class VADStream(ABC):
class _FlushSentinel:
pass
def __init__(self, vad: VAD) -> None:
self._vad = vad
self._last_activity_time = time.perf_counter()
self._input_ch = aio.Chan[Union[rtc.AudioFrame, VADStream._FlushSentinel]]()
self._event_ch = aio.Chan[VADEvent]()
self._event_aiter, monitor_aiter = aio.itertools.tee(self._event_ch, 2)
self._metrics_task = asyncio.create_task(
self._metrics_monitor_task(monitor_aiter), name="TTS._metrics_task"
)
self._task = asyncio.create_task(self._main_task())
self._task.add_done_callback(lambda _: self._event_ch.close())
@abstractmethod
async def _main_task(self) -> None: ...
async def _metrics_monitor_task(self, event_aiter: AsyncIterable[VADEvent]) -> None:
"""Task used to collect metrics"""
inference_duration_total = 0.0
inference_count = 0
async for ev in event_aiter:
if ev.type == VADEventType.INFERENCE_DONE:
inference_duration_total += ev.inference_duration
inference_count += 1
if inference_count >= 1 / self._vad.capabilities.update_interval:
vad_metrics = VADMetrics(
timestamp=time.time(),
idle_time=time.perf_counter() - self._last_activity_time,
inference_duration_total=inference_duration_total,
inference_count=inference_count,
label=self._vad._label,
)
self._vad.emit("metrics_collected", vad_metrics)
inference_duration_total = 0.0
inference_count = 0
elif ev.type in [VADEventType.START_OF_SPEECH, VADEventType.END_OF_SPEECH]:
self._last_activity_time = time.perf_counter()
def push_frame(self, frame: rtc.AudioFrame) -> None:
"""Push some text to be synthesized"""
self._check_input_not_ended()
self._check_not_closed()
self._input_ch.send_nowait(frame)
def flush(self) -> None:
"""Mark the end of the current segment"""
self._check_input_not_ended()
self._check_not_closed()
self._input_ch.send_nowait(self._FlushSentinel())
def end_input(self) -> None:
"""Mark the end of input, no more text will be pushed"""
self.flush()
self._input_ch.close()
async def aclose(self) -> None:
"""Close ths stream immediately"""
self._input_ch.close()
await aio.gracefully_cancel(self._task)
self._event_ch.close()
await self._metrics_task
async def __anext__(self) -> VADEvent:
try:
val = await self._event_aiter.__anext__()
except StopAsyncIteration:
if not self._task.cancelled() and (exc := self._task.exception()):
raise exc from None
raise StopAsyncIteration
return val
def __aiter__(self) -> AsyncIterator[VADEvent]:
return self
def _check_not_closed(self) -> None:
if self._event_ch.closed:
cls = type(self)
raise RuntimeError(f"{cls.__module__}.{cls.__name__} is closed")
def _check_input_not_ended(self) -> None:
if self._input_ch.closed:
cls = type(self)
raise RuntimeError(f"{cls.__module__}.{cls.__name__} input ended")
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "0.12.20"
from ..pipeline import AgentCallContext, AgentTranscriptionOptions, VoicePipelineAgent
AssistantTranscriptionOptions = AgentTranscriptionOptions
AssistantCallContext = AgentCallContext
VoiceAssistant = VoicePipelineAgent
__all__ = [
"AssistantTranscriptionOptions",
"AssistantCallContext",
"VoiceAssistant",
]
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import asyncio
import contextlib
import datetime
import inspect
import math
import multiprocessing as mp
import os
import sys
import threading
from dataclasses import dataclass, field
from enum import Enum
from functools import reduce
from typing import (
Any,
Awaitable,
Callable,
Generic,
Literal,
TypeVar,
)
from urllib.parse import urljoin, urlparse
import aiohttp
import jwt
from livekit import api, rtc
from livekit.protocol import agent, models
from . import http_server, ipc, utils
from ._exceptions import AssignmentTimeoutError
from .inference_runner import _InferenceRunner
from .job import (
JobAcceptArguments,
JobContext,
JobExecutorType,
JobProcess,
JobRequest,
RunningJobInfo,
)
from .log import DEV_LEVEL, logger
from .utils.hw import get_cpu_monitor
from .version import __version__
ASSIGNMENT_TIMEOUT = 7.5
UPDATE_LOAD_INTERVAL = 2.5
def _default_initialize_process_fnc(proc: JobProcess) -> Any:
return
async def _default_request_fnc(ctx: JobRequest) -> None:
await ctx.accept()
class WorkerType(Enum):
ROOM = agent.JobType.JT_ROOM
PUBLISHER = agent.JobType.JT_PUBLISHER
class _DefaultLoadCalc:
_instance = None
def __init__(self) -> None:
self._m_avg = utils.MovingAverage(5) # avg over 2.5
self._cpu_monitor = get_cpu_monitor()
self._thread = threading.Thread(
target=self._calc_load, daemon=True, name="worker_cpu_load_monitor"
)
self._lock = threading.Lock()
self._thread.start()
def _calc_load(self) -> None:
while True:
cpu_p = self._cpu_monitor.cpu_percent(interval=0.5)
with self._lock:
self._m_avg.add_sample(cpu_p)
def _get_avg(self) -> float:
with self._lock:
return self._m_avg.get_avg()
@classmethod
def get_load(cls, worker: Worker) -> float:
if cls._instance is None:
cls._instance = _DefaultLoadCalc()
return cls._instance._m_avg.get_avg()
@dataclass
class WorkerPermissions:
can_publish: bool = True
can_subscribe: bool = True
can_publish_data: bool = True
can_update_metadata: bool = True
can_publish_sources: list[models.TrackSource] = field(default_factory=list)
hidden: bool = False
if sys.platform.startswith("win"):
# Some python versions on Windows gets a BrokenPipeError when creating a new process
_default_job_executor_type = JobExecutorType.THREAD
else:
_default_job_executor_type = JobExecutorType.PROCESS
T = TypeVar("T")
@dataclass(frozen=True)
class _WorkerEnvOption(Generic[T]):
dev_default: T
prod_default: T
@staticmethod
def getvalue(opt: T | _WorkerEnvOption[T], devmode: bool) -> T:
if isinstance(opt, _WorkerEnvOption):
return opt.dev_default if devmode else opt.prod_default
return opt
# NOTE: this object must be pickle-able
@dataclass
class WorkerOptions:
entrypoint_fnc: Callable[[JobContext], Awaitable[None]]
"""Entrypoint function that will be called when a job is assigned to this worker."""
request_fnc: Callable[[JobRequest], Awaitable[None]] = _default_request_fnc
"""Inspect the request and decide if the current worker should handle it.
When left empty, all jobs are accepted."""
prewarm_fnc: Callable[[JobProcess], Any] = _default_initialize_process_fnc
"""A function to perform any necessary initialization before the job starts."""
load_fnc: Callable[[Worker], float] | Callable[[], float] = (
_DefaultLoadCalc.get_load
)
"""Called to determine the current load of the worker. Should return a value between 0 and 1."""
job_executor_type: JobExecutorType = _default_job_executor_type
"""Which executor to use to run jobs. (currently thread or process are supported)"""
load_threshold: float | _WorkerEnvOption[float] = _WorkerEnvOption(
dev_default=math.inf, prod_default=0.75
)
"""When the load exceeds this threshold, the worker will be marked as unavailable.
Defaults to 0.75 on "production" mode, and is disabled in "development" mode.
"""
job_memory_warn_mb: float = 300
"""Memory warning threshold in MB. If the job process exceeds this limit, a warning will be logged."""
job_memory_limit_mb: float = 0
"""Maximum memory usage for a job in MB, the job process will be killed if it exceeds this limit.
Defaults to 0 (disabled).
"""
"""Number of idle processes to keep warm."""
num_idle_processes: int | _WorkerEnvOption[int] = _WorkerEnvOption(
dev_default=0, prod_default=3
)
"""Number of idle processes to keep warm."""
shutdown_process_timeout: float = 60.0
"""Maximum amount of time to wait for a job to shut down gracefully"""
initialize_process_timeout: float = 10.0
"""Maximum amount of time to wait for a process to initialize/prewarm"""
permissions: WorkerPermissions = field(default_factory=WorkerPermissions)
"""Permissions that the agent should join the room with."""
agent_name: str = ""
"""Set agent_name to enable explicit dispatch. When explicit dispatch is enabled, jobs will not be dispatched to rooms automatically. Instead, you can either specify the agent(s) to be dispatched in the end-user's token, or use the AgentDispatch.createDispatch API"""
worker_type: WorkerType = WorkerType.ROOM
"""Whether to spin up an agent for each room or publisher."""
max_retry: int = 16
"""Maximum number of times to retry connecting to LiveKit."""
ws_url: str = "ws://localhost:7880"
"""URL to connect to the LiveKit server.
By default it uses ``LIVEKIT_URL`` from environment"""
api_key: str | None = None
"""API key to authenticate with LiveKit.
By default it uses ``LIVEKIT_API_KEY`` from environment"""
api_secret: str | None = None
"""API secret to authenticate with LiveKit.
By default it uses ``LIVEKIT_API_SECRET`` from environment"""
host: str = "" # default to all interfaces
port: int | _WorkerEnvOption[int] = _WorkerEnvOption(
dev_default=0, prod_default=8081
)
"""Port for local HTTP server to listen on.
The HTTP server is used as a health check endpoint.
"""
def validate_config(self, devmode: bool):
load_threshold = _WorkerEnvOption.getvalue(self.load_threshold, devmode)
if load_threshold > 1 and not devmode:
logger.warning(
f"load_threshold in prod env must be less than 1, current value: {load_threshold}"
)
EventTypes = Literal["worker_registered"]
class Worker(utils.EventEmitter[EventTypes]):
def __init__(
self,
opts: WorkerOptions,
*,
devmode: bool = True,
loop: asyncio.AbstractEventLoop | None = None,
) -> None:
super().__init__()
opts.ws_url = opts.ws_url or os.environ.get("LIVEKIT_URL") or ""
opts.api_key = opts.api_key or os.environ.get("LIVEKIT_API_KEY") or ""
opts.api_secret = opts.api_secret or os.environ.get("LIVEKIT_API_SECRET") or ""
if not opts.ws_url:
raise ValueError(
"ws_url is required, or add LIVEKIT_URL in your environment"
)
if not opts.api_key:
raise ValueError(
"api_key is required, or add LIVEKIT_API_KEY in your environment"
)
if not opts.api_secret:
raise ValueError(
"api_secret is required, or add LIVEKIT_API_SECRET in your environment"
)
if (
opts.job_memory_limit_mb > 0
and opts.job_executor_type != JobExecutorType.PROCESS
):
logger.warning(
"max_job_memory_usage is only supported for process-based job executors, "
"ignoring max_job_memory_usage"
)
self._opts = opts
self._loop = loop or asyncio.get_event_loop()
self._id = "unregistered"
self._closed, self._draining, self._connecting = True, False, False
self._tasks = set[asyncio.Task[Any]]()
self._pending_assignments: dict[str, asyncio.Future[agent.JobAssignment]] = {}
self._close_future: asyncio.Future[None] | None = None
self._msg_chan = utils.aio.Chan[agent.WorkerMessage](128, loop=self._loop)
self._devmode = devmode
# using spawn context for all platforms. We may have further optimizations for
# Linux with forkserver, but for now, this is the safest option
mp_ctx = mp.get_context("spawn")
self._inference_executor: (
ipc.inference_proc_executor.InferenceProcExecutor | None
) = None
if len(_InferenceRunner.registered_runners) > 0:
self._inference_executor = (
ipc.inference_proc_executor.InferenceProcExecutor(
runners=_InferenceRunner.registered_runners,
initialize_timeout=30,
close_timeout=5,
memory_warn_mb=2000,
memory_limit_mb=0, # no limit
ping_interval=5,
ping_timeout=60,
high_ping_threshold=2.5,
mp_ctx=mp_ctx,
loop=self._loop,
)
)
self._proc_pool = ipc.proc_pool.ProcPool(
initialize_process_fnc=opts.prewarm_fnc,
job_entrypoint_fnc=opts.entrypoint_fnc,
num_idle_processes=_WorkerEnvOption.getvalue(
opts.num_idle_processes, self._devmode
),
loop=self._loop,
job_executor_type=opts.job_executor_type,
inference_executor=self._inference_executor,
mp_ctx=mp_ctx,
initialize_timeout=opts.initialize_process_timeout,
close_timeout=opts.shutdown_process_timeout,
memory_warn_mb=opts.job_memory_warn_mb,
memory_limit_mb=opts.job_memory_limit_mb,
)
self._previous_status = agent.WorkerStatus.WS_AVAILABLE
self._api: api.LiveKitAPI | None = None
self._http_session: aiohttp.ClientSession | None = None
self._http_server = http_server.HttpServer(
opts.host,
_WorkerEnvOption.getvalue(opts.port, self._devmode),
loop=self._loop,
)
self._main_task: asyncio.Task[None] | None = None
async def run(self):
if not self._closed:
raise Exception("worker is already running")
logger.info(
"starting worker",
extra={"version": __version__, "rtc-version": rtc.__version__},
)
if self._inference_executor is not None:
logger.info("starting inference executor")
await self._inference_executor.start()
await self._inference_executor.initialize()
self._closed = False
def _update_job_status(proc: ipc.job_executor.JobExecutor) -> None:
t = self._loop.create_task(self._update_job_status(proc))
self._tasks.add(t)
t.add_done_callback(self._tasks.discard)
self._proc_pool.on("process_started", _update_job_status)
self._proc_pool.on("process_closed", _update_job_status)
self._proc_pool.on("process_job_launched", _update_job_status)
self._proc_pool.start()
self._api = api.LiveKitAPI(
self._opts.ws_url, self._opts.api_key, self._opts.api_secret
)
self._http_session = aiohttp.ClientSession()
self._close_future = asyncio.Future(loop=self._loop)
self._main_task = asyncio.create_task(self._worker_task(), name="worker_task")
tasks = [
self._main_task,
asyncio.create_task(self._http_server.run(), name="http_server"),
]
try:
await asyncio.gather(*tasks)
finally:
await utils.aio.gracefully_cancel(*tasks)
if not self._close_future.done():
self._close_future.set_result(None)
@property
def id(self) -> str:
return self._id
@property
def active_jobs(self) -> list[RunningJobInfo]:
return [
proc.running_job for proc in self._proc_pool.processes if proc.running_job
]
async def drain(self, timeout: int | None = None) -> None:
"""When timeout isn't None, it will raise asyncio.TimeoutError if the processes didn't finish in time."""
if self._draining:
return
logger.info("draining worker", extra={"id": self.id, "timeout": timeout})
self._draining = True
await self._update_worker_status()
async def _join_jobs():
for proc in self._proc_pool.processes:
if proc.running_job:
await proc.join()
if timeout:
await asyncio.wait_for(
_join_jobs(), timeout
) # raises asyncio.TimeoutError on timeout
else:
await _join_jobs()
async def simulate_job(
self, room: str, participant_identity: str | None = None
) -> None:
assert self._api is not None
room_obj = await self._api.room.create_room(api.CreateRoomRequest(name=room))
participant = None
if participant_identity:
participant = await self._api.room.get_participant(
api.RoomParticipantIdentity(room=room, identity=participant_identity)
)
msg = agent.WorkerMessage()
msg.simulate_job.room.CopyFrom(room_obj)
if participant:
msg.simulate_job.participant.CopyFrom(participant)
await self._queue_msg(msg)
async def aclose(self) -> None:
if self._closed:
if self._close_future is not None:
await self._close_future
return
logger.info("shutting down worker", extra={"id": self.id})
assert self._close_future is not None
assert self._http_session is not None
assert self._api is not None
assert self._main_task is not None
self._closed = True
self._main_task.cancel()
await self._proc_pool.aclose()
if self._inference_executor is not None:
await self._inference_executor.aclose()
await self._http_session.close()
await self._http_server.aclose()
await self._api.aclose()
await asyncio.gather(*self._tasks, return_exceptions=True)
# await asyncio.sleep(0.25) # see https://github.com/aio-libs/aiohttp/issues/1925
self._msg_chan.close()
await self._close_future
async def _queue_msg(self, msg: agent.WorkerMessage) -> None:
"""_queue_msg raises aio.ChanClosed when the worker is closing/closed"""
if self._connecting:
which = msg.WhichOneof("message")
if which == "update_worker":
return
elif which == "ping":
return
await self._msg_chan.send(msg)
async def _worker_task(self) -> None:
assert self._http_session is not None
retry_count = 0
ws: aiohttp.ClientWebSocketResponse | None = None
while not self._closed:
try:
self._connecting = True
join_jwt = (
api.AccessToken(self._opts.api_key, self._opts.api_secret)
.with_grants(api.VideoGrants(agent=True))
.to_jwt()
)
headers = {"Authorization": f"Bearer {join_jwt}"}
parse = urlparse(self._opts.ws_url)
scheme = parse.scheme
if scheme.startswith("http"):
scheme = scheme.replace("http", "ws")
path_parts = [f"{scheme}://{parse.netloc}", parse.path, "/agent"]
agent_url = reduce(urljoin, path_parts)
ws = await self._http_session.ws_connect(
agent_url, headers=headers, autoping=True
)
retry_count = 0
# register the worker
req = agent.WorkerMessage()
req.register.type = self._opts.worker_type.value
req.register.allowed_permissions.CopyFrom(
models.ParticipantPermission(
can_publish=self._opts.permissions.can_publish,
can_subscribe=self._opts.permissions.can_subscribe,
can_publish_data=self._opts.permissions.can_publish_data,
can_update_metadata=self._opts.permissions.can_update_metadata,
can_publish_sources=self._opts.permissions.can_publish_sources,
hidden=self._opts.permissions.hidden,
agent=True,
)
)
req.register.agent_name = self._opts.agent_name
req.register.version = __version__
await ws.send_bytes(req.SerializeToString())
# wait for the register response before running this connection
first_msg_b = await ws.receive_bytes()
msg = agent.ServerMessage()
msg.ParseFromString(first_msg_b)
if not msg.HasField("register"):
raise Exception("expected register response as first message")
self._handle_register(msg.register)
self._connecting = False
await self._run_ws(ws)
except Exception as e:
if self._closed:
break
if retry_count >= self._opts.max_retry:
raise RuntimeError(
f"failed to connect to livekit after {retry_count} attempts",
)
retry_delay = min(retry_count * 2, 10)
retry_count += 1
logger.warning(
f"failed to connect to livekit, retrying in {retry_delay}s: {e}"
)
await asyncio.sleep(retry_delay)
finally:
if ws is not None:
await ws.close()
async def _run_ws(self, ws: aiohttp.ClientWebSocketResponse):
closing_ws = False
async def _load_task():
"""periodically check load and update worker status"""
interval = utils.aio.interval(UPDATE_LOAD_INTERVAL)
while True:
await interval.tick()
await self._update_worker_status()
async def _send_task():
nonlocal closing_ws
while True:
try:
msg = await self._msg_chan.recv()
await ws.send_bytes(msg.SerializeToString())
except utils.aio.ChanClosed:
closing_ws = True
return
async def _recv_task():
nonlocal closing_ws
while True:
msg = await ws.receive()
if msg.type in (
aiohttp.WSMsgType.CLOSE,
aiohttp.WSMsgType.CLOSED,
aiohttp.WSMsgType.CLOSING,
):
if closing_ws:
return
raise Exception("worker connection closed unexpectedly")
if msg.type != aiohttp.WSMsgType.BINARY:
logger.warning("unexpected message type: %s", msg.type)
continue
data = msg.data
msg = agent.ServerMessage()
msg.ParseFromString(data)
which = msg.WhichOneof("message")
if which == "availability":
self._handle_availability(msg.availability)
elif which == "assignment":
self._handle_assignment(msg.assignment)
elif which == "termination":
user_task = self._loop.create_task(
self._handle_termination(msg.termination),
name="agent_job_termination",
)
self._tasks.add(user_task)
user_task.add_done_callback(self._tasks.discard)
tasks = [
asyncio.create_task(_load_task()),
asyncio.create_task(_send_task()),
asyncio.create_task(_recv_task()),
]
try:
await asyncio.gather(*tasks)
finally:
await utils.aio.gracefully_cancel(*tasks)
async def _reload_jobs(self, jobs: list[RunningJobInfo]) -> None:
if not self._opts.api_secret:
raise RuntimeError("api_secret is required to reload jobs")
for aj in jobs:
logger.log(
DEV_LEVEL,
"reloading job",
extra={"job_id": aj.job.id, "agent_name": aj.job.agent_name},
)
url = self._opts.ws_url
# take the original jwt token and extend it while keeping all the same data that was generated
# by the SFU for the original join token.
original_token = aj.token
decoded = jwt.decode(
original_token, self._opts.api_secret, algorithms=["HS256"]
)
decoded["exp"] = (
int(datetime.datetime.now(datetime.timezone.utc).timestamp()) + 3600
)
running_info = RunningJobInfo(
accept_arguments=aj.accept_arguments,
job=aj.job,
url=url,
token=jwt.encode(decoded, self._opts.api_secret, algorithm="HS256"),
worker_id=aj.worker_id,
)
await self._proc_pool.launch_job(running_info)
def _handle_register(self, reg: agent.RegisterWorkerResponse):
self._id = reg.worker_id
logger.info(
"registered worker",
extra={
"id": reg.worker_id,
"region": reg.server_info.region,
"protocol": reg.server_info.protocol,
"node_id": reg.server_info.node_id,
},
)
self.emit("worker_registered", reg.worker_id, reg.server_info)
def _handle_availability(self, msg: agent.AvailabilityRequest):
task = self._loop.create_task(self._answer_availability(msg))
self._tasks.add(task)
task.add_done_callback(self._tasks.discard)
async def _answer_availability(self, msg: agent.AvailabilityRequest):
"""Ask the user if they want to accept this job and forward the answer to the server.
If we get the job assigned, we start a new process."""
answered = False
async def _on_reject() -> None:
nonlocal answered
answered = True
availability_resp = agent.WorkerMessage()
availability_resp.availability.job_id = msg.job.id
availability_resp.availability.available = False
await self._queue_msg(availability_resp)
async def _on_accept(args: JobAcceptArguments) -> None:
nonlocal answered
answered = True
availability_resp = agent.WorkerMessage()
availability_resp.availability.job_id = msg.job.id
availability_resp.availability.available = True
availability_resp.availability.participant_identity = args.identity
availability_resp.availability.participant_name = args.name
availability_resp.availability.participant_metadata = args.metadata
if args.attributes:
availability_resp.availability.participant_attributes.update(
args.attributes
)
await self._queue_msg(availability_resp)
wait_assignment = asyncio.Future[agent.JobAssignment]()
self._pending_assignments[job_req.id] = wait_assignment
# the job was accepted by the user, wait for the server assignment
try:
await asyncio.wait_for(wait_assignment, ASSIGNMENT_TIMEOUT)
except asyncio.TimeoutError:
logger.warning(
f"assignment for job {job_req.id} timed out",
extra={"job_request": job_req, "agent_name": self._opts.agent_name},
)
raise AssignmentTimeoutError()
job_assign = wait_assignment.result()
running_info = RunningJobInfo(
accept_arguments=args,
job=msg.job,
url=job_assign.url or self._opts.ws_url,
token=job_assign.token,
worker_id=self._id,
)
await self._proc_pool.launch_job(running_info)
job_req = JobRequest(job=msg.job, on_reject=_on_reject, on_accept=_on_accept)
logger.info(
"received job request",
extra={
"job_id": msg.job.id,
"dispatch_id": msg.job.dispatch_id,
"room_name": msg.job.room.name,
"agent_name": self._opts.agent_name,
"resuming": msg.resuming,
},
)
@utils.log_exceptions(logger=logger)
async def _job_request_task():
try:
await self._opts.request_fnc(job_req)
except Exception:
logger.exception(
"job_request_fnc failed",
extra={"job_request": job_req, "agent_name": self._opts.agent_name},
)
if not answered:
logger.warning(
"no answer was given inside the job_request_fnc, automatically rejecting the job",
extra={"job_request": job_req, "agent_name": self._opts.agent_name},
)
await _on_reject()
user_task = self._loop.create_task(_job_request_task(), name="job_request")
self._tasks.add(user_task)
user_task.add_done_callback(self._tasks.discard)
def _handle_assignment(self, assignment: agent.JobAssignment):
if assignment.job.id in self._pending_assignments:
with contextlib.suppress(asyncio.InvalidStateError):
fut = self._pending_assignments.pop(assignment.job.id)
fut.set_result(assignment)
else:
logger.warning(
"received assignment for an unknown job",
extra={"job": assignment.job, "agent_name": self._opts.agent_name},
)
async def _handle_termination(self, msg: agent.JobTermination):
proc = self._proc_pool.get_by_job_id(msg.job_id)
if not proc:
# safe to ignore
return
await proc.aclose()
async def _update_worker_status(self):
job_cnt = len(self.active_jobs)
if self._draining:
update = agent.UpdateWorkerStatus(
status=agent.WorkerStatus.WS_FULL, job_count=job_cnt
)
msg = agent.WorkerMessage(update_worker=update)
await self._queue_msg(msg)
return
def load_fnc():
signature = inspect.signature(self._opts.load_fnc)
parameters = list(signature.parameters.values())
if len(parameters) == 0:
return self._opts.load_fnc() # type: ignore
return self._opts.load_fnc(self) # type: ignore
current_load = await asyncio.get_event_loop().run_in_executor(None, load_fnc)
is_full = current_load >= _WorkerEnvOption.getvalue(
self._opts.load_threshold, self._devmode
)
currently_available = not is_full and not self._draining
status = (
agent.WorkerStatus.WS_AVAILABLE
if currently_available
else agent.WorkerStatus.WS_FULL
)
update = agent.UpdateWorkerStatus(
load=current_load, status=status, job_count=job_cnt
)
# only log if status has changed
if self._previous_status != status and not self._draining:
self._previous_status = status
extra = {
"load": current_load,
"threshold": self._opts.load_threshold,
}
if is_full:
logger.info(
"worker is at full capacity, marking as unavailable",
extra=extra,
)
else:
logger.info(
"worker is below capacity, marking as available",
extra=extra,
)
msg = agent.WorkerMessage(update_worker=update)
with contextlib.suppress(utils.aio.ChanClosed):
await self._queue_msg(msg)
async def _update_job_status(self, proc: ipc.job_executor.JobExecutor) -> None:
job_info = proc.running_job
if job_info is None:
return
status: agent.JobStatus = agent.JobStatus.JS_RUNNING
if proc.status == ipc.job_executor.JobStatus.FAILED:
status = agent.JobStatus.JS_FAILED
elif proc.status == ipc.job_executor.JobStatus.SUCCESS:
status = agent.JobStatus.JS_SUCCESS
elif proc.status == ipc.job_executor.JobStatus.RUNNING:
status = agent.JobStatus.JS_RUNNING
update = agent.UpdateJobStatus(job_id=job_info.job.id, status=status, error="")
msg = agent.WorkerMessage(update_job=update)
await self._queue_msg(msg)
{
"name": "livekit-agents",
"private": true,
"version": "0.12.20"
}
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pathlib
import setuptools
here = pathlib.Path(__file__).parent.resolve()
about = {}
with open(os.path.join(here, "livekit", "agents", "version.py"), "r") as f:
exec(f.read(), about)
setuptools.setup(
name="livekit-agents",
version=about["__version__"],
description="LiveKit Python Agents",
long_description=(here / "README.md").read_text(encoding="utf-8"),
long_description_content_type="text/markdown",
url="https://github.com/livekit/agents",
cmdclass={},
classifiers=[
"Intended Audience :: Developers",
"License :: OSI Approved :: Apache Software License",
"Topic :: Multimedia :: Sound/Audio",
"Topic :: Multimedia :: Video",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3 :: Only",
],
keywords=["webrtc", "realtime", "audio", "video", "livekit", "agents", "AI"],
license="Apache-2.0",
packages=setuptools.find_namespace_packages(include=["livekit.*"]),
python_requires=">=3.9.0",
install_requires=[
"click~=8.1",
"livekit>=0.18.1",
"livekit-api~=0.8",
"livekit-protocol~=0.7",
"protobuf>=3",
"pyjwt>=2.0.0",
"types-protobuf>=4,<5",
"watchfiles>=0.22",
"psutil>=5.9",
"aiohttp>=3.10",
"typing-extensions>=4.12",
],
extras_require={
':sys_platform=="win32"': [
"colorama"
], # fix logs color on windows (devmode only)
"codecs": ["av>=12.0.0", "numpy>=1.26.0"],
"images": ["pillow>=10.3.0"],
},
package_data={"livekit.agents": ["py.typed"]},
project_urls={
"Documentation": "https://docs.livekit.io",
"Website": "https://livekit.io/",
"Source": "https://github.com/livekit/agents",
},
)
#!/bin/bash
set -e
# Get the directory where the script is located
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
pip install \
"${SCRIPT_DIR}/livekit-plugins-anthropic" \
"${SCRIPT_DIR}/livekit-plugins-assemblyai" \
"${SCRIPT_DIR}/livekit-plugins-aws" \
"${SCRIPT_DIR}/livekit-plugins-azure" \
"${SCRIPT_DIR}/livekit-plugins-cartesia" \
"${SCRIPT_DIR}/livekit-plugins-clova" \
"${SCRIPT_DIR}/livekit-plugins-deepgram" \
"${SCRIPT_DIR}/livekit-plugins-elevenlabs" \
"${SCRIPT_DIR}/livekit-plugins-fal" \
"${SCRIPT_DIR}/livekit-plugins-google" \
"${SCRIPT_DIR}/livekit-plugins-groq" \
"${SCRIPT_DIR}/livekit-plugins-llama-index" \
"${SCRIPT_DIR}/livekit-plugins-neuphonic" \
"${SCRIPT_DIR}/livekit-plugins-nltk" \
"${SCRIPT_DIR}/livekit-plugins-openai" \
"${SCRIPT_DIR}/livekit-plugins-playai" \
"${SCRIPT_DIR}/livekit-plugins-rag" \
"${SCRIPT_DIR}/livekit-plugins-rime" \
"${SCRIPT_DIR}/livekit-plugins-silero" \
"${SCRIPT_DIR}/livekit-plugins-speechmatics" \
"${SCRIPT_DIR}/livekit-plugins-turn-detector" \
"${SCRIPT_DIR}/livekit-plugins-resemble"
#!/bin/bash
set -e
if [[ -z "$VIRTUAL_ENV" ]]; then
echo "You are not in a virtual environment."
exit 1
fi
pip install -e ./livekit-plugins-anthropic --config-settings editable_mode=strict
pip install -e ./livekit-plugins-aws --config-settings editable_mode=strict
pip install -e ./livekit-plugins-assemblyai --config-settings editable_mode=strict
pip install -e ./livekit-plugins-azure --config-settings editable_mode=strict
pip install -e ./livekit-plugins-cartesia --config-settings editable_mode=strict
pip install -e ./livekit-plugins-deepgram --config-settings editable_mode=strict
pip install -e ./livekit-plugins-elevenlabs --config-settings editable_mode=strict
pip install -e ./livekit-plugins-fal --config-settings editable_mode=strict
pip install -e ./livekit-plugins-google --config-settings editable_mode=strict
pip install -e ./livekit-plugins-minimal --config-settings editable_mode=strict
pip install -e ./livekit-plugins-nltk --config-settings editable_mode=strict
pip install -e ./livekit-plugins-openai --config-settings editable_mode=strict
pip install -e ./livekit-plugins-rag --config-settings editable_mode=strict
pip install -e ./livekit-plugins-rime --config-settings editable_mode=strict
pip install -e ./livekit-plugins-llama-index --config-settings editable_mode=strict
pip install -e ./livekit-plugins-turn-detector --config-settings editable_mode=strict
pip install -e ./livekit-plugins-silero --config-settings editable_mode=strict
pip install -e ./livekit-plugins-speechmatics --config-settings editable_mode=strict
pip install -e ./livekit-plugins-neuphonic --config-settings editable_mode=strict
pip install -e ./livekit-plugins-resemble --config-settings editable_mode=strict
pip install -e ./livekit-plugins-browser --config-settings editable_mode=strict
# livekit-plugins-anthropic
## 0.2.13
### Patch Changes
- updated livekit-agent reference to <1.0 - [#1607](https://github.com/livekit/agents/pull/1607) ([@davidzhao](https://github.com/davidzhao))
## 0.2.12
### Patch Changes
- don't pass functions in params when tool choice is set to none - [#1552](https://github.com/livekit/agents/pull/1552) ([@jayeshp19](https://github.com/jayeshp19))
## 0.2.11
### Patch Changes
- Add cache support for Anthropic - [#1478](https://github.com/livekit/agents/pull/1478) ([@jayeshp19](https://github.com/jayeshp19))
## 0.2.10
### Patch Changes
- Added an additional field in LLM capabilities class to check if model providers support function call history within chat context without needing function definitions. - [#1441](https://github.com/livekit/agents/pull/1441) ([@jayeshp19](https://github.com/jayeshp19))
## 0.2.9
### Patch Changes
- improved handling of LLM errors, do not retry if already began - [#1298](https://github.com/livekit/agents/pull/1298) ([@davidzhao](https://github.com/davidzhao))
## 0.2.8
### Patch Changes
- Moved create_ai_function_info to function_context.py for better reusability and reduce repetation - [#1260](https://github.com/livekit/agents/pull/1260) ([@jayeshp19](https://github.com/jayeshp19))
- Add support for OpenAI's "detail" parameter to ChatImage - [#1213](https://github.com/livekit/agents/pull/1213) ([@bcherry](https://github.com/bcherry))
Add support for data URLs on ChatImage in the Anthropic plugin.
- fix: correctly parse function argument types - [#1221](https://github.com/livekit/agents/pull/1221) ([@jayeshp19](https://github.com/jayeshp19))
- Fix center_aspect_fit bug, add scale_aspect_fit and scale_aspect_fill resizing options. - [#1222](https://github.com/livekit/agents/pull/1222) ([@bcherry](https://github.com/bcherry))
Make scale_aspect_fit the new default resizing option for video frames.
## 0.2.7
### Patch Changes
- fix: return structured output from func calls - [#1187](https://github.com/livekit/agents/pull/1187) ([@jayeshp19](https://github.com/jayeshp19))
## 0.2.6
### Patch Changes
- fix release - [#1176](https://github.com/livekit/agents/pull/1176) ([@theomonnom](https://github.com/theomonnom))
## 0.2.5
### Patch Changes
- support for custom tool use in LLMs - [#1102](https://github.com/livekit/agents/pull/1102) ([@jayeshp19](https://github.com/jayeshp19))
- feat: llm retry & llm.FallbackAdapter - [#1132](https://github.com/livekit/agents/pull/1132) ([@theomonnom](https://github.com/theomonnom))
## 0.2.4
### Patch Changes
- anthropic tool fix - [#1051](https://github.com/livekit/agents/pull/1051) ([@jayeshp19](https://github.com/jayeshp19))
## 0.2.3
### Patch Changes
- fix: invalid request on anthropic - [#1018](https://github.com/livekit/agents/pull/1018) ([@theomonnom](https://github.com/theomonnom))
## 0.2.2
### Patch Changes
- pipelineagent: expose timing metrics & api errors wip - [#957](https://github.com/livekit/agents/pull/957) ([@theomonnom](https://github.com/theomonnom))
## 0.2.1
### Patch Changes
- Fixes to Anthropic Function Calling - [#708](https://github.com/livekit/agents/pull/708) ([@keepingitneil](https://github.com/keepingitneil))
## 0.2.0
### Minor Changes
- bump anthropic for release - [#724](https://github.com/livekit/agents/pull/724) ([@theomonnom](https://github.com/theomonnom))
# LiveKit Plugins Anthropic
Agent Framework plugin for services from Anthropic.
## Installation
```bash
pip install livekit-plugins-anthropic
You’ll need an API key from Anthropic. It can be set as an environment variable: ANTHROPIC_API_KEY
## livekit-plugins/livekit-plugins-anthropic/livekit/plugins/anthropic/__init__.py
```py
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .llm import LLM, LLMStream
from .log import logger
from .models import ChatModels
from .version import __version__
__all__ = [
"LLM",
"LLMStream",
"ChatModels",
"logger",
"__version__",
]
from livekit.agents import Plugin
class AnthropicPlugin(Plugin):
def __init__(self) -> None:
super().__init__(__name__, __version__, __package__, logger)
Plugin.register_plugin(AnthropicPlugin())
# Cleanup docs of unexported modules
_module = dir()
NOT_IN_ALL = [m for m in _module if m not in __all__]
__pdoc__ = {}
for n in NOT_IN_ALL:
__pdoc__[n] = False
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import base64
import inspect
import json
import os
from dataclasses import dataclass
from typing import (
Any,
Awaitable,
List,
Literal,
Union,
cast,
get_args,
get_origin,
)
import httpx
from livekit import rtc
from livekit.agents import (
APIConnectionError,
APIStatusError,
APITimeoutError,
llm,
utils,
)
from livekit.agents.llm import LLMCapabilities, ToolChoice
from livekit.agents.llm.function_context import (
_create_ai_function_info,
_is_optional_type,
)
from livekit.agents.types import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions
import anthropic
from .log import logger
from .models import (
ChatModels,
)
CACHE_CONTROL_EPHEMERAL = anthropic.types.CacheControlEphemeralParam(type="ephemeral")
@dataclass
class LLMOptions:
model: str | ChatModels
user: str | None
temperature: float | None
parallel_tool_calls: bool | None
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] | None
caching: Literal["ephemeral"] | None = None
"""If set to "ephemeral", the system prompt, tools, and chat history will be cached."""
class LLM(llm.LLM):
def __init__(
self,
*,
model: str | ChatModels = "claude-3-5-sonnet-20241022",
api_key: str | None = None,
base_url: str | None = None,
user: str | None = None,
client: anthropic.AsyncClient | None = None,
temperature: float | None = None,
parallel_tool_calls: bool | None = None,
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = "auto",
caching: Literal["ephemeral"] | None = None,
) -> None:
"""
Create a new instance of Anthropic LLM.
``api_key`` must be set to your Anthropic API key, either using the argument or by setting
the ``ANTHROPIC_API_KEY`` environmental variable.
model (str | ChatModels): The model to use. Defaults to "claude-3-5-sonnet-20241022".
api_key (str | None): The Anthropic API key. Defaults to the ANTHROPIC_API_KEY environment variable.
base_url (str | None): The base URL for the Anthropic API. Defaults to None.
user (str | None): The user for the Anthropic API. Defaults to None.
client (anthropic.AsyncClient | None): The Anthropic client to use. Defaults to None.
temperature (float | None): The temperature for the Anthropic API. Defaults to None.
parallel_tool_calls (bool | None): Whether to parallelize tool calls. Defaults to None.
tool_choice (Union[ToolChoice, Literal["auto", "required", "none"]] | None): The tool choice for the Anthropic API. Defaults to "auto".
caching (Literal["ephemeral"] | None): If set to "ephemeral", caching will be enabled for the system prompt, tools, and chat history.
"""
super().__init__(
capabilities=LLMCapabilities(
requires_persistent_functions=True,
supports_choices_on_int=True,
)
)
# throw an error on our end
api_key = api_key or os.environ.get("ANTHROPIC_API_KEY")
if api_key is None:
raise ValueError("Anthropic API key is required")
self._opts = LLMOptions(
model=model,
user=user,
temperature=temperature,
parallel_tool_calls=parallel_tool_calls,
tool_choice=tool_choice,
caching=caching,
)
self._client = client or anthropic.AsyncClient(
api_key=api_key,
base_url=base_url,
http_client=httpx.AsyncClient(
timeout=5.0,
follow_redirects=True,
limits=httpx.Limits(
max_connections=1000,
max_keepalive_connections=100,
keepalive_expiry=120,
),
),
)
def chat(
self,
*,
chat_ctx: llm.ChatContext,
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
fnc_ctx: llm.FunctionContext | None = None,
temperature: float | None = None,
n: int | None = 1,
parallel_tool_calls: bool | None = None,
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]]
| None = None,
) -> "LLMStream":
if temperature is None:
temperature = self._opts.temperature
if parallel_tool_calls is None:
parallel_tool_calls = self._opts.parallel_tool_calls
if tool_choice is None:
tool_choice = self._opts.tool_choice
opts: dict[str, Any] = dict()
if fnc_ctx and len(fnc_ctx.ai_functions) > 0:
fncs_desc: list[anthropic.types.ToolParam] = []
for i, fnc in enumerate(fnc_ctx.ai_functions.values()):
# caching last tool will cache all the tools if caching is enabled
cache_ctrl = (
CACHE_CONTROL_EPHEMERAL
if (i == len(fnc_ctx.ai_functions) - 1)
and self._opts.caching == "ephemeral"
else None
)
fncs_desc.append(
_build_function_description(
fnc,
cache_ctrl=cache_ctrl,
)
)
opts["tools"] = fncs_desc
if tool_choice is not None:
anthropic_tool_choice: dict[str, Any] | None = {"type": "auto"}
if isinstance(tool_choice, ToolChoice):
if tool_choice.type == "function":
anthropic_tool_choice = {
"type": "tool",
"name": tool_choice.name,
}
elif isinstance(tool_choice, str):
if tool_choice == "required":
anthropic_tool_choice = {"type": "any"}
elif tool_choice == "none":
opts["tools"] = []
anthropic_tool_choice = None
if anthropic_tool_choice is not None:
if parallel_tool_calls is False:
anthropic_tool_choice["disable_parallel_tool_use"] = True
opts["tool_choice"] = anthropic_tool_choice
latest_system_message: anthropic.types.TextBlockParam | None = (
_latest_system_message(chat_ctx, caching=self._opts.caching)
)
if latest_system_message:
opts["system"] = [latest_system_message]
anthropic_ctx = _build_anthropic_context(
chat_ctx.messages,
id(self),
caching=self._opts.caching,
)
collaped_anthropic_ctx = _merge_messages(anthropic_ctx)
stream = self._client.messages.create(
max_tokens=opts.get("max_tokens", 1024),
messages=collaped_anthropic_ctx,
model=self._opts.model,
temperature=temperature or anthropic.NOT_GIVEN,
top_k=n or anthropic.NOT_GIVEN,
stream=True,
**opts,
)
return LLMStream(
self,
anthropic_stream=stream,
chat_ctx=chat_ctx,
fnc_ctx=fnc_ctx,
conn_options=conn_options,
)
class LLMStream(llm.LLMStream):
def __init__(
self,
llm: LLM,
*,
anthropic_stream: Awaitable[
anthropic.AsyncStream[anthropic.types.RawMessageStreamEvent]
],
chat_ctx: llm.ChatContext,
fnc_ctx: llm.FunctionContext | None,
conn_options: APIConnectOptions,
) -> None:
super().__init__(
llm, chat_ctx=chat_ctx, fnc_ctx=fnc_ctx, conn_options=conn_options
)
self._awaitable_anthropic_stream = anthropic_stream
self._anthropic_stream: (
anthropic.AsyncStream[anthropic.types.RawMessageStreamEvent] | None
) = None
# current function call that we're waiting for full completion (args are streamed)
self._tool_call_id: str | None = None
self._fnc_name: str | None = None
self._fnc_raw_arguments: str | None = None
self._request_id: str = ""
self._ignoring_cot = False # ignore chain of thought
self._input_tokens = 0
self._cache_creation_tokens = 0
self._cache_read_tokens = 0
self._output_tokens = 0
async def _run(self) -> None:
retryable = True
try:
if not self._anthropic_stream:
self._anthropic_stream = await self._awaitable_anthropic_stream
async with self._anthropic_stream as stream:
async for event in stream:
chat_chunk = self._parse_event(event)
if chat_chunk is not None:
self._event_ch.send_nowait(chat_chunk)
retryable = False
self._event_ch.send_nowait(
llm.ChatChunk(
request_id=self._request_id,
usage=llm.CompletionUsage(
completion_tokens=self._output_tokens,
prompt_tokens=self._input_tokens,
total_tokens=self._input_tokens
+ self._output_tokens
+ self._cache_creation_tokens
+ self._cache_read_tokens,
cache_creation_input_tokens=self._cache_creation_tokens,
cache_read_input_tokens=self._cache_read_tokens,
),
)
)
except anthropic.APITimeoutError:
raise APITimeoutError(retryable=retryable)
except anthropic.APIStatusError as e:
raise APIStatusError(
e.message,
status_code=e.status_code,
request_id=e.request_id,
body=e.body,
)
except Exception as e:
raise APIConnectionError(retryable=retryable) from e
def _parse_event(
self, event: anthropic.types.RawMessageStreamEvent
) -> llm.ChatChunk | None:
if event.type == "message_start":
self._request_id = event.message.id
self._input_tokens = event.message.usage.input_tokens
self._output_tokens = event.message.usage.output_tokens
if event.message.usage.cache_creation_input_tokens:
self._cache_creation_tokens = (
event.message.usage.cache_creation_input_tokens
)
if event.message.usage.cache_read_input_tokens:
self._cache_read_tokens = event.message.usage.cache_read_input_tokens
elif event.type == "message_delta":
self._output_tokens += event.usage.output_tokens
elif event.type == "content_block_start":
if event.content_block.type == "tool_use":
self._tool_call_id = event.content_block.id
self._fnc_name = event.content_block.name
self._fnc_raw_arguments = ""
elif event.type == "content_block_delta":
delta = event.delta
if delta.type == "text_delta":
text = delta.text
if self._fnc_ctx is not None:
# anthropic may inject COC when using functions
if text.startswith("<thinking>"):
self._ignoring_cot = True
elif self._ignoring_cot and "</thinking>" in text:
text = text.split("</thinking>")[-1]
self._ignoring_cot = False
if self._ignoring_cot:
return None
return llm.ChatChunk(
request_id=self._request_id,
choices=[
llm.Choice(
delta=llm.ChoiceDelta(content=text, role="assistant")
)
],
)
elif delta.type == "input_json_delta":
assert self._fnc_raw_arguments is not None
self._fnc_raw_arguments += delta.partial_json
elif event.type == "content_block_stop":
if self._tool_call_id is not None and self._fnc_ctx:
assert self._fnc_name is not None
assert self._fnc_raw_arguments is not None
fnc_info = _create_ai_function_info(
self._fnc_ctx,
self._tool_call_id,
self._fnc_name,
self._fnc_raw_arguments,
)
self._function_calls_info.append(fnc_info)
chat_chunk = llm.ChatChunk(
request_id=self._request_id,
choices=[
llm.Choice(
delta=llm.ChoiceDelta(
role="assistant", tool_calls=[fnc_info]
),
)
],
)
self._tool_call_id = self._fnc_raw_arguments = self._fnc_name = None
return chat_chunk
return None
def _latest_system_message(
chat_ctx: llm.ChatContext, caching: Literal["ephemeral"] | None = None
) -> anthropic.types.TextBlockParam | None:
latest_system_message: llm.ChatMessage | None = None
for m in chat_ctx.messages:
if m.role == "system":
latest_system_message = m
continue
latest_system_str = ""
if latest_system_message:
if isinstance(latest_system_message.content, str):
latest_system_str = latest_system_message.content
elif isinstance(latest_system_message.content, list):
latest_system_str = " ".join(
[c for c in latest_system_message.content if isinstance(c, str)]
)
if latest_system_str:
system_text_block = anthropic.types.TextBlockParam(
text=latest_system_str,
type="text",
cache_control=CACHE_CONTROL_EPHEMERAL if caching == "ephemeral" else None,
)
return system_text_block
return None
def _merge_messages(
messages: List[anthropic.types.MessageParam],
) -> List[anthropic.types.MessageParam]:
# Anthropic enforces alternating messages
combined_messages: list[anthropic.types.MessageParam] = []
for m in messages:
if len(combined_messages) == 0 or m["role"] != combined_messages[-1]["role"]:
combined_messages.append(m)
continue
last_message = combined_messages[-1]
if not isinstance(last_message["content"], list) or not isinstance(
m["content"], list
):
logger.error("message content is not a list")
continue
last_message["content"].extend(m["content"])
if len(combined_messages) == 0 or combined_messages[0]["role"] != "user":
combined_messages.insert(
0, {"role": "user", "content": [{"type": "text", "text": "(empty)"}]}
)
return combined_messages
def _build_anthropic_context(
chat_ctx: List[llm.ChatMessage],
cache_key: Any,
caching: Literal["ephemeral"] | None,
) -> List[anthropic.types.MessageParam]:
result: List[anthropic.types.MessageParam] = []
for i, msg in enumerate(chat_ctx):
# caching last message will cache whole chat history if caching is enabled
cache_ctrl = (
CACHE_CONTROL_EPHEMERAL
if ((i == len(chat_ctx) - 1) and caching == "ephemeral")
else None
)
a_msg = _build_anthropic_message(msg, cache_key, cache_ctrl=cache_ctrl)
if a_msg:
result.append(a_msg)
return result
def _build_anthropic_message(
msg: llm.ChatMessage,
cache_key: Any,
cache_ctrl: anthropic.types.CacheControlEphemeralParam | None,
) -> anthropic.types.MessageParam | None:
if msg.role == "user" or msg.role == "assistant":
a_msg: anthropic.types.MessageParam = {
"role": msg.role,
"content": [],
}
assert isinstance(a_msg["content"], list)
a_content = a_msg["content"]
# add content if provided
if isinstance(msg.content, str) and msg.content:
a_msg["content"].append(
anthropic.types.TextBlockParam(
text=msg.content,
type="text",
cache_control=cache_ctrl,
)
)
elif isinstance(msg.content, list):
for cnt in msg.content:
if isinstance(cnt, str) and cnt:
content: anthropic.types.TextBlockParam = (
anthropic.types.TextBlockParam(
text=cnt,
type="text",
cache_control=cache_ctrl,
)
)
a_content.append(content)
elif isinstance(cnt, llm.ChatImage):
a_content.append(
_build_anthropic_image_content(cnt, cache_key, cache_ctrl)
)
if msg.tool_calls is not None:
for fnc in msg.tool_calls:
tool_use = anthropic.types.ToolUseBlockParam(
id=fnc.tool_call_id,
type="tool_use",
name=fnc.function_info.name,
input=fnc.arguments,
cache_control=cache_ctrl,
)
a_content.append(tool_use)
return a_msg
elif msg.role == "tool":
if isinstance(msg.content, dict):
msg.content = json.dumps(msg.content)
if not isinstance(msg.content, str):
logger.warning("tool message content is not a string or dict")
return None
if not msg.tool_call_id:
return None
u_content = anthropic.types.ToolResultBlockParam(
tool_use_id=msg.tool_call_id,
type="tool_result",
content=msg.content,
is_error=msg.tool_exception is not None,
cache_control=cache_ctrl,
)
return {
"role": "user",
"content": [u_content],
}
return None
def _build_anthropic_image_content(
image: llm.ChatImage,
cache_key: Any,
cache_ctrl: anthropic.types.CacheControlEphemeralParam | None,
) -> anthropic.types.ImageBlockParam:
if isinstance(image.image, str): # image is a URL
if not image.image.startswith("data:"):
raise ValueError("LiveKit Anthropic Plugin: Image URLs must be data URLs")
try:
header, b64_data = image.image.split(",", 1)
media_type = header.split(";")[0].split(":")[1]
supported_types = {"image/jpeg", "image/png", "image/webp", "image/gif"}
if media_type not in supported_types:
raise ValueError(
f"LiveKit Anthropic Plugin: Unsupported media type {media_type}. Must be jpeg, png, webp, or gif"
)
return {
"type": "image",
"source": {
"type": "base64",
"data": b64_data,
"media_type": cast(
Literal["image/jpeg", "image/png", "image/gif", "image/webp"],
media_type,
),
},
"cache_control": cache_ctrl,
}
except (ValueError, IndexError) as e:
raise ValueError(
f"LiveKit Anthropic Plugin: Invalid image data URL {str(e)}"
)
elif isinstance(image.image, rtc.VideoFrame): # image is a VideoFrame
if cache_key not in image._cache:
# inside our internal implementation, we allow to put extra metadata to
# each ChatImage (avoid to reencode each time we do a chatcompletion request)
opts = utils.images.EncodeOptions()
if image.inference_width and image.inference_height:
opts.resize_options = utils.images.ResizeOptions(
width=image.inference_width,
height=image.inference_height,
strategy="scale_aspect_fit",
)
encoded_data = utils.images.encode(image.image, opts)
image._cache[cache_key] = base64.b64encode(encoded_data).decode("utf-8")
return {
"type": "image",
"source": {
"type": "base64",
"data": image._cache[cache_key],
"media_type": "image/jpeg",
},
"cache_control": cache_ctrl,
}
raise ValueError(
"LiveKit Anthropic Plugin: ChatImage must be an rtc.VideoFrame or a data URL"
)
def _build_function_description(
fnc_info: llm.function_context.FunctionInfo,
cache_ctrl: anthropic.types.CacheControlEphemeralParam | None,
) -> anthropic.types.ToolParam:
def build_schema_field(arg_info: llm.function_context.FunctionArgInfo):
def type2str(t: type) -> str:
if t is str:
return "string"
elif t in (int, float):
return "number"
elif t is bool:
return "boolean"
raise ValueError(f"unsupported type {t} for ai_property")
p: dict[str, Any] = {}
if arg_info.default is inspect.Parameter.empty:
p["required"] = True
else:
p["required"] = False
if arg_info.description:
p["description"] = arg_info.description
_, inner_th = _is_optional_type(arg_info.type)
if get_origin(inner_th) is list:
inner_type = get_args(inner_th)[0]
p["type"] = "array"
p["items"] = {}
p["items"]["type"] = type2str(inner_type)
if arg_info.choices:
p["items"]["enum"] = arg_info.choices
else:
p["type"] = type2str(inner_th)
if arg_info.choices:
p["enum"] = arg_info.choices
return p
input_schema: dict[str, object] = {"type": "object"}
for arg_info in fnc_info.arguments.values():
input_schema[arg_info.name] = build_schema_field(arg_info)
return anthropic.types.ToolParam(
name=fnc_info.name,
description=fnc_info.description,
input_schema=input_schema,
cache_control=cache_ctrl,
)
import logging
logger = logging.getLogger("livekit.plugins.anthropic")
from typing import Literal
ChatModels = Literal[
"claude-3-5-sonnet-20240620",
"claude-3-opus-20240229",
"claude-3-sonnet-20240229",
"claude-3-5-sonnet-20241022",
"claude-3-haiku-20240307",
]
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "0.2.13"
{
"name": "livekit-plugins-anthropic",
"private": true,
"version": "0.2.13"
}
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pathlib
import setuptools
import setuptools.command.build_py
here = pathlib.Path(__file__).parent.resolve()
about = {}
with open(
os.path.join(here, "livekit", "plugins", "anthropic", "version.py"), "r"
) as f:
exec(f.read(), about)
setuptools.setup(
name="livekit-plugins-anthropic",
version=about["__version__"],
description="Agent Framework plugin for services from Anthropic",
long_description=(here / "README.md").read_text(encoding="utf-8"),
long_description_content_type="text/markdown",
url="https://github.com/livekit/agents",
cmdclass={},
classifiers=[
"Intended Audience :: Developers",
"License :: OSI Approved :: Apache Software License",
"Topic :: Multimedia :: Sound/Audio",
"Topic :: Multimedia :: Video",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3 :: Only",
],
keywords=["webrtc", "realtime", "audio", "video", "livekit"],
license="Apache-2.0",
packages=setuptools.find_namespace_packages(include=["livekit.*"]),
python_requires=">=3.9.0",
install_requires=["livekit-agents>=0.12.16,<1.0.0", "anthropic>=0.34"],
package_data={"livekit.plugins.anthropic": ["py.typed"]},
project_urls={
"Documentation": "https://docs.livekit.io",
"Website": "https://livekit.io/",
"Source": "https://github.com/livekit/agents",
},
)
# livekit-plugins-assemblyai
## 0.2.3
### Patch Changes
- updated livekit-agent reference to <1.0 - [#1607](https://github.com/livekit/agents/pull/1607) ([@davidzhao](https://github.com/davidzhao))
## 0.2.2
### Patch Changes
- fix: Ensure STT exceptions are being propagated - [#1291](https://github.com/livekit/agents/pull/1291) ([@davidzhao](https://github.com/davidzhao))
- assemblyai: encode boost words - [#1284](https://github.com/livekit/agents/pull/1284) ([@jmugicagonz](https://github.com/jmugicagonz))
## 0.2.1
### Patch Changes
- fix release - [#1176](https://github.com/livekit/agents/pull/1176) ([@theomonnom](https://github.com/theomonnom))
## 0.1.1
### Patch Changes
- feat: stt retry & stt.FallbackAdapter - [#1114](https://github.com/livekit/agents/pull/1114) ([@theomonnom](https://github.com/theomonnom))
## 0.1.0
### Minor Changes
- Introduce assembly.ai plugin - [#1082](https://github.com/livekit/agents/pull/1082) ([@davidzhao](https://github.com/davidzhao))
# LiveKit Plugins AssemblyAI
Agent Framework plugin for AssemblyAI. Currently supports Streaming Speech-to-Text.
## Installation
```bash
pip install livekit-plugins-assemblyai
You’ll need to specify an AssemblyAI API Key. It can be set as environment variable: ASSEMBLYAI_API_KEY
.
## livekit-plugins/livekit-plugins-assemblyai/livekit/plugins/assemblyai/__init__.py
```py
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .log import logger
from .stt import STT, SpeechStream
from .version import __version__
__all__ = [
"STT",
"SpeechStream",
"logger",
"__version__",
]
from livekit.agents import Plugin
class AssemblyAIPlugin(Plugin):
def __init__(self):
super().__init__(__name__, __version__, __package__)
Plugin.register_plugin(AssemblyAIPlugin())
import logging
logger = logging.getLogger("livekit.plugins.assemblyai")
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import asyncio
import dataclasses
import json
import os
import weakref
from dataclasses import dataclass
from typing import List, Literal, Optional
from urllib.parse import urlencode
import aiohttp
from livekit.agents import (
DEFAULT_API_CONNECT_OPTIONS,
APIConnectOptions,
APIStatusError,
stt,
utils,
)
from livekit.agents.stt import SpeechEvent
from livekit.agents.utils import AudioBuffer
from .log import logger
ENGLISH = "en"
# Define bytes per frame for different encoding types
bytes_per_frame = {
"pcm_s16le": 2,
"pcm_mulaw": 1,
}
@dataclass
class STTOptions:
sample_rate: int
buffer_size_seconds: float
word_boost: Optional[List[str]] = None
encoding: Optional[Literal["pcm_s16le", "pcm_mulaw"]] = None
disable_partial_transcripts: bool = False
enable_extra_session_information: bool = False
end_utterance_silence_threshold: Optional[int] = None
# Buffer to collect frames to send to AssemblyAI
def __post_init__(self):
if self.encoding not in (None, "pcm_s16le", "pcm_mulaw"):
raise ValueError(f"Invalid encoding: {self.encoding}")
class STT(stt.STT):
def __init__(
self,
*,
api_key: Optional[str] = None,
sample_rate: int = 16000,
word_boost: Optional[List[str]] = None,
encoding: Optional[Literal["pcm_s16le", "pcm_mulaw"]] = "pcm_s16le",
disable_partial_transcripts: bool = False,
enable_extra_session_information: bool = False,
end_utterance_silence_threshold: Optional[int] = 500,
http_session: Optional[aiohttp.ClientSession] = None,
buffer_size_seconds: float = 0.05,
):
super().__init__(
capabilities=stt.STTCapabilities(
streaming=True,
interim_results=True,
),
)
api_key = api_key or os.environ.get("ASSEMBLYAI_API_KEY")
if api_key is None:
raise ValueError(
"AssemblyAI API key is required. "
"Pass one in via the `api_key` parameter, "
"or set it as the `ASSEMBLYAI_API_KEY` environment variable"
)
self._api_key = api_key
self._opts = STTOptions(
sample_rate=sample_rate,
word_boost=word_boost,
encoding=encoding,
disable_partial_transcripts=disable_partial_transcripts,
enable_extra_session_information=enable_extra_session_information,
buffer_size_seconds=buffer_size_seconds,
end_utterance_silence_threshold=end_utterance_silence_threshold,
)
self._session = http_session
self._streams = weakref.WeakSet[SpeechStream]()
@property
def session(self) -> aiohttp.ClientSession:
if not self._session:
self._session = utils.http_context.http_session()
return self._session
async def _recognize_impl(
self,
buffer: AudioBuffer,
*,
language: str | None,
conn_options: APIConnectOptions,
) -> stt.SpeechEvent:
raise NotImplementedError("Not implemented")
def stream(
self,
*,
language: Optional[str] = None,
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
) -> "SpeechStream":
config = dataclasses.replace(self._opts)
stream = SpeechStream(
stt=self,
conn_options=conn_options,
opts=config,
api_key=self._api_key,
http_session=self.session,
)
self._streams.add(stream)
return stream
def update_options(
self,
*,
disable_partial_transcripts: Optional[bool] = None,
word_boost: Optional[List[str]] = None,
end_utterance_silence_threshold: Optional[int] = None,
enable_extra_session_information: Optional[bool] = None,
buffer_size_seconds: Optional[float] = None,
):
if disable_partial_transcripts is not None:
self._opts.disable_partial_transcripts = disable_partial_transcripts
if word_boost is not None:
self._opts.word_boost = word_boost
if end_utterance_silence_threshold is not None:
self._opts.end_utterance_silence_threshold = end_utterance_silence_threshold
if enable_extra_session_information is not None:
self._opts.enable_extra_session_information = (
enable_extra_session_information
)
if buffer_size_seconds is not None:
self._opts.buffer_size_seconds = buffer_size_seconds
for stream in self._streams:
stream.update_options(
disable_partial_transcripts=disable_partial_transcripts,
word_boost=word_boost,
end_utterance_silence_threshold=end_utterance_silence_threshold,
enable_extra_session_information=enable_extra_session_information,
buffer_size_seconds=buffer_size_seconds,
)
class SpeechStream(stt.SpeechStream):
# Used to close websocket
_CLOSE_MSG: str = json.dumps({"terminate_session": True})
def __init__(
self,
*,
stt: STT,
opts: STTOptions,
conn_options: APIConnectOptions,
api_key: str,
http_session: aiohttp.ClientSession,
) -> None:
super().__init__(
stt=stt, conn_options=conn_options, sample_rate=opts.sample_rate
)
self._opts = opts
self._api_key = api_key
self._session = http_session
self._speech_duration: float = 0
# keep a list of final transcripts to combine them inside the END_OF_SPEECH event
self._final_events: List[SpeechEvent] = []
self._reconnect_event = asyncio.Event()
def update_options(
self,
*,
disable_partial_transcripts: Optional[bool] = None,
word_boost: Optional[List[str]] = None,
end_utterance_silence_threshold: Optional[int] = None,
enable_extra_session_information: Optional[bool] = None,
buffer_size_seconds: Optional[float] = None,
):
if disable_partial_transcripts is not None:
self._opts.disable_partial_transcripts = disable_partial_transcripts
if word_boost is not None:
self._opts.word_boost = word_boost
if end_utterance_silence_threshold is not None:
self._opts.end_utterance_silence_threshold = end_utterance_silence_threshold
if enable_extra_session_information is not None:
self._opts.enable_extra_session_information = (
enable_extra_session_information
)
if buffer_size_seconds is not None:
self._opts.buffer_size_seconds = buffer_size_seconds
self._reconnect_event.set()
async def _run(self) -> None:
"""
Run a single websocket connection to AssemblyAI and make sure to reconnect
when something went wrong.
"""
closing_ws = False
async def send_task(ws: aiohttp.ClientWebSocketResponse):
nonlocal closing_ws
if self._opts.end_utterance_silence_threshold:
await ws.send_str(
json.dumps(
{
"end_utterance_silence_threshold": self._opts.end_utterance_silence_threshold
}
)
)
samples_per_buffer = self._opts.sample_rate // round(
1 / self._opts.buffer_size_seconds
)
audio_bstream = utils.audio.AudioByteStream(
sample_rate=self._opts.sample_rate,
num_channels=1,
samples_per_channel=samples_per_buffer,
)
# forward inputs to AssemblyAI
# if we receive a close message, signal it to AssemblyAI and break.
# the recv task will then make sure to process the remaining audio and stop
async for data in self._input_ch:
if isinstance(data, self._FlushSentinel):
frames = audio_bstream.flush()
else:
frames = audio_bstream.write(data.data.tobytes())
for frame in frames:
self._speech_duration += frame.duration
await ws.send_bytes(frame.data.tobytes())
closing_ws = True
await ws.send_str(SpeechStream._CLOSE_MSG)
async def recv_task(ws: aiohttp.ClientWebSocketResponse):
nonlocal closing_ws
while True:
try:
msg = await asyncio.wait_for(ws.receive(), timeout=5)
except asyncio.TimeoutError:
if closing_ws:
break
continue
if msg.type in (
aiohttp.WSMsgType.CLOSED,
aiohttp.WSMsgType.CLOSE,
aiohttp.WSMsgType.CLOSING,
):
if closing_ws: # close is expected, see SpeechStream.aclose
return
raise APIStatusError(
"AssemblyAI connection closed unexpectedly",
) # this will trigger a reconnection, see the _run loop
if msg.type != aiohttp.WSMsgType.TEXT:
logger.error("unexpected AssemblyAI message type %s", msg.type)
continue
try:
# received a message from AssemblyAI
data = json.loads(msg.data)
self._process_stream_event(data, closing_ws)
except Exception:
logger.exception("failed to process AssemblyAI message")
ws: aiohttp.ClientWebSocketResponse | None = None
while True:
try:
ws = await self._connect_ws()
tasks = [
asyncio.create_task(send_task(ws)),
asyncio.create_task(recv_task(ws)),
]
wait_reconnect_task = asyncio.create_task(self._reconnect_event.wait())
try:
done, _ = await asyncio.wait(
[asyncio.gather(*tasks), wait_reconnect_task],
return_when=asyncio.FIRST_COMPLETED,
) # type: ignore
for task in done:
if task != wait_reconnect_task:
task.result()
if wait_reconnect_task not in done:
break
self._reconnect_event.clear()
finally:
await utils.aio.gracefully_cancel(*tasks, wait_reconnect_task)
finally:
if ws is not None:
await ws.close()
async def _connect_ws(self) -> aiohttp.ClientWebSocketResponse:
live_config = {
"sample_rate": self._opts.sample_rate,
"word_boost": json.dumps(self._opts.word_boost)
if self._opts.word_boost is not None
else None,
"encoding": self._opts.encoding,
"disable_partial_transcripts": self._opts.disable_partial_transcripts,
"enable_extra_session_information": self._opts.enable_extra_session_information,
}
headers = {
"Authorization": self._api_key,
"Content-Type": "application/json",
}
ws_url = "wss://api.assemblyai.com/v2/realtime/ws"
filtered_config = {k: v for k, v in live_config.items() if v is not None}
url = f"{ws_url}?{urlencode(filtered_config).lower()}"
ws = await self._session.ws_connect(url, headers=headers)
return ws
def _process_stream_event(self, data: dict, closing_ws: bool) -> None:
# see this page:
# https://www.assemblyai.com/docs/api-reference/streaming/realtime
# for more information about the different types of events
if "error" in data:
logger.error("Received error from AssemblyAI: %s", data["error"])
return
message_type = data.get("message_type")
if message_type == "SessionBegins":
start_event = stt.SpeechEvent(type=stt.SpeechEventType.START_OF_SPEECH)
self._event_ch.send_nowait(start_event)
elif message_type == "PartialTranscript":
alts = live_transcription_to_speech_data(ENGLISH, data)
if len(alts) > 0 and alts[0].text:
interim_event = stt.SpeechEvent(
type=stt.SpeechEventType.INTERIM_TRANSCRIPT,
alternatives=alts,
)
self._event_ch.send_nowait(interim_event)
elif message_type == "FinalTranscript":
alts = live_transcription_to_speech_data(ENGLISH, data)
if len(alts) > 0 and alts[0].text:
final_event = stt.SpeechEvent(
type=stt.SpeechEventType.FINAL_TRANSCRIPT,
alternatives=alts,
)
self._final_events.append(final_event)
self._event_ch.send_nowait(final_event)
# log metrics
if self._speech_duration > 0:
usage_event = stt.SpeechEvent(
type=stt.SpeechEventType.RECOGNITION_USAGE,
alternatives=[],
recognition_usage=stt.RecognitionUsage(
audio_duration=self._speech_duration
),
)
self._event_ch.send_nowait(usage_event)
self._speech_duration = 0
elif message_type == "SessionTerminated":
if closing_ws:
pass
else:
raise Exception("AssemblyAI connection closed unexpectedly")
elif message_type == "SessionInformation":
logger.debug("AssemblyAI Session Information: %s", str(data))
else:
logger.warning(
"Received unexpected message type from AssemblyAI: %s",
message_type or "No message_type field",
)
def live_transcription_to_speech_data(
language: str,
data: dict,
) -> List[stt.SpeechData]:
return [
stt.SpeechData(
language=language,
start_time=data["words"][0]["start"] / 1000 if data["words"] else 0,
end_time=data["words"][-1]["end"] / 1000 if data["words"] else 0,
confidence=data["confidence"],
text=data["text"],
),
]
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "0.2.3"
{
"name": "livekit-plugins-assemblyai",
"private": true,
"version": "0.2.3"
}
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pathlib
import setuptools
import setuptools.command.build_py
here = pathlib.Path(__file__).parent.resolve()
about = {}
with open(
os.path.join(here, "livekit", "plugins", "assemblyai", "version.py"), "r"
) as f:
exec(f.read(), about)
setuptools.setup(
name="livekit-plugins-assemblyai",
version=about["__version__"],
description="Agent Framework plugin for AssemblyAI",
long_description=(here / "README.md").read_text(encoding="utf-8"),
long_description_content_type="text/markdown",
url="https://github.com/livekit/agents",
cmdclass={},
classifiers=[
"Intended Audience :: Developers",
"License :: OSI Approved :: Apache Software License",
"Topic :: Multimedia :: Sound/Audio",
"Topic :: Multimedia :: Video",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3 :: Only",
],
keywords=["webrtc", "realtime", "audio", "video", "livekit"],
license="Apache-2.0",
packages=setuptools.find_namespace_packages(include=["livekit.*"]),
python_requires=">=3.9.0",
install_requires=[
"livekit-agents>=0.12.16,<1.0.0",
],
package_data={},
project_urls={
"Documentation": "https://docs.livekit.io",
"Website": "https://livekit.io/",
"Source": "https://github.com/livekit/agents",
},
)
# livekit-plugins-aws
## 0.1.1
### Patch Changes
- use streaming AudioDecoder to handle compressed encoding - [#1584](https://github.com/livekit/agents/pull/1584) ([@davidzhao](https://github.com/davidzhao))
- updated livekit-agent reference to <1.0 - [#1607](https://github.com/livekit/agents/pull/1607) ([@davidzhao](https://github.com/davidzhao))
# LiveKit Plugins AWS
Agent Framework plugin for services from AWS.
- aws polly for tts
- aws transcribe for stt
- aws bedrock for llm
## Installation
```bash
pip install livekit-plugins-aws
You’ll need to specify an AWS Access Key and a Deployment Region. They can be set as environment variables: AWS_ACCESS_KEY_ID
, AWS_SECRET_ACCESS_KEY
and AWS_DEFAULT_REGION
, respectively.
## livekit-plugins/livekit-plugins-aws/livekit/plugins/aws/__init__.py
```py
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .llm import LLM
from .stt import STT, SpeechStream
from .tts import TTS, ChunkedStream
from .version import __version__
__all__ = ["STT", "SpeechStream", "TTS", "ChunkedStream", "LLM", "__version__"]
from livekit.agents import Plugin
class AWSPlugin(Plugin):
def __init__(self) -> None:
super().__init__(__name__, __version__, __package__)
Plugin.register_plugin(AWSPlugin())
from __future__ import annotations
import base64
import inspect
import json
import os
from typing import Any, Dict, List, Optional, Tuple, get_args, get_origin
import boto3
from livekit import rtc
from livekit.agents import llm, utils
from livekit.agents.llm.function_context import _is_optional_type
__all__ = ["_build_aws_ctx", "_build_tools", "_get_aws_credentials"]
def _get_aws_credentials(
api_key: Optional[str], api_secret: Optional[str], region: Optional[str]
):
region = region or os.environ.get("AWS_DEFAULT_REGION")
if not region:
raise ValueError(
"AWS_DEFAULT_REGION must be set using the argument or by setting the AWS_DEFAULT_REGION environment variable."
)
# If API key and secret are provided, create a session with them
if api_key and api_secret:
session = boto3.Session(
aws_access_key_id=api_key,
aws_secret_access_key=api_secret,
region_name=region,
)
else:
session = boto3.Session(region_name=region)
credentials = session.get_credentials()
if not credentials or not credentials.access_key or not credentials.secret_key:
raise ValueError("No valid AWS credentials found.")
return credentials.access_key, credentials.secret_key
JSON_SCHEMA_TYPE_MAP: Dict[type, str] = {
str: "string",
int: "integer",
float: "number",
bool: "boolean",
dict: "object",
list: "array",
}
def _build_parameters(arguments: Dict[str, Any]) -> Optional[Dict[str, Any]]:
properties: Dict[str, dict] = {}
required: List[str] = []
for arg_name, arg_info in arguments.items():
prop = {}
if hasattr(arg_info, "description") and arg_info.description:
prop["description"] = arg_info.description
_, py_type = _is_optional_type(arg_info.type)
origin = get_origin(py_type)
if origin is list:
item_type = get_args(py_type)[0]
if item_type not in JSON_SCHEMA_TYPE_MAP:
raise ValueError(f"Unsupported type: {item_type}")
prop["type"] = "array"
prop["items"] = {"type": JSON_SCHEMA_TYPE_MAP[item_type]}
if hasattr(arg_info, "choices") and arg_info.choices:
prop["items"]["enum"] = list(arg_info.choices)
else:
if py_type not in JSON_SCHEMA_TYPE_MAP:
raise ValueError(f"Unsupported type: {py_type}")
prop["type"] = JSON_SCHEMA_TYPE_MAP[py_type]
if arg_info.choices:
prop["enum"] = list(arg_info.choices)
properties[arg_name] = prop
if arg_info.default is inspect.Parameter.empty:
required.append(arg_name)
if properties:
parameters = {"json": {"type": "object", "properties": properties}}
if required:
parameters["json"]["required"] = required
return parameters
return None
def _build_tools(fnc_ctx: Any) -> List[dict]:
tools: List[dict] = []
for fnc_info in fnc_ctx.ai_functions.values():
parameters = _build_parameters(fnc_info.arguments)
func_decl = {
"toolSpec": {
"name": fnc_info.name,
"description": fnc_info.description,
"inputSchema": parameters
if parameters
else {"json": {"type": "object", "properties": {}}},
}
}
tools.append(func_decl)
return tools
def _build_image(image: llm.ChatImage, cache_key: Any) -> dict:
if isinstance(image.image, str):
if image.image.startswith("data:image/jpeg;base64,"):
base64_data = image.image.split(",", 1)[1]
try:
image_bytes = base64.b64decode(base64_data)
except Exception as e:
raise ValueError("Invalid base64 data in image URL") from e
return {"image": {"format": "jpeg", "source": {"bytes": image_bytes}}}
else:
return {"image": {"format": "jpeg", "source": {"uri": image.image}}}
elif isinstance(image.image, rtc.VideoFrame):
if cache_key not in image._cache:
opts = utils.images.EncodeOptions()
if image.inference_width and image.inference_height:
opts.resize_options = utils.images.ResizeOptions(
width=image.inference_width,
height=image.inference_height,
strategy="scale_aspect_fit",
)
image._cache[cache_key] = utils.images.encode(image.image, opts)
return {
"image": {
"format": "jpeg",
"source": {
"bytes": image._cache[cache_key],
},
}
}
raise ValueError(f"Unsupported image type: {type(image.image)}")
def _build_aws_ctx(
chat_ctx: llm.ChatContext, cache_key: Any
) -> Tuple[List[dict], Optional[dict]]:
messages: List[dict] = []
system: Optional[dict] = None
current_role: Optional[str] = None
current_content: List[dict] = []
for msg in chat_ctx.messages:
if msg.role == "system":
if isinstance(msg.content, str):
system = {"text": msg.content}
continue
if msg.role == "assistant":
role = "assistant"
else:
role = "user"
if role != current_role:
if current_role is not None and current_content:
messages.append({"role": current_role, "content": current_content})
current_role = role
current_content = []
if msg.tool_calls:
for fnc in msg.tool_calls:
current_content.append(
{
"toolUse": {
"toolUseId": fnc.tool_call_id,
"name": fnc.function_info.name,
"input": fnc.arguments,
}
}
)
if msg.role == "tool":
tool_response: dict = {
"toolResult": {
"toolUseId": msg.tool_call_id,
"content": [],
"status": "success",
}
}
if isinstance(msg.content, dict):
tool_response["toolResult"]["content"].append({"json": msg.content})
elif isinstance(msg.content, str):
tool_response["toolResult"]["content"].append({"text": msg.content})
current_content.append(tool_response)
else:
if msg.content:
if isinstance(msg.content, str):
current_content.append({"text": msg.content})
elif isinstance(msg.content, dict):
current_content.append({"text": json.dumps(msg.content)})
elif isinstance(msg.content, list):
for item in msg.content:
if isinstance(item, str):
current_content.append({"text": item})
elif isinstance(item, llm.ChatImage):
current_content.append(_build_image(item, cache_key))
if current_role is not None and current_content:
messages.append({"role": current_role, "content": current_content})
return messages, system
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import asyncio
import os
from dataclasses import dataclass
from typing import Any, Literal, MutableSet, Union
import boto3
from livekit.agents import (
APIConnectionError,
APIStatusError,
llm,
)
from livekit.agents.llm import LLMCapabilities, ToolChoice, _create_ai_function_info
from livekit.agents.types import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions
from ._utils import _build_aws_ctx, _build_tools, _get_aws_credentials
from .log import logger
TEXT_MODEL = Literal["anthropic.claude-3-5-sonnet-20241022-v2:0"]
DEFAULT_REGION = "us-east-1"
@dataclass
class LLMOptions:
model: TEXT_MODEL | str
temperature: float | None
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = "auto"
max_output_tokens: int | None = None
top_p: float | None = None
additional_request_fields: dict[str, Any] | None = None
class LLM(llm.LLM):
def __init__(
self,
*,
model: TEXT_MODEL | str = "anthropic.claude-3-5-sonnet-20240620-v1:0",
api_key: str | None = None,
api_secret: str | None = None,
region: str = "us-east-1",
temperature: float = 0.8,
max_output_tokens: int | None = None,
top_p: float | None = None,
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = "auto",
additional_request_fields: dict[str, Any] | None = None,
) -> None:
"""
Create a new instance of AWS Bedrock LLM.
``api_key`` and ``api_secret`` must be set to your AWS Access key id and secret access key, either using the argument or by setting the
``AWS_ACCESS_KEY_ID`` and ``AWS_SECRET_ACCESS_KEY`` environmental variables.
See https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock-runtime/client/converse_stream.html for more details on the the AWS Bedrock Runtime API.
Args:
model (TEXT_MODEL, optional): model or inference profile arn to use(https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-use.html). Defaults to 'anthropic.claude-3-5-sonnet-20240620-v1:0'.
api_key(str, optional): AWS access key id.
api_secret(str, optional): AWS secret access key
region (str, optional): The region to use for AWS API requests. Defaults value is "us-east-1".
temperature (float, optional): Sampling temperature for response generation. Defaults to 0.8.
max_output_tokens (int, optional): Maximum number of tokens to generate in the output. Defaults to None.
top_p (float, optional): The nucleus sampling probability for response generation. Defaults to None.
tool_choice (ToolChoice or Literal["auto", "required", "none"], optional): Specifies whether to use tools during response generation. Defaults to "auto".
additional_request_fields (dict[str, Any], optional): Additional request fields to send to the AWS Bedrock Converse API. Defaults to None.
"""
super().__init__(
capabilities=LLMCapabilities(
supports_choices_on_int=True,
requires_persistent_functions=True,
)
)
self._api_key, self._api_secret = _get_aws_credentials(
api_key, api_secret, region
)
self._model = model or os.environ.get("BEDROCK_INFERENCE_PROFILE_ARN")
if not self._model:
raise ValueError(
"model or inference profile arn must be set using the argument or by setting the BEDROCK_INFERENCE_PROFILE_ARN environment variable."
)
self._opts = LLMOptions(
model=self._model,
temperature=temperature,
tool_choice=tool_choice,
max_output_tokens=max_output_tokens,
top_p=top_p,
additional_request_fields=additional_request_fields,
)
self._region = region
self._running_fncs: MutableSet[asyncio.Task[Any]] = set()
def chat(
self,
*,
chat_ctx: llm.ChatContext,
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
fnc_ctx: llm.FunctionContext | None = None,
temperature: float | None = None,
n: int | None = 1,
parallel_tool_calls: bool | None = None,
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]]
| None = None,
) -> "LLMStream":
if tool_choice is None:
tool_choice = self._opts.tool_choice
if temperature is None:
temperature = self._opts.temperature
return LLMStream(
self,
model=self._opts.model,
aws_access_key_id=self._api_key,
aws_secret_access_key=self._api_secret,
region_name=self._region,
max_output_tokens=self._opts.max_output_tokens,
top_p=self._opts.top_p,
additional_request_fields=self._opts.additional_request_fields,
chat_ctx=chat_ctx,
fnc_ctx=fnc_ctx,
conn_options=conn_options,
temperature=temperature,
tool_choice=tool_choice,
)
class LLMStream(llm.LLMStream):
def __init__(
self,
llm: LLM,
*,
model: str | TEXT_MODEL,
aws_access_key_id: str | None,
aws_secret_access_key: str | None,
region_name: str,
chat_ctx: llm.ChatContext,
conn_options: APIConnectOptions,
fnc_ctx: llm.FunctionContext | None,
temperature: float | None,
max_output_tokens: int | None,
top_p: float | None,
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]],
additional_request_fields: dict[str, Any] | None,
) -> None:
super().__init__(
llm, chat_ctx=chat_ctx, fnc_ctx=fnc_ctx, conn_options=conn_options
)
self._client = boto3.client(
"bedrock-runtime",
region_name=region_name,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
)
self._model = model
self._llm: LLM = llm
self._max_output_tokens = max_output_tokens
self._top_p = top_p
self._temperature = temperature
self._tool_choice = tool_choice
self._additional_request_fields = additional_request_fields
async def _run(self) -> None:
self._tool_call_id: str | None = None
self._fnc_name: str | None = None
self._fnc_raw_arguments: str | None = None
self._text: str = ""
retryable = True
try:
opts: dict[str, Any] = {}
messages, system_instruction = _build_aws_ctx(self._chat_ctx, id(self))
messages = _merge_messages(messages)
def _get_tool_config() -> dict[str, Any] | None:
if not (self._fnc_ctx and self._fnc_ctx.ai_functions):
return None
tools = _build_tools(self._fnc_ctx)
config: dict[str, Any] = {"tools": tools}
if isinstance(self._tool_choice, ToolChoice):
config["toolChoice"] = {"tool": {"name": self._tool_choice.name}}
elif self._tool_choice == "required":
config["toolChoice"] = {"any": {}}
elif self._tool_choice == "auto":
config["toolChoice"] = {"auto": {}}
else:
return None
return config
tool_config = _get_tool_config()
if tool_config:
opts["toolConfig"] = tool_config
if self._additional_request_fields:
opts["additionalModelRequestFields"] = _strip_nones(
self._additional_request_fields
)
if system_instruction:
opts["system"] = [system_instruction]
inference_config = _strip_nones(
{
"maxTokens": self._max_output_tokens,
"temperature": self._temperature,
"topP": self._top_p,
}
)
response = self._client.converse_stream(
modelId=self._model,
messages=messages,
inferenceConfig=inference_config,
**_strip_nones(opts),
) # type: ignore
request_id = response["ResponseMetadata"]["RequestId"]
if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
raise APIStatusError(
f"aws bedrock llm: error generating content: {response}",
retryable=False,
request_id=request_id,
)
for chunk in response["stream"]:
chat_chunk = self._parse_chunk(request_id, chunk)
if chat_chunk is not None:
retryable = False
self._event_ch.send_nowait(chat_chunk)
# Let other coroutines run
await asyncio.sleep(0)
except Exception as e:
raise APIConnectionError(
f"aws bedrock llm: error generating content: {e}",
retryable=retryable,
) from e
def _parse_chunk(self, request_id: str, chunk: dict) -> llm.ChatChunk | None:
if "contentBlockStart" in chunk:
tool_use = chunk["contentBlockStart"]["start"]["toolUse"]
self._tool_call_id = tool_use["toolUseId"]
self._fnc_name = tool_use["name"]
self._fnc_raw_arguments = ""
elif "contentBlockDelta" in chunk:
delta = chunk["contentBlockDelta"]["delta"]
if "toolUse" in delta:
self._fnc_raw_arguments += delta["toolUse"]["input"]
elif "text" in delta:
self._text += delta["text"]
elif "contentBlockStop" in chunk:
if self._text:
chat_chunk = llm.ChatChunk(
request_id=request_id,
choices=[
llm.Choice(
delta=llm.ChoiceDelta(content=self._text, role="assistant"),
index=chunk["contentBlockStop"]["contentBlockIndex"],
)
],
)
self._text = ""
return chat_chunk
elif self._tool_call_id:
return self._try_build_function(request_id, chunk)
elif "metadata" in chunk:
metadata = chunk["metadata"]
return llm.ChatChunk(
request_id=request_id,
usage=llm.CompletionUsage(
completion_tokens=metadata["usage"]["outputTokens"],
prompt_tokens=metadata["usage"]["inputTokens"],
total_tokens=metadata["usage"]["totalTokens"],
),
)
return None
def _try_build_function(self, request_id: str, chunk: dict) -> llm.ChatChunk | None:
if self._tool_call_id is None:
logger.warning("aws bedrock llm: no tool call id in the response")
return None
if self._fnc_name is None:
logger.warning("aws bedrock llm: no function name in the response")
return None
if self._fnc_raw_arguments is None:
logger.warning("aws bedrock llm: no function arguments in the response")
return None
if self._fnc_ctx is None:
logger.warning(
"aws bedrock llm: stream tried to run function without function context"
)
return None
fnc_info = _create_ai_function_info(
self._fnc_ctx,
self._tool_call_id,
self._fnc_name,
self._fnc_raw_arguments,
)
self._tool_call_id = self._fnc_name = self._fnc_raw_arguments = None
self._function_calls_info.append(fnc_info)
return llm.ChatChunk(
request_id=request_id,
choices=[
llm.Choice(
delta=llm.ChoiceDelta(
role="assistant",
tool_calls=[fnc_info],
),
index=chunk["contentBlockStop"]["contentBlockIndex"],
)
],
)
def _merge_messages(
messages: list[dict],
) -> list[dict]:
# Anthropic enforces alternating messages
combined_messages: list[dict] = []
for m in messages:
if len(combined_messages) == 0 or m["role"] != combined_messages[-1]["role"]:
combined_messages.append(m)
continue
last_message = combined_messages[-1]
if not isinstance(last_message["content"], list) or not isinstance(
m["content"], list
):
logger.error("message content is not a list")
continue
last_message["content"].extend(m["content"])
if len(combined_messages) == 0 or combined_messages[0]["role"] != "user":
combined_messages.insert(0, {"role": "user", "content": [{"text": "(empty)"}]})
return combined_messages
def _strip_nones(d: dict[str, Any]) -> dict[str, Any]:
return {k: v for k, v in d.items() if v is not None}
import logging
logger = logging.getLogger("livekit.plugins.aws")
from typing import Literal
TTS_SPEECH_ENGINE = Literal["standard", "neural", "long-form", "generative"]
TTS_LANGUAGE = Literal[
"arb",
"cmn-CN",
"cy-GB",
"da-DK",
"de-DE",
"en-AU",
"en-GB",
"en-GB-WLS",
"en-IN",
"en-US",
"es-ES",
"es-MX",
"es-US",
"fr-CA",
"fr-FR",
"is-IS",
"it-IT",
"ja-JP",
"hi-IN",
"ko-KR",
"nb-NO",
"nl-NL",
"pl-PL",
"pt-BR",
"pt-PT",
"ro-RO",
"ru-RU",
"sv-SE",
"tr-TR",
"en-NZ",
"en-ZA",
"ca-ES",
"de-AT",
"yue-CN",
"ar-AE",
"fi-FI",
"en-IE",
"nl-BE",
"fr-BE",
"cs-CZ",
"de-CH",
]
TTS_OUTPUT_FORMAT = Literal["mp3"]
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import asyncio
from dataclasses import dataclass
from typing import Optional
from amazon_transcribe.client import TranscribeStreamingClient
from amazon_transcribe.model import Result, TranscriptEvent
from livekit import rtc
from livekit.agents import (
DEFAULT_API_CONNECT_OPTIONS,
APIConnectOptions,
stt,
utils,
)
from ._utils import _get_aws_credentials
from .log import logger
@dataclass
class STTOptions:
speech_region: str
sample_rate: int
language: str
encoding: str
vocabulary_name: Optional[str]
session_id: Optional[str]
vocab_filter_method: Optional[str]
vocab_filter_name: Optional[str]
show_speaker_label: Optional[bool]
enable_channel_identification: Optional[bool]
number_of_channels: Optional[int]
enable_partial_results_stabilization: Optional[bool]
partial_results_stability: Optional[str]
language_model_name: Optional[str]
class STT(stt.STT):
def __init__(
self,
*,
speech_region: str = "us-east-1",
api_key: str | None = None,
api_secret: str | None = None,
sample_rate: int = 48000,
language: str = "en-US",
encoding: str = "pcm",
vocabulary_name: Optional[str] = None,
session_id: Optional[str] = None,
vocab_filter_method: Optional[str] = None,
vocab_filter_name: Optional[str] = None,
show_speaker_label: Optional[bool] = None,
enable_channel_identification: Optional[bool] = None,
number_of_channels: Optional[int] = None,
enable_partial_results_stabilization: Optional[bool] = None,
partial_results_stability: Optional[str] = None,
language_model_name: Optional[str] = None,
):
super().__init__(
capabilities=stt.STTCapabilities(streaming=True, interim_results=True)
)
self._api_key, self._api_secret = _get_aws_credentials(
api_key, api_secret, speech_region
)
self._config = STTOptions(
speech_region=speech_region,
language=language,
sample_rate=sample_rate,
encoding=encoding,
vocabulary_name=vocabulary_name,
session_id=session_id,
vocab_filter_method=vocab_filter_method,
vocab_filter_name=vocab_filter_name,
show_speaker_label=show_speaker_label,
enable_channel_identification=enable_channel_identification,
number_of_channels=number_of_channels,
enable_partial_results_stabilization=enable_partial_results_stabilization,
partial_results_stability=partial_results_stability,
language_model_name=language_model_name,
)
async def _recognize_impl(
self,
buffer: utils.AudioBuffer,
*,
language: str | None,
conn_options: APIConnectOptions,
) -> stt.SpeechEvent:
raise NotImplementedError(
"Amazon Transcribe does not support single frame recognition"
)
def stream(
self,
*,
language: str | None = None,
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
) -> "SpeechStream":
return SpeechStream(
stt=self,
conn_options=conn_options,
opts=self._config,
)
class SpeechStream(stt.SpeechStream):
def __init__(
self,
stt: STT,
opts: STTOptions,
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
) -> None:
super().__init__(
stt=stt, conn_options=conn_options, sample_rate=opts.sample_rate
)
self._opts = opts
self._client = TranscribeStreamingClient(region=self._opts.speech_region)
async def _run(self) -> None:
stream = await self._client.start_stream_transcription(
language_code=self._opts.language,
media_sample_rate_hz=self._opts.sample_rate,
media_encoding=self._opts.encoding,
vocabulary_name=self._opts.vocabulary_name,
session_id=self._opts.session_id,
vocab_filter_method=self._opts.vocab_filter_method,
vocab_filter_name=self._opts.vocab_filter_name,
show_speaker_label=self._opts.show_speaker_label,
enable_channel_identification=self._opts.enable_channel_identification,
number_of_channels=self._opts.number_of_channels,
enable_partial_results_stabilization=self._opts.enable_partial_results_stabilization,
partial_results_stability=self._opts.partial_results_stability,
language_model_name=self._opts.language_model_name,
)
@utils.log_exceptions(logger=logger)
async def input_generator():
async for frame in self._input_ch:
if isinstance(frame, rtc.AudioFrame):
await stream.input_stream.send_audio_event(
audio_chunk=frame.data.tobytes()
)
await stream.input_stream.end_stream()
@utils.log_exceptions(logger=logger)
async def handle_transcript_events():
async for event in stream.output_stream:
if isinstance(event, TranscriptEvent):
self._process_transcript_event(event)
tasks = [
asyncio.create_task(input_generator()),
asyncio.create_task(handle_transcript_events()),
]
try:
await asyncio.gather(*tasks)
finally:
await utils.aio.gracefully_cancel(*tasks)
def _process_transcript_event(self, transcript_event: TranscriptEvent):
stream = transcript_event.transcript.results
for resp in stream:
if resp.start_time and resp.start_time == 0.0:
self._event_ch.send_nowait(
stt.SpeechEvent(type=stt.SpeechEventType.START_OF_SPEECH)
)
if resp.end_time and resp.end_time > 0.0:
if resp.is_partial:
self._event_ch.send_nowait(
stt.SpeechEvent(
type=stt.SpeechEventType.INTERIM_TRANSCRIPT,
alternatives=[
_streaming_recognize_response_to_speech_data(resp)
],
)
)
else:
self._event_ch.send_nowait(
stt.SpeechEvent(
type=stt.SpeechEventType.FINAL_TRANSCRIPT,
alternatives=[
_streaming_recognize_response_to_speech_data(resp)
],
)
)
if not resp.is_partial:
self._event_ch.send_nowait(
stt.SpeechEvent(type=stt.SpeechEventType.END_OF_SPEECH)
)
def _streaming_recognize_response_to_speech_data(resp: Result) -> stt.SpeechData:
data = stt.SpeechData(
language="en-US",
start_time=resp.start_time if resp.start_time else 0.0,
end_time=resp.end_time if resp.end_time else 0.0,
confidence=0.0,
text=resp.alternatives[0].transcript if resp.alternatives else "",
)
return data
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import asyncio
from dataclasses import dataclass
from typing import Any, Callable, Optional
import aiohttp
from aiobotocore.session import AioSession, get_session
from livekit.agents import (
APIConnectionError,
APIConnectOptions,
APIStatusError,
APITimeoutError,
tts,
utils,
)
from ._utils import _get_aws_credentials
from .models import TTS_LANGUAGE, TTS_SPEECH_ENGINE
TTS_NUM_CHANNELS: int = 1
DEFAULT_SPEECH_ENGINE: TTS_SPEECH_ENGINE = "generative"
DEFAULT_SPEECH_REGION = "us-east-1"
DEFAULT_VOICE = "Ruth"
DEFAULT_SAMPLE_RATE = 16000
@dataclass
class _TTSOptions:
# https://docs.aws.amazon.com/polly/latest/dg/API_SynthesizeSpeech.html
voice: str | None
speech_engine: TTS_SPEECH_ENGINE
speech_region: str
sample_rate: int
language: TTS_LANGUAGE | str | None
class TTS(tts.TTS):
def __init__(
self,
*,
voice: str | None = DEFAULT_VOICE,
language: TTS_LANGUAGE | str | None = None,
speech_engine: TTS_SPEECH_ENGINE = DEFAULT_SPEECH_ENGINE,
sample_rate: int = DEFAULT_SAMPLE_RATE,
speech_region: str = DEFAULT_SPEECH_REGION,
api_key: str | None = None,
api_secret: str | None = None,
session: AioSession | None = None,
) -> None:
"""
Create a new instance of AWS Polly TTS.
``api_key`` and ``api_secret`` must be set to your AWS Access key id and secret access key, either using the argument or by setting the
``AWS_ACCESS_KEY_ID`` and ``AWS_SECRET_ACCESS_KEY`` environmental variables.
See https://docs.aws.amazon.com/polly/latest/dg/API_SynthesizeSpeech.html for more details on the the AWS Polly TTS.
Args:
Voice (TTSModels, optional): Voice ID to use for the synthesis. Defaults to "Ruth".
language (TTS_LANGUAGE, optional): language code for the Synthesize Speech request. This is only necessary if using a bilingual voice, such as Aditi, which can be used for either Indian English (en-IN) or Hindi (hi-IN).
sample_rate(int, optional): The audio frequency specified in Hz. Defaults to 16000.
speech_engine(TTS_SPEECH_ENGINE, optional): The engine to use for the synthesis. Defaults to "generative".
speech_region(str, optional): The region to use for the synthesis. Defaults to "us-east-1".
api_key(str, optional): AWS access key id.
api_secret(str, optional): AWS secret access key.
"""
super().__init__(
capabilities=tts.TTSCapabilities(
streaming=False,
),
sample_rate=sample_rate,
num_channels=TTS_NUM_CHANNELS,
)
self._api_key, self._api_secret = _get_aws_credentials(
api_key, api_secret, speech_region
)
self._opts = _TTSOptions(
voice=voice,
speech_engine=speech_engine,
speech_region=speech_region,
language=language,
sample_rate=sample_rate,
)
self._session = session or get_session()
def _get_client(self):
return self._session.create_client(
"polly",
region_name=self._opts.speech_region,
aws_access_key_id=self._api_key,
aws_secret_access_key=self._api_secret,
)
def synthesize(
self,
text: str,
*,
conn_options: Optional[APIConnectOptions] = None,
) -> "ChunkedStream":
return ChunkedStream(
tts=self,
text=text,
conn_options=conn_options,
opts=self._opts,
get_client=self._get_client,
)
class ChunkedStream(tts.ChunkedStream):
def __init__(
self,
*,
tts: TTS,
text: str,
conn_options: Optional[APIConnectOptions] = None,
opts: _TTSOptions,
get_client: Callable[[], Any],
) -> None:
super().__init__(tts=tts, input_text=text, conn_options=conn_options)
self._opts = opts
self._get_client = get_client
self._segment_id = utils.shortuuid()
async def _run(self):
request_id = utils.shortuuid()
try:
async with self._get_client() as client:
params = {
"Text": self._input_text,
"OutputFormat": "mp3",
"Engine": self._opts.speech_engine,
"VoiceId": self._opts.voice,
"TextType": "text",
"SampleRate": str(self._opts.sample_rate),
"LanguageCode": self._opts.language,
}
response = await client.synthesize_speech(**_strip_nones(params))
if "AudioStream" in response:
decoder = utils.codecs.AudioStreamDecoder(
sample_rate=self._opts.sample_rate,
num_channels=1,
)
# Create a task to push data to the decoder
async def push_data():
try:
async with response["AudioStream"] as resp:
async for data, _ in resp.content.iter_chunks():
decoder.push(data)
finally:
decoder.end_input()
# Start pushing data to the decoder
push_task = asyncio.create_task(push_data())
try:
# Create emitter and process decoded frames
emitter = tts.SynthesizedAudioEmitter(
event_ch=self._event_ch,
request_id=request_id,
segment_id=self._segment_id,
)
async for frame in decoder:
emitter.push(frame)
emitter.flush()
await push_task
finally:
await utils.aio.gracefully_cancel(push_task)
except asyncio.TimeoutError as e:
raise APITimeoutError() from e
except aiohttp.ClientResponseError as e:
raise APIStatusError(
message=e.message,
status_code=e.status,
request_id=request_id,
body=None,
) from e
except Exception as e:
raise APIConnectionError() from e
def _strip_nones(d: dict[str, Any]) -> dict[str, Any]:
return {k: v for k, v in d.items() if v is not None}
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "0.1.1"
{
"name": "livekit-plugins-aws",
"private": true,
"version": "0.1.1"
}
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pathlib
import setuptools
import setuptools.command.build_py
here = pathlib.Path(__file__).parent.resolve()
about = {}
with open(os.path.join(here, "livekit", "plugins", "aws", "version.py"), "r") as f:
exec(f.read(), about)
setuptools.setup(
name="livekit-plugins-aws",
version=about["__version__"],
description="LiveKit Agents Plugin for services from AWS",
long_description=(here / "README.md").read_text(encoding="utf-8"),
long_description_content_type="text/markdown",
url="https://github.com/livekit/agents",
cmdclass={},
classifiers=[
"Intended Audience :: Developers",
"License :: OSI Approved :: Apache Software License",
"Topic :: Multimedia :: Sound/Audio",
"Topic :: Multimedia :: Video",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3 :: Only",
],
keywords=["webrtc", "realtime", "audio", "video", "livekit", "aws"],
license="Apache-2.0",
packages=setuptools.find_namespace_packages(include=["livekit.*"]),
python_requires=">=3.9.0",
install_requires=[
"livekit-agents[codecs]>=0.12.16,<1.0.0",
"aiobotocore==2.19.0",
"boto3==1.36.3",
"amazon-transcribe>=0.6.2",
],
project_urls={
"Documentation": "https://docs.livekit.io",
"Website": "https://livekit.io/",
"Source": "https://github.com/livekit/agents",
},
)
# livekit-plugins-azure
## 0.5.7
### Patch Changes
- add speech endpoint in azure ctor and azure speech sdk version upgrade - [#2007](https://github.com/livekit/agents/pull/2007) ([@jayeshp19](https://github.com/jayeshp19))
## 0.5.6
### Patch Changes
- Add callbacks as updatable Azure TTS options - [#1645](https://github.com/livekit/agents/pull/1645) ([@anishnag](https://github.com/anishnag))
## 0.5.5
### Patch Changes
- feat: Azure.STT support profanity_option - [#1540](https://github.com/livekit/agents/pull/1540) ([@shiftu](https://github.com/shiftu))
- updated livekit-agent reference to <1.0 - [#1607](https://github.com/livekit/agents/pull/1607) ([@davidzhao](https://github.com/davidzhao))
## 0.5.4
### Patch Changes
- Add handlers for supported synthesis events for Azure TTS - [#1486](https://github.com/livekit/agents/pull/1486) ([@anishnag](https://github.com/anishnag))
## 0.5.3
### Patch Changes
- azure speech support all different configs - [#1362](https://github.com/livekit/agents/pull/1362) ([@longcw](https://github.com/longcw))
- reduces initial delay before model retries - [#1337](https://github.com/livekit/agents/pull/1337) ([@davidzhao](https://github.com/davidzhao))
## 0.5.2
### Patch Changes
- fix: Ensure STT exceptions are being propagated - [#1291](https://github.com/livekit/agents/pull/1291) ([@davidzhao](https://github.com/davidzhao))
## 0.5.1
### Patch Changes
- fix azure stt language autodetection - [#1246](https://github.com/livekit/agents/pull/1246) ([@davidzhao](https://github.com/davidzhao))
## 0.5.0
### Minor Changes
- Improvements to end of turn plugin, ensure STT language settings. - [#1195](https://github.com/livekit/agents/pull/1195) ([@davidzhao](https://github.com/davidzhao))
## 0.4.4
### Patch Changes
- fix release - [#1176](https://github.com/livekit/agents/pull/1176) ([@theomonnom](https://github.com/theomonnom))
## 0.4.3
### Patch Changes
- feat: stt retry & stt.FallbackAdapter - [#1114](https://github.com/livekit/agents/pull/1114) ([@theomonnom](https://github.com/theomonnom))
- azure: support auth entra token for TTS - [#1134](https://github.com/livekit/agents/pull/1134) ([@nfma](https://github.com/nfma))
- feat: tts retry & tts.FallbackAdapter - [#1074](https://github.com/livekit/agents/pull/1074) ([@theomonnom](https://github.com/theomonnom))
## 0.4.2
### Patch Changes
- add support for azure speech containers - [#1043](https://github.com/livekit/agents/pull/1043) ([@longcw](https://github.com/longcw))
- fix azure sample_rate parameter - [#1072](https://github.com/livekit/agents/pull/1072) ([@theomonnom](https://github.com/theomonnom))
## 0.4.1
### Patch Changes
- add update_options to TTS - [#922](https://github.com/livekit/agents/pull/922) ([@theomonnom](https://github.com/theomonnom))
- pipelineagent: expose timing metrics & api errors wip - [#957](https://github.com/livekit/agents/pull/957) ([@theomonnom](https://github.com/theomonnom))
- azure tts: fix SSML Implementation by Adding <voice> Tag - [#929](https://github.com/livekit/agents/pull/929) ([@samirsalman](https://github.com/samirsalman))
- azure tts: fix Prosody Config Validation - [#918](https://github.com/livekit/agents/pull/918) ([@samirsalman](https://github.com/samirsalman))
- expose usage metrics - [#984](https://github.com/livekit/agents/pull/984) ([@theomonnom](https://github.com/theomonnom))
## 0.4.0
### Minor Changes
- Azure TTS Prosody SSML support #912 - [#914](https://github.com/livekit/agents/pull/914) ([@theomonnom](https://github.com/theomonnom))
## 0.3.2
### Patch Changes
- avoid returning tiny frames from TTS - [#747](https://github.com/livekit/agents/pull/747) ([@theomonnom](https://github.com/theomonnom))
## 0.3.1
### Patch Changes
- fix end_input not flushing & unhandled flush messages - [#528](https://github.com/livekit/agents/pull/528) ([@theomonnom](https://github.com/theomonnom))
## 0.3.0
### Minor Changes
- dev prerelease - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
### Patch Changes
- test release - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
- pull: '--rebase --autostash ...' - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
- Default loglevel to warn - [#472](https://github.com/livekit/agents/pull/472) ([@lukasIO](https://github.com/lukasIO))
- bump versions to update dependencies - [#510](https://github.com/livekit/agents/pull/510) ([@theomonnom](https://github.com/theomonnom))
- test release - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
- fix changesets release CI - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
- release v0.8.0 - [`6e74aa714c2dfaa8212db4528d7b59d095b6c660`](https://github.com/livekit/agents/commit/6e74aa714c2dfaa8212db4528d7b59d095b6c660) ([@theomonnom](https://github.com/theomonnom))
- dev fixes - multiprocessing & voiceassistant - [#493](https://github.com/livekit/agents/pull/493) ([@theomonnom](https://github.com/theomonnom))
## 0.3.0-dev.7
### Patch Changes
- pull: '--rebase --autostash ...' - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
## 0.3.0-dev.6
### Patch Changes
- test release - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
## 0.3.0-dev.5
### Patch Changes
- test release - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
## 0.3.0-dev.4
### Patch Changes
- fix changesets release CI - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
## 0.3.0-dev.3
### Patch Changes
- bump versions to update dependencies - [#510](https://github.com/livekit/agents/pull/510) ([@theomonnom](https://github.com/theomonnom))
## 0.3.0-dev.2
### Patch Changes
- dev fixes - multiprocessing & voiceassistant - [#493](https://github.com/livekit/agents/pull/493) ([@theomonnom](https://github.com/theomonnom))
## 0.3.0-dev.1
### Minor Changes
- dev prerelease - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
## 0.2.2-dev.0
### Patch Changes
- Default loglevel to warn - [#472](https://github.com/livekit/agents/pull/472) ([@lukasIO](https://github.com/lukasIO))
# LiveKit Plugins Azure
Agent Framework plugin for services from Azure Cognitive Services. Currently supports STT and TTS.
## Installation
```bash
pip install livekit-plugins-azure
You’ll need to specify an Azure Speech Key and a Deployment Region. They can be set as environment variables: AZURE_SPEECH_KEY
and AZURE_SPEECH_REGION
, respectively.
## livekit-plugins/livekit-plugins-azure/livekit/plugins/azure/__init__.py
```py
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .stt import STT, SpeechStream
from .tts import TTS
from .version import __version__
__all__ = ["STT", "SpeechStream", "TTS", "__version__"]
from livekit.agents import Plugin
from .log import logger
class AzurePlugin(Plugin):
def __init__(self):
super().__init__(__name__, __version__, __package__, logger)
Plugin.register_plugin(AzurePlugin())
# Cleanup docs of unexported modules
_module = dir()
NOT_IN_ALL = [m for m in _module if m not in __all__]
__pdoc__ = {}
for n in NOT_IN_ALL:
__pdoc__[n] = False
import logging
logger = logging.getLogger("livekit.plugins.azure")
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import asyncio
import contextlib
import os
import weakref
from copy import deepcopy
from dataclasses import dataclass
from livekit import rtc
from livekit.agents import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions, stt, utils
import azure.cognitiveservices.speech as speechsdk # type: ignore
from .log import logger
@dataclass
class STTOptions:
speech_key: str | None
speech_region: str | None
# see https://learn.microsoft.com/en-us/azure/ai-services/speech-service/speech-container-stt?tabs=container#use-the-container
speech_host: str | None
# for using Microsoft Entra auth (see https://learn.microsoft.com/en-us/azure/ai-services/speech-service/how-to-configure-azure-ad-auth?tabs=portal&pivots=programming-language-python)
speech_auth_token: str | None
sample_rate: int
num_channels: int
segmentation_silence_timeout_ms: int | None
segmentation_max_time_ms: int | None
segmentation_strategy: str | None
languages: list[
str
] # see https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=stt
speech_endpoint: str | None = None
profanity: speechsdk.enums.ProfanityOption | None = None
class STT(stt.STT):
def __init__(
self,
*,
speech_key: str | None = None,
speech_region: str | None = None,
speech_host: str | None = None,
speech_endpoint: str | None = None,
speech_auth_token: str | None = None,
sample_rate: int = 16000,
num_channels: int = 1,
segmentation_silence_timeout_ms: int | None = None,
segmentation_max_time_ms: int | None = None,
segmentation_strategy: str | None = None,
# Azure handles multiple languages and can auto-detect the language used. It requires the candidate set to be set.
languages: list[str] = ["en-US"],
# for compatibility with other STT plugins
language: str | None = None,
profanity: speechsdk.enums.ProfanityOption | None = None,
):
"""
Create a new instance of Azure STT.
Either ``speech_host`` or ``speech_key`` and ``speech_region`` or
``speech_auth_token`` and ``speech_region`` must be set using arguments.
Alternatively, set the ``AZURE_SPEECH_HOST``, ``AZURE_SPEECH_KEY``
and ``AZURE_SPEECH_REGION`` environmental variables, respectively.
``speech_auth_token`` must be set using the arguments as it's an ephemeral token.
"""
super().__init__(
capabilities=stt.STTCapabilities(streaming=True, interim_results=True)
)
speech_host = speech_host or os.environ.get("AZURE_SPEECH_HOST")
speech_key = speech_key or os.environ.get("AZURE_SPEECH_KEY")
speech_region = speech_region or os.environ.get("AZURE_SPEECH_REGION")
if not (
speech_host
or (speech_key and speech_region)
or (speech_auth_token and speech_region)
or (speech_endpoint and speech_key)
):
raise ValueError(
"AZURE_SPEECH_HOST or AZURE_SPEECH_KEY and AZURE_SPEECH_REGION or speech_auth_token and AZURE_SPEECH_REGION must be set"
)
if speech_region and speech_endpoint:
logger.warning(
"speech_region and speech_endpoint are both set. Using speech_endpoint."
)
speech_region = None
if language:
languages = [language]
self._config = STTOptions(
speech_key=speech_key,
speech_region=speech_region,
speech_host=speech_host,
speech_auth_token=speech_auth_token,
speech_endpoint=speech_endpoint,
languages=languages,
sample_rate=sample_rate,
num_channels=num_channels,
segmentation_silence_timeout_ms=segmentation_silence_timeout_ms,
segmentation_max_time_ms=segmentation_max_time_ms,
segmentation_strategy=segmentation_strategy,
profanity=profanity,
)
self._streams = weakref.WeakSet[SpeechStream]()
async def _recognize_impl(
self,
buffer: utils.AudioBuffer,
*,
language: str | None,
conn_options: APIConnectOptions,
) -> stt.SpeechEvent:
raise NotImplementedError("Azure STT does not support single frame recognition")
def stream(
self,
*,
languages: list[str] | None = None,
language: str | None = None,
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
) -> "SpeechStream":
config = deepcopy(self._config)
if language and not languages:
languages = [language]
if languages:
config.languages = languages
stream = SpeechStream(stt=self, opts=config, conn_options=conn_options)
self._streams.add(stream)
return stream
def update_options(
self, *, language: str | None = None, languages: list[str] | None = None
):
if language and not languages:
languages = [language]
if languages is not None:
self._config.languages = languages
for stream in self._streams:
stream.update_options(languages=languages)
class SpeechStream(stt.SpeechStream):
def __init__(
self, *, stt: STT, opts: STTOptions, conn_options: APIConnectOptions
) -> None:
super().__init__(
stt=stt, conn_options=conn_options, sample_rate=opts.sample_rate
)
self._opts = opts
self._speaking = False
self._session_stopped_event = asyncio.Event()
self._session_started_event = asyncio.Event()
self._loop = asyncio.get_running_loop()
self._reconnect_event = asyncio.Event()
def update_options(
self, *, language: str | None = None, languages: list[str] | None = None
):
if language and not languages:
languages = [language]
if languages:
self._opts.languages = languages
self._reconnect_event.set()
async def _run(self) -> None:
while True:
self._stream = speechsdk.audio.PushAudioInputStream(
stream_format=speechsdk.audio.AudioStreamFormat(
samples_per_second=self._opts.sample_rate,
bits_per_sample=16,
channels=self._opts.num_channels,
)
)
self._recognizer = _create_speech_recognizer(
config=self._opts, stream=self._stream
)
self._recognizer.recognizing.connect(self._on_recognizing)
self._recognizer.recognized.connect(self._on_recognized)
self._recognizer.speech_start_detected.connect(self._on_speech_start)
self._recognizer.speech_end_detected.connect(self._on_speech_end)
self._recognizer.session_started.connect(self._on_session_started)
self._recognizer.session_stopped.connect(self._on_session_stopped)
self._recognizer.start_continuous_recognition()
try:
await asyncio.wait_for(
self._session_started_event.wait(), self._conn_options.timeout
)
async def process_input():
async for input in self._input_ch:
if isinstance(input, rtc.AudioFrame):
self._stream.write(input.data.tobytes())
process_input_task = asyncio.create_task(process_input())
wait_reconnect_task = asyncio.create_task(self._reconnect_event.wait())
try:
done, _ = await asyncio.wait(
[process_input_task, wait_reconnect_task],
return_when=asyncio.FIRST_COMPLETED,
)
for task in done:
if task != wait_reconnect_task:
task.result()
if wait_reconnect_task not in done:
break
self._reconnect_event.clear()
finally:
await utils.aio.gracefully_cancel(
process_input_task, wait_reconnect_task
)
self._stream.close()
await self._session_stopped_event.wait()
finally:
def _cleanup():
self._recognizer.stop_continuous_recognition()
del self._recognizer
await asyncio.to_thread(_cleanup)
def _on_recognized(self, evt: speechsdk.SpeechRecognitionEventArgs):
detected_lg = speechsdk.AutoDetectSourceLanguageResult(evt.result).language
text = evt.result.text.strip()
if not text:
return
if not detected_lg and self._opts.languages:
detected_lg = self._opts.languages[0]
final_data = stt.SpeechData(
language=detected_lg, confidence=1.0, text=evt.result.text
)
with contextlib.suppress(RuntimeError):
self._loop.call_soon_threadsafe(
self._event_ch.send_nowait,
stt.SpeechEvent(
type=stt.SpeechEventType.FINAL_TRANSCRIPT, alternatives=[final_data]
),
)
def _on_recognizing(self, evt: speechsdk.SpeechRecognitionEventArgs):
detected_lg = speechsdk.AutoDetectSourceLanguageResult(evt.result).language
text = evt.result.text.strip()
if not text:
return
if not detected_lg and self._opts.languages:
detected_lg = self._opts.languages[0]
interim_data = stt.SpeechData(
language=detected_lg, confidence=0.0, text=evt.result.text
)
with contextlib.suppress(RuntimeError):
self._loop.call_soon_threadsafe(
self._event_ch.send_nowait,
stt.SpeechEvent(
type=stt.SpeechEventType.INTERIM_TRANSCRIPT,
alternatives=[interim_data],
),
)
def _on_speech_start(self, evt: speechsdk.SpeechRecognitionEventArgs):
if self._speaking:
return
self._speaking = True
with contextlib.suppress(RuntimeError):
self._loop.call_soon_threadsafe(
self._event_ch.send_nowait,
stt.SpeechEvent(type=stt.SpeechEventType.START_OF_SPEECH),
)
def _on_speech_end(self, evt: speechsdk.SpeechRecognitionEventArgs):
if not self._speaking:
return
self._speaking = False
with contextlib.suppress(RuntimeError):
self._loop.call_soon_threadsafe(
self._event_ch.send_nowait,
stt.SpeechEvent(type=stt.SpeechEventType.END_OF_SPEECH),
)
def _on_session_started(self, evt: speechsdk.SpeechRecognitionEventArgs):
self._session_started_event.set()
with contextlib.suppress(RuntimeError):
self._loop.call_soon_threadsafe(self._session_started_event.set)
def _on_session_stopped(self, evt: speechsdk.SpeechRecognitionEventArgs):
with contextlib.suppress(RuntimeError):
self._loop.call_soon_threadsafe(self._session_stopped_event.set)
def _create_speech_recognizer(
*, config: STTOptions, stream: speechsdk.audio.AudioInputStream
) -> speechsdk.SpeechRecognizer:
# let the SpeechConfig constructor to validate the arguments
speech_config = speechsdk.SpeechConfig(
subscription=config.speech_key,
region=config.speech_region,
endpoint=config.speech_endpoint,
host=config.speech_host,
auth_token=config.speech_auth_token,
)
if config.segmentation_silence_timeout_ms:
speech_config.set_property(
speechsdk.enums.PropertyId.Speech_SegmentationSilenceTimeoutMs,
str(config.segmentation_silence_timeout_ms),
)
if config.segmentation_max_time_ms:
speech_config.set_property(
speechsdk.enums.PropertyId.Speech_SegmentationMaximumTimeMs,
str(config.segmentation_max_time_ms),
)
if config.segmentation_strategy:
speech_config.set_property(
speechsdk.enums.PropertyId.Speech_SegmentationStrategy,
str(config.segmentation_strategy),
)
if config.profanity is not None:
speech_config.set_profanity(config.profanity)
auto_detect_source_language_config = None
if config.languages and len(config.languages) >= 1:
auto_detect_source_language_config = (
speechsdk.languageconfig.AutoDetectSourceLanguageConfig(
languages=config.languages
)
)
audio_config = speechsdk.audio.AudioConfig(stream=stream)
speech_recognizer = speechsdk.SpeechRecognizer(
speech_config=speech_config,
audio_config=audio_config,
auto_detect_source_language_config=auto_detect_source_language_config, # type: ignore
)
return speech_recognizer
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import asyncio
import contextlib
import os
from dataclasses import dataclass
from typing import Callable, Literal, Optional
from livekit.agents import (
APIConnectionError,
APIConnectOptions,
APITimeoutError,
tts,
utils,
)
import azure.cognitiveservices.speech as speechsdk # type: ignore
from .log import logger
# only raw & pcm
SUPPORTED_SAMPLE_RATE = {
8000: speechsdk.SpeechSynthesisOutputFormat.Raw8Khz16BitMonoPcm,
16000: speechsdk.SpeechSynthesisOutputFormat.Raw16Khz16BitMonoPcm,
22050: speechsdk.SpeechSynthesisOutputFormat.Raw22050Hz16BitMonoPcm,
24000: speechsdk.SpeechSynthesisOutputFormat.Raw24Khz16BitMonoPcm,
44100: speechsdk.SpeechSynthesisOutputFormat.Raw44100Hz16BitMonoPcm,
48000: speechsdk.SpeechSynthesisOutputFormat.Raw48Khz16BitMonoPcm,
}
@dataclass
class ProsodyConfig:
"""
Prosody configuration for Azure TTS.
Args:
rate: Speaking rate. Can be one of "x-slow", "slow", "medium", "fast", "x-fast", or a float. A float value of 1.0 represents normal speed.
volume: Speaking volume. Can be one of "silent", "x-soft", "soft", "medium", "loud", "x-loud", or a float. A float value of 100 (x-loud) represents the highest volume and it's the default pitch.
pitch: Speaking pitch. Can be one of "x-low", "low", "medium", "high", "x-high". The default pitch is "medium".
"""
rate: Literal["x-slow", "slow", "medium", "fast", "x-fast"] | float | None = None
volume: (
Literal["silent", "x-soft", "soft", "medium", "loud", "x-loud"] | float | None
) = None
pitch: Literal["x-low", "low", "medium", "high", "x-high"] | None = None
def validate(self) -> None:
if self.rate:
if isinstance(self.rate, float) and not 0.5 <= self.rate <= 2:
raise ValueError("Prosody rate must be between 0.5 and 2")
if isinstance(self.rate, str) and self.rate not in [
"x-slow",
"slow",
"medium",
"fast",
"x-fast",
]:
raise ValueError(
"Prosody rate must be one of 'x-slow', 'slow', 'medium', 'fast', 'x-fast'"
)
if self.volume:
if isinstance(self.volume, float) and not 0 <= self.volume <= 100:
raise ValueError("Prosody volume must be between 0 and 100")
if isinstance(self.volume, str) and self.volume not in [
"silent",
"x-soft",
"soft",
"medium",
"loud",
"x-loud",
]:
raise ValueError(
"Prosody volume must be one of 'silent', 'x-soft', 'soft', 'medium', 'loud', 'x-loud'"
)
if self.pitch and self.pitch not in [
"x-low",
"low",
"medium",
"high",
"x-high",
]:
raise ValueError(
"Prosody pitch must be one of 'x-low', 'low', 'medium', 'high', 'x-high'"
)
def __post_init__(self):
self.validate()
@dataclass
class StyleConfig:
"""
Style configuration for Azure TTS neural voices.
Args:
style: Speaking style for neural voices. Examples: "cheerful", "sad", "angry", etc.
degree: Intensity of the style, from 0.1 to 2.0.
"""
style: str
degree: float | None = None
def validate(self) -> None:
if self.degree is not None and not 0.1 <= self.degree <= 2.0:
raise ValueError("Style degree must be between 0.1 and 2.0")
def __post_init__(self):
self.validate()
@dataclass
class _TTSOptions:
sample_rate: int
speech_key: str | None = None
speech_region: str | None = None
# see https://learn.microsoft.com/en-us/azure/ai-services/speech-service/speech-container-ntts?tabs=container#use-the-container
speech_host: str | None = None
# see https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts
voice: str | None = None
# for using custom voices (see https://learn.microsoft.com/en-us/azure/ai-services/speech-service/how-to-speech-synthesis?tabs=browserjs%2Cterminal&pivots=programming-language-python#use-a-custom-endpoint)
endpoint_id: str | None = None
# for using Microsoft Entra auth (see https://learn.microsoft.com/en-us/azure/ai-services/speech-service/how-to-configure-azure-ad-auth?tabs=portal&pivots=programming-language-python)
speech_auth_token: str | None = None
# Useful to specify the language with multi-language voices
language: str | None = None
# See https://learn.microsoft.com/en-us/azure/ai-services/speech-service/speech-synthesis-markup-voice#adjust-prosody
prosody: ProsodyConfig | None = None
speech_endpoint: str | None = None
style: StyleConfig | None = None
# See https://learn.microsoft.com/en-us/azure/ai-services/speech-service/how-to-speech-synthesis?tabs=browserjs%2Cterminal&pivots=programming-language-python
on_bookmark_reached_event: Callable | None = None
on_synthesis_canceled_event: Callable | None = None
on_synthesis_completed_event: Callable | None = None
on_synthesis_started_event: Callable | None = None
on_synthesizing_event: Callable | None = None
on_viseme_event: Callable | None = None
on_word_boundary_event: Callable | None = None
class TTS(tts.TTS):
def __init__(
self,
*,
sample_rate: int = 24000,
voice: str | None = None,
language: str | None = None,
prosody: ProsodyConfig | None = None,
speech_key: str | None = None,
speech_region: str | None = None,
speech_host: str | None = None,
speech_auth_token: str | None = None,
endpoint_id: str | None = None,
style: StyleConfig | None = None,
on_bookmark_reached_event: Callable | None = None,
on_synthesis_canceled_event: Callable | None = None,
on_synthesis_completed_event: Callable | None = None,
on_synthesis_started_event: Callable | None = None,
on_synthesizing_event: Callable | None = None,
on_viseme_event: Callable | None = None,
on_word_boundary_event: Callable | None = None,
) -> None:
"""
Create a new instance of Azure TTS.
Either ``speech_host`` or ``speech_key`` and ``speech_region`` or
``speech_auth_token`` and ``speech_region`` must be set using arguments.
Alternatively, set the ``AZURE_SPEECH_HOST``, ``AZURE_SPEECH_KEY``
and ``AZURE_SPEECH_REGION`` environmental variables, respectively.
``speech_auth_token`` must be set using the arguments as it's an ephemeral token.
"""
if sample_rate not in SUPPORTED_SAMPLE_RATE:
raise ValueError(
f"Unsupported sample rate {sample_rate}. Supported sample rates: {list(SUPPORTED_SAMPLE_RATE.keys())}"
)
super().__init__(
capabilities=tts.TTSCapabilities(
streaming=False,
),
sample_rate=sample_rate,
num_channels=1,
)
speech_host = speech_host or os.environ.get("AZURE_SPEECH_HOST")
speech_key = speech_key or os.environ.get("AZURE_SPEECH_KEY")
speech_region = speech_region or os.environ.get("AZURE_SPEECH_REGION")
if not (
speech_host
or (speech_key and speech_region)
or (speech_auth_token and speech_region)
):
raise ValueError(
"AZURE_SPEECH_HOST or AZURE_SPEECH_KEY and AZURE_SPEECH_REGION or speech_auth_token and AZURE_SPEECH_REGION must be set"
)
if prosody:
prosody.validate()
if style:
style.validate()
self._opts = _TTSOptions(
sample_rate=sample_rate,
speech_key=speech_key,
speech_region=speech_region,
speech_host=speech_host,
speech_auth_token=speech_auth_token,
voice=voice,
endpoint_id=endpoint_id,
language=language,
prosody=prosody,
style=style,
on_bookmark_reached_event=on_bookmark_reached_event,
on_synthesis_canceled_event=on_synthesis_canceled_event,
on_synthesis_completed_event=on_synthesis_completed_event,
on_synthesis_started_event=on_synthesis_started_event,
on_synthesizing_event=on_synthesizing_event,
on_viseme_event=on_viseme_event,
on_word_boundary_event=on_word_boundary_event,
)
def update_options(
self,
*,
voice: str | None = None,
language: str | None = None,
prosody: ProsodyConfig | None = None,
style: StyleConfig | None = None,
on_bookmark_reached_event: Callable | None = None,
on_synthesis_canceled_event: Callable | None = None,
on_synthesis_completed_event: Callable | None = None,
on_synthesis_started_event: Callable | None = None,
on_synthesizing_event: Callable | None = None,
on_viseme_event: Callable | None = None,
on_word_boundary_event: Callable | None = None,
) -> None:
self._opts.voice = voice or self._opts.voice
self._opts.language = language or self._opts.language
self._opts.prosody = prosody or self._opts.prosody
self._opts.style = style or self._opts.style
self._opts.on_bookmark_reached_event = (
on_bookmark_reached_event or self._opts.on_bookmark_reached_event
)
self._opts.on_synthesis_canceled_event = (
on_synthesis_canceled_event or self._opts.on_synthesis_canceled_event
)
self._opts.on_synthesis_completed_event = (
on_synthesis_completed_event or self._opts.on_synthesis_completed_event
)
self._opts.on_synthesis_started_event = (
on_synthesis_started_event or self._opts.on_synthesis_started_event
)
self._opts.on_synthesizing_event = (
on_synthesizing_event or self._opts.on_synthesizing_event
)
self._opts.on_viseme_event = on_viseme_event or self._opts.on_viseme_event
self._opts.on_word_boundary_event = (
on_word_boundary_event or self._opts.on_word_boundary_event
)
def synthesize(
self,
text: str,
*,
conn_options: Optional[APIConnectOptions] = None,
) -> "ChunkedStream":
return ChunkedStream(
tts=self, input_text=text, conn_options=conn_options, opts=self._opts
)
class ChunkedStream(tts.ChunkedStream):
def __init__(
self,
*,
tts: TTS,
input_text: str,
opts: _TTSOptions,
conn_options: Optional[APIConnectOptions] = None,
) -> None:
super().__init__(tts=tts, input_text=input_text, conn_options=conn_options)
self._opts = opts
async def _run(self):
stream_callback = speechsdk.audio.PushAudioOutputStream(
_PushAudioOutputStreamCallback(
self._opts, asyncio.get_running_loop(), self._event_ch
)
)
synthesizer = _create_speech_synthesizer(
config=self._opts,
stream=stream_callback,
)
def _synthesize() -> speechsdk.SpeechSynthesisResult:
if self._opts.prosody or self._opts.style:
ssml = (
'<speak version="1.0" '
'xmlns="http://www.w3.org/2001/10/synthesis" '
'xmlns:mstts="http://www.w3.org/2001/mstts" '
f'xml:lang="{self._opts.language or "en-US"}">'
)
ssml += f'<voice name="{self._opts.voice}">'
# Add style if specified
if self._opts.style:
style_degree = (
f' styledegree="{self._opts.style.degree}"'
if self._opts.style.degree
else ""
)
ssml += f'<mstts:express-as style="{self._opts.style.style}"{style_degree}>'
# Add prosody if specified
if self._opts.prosody:
ssml += "<prosody"
if self._opts.prosody.rate:
ssml += f' rate="{self._opts.prosody.rate}"'
if self._opts.prosody.volume:
ssml += f' volume="{self._opts.prosody.volume}"'
if self._opts.prosody.pitch:
ssml += f' pitch="{self._opts.prosody.pitch}"'
ssml += ">"
ssml += self._input_text
ssml += "</prosody>"
else:
ssml += self._input_text
# Close style tag if it was opened
if self._opts.style:
ssml += "</mstts:express-as>"
ssml += "</voice></speak>"
return synthesizer.speak_ssml_async(ssml).get() # type: ignore
return synthesizer.speak_text_async(self.input_text).get() # type: ignore
result = None
try:
result = await asyncio.to_thread(_synthesize)
if result.reason != speechsdk.ResultReason.SynthesizingAudioCompleted:
if (
result.cancellation_details.error_code
== speechsdk.CancellationErrorCode.ServiceTimeout
):
raise APITimeoutError()
else:
cancel_details = result.cancellation_details
raise APIConnectionError(cancel_details.error_details)
finally:
def _cleanup() -> None:
# cleanup resources inside an Executor
# to avoid blocking the event loop
nonlocal synthesizer, stream_callback, result
del synthesizer
del stream_callback
if result is not None:
del result
try:
await asyncio.to_thread(_cleanup)
except Exception:
logger.exception("failed to cleanup Azure TTS resources")
class _PushAudioOutputStreamCallback(speechsdk.audio.PushAudioOutputStreamCallback):
def __init__(
self,
opts: _TTSOptions,
loop: asyncio.AbstractEventLoop,
event_ch: utils.aio.ChanSender[tts.SynthesizedAudio],
):
super().__init__()
self._event_ch = event_ch
self._opts = opts
self._loop = loop
self._request_id = utils.shortuuid()
self._bstream = utils.audio.AudioByteStream(
sample_rate=opts.sample_rate, num_channels=1
)
def write(self, audio_buffer: memoryview) -> int:
for frame in self._bstream.write(audio_buffer.tobytes()):
audio = tts.SynthesizedAudio(
request_id=self._request_id,
frame=frame,
)
with contextlib.suppress(RuntimeError):
self._loop.call_soon_threadsafe(self._event_ch.send_nowait, audio)
return audio_buffer.nbytes
def close(self) -> None:
for frame in self._bstream.flush():
audio = tts.SynthesizedAudio(
request_id=self._request_id,
frame=frame,
)
with contextlib.suppress(RuntimeError):
self._loop.call_soon_threadsafe(self._event_ch.send_nowait, audio)
def _create_speech_synthesizer(
*, config: _TTSOptions, stream: speechsdk.audio.AudioOutputStream
) -> speechsdk.SpeechSynthesizer:
# let the SpeechConfig constructor to validate the arguments
speech_config = speechsdk.SpeechConfig(
subscription=config.speech_key,
region=config.speech_region,
endpoint=config.speech_endpoint,
host=config.speech_host,
auth_token=config.speech_auth_token,
speech_recognition_language=config.language or "en-US",
)
speech_config.set_speech_synthesis_output_format(
SUPPORTED_SAMPLE_RATE[config.sample_rate]
)
stream_config = speechsdk.audio.AudioOutputConfig(stream=stream)
if config.voice is not None:
speech_config.speech_synthesis_voice_name = config.voice
if config.endpoint_id is not None:
speech_config.endpoint_id = config.endpoint_id
synthesizer = speechsdk.SpeechSynthesizer(
speech_config=speech_config, audio_config=stream_config
)
if config.on_bookmark_reached_event:
synthesizer.bookmark_reached.connect(config.on_bookmark_reached_event)
if config.on_synthesis_canceled_event:
synthesizer.synthesis_canceled.connect(config.on_synthesis_canceled_event)
if config.on_synthesis_completed_event:
synthesizer.synthesis_completed.connect(config.on_synthesis_completed_event)
if config.on_synthesis_started_event:
synthesizer.synthesis_started.connect(config.on_synthesis_started_event)
if config.on_synthesizing_event:
synthesizer.synthesizing.connect(config.on_synthesizing_event)
if config.on_viseme_event:
synthesizer.viseme_received.connect(config.on_viseme_event)
if config.on_word_boundary_event:
synthesizer.synthesis_word_boundary.connect(config.on_word_boundary_event)
return synthesizer
# Copyright 2024 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "0.5.7"
{
"name": "livekit-plugins-azure",
"private": true,
"version": "0.5.7"
}
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pathlib
import setuptools
import setuptools.command.build_py
here = pathlib.Path(__file__).parent.resolve()
about = {}
with open(os.path.join(here, "livekit", "plugins", "azure", "version.py"), "r") as f:
exec(f.read(), about)
setuptools.setup(
name="livekit-plugins-azure",
version=about["__version__"],
description="Agent Framework plugin for services from Azure",
long_description=(here / "README.md").read_text(encoding="utf-8"),
long_description_content_type="text/markdown",
url="https://github.com/livekit/agents",
cmdclass={},
classifiers=[
"Intended Audience :: Developers",
"License :: OSI Approved :: Apache Software License",
"Topic :: Multimedia :: Sound/Audio",
"Topic :: Multimedia :: Video",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3 :: Only",
],
keywords=["webrtc", "realtime", "audio", "video", "livekit"],
license="Apache-2.0",
packages=setuptools.find_namespace_packages(include=["livekit.*"]),
python_requires=">=3.9.0",
install_requires=[
"livekit-agents>=0.12.16,<1.0.0",
"azure-cognitiveservices-speech>=1.43.0",
],
package_data={},
project_urls={
"Documentation": "https://docs.livekit.io",
"Website": "https://livekit.io/",
"Source": "https://github.com/livekit/agents",
},
)
# Defines the Chromium style for automatic reformatting.
# http://clang.llvm.org/docs/ClangFormatStyleOptions.html
BasedOnStyle: Chromium
---
Language: ObjC
BasedOnStyle: Google
BinPackParameters: false
BinPackArguments: false
ColumnLimit: 100
ObjCBlockIndentWidth: 2
AllowAllParametersOfDeclarationOnNextLine: true
AlignOperands: false
AlwaysBreakBeforeMultilineStrings: false
AllowShortFunctionsOnASingleLine: Inline
BreakBeforeTernaryOperators: false
IndentWrappedFunctionNames: true
ContinuationIndentWidth: 4
ObjCSpaceBeforeProtocolList: true
---
Language: Cpp
IncludeBlocks: Regroup
# livekit-plugins-browser
## 0.0.6
### Patch Changes
- updated livekit-agent reference to <1.0 - [#1607](https://github.com/livekit/agents/pull/1607) ([@davidzhao](https://github.com/davidzhao))
## 0.0.5
### Patch Changes
- fix: fix `imgui` setup - [#1226](https://github.com/livekit/agents/pull/1226) ([@mbukeRepo](https://github.com/mbukeRepo))
## 0.0.4
### Patch Changes
- fix release - [#1176](https://github.com/livekit/agents/pull/1176) ([@theomonnom](https://github.com/theomonnom))
## 0.0.3
### Patch Changes
- pipelineagent: expose timing metrics & api errors wip - [#957](https://github.com/livekit/agents/pull/957) ([@theomonnom](https://github.com/theomonnom))
## 0.0.2
### Patch Changes
- livekit-plugins-browser: prepare for release - [#659](https://github.com/livekit/agents/pull/659) ([@theomonnom](https://github.com/theomonnom))
cmake_minimum_required(VERSION 3.19)
set(CMAKE_CONFIGURATION_TYPES Debug Release)
project(livekit-cef)
set_property(GLOBAL PROPERTY OS_FOLDERS ON)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON) # useful for clangd as the language server
set(USE_SANDBOX OFF) # TODO(theomonnom): I don't think we want to enable sandbox
# for now, it add complexity
# Specify the CEF distribution version.
if(NOT DEFINED CEF_VERSION)
# set(CEF_VERSION "122.1.10+gc902316+chromium-122.0.6261.112")
set(CEF_VERSION "127.3.5+g114ea2a+chromium-127.0.6533.120")
endif()
if("${CMAKE_SYSTEM_NAME}" STREQUAL "Darwin")
if("${PROJECT_ARCH}" STREQUAL "arm64")
set(CEF_PLATFORM "macosarm64")
elseif("${PROJECT_ARCH}" STREQUAL "x86_64")
set(CEF_PLATFORM "macosx64")
elseif("${CMAKE_HOST_SYSTEM_PROCESSOR}" STREQUAL "arm64")
set(PROJECT_ARCH "arm64")
set(CEF_PLATFORM "macosarm64")
else()
set(PROJECT_ARCH "x86_64")
set(CEF_PLATFORM "macosx64")
endif()
elseif("${CMAKE_SYSTEM_NAME}" STREQUAL "Linux")
if(CMAKE_SIZEOF_VOID_P MATCHES 8)
set(CEF_PLATFORM "linux64")
else()
set(CEF_PLATFORM "linux32")
endif()
elseif("${CMAKE_SYSTEM_NAME}" STREQUAL "Windows")
if(CMAKE_SIZEOF_VOID_P MATCHES 8)
set(CEF_PLATFORM "windows64")
else()
set(CEF_PLATFORM "windows32")
endif()
endif()
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
# Download and extract the CEF binary distribution (executes DownloadCEF.cmake).
include(DownloadCEF)
downloadcef("${CEF_PLATFORM}" "${CEF_VERSION}"
"${CMAKE_SOURCE_DIR}/third_party/cef")
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CEF_ROOT}/cmake")
# Load the CEF configuration (executes FindCEF.cmake).
find_package(CEF REQUIRED)
# Python
find_package(PythonInterp REQUIRED)
find_package(pybind11 REQUIRED)
message(STATUS "Using Python: ${PYTHON_EXECUTABLE}")
add_subdirectory(${CEF_LIBCEF_DLL_WRAPPER_PATH} libcef_dll_wrapper)
add_subdirectory(src)
print_cef_config()
// Copyright (c) 2008-2016 Marshall A. Greenblatt. Portions Copyright (c)
// 2006-2009 Google Inc. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the name Chromium Embedded
// Framework nor the names of its contributors may be used to endorse
// or promote products derived from this software without specific prior
// written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# LiveKit Plugins Browser
Chromium Embedded Framework (CEF) for LiveKit Agents
# Copyright (c) 2016 The Chromium Embedded Framework Authors. All rights
# reserved. Use of this source code is governed by a BSD-style license that
# can be found in the LICENSE file.
# Download the CEF binary distribution for |platform| and |version| to
# |download_dir|. The |CEF_ROOT| variable will be set in global scope pointing
# to the extracted location.
# Visit https://cef-builds.spotifycdn.com/index.html for the list of
# supported platforms and versions.
function(DownloadCEF platform version download_dir)
# Specify the binary distribution type and download directory.
set(CEF_DISTRIBUTION "cef_binary_${version}_${platform}")
set(CEF_DOWNLOAD_DIR "${download_dir}")
# The location where we expect the extracted binary distribution.
set(CEF_ROOT "${CEF_DOWNLOAD_DIR}/${CEF_DISTRIBUTION}" CACHE INTERNAL "CEF_ROOT")
# Download and/or extract the binary distribution if necessary.
if(NOT IS_DIRECTORY "${CEF_ROOT}")
set(CEF_DOWNLOAD_FILENAME "${CEF_DISTRIBUTION}.tar.bz2")
set(CEF_DOWNLOAD_PATH "${CEF_DOWNLOAD_DIR}/${CEF_DOWNLOAD_FILENAME}")
if(NOT EXISTS "${CEF_DOWNLOAD_PATH}")
set(CEF_DOWNLOAD_URL "https://cef-builds.spotifycdn.com/${CEF_DOWNLOAD_FILENAME}")
string(REPLACE "+" "%2B" CEF_DOWNLOAD_URL_ESCAPED ${CEF_DOWNLOAD_URL})
# Download the SHA1 hash for the binary distribution.
message(STATUS "Downloading ${CEF_DOWNLOAD_PATH}.sha1 from ${CEF_DOWNLOAD_URL_ESCAPED}...")
file(DOWNLOAD "${CEF_DOWNLOAD_URL_ESCAPED}.sha1" "${CEF_DOWNLOAD_PATH}.sha1")
file(READ "${CEF_DOWNLOAD_PATH}.sha1" CEF_SHA1)
# Download the binary distribution and verify the hash.
message(STATUS "Downloading ${CEF_DOWNLOAD_PATH}...")
file(
DOWNLOAD "${CEF_DOWNLOAD_URL_ESCAPED}" "${CEF_DOWNLOAD_PATH}"
EXPECTED_HASH SHA1=${CEF_SHA1}
SHOW_PROGRESS
)
endif()
# Extract the binary distribution.
message(STATUS "Extracting ${CEF_DOWNLOAD_PATH}...")
execute_process(
COMMAND ${CMAKE_COMMAND} -E tar xzf "${CEF_DOWNLOAD_DIR}/${CEF_DOWNLOAD_FILENAME}"
WORKING_DIRECTORY ${CEF_DOWNLOAD_DIR}
)
endif()
endfunction()
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from livekit.agents import Plugin
from .log import logger
from .proc import BrowserContext, BrowserPage
from .version import __version__
__all__ = ["BrowserContext", "BrowserPage"]
class BrowserPlugin(Plugin):
def __init__(self):
super().__init__(__name__, __version__, __package__, logger)
Plugin.register_plugin(BrowserPlugin())
# Cleanup docs of unexported modules
_module = dir()
NOT_IN_ALL = [m for m in _module if m not in __all__]
__pdoc__ = {}
for n in NOT_IN_ALL:
__pdoc__[n] = False
import logging
logger = logging.getLogger("livekit.plugins.browser")
from __future__ import annotations
import asyncio
import contextlib
import multiprocessing as mp
import multiprocessing.context as mpc
import multiprocessing.shared_memory as mp_shm
import socket
import tempfile
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import Callable, Literal
from livekit import rtc
from livekit.agents import ipc, utils
from . import logger, proc_main, proto
@dataclass
class _PageOptions:
page_id: int
url: str
width: int
height: int
framerate: int
EventTypes = Literal["paint"]
@dataclass
class PaintData:
dirty_rects: list[tuple[int, int, int, int]]
frame: rtc.VideoFrame
width: int
height: int
@dataclass
class BrowserOptions:
url: str
framerate: int
width: int
height: int
paint_callback: Callable[[PaintData], None]
class BrowserPage(utils.EventEmitter[EventTypes]):
def __init__(
self,
mp_ctx: mpc.SpawnContext,
opts: _PageOptions,
ctx_duplex: utils.aio.duplex_unix._AsyncDuplex,
) -> None:
super().__init__()
self._mp_ctx = mp_ctx
self._opts = opts
self._ctx_duplex = ctx_duplex
self._view_width = 0
self._view_height = 0
self._created_fut = asyncio.Future()
self._close_fut = asyncio.Future()
@property
def id(self) -> int:
return self._opts.page_id
async def start(self) -> None:
shm_name = f"lkcef_browser_{utils.shortuuid()}"
self._shm = mp_shm.SharedMemory(
create=True,
size=proto.SHM_MAX_WIDTH * proto.SHM_MAX_HEIGHT * 4,
name=shm_name,
)
self._framebuffer = rtc.VideoFrame(
proto.SHM_MAX_WIDTH,
proto.SHM_MAX_HEIGHT,
rtc.VideoBufferType.BGRA,
bytearray(proto.SHM_MAX_WIDTH * proto.SHM_MAX_HEIGHT * 4),
)
req = proto.CreateBrowserRequest(
page_id=self._opts.page_id,
width=self._opts.width,
height=self._opts.height,
shm_name=shm_name,
url=self._opts.url,
framerate=self._opts.framerate,
)
await ipc.channel.asend_message(self._ctx_duplex, req)
# TODO(theomonnom): create timeout (would prevent never resolving futures if the
# browser process crashed for some reasons)
await asyncio.shield(self._created_fut)
async def aclose(self) -> None:
await ipc.channel.asend_message(
self._ctx_duplex, proto.CloseBrowserRequest(page_id=self.id)
)
await asyncio.shield(self._close_fut)
self._shm.unlink()
self._shm.close()
async def _handle_created(self, msg: proto.CreateBrowserResponse) -> None:
self._created_fut.set_result(None)
async def _handle_paint(self, acq: proto.AcquirePaintData) -> None:
old_width = self._view_width
old_height = self._view_height
self._view_width = acq.width
self._view_height = acq.height
# TODO(theomonnom): remove hacky alloc-free resizing
self._framebuffer._width = acq.width
self._framebuffer._height = acq.height
proto.copy_paint_data(
acq, old_width, old_height, self._shm.buf, self._framebuffer.data
)
paint_data = PaintData(
dirty_rects=acq.dirty_rects,
frame=self._framebuffer,
width=acq.width,
height=acq.height,
)
self.emit("paint", paint_data)
release_paint = proto.ReleasePaintData(page_id=acq.page_id)
await ipc.channel.asend_message(self._ctx_duplex, release_paint)
async def _handle_close(self, msg: proto.BrowserClosed) -> None:
logger.debug("browser page closed", extra={"page_id": self.id})
self._close_fut.set_result(None)
class BrowserContext:
def __init__(self, *, dev_mode: bool, remote_debugging_port: int = 0) -> None:
self._mp_ctx = mp.get_context("spawn")
self._pages: dict[int, BrowserPage] = {}
self._dev_mode = dev_mode
self._initialized = False
self._next_page_id = 1
self._remote_debugging_port = remote_debugging_port
async def initialize(self) -> None:
mp_pch, mp_cch = socket.socketpair()
self._duplex = await utils.aio.duplex_unix._AsyncDuplex.open(mp_pch)
self._proc = self._mp_ctx.Process(target=proc_main.main, args=(mp_cch,))
self._proc.start()
mp_cch.close()
if not self._remote_debugging_port:
with contextlib.closing(
socket.socket(socket.AF_INET, socket.SOCK_STREAM)
) as s:
s.bind(("", 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self._remote_debugging_port = s.getsockname()[1]
logger.debug("using remote debugging port %d", self._remote_debugging_port)
await ipc.channel.asend_message(
self._duplex,
proto.InitializeContextRequest(
dev_mode=self._dev_mode,
remote_debugging_port=self._remote_debugging_port,
root_cache_path=tempfile.mkdtemp(), # TODO(theomonnom): cleanup
),
)
resp = await ipc.channel.arecv_message(self._duplex, proto.IPC_MESSAGES)
assert isinstance(resp, proto.ContextInitializedResponse)
self._initialized = True
logger.debug("browser context initialized", extra={"pid": self._proc.pid})
self._main_atask = asyncio.create_task(self._main_task(self._duplex))
@asynccontextmanager
async def playwright(self, timeout: float | None = None):
if not self._initialized:
raise RuntimeError("BrowserContext not initialized")
from playwright.async_api import async_playwright
async with async_playwright() as p:
url = f"http://localhost:{self._remote_debugging_port}"
browser = await p.chromium.connect_over_cdp(url, timeout=timeout)
try:
yield browser
finally:
await browser.close()
@utils.log_exceptions(logger)
async def _main_task(self, duplex: utils.aio.duplex_unix._AsyncDuplex) -> None:
while True:
try:
msg = await ipc.channel.arecv_message(duplex, proto.IPC_MESSAGES)
except utils.aio.duplex_unix.DuplexClosed:
break
if isinstance(msg, proto.CreateBrowserResponse):
page = self._pages[msg.page_id]
await page._handle_created(msg)
elif isinstance(msg, proto.AcquirePaintData):
page = self._pages[msg.page_id]
await page._handle_paint(msg)
elif isinstance(msg, proto.BrowserClosed):
page = self._pages[msg.page_id]
await page._handle_close(msg)
async def new_page(
self, *, url: str, width: int = 800, height: int = 600, framerate: int = 30
) -> BrowserPage:
if not self._initialized:
raise RuntimeError("BrowserContext not initialized")
page_id = self._next_page_id
self._next_page_id += 1
page = BrowserPage(
self._mp_ctx,
_PageOptions(
page_id=page_id,
url=url,
width=width,
height=height,
framerate=framerate,
),
self._duplex,
)
self._pages[page_id] = page
await page.start()
return page
import importlib.resources
import multiprocessing.shared_memory as mp_shm
import socket
import threading
from livekit.agents import ipc, utils
from . import logger, proto
class BrowserServer:
def __init__(
self,
duplex: utils.aio.duplex_unix._Duplex,
shm: mp_shm.SharedMemory,
page_id: int,
):
self._duplex = duplex
self._shm = shm
self._page_id = page_id
self._view_width = 0
self._view_height = 0
self._closing = False
self._release_paint_e = threading.Event()
@staticmethod
def create(
*,
duplex: utils.aio.duplex_unix._Duplex,
create_req: proto.CreateBrowserRequest,
browser_app,
) -> "BrowserServer":
logger.debug(
"creating browser",
extra={
"page_id": create_req.page_id,
"url": create_req.url,
"framerate": create_req.framerate,
"width": create_req.width,
"height": create_req.height,
"shm_name": create_req.shm_name,
},
)
import lkcef_python as lkcef
opts = lkcef.BrowserOptions()
opts.framerate = create_req.framerate
opts.width = create_req.width
opts.height = create_req.height
shm = mp_shm.SharedMemory(name=create_req.shm_name)
bserver = BrowserServer(duplex, shm, create_req.page_id)
opts.created_callback = bserver._browser_created
opts.paint_callback = bserver._paint
opts.close_callback = bserver._closed
browser_app.create_browser(create_req.url, opts)
return bserver
def _browser_created(self, impl):
browser_id = impl.identifier()
logger.debug(
"browser created",
extra={"browser_id": browser_id, "page_id": self._page_id},
)
self._impl = impl
try:
ipc.channel.send_message(
self._duplex,
proto.CreateBrowserResponse(
page_id=self._page_id, browser_id=browser_id
),
)
except utils.aio.duplex_unix.DuplexClosed:
logger.exception("failed to send CreateBrowserResponse")
def _paint(self, frame_data):
if self._closing:
return # make sure to not use the shm
acq = proto.AcquirePaintData()
acq.page_id = self._page_id
acq.width = frame_data.width
acq.height = frame_data.height
dirty_rects = []
for rect in frame_data.dirty_rects:
dirty_rects.append((rect.x, rect.y, rect.width, rect.height))
acq.dirty_rects = dirty_rects
old_width = self._view_width
old_height = self._view_height
self._view_width = frame_data.width
self._view_height = frame_data.height
proto.copy_paint_data(
acq, old_width, old_height, frame_data.buffer, self._shm.buf
)
try:
ipc.channel.send_message(self._duplex, acq)
self._release_paint_e.wait() # wait for release
self._release_paint_e.clear()
except utils.aio.duplex_unix.DuplexClosed:
logger.exception("failed to send AcquirePaintData")
def _closed(self) -> None:
ipc.channel.send_message(
self._duplex, proto.BrowserClosed(page_id=self._page_id)
)
def handle_release_paint(self, msg: proto.ReleasePaintData):
self._release_paint_e.set()
def handle_close(self, msg: proto.CloseBrowserRequest):
self._closing = True
self._impl.close()
def _manager_thread(duplex: utils.aio.duplex_unix._Duplex, browser_app):
browsers: dict[int, BrowserServer] = {}
while True:
try:
msg = ipc.channel.recv_message(duplex, proto.IPC_MESSAGES)
except utils.aio.duplex_unix.DuplexClosed:
break
if isinstance(msg, proto.CreateBrowserRequest):
server = BrowserServer.create(
duplex=duplex, create_req=msg, browser_app=browser_app
)
browsers[msg.page_id] = server
elif isinstance(msg, proto.ReleasePaintData):
server = browsers[msg.page_id]
server.handle_release_paint(msg)
elif isinstance(msg, proto.CloseBrowserRequest):
server = browsers[msg.page_id]
server.handle_close(msg)
del browsers[msg.page_id]
def main(mp_cch: socket.socket):
import lkcef_python as lkcef
duplex = utils.aio.duplex_unix._Duplex.open(mp_cch)
init_req = ipc.channel.recv_message(duplex, proto.IPC_MESSAGES)
assert isinstance(init_req, proto.InitializeContextRequest)
logger.debug("initializing browser context", extra={"dev_mode": init_req.dev_mode})
def _context_initialized():
try:
ipc.channel.send_message(duplex, proto.ContextInitializedResponse())
except utils.aio.duplex_unix.DuplexClosed:
logger.exception("failed to send ContextInitializedResponse")
opts = lkcef.AppOptions()
opts.dev_mode = init_req.dev_mode
opts.remote_debugging_port = init_req.remote_debugging_port
opts.root_cache_path = init_req.root_cache_path
opts.initialized_callback = _context_initialized
res = (
importlib.resources.files("livekit.plugins.browser.resources") / "lkcef_app.app"
)
with importlib.resources.as_file(res) as path:
opts.framework_path = str(
path / "Contents" / "Frameworks" / "Chromium Embedded Framework.framework"
)
opts.main_bundle_path = str(path)
opts.subprocess_path = str(
path
/ "Contents"
/ "Frameworks"
/ "lkcef Helper.app"
/ "Contents"
/ "MacOS"
/ "lkcef Helper"
)
app = lkcef.BrowserApp(opts)
man_t = threading.Thread(target=_manager_thread, args=(duplex, app))
man_t.start()
app.run() # run indefinitely
import io
from dataclasses import dataclass, field
from typing import ClassVar
import numpy as np
from livekit.agents.ipc import channel
# there is no risk to increase these values. just using these defaults for now
SHM_MAX_WIDTH = 1920
SHM_MAX_HEIGHT = 1080
@dataclass
class InitializeContextRequest:
MSG_ID: ClassVar[int] = 0
dev_mode: bool = False
remote_debugging_port: int = 0
root_cache_path: str = ""
def write(self, b: io.BytesIO) -> None:
channel.write_bool(b, self.dev_mode)
channel.write_int(b, self.remote_debugging_port)
channel.write_string(b, self.root_cache_path)
def read(self, b: io.BytesIO) -> None:
self.dev_mode = channel.read_bool(b)
self.remote_debugging_port = channel.read_int(b)
self.root_cache_path = channel.read_string(b)
@dataclass
class ContextInitializedResponse:
MSG_ID: ClassVar[int] = 1
@dataclass
class CreateBrowserRequest:
MSG_ID: ClassVar[int] = 2
page_id: int = -1
url: str = ""
framerate: int = 0
width: int = 0
height: int = 0
shm_name: str = ""
def write(self, b: io.BytesIO) -> None:
channel.write_int(b, self.page_id)
channel.write_string(b, self.url)
channel.write_int(b, self.framerate)
channel.write_int(b, self.width)
channel.write_int(b, self.height)
channel.write_string(b, self.shm_name)
def read(self, b: io.BytesIO) -> None:
self.page_id = channel.read_int(b)
self.url = channel.read_string(b)
self.framerate = channel.read_int(b)
self.width = channel.read_int(b)
self.height = channel.read_int(b)
self.shm_name = channel.read_string(b)
@dataclass
class CreateBrowserResponse:
"""
This is going to wait for the created_callback to be called.
(The create_browser function will be async)
"""
MSG_ID: ClassVar[int] = 3
page_id: int = -1
browser_id: int = 0
def write(self, b: io.BytesIO) -> None:
channel.write_int(b, self.page_id)
channel.write_int(b, self.browser_id)
def read(self, b: io.BytesIO) -> None:
self.page_id = channel.read_int(b)
self.browser_id = channel.read_int(b)
@dataclass
class AcquirePaintData:
MSG_ID: ClassVar[int] = 4
page_id: int = -1
width: int = 0
height: int = 0
dirty_rects: list[tuple[int, int, int, int]] = field(default_factory=list)
def write(self, b: io.BytesIO) -> None:
channel.write_int(b, self.page_id)
channel.write_int(b, self.width)
channel.write_int(b, self.height)
channel.write_int(b, len(self.dirty_rects))
for rect in self.dirty_rects:
channel.write_int(b, rect[0])
channel.write_int(b, rect[1])
channel.write_int(b, rect[2])
channel.write_int(b, rect[3])
def read(self, b: io.BytesIO) -> None:
self.page_id = channel.read_int(b)
self.width = channel.read_int(b)
self.height = channel.read_int(b)
num_rects = channel.read_int(b)
self.dirty_rects = []
for _ in range(num_rects):
x = channel.read_int(b)
y = channel.read_int(b)
width = channel.read_int(b)
height = channel.read_int(b)
self.dirty_rects.append((x, y, width, height))
@dataclass
class ReleasePaintData:
MSG_ID: ClassVar[int] = 5
page_id: int = -1
def write(self, b: io.BytesIO) -> None:
channel.write_int(b, self.page_id)
def read(self, b: io.BytesIO) -> None:
self.page_id = channel.read_int(b)
@dataclass
class CloseBrowserRequest:
MSG_ID: ClassVar[int] = 6
page_id: int = -1
def write(self, b: io.BytesIO) -> None:
channel.write_int(b, self.page_id)
def read(self, b: io.BytesIO) -> None:
self.page_id = channel.read_int(b)
@dataclass
class BrowserClosed:
MSG_ID: ClassVar[int] = 7
page_id: int = -1
def write(self, b: io.BytesIO) -> None:
channel.write_int(b, self.page_id)
def read(self, b: io.BytesIO) -> None:
self.page_id = channel.read_int(b)
IPC_MESSAGES = {
InitializeContextRequest.MSG_ID: InitializeContextRequest,
ContextInitializedResponse.MSG_ID: ContextInitializedResponse,
CreateBrowserRequest.MSG_ID: CreateBrowserRequest,
CreateBrowserResponse.MSG_ID: CreateBrowserResponse,
AcquirePaintData.MSG_ID: AcquirePaintData,
ReleasePaintData.MSG_ID: ReleasePaintData,
CloseBrowserRequest.MSG_ID: CloseBrowserRequest,
BrowserClosed.MSG_ID: BrowserClosed,
}
def copy_paint_data(
acq: AcquirePaintData,
old_width: int,
old_height: int,
source: memoryview,
dest: memoryview,
):
dirty_rects = acq.dirty_rects
# source_arr = np.frombuffer(source, dtype=np.uint32).reshape((acq.height, acq.width))
source_arr = np.ndarray(
(acq.height, acq.width),
dtype=np.uint32,
buffer=source,
)
dest_arr = np.ndarray(
(acq.height, acq.width),
dtype=np.uint32,
buffer=dest,
)
has_fullscreen_rect = len(dirty_rects) == 1 and dirty_rects[0] == (
0,
0,
acq.width,
acq.height,
)
if old_width != acq.width or old_height != acq.height or has_fullscreen_rect:
np.copyto(dest_arr, source_arr)
else:
for rect in dirty_rects:
x, y, w, h = rect
dest_arr[y : y + h, x : x + w] = source_arr[y : y + h, x : x + w]
"""Used by importlib.resources and setuptools"""
# Copyright 2023 LiveKit, Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "0.0.6"
{
"name": "livekit-plugins-browser",
"private": true,
"version": "0.0.6"
}
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
[tool.cibuildwheel.macos]
repair-wheel-command = "" # getting issues with unresolved files
[tool.cibuildwheel]
before-build = "pip install pybind11[global]"
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pathlib
import re
import subprocess
import sys
from pathlib import Path
import setuptools
from setuptools import Extension
from setuptools.command.build_ext import build_ext
here = pathlib.Path(__file__).parent.resolve()
about = {}
with open(os.path.join(here, "livekit", "plugins", "browser", "version.py"), "r") as f:
exec(f.read(), about)
class CMakeExtension(Extension):
def __init__(self, name: str, sourcedir: str = "") -> None:
super().__init__(name, sources=[])
self.sourcedir = os.fspath(Path(sourcedir).resolve())
class CMakeBuild(build_ext):
def build_extension(self, ext: CMakeExtension) -> None:
# Must be in this form due to bug in .resolve() only fixed in Python 3.10+
ext_fullpath = Path.cwd() / self.get_ext_fullpath(ext.name)
extdir = ext_fullpath.parent.resolve()
debug = int(os.environ.get("DEBUG", 0)) if self.debug is None else self.debug
cfg = "Debug" if debug else "Release"
cmake_args = [
f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}",
f"-DPYTHON_EXECUTABLE={sys.executable}",
f"-DCMAKE_BUILD_TYPE={cfg}",
]
print(f"cmake_args: {cmake_args}")
if sys.platform.startswith("darwin"):
# Cross-compile support for macOS - respect ARCHFLAGS if set
archs = re.findall(r"-arch (\S+)", os.environ.get("ARCHFLAGS", ""))
if archs:
cmake_args += ["-DCMAKE_OSX_ARCHITECTURES={}".format(";".join(archs))]
self.build_temp = Path(self.build_temp) / ext.name
if not self.build_temp.exists():
self.build_temp.mkdir(parents=True)
subprocess.run(
["cmake", ext.sourcedir, *cmake_args], cwd=self.build_temp, check=True
)
subprocess.run(["cmake", "--build", "."], cwd=self.build_temp, check=True)
build_output = self.build_temp / "src" / cfg
for f in build_output.iterdir():
if f.suffix == ".so":
self.copy_file(f, extdir / f.name)
if sys.platform.startswith("darwin"):
# on macos, copy the dummy app
app = build_output / "lkcef_app.app"
self.copy_tree(
app,
str(
extdir
/ "livekit"
/ "plugins"
/ "browser"
/ "resources"
/ "lkcef_app.app"
),
)
setuptools.setup(
name="livekit-plugins-browser",
version=about["__version__"],
description="Chromium Embedded Framework (CEF) for LiveKit Agents",
long_description=(here / "README.md").read_text(encoding="utf-8"),
long_description_content_type="text/markdown",
url="https://github.com/livekit/agents",
classifiers=[
"Intended Audience :: Developers",
"License :: OSI Approved :: Apache Software License",
"Topic :: Multimedia :: Sound/Audio",
"Topic :: Multimedia :: Video",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3 :: Only",
],
keywords=["webrtc", "realtime", "audio", "video", "livekit"],
license="Apache-2.0",
ext_modules=[CMakeExtension("lkcef_python")],
cmdclass={"build_ext": CMakeBuild},
packages=setuptools.find_namespace_packages(include=["livekit.*"]),
python_requires=">=3.9.0",
install_requires=["livekit-agents>=0.12.16,<1.0.0"],
package_data={
"livekit.plugins.browser": ["py.typed"],
"livekit.plugins.browser.resources": ["**", "lkcef_app.app"],
},
project_urls={
"Documentation": "https://docs.livekit.io",
"Website": "https://livekit.io/",
"Source": "https://github.com/livekit/agents",
},
)
include(FetchContent)
set(FETCHCONTENT_QUIET off)
# I don't want to write a different code per platform for the dev mode.
# so use glfw and imgui like I do for my other side projects...
set(GLFW_BUILD_DOCS OFF CACHE BOOL "" FORCE)
set(GLFW_BUILD_EXAMPLES OFF CACHE BOOL "" FORCE)
set(GLFW_BUILD_TESTS OFF CACHE BOOL "" FORCE)
set(GLFW_INSTALL OFF CACHE BOOL "" FORCE)
FetchContent_Declare(glfw GIT_REPOSITORY https://github.com/glfw/glfw.git GIT_TAG 3.4)
FetchContent_MakeAvailable(glfw)
FetchContent_Declare(
imgui
GIT_REPOSITORY https://github.com/ocornut/imgui
GIT_TAG origin/docking
GIT_SHALLOW TRUE
)
FetchContent_GetProperties(imgui)
FetchContent_Populate(imgui)
FetchContent_MakeAvailable(imgui)
file(GLOB IMGUI_SOURCES ${imgui_SOURCE_DIR}/*.cpp)
add_library(imgui STATIC ${IMGUI_SOURCES}
${imgui_SOURCE_DIR}/backends/imgui_impl_glfw.cpp
${imgui_SOURCE_DIR}/backends/imgui_impl_opengl3.cpp
${imgui_SOURCE_DIR}/misc/cpp/imgui_stdlib.cpp
)
set_target_properties(imgui PROPERTIES CXX_STANDARD 17)
target_include_directories(imgui PUBLIC ${imgui_SOURCE_DIR} ${imgui_SOURCE_DIR}/misc/cpp ${imgui_SOURCE_DIR}/backends ${GLFW_INCLUDE_DIR})
target_link_libraries(imgui PRIVATE glfw)
set(LKCEF_SRCS app.cpp app.hpp handler.hpp handler.cpp dev_renderer.hpp dev_renderer.cpp gleq.h browser_handle.hpp browser_handle.cpp)
set(LKCEF_SRCS_LINUX main_linux.cpp)
set(LKCEF_SRCS_MAC app_mac.mm)
set(LKCEF_SRCS_WINDOWS main_win.cpp )
append_platform_sources(LKCEF_SRCS)
source_group(lkcef FILES ${LKCEF_SRCS})
set(LKCEF_HELPER_SRCS )
set(LKCEF_HELPER_SRCS_LINUX helper_main_linux.cpp)
set(LKCEF_HELPER_SRCS_MAC helper_main_mac.mm)
set(LKCEF_HELPER_SRCS_WINDOWS helper_main_win.cpp)
append_platform_sources(LKCEF_HELPER_SRCS)
source_group(lkcef FILES ${LKCEF_HELPER_SRCS})
set(LKCEF_PYTHON_SRCS agents_python.hpp
agents_python.cpp)
if(OS_LINUX OR OS_WINDOWS)
# Logical target used to link the libcef library on Linux and Windows. On
# macOS the CEF framework is loaded dynamically at startup.
add_logical_target("libcef_lib" "${CEF_LIB_DEBUG}" "${CEF_LIB_RELEASE}")
endif()
set_cef_target_out_dir() # Determine the target output directory.
if(OS_LINUX)
# Helper executable target.
add_executable(lkcef_helper ${LKCEF_HELPER_SRCS})
set_executable_target_properties(lkcef_helper)
add_dependencies(lkcef_helper libcef_dll_wrapper)
target_link_libraries(lkcef_helper libcef_lib libcef_dll_wrapper
${CEF_STANDARD_LIBS})
# Set rpath so that libraries can be placed next to the executable.
set_target_properties(lkcef_helper PROPERTIES INSTALL_RPATH "$ORIGIN")
set_target_properties(lkcef_helper PROPERTIES BUILD_WITH_INSTALL_RPATH TRUE)
# library target.
add_library(lkcef SHARED ${LKCEF_SRCS})
set_library_target_properties(lkcef)
add_dependencies(lkcef libcef_dll_wrapper lkcef_helper)
target_link_libraries(lkcef libcef_lib libcef_dll_wrapper
${CEF_STANDARD_LIBS})
# Set rpath so that libraries can be placed next to the library.
set_target_properties(lkcef PROPERTIES INSTALL_RPATH "$ORIGIN")
set_target_properties(lkcef PROPERTIES BUILD_WITH_INSTALL_RPATH TRUE)
# Copy binary and resource files to the target output directory.
copy_files("lkcef" "${CEF_BINARY_FILES}" "${CEF_BINARY_DIR}"
"${CEF_TARGET_OUT_DIR}")
copy_files("lkcef" "${CEF_RESOURCE_FILES}" "${CEF_RESOURCE_DIR}"
"${CEF_TARGET_OUT_DIR}")
endif()
if(OS_MAC)
# Avoid CMP0042 policy errors.
set(CMAKE_MACOSX_RPATH 1)
# Avoid CMP0068 policy errors.
if(POLICY CMP0068)
cmake_policy(SET CMP0068 NEW)
endif()
add_executable(lkcef_app MACOSX_BUNDLE dummy.cpp) # dummy app
set_target_properties(lkcef_app PROPERTIES
MACOSX_BUNDLE_INFO_PLIST "${CMAKE_CURRENT_SOURCE_DIR}/resources/lkcefapp-Info.plist"
OUTPUT_NAME "lkcef_app"
)
# library target.
add_library(lkcef STATIC ${LKCEF_SRCS})
set_library_target_properties(lkcef)
add_dependencies(lkcef libcef_dll_wrapper)
target_include_directories(lkcef PRIVATE ${GLFW_INCLUDE_DIR})
target_link_libraries(lkcef libcef_dll_wrapper ${CEF_STANDARD_LIBS} glfw imgui)
add_custom_command(
TARGET lkcef
POST_BUILD
# Copy the CEF framework into the main app bundle.
COMMAND
${CMAKE_COMMAND} -E copy_directory
"${CEF_BINARY_DIR}/Chromium Embedded Framework.framework"
"$<TARGET_BUNDLE_DIR:lkcef_app>/Contents/Frameworks/Chromium Embedded Framework.framework"
VERBATIM)
# Create the multiple Helper app bundle targets.
foreach(_suffix_list ${CEF_HELPER_APP_SUFFIXES})
# Convert to a list and extract the suffix values.
string(REPLACE ":" ";" _suffix_list ${_suffix_list})
list(GET _suffix_list 0 _name_suffix)
list(GET _suffix_list 1 _target_suffix)
list(GET _suffix_list 2 _plist_suffix)
# Define Helper target and output names.
set(_helper_target "lkcef_Helper${_target_suffix}")
set(_helper_output_name "lkcef Helper${_name_suffix}")
# Create Helper-specific variants of the helper-Info.plist file.
set(_helper_info_plist
"${CMAKE_CURRENT_BINARY_DIR}/lkcef-Info${_target_suffix}.plist")
file(READ "${CMAKE_CURRENT_SOURCE_DIR}/resources/lkcefhelper-Info.plist"
_plist_contents)
string(REPLACE "\${EXECUTABLE_NAME}" "${_helper_output_name}"
_plist_contents ${_plist_contents})
string(REPLACE "\${PRODUCT_NAME}" "${_helper_output_name}" _plist_contents
${_plist_contents})
string(REPLACE "\${BUNDLE_ID_SUFFIX}" "${_plist_suffix}" _plist_contents
${_plist_contents})
file(WRITE ${_helper_info_plist} ${_plist_contents})
# Create Helper executable target.
add_executable(${_helper_target} MACOSX_BUNDLE ${LKCEF_HELPER_SRCS})
set_executable_target_properties(${_helper_target})
add_dependencies(${_helper_target} libcef_dll_wrapper)
target_link_libraries(${_helper_target} libcef_dll_wrapper
${CEF_STANDARD_LIBS})
set_target_properties(
${_helper_target}
PROPERTIES MACOSX_BUNDLE_INFO_PLIST ${_helper_info_plist}
OUTPUT_NAME ${_helper_output_name})
# Add the Helper as a dependency of the main executable target.
add_dependencies(lkcef "${_helper_target}")
# Copy the Helper app bundle into the Frameworks directory.
add_custom_command(
TARGET lkcef
POST_BUILD
COMMAND
${CMAKE_COMMAND} -E copy_directory
"${CEF_TARGET_OUT_DIR}/${_helper_output_name}.app"
"$<TARGET_BUNDLE_DIR:lkcef_app>/Contents/Frameworks/${_helper_output_name}.app"
VERBATIM)
endforeach()
endif()
if(OS_WINDOWS)
# Helper executable target.
add_executable(lkcef_helper WIN32 ${LKCEF_HELPER_SRCS})
set_executable_target_properties(lkcef_helper)
add_dependencies(lkcef_helper libcef_dll_wrapper)
target_link_libraries(lkcef_helper libcef_lib libcef_dll_wrapper
${CEF_STANDARD_LIBS})
# library target.
add_library(lkcef SHARED ${LKCEF_SRCS})
set_library_target_properties(lkcef)
add_dependencies(lkcef libcef_dll_wrapper lkcef_helper)
target_link_libraries(lkcef libcef_lib libcef_dll_wrapper
${CEF_STANDARD_LIBS})
# Add the custom manifest files to the DLL and helper EXE.
add_windows_manifest("${CMAKE_CURRENT_SOURCE_DIR}" "lkcef" "dll")
add_windows_manifest("${CMAKE_CURRENT_SOURCE_DIR}" "lkcef_helper" "exe")
# Copy binary and resource files to the target output directory.
copy_files("lkcef" "${CEF_BINARY_FILES}" "${CEF_BINARY_DIR}"
"${CEF_TARGET_OUT_DIR}")
copy_files("lkcef" "${CEF_RESOURCE_FILES}" "${CEF_RESOURCE_DIR}"
"${CEF_TARGET_OUT_DIR}")
endif()
# TODO(theomonnom): should be pretty similar for NodeJS
pybind11_add_module(lkcef_python ${LKCEF_PYTHON_SRCS})
set_target_properties(lkcef_python PROPERTIES INSTALL_RPATH "$ORIGIN")
set_target_properties(lkcef_python PROPERTIES BUILD_WITH_INSTALL_RPATH TRUE)
target_include_directories(lkcef_python PRIVATE ${CEF_INCLUDE_PATH})
target_link_libraries(lkcef_python PUBLIC lkcef)
target_link_libraries(lkcef_python PUBLIC libcef_dll_wrapper ${CEF_STANDARD_LIBS})
add_dependencies(lkcef_python libcef_dll_wrapper)
#include "agents_python.hpp"
#include <pybind11/functional.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "app.hpp"
#include "include/base/cef_callback.h"
#include "include/internal/cef_mac.h"
#include "include/wrapper/cef_closure_task.h"
namespace py = pybind11;
BrowserApp::BrowserApp(const AppOptions& options) : options_(options) {
app_ = new AgentApp(options_.dev_mode, options.remote_debugging_port,
options.root_cache_path, options.framework_path,
options.main_bundle_path, options.subprocess_path,
options_.initialized_callback);
}
bool BrowserApp::CreateBrowser(const std::string& url,
const BrowserOptions& options) {
if (CefCurrentlyOn(TID_UI)) {
CreateBrowserOnUIThread(url, options);
return true;
}
// TODO(theomonnom): Document base::Unretained
CefPostTask(TID_UI, base::BindOnce(&BrowserApp::CreateBrowserOnUIThread,
base::Unretained(this), url, options));
return true;
}
void BrowserApp::CreateBrowserOnUIThread(const std::string& url,
const BrowserOptions& options) {
std::shared_ptr<BrowserImpl> browser_impl = std::make_shared<BrowserImpl>();
browsers_.push_back(browser_impl);
CefRefPtr<BrowserHandle> handle = app_->CreateBrowser(
url, options.framerate, options.width, options.height,
[options, browser_impl]() { options.created_callback(browser_impl); },
[options](std::vector<CefRect> dirtyRects, const void* buffer, int width,
int height) {
PaintData event{};
std::vector<PaintRect> rects;
rects.reserve(dirtyRects.size());
for (const auto& rect : dirtyRects) {
rects.push_back({rect.x, rect.y, rect.width, rect.height});
}
event.dirtyRect = rects;
event.buffer = buffer;
event.width = width;
event.height = height;
options.paint_callback(event);
},
options.close_callback);
browser_impl->handle = handle;
}
int BrowserApp::Run() {
return RunAgentApp(app_);
}
BrowserImpl::BrowserImpl() {}
void BrowserImpl::SetSize(int width, int height) {
if (handle)
handle->SetSize(width, height);
}
void BrowserImpl::Close() {
if (handle)
handle->Close();
}
int BrowserImpl::Identifier() const {
return handle->GetBrowser()->GetIdentifier();
}
py::memoryview paint_data_to_memoryview(const PaintData& event) {
return py::memoryview::from_buffer(
const_cast<uint32_t*>(static_cast<const uint32_t*>(event.buffer)),
{event.height * event.width}, {sizeof(uint32_t)}, true);
}
PYBIND11_MODULE(lkcef_python, m) {
// Isn't that fucking cool? llm using browsers
m.doc() = "Chromium Embedded Framework (CEF) for LiveKit Agents";
py::class_<AppOptions>(m, "AppOptions")
.def(py::init())
.def_readwrite("dev_mode", &AppOptions::dev_mode)
.def_readwrite("remote_debugging_port",
&AppOptions::remote_debugging_port)
.def_readwrite("root_cache_path", &AppOptions::root_cache_path)
.def_readwrite("framework_path", &AppOptions::framework_path)
.def_readwrite("main_bundle_path", &AppOptions::main_bundle_path)
.def_readwrite("subprocess_path", &AppOptions::subprocess_path)
.def_readwrite("initialized_callback", &AppOptions::initialized_callback);
py::class_<BrowserOptions>(m, "BrowserOptions")
.def(py::init())
.def_readwrite("framerate", &BrowserOptions::framerate)
.def_readwrite("width", &BrowserOptions::width)
.def_readwrite("height", &BrowserOptions::height)
.def_readwrite("created_callback", &BrowserOptions::created_callback)
.def_readwrite("paint_callback", &BrowserOptions::paint_callback)
.def_readwrite("close_callback", &BrowserOptions::close_callback);
py::class_<BrowserApp>(m, "BrowserApp")
.def(py::init<const AppOptions&>())
.def("create_browser", &BrowserApp::CreateBrowser)
.def("run", &BrowserApp::Run, py::call_guard<py::gil_scoped_release>());
py::class_<BrowserImpl, std::shared_ptr<BrowserImpl>>(m, "BrowserImpl")
.def("set_size", &BrowserImpl::SetSize)
.def("close", &BrowserImpl::Close)
.def("identifier", &BrowserImpl::Identifier);
py::class_<PaintRect>(m, "PaintRect")
.def_readwrite("x", &PaintRect::x)
.def_readwrite("y", &PaintRect::y)
.def_readwrite("width", &PaintRect::width)
.def_readwrite("height", &PaintRect::height);
py::class_<PaintData>(m, "PaintData")
.def(py::init())
.def_readwrite("dirty_rects", &PaintData::dirtyRect)
.def_readwrite("width", &PaintData::width)
.def_readwrite("height", &PaintData::height)
.def_property_readonly("buffer", [](const PaintData& event) {
return paint_data_to_memoryview(event);
});
}
#ifndef LKCEF_AGENTS_PYTHON_HPP
#define LKCEF_AGENTS_PYTHON_HPP
#include <functional>
#include <memory>
#include "app.hpp"
class BrowserImpl;
struct PaintData;
struct AppOptions {
bool dev_mode = false;
int remote_debugging_port = 0;
std::string root_cache_path;
std::string framework_path;
std::string main_bundle_path;
std::string subprocess_path;
std::function<void()> initialized_callback = nullptr;
};
struct BrowserOptions {
int framerate = 30;
int width = 800;
int height = 600;
std::function<void(std::shared_ptr<BrowserImpl>)> created_callback = nullptr;
std::function<void(const PaintData&)> paint_callback = nullptr;
std::function<void()> close_callback = nullptr;
};
struct BrowserApp {
BrowserApp(const AppOptions& options);
bool CreateBrowser(const std::string& url, const BrowserOptions& options);
void CreateBrowserOnUIThread(const std::string& url, const BrowserOptions& options);
int Run();
private:
AppOptions options_;
CefRefPtr<AgentApp> app_;
std::list<std::shared_ptr<BrowserImpl>> browsers_;
};
struct BrowserImpl {
BrowserImpl();
void SetSize(int width, int height);
void Close();
int Identifier() const;
CefRefPtr<BrowserHandle> handle = nullptr;
};
struct PaintRect {
int x = 0;
int y = 0;
int width = 0;
int height = 0;
};
struct PaintData {
std::vector<PaintRect> dirtyRect;
const void* buffer;
int width;
int height;
};
#endif // LKCEF_AGENTS_PYTHON_HPP
#include "app.hpp"
#include <iostream>
#include <string>
#include <utility>
#include "include/cef_command_line.h"
#include "include/views/cef_window.h"
#include "include/wrapper/cef_helpers.h"
AgentApp::AgentApp(bool dev_mode,
int remote_debugging_port,
std::string root_cache_path,
std::string framework_path,
std::string main_bundle_path,
std::string subprocess_path,
std::function<void()> initialized_callback)
: dev_mode_(dev_mode),
remote_debugging_port_(remote_debugging_port),
root_cache_path_(std::move(root_cache_path)),
framework_path_(std::move(framework_path)),
main_bundle_path_(std::move(main_bundle_path)),
subprocess_path_(std::move(subprocess_path)),
initialized_callback_(std::move(initialized_callback)) {
browser_store_ = CefRefPtr<BrowserStore>(new BrowserStore());
if (dev_mode)
dev_renderer_ = CefRefPtr<DevRenderer>(new DevRenderer(browser_store_));
}
void AgentApp::OnBeforeCommandLineProcessing(
const CefString& process_type,
CefRefPtr<CefCommandLine> command_line) {
command_line->AppendSwitch("--disable-gpu");
command_line->AppendSwitch("--disable-gpu-compositing");
command_line->AppendSwitch("--enable-chrome-runtime");
// command_line->AppendSwitch("--enable-begin-frame-scheduling");
}
void AgentApp::OnContextInitialized() {
CEF_REQUIRE_UI_THREAD(); // Main thread in our case
client_ =
CefRefPtr<AgentHandler>(new AgentHandler(browser_store_, dev_renderer_));
dev_client_ = CefRefPtr<DevToolsHandler>(new DevToolsHandler());
if (initialized_callback_)
initialized_callback_();
}
CefRefPtr<CefClient> AgentApp::GetDefaultClient() {
return client_;
}
CefRefPtr<BrowserHandle> AgentApp::CreateBrowser(
const std::string& url,
int framerate,
int width,
int height,
std::function<void()> created_callback,
std::function<void(std::vector<CefRect> dirtyRects,
const void* buffer,
int width,
int height)> paint_callback,
std::function<void()> close_callback) {
CEF_REQUIRE_UI_THREAD();
// windowInfo.SetAsWindowless(dev_renderer_->getNativeWindowHandle());
CefWindowInfo windowInfo;
windowInfo.SetAsWindowless(nullptr);
CefBrowserSettings settings;
settings.windowless_frame_rate = framerate;
settings.background_color = CefColorSetARGB(255, 255, 255, 255);
CefRefPtr<BrowserHandle> browser_handle =
new BrowserHandle(std::move(created_callback), std::move(paint_callback),
std::move(close_callback), width, height);
browser_store_->AddPendingHandle(browser_handle);
bool result = CefBrowserHost::CreateBrowser(windowInfo, client_, url,
settings, nullptr, nullptr);
if (!result) {
browser_store_->RemovePendingHandle(browser_handle);
return nullptr;
}
return browser_handle;
}
int AgentApp::Run() {
if (dev_mode_) {
dev_renderer_->Run();
} else {
CefRunMessageLoop();
}
// Close all browsers
return 0;
}
#ifndef LKCEF_APP_HPP
#define LKCEF_APP_HPP
#include "browser_handle.hpp"
#include "dev_renderer.hpp"
#include "handler.hpp"
#include "include/cef_app.h"
#include "include/cef_base.h"
#include "include/cef_browser_process_handler.h"
#include "include/cef_client.h"
#include "include/internal/cef_ptr.h"
class AgentApp : public CefApp, public CefBrowserProcessHandler {
public:
AgentApp(bool dev_mode,
int remote_debugging_port,
std::string root_cache_path,
std::string framework_path,
std::string main_bundle_path,
std::string subprocess_path,
std::function<void()> initialized_callback);
CefRefPtr<CefBrowserProcessHandler> GetBrowserProcessHandler() override {
return this;
}
void OnBeforeCommandLineProcessing(
const CefString& process_type,
CefRefPtr<CefCommandLine> command_line) override;
void OnContextInitialized() override;
CefRefPtr<CefClient> GetDefaultClient() override;
CefRefPtr<BrowserHandle> CreateBrowser(
const std::string& url,
int framerate,
int width,
int height,
std::function<void()> created_callback,
std::function<void(std::vector<CefRect> dirtyRect,
const void* buffer,
int width,
int height)> paint_callback,
std::function<void()> close_callback);
int Run();
bool IsDevMode() const { return dev_mode_; }
int GetRemoteDebuggingPort() const { return remote_debugging_port_; }
std::string GetRootCachePath() const { return root_cache_path_; }
std::string GetFrameworkPath() const { return framework_path_; }
std::string GetMainBundlePath() const { return main_bundle_path_; }
std::string GetSubprocessPath() const { return subprocess_path_; }
private:
IMPLEMENT_REFCOUNTING(AgentApp);
CefRefPtr<BrowserStore> browser_store_;
CefRefPtr<AgentHandler> client_;
CefRefPtr<DevToolsHandler> dev_client_;
CefRefPtr<DevRenderer> dev_renderer_;
bool dev_mode_;
int remote_debugging_port_;
std::string root_cache_path_;
std::string framework_path_;
std::string main_bundle_path_;
std::string subprocess_path_;
std::function<void()> initialized_callback_;
};
int RunAgentApp(CefRefPtr<AgentApp> app);
#endif // LKCEF_APP_HPP
#import <Cocoa/Cocoa.h>
#include <iostream>
#import <Cocoa/Cocoa.h>
#include <objc/runtime.h>
#include "app.hpp"
#include "handler.hpp"
#include "include/cef_application_mac.h"
#include "include/cef_command_line.h"
#include "include/wrapper/cef_library_loader.h"
BOOL g_handling_send_event = false;
@interface NSApplication (AgentsApplication) <CefAppProtocol>
- (BOOL)isHandlingSendEvent;
- (void)setHandlingSendEvent:(BOOL)handlingSendEvent;
- (void)_swizzled_sendEvent:(NSEvent*)event;
- (void)_swizzled_terminate:(id)sender;
@end
@implementation NSApplication (AgentsApplication)
// This selector is called very early during the application initialization.
+ (void)load {
NSLog(@"AgentsApplication::load");
// Swap NSApplication::sendEvent with _swizzled_sendEvent.
Method original = class_getInstanceMethod(self, @selector(sendEvent));
Method swizzled =
class_getInstanceMethod(self, @selector(_swizzled_sendEvent));
method_exchangeImplementations(original, swizzled);
Method originalTerm = class_getInstanceMethod(self, @selector(terminate:));
Method swizzledTerm =
class_getInstanceMethod(self, @selector(_swizzled_terminate:));
method_exchangeImplementations(originalTerm, swizzledTerm);
}
- (BOOL)isHandlingSendEvent {
return g_handling_send_event;
}
- (void)setHandlingSendEvent:(BOOL)handlingSendEvent {
g_handling_send_event = handlingSendEvent;
}
- (void)_swizzled_sendEvent:(NSEvent*)event {
CefScopedSendingEvent sendingEventScoper;
// Calls NSApplication::sendEvent due to the swizzling.
[self _swizzled_sendEvent:event];
}
- (void)_swizzled_terminate:(id)sender {
[self _swizzled_terminate:sender];
}
@end
// Entry point function for the browser process.
int RunAgentApp(CefRefPtr<AgentApp> app) {
CefMainArgs main_args(0, nullptr);
@autoreleasepool {
[NSApplication sharedApplication];
// If there was an invocation to NSApp prior to this method, then the NSApp
// will not be a AgentsApplication, but will instead be an NSApplication.
// This is undesirable and we must enforce that this doesn't happen.
CHECK([NSApp isKindOfClass:[NSApplication class]]);
std::string framework_lib = app->GetFrameworkPath() + "/Chromium Embedded Framework";
if (!cef_load_library(framework_lib.c_str())) {
std::cerr << "lkcef: Failed to load CEF library" << std::endl;
return 1;
}
CefSettings settings{};
settings.chrome_runtime = true;
settings.external_message_pump = app->IsDevMode();
settings.remote_debugging_port = app->GetRemoteDebuggingPort();
CefString(&settings.root_cache_path).FromString(app->GetRootCachePath());
CefString(&settings.framework_dir_path).FromString(app->GetFrameworkPath());
CefString(&settings.main_bundle_path).FromString(app->GetMainBundlePath());
CefString(&settings.browser_subprocess_path).FromString(app->GetSubprocessPath());
settings.no_sandbox = true; // No sandbox for MacOS, for livekit-agents,
// we're only going to support Linux
settings.windowless_rendering_enabled = true;
// Initialize the CEF browser process. May return false if initialization
// fails or if early exit is desired (for example, due to process singleton
// relaunch behavior).
if (!CefInitialize(main_args, settings, app.get(), nullptr)) {
std::cerr << "lkcef: Failed to initialize CEF" << std::endl;
// TODO(theomonnom): Use CefGetExitCode();
return 1;
}
app->Run();
CefShutdown();
cef_unload_library();
} // @autoreleasepool
return 0;
}
#include "browser_handle.hpp"
void BrowserHandle::SetSize(int width, int height) {
width_ = width;
height_ = height;
if (browser_)
browser_->GetHost()->WasResized();
}
void BrowserHandle::Close() {
if (browser_)
browser_->GetHost()->CloseBrowser(true);
}
#ifndef LKCEF_BROWSER_HANDLE_HPP
#define LKCEF_BROWSER_HANDLE_HPP
#include <list>
#include "include/cef_client.h"
#include "include/wrapper/cef_helpers.h"
class BrowserHandle : public CefBaseRefCounted {
public:
BrowserHandle(
std::function<void()> created_callback,
std::function<void(std::vector<CefRect> dirtyRects,
const void* buffer,
int width,
int height)> paint_callback,
std::function<void()> close_callback,
int width,
int height)
: created_callback_(std::move(created_callback)),
paint_callback_(std::move(paint_callback)),
close_callback_(std::move(close_callback)),
width_(width),
height_(height) {}
CefRefPtr<CefBrowser> browser_ = nullptr;
std::function<void()> created_callback_ = nullptr;
std::function<void(std::vector<CefRect> dirtyRect,
const void* buffer,
int width,
int height)>
paint_callback_ = nullptr;
std::function<void()> close_callback_ = nullptr;
void SetSize(int width, int height);
void Close();
int GetWidth() const { return width_; }
int GetHeight() const { return height_; }
CefRefPtr<CefBrowser> GetBrowser() const { return browser_; }
private:
int width_ = 0;
int height_ = 0;
IMPLEMENT_REFCOUNTING(BrowserHandle);
};
struct BrowserStore : public CefBaseRefCounted {
std::unordered_map<int, CefRefPtr<BrowserHandle>> browser_handles_;
std::list<CefRefPtr<BrowserHandle>> pending_handles_;
void AddPendingHandle(CefRefPtr<BrowserHandle> handle) {
CEF_REQUIRE_UI_THREAD();
pending_handles_.push_back(handle);
}
void RemovePendingHandle(CefRefPtr<BrowserHandle> handle) {
CEF_REQUIRE_UI_THREAD();
pending_handles_.remove(handle);
}
CefRefPtr<BrowserHandle> GetBrowserHandle(int identifier) {
CEF_REQUIRE_UI_THREAD();
return browser_handles_[identifier];
}
IMPLEMENT_REFCOUNTING(BrowserStore);
};
#endif // LKCEF_BROWSER_HANDLE_HPP
#include "dev_renderer.hpp"
#include <iostream>
#include "handler.hpp"
#define IMGUI_DEFINE_MATH_OPERATORS
#include "imgui.h"
#include "imgui_impl_glfw.h"
#include "imgui_impl_opengl3.h"
#include "imgui_stdlib.h"
#include "include/cef_app.h"
#include "include/wrapper/cef_helpers.h"
#include "keyboard_codes.h"
#define GLEQ_IMPLEMENTATION
#define GLEQ_STATIC
#include "gleq.h"
// DCHECK on gl errors.
#if DCHECK_IS_ON()
#define VERIFY_NO_ERROR \
{ \
int _gl_error = glGetError(); \
DCHECK(_gl_error == GL_NO_ERROR) << "glGetError returned " << _gl_error; \
}
#else
#define VERIFY_NO_ERROR
#endif
int glfw_key_to_cef_key(int glfwKey) {
switch (glfwKey) {
case GLFW_KEY_SPACE:
return WebCore::VK_SPACE;
case GLFW_KEY_APOSTROPHE:
return WebCore::VK_OEM_7;
case GLFW_KEY_COMMA:
return WebCore::VK_OEM_COMMA;
case GLFW_KEY_MINUS:
return WebCore::VK_OEM_MINUS;
case GLFW_KEY_PERIOD:
return WebCore::VK_OEM_PERIOD;
case GLFW_KEY_SLASH:
return WebCore::VK_OEM_2;
case GLFW_KEY_0:
return WebCore::VK_0;
case GLFW_KEY_1:
return WebCore::VK_1;
case GLFW_KEY_2:
return WebCore::VK_2;
case GLFW_KEY_3:
return WebCore::VK_3;
case GLFW_KEY_4:
return WebCore::VK_4;
case GLFW_KEY_5:
return WebCore::VK_5;
case GLFW_KEY_6:
return WebCore::VK_6;
case GLFW_KEY_7:
return WebCore::VK_7;
case GLFW_KEY_8:
return WebCore::VK_8;
case GLFW_KEY_9:
return WebCore::VK_9;
case GLFW_KEY_SEMICOLON:
return WebCore::VK_OEM_1;
case GLFW_KEY_EQUAL:
return WebCore::VK_OEM_PLUS;
case GLFW_KEY_A:
return WebCore::VK_A;
case GLFW_KEY_B:
return WebCore::VK_B;
case GLFW_KEY_C:
return WebCore::VK_C;
case GLFW_KEY_D:
return WebCore::VK_D;
case GLFW_KEY_E:
return WebCore::VK_E;
case GLFW_KEY_F:
return WebCore::VK_F;
case GLFW_KEY_G:
return WebCore::VK_G;
case GLFW_KEY_H:
return WebCore::VK_H;
case GLFW_KEY_I:
return WebCore::VK_I;
case GLFW_KEY_J:
return WebCore::VK_J;
case GLFW_KEY_K:
return WebCore::VK_K;
case GLFW_KEY_L:
return WebCore::VK_L;
case GLFW_KEY_M:
return WebCore::VK_M;
case GLFW_KEY_N:
return WebCore::VK_N;
case GLFW_KEY_O:
return WebCore::VK_O;
case GLFW_KEY_P:
return WebCore::VK_P;
case GLFW_KEY_Q:
return WebCore::VK_Q;
case GLFW_KEY_R:
return WebCore::VK_R;
case GLFW_KEY_S:
return WebCore::VK_S;
case GLFW_KEY_T:
return WebCore::VK_T;
case GLFW_KEY_U:
return WebCore::VK_U;
case GLFW_KEY_V:
return WebCore::VK_V;
case GLFW_KEY_W:
return WebCore::VK_W;
case GLFW_KEY_X:
return WebCore::VK_X;
case GLFW_KEY_Y:
return WebCore::VK_Y;
case GLFW_KEY_Z:
return WebCore::VK_Z;
case GLFW_KEY_LEFT_BRACKET:
return WebCore::VK_OEM_4;
case GLFW_KEY_BACKSLASH:
return WebCore::VK_OEM_5;
case GLFW_KEY_RIGHT_BRACKET:
return WebCore::VK_OEM_6;
case GLFW_KEY_GRAVE_ACCENT:
return WebCore::VK_OEM_3;
case GLFW_KEY_ESCAPE:
return WebCore::VK_ESCAPE;
case GLFW_KEY_ENTER:
return WebCore::VK_RETURN;
case GLFW_KEY_TAB:
return WebCore::VK_TAB;
case GLFW_KEY_BACKSPACE:
return WebCore::VK_BACK;
case GLFW_KEY_INSERT:
return WebCore::VK_INSERT;
case GLFW_KEY_DELETE:
return WebCore::VK_DELETE;
case GLFW_KEY_RIGHT:
return WebCore::VK_RIGHT;
case GLFW_KEY_LEFT:
return WebCore::VK_LEFT;
case GLFW_KEY_DOWN:
return WebCore::VK_DOWN;
case GLFW_KEY_UP:
return WebCore::VK_UP;
case GLFW_KEY_PAGE_UP:
return WebCore::VK_PRIOR;
case GLFW_KEY_PAGE_DOWN:
return WebCore::VK_NEXT;
case GLFW_KEY_HOME:
return WebCore::VK_HOME;
case GLFW_KEY_END:
return WebCore::VK_END;
case GLFW_KEY_CAPS_LOCK:
return WebCore::VK_CAPITAL;
case GLFW_KEY_SCROLL_LOCK:
return WebCore::VK_SCROLL;
case GLFW_KEY_NUM_LOCK:
return WebCore::VK_NUMLOCK;
case GLFW_KEY_PRINT_SCREEN:
return WebCore::VK_SNAPSHOT;
case GLFW_KEY_PAUSE:
return WebCore::VK_PAUSE;
case GLFW_KEY_F1:
return WebCore::VK_F1;
case GLFW_KEY_F2:
return WebCore::VK_F2;
case GLFW_KEY_F3:
return WebCore::VK_F3;
case GLFW_KEY_F4:
return WebCore::VK_F4;
case GLFW_KEY_F5:
return WebCore::VK_F5;
case GLFW_KEY_F6:
return WebCore::VK_F6;
case GLFW_KEY_F7:
return WebCore::VK_F7;
case GLFW_KEY_F8:
return WebCore::VK_F8;
case GLFW_KEY_F9:
return WebCore::VK_F9;
case GLFW_KEY_F10:
return WebCore::VK_F10;
case GLFW_KEY_F11:
return WebCore::VK_F11;
case GLFW_KEY_F12:
return WebCore::VK_F12;
// Add more cases as needed
default:
return WebCore::VK_UNKNOWN;
}
}
static uint32_t glfw_mods_to_cef_mods(int glfw_mods) {
uint32_t cef_flags = 0;
if (glfw_mods & 0x0001) { // GLFW_MOD_SHIFT
cef_flags |= (1 << 1); // EVENTFLAG_SHIFT_DOWN
}
if (glfw_mods & 0x0002) { // GLFW_MOD_CONTROL
cef_flags |= (1 << 2); // EVENTFLAG_CONTROL_DOWN
}
if (glfw_mods & 0x0004) { // GLFW_MOD_ALT
cef_flags |= (1 << 3); // EVENTFLAG_ALT_DOWN
}
if (glfw_mods & 0x0008) { // GLFW_MOD_SUPER
cef_flags |=
(1 << 7); // EVENTFLAG_COMMAND_DOWN (Super key -> Command on Mac)
}
if (glfw_mods & 0x0010) { // GLFW_MOD_CAPS_LOCK
cef_flags |= (1 << 0); // EVENTFLAG_CAPS_LOCK_ON
}
if (glfw_mods & 0x0020) { // GLFW_MOD_NUM_LOCK
cef_flags |= (1 << 8); // EVENTFLAG_NUM_LOCK_ON
}
return cef_flags;
}
static std::optional<CefBrowserHost::MouseButtonType> glfw_button_to_cef_button(
int button) {
switch (button) {
case GLFW_MOUSE_BUTTON_LEFT:
return CefBrowserHost::MouseButtonType::MBT_LEFT;
case GLFW_MOUSE_BUTTON_MIDDLE:
return CefBrowserHost::MouseButtonType::MBT_MIDDLE;
case GLFW_MOUSE_BUTTON_RIGHT:
return CefBrowserHost::MouseButtonType::MBT_RIGHT;
default:
return std::nullopt;
}
}
static void glfw_error_callback(int error, const char* description) {
fprintf(stderr, "GLFW Error %d: %s\n", error, description);
}
DevRenderer::DevRenderer(CefRefPtr<BrowserStore> browser_store)
: browser_store_(browser_store) {}
void DevRenderer::OnTitleChange(CefRefPtr<CefBrowser> browser,
const CefString& title) {
CEF_REQUIRE_UI_THREAD();
int identifier = browser->GetIdentifier();
BrowserData* data = &browser_data_[identifier];
data->title = title;
}
void DevRenderer::OnLoadingStateChange(CefRefPtr<CefBrowser> browser,
bool isLoading,
bool canGoBack,
bool canGoForward) {
if (!isLoading) {
int identifier = browser->GetIdentifier();
BrowserData* data = &browser_data_[identifier];
data->url = browser->GetMainFrame()->GetURL();
}
}
void DevRenderer::OnAfterCreated(CefRefPtr<CefBrowser> browser) {
CEF_REQUIRE_UI_THREAD();
int identifier = browser->GetIdentifier();
unsigned int texture_id;
glGenTextures(1, &texture_id);
VERIFY_NO_ERROR;
BrowserData data{};
data.browser = browser;
data.texture_id = texture_id;
browser_data_.insert({identifier, data});
glBindTexture(GL_TEXTURE_2D, texture_id);
VERIFY_NO_ERROR;
glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST);
VERIFY_NO_ERROR;
glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST);
}
void DevRenderer::OnPaint(CefRefPtr<CefBrowser> browser,
CefRenderHandler::PaintElementType type,
const CefRenderHandler::RectList& dirtyRects,
const void* buffer,
int width,
int height) {
CEF_REQUIRE_UI_THREAD();
if (type != CefRenderHandler::PaintElementType::PET_VIEW) {
return; // Ignore PET_POPUP for now, bc I'm lazy
}
int identifier = browser->GetIdentifier();
BrowserData* data = &browser_data_[identifier];
int old_width = data->view_width;
int old_height = data->view_height;
data->view_width = width;
data->view_height = height;
glBindTexture(GL_TEXTURE_2D, data->texture_id);
glPixelStorei(GL_UNPACK_ROW_LENGTH, width);
VERIFY_NO_ERROR;
bool has_fullscreen_rect =
dirtyRects.size() == 1 && dirtyRects[0] == CefRect(0, 0, width, height);
if (old_width != width || old_height != height || has_fullscreen_rect) {
glPixelStorei(GL_UNPACK_SKIP_PIXELS, 0);
VERIFY_NO_ERROR;
glPixelStorei(GL_UNPACK_SKIP_ROWS, 0);
VERIFY_NO_ERROR;
glTexImage2D(GL_TEXTURE_2D, 0, GL_RGBA, width, height, 0, GL_BGRA,
GL_UNSIGNED_INT_8_8_8_8_REV, buffer);
VERIFY_NO_ERROR;
} else {
CefRenderHandler::RectList::const_iterator i = dirtyRects.begin();
for (; i != dirtyRects.end(); ++i) {
const CefRect& rect = *i;
glPixelStorei(GL_UNPACK_SKIP_PIXELS, rect.x);
VERIFY_NO_ERROR;
glPixelStorei(GL_UNPACK_SKIP_ROWS, rect.y);
VERIFY_NO_ERROR;
glTexSubImage2D(GL_TEXTURE_2D, 0, rect.x, rect.y, rect.width, rect.height,
GL_BGRA, GL_UNSIGNED_INT_8_8_8_8_REV, buffer);
VERIFY_NO_ERROR;
}
}
}
void DevRenderer::OnBeforeClose(CefRefPtr<CefBrowser> browser) {
CEF_REQUIRE_UI_THREAD();
int identifier = browser->GetIdentifier();
BrowserData* data = &browser_data_[identifier];
glDeleteTextures(1, &data->texture_id);
browser_data_.erase(identifier);
}
void DevRenderer::Run() {
glfwSetErrorCallback(glfw_error_callback);
if (!glfwInit()) {
std::cerr << "Failed to initialize GLFW" << std::endl;
return;
}
gleqInit();
const char* glsl_version = "#version 150";
glfwWindowHint(GLFW_CONTEXT_VERSION_MAJOR, 3);
glfwWindowHint(GLFW_CONTEXT_VERSION_MINOR, 2);
glfwWindowHint(GLFW_OPENGL_PROFILE, GLFW_OPENGL_CORE_PROFILE);
glfwWindowHint(GLFW_OPENGL_FORWARD_COMPAT, GL_TRUE);
window_ =
glfwCreateWindow(800, 600, "livekit-plugins-browser (Development Window)",
nullptr, nullptr);
gleqTrackWindow(window_);
if (!window_) {
std::cerr << "Failed to create GLFW window" << std::endl;
glfwTerminate();
return;
}
glfwMakeContextCurrent(window_);
glfwSwapInterval(1); // Enable vsync
IMGUI_CHECKVERSION();
ImGui::CreateContext();
ImGuiIO& io = ImGui::GetIO();
io.ConfigFlags |= ImGuiConfigFlags_NavEnableKeyboard;
io.ConfigFlags |= ImGuiConfigFlags_DockingEnable;
ImGui_ImplGlfw_InitForOpenGL(window_, true);
ImGui_ImplOpenGL3_Init(glsl_version);
ImVec4 clear_color = ImVec4(0.03f, 0.03f, 0.03f, 1.0f);
while (!glfwWindowShouldClose(window_)) {
glfwPollEvents();
CefDoMessageLoopWork();
ImGui_ImplOpenGL3_NewFrame();
ImGui_ImplGlfw_NewFrame();
ImGui::NewFrame();
// Flags used for the "invisible" dockspace frame
ImGuiWindowFlags windowFlags =
ImGuiWindowFlags_NoDocking | ImGuiWindowFlags_NoTitleBar |
ImGuiWindowFlags_NoCollapse | ImGuiWindowFlags_NoResize |
ImGuiWindowFlags_NoMove | ImGuiWindowFlags_NoBringToFrontOnFocus |
ImGuiWindowFlags_NoNavFocus | ImGuiWindowFlags_NoBackground;
ImGuiViewport* viewport = ImGui::GetMainViewport();
ImGui::SetNextWindowPos(viewport->Pos);
ImGui::SetNextWindowSize(viewport->Size);
ImGui::SetNextWindowViewport(viewport->ID);
ImGui::PushStyleVar(ImGuiStyleVar_WindowRounding, 0);
ImGui::PushStyleVar(ImGuiStyleVar_WindowBorderSize, 0);
ImGui::PushStyleVar(ImGuiStyleVar_WindowPadding, ImVec2(0, 0));
ImGui::Begin("Editor", nullptr, windowFlags);
ImGui::PopStyleVar(3);
ImGui::DockSpace(ImGui::GetID("EditorDockSpace"), ImVec2(),
ImGuiDockNodeFlags_PassthruCentralNode);
// Focused browser input states
BrowserData* focused_browser = nullptr;
int browser_view_x = 0;
int browser_view_y = 0;
for (auto& [identifier, data] : browser_data_) {
std::string name =
(data.title.empty() ? "Browser #" + std::to_string(identifier)
: data.title) +
"###Browser" + std::to_string(identifier);
ImGui::PushStyleVar(ImGuiStyleVar_WindowPadding, ImVec2(0, 0));
if (ImGui::Begin(name.c_str())) {
ImGui::BeginDisabled(!data.browser->CanGoBack());
if (ImGui::ArrowButton("##BrowserBack", ImGuiDir_Left)) {
data.browser->GoBack();
}
ImGui::EndDisabled();
ImGui::SameLine();
ImGui::BeginDisabled(!data.browser->CanGoForward());
if (ImGui::ArrowButton("##BrowserForward", ImGuiDir_Right)) {
data.browser->GoForward();
}
ImGui::EndDisabled();
ImGui::SameLine();
if (ImGui::InputText("##BrowserURL", &data.url,
ImGuiInputTextFlags_EnterReturnsTrue)) {
data.browser->GetMainFrame()->LoadURL(data.url);
}
ImGui::SameLine();
if (ImGui::Button("Show DevTools")) {
CefWindowInfo windowInfo{};
CefBrowserSettings settings{};
data.browser->GetHost()->ShowDevTools(
windowInfo, DevToolsHandler::GetInstance(), settings, CefPoint());
}
ImVec2 size = ImGui::GetContentRegionAvail();
// Resize the browser view if needed
if (size.x > 0 && size.y > 0 &&
(data.view_width != static_cast<int>(size.x) ||
data.view_height != static_cast<int>(size.y))) {
browser_store_->GetBrowserHandle(identifier)
->SetSize(static_cast<int>(size.x), static_cast<int>(size.y));
}
ImVec2 cursor_pos = ImGui::GetCursorScreenPos();
bool is_focused = ImGui::IsWindowFocused();
if (is_focused) {
focused_browser = &data;
browser_view_x = static_cast<int>(cursor_pos.x);
browser_view_y = static_cast<int>(cursor_pos.y);
data.browser->GetHost()->SetFocus(true);
}
// Render the browser tex
ImGui::Image((ImTextureID)(intptr_t)data.texture_id,
ImVec2((float)data.view_width, (float)data.view_height));
}
ImGui::End();
ImGui::PopStyleVar();
}
GLEQevent event;
while (gleqNextEvent(&event)) {
switch (event.type) {
case GLEQ_CURSOR_MOVED:
case GLEQ_BUTTON_PRESSED:
case GLEQ_SCROLLED:
case GLEQ_BUTTON_RELEASED:
if (focused_browser) {
CefMouseEvent cef_event;
if (event.type == GLEQ_CURSOR_MOVED) {
cef_event.x = event.pos.x - browser_view_x;
cef_event.y = event.pos.y - browser_view_y;
focused_browser->browser->GetHost()->SendMouseMoveEvent(cef_event,
false);
} else if (event.type == GLEQ_SCROLLED) {
double xpos, ypos;
glfwGetCursorPos(window_, &xpos, &ypos);
cef_event.x = static_cast<int>(xpos) - browser_view_x;
cef_event.y = static_cast<int>(ypos) - browser_view_y;
static const int scrollbarPixelsPerTick = 20;
int scroll_x =
static_cast<int>(event.scroll.x * scrollbarPixelsPerTick);
int scroll_y =
static_cast<int>(event.scroll.y * scrollbarPixelsPerTick);
focused_browser->browser->GetHost()->SendMouseWheelEvent(
cef_event, scroll_x, scroll_y);
} else {
double xpos, ypos;
glfwGetCursorPos(window_, &xpos, &ypos);
cef_event.x = static_cast<int>(xpos) - browser_view_x;
cef_event.y = static_cast<int>(ypos) - browser_view_y;
cef_event.modifiers = glfw_mods_to_cef_mods(event.mouse.mods);
std::optional<CefBrowserHost::MouseButtonType> cef_button =
glfw_button_to_cef_button(event.mouse.button);
if (cef_button.has_value()) {
focused_browser->browser->GetHost()->SendMouseClickEvent(
cef_event, cef_button.value(),
event.type == GLEQ_BUTTON_RELEASED, 1);
}
}
}
break;
case GLEQ_KEY_PRESSED:
case GLEQ_KEY_RELEASED:
if (focused_browser) {
CefKeyEvent cef_event;
cef_event.windows_key_code =
glfw_key_to_cef_key(event.keyboard.key);
cef_event.native_key_code = event.keyboard.scancode;
cef_event.modifiers = glfw_mods_to_cef_mods(event.keyboard.mods);
cef_event.is_system_key = false;
if (event.type == GLEQ_KEY_PRESSED) {
cef_event.type = KEYEVENT_RAWKEYDOWN;
focused_browser->browser->GetHost()->SendKeyEvent(cef_event);
} else {
cef_event.type = KEYEVENT_KEYUP;
focused_browser->browser->GetHost()->SendKeyEvent(cef_event);
}
}
break;
case GLEQ_CODEPOINT_INPUT:
if (focused_browser) {
CefKeyEvent cef_event;
cef_event.type = KEYEVENT_CHAR;
cef_event.windows_key_code = 0;
cef_event.native_key_code = 0;
cef_event.modifiers = 0;
cef_event.is_system_key = false;
cef_event.unmodified_character = event.codepoint;
cef_event.character = event.codepoint;
focused_browser->browser->GetHost()->SendKeyEvent(cef_event);
}
break;
default:
break;
}
gleqFreeEvent(&event);
}
ImGui::End();
ImGui::Render();
int display_w, display_h;
glfwGetFramebufferSize(window_, &display_w, &display_h);
glViewport(0, 0, display_w, display_h);
glClearColor(clear_color.x * clear_color.w, clear_color.y * clear_color.w,
clear_color.z * clear_color.w, clear_color.w);
glClear(GL_COLOR_BUFFER_BIT);
ImGui_ImplOpenGL3_RenderDrawData(ImGui::GetDrawData());
glfwSwapBuffers(window_);
}
ImGui_ImplOpenGL3_Shutdown();
ImGui_ImplGlfw_Shutdown();
ImGui::DestroyContext();
glfwDestroyWindow(window_);
glfwTerminate();
}
void DevRenderer::Close() {
// glfwSetWindowShouldClose(window_, GLFW_TRUE);
}
#ifndef LKCEF_DEV_RENDERER_HPP
#define LKCEF_DEV_RENDERER_HPP
#include "include/cef_app.h"
#include "browser_handle.hpp"
#define GL_SILENCE_DEPRECATION
#include <GLFW/glfw3.h> // Will drag system OpenGL headers
#define GLFW_EXPOSE_NATIVE_COCOA
//#define GLFW_NATIVE_INCLUDE_NONE
#include <GLFW/glfw3native.h>
class DevRenderer: public CefBaseRefCounted {
public:
DevRenderer(CefRefPtr<BrowserStore> browser_store);
void Run();
void Close();
void OnTitleChange(CefRefPtr<CefBrowser> browser,
const CefString &title);
void OnLoadingStateChange(CefRefPtr<CefBrowser> browser,
bool isLoading,
bool canGoBack,
bool canGoForward);
void OnAfterCreated(CefRefPtr<CefBrowser> browser);
void OnPaint(CefRefPtr<CefBrowser> browser,
CefRenderHandler::PaintElementType type,
const CefRenderHandler::RectList&ts,
const void* buffer,
int width,
int height);
void OnBeforeClose(CefRefPtr<CefBrowser> browser);
void* getNativeWindowHandle() const {
return glfwGetCocoaWindow(window_);
}
private:
struct BrowserData{
CefRefPtr<CefBrowser> browser;
unsigned int texture_id;
int view_width;
int view_height;
std::string title;
std::string url;
};
GLFWwindow* window_ = nullptr;
std::unordered_map<int, BrowserData> browser_data_;
CefRefPtr<BrowserStore> browser_store_;
IMPLEMENT_REFCOUNTING(DevRenderer);
};
#endif // LKCEF_DEV_RENDERER_HPP
int main() {
return 0;
}
/*
* GLEQ - A basic event queue for GLFW 3
* Copyright © Camilla Löwy <elmindreda@glfw.org>
*
* This software is provided 'as-is', without any express or implied
* warranty. In no event will the authors be held liable for any damages
* arising from the use of this software.
*
* Permission is granted to anyone to use this software for any purpose,
* including commercial applications, and to alter it and redistribute it
* freely, subject to the following restrictions:
*
* 1. The origin of this software must not be misrepresented; you must not
* claim that you wrote the original software. If you use this software
* in a product, an acknowledgment in the product documentation would
* be appreciated but is not required.
*
* 2. Altered source versions must be plainly marked as such, and must not
* be misrepresented as being the original software.
*
* 3. This notice may not be removed or altered from any source
* distribution.
*/
#ifndef GLEQ_HEADER_FILE
#define GLEQ_HEADER_FILE
#include <GLFW/glfw3.h>
#ifdef GLEQ_STATIC
#define GLEQDEF static
#else
#define GLEQDEF extern
#endif
#ifdef __cplusplus
extern "C" {
#endif
typedef enum
{
GLEQ_NONE,
GLEQ_WINDOW_MOVED,
GLEQ_WINDOW_RESIZED,
GLEQ_WINDOW_CLOSED,
GLEQ_WINDOW_REFRESH,
GLEQ_WINDOW_FOCUSED,
GLEQ_WINDOW_DEFOCUSED,
GLEQ_WINDOW_ICONIFIED,
GLEQ_WINDOW_UNICONIFIED,
GLEQ_FRAMEBUFFER_RESIZED,
GLEQ_BUTTON_PRESSED,
GLEQ_BUTTON_RELEASED,
GLEQ_CURSOR_MOVED,
GLEQ_CURSOR_ENTERED,
GLEQ_CURSOR_LEFT,
GLEQ_SCROLLED,
GLEQ_KEY_PRESSED,
GLEQ_KEY_REPEATED,
GLEQ_KEY_RELEASED,
GLEQ_CODEPOINT_INPUT,
GLEQ_MONITOR_CONNECTED,
GLEQ_MONITOR_DISCONNECTED,
#if GLFW_VERSION_MINOR >= 1
GLEQ_FILE_DROPPED,
#endif
#if GLFW_VERSION_MINOR >= 2
GLEQ_JOYSTICK_CONNECTED,
GLEQ_JOYSTICK_DISCONNECTED,
#endif
#if GLFW_VERSION_MINOR >= 3
GLEQ_WINDOW_MAXIMIZED,
GLEQ_WINDOW_UNMAXIMIZED,
GLEQ_WINDOW_SCALE_CHANGED,
#endif
} GLEQtype;
typedef struct GLEQevent
{
GLEQtype type;
union {
GLFWwindow* window;
GLFWmonitor* monitor;
int joystick;
};
union {
struct {
int x;
int y;
} pos;
struct {
int width;
int height;
} size;
struct {
double x;
double y;
} scroll;
struct {
int key;
int scancode;
int mods;
} keyboard;
struct {
int button;
int mods;
} mouse;
unsigned int codepoint;
#if GLFW_VERSION_MINOR >= 1
struct {
char** paths;
int count;
} file;
#endif
#if GLFW_VERSION_MINOR >= 3
struct {
float x;
float y;
} scale;
#endif
};
} GLEQevent;
GLEQDEF void gleqInit(void);
GLEQDEF void gleqTrackWindow(GLFWwindow* window);
GLEQDEF int gleqNextEvent(GLEQevent* event);
GLEQDEF void gleqFreeEvent(GLEQevent* event);
#ifdef __cplusplus
}
#endif
#ifdef GLEQ_IMPLEMENTATION
#include <stdlib.h>
#include <string.h>
#include <assert.h>
#ifndef GLEQ_CAPACITY
#define GLEQ_CAPACITY 1024
#endif
static struct
{
GLEQevent events[GLEQ_CAPACITY];
size_t head;
size_t tail;
} gleq_queue = { {}, 0, 0 };
static char* gleq_strdup(const char* string)
{
const size_t size = strlen(string) + 1;
char* result = (char*) malloc(size);
memcpy(result, string, size);
return result;
}
static GLEQevent* gleq_new_event(void)
{
GLEQevent* event = gleq_queue.events + gleq_queue.head;
gleq_queue.head = (gleq_queue.head + 1) % GLEQ_CAPACITY;
assert(gleq_queue.head != gleq_queue.tail);
memset(event, 0, sizeof(GLEQevent));
return event;
}
static void gleq_window_pos_callback(GLFWwindow* window, int x, int y)
{
GLEQevent* event = gleq_new_event();
event->type = GLEQ_WINDOW_MOVED;
event->window = window;
event->pos.x = x;
event->pos.y = y;
}
static void gleq_window_size_callback(GLFWwindow* window, int width, int height)
{
GLEQevent* event = gleq_new_event();
event->type = GLEQ_WINDOW_RESIZED;
event->window = window;
event->size.width = width;
event->size.height = height;
}
static void gleq_window_close_callback(GLFWwindow* window)
{
GLEQevent* event = gleq_new_event();
event->type = GLEQ_WINDOW_CLOSED;
event->window = window;
}
static void gleq_window_refresh_callback(GLFWwindow* window)
{
GLEQevent* event = gleq_new_event();
event->type = GLEQ_WINDOW_REFRESH;
event->window = window;
}
static void gleq_window_focus_callback(GLFWwindow* window, int focused)
{
GLEQevent* event = gleq_new_event();
event->window = window;
if (focused)
event->type = GLEQ_WINDOW_FOCUSED;
else
event->type = GLEQ_WINDOW_DEFOCUSED;
}
static void gleq_window_iconify_callback(GLFWwindow* window, int iconified)
{
GLEQevent* event = gleq_new_event();
event->window = window;
if (iconified)
event->type = GLEQ_WINDOW_ICONIFIED;
else
event->type = GLEQ_WINDOW_UNICONIFIED;
}
static void gleq_framebuffer_size_callback(GLFWwindow* window, int width, int height)
{
GLEQevent* event = gleq_new_event();
event->type = GLEQ_FRAMEBUFFER_RESIZED;
event->window = window;
event->size.width = width;
event->size.height = height;
}
static void gleq_mouse_button_callback(GLFWwindow* window, int button, int action, int mods)
{
GLEQevent* event = gleq_new_event();
event->window = window;
event->mouse.button = button;
event->mouse.mods = mods;
if (action == GLFW_PRESS)
event->type = GLEQ_BUTTON_PRESSED;
else if (action == GLFW_RELEASE)
event->type = GLEQ_BUTTON_RELEASED;
}
static void gleq_cursor_pos_callback(GLFWwindow* window, double x, double y)
{
GLEQevent* event = gleq_new_event();
event->type = GLEQ_CURSOR_MOVED;
event->window = window;
event->pos.x = (int) x;
event->pos.y = (int) y;
}
static void gleq_cursor_enter_callback(GLFWwindow* window, int entered)
{
GLEQevent* event = gleq_new_event();
event->window = window;
if (entered)
event->type = GLEQ_CURSOR_ENTERED;
else
event->type = GLEQ_CURSOR_LEFT;
}
static void gleq_scroll_callback(GLFWwindow* window, double x, double y)
{
GLEQevent* event = gleq_new_event();
event->type = GLEQ_SCROLLED;
event->window = window;
event->scroll.x = x;
event->scroll.y = y;
}
static void gleq_key_callback(GLFWwindow* window, int key, int scancode, int action, int mods)
{
GLEQevent* event = gleq_new_event();
event->window = window;
event->keyboard.key = key;
event->keyboard.scancode = scancode;
event->keyboard.mods = mods;
if (action == GLFW_PRESS)
event->type = GLEQ_KEY_PRESSED;
else if (action == GLFW_RELEASE)
event->type = GLEQ_KEY_RELEASED;
else if (action == GLFW_REPEAT)
event->type = GLEQ_KEY_REPEATED;
}
static void gleq_char_callback(GLFWwindow* window, unsigned int codepoint)
{
GLEQevent* event = gleq_new_event();
event->type = GLEQ_CODEPOINT_INPUT;
event->window = window;
event->codepoint = codepoint;
}
static void gleq_monitor_callback(GLFWmonitor* monitor, int action)
{
GLEQevent* event = gleq_new_event();
event->monitor = monitor;
if (action == GLFW_CONNECTED)
event->type = GLEQ_MONITOR_CONNECTED;
else if (action == GLFW_DISCONNECTED)
event->type = GLEQ_MONITOR_DISCONNECTED;
}
#if GLFW_VERSION_MINOR >= 1
static void gleq_file_drop_callback(GLFWwindow* window, int count, const char** paths)
{
GLEQevent* event = gleq_new_event();
event->type = GLEQ_FILE_DROPPED;
event->window = window;
event->file.paths = (char**) malloc(count * sizeof(char*));
event->file.count = count;
while (count--)
event->file.paths[count] = gleq_strdup(paths[count]);
}
#endif
#if GLFW_VERSION_MINOR >= 2
static void gleq_joystick_callback(int jid, int action)
{
GLEQevent* event = gleq_new_event();
event->joystick = jid;
if (action == GLFW_CONNECTED)
event->type = GLEQ_JOYSTICK_CONNECTED;
else if (action == GLFW_DISCONNECTED)
event->type = GLEQ_JOYSTICK_DISCONNECTED;
}
#endif
#if GLFW_VERSION_MINOR >= 3
static void gleq_window_maximize_callback(GLFWwindow* window, int maximized)
{
GLEQevent* event = gleq_new_event();
event->window = window;
if (maximized)
event->type = GLEQ_WINDOW_MAXIMIZED;
else
event->type = GLEQ_WINDOW_UNMAXIMIZED;
}
static void gleq_window_content_scale_callback(GLFWwindow* window, float xscale, float yscale)
{
GLEQevent* event = gleq_new_event();
event->window = window;
event->type = GLEQ_WINDOW_SCALE_CHANGED;
event->scale.x = xscale;
event->scale.y = yscale;
}
#endif
GLEQDEF void gleqInit(void)
{
glfwSetMonitorCallback(gleq_monitor_callback);
#if GLFW_VERSION_MINOR >= 2
glfwSetJoystickCallback(gleq_joystick_callback);
#endif
}
GLEQDEF void gleqTrackWindow(GLFWwindow* window)
{
glfwSetWindowPosCallback(window, gleq_window_pos_callback);
glfwSetWindowSizeCallback(window, gleq_window_size_callback);
glfwSetWindowCloseCallback(window, gleq_window_close_callback);
glfwSetWindowRefreshCallback(window, gleq_window_refresh_callback);
glfwSetWindowFocusCallback(window, gleq_window_focus_callback);
glfwSetWindowIconifyCallback(window, gleq_window_iconify_callback);
glfwSetFramebufferSizeCallback(window, gleq_framebuffer_size_callback);
glfwSetMouseButtonCallback(window, gleq_mouse_button_callback);
glfwSetCursorPosCallback(window, gleq_cursor_pos_callback);
glfwSetCursorEnterCallback(window, gleq_cursor_enter_callback);
glfwSetScrollCallback(window, gleq_scroll_callback);
glfwSetKeyCallback(window, gleq_key_callback);
glfwSetCharCallback(window, gleq_char_callback);
#if GLFW_VERSION_MINOR >= 1
glfwSetDropCallback(window, gleq_file_drop_callback);
#endif
#if GLFW_VERSION_MINOR >= 3
glfwSetWindowMaximizeCallback(window, gleq_window_maximize_callback);
glfwSetWindowContentScaleCallback(window, gleq_window_content_scale_callback);
#endif
}
GLEQDEF int gleqNextEvent(GLEQevent* event)
{
memset(event, 0, sizeof(GLEQevent));
if (gleq_queue.head != gleq_queue.tail)
{
*event = gleq_queue.events[gleq_queue.tail];
gleq_queue.tail = (gleq_queue.tail + 1) % GLEQ_CAPACITY;
}
return event->type != GLEQ_NONE;
}
GLEQDEF void gleqFreeEvent(GLEQevent* event)
{
#if GLFW_VERSION_MINOR >= 1
if (event->type == GLEQ_FILE_DROPPED)
{
while (event->file.count--)
free(event->file.paths[event->file.count]);
free(event->file.paths);
}
#endif
memset(event, 0, sizeof(GLEQevent));
}
#endif /* GLEQ_IMPLEMENTATION */
#endif /* GLEQ_HEADER_FILE */
#include "handler.hpp"
#include <iostream>
#include "include/base/cef_callback.h"
#include "include/cef_parser.h"
#include "include/views/cef_browser_view.h"
#include "include/wrapper/cef_closure_task.h"
#include "include/wrapper/cef_helpers.h"
DevToolsHandler* g_dev_instance = nullptr;
DevToolsHandler::DevToolsHandler() {
g_dev_instance = this;
}
DevToolsHandler::~DevToolsHandler() {
g_dev_instance = nullptr;
}
DevToolsHandler* DevToolsHandler::GetInstance() {
return g_dev_instance;
}
AgentHandler* g_instance = nullptr;
AgentHandler::AgentHandler(CefRefPtr<BrowserStore> browser_store,
CefRefPtr<DevRenderer> dev_renderer)
: browser_store_(std::move(browser_store)),
dev_renderer_(std::move(dev_renderer)) {
g_instance = this;
}
AgentHandler::~AgentHandler() {
g_instance = nullptr;
}
AgentHandler* AgentHandler::GetInstance() {
return g_instance;
}
void AgentHandler::OnTitleChange(CefRefPtr<CefBrowser> browser,
const CefString& title) {
CEF_REQUIRE_UI_THREAD();
if (dev_renderer_)
dev_renderer_->OnTitleChange(browser, title);
}
void AgentHandler::OnPaint(CefRefPtr<CefBrowser> browser,
PaintElementType type,
const RectList& dirtyRects,
const void* buffer,
int width,
int height) {
CEF_REQUIRE_UI_THREAD();
int identifier = browser->GetIdentifier();
CefRefPtr<BrowserHandle> handle =
browser_store_->browser_handles_[identifier];
if (handle->paint_callback_)
handle->paint_callback_(dirtyRects, buffer, width, height);
if (dev_renderer_)
dev_renderer_->OnPaint(browser, type, dirtyRects, buffer, width, height);
}
void AgentHandler::GetViewRect(CefRefPtr<CefBrowser> browser, CefRect& rect) {
CEF_REQUIRE_UI_THREAD();
int identifier = browser->GetIdentifier();
CefRefPtr<BrowserHandle>& handle =
browser_store_->browser_handles_[identifier];
rect.Set(0, 0, handle->GetWidth(), handle->GetHeight());
};
void AgentHandler::OnAudioStreamPacket(CefRefPtr<CefBrowser> browser,
const float** data,
int frames,
int64_t pts) {
// std::cout << "OnAudioStreamPacket" << std::endl;
}
void AgentHandler::OnAudioStreamStarted(CefRefPtr<CefBrowser> browser,
const CefAudioParameters& params,
int channels) {}
void AgentHandler::OnAudioStreamStopped(CefRefPtr<CefBrowser> browser) {}
void AgentHandler::OnAudioStreamError(CefRefPtr<CefBrowser> browser,
const CefString& message) {}
bool AgentHandler::OnBeforePopup(CefRefPtr<CefBrowser> browser,
CefRefPtr<CefFrame> frame,
const CefString& target_url,
const CefString& target_frame_name,
WindowOpenDisposition target_disposition,
bool user_gesture,
const CefPopupFeatures& popupFeatures,
CefWindowInfo& windowInfo,
CefRefPtr<CefClient>& client,
CefBrowserSettings& settings,
CefRefPtr<CefDictionaryValue>& extra_info,
bool* no_javascript_access) {
browser->GetMainFrame()->LoadURL(target_url);
return true;
}
void AgentHandler::OnAfterCreated(CefRefPtr<CefBrowser> browser) {
CEF_REQUIRE_UI_THREAD();
if (browser->IsPopup()) {
return;
}
int identifier = browser->GetIdentifier();
CefRefPtr<BrowserHandle> handle = browser_store_->pending_handles_.front();
browser_store_->pending_handles_.pop_front();
handle->browser_ = browser;
browser_store_->browser_handles_[identifier] = handle;
if (handle->created_callback_)
handle->created_callback_();
if (dev_renderer_)
dev_renderer_->OnAfterCreated(browser);
}
bool AgentHandler::DoClose(CefRefPtr<CefBrowser> browser) {
CEF_REQUIRE_UI_THREAD();
int identifier = browser->GetIdentifier();
CefRefPtr<BrowserHandle> handle =
browser_store_->browser_handles_[identifier];
browser_store_->browser_handles_.erase(identifier);
if (handle->close_callback_)
handle->close_callback_();
return false;
}
void AgentHandler::OnBeforeClose(CefRefPtr<CefBrowser> browser) {
CEF_REQUIRE_UI_THREAD();
if (dev_renderer_)
dev_renderer_->OnBeforeClose(browser);
}
void AgentHandler::OnLoadingStateChange(CefRefPtr<CefBrowser> browser,
bool isLoading,
bool canGoBack,
bool canGoForward) {
CEF_REQUIRE_UI_THREAD();
if (dev_renderer_)
dev_renderer_->OnLoadingStateChange(browser, isLoading, canGoBack,
canGoForward);
}
void AgentHandler::CloseAllBrowsers(bool force_close) {
if (!CefCurrentlyOn(TID_UI)) {
// Execute on the UI thread.
CefPostTask(TID_UI, base::BindOnce(&AgentHandler::CloseAllBrowsers, this,
force_close));
return;
}
if (browser_store_->browser_handles_.empty()) {
return;
}
for (const auto& pair : browser_store_->browser_handles_) {
pair.second->browser_->GetHost()->CloseBrowser(force_close);
}
}
#if !defined(OS_MAC)
void AgentHandler::PlatformShowWindow(CefRefPtr<CefBrowser> browser) {
NOTIMPLEMENTED();
}
#endif
#ifndef LKCEF_HANDLER_HPP
#define LKCEF_HANDLER_HPP
#include <list>
#include "dev_renderer.hpp"
#include "browser_handle.hpp"
#include "include/cef_client.h"
#include "include/wrapper/cef_helpers.h"
class DevToolsHandler : public CefClient {
public:
DevToolsHandler();
~DevToolsHandler();
static DevToolsHandler* GetInstance();
private:
IMPLEMENT_REFCOUNTING(DevToolsHandler);
};
class AgentHandler : public CefClient,
public CefDisplayHandler,
public CefRenderHandler,
public CefAudioHandler,
public CefLifeSpanHandler,
public CefLoadHandler {
public:
AgentHandler(CefRefPtr<BrowserStore> browser_store, CefRefPtr<DevRenderer> dev_renderer);
~AgentHandler();
static AgentHandler* GetInstance();
CefRefPtr<CefDisplayHandler> GetDisplayHandler() override { return this; }
CefRefPtr<CefRenderHandler> GetRenderHandler() override { return this; }
CefRefPtr<CefAudioHandler> GetAudioHandler() override { return this; }
CefRefPtr<CefLifeSpanHandler> GetLifeSpanHandler() override { return this; }
CefRefPtr<CefLoadHandler> GetLoadHandler() override { return this; }
// CefDisplayHandler methods
void OnTitleChange(CefRefPtr<CefBrowser> browser,
const CefString& title) override;
// CefRenderHandler methods
void OnPaint(CefRefPtr<CefBrowser> browser,
PaintElementType type,
const RectList& dirtyRects,
const void* buffer,
int width,
int height) override;
void GetViewRect(CefRefPtr<CefBrowser> browser, CefRect& rect) override;
// CefAudioHandler methods
void OnAudioStreamPacket(CefRefPtr<CefBrowser> browser,
const float** data,
int frames,
int64_t pts) override;
void OnAudioStreamStarted(CefRefPtr<CefBrowser> browser,
const CefAudioParameters& params,
int channels) override;
void OnAudioStreamStopped(CefRefPtr<CefBrowser> browser) override;
void OnAudioStreamError(CefRefPtr<CefBrowser> browser,
const CefString& message) override;
// CefLifeSpanHandler methods
bool OnBeforePopup(CefRefPtr<CefBrowser> browser,
CefRefPtr<CefFrame> frame,
const CefString& target_url,
const CefString& target_frame_name,
WindowOpenDisposition target_disposition,
bool user_gesture,
const CefPopupFeatures& popupFeatures,
CefWindowInfo& windowInfo,
CefRefPtr<CefClient>& client,
CefBrowserSettings& settings,
CefRefPtr<CefDictionaryValue>& extra_info,
bool* no_javascript_access) override;
void OnAfterCreated(CefRefPtr<CefBrowser> browser) override;
bool DoClose(CefRefPtr<CefBrowser> browser) override;
void OnBeforeClose(CefRefPtr<CefBrowser> browser) override;
// CefLoadHandler methods
void OnLoadingStateChange(CefRefPtr<CefBrowser> browser,
bool isLoading,
bool canGoBack,
bool canGoForward) override;
void CloseAllBrowsers(bool force_close);
private:
CefRefPtr<BrowserStore> browser_store_;
CefRefPtr<DevRenderer> dev_renderer_;
IMPLEMENT_REFCOUNTING(AgentHandler);
};
#endif // LKCEF_HANDLER_HPP
#include "include/cef_app.h"
#include "include/wrapper/cef_library_loader.h"
int main(int argc, char* argv[]) {
CefScopedLibraryLoader library_loader;
if (!library_loader.LoadInHelper()) {
return 1;
}
CefMainArgs main_args(argc, argv);
return CefExecuteProcess(main_args, nullptr, nullptr);
}
#ifndef LKCEF_KEYBOARD_CODES_H
#define LKCEF_KEYBOARD_CODES_H
namespace WebCore {
// VK_LBUTTON (01) Left mouse button
// VK_RBUTTON (02) Right mouse button
// VK_CANCEL (03) Control-break processing
// VK_MBUTTON (04) Middle mouse button (three-button mouse)
// VK_XBUTTON1 (05)
// VK_XBUTTON2 (06)
// VK_BACK (08) BACKSPACE key
const int VK_BACK = 0x08;
// VK_TAB (09) TAB key
const int VK_TAB = 0x09;
// VK_CLEAR (0C) CLEAR key
const int VK_CLEAR = 0x0C;
// VK_RETURN (0D)
const int VK_RETURN = 0x0D;
// VK_SHIFT (10) SHIFT key
const int VK_SHIFT = 0x10;
// VK_CONTROL (11) CTRL key
const int VK_CONTROL = 0x11;
// VK_MENU (12) ALT key
const int VK_MENU = 0x12;
// VK_PAUSE (13) PAUSE key
const int VK_PAUSE = 0x13;
// VK_CAPITAL (14) CAPS LOCK key
const int VK_CAPITAL = 0x14;
// VK_KANA (15) Input Method Editor (IME) Kana mode
const int VK_KANA = 0x15;
// VK_HANGUEL (15) IME Hanguel mode (maintained for compatibility; use
// VK_HANGUL) VK_HANGUL (15) IME Hangul mode
const int VK_HANGUL = 0x15;
// VK_JUNJA (17) IME Junja mode
const int VK_JUNJA = 0x17;
// VK_FINAL (18) IME final mode
const int VK_FINAL = 0x18;
// VK_HANJA (19) IME Hanja mode
const int VK_HANJA = 0x19;
// VK_KANJI (19) IME Kanji mode
const int VK_KANJI = 0x19;
// VK_ESCAPE (1B) ESC key
const int VK_ESCAPE = 0x1B;
// VK_CONVERT (1C) IME convert
const int VK_CONVERT = 0x1C;
// VK_NONCONVERT (1D) IME nonconvert
const int VK_NONCONVERT = 0x1D;
// VK_ACCEPT (1E) IME accept
const int VK_ACCEPT = 0x1E;
// VK_MODECHANGE (1F) IME mode change request
const int VK_MODECHANGE = 0x1F;
// VK_SPACE (20) SPACEBAR
const int VK_SPACE = 0x20;
// VK_PRIOR (21) PAGE UP key
const int VK_PRIOR = 0x21;
// VK_NEXT (22) PAGE DOWN key
const int VK_NEXT = 0x22;
// VK_END (23) END key
const int VK_END = 0x23;
// VK_HOME (24) HOME key
const int VK_HOME = 0x24;
// VK_LEFT (25) LEFT ARROW key
const int VK_LEFT = 0x25;
// VK_UP (26) UP ARROW key
const int VK_UP = 0x26;
// VK_RIGHT (27) RIGHT ARROW key
const int VK_RIGHT = 0x27;
// VK_DOWN (28) DOWN ARROW key
const int VK_DOWN = 0x28;
// VK_SELECT (29) SELECT key
const int VK_SELECT = 0x29;
// VK_PRINT (2A) PRINT key
const int VK_PRINT = 0x2A;
// VK_EXECUTE (2B) EXECUTE key
const int VK_EXECUTE = 0x2B;
// VK_SNAPSHOT (2C) PRINT SCREEN key
const int VK_SNAPSHOT = 0x2C;
// VK_INSERT (2D) INS key
const int VK_INSERT = 0x2D;
// VK_DELETE (2E) DEL key
const int VK_DELETE = 0x2E;
// VK_HELP (2F) HELP key
const int VK_HELP = 0x2F;
// (30) 0 key
const int VK_0 = 0x30;
// (31) 1 key
const int VK_1 = 0x31;
// (32) 2 key
const int VK_2 = 0x32;
// (33) 3 key
const int VK_3 = 0x33;
// (34) 4 key
const int VK_4 = 0x34;
// (35) 5 key;
const int VK_5 = 0x35;
// (36) 6 key
const int VK_6 = 0x36;
// (37) 7 key
const int VK_7 = 0x37;
// (38) 8 key
const int VK_8 = 0x38;
// (39) 9 key
const int VK_9 = 0x39;
// (41) A key
const int VK_A = 0x41;
// (42) B key
const int VK_B = 0x42;
// (43) C key
const int VK_C = 0x43;
// (44) D key
const int VK_D = 0x44;
// (45) E key
const int VK_E = 0x45;
// (46) F key
const int VK_F = 0x46;
// (47) G key
const int VK_G = 0x47;
// (48) H key
const int VK_H = 0x48;
// (49) I key
const int VK_I = 0x49;
// (4A) J key
const int VK_J = 0x4A;
// (4B) K key
const int VK_K = 0x4B;
// (4C) L key
const int VK_L = 0x4C;
// (4D) M key
const int VK_M = 0x4D;
// (4E) N key
const int VK_N = 0x4E;
// (4F) O key
const int VK_O = 0x4F;
// (50) P key
const int VK_P = 0x50;
// (51) Q key
const int VK_Q = 0x51;
// (52) R key
const int VK_R = 0x52;
// (53) S key
const int VK_S = 0x53;
// (54) T key
const int VK_T = 0x54;
// (55) U key
const int VK_U = 0x55;
// (56) V key
const int VK_V = 0x56;
// (57) W key
const int VK_W = 0x57;
// (58) X key
const int VK_X = 0x58;
// (59) Y key
const int VK_Y = 0x59;
// (5A) Z key
const int VK_Z = 0x5A;
// VK_LWIN (5B) Left Windows key (Microsoft Natural keyboard)
const int VK_LWIN = 0x5B;
// VK_RWIN (5C) Right Windows key (Natural keyboard)
const int VK_RWIN = 0x5C;
// VK_APPS (5D) Applications key (Natural keyboard)
const int VK_APPS = 0x5D;
// VK_SLEEP (5F) Computer Sleep key
const int VK_SLEEP = 0x5F;
// VK_NUMPAD0 (60) Numeric keypad 0 key
const int VK_NUMPAD0 = 0x60;
// VK_NUMPAD1 (61) Numeric keypad 1 key
const int VK_NUMPAD1 = 0x61;
// VK_NUMPAD2 (62) Numeric keypad 2 key
const int VK_NUMPAD2 = 0x62;
// VK_NUMPAD3 (63) Numeric keypad 3 key
const int VK_NUMPAD3 = 0x63;
// VK_NUMPAD4 (64) Numeric keypad 4 key
const int VK_NUMPAD4 = 0x64;
// VK_NUMPAD5 (65) Numeric keypad 5 key
const int VK_NUMPAD5 = 0x65;
// VK_NUMPAD6 (66) Numeric keypad 6 key
const int VK_NUMPAD6 = 0x66;
// VK_NUMPAD7 (67) Numeric keypad 7 key
const int VK_NUMPAD7 = 0x67;
// VK_NUMPAD8 (68) Numeric keypad 8 key
const int VK_NUMPAD8 = 0x68;
// VK_NUMPAD9 (69) Numeric keypad 9 key
const int VK_NUMPAD9 = 0x69;
// VK_MULTIPLY (6A) Multiply key
const int VK_MULTIPLY = 0x6A;
// VK_ADD (6B) Add key
const int VK_ADD = 0x6B;
// VK_SEPARATOR (6C) Separator key
const int VK_SEPARATOR = 0x6C;
// VK_SUBTRACT (6D) Subtract key
const int VK_SUBTRACT = 0x6D;
// VK_DECIMAL (6E) Decimal key
const int VK_DECIMAL = 0x6E;
// VK_DIVIDE (6F) Divide key
const int VK_DIVIDE = 0x6F;
// VK_F1 (70) F1 key
const int VK_F1 = 0x70;
// VK_F2 (71) F2 key
const int VK_F2 = 0x71;
// VK_F3 (72) F3 key
const int VK_F3 = 0x72;
// VK_F4 (73) F4 key
const int VK_F4 = 0x73;
// VK_F5 (74) F5 key
const int VK_F5 = 0x74;
// VK_F6 (75) F6 key
const int VK_F6 = 0x75;
// VK_F7 (76) F7 key
const int VK_F7 = 0x76;
// VK_F8 (77) F8 key
const int VK_F8 = 0x77;
// VK_F9 (78) F9 key
const int VK_F9 = 0x78;
// VK_F10 (79) F10 key
const int VK_F10 = 0x79;
// VK_F11 (7A) F11 key
const int VK_F11 = 0x7A;
// VK_F12 (7B) F12 key
const int VK_F12 = 0x7B;
// VK_F13 (7C) F13 key
const int VK_F13 = 0x7C;
// VK_F14 (7D) F14 key
const int VK_F14 = 0x7D;
// VK_F15 (7E) F15 key
const int VK_F15 = 0x7E;
// VK_F16 (7F) F16 key
const int VK_F16 = 0x7F;
// VK_F17 (80H) F17 key
const int VK_F17 = 0x80;
// VK_F18 (81H) F18 key
const int VK_F18 = 0x81;
// VK_F19 (82H) F19 key
const int VK_F19 = 0x82;
// VK_F20 (83H) F20 key
const int VK_F20 = 0x83;
// VK_F21 (84H) F21 key
const int VK_F21 = 0x84;
// VK_F22 (85H) F22 key
const int VK_F22 = 0x85;
// VK_F23 (86H) F23 key
const int VK_F23 = 0x86;
// VK_F24 (87H) F24 key
const int VK_F24 = 0x87;
// VK_NUMLOCK (90) NUM LOCK key
const int VK_NUMLOCK = 0x90;
// VK_SCROLL (91) SCROLL LOCK key
const int VK_SCROLL = 0x91;
// VK_LSHIFT (A0) Left SHIFT key
const int VK_LSHIFT = 0xA0;
// VK_RSHIFT (A1) Right SHIFT key
const int VK_RSHIFT = 0xA1;
// VK_LCONTROL (A2) Left CONTROL key
const int VK_LCONTROL = 0xA2;
// VK_RCONTROL (A3) Right CONTROL key
const int VK_RCONTROL = 0xA3;
// VK_LMENU (A4) Left MENU key
const int VK_LMENU = 0xA4;
// VK_RMENU (A5) Right MENU key
const int VK_RMENU = 0xA5;
// VK_BROWSER_BACK (A6) Windows 2000/XP: Browser Back key
const int VK_BROWSER_BACK = 0xA6;
// VK_BROWSER_FORWARD (A7) Windows 2000/XP: Browser Forward key
const int VK_BROWSER_FORWARD = 0xA7;
// VK_BROWSER_REFRESH (A8) Windows 2000/XP: Browser Refresh key
const int VK_BROWSER_REFRESH = 0xA8;
// VK_BROWSER_STOP (A9) Windows 2000/XP: Browser Stop key
const int VK_BROWSER_STOP = 0xA9;
// VK_BROWSER_SEARCH (AA) Windows 2000/XP: Browser Search key
const int VK_BROWSER_SEARCH = 0xAA;
// VK_BROWSER_FAVORITES (AB) Windows 2000/XP: Browser Favorites key
const int VK_BROWSER_FAVORITES = 0xAB;
// VK_BROWSER_HOME (AC) Windows 2000/XP: Browser Start and Home key
const int VK_BROWSER_HOME = 0xAC;
// VK_VOLUME_MUTE (AD) Windows 2000/XP: Volume Mute key
const int VK_VOLUME_MUTE = 0xAD;
// VK_VOLUME_DOWN (AE) Windows 2000/XP: Volume Down key
const int VK_VOLUME_DOWN = 0xAE;
// VK_VOLUME_UP (AF) Windows 2000/XP: Volume Up key
const int VK_VOLUME_UP = 0xAF;
// VK_MEDIA_NEXT_TRACK (B0) Windows 2000/XP: Next Track key
const int VK_MEDIA_NEXT_TRACK = 0xB0;
// VK_MEDIA_PREV_TRACK (B1) Windows 2000/XP: Previous Track key
const int VK_MEDIA_PREV_TRACK = 0xB1;
// VK_MEDIA_STOP (B2) Windows 2000/XP: Stop Media key
const int VK_MEDIA_STOP = 0xB2;
// VK_MEDIA_PLAY_PAUSE (B3) Windows 2000/XP: Play/Pause Media key
const int VK_MEDIA_PLAY_PAUSE = 0xB3;
// VK_LAUNCH_MAIL (B4) Windows 2000/XP: Start Mail key
const int VK_MEDIA_LAUNCH_MAIL = 0xB4;
// VK_LAUNCH_MEDIA_SELECT (B5) Windows 2000/XP: Select Media key
const int VK_MEDIA_LAUNCH_MEDIA_SELECT = 0xB5;
// VK_LAUNCH_APP1 (B6) Windows 2000/XP: Start Application 1 key
const int VK_MEDIA_LAUNCH_APP1 = 0xB6;
// VK_LAUNCH_APP2 (B7) Windows 2000/XP: Start Application 2 key
const int VK_MEDIA_LAUNCH_APP2 = 0xB7;
// VK_OEM_1 (BA) Used for miscellaneous characters; it can vary by keyboard.
// Windows 2000/XP: For the US standard keyboard, the ';:' key
const int VK_OEM_1 = 0xBA;
// VK_OEM_PLUS (BB) Windows 2000/XP: For any country/region, the '+' key
const int VK_OEM_PLUS = 0xBB;
// VK_OEM_COMMA (BC) Windows 2000/XP: For any country/region, the ',' key
const int VK_OEM_COMMA = 0xBC;
// VK_OEM_MINUS (BD) Windows 2000/XP: For any country/region, the '-' key
const int VK_OEM_MINUS = 0xBD;
// VK_OEM_PERIOD (BE) Windows 2000/XP: For any country/region, the '.' key
const int VK_OEM_PERIOD = 0xBE;
// VK_OEM_2 (BF) Used for miscellaneous characters; it can vary by keyboard.
// Windows 2000/XP: For the US standard keyboard, the '/?' key
const int VK_OEM_2 = 0xBF;
// VK_OEM_3 (C0) Used for miscellaneous characters; it can vary by keyboard.
// Windows 2000/XP: For the US standard keyboard, the '`~' key
const int VK_OEM_3 = 0xC0;
// VK_OEM_4 (DB) Used for miscellaneous characters; it can vary by keyboard.
// Windows 2000/XP: For the US standard keyboard, the '[{' key
const int VK_OEM_4 = 0xDB;
// VK_OEM_5 (DC) Used for miscellaneous characters; it can vary by keyboard.
// Windows 2000/XP: For the US standard keyboard, the '\|' key
const int VK_OEM_5 = 0xDC;
// VK_OEM_6 (DD) Used for miscellaneous characters; it can vary by keyboard.
// Windows 2000/XP: For the US standard keyboard, the ']}' key
const int VK_OEM_6 = 0xDD;
// VK_OEM_7 (DE) Used for miscellaneous characters; it can vary by keyboard.
// Windows 2000/XP: For the US standard keyboard, the
// 'single-quote/double-quote' key
const int VK_OEM_7 = 0xDE;
// VK_OEM_8 (DF) Used for miscellaneous characters; it can vary by keyboard.
const int VK_OEM_8 = 0xDF;
// VK_OEM_102 (E2) Windows 2000/XP: Either the angle bracket key or the
// backslash key on the RT 102-key keyboard
const int VK_OEM_102 = 0xE2;
// VK_PROCESSKEY (E5) Windows 95/98/Me, Windows NT 4.0, Windows 2000/XP: IME
// PROCESS key
const int VK_PROCESSKEY = 0xE5;
// VK_PACKET (E7) Windows 2000/XP: Used to pass Unicode characters as if they
// were keystrokes. The VK_PACKET key is the low word of a 32-bit Virtual Key
// value used for non-keyboard input methods. For more information, see Remark
// in KEYBDINPUT,SendInput, WM_KEYDOWN, and WM_KEYUP
const int VK_PACKET = 0xE7;
// VK_ATTN (F6) Attn key
const int VK_ATTN = 0xF6;
// VK_CRSEL (F7) CrSel key
const int VK_CRSEL = 0xF7;
// VK_EXSEL (F8) ExSel key
const int VK_EXSEL = 0xF8;
// VK_EREOF (F9) Erase EOF key
const int VK_EREOF = 0xF9;
// VK_PLAY (FA) Play key
const int VK_PLAY = 0xFA;
// VK_ZOOM (FB) Zoom key
const int VK_ZOOM = 0xFB;
// VK_NONAME (FC) Reserved for future use
const int VK_NONAME = 0xFC;
// VK_PA1 (FD) PA1 key
const int VK_PA1 = 0xFD;
// VK_OEM_CLEAR (FE) Clear key
const int VK_OEM_CLEAR = 0xFE;
const int VK_UNKNOWN = 0;
} // namespace WebCore
#endif // LKCEF_KEYBOARD_CODES_H
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>CFBundleDevelopmentRegion</key>
<string>en</string>
<key>CFBundleExecutable</key>
<string>lkcef_app</string>
<key>CFBundleIdentifier</key>
<string>io.livekit.cef</string>
<key>CFBundleInfoDictionaryVersion</key>
<string>6.0</string>
<key>CFBundleName</key>
<string>lkcef-agents</string>
<key>CFBundlePackageType</key>
<string>APPL</string>
<key>CFBundleSignature</key>
<string>????</string>
<key>LSEnvironment</key>
<dict>
<key>MallocNanoZone</key>
<string>0</string>
</dict>
<key>LSFileQuarantineEnabled</key>
<true/>
<key>LSMinimumSystemVersion</key>
<string>10.11.0</string>
<key>LSUIElement</key>
<string>1</string>
<key>NSSupportsAutomaticGraphicsSwitching</key>
<true/>
</dict>
</plist>
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>CFBundleDevelopmentRegion</key>
<string>en</string>
<key>CFBundleDisplayName</key>
<string>${EXECUTABLE_NAME}</string>
<key>CFBundleExecutable</key>
<string>${EXECUTABLE_NAME}</string>
<key>CFBundleIdentifier</key>
<string>io.livekit.cef.helper${BUNDLE_ID_SUFFIX}</string>
<key>CFBundleInfoDictionaryVersion</key>
<string>6.0</string>
<key>CFBundleName</key>
<string>${PRODUCT_NAME}</string>
<key>CFBundlePackageType</key>
<string>APPL</string>
<key>CFBundleSignature</key>
<string>????</string>
<key>LSEnvironment</key>
<dict>
<key>MallocNanoZone</key>
<string>0</string>
</dict>
<key>LSFileQuarantineEnabled</key>
<true/>
<key>LSMinimumSystemVersion</key>
<string>10.11.0</string>
<key>LSUIElement</key>
<string>1</string>
<key>NSSupportsAutomaticGraphicsSwitching</key>
<true/>
</dict>
</plist>
# flake8: noqa
import sys
print("cwd: ", sys.path[0])
sys.path.insert(0, "./Debug")
import lkcef_python as lkcef
print("lkcef __dict__: ", lkcef.__dict__)
print("BrowserImpl __dict__: ", lkcef.BrowserImpl.__dict__)
def _context_initialized():
opts = lkcef.BrowserOptions()
opts.framerate = 30
def _browser_created(browser_impl):
print("run_browser.py - Browser created")
opts.created_callback = _browser_created
def _on_paint(frame_data):
pass
opts.paint_callback = _on_paint
def _on_closed():
print("run_browser.py - Browser closed")
opts.close_callback = _on_closed
app.create_browser("http://www.livekit.io", opts)
print("run_browser.py - Context initialized")
opts = lkcef.AppOptions()
opts.dev_mode = True
opts.initialized_callback = _context_initialized
opts.framework_path = "/Users/theomonnom/livekit/agents/livekit-plugins/livekit-plugins-browser/cef/src/Debug/lkcef_app.app/Contents/Frameworks/Chromium Embedded Framework.framework"
opts.main_bundle_path = "/Users/theomonnom/livekit/agents/livekit-plugins/livekit-plugins-browser/cef/src/Debug/lkcef_app.app"
opts.subprocess_path = "/Users/theomonnom/livekit/agents/livekit-plugins/livekit-plugins-browser/cef/src/Debug/lkcef_app.app/Contents/Frameworks/lkcef Helper.app/Contents/MacOS/lkcef Helper"
app = lkcef.BrowserApp(opts)
app.run()
# livekit-plugins-cartesia
## 0.4.11
### Patch Changes
- Add string type support to model parameter - [#1657](https://github.com/livekit/agents/pull/1657) ([@jayeshp19](https://github.com/jayeshp19))
## 0.4.10
### Patch Changes
- Adding new model literals, updating default to sonic-2 - [#1627](https://github.com/livekit/agents/pull/1627) ([@longcw](https://github.com/longcw))
## 0.4.9
### Patch Changes
- use streaming AudioDecoder to handle compressed encoding - [#1584](https://github.com/livekit/agents/pull/1584) ([@davidzhao](https://github.com/davidzhao))
- added a tts.prewarm method to start the connection pool early. - [#1587](https://github.com/livekit/agents/pull/1587) ([@davidzhao](https://github.com/davidzhao))
- update pool configuration for deepgram and cartesia - [#1605](https://github.com/livekit/agents/pull/1605) ([@jayeshp19](https://github.com/jayeshp19))
- updated livekit-agent reference to <1.0 - [#1607](https://github.com/livekit/agents/pull/1607) ([@davidzhao](https://github.com/davidzhao))
## 0.4.8
### Patch Changes
- feat: connection pooling. speeds up generation with STT/TTS providers - [#1538](https://github.com/livekit/agents/pull/1538) ([@davidzhao](https://github.com/davidzhao))
- remove update options from tts synthesis stream - [#1546](https://github.com/livekit/agents/pull/1546) ([@jayeshp19](https://github.com/jayeshp19))
## 0.4.7
### Patch Changes
- improved TTFB metrics for streaming TTS - [#1431](https://github.com/livekit/agents/pull/1431) ([@davidzhao](https://github.com/davidzhao))
## 0.4.6
### Patch Changes
- update Cartesia plugin default model and voice id - [#1346](https://github.com/livekit/agents/pull/1346) ([@noahlt](https://github.com/noahlt))
## 0.4.5
### Patch Changes
- fix release - [#1176](https://github.com/livekit/agents/pull/1176) ([@theomonnom](https://github.com/theomonnom))
## 0.4.4
### Patch Changes
- feat: tts retry & tts.FallbackAdapter - [#1074](https://github.com/livekit/agents/pull/1074) ([@theomonnom](https://github.com/theomonnom))
## 0.4.3
### Patch Changes
- add update_options to TTS - [#922](https://github.com/livekit/agents/pull/922) ([@theomonnom](https://github.com/theomonnom))
- pipelineagent: expose timing metrics & api errors wip - [#957](https://github.com/livekit/agents/pull/957) ([@theomonnom](https://github.com/theomonnom))
- expose usage metrics - [#984](https://github.com/livekit/agents/pull/984) ([@theomonnom](https://github.com/theomonnom))
## 0.4.2
### Patch Changes
- Add support for cartesia voice control - [#740](https://github.com/livekit/agents/pull/740) ([@bcherry](https://github.com/bcherry))
## 0.4.1
### Patch Changes
- Switch Cartesia to a sentence tokenizer and keep the same context id throughout. - [#608](https://github.com/livekit/agents/pull/608) ([@keepingitneil](https://github.com/keepingitneil))
Propagate segment_id through the basic sentence tokenizer
## 0.3.0
### Minor Changes
- cartesia: correctly add spaces & fix tests - [#591](https://github.com/livekit/agents/pull/591) ([@theomonnom](https://github.com/theomonnom))
### Patch Changes
- fix log warnings & cartesia end of speech - [#603](https://github.com/livekit/agents/pull/603) ([@theomonnom](https://github.com/theomonnom))
- stt/tts: fix unread inputs when the input channel is closed - [#594](https://github.com/livekit/agents/pull/594) ([@theomonnom](https://github.com/theomonnom))
- Adds websockets streaming to Cartesia plugin - [#544](https://github.com/livekit/agents/pull/544) ([@sauhardjain](https://github.com/sauhardjain))
## 0.2.0
### Minor Changes
- dev prerelease - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
### Patch Changes
- test release - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
- pull: '--rebase --autostash ...' - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
- Default loglevel to warn - [#472](https://github.com/livekit/agents/pull/472) ([@lukasIO](https://github.com/lukasIO))
- bump versions to update dependencies - [#510](https://github.com/livekit/agents/pull/510) ([@theomonnom](https://github.com/theomonnom))
- test release - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
- fix changesets release CI - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
- release v0.8.0 - [`6e74aa714c2dfaa8212db4528d7b59d095b6c660`](https://github.com/livekit/agents/commit/6e74aa714c2dfaa8212db4528d7b59d095b6c660) ([@theomonnom](https://github.com/theomonnom))
- dev fixes - multiprocessing & voiceassistant - [#493](https://github.com/livekit/agents/pull/493) ([@theomonnom](https://github.com/theomonnom))
## 0.2.0-dev.7
### Patch Changes
- pull: '--rebase --autostash ...' - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
## 0.2.0-dev.6
### Patch Changes
- test release - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
## 0.2.0-dev.5
### Patch Changes
- test release - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
## 0.2.0-dev.4
### Patch Changes
- fix changesets release CI - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
## 0.2.0-dev.3
### Patch Changes
- bump versions to update dependencies - [#510](https://github.com/livekit/agents/pull/510) ([@theomonnom](https://github.com/theomonnom))
## 0.2.0-dev.2
### Patch Changes
- dev fixes - multiprocessing & voiceassistant - [#493](https://github.com/livekit/agents/pull/493) ([@theomonnom](https://github.com/theomonnom))
## 0.2.0-dev.1
### Minor Changes
- dev prerelease - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
## 0.1.2-dev.0
### Patch Changes
- Default loglevel to warn - [#472](https://github.com/livekit/agents/pull/472) ([@lukasIO](https://github.com/lukasIO))
# LiveKit Plugins Cartesia
Agent Framework plugin for voice synthesis with [Cartesia](https://cartesia.ai/) API.
## Installation
```bash
pip install livekit-plugins-cartesia
You’ll need an API key from Cartesia. It can be set as an environment variable: CARTESIA_API_KEY
## livekit-plugins/livekit-plugins-cartesia/livekit/plugins/cartesia/__init__.py
```py
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .tts import TTS, ChunkedStream
from .version import __version__
__all__ = ["TTS", "ChunkedStream", "__version__"]
from livekit.agents import Plugin
from .log import logger
class CartesiaPlugin(Plugin):
def __init__(self):
super().__init__(__name__, __version__, __package__, logger)
Plugin.register_plugin(CartesiaPlugin())
# Cleanup docs of unexported modules
_module = dir()
NOT_IN_ALL = [m for m in _module if m not in __all__]
__pdoc__ = {}
for n in NOT_IN_ALL:
__pdoc__[n] = False
import logging
logger = logging.getLogger("livekit.plugins.cartesia")
from typing import Literal
TTSEncoding = Literal[
"pcm_s16le",
# Not yet supported
# "pcm_f32le",
# "pcm_mulaw",
# "pcm_alaw",
]
TTSModels = Literal["sonic", "sonic-2", "sonic-lite", "sonic-preview", "sonic-turbo"]
TTSLanguages = Literal["en", "es", "fr", "de", "pt", "zh", "ja"]
TTSDefaultVoiceId = "794f9389-aac1-45b6-b726-9d9369183238"
TTSVoiceSpeed = Literal["fastest", "fast", "normal", "slow", "slowest"]
TTSVoiceEmotion = Literal[
"anger:lowest",
"anger:low",
"anger",
"anger:high",
"anger:highest",
"positivity:lowest",
"positivity:low",
"positivity",
"positivity:high",
"positivity:highest",
"surprise:lowest",
"surprise:low",
"surprise",
"surprise:high",
"surprise:highest",
"sadness:lowest",
"sadness:low",
"sadness",
"sadness:high",
"sadness:highest",
"curiosity:lowest",
"curiosity:low",
"curiosity",
"curiosity:high",
"curiosity:highest",
]
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import asyncio
import base64
import json
import os
import weakref
from dataclasses import dataclass
from typing import Any, Optional
import aiohttp
from livekit.agents import (
APIConnectionError,
APIConnectOptions,
APIStatusError,
APITimeoutError,
tokenize,
tts,
utils,
)
from .log import logger
from .models import (
TTSDefaultVoiceId,
TTSEncoding,
TTSModels,
TTSVoiceEmotion,
TTSVoiceSpeed,
)
API_AUTH_HEADER = "X-API-Key"
API_VERSION_HEADER = "Cartesia-Version"
API_VERSION = "2024-06-10"
NUM_CHANNELS = 1
BUFFERED_WORDS_COUNT = 3
@dataclass
class _TTSOptions:
model: TTSModels | str
encoding: TTSEncoding
sample_rate: int
voice: str | list[float]
speed: TTSVoiceSpeed | float | None
emotion: list[TTSVoiceEmotion | str] | None
api_key: str
language: str
base_url: str
def get_http_url(self, path: str) -> str:
return f"{self.base_url}{path}"
def get_ws_url(self, path: str) -> str:
return f"{self.base_url.replace('http', 'ws', 1)}{path}"
class TTS(tts.TTS):
def __init__(
self,
*,
model: TTSModels | str = "sonic-2",
language: str = "en",
encoding: TTSEncoding = "pcm_s16le",
voice: str | list[float] = TTSDefaultVoiceId,
speed: TTSVoiceSpeed | float | None = None,
emotion: list[TTSVoiceEmotion | str] | None = None,
sample_rate: int = 24000,
api_key: str | None = None,
http_session: aiohttp.ClientSession | None = None,
base_url: str = "https://api.cartesia.ai",
) -> None:
"""
Create a new instance of Cartesia TTS.
See https://docs.cartesia.ai/reference/web-socket/stream-speech/stream-speech for more details on the the Cartesia API.
Args:
model (TTSModels, optional): The Cartesia TTS model to use. Defaults to "sonic-2".
language (str, optional): The language code for synthesis. Defaults to "en".
encoding (TTSEncoding, optional): The audio encoding format. Defaults to "pcm_s16le".
voice (str | list[float], optional): The voice ID or embedding array.
speed (TTSVoiceSpeed | float, optional): Voice Control - Speed (https://docs.cartesia.ai/user-guides/voice-control)
emotion (list[TTSVoiceEmotion], optional): Voice Control - Emotion (https://docs.cartesia.ai/user-guides/voice-control)
sample_rate (int, optional): The audio sample rate in Hz. Defaults to 24000.
api_key (str, optional): The Cartesia API key. If not provided, it will be read from the CARTESIA_API_KEY environment variable.
http_session (aiohttp.ClientSession | None, optional): An existing aiohttp ClientSession to use. If not provided, a new session will be created.
base_url (str, optional): The base URL for the Cartesia API. Defaults to "https://api.cartesia.ai".
"""
super().__init__(
capabilities=tts.TTSCapabilities(streaming=True),
sample_rate=sample_rate,
num_channels=NUM_CHANNELS,
)
api_key = api_key or os.environ.get("CARTESIA_API_KEY")
if not api_key:
raise ValueError("CARTESIA_API_KEY must be set")
self._opts = _TTSOptions(
model=model,
language=language,
encoding=encoding,
sample_rate=sample_rate,
voice=voice,
speed=speed,
emotion=emotion,
api_key=api_key,
base_url=base_url,
)
self._session = http_session
self._pool = utils.ConnectionPool[aiohttp.ClientWebSocketResponse](
connect_cb=self._connect_ws,
close_cb=self._close_ws,
max_session_duration=300,
mark_refreshed_on_get=True,
)
self._streams = weakref.WeakSet[SynthesizeStream]()
async def _connect_ws(self) -> aiohttp.ClientWebSocketResponse:
session = self._ensure_session()
url = self._opts.get_ws_url(
f"/tts/websocket?api_key={self._opts.api_key}&cartesia_version={API_VERSION}"
)
return await asyncio.wait_for(
session.ws_connect(url), self._conn_options.timeout
)
async def _close_ws(self, ws: aiohttp.ClientWebSocketResponse):
await ws.close()
def _ensure_session(self) -> aiohttp.ClientSession:
if not self._session:
self._session = utils.http_context.http_session()
return self._session
def prewarm(self) -> None:
self._pool.prewarm()
def update_options(
self,
*,
model: TTSModels | str | None = None,
language: str | None = None,
voice: str | list[float] | None = None,
speed: TTSVoiceSpeed | float | None = None,
emotion: list[TTSVoiceEmotion | str] | None = None,
) -> None:
"""
Update the Text-to-Speech (TTS) configuration options.
This method allows updating the TTS settings, including model type, language, voice, speed,
and emotion. If any parameter is not provided, the existing value will be retained.
Args:
model (TTSModels, optional): The Cartesia TTS model to use. Defaults to "sonic-2".
language (str, optional): The language code for synthesis. Defaults to "en".
voice (str | list[float], optional): The voice ID or embedding array.
speed (TTSVoiceSpeed | float, optional): Voice Control - Speed (https://docs.cartesia.ai/user-guides/voice-control)
emotion (list[TTSVoiceEmotion], optional): Voice Control - Emotion (https://docs.cartesia.ai/user-guides/voice-control)
"""
self._opts.model = model or self._opts.model
self._opts.language = language or self._opts.language
self._opts.voice = voice or self._opts.voice
self._opts.speed = speed or self._opts.speed
if emotion is not None:
self._opts.emotion = emotion
def synthesize(
self,
text: str,
*,
conn_options: Optional[APIConnectOptions] = None,
) -> ChunkedStream:
return ChunkedStream(
tts=self,
input_text=text,
conn_options=conn_options,
opts=self._opts,
session=self._ensure_session(),
)
def stream(
self, *, conn_options: Optional[APIConnectOptions] = None
) -> "SynthesizeStream":
stream = SynthesizeStream(
tts=self,
pool=self._pool,
opts=self._opts,
)
self._streams.add(stream)
return stream
async def aclose(self) -> None:
for stream in list(self._streams):
await stream.aclose()
self._streams.clear()
await self._pool.aclose()
await super().aclose()
class ChunkedStream(tts.ChunkedStream):
"""Synthesize chunked text using the bytes endpoint"""
def __init__(
self,
*,
tts: TTS,
input_text: str,
opts: _TTSOptions,
session: aiohttp.ClientSession,
conn_options: Optional[APIConnectOptions] = None,
) -> None:
super().__init__(tts=tts, input_text=input_text, conn_options=conn_options)
self._opts, self._session = opts, session
async def _run(self) -> None:
request_id = utils.shortuuid()
bstream = utils.audio.AudioByteStream(
sample_rate=self._opts.sample_rate, num_channels=NUM_CHANNELS
)
json = _to_cartesia_options(self._opts)
json["transcript"] = self._input_text
headers = {
API_AUTH_HEADER: self._opts.api_key,
API_VERSION_HEADER: API_VERSION,
}
try:
async with self._session.post(
self._opts.get_http_url("/tts/bytes"),
headers=headers,
json=json,
timeout=aiohttp.ClientTimeout(
total=30,
sock_connect=self._conn_options.timeout,
),
) as resp:
resp.raise_for_status()
emitter = tts.SynthesizedAudioEmitter(
event_ch=self._event_ch,
request_id=request_id,
)
async for data, _ in resp.content.iter_chunks():
for frame in bstream.write(data):
emitter.push(frame)
for frame in bstream.flush():
emitter.push(frame)
emitter.flush()
except asyncio.TimeoutError as e:
raise APITimeoutError() from e
except aiohttp.ClientResponseError as e:
raise APIStatusError(
message=e.message,
status_code=e.status,
request_id=None,
body=None,
) from e
except Exception as e:
raise APIConnectionError() from e
class SynthesizeStream(tts.SynthesizeStream):
def __init__(
self,
*,
tts: TTS,
opts: _TTSOptions,
pool: utils.ConnectionPool[aiohttp.ClientWebSocketResponse],
):
super().__init__(tts=tts)
self._opts, self._pool = opts, pool
self._sent_tokenizer_stream = tokenize.basic.SentenceTokenizer(
min_sentence_len=BUFFERED_WORDS_COUNT
).stream()
async def _run(self) -> None:
request_id = utils.shortuuid()
async def _sentence_stream_task(ws: aiohttp.ClientWebSocketResponse):
base_pkt = _to_cartesia_options(self._opts)
async for ev in self._sent_tokenizer_stream:
token_pkt = base_pkt.copy()
token_pkt["context_id"] = request_id
token_pkt["transcript"] = ev.token + " "
token_pkt["continue"] = True
self._mark_started()
await ws.send_str(json.dumps(token_pkt))
end_pkt = base_pkt.copy()
end_pkt["context_id"] = request_id
end_pkt["transcript"] = " "
end_pkt["continue"] = False
await ws.send_str(json.dumps(end_pkt))
async def _input_task():
async for data in self._input_ch:
if isinstance(data, self._FlushSentinel):
self._sent_tokenizer_stream.flush()
continue
self._sent_tokenizer_stream.push_text(data)
self._sent_tokenizer_stream.end_input()
async def _recv_task(ws: aiohttp.ClientWebSocketResponse):
audio_bstream = utils.audio.AudioByteStream(
sample_rate=self._opts.sample_rate,
num_channels=NUM_CHANNELS,
)
emitter = tts.SynthesizedAudioEmitter(
event_ch=self._event_ch,
request_id=request_id,
)
while True:
msg = await ws.receive()
if msg.type in (
aiohttp.WSMsgType.CLOSED,
aiohttp.WSMsgType.CLOSE,
aiohttp.WSMsgType.CLOSING,
):
raise APIStatusError(
"Cartesia connection closed unexpectedly",
request_id=request_id,
)
if msg.type != aiohttp.WSMsgType.TEXT:
logger.warning("unexpected Cartesia message type %s", msg.type)
continue
data = json.loads(msg.data)
segment_id = data.get("context_id")
emitter._segment_id = segment_id
if data.get("data"):
b64data = base64.b64decode(data["data"])
for frame in audio_bstream.write(b64data):
emitter.push(frame)
elif data.get("done"):
for frame in audio_bstream.flush():
emitter.push(frame)
emitter.flush()
if segment_id == request_id:
# we're not going to receive more frames, end stream
break
else:
logger.error("unexpected Cartesia message %s", data)
async with self._pool.connection() as ws:
tasks = [
asyncio.create_task(_input_task()),
asyncio.create_task(_sentence_stream_task(ws)),
asyncio.create_task(_recv_task(ws)),
]
try:
await asyncio.gather(*tasks)
finally:
await utils.aio.gracefully_cancel(*tasks)
def _to_cartesia_options(opts: _TTSOptions) -> dict[str, Any]:
voice: dict[str, Any] = {}
if isinstance(opts.voice, str):
voice["mode"] = "id"
voice["id"] = opts.voice
else:
voice["mode"] = "embedding"
voice["embedding"] = opts.voice
voice_controls: dict = {}
if opts.speed is not None:
voice_controls["speed"] = opts.speed
if opts.emotion is not None:
voice_controls["emotion"] = opts.emotion
if voice_controls:
voice["__experimental_controls"] = voice_controls
return {
"model_id": opts.model,
"voice": voice,
"output_format": {
"container": "raw",
"encoding": opts.encoding,
"sample_rate": opts.sample_rate,
},
"language": opts.language,
}
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "0.4.11"
{
"name": "livekit-plugins-cartesia",
"private": true,
"version": "0.4.11"
}
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pathlib
import setuptools
import setuptools.command.build_py
here = pathlib.Path(__file__).parent.resolve()
about = {}
with open(os.path.join(here, "livekit", "plugins", "cartesia", "version.py"), "r") as f:
exec(f.read(), about)
setuptools.setup(
name="livekit-plugins-cartesia",
version=about["__version__"],
description="LiveKit Agents Plugin for Cartesia",
long_description=(here / "README.md").read_text(encoding="utf-8"),
long_description_content_type="text/markdown",
url="https://github.com/livekit/agents",
cmdclass={},
classifiers=[
"Intended Audience :: Developers",
"License :: OSI Approved :: Apache Software License",
"Topic :: Multimedia :: Sound/Audio",
"Topic :: Multimedia :: Video",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3 :: Only",
],
keywords=["webrtc", "realtime", "audio", "video", "livekit"],
license="Apache-2.0",
packages=setuptools.find_namespace_packages(include=["livekit.*"]),
python_requires=">=3.9.0",
install_requires=["livekit-agents>=0.12.16,<1.0.0"],
project_urls={
"Documentation": "https://docs.livekit.io",
"Website": "https://livekit.io/",
"Source": "https://github.com/livekit/agents",
},
)
# LiveKit Plugins Clova
Agent Framework plugin for speech-to-text with [Clova](https://api.ncloud-docs.com/docs/)'s API. Currently supports speech-to-text.
## Installation
```bash
pip install livekit-plugins-clova
You need invoke url and secret key from Naver cloud platform -> Clova Speech and set as environment variables: CLOVA_STT_INVOKE_URL
& CLOVA_STT_SECRET_KEY
## livekit-plugins/livekit-plugins-clova/livekit/plugins/clova/__init__.py
```py
from .stt import STT
from .version import __version__
__all__ = [
"STT",
"__version__",
]
from livekit.agents import Plugin
class ClovaSTTPlugin(Plugin):
def __init__(self):
super().__init__(__name__, __version__, __package__)
def download_files(self):
pass
Plugin.register_plugin(ClovaSTTPlugin())
# Cleanup docs of unexported modules
_module = dir()
NOT_IN_ALL = [m for m in _module if m not in __all__]
__pdoc__ = {}
for n in NOT_IN_ALL:
__pdoc__[n] = False
import io
from pydub import AudioSegment
def resample_audio(audio_bytes, original_sample_rate, target_sample_rate):
resampled_audio = AudioSegment.from_raw(
io.BytesIO(audio_bytes),
sample_width=2,
frame_rate=original_sample_rate,
channels=1,
).set_frame_rate(target_sample_rate)
return resampled_audio.raw_data
CLOVA_INPUT_SAMPLE_RATE = 16000
LIVEKIT_INPUT_SAMPLE_RATE = 48000
import logging
logger = logging.getLogger("livekit.plugins.clova")
from typing import Literal
ClovaSttLanguages = Literal["ko-KR", "en-US", "enko", "ja", "zh-cn", "zh-tw"]
ClovaSpeechAPIType = Literal[
"recognizer/object-storage", "recognizer/url", "recognizer/upload"
]
clova_languages_mapping = {
"en": "en-US",
"ko-KR": "ko-KR",
"en-US": "en-US",
"enko": "enko",
"ja": "ja",
"zh-cn": "zh-cn",
"zh-tw": "zh-tw",
}
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import asyncio
import io
import json
import os
import time
import wave
from typing import Optional, Union
import aiohttp
from livekit.agents import (
APIConnectOptions,
APIStatusError,
APITimeoutError,
stt,
utils,
)
from livekit.agents.stt import SpeechEventType, STTCapabilities
from livekit.agents.utils import AudioBuffer, merge_frames
from livekit.plugins.clova.constants import CLOVA_INPUT_SAMPLE_RATE
from .common import resample_audio
from .log import logger
from .models import ClovaSpeechAPIType, ClovaSttLanguages, clova_languages_mapping
class STT(stt.STT):
def __init__(
self,
*,
language: ClovaSttLanguages | str = "en-US",
secret: Optional[str] = None,
invoke_url: Optional[str] = None,
http_session: Optional[aiohttp.ClientSession] = None,
threshold: float = 0.5,
):
"""
Create a new instance of Clova STT.
``secret`` and ``invoke_url`` must be set, either using arguments or by setting the
``CLOVA_STT_SECRET_KEY`` and ``CLOVA_STT_INVOKE_URL`` environmental variables, respectively.
"""
super().__init__(
capabilities=STTCapabilities(streaming=False, interim_results=True)
)
self._secret = secret or os.environ.get("CLOVA_STT_SECRET_KEY")
self._invoke_url = invoke_url or os.environ.get("CLOVA_STT_INVOKE_URL")
self._language = clova_languages_mapping.get(language, language)
self._session = http_session
if self._secret is None:
raise ValueError(
"Clova STT secret key is required. It should be set with env CLOVA_STT_SECRET_KEY"
)
self.threshold = threshold
def update_options(self, *, language: str | None = None) -> None:
self._language = (
clova_languages_mapping.get(language, language) or self._language
)
def _ensure_session(self) -> aiohttp.ClientSession:
if not self._session:
self._session = utils.http_context.http_session()
return self._session
def url_builder(
self, process_method: ClovaSpeechAPIType = "recognizer/upload"
) -> str:
return f"{self._invoke_url}/{process_method}"
async def _recognize_impl(
self,
buffer: AudioBuffer,
*,
language: Union[ClovaSttLanguages, str, None],
conn_options: APIConnectOptions,
) -> stt.SpeechEvent:
try:
url = self.url_builder()
payload = json.dumps({"language": self._language, "completion": "sync"})
buffer = merge_frames(buffer)
buffer_bytes = resample_audio(
buffer.data.tobytes(), buffer.sample_rate, CLOVA_INPUT_SAMPLE_RATE
)
io_buffer = io.BytesIO()
with wave.open(io_buffer, "wb") as wav:
wav.setnchannels(1)
wav.setsampwidth(2) # 16-bit
wav.setframerate(CLOVA_INPUT_SAMPLE_RATE)
wav.writeframes(buffer_bytes)
io_buffer.seek(0)
headers = {"X-CLOVASPEECH-API-KEY": self._secret}
form_data = aiohttp.FormData()
form_data.add_field("params", payload)
form_data.add_field(
"media", io_buffer, filename="audio.wav", content_type="audio/wav"
)
start = time.time()
async with self._ensure_session().post(
url,
data=form_data,
headers=headers,
timeout=aiohttp.ClientTimeout(
total=30,
sock_connect=conn_options.timeout,
),
) as response:
response_data = await response.json()
end = time.time()
text = response_data.get("text")
confidence = response_data.get("confidence")
logger.info(f"{text} | {confidence} | total_seconds: {end - start}")
if not text or "error" in response_data:
raise ValueError(f"Unexpected response: {response_data}")
if confidence < self.threshold:
raise ValueError(
f"Confidence: {confidence} is bellow threshold {self.threshold}. Skipping."
)
logger.info(f"final event: {response_data}")
return self._transcription_to_speech_event(text=text)
except asyncio.TimeoutError as e:
raise APITimeoutError() from e
except aiohttp.ClientResponseError as e:
raise APIStatusError(
message=e.message,
status_code=e.status,
request_id=None,
body=None,
) from e
def _transcription_to_speech_event(
self,
event_type: SpeechEventType = stt.SpeechEventType.INTERIM_TRANSCRIPT,
text: str | None = None,
) -> stt.SpeechEvent:
return stt.SpeechEvent(
type=event_type,
alternatives=[stt.SpeechData(text=text, language=self._language)],
)
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "0.0.2"
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pathlib
import setuptools
import setuptools.command.build_py
here = pathlib.Path(__file__).parent.resolve()
about = {}
with open(os.path.join(here, "livekit", "plugins", "clova", "version.py"), "r") as f:
exec(f.read(), about)
setuptools.setup(
name="livekit-plugins-clova",
version=about["__version__"],
description="LiveKit Agents Plugin for LINE Clova STT",
long_description=(here / "README.md").read_text(encoding="utf-8"),
long_description_content_type="text/markdown",
url="https://github.com/livekit/agents",
cmdclass={},
classifiers=[
"Intended Audience :: Developers",
"License :: OSI Approved :: Apache Software License",
"Topic :: Multimedia :: Sound/Audio",
"Topic :: Multimedia :: Video",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3 :: Only",
],
keywords=["webrtc", "realtime", "audio", "video", "livekit"],
license="Apache-2.0",
packages=setuptools.find_namespace_packages(include=["livekit.*"]),
python_requires=">=3.9.0",
install_requires=["livekit-agents>=0.12.16,<1.0.0", "pydub~=0.25.1"],
project_urls={
"Documentation": "https://docs.livekit.io",
"Website": "https://livekit.io/",
"Source": "https://github.com/livekit/agents",
},
)
# livekit-plugins-deepgram
## 0.7.3
### Patch Changes
- support multilingual with Nova-3 model - [#1736](https://github.com/livekit/agents/pull/1736) ([@jeradf](https://github.com/jeradf))
## 0.7.2
### Patch Changes
- Added optional parameter to opt out from deepgrams model improvement plan - [#1713](https://github.com/livekit/agents/pull/1713) ([@MatthiasGruba](https://github.com/MatthiasGruba))
## 0.7.1
### Patch Changes
- add `nova-3-medical` to stt models - [#1657](https://github.com/livekit/agents/pull/1657) ([@jayeshp19](https://github.com/jayeshp19))
- Add string type support to model parameter - [#1657](https://github.com/livekit/agents/pull/1657) ([@jayeshp19](https://github.com/jayeshp19))
- support numerals deepgram stt - [#1667](https://github.com/livekit/agents/pull/1667) ([@jayeshp19](https://github.com/jayeshp19))
## 0.7.0
### Minor Changes
- use streaming AudioDecoder to handle compressed encoding - [#1584](https://github.com/livekit/agents/pull/1584) ([@davidzhao](https://github.com/davidzhao))
### Patch Changes
- added a tts.prewarm method to start the connection pool early. - [#1587](https://github.com/livekit/agents/pull/1587) ([@davidzhao](https://github.com/davidzhao))
- update pool configuration for deepgram and cartesia - [#1605](https://github.com/livekit/agents/pull/1605) ([@jayeshp19](https://github.com/jayeshp19))
- set mex session duration to 1 hour in deepgram connection pool - [#1582](https://github.com/livekit/agents/pull/1582) ([@jayeshp19](https://github.com/jayeshp19))
- updated livekit-agent reference to <1.0 - [#1607](https://github.com/livekit/agents/pull/1607) ([@davidzhao](https://github.com/davidzhao))
## 0.6.20
### Patch Changes
- fix(deepgram): fix STT keyterm parameter - [#1535](https://github.com/livekit/agents/pull/1535) ([@wdhwg001](https://github.com/wdhwg001))
- use connection pool for deepgram tts - [#1523](https://github.com/livekit/agents/pull/1523) ([@jayeshp19](https://github.com/jayeshp19))
- remove update options from tts synthesis stream - [#1546](https://github.com/livekit/agents/pull/1546) ([@jayeshp19](https://github.com/jayeshp19))
## 0.6.19
### Patch Changes
- deepgram: support for Nova-3 keyterms - [#1484](https://github.com/livekit/agents/pull/1484) ([@davidzhao](https://github.com/davidzhao))
## 0.6.18
### Patch Changes
- chore(Deepgram STT): add nova-3 model to type literal - [#1464](https://github.com/livekit/agents/pull/1464) ([@chasemcdo](https://github.com/chasemcdo))
## 0.6.17
### Patch Changes
- improved TTFB metrics for streaming TTS - [#1431](https://github.com/livekit/agents/pull/1431) ([@davidzhao](https://github.com/davidzhao))
## 0.6.16
### Patch Changes
- fix: Ensure STT exceptions are being propagated - [#1291](https://github.com/livekit/agents/pull/1291) ([@davidzhao](https://github.com/davidzhao))
## 0.6.15
### Patch Changes
- added streaming audio decoder for compressed audio. - [#1236](https://github.com/livekit/agents/pull/1236) ([@davidzhao](https://github.com/davidzhao))
- Support Deepgram TTS - [#1201](https://github.com/livekit/agents/pull/1201) ([@jayeshp19](https://github.com/jayeshp19))
## 0.6.14
### Patch Changes
- enable deepgram filler words by default to improve end of turn accuracy - [#1190](https://github.com/livekit/agents/pull/1190) ([@davidzhao](https://github.com/davidzhao))
## 0.6.13
### Patch Changes
- fix release - [#1176](https://github.com/livekit/agents/pull/1176) ([@theomonnom](https://github.com/theomonnom))
## 0.6.12
### Patch Changes
- feat: stt retry & stt.FallbackAdapter - [#1114](https://github.com/livekit/agents/pull/1114) ([@theomonnom](https://github.com/theomonnom))
- Added support for custom deepgram base url - [#1137](https://github.com/livekit/agents/pull/1137) ([@theomonnom](https://github.com/theomonnom))
## 0.6.11
### Patch Changes
- add PeriodicCollector utility for metrics - [#1094](https://github.com/livekit/agents/pull/1094) ([@davidzhao](https://github.com/davidzhao))
## 0.6.10
### Patch Changes
- fix Deepgram missing first word, disabled energy filter by default - [#1090](https://github.com/livekit/agents/pull/1090) ([@davidzhao](https://github.com/davidzhao))
## 0.6.9
### Patch Changes
- stt: reduce bandwidth usage by reducing sample_rate to 16khz - [#920](https://github.com/livekit/agents/pull/920) ([@theomonnom](https://github.com/theomonnom))
- deepgram: send finalize each time we stop sending audio - [#1004](https://github.com/livekit/agents/pull/1004) ([@theomonnom](https://github.com/theomonnom))
- pipelineagent: expose timing metrics & api errors wip - [#957](https://github.com/livekit/agents/pull/957) ([@theomonnom](https://github.com/theomonnom))
- expose usage metrics - [#984](https://github.com/livekit/agents/pull/984) ([@theomonnom](https://github.com/theomonnom))
## 0.6.8
### Patch Changes
- accepts parameter profanity_filter - [#811](https://github.com/livekit/agents/pull/811) ([@jebjebs](https://github.com/jebjebs))
## 0.6.7
### Patch Changes
- Only send actual audio to Deepgram using a basic audio RMS filter - [#738](https://github.com/livekit/agents/pull/738) ([@keepingitneil](https://github.com/keepingitneil))
- defaults to nova-2-general model - [#726](https://github.com/livekit/agents/pull/726) ([@davidzhao](https://github.com/davidzhao))
## 0.6.6
### Patch Changes
- deepgram: switch the default model to phonecall - [#676](https://github.com/livekit/agents/pull/676) ([@theomonnom](https://github.com/theomonnom))
## 0.6.5
### Patch Changes
- deepgram: fallback to nova-2-general when the language isn't supported - [#623](https://github.com/livekit/agents/pull/623) ([@theomonnom](https://github.com/theomonnom))
## 0.6.4
### Patch Changes
- deepgram: add support for keywords boost/penalty - [#599](https://github.com/livekit/agents/pull/599) ([@theomonnom](https://github.com/theomonnom))
- fix log warnings & cartesia end of speech - [#603](https://github.com/livekit/agents/pull/603) ([@theomonnom](https://github.com/theomonnom))
- stt/tts: fix unread inputs when the input channel is closed - [#594](https://github.com/livekit/agents/pull/594) ([@theomonnom](https://github.com/theomonnom))
## 0.6.3
### Patch Changes
- deepgram: update default model to nova-2-conversationalai - [#576](https://github.com/livekit/agents/pull/576) ([@theomonnom](https://github.com/theomonnom))
## 0.6.2
### Patch Changes
- deepgram: reduce chunks size to 100ms - [#561](https://github.com/livekit/agents/pull/561) ([@theomonnom](https://github.com/theomonnom))
- deepgram: segment audio frames into 200ms intervals before sending to the websocket #549 - [#553](https://github.com/livekit/agents/pull/553) ([@theomonnom](https://github.com/theomonnom))
## 0.6.1
### Patch Changes
- fix end_input not flushing & unhandled flush messages - [#528](https://github.com/livekit/agents/pull/528) ([@theomonnom](https://github.com/theomonnom))
## 0.6.0
### Minor Changes
- dev prerelease - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
### Patch Changes
- test release - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
- pull: '--rebase --autostash ...' - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
- Default loglevel to warn - [#472](https://github.com/livekit/agents/pull/472) ([@lukasIO](https://github.com/lukasIO))
- bump versions to update dependencies - [#510](https://github.com/livekit/agents/pull/510) ([@theomonnom](https://github.com/theomonnom))
- test release - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
- fix changesets release CI - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
- release v0.8.0 - [`6e74aa714c2dfaa8212db4528d7b59d095b6c660`](https://github.com/livekit/agents/commit/6e74aa714c2dfaa8212db4528d7b59d095b6c660) ([@theomonnom](https://github.com/theomonnom))
- dev fixes - multiprocessing & voiceassistant - [#493](https://github.com/livekit/agents/pull/493) ([@theomonnom](https://github.com/theomonnom))
## 0.6.0-dev.7
### Patch Changes
- pull: '--rebase --autostash ...' - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
## 0.6.0-dev.6
### Patch Changes
- test release - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
## 0.6.0-dev.5
### Patch Changes
- test release - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
## 0.6.0-dev.4
### Patch Changes
- fix changesets release CI - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
## 0.6.0-dev.3
### Patch Changes
- bump versions to update dependencies - [#510](https://github.com/livekit/agents/pull/510) ([@theomonnom](https://github.com/theomonnom))
## 0.6.0-dev.2
### Patch Changes
- dev fixes - multiprocessing & voiceassistant - [#493](https://github.com/livekit/agents/pull/493) ([@theomonnom](https://github.com/theomonnom))
## 0.6.0-dev.1
### Minor Changes
- dev prerelease - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
## 0.5.2-dev.0
### Patch Changes
- Default loglevel to warn - [#472](https://github.com/livekit/agents/pull/472) ([@lukasIO](https://github.com/lukasIO))
# LiveKit Plugins DeepGram
Agent Framework plugin for speech-to-text with [DeepGram](https://deepgram.com/)'s API. Currently supports speech-to-text.
## Installation
```bash
pip install livekit-plugins-deepgram
You’ll need an API key from DeepGram. It can be set as an environment variable: DEEPGRAM_API_KEY
## livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/__init__.py
```py
from .stt import STT, AudioEnergyFilter, SpeechStream
from .tts import TTS
from .version import __version__
__all__ = ["STT", "SpeechStream", "AudioEnergyFilter", "__version__", "TTS"]
from livekit.agents import Plugin
from .log import logger
class DeepgramPlugin(Plugin):
def __init__(self):
super().__init__(__name__, __version__, __package__, logger)
Plugin.register_plugin(DeepgramPlugin())
# Cleanup docs of unexported modules
_module = dir()
NOT_IN_ALL = [m for m in _module if m not in __all__]
__pdoc__ = {}
for n in NOT_IN_ALL:
__pdoc__[n] = False
import time
from typing import Callable, Generic, Optional, TypeVar
T = TypeVar("T")
class PeriodicCollector(Generic[T]):
def __init__(self, callback: Callable[[T], None], *, duration: float) -> None:
"""
Create a new periodic collector that accumulates values and calls the callback
after the specified duration if there are values to report.
Args:
duration: Time in seconds between callback invocations
callback: Function to call with accumulated value when duration expires
"""
self._duration = duration
self._callback = callback
self._last_flush_time = time.monotonic()
self._total: Optional[T] = None
def push(self, value: T) -> None:
"""Add a value to the accumulator"""
if self._total is None:
self._total = value
else:
self._total += value # type: ignore
if time.monotonic() - self._last_flush_time >= self._duration:
self.flush()
def flush(self) -> None:
"""Force callback to be called with current total if non-zero"""
if self._total is not None:
self._callback(self._total)
self._total = None
self._last_flush_time = time.monotonic()
import logging
logger = logging.getLogger("livekit.plugins.deepgram")
from typing import Literal
DeepgramModels = Literal[
"nova-general",
"nova-phonecall",
"nova-meeting",
"nova-2-general",
"nova-2-meeting",
"nova-2-phonecall",
"nova-2-finance",
"nova-2-conversationalai",
"nova-2-voicemail",
"nova-2-video",
"nova-2-medical",
"nova-2-drivethru",
"nova-2-automotive",
"nova-3",
"nova-3-general",
"nova-3-medical",
"enhanced-general",
"enhanced-meeting",
"enhanced-phonecall",
"enhanced-finance",
"base",
"meeting",
"phonecall",
"finance",
"conversationalai",
"voicemail",
"video",
"whisper-tiny",
"whisper-base",
"whisper-small",
"whisper-medium",
"whisper-large",
]
DeepgramLanguages = Literal[
"zh",
"zh-CN",
"zh-TW",
"da",
"nl",
"en",
"en-US",
"en-AU",
"en-GB",
"en-NZ",
"en-IN",
"fr",
"fr-CA",
"de",
"hi",
"hi-Latn",
"pt",
"pt-BR",
"es",
"es-419",
"hi",
"hi-Latn",
"it",
"ja",
"ko",
"no",
"pl",
"pt",
"pt-BR",
"es-LATAM",
"sv",
"ta",
"taq",
"uk",
"tr",
"sv",
"id",
"pt",
"pt-BR",
"ru",
"th",
]
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import asyncio
import dataclasses
import json
import os
import weakref
from dataclasses import dataclass
from enum import Enum
from typing import Any, List, Optional, Tuple
from urllib.parse import urlencode
import aiohttp
import numpy as np
from livekit import rtc
from livekit.agents import (
DEFAULT_API_CONNECT_OPTIONS,
APIConnectionError,
APIConnectOptions,
APIStatusError,
APITimeoutError,
stt,
utils,
)
from livekit.agents.utils import AudioBuffer
from ._utils import PeriodicCollector
from .log import logger
from .models import DeepgramLanguages, DeepgramModels
BASE_URL = "https://api.deepgram.com/v1/listen"
# This is the magic number during testing that we use to determine if a frame is loud enough
# to possibly contain speech. It's very conservative.
MAGIC_NUMBER_THRESHOLD = 0.004**2
class AudioEnergyFilter:
class State(Enum):
START = 0
SPEAKING = 1
SILENCE = 2
END = 3
def __init__(
self, *, min_silence: float = 1.5, rms_threshold: float = MAGIC_NUMBER_THRESHOLD
):
self._cooldown_seconds = min_silence
self._cooldown = min_silence
self._state = self.State.SILENCE
self._rms_threshold = rms_threshold
def update(self, frame: rtc.AudioFrame) -> State:
arr = np.frombuffer(frame.data, dtype=np.int16)
float_arr = arr.astype(np.float32) / 32768.0
rms = np.mean(np.square(float_arr))
if rms > self._rms_threshold:
self._cooldown = self._cooldown_seconds
if self._state in (self.State.SILENCE, self.State.END):
self._state = self.State.START
else:
self._state = self.State.SPEAKING
else:
if self._cooldown <= 0:
if self._state in (self.State.SPEAKING, self.State.START):
self._state = self.State.END
elif self._state == self.State.END:
self._state = self.State.SILENCE
else:
# keep speaking during cooldown
self._cooldown -= frame.duration
self._state = self.State.SPEAKING
return self._state
@dataclass
class STTOptions:
language: DeepgramLanguages | str | None
detect_language: bool
interim_results: bool
punctuate: bool
model: DeepgramModels | str
smart_format: bool
no_delay: bool
endpointing_ms: int
filler_words: bool
sample_rate: int
num_channels: int
keywords: list[Tuple[str, float]]
keyterms: list[str]
profanity_filter: bool
energy_filter: AudioEnergyFilter | bool = False
numerals: bool = False
mip_opt_out: bool = False
class STT(stt.STT):
def __init__(
self,
*,
model: DeepgramModels | str = "nova-2-general",
language: DeepgramLanguages | str = "en-US",
detect_language: bool = False,
interim_results: bool = True,
punctuate: bool = True,
smart_format: bool = True,
sample_rate: int = 16000,
no_delay: bool = True,
endpointing_ms: int = 25,
# enable filler words by default to improve turn detector accuracy
filler_words: bool = True,
keywords: list[Tuple[str, float]] | None = None,
keyterms: list[str] | None = None,
profanity_filter: bool = False,
api_key: str | None = None,
http_session: aiohttp.ClientSession | None = None,
base_url: str = BASE_URL,
energy_filter: AudioEnergyFilter | bool = False,
numerals: bool = False,
mip_opt_out: bool = False,
) -> None:
"""Create a new instance of Deepgram STT.
Args:
model: The Deepgram model to use for speech recognition. Defaults to "nova-2-general".
language: The language code for recognition. Defaults to "en-US".
detect_language: Whether to enable automatic language detection. Defaults to False.
interim_results: Whether to return interim (non-final) transcription results. Defaults to True.
punctuate: Whether to add punctuations to the transcription. Defaults to True. Turn detector will work better with punctuations.
smart_format: Whether to apply smart formatting to numbers, dates, etc. Defaults to True.
sample_rate: The sample rate of the audio in Hz. Defaults to 16000.
no_delay: When smart_format is used, ensures it does not wait for sequence to be complete before returning results. Defaults to True.
endpointing_ms: Time in milliseconds of silence to consider end of speech. Set to 0 to disable. Defaults to 25.
filler_words: Whether to include filler words (um, uh, etc.) in transcription. Defaults to True.
keywords: List of tuples containing keywords and their boost values for improved recognition.
Each tuple should be (keyword: str, boost: float). Defaults to None.
`keywords` does not work with Nova-3 models. Use `keyterms` instead.
keyterms: List of key terms to improve recognition accuracy. Defaults to None.
`keyterms` is supported by Nova-3 models.
profanity_filter: Whether to filter profanity from the transcription. Defaults to False.
api_key: Your Deepgram API key. If not provided, will look for DEEPGRAM_API_KEY environment variable.
http_session: Optional aiohttp ClientSession to use for requests.
base_url: The base URL for Deepgram API. Defaults to "https://api.deepgram.com/v1/listen".
energy_filter: Audio energy filter configuration for voice activity detection.
Can be a boolean or AudioEnergyFilter instance. Defaults to False.
numerals: Whether to include numerals in the transcription. Defaults to False.
mip_opt_out: Whether to take part in the model improvement program
Raises:
ValueError: If no API key is provided or found in environment variables.
Note:
The api_key must be set either through the constructor argument or by setting
the DEEPGRAM_API_KEY environmental variable.
"""
super().__init__(
capabilities=stt.STTCapabilities(
streaming=True, interim_results=interim_results
)
)
self._base_url = base_url
api_key = api_key or os.environ.get("DEEPGRAM_API_KEY")
if api_key is None:
raise ValueError("Deepgram API key is required")
model = _validate_model(model, language)
self._api_key = api_key
self._opts = STTOptions(
language=language,
detect_language=detect_language,
interim_results=interim_results,
punctuate=punctuate,
model=model,
smart_format=smart_format,
no_delay=no_delay,
endpointing_ms=endpointing_ms,
filler_words=filler_words,
sample_rate=sample_rate,
num_channels=1,
keywords=keywords or [],
keyterms=keyterms or [],
profanity_filter=profanity_filter,
energy_filter=energy_filter,
numerals=numerals,
mip_opt_out=mip_opt_out,
)
self._session = http_session
self._streams = weakref.WeakSet[SpeechStream]()
def _ensure_session(self) -> aiohttp.ClientSession:
if not self._session:
self._session = utils.http_context.http_session()
return self._session
async def _recognize_impl(
self,
buffer: AudioBuffer,
*,
language: DeepgramLanguages | str | None,
conn_options: APIConnectOptions,
) -> stt.SpeechEvent:
config = self._sanitize_options(language=language)
recognize_config = {
"model": str(config.model),
"punctuate": config.punctuate,
"detect_language": config.detect_language,
"smart_format": config.smart_format,
"keywords": self._opts.keywords,
"profanity_filter": config.profanity_filter,
"numerals": config.numerals,
}
if config.language:
recognize_config["language"] = config.language
try:
async with self._ensure_session().post(
url=_to_deepgram_url(recognize_config, self._base_url, websocket=False),
data=rtc.combine_audio_frames(buffer).to_wav_bytes(),
headers={
"Authorization": f"Token {self._api_key}",
"Accept": "application/json",
"Content-Type": "audio/wav",
},
timeout=aiohttp.ClientTimeout(
total=30,
sock_connect=conn_options.timeout,
),
) as res:
return prerecorded_transcription_to_speech_event(
config.language,
await res.json(),
)
except asyncio.TimeoutError as e:
raise APITimeoutError() from e
except aiohttp.ClientResponseError as e:
raise APIStatusError(
message=e.message,
status_code=e.status,
request_id=None,
body=None,
) from e
except Exception as e:
raise APIConnectionError() from e
def stream(
self,
*,
language: DeepgramLanguages | str | None = None,
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
) -> "SpeechStream":
config = self._sanitize_options(language=language)
stream = SpeechStream(
stt=self,
conn_options=conn_options,
opts=config,
api_key=self._api_key,
http_session=self._ensure_session(),
base_url=self._base_url,
)
self._streams.add(stream)
return stream
def update_options(
self,
*,
language: DeepgramLanguages | str | None = None,
model: DeepgramModels | str | None = None,
interim_results: bool | None = None,
punctuate: bool | None = None,
smart_format: bool | None = None,
sample_rate: int | None = None,
no_delay: bool | None = None,
endpointing_ms: int | None = None,
filler_words: bool | None = None,
keywords: list[Tuple[str, float]] | None = None,
keyterms: list[str] | None = None,
profanity_filter: bool | None = None,
numerals: bool | None = None,
mip_opt_out: bool | None = None,
):
if language is not None:
self._opts.language = language
if model is not None:
self._opts.model = _validate_model(model, language)
if interim_results is not None:
self._opts.interim_results = interim_results
if punctuate is not None:
self._opts.punctuate = punctuate
if smart_format is not None:
self._opts.smart_format = smart_format
if sample_rate is not None:
self._opts.sample_rate = sample_rate
if no_delay is not None:
self._opts.no_delay = no_delay
if endpointing_ms is not None:
self._opts.endpointing_ms = endpointing_ms
if filler_words is not None:
self._opts.filler_words = filler_words
if keywords is not None:
self._opts.keywords = keywords
if keyterms is not None:
self._opts.keyterms = keyterms
if profanity_filter is not None:
self._opts.profanity_filter = profanity_filter
if mip_opt_out is not None:
self._opts.mip_opt_out = mip_opt_out
for stream in self._streams:
stream.update_options(
language=language,
model=model,
interim_results=interim_results,
punctuate=punctuate,
smart_format=smart_format,
sample_rate=sample_rate,
no_delay=no_delay,
endpointing_ms=endpointing_ms,
filler_words=filler_words,
keywords=keywords,
keyterms=keyterms,
profanity_filter=profanity_filter,
numerals=numerals,
mip_opt_out=mip_opt_out,
)
def _sanitize_options(self, *, language: str | None = None) -> STTOptions:
config = dataclasses.replace(self._opts)
config.language = language or config.language
if config.detect_language:
config.language = None
return config
class SpeechStream(stt.SpeechStream):
_KEEPALIVE_MSG: str = json.dumps({"type": "KeepAlive"})
_CLOSE_MSG: str = json.dumps({"type": "CloseStream"})
_FINALIZE_MSG: str = json.dumps({"type": "Finalize"})
def __init__(
self,
*,
stt: STT,
opts: STTOptions,
conn_options: APIConnectOptions,
api_key: str,
http_session: aiohttp.ClientSession,
base_url: str,
) -> None:
super().__init__(
stt=stt, conn_options=conn_options, sample_rate=opts.sample_rate
)
if opts.detect_language and opts.language is None:
raise ValueError("language detection is not supported in streaming mode")
self._opts = opts
self._api_key = api_key
self._session = http_session
self._base_url = base_url
self._speaking = False
self._audio_duration_collector = PeriodicCollector(
callback=self._on_audio_duration_report,
duration=5.0,
)
self._audio_energy_filter: Optional[AudioEnergyFilter] = None
if opts.energy_filter:
if isinstance(opts.energy_filter, AudioEnergyFilter):
self._audio_energy_filter = opts.energy_filter
else:
self._audio_energy_filter = AudioEnergyFilter()
self._request_id = ""
self._reconnect_event = asyncio.Event()
def update_options(
self,
*,
language: DeepgramLanguages | str | None = None,
model: DeepgramModels | str | None = None,
interim_results: bool | None = None,
punctuate: bool | None = None,
smart_format: bool | None = None,
sample_rate: int | None = None,
no_delay: bool | None = None,
endpointing_ms: int | None = None,
filler_words: bool | None = None,
keywords: list[Tuple[str, float]] | None = None,
keyterms: list[str] | None = None,
profanity_filter: bool | None = None,
numerals: bool | None = None,
mip_opt_out: bool | None = None,
):
if language is not None:
self._opts.language = language
if model is not None:
self._opts.model = _validate_model(model, language)
if interim_results is not None:
self._opts.interim_results = interim_results
if punctuate is not None:
self._opts.punctuate = punctuate
if smart_format is not None:
self._opts.smart_format = smart_format
if sample_rate is not None:
self._opts.sample_rate = sample_rate
if no_delay is not None:
self._opts.no_delay = no_delay
if endpointing_ms is not None:
self._opts.endpointing_ms = endpointing_ms
if filler_words is not None:
self._opts.filler_words = filler_words
if keywords is not None:
self._opts.keywords = keywords
if keyterms is not None:
self._opts.keyterms = keyterms
if profanity_filter is not None:
self._opts.profanity_filter = profanity_filter
if numerals is not None:
self._opts.numerals = numerals
if mip_opt_out is not None:
self._opts.mip_opt_out = mip_opt_out
self._reconnect_event.set()
async def _run(self) -> None:
closing_ws = False
async def keepalive_task(ws: aiohttp.ClientWebSocketResponse):
# if we want to keep the connection alive even if no audio is sent,
# Deepgram expects a keepalive message.
# https://developers.deepgram.com/reference/listen-live#stream-keepalive
try:
while True:
await ws.send_str(SpeechStream._KEEPALIVE_MSG)
await asyncio.sleep(5)
except Exception:
return
@utils.log_exceptions(logger=logger)
async def send_task(ws: aiohttp.ClientWebSocketResponse):
nonlocal closing_ws
# forward audio to deepgram in chunks of 50ms
samples_50ms = self._opts.sample_rate // 20
audio_bstream = utils.audio.AudioByteStream(
sample_rate=self._opts.sample_rate,
num_channels=self._opts.num_channels,
samples_per_channel=samples_50ms,
)
has_ended = False
last_frame: Optional[rtc.AudioFrame] = None
async for data in self._input_ch:
frames: list[rtc.AudioFrame] = []
if isinstance(data, rtc.AudioFrame):
state = self._check_energy_state(data)
if state in (
AudioEnergyFilter.State.START,
AudioEnergyFilter.State.SPEAKING,
):
if last_frame:
frames.extend(
audio_bstream.write(last_frame.data.tobytes())
)
last_frame = None
frames.extend(audio_bstream.write(data.data.tobytes()))
elif state == AudioEnergyFilter.State.END:
# no need to buffer as we have cooldown period
frames.extend(audio_bstream.flush())
has_ended = True
elif state == AudioEnergyFilter.State.SILENCE:
# buffer the last silence frame, since it could contain beginning of speech
# TODO: improve accuracy by using a ring buffer with longer window
last_frame = data
elif isinstance(data, self._FlushSentinel):
frames.extend(audio_bstream.flush())
has_ended = True
for frame in frames:
self._audio_duration_collector.push(frame.duration)
await ws.send_bytes(frame.data.tobytes())
if has_ended:
self._audio_duration_collector.flush()
await ws.send_str(SpeechStream._FINALIZE_MSG)
has_ended = False
# tell deepgram we are done sending audio/inputs
closing_ws = True
await ws.send_str(SpeechStream._CLOSE_MSG)
@utils.log_exceptions(logger=logger)
async def recv_task(ws: aiohttp.ClientWebSocketResponse):
nonlocal closing_ws
while True:
msg = await ws.receive()
if msg.type in (
aiohttp.WSMsgType.CLOSED,
aiohttp.WSMsgType.CLOSE,
aiohttp.WSMsgType.CLOSING,
):
if closing_ws: # close is expected, see SpeechStream.aclose
return
# this will trigger a reconnection, see the _run loop
raise APIStatusError(
message="deepgram connection closed unexpectedly"
)
if msg.type != aiohttp.WSMsgType.TEXT:
logger.warning("unexpected deepgram message type %s", msg.type)
continue
try:
self._process_stream_event(json.loads(msg.data))
except Exception:
logger.exception("failed to process deepgram message")
ws: aiohttp.ClientWebSocketResponse | None = None
while True:
try:
ws = await self._connect_ws()
tasks = [
asyncio.create_task(send_task(ws)),
asyncio.create_task(recv_task(ws)),
asyncio.create_task(keepalive_task(ws)),
]
wait_reconnect_task = asyncio.create_task(self._reconnect_event.wait())
try:
done, _ = await asyncio.wait(
[asyncio.gather(*tasks), wait_reconnect_task],
return_when=asyncio.FIRST_COMPLETED,
) # type: ignore
# propagate exceptions from completed tasks
for task in done:
if task != wait_reconnect_task:
task.result()
if wait_reconnect_task not in done:
break
self._reconnect_event.clear()
finally:
await utils.aio.gracefully_cancel(*tasks, wait_reconnect_task)
finally:
if ws is not None:
await ws.close()
async def _connect_ws(self) -> aiohttp.ClientWebSocketResponse:
live_config: dict[str, Any] = {
"model": self._opts.model,
"punctuate": self._opts.punctuate,
"smart_format": self._opts.smart_format,
"no_delay": self._opts.no_delay,
"interim_results": self._opts.interim_results,
"encoding": "linear16",
"vad_events": True,
"sample_rate": self._opts.sample_rate,
"channels": self._opts.num_channels,
"endpointing": False
if self._opts.endpointing_ms == 0
else self._opts.endpointing_ms,
"filler_words": self._opts.filler_words,
"profanity_filter": self._opts.profanity_filter,
"numerals": self._opts.numerals,
"mip_opt_out": self._opts.mip_opt_out,
}
if self._opts.keywords:
live_config["keywords"] = self._opts.keywords
if self._opts.keyterms:
# the query param is `keyterm`
# See: https://developers.deepgram.com/docs/keyterm
live_config["keyterm"] = self._opts.keyterms
if self._opts.language:
live_config["language"] = self._opts.language
ws = await asyncio.wait_for(
self._session.ws_connect(
_to_deepgram_url(live_config, base_url=self._base_url, websocket=True),
headers={"Authorization": f"Token {self._api_key}"},
),
self._conn_options.timeout,
)
return ws
def _check_energy_state(self, frame: rtc.AudioFrame) -> AudioEnergyFilter.State:
if self._audio_energy_filter:
return self._audio_energy_filter.update(frame)
return AudioEnergyFilter.State.SPEAKING
def _on_audio_duration_report(self, duration: float) -> None:
usage_event = stt.SpeechEvent(
type=stt.SpeechEventType.RECOGNITION_USAGE,
request_id=self._request_id,
alternatives=[],
recognition_usage=stt.RecognitionUsage(audio_duration=duration),
)
self._event_ch.send_nowait(usage_event)
def _process_stream_event(self, data: dict) -> None:
assert self._opts.language is not None
if data["type"] == "SpeechStarted":
# This is a normal case. Deepgram's SpeechStarted events
# are not correlated with speech_final or utterance end.
# It's possible that we receive two in a row without an endpoint
# It's also possible we receive a transcript without a SpeechStarted event.
if self._speaking:
return
self._speaking = True
start_event = stt.SpeechEvent(type=stt.SpeechEventType.START_OF_SPEECH)
self._event_ch.send_nowait(start_event)
# see this page:
# https://developers.deepgram.com/docs/understand-endpointing-interim-results#using-endpointing-speech_final
# for more information about the different types of events
elif data["type"] == "Results":
metadata = data["metadata"]
request_id = metadata["request_id"]
is_final_transcript = data["is_final"]
is_endpoint = data["speech_final"]
self._request_id = request_id
alts = live_transcription_to_speech_data(self._opts.language, data)
# If, for some reason, we didn't get a SpeechStarted event but we got
# a transcript with text, we should start speaking. It's rare but has
# been observed.
if len(alts) > 0 and alts[0].text:
if not self._speaking:
self._speaking = True
start_event = stt.SpeechEvent(
type=stt.SpeechEventType.START_OF_SPEECH
)
self._event_ch.send_nowait(start_event)
if is_final_transcript:
final_event = stt.SpeechEvent(
type=stt.SpeechEventType.FINAL_TRANSCRIPT,
request_id=request_id,
alternatives=alts,
)
self._event_ch.send_nowait(final_event)
else:
interim_event = stt.SpeechEvent(
type=stt.SpeechEventType.INTERIM_TRANSCRIPT,
request_id=request_id,
alternatives=alts,
)
self._event_ch.send_nowait(interim_event)
# if we receive an endpoint, only end the speech if
# we either had a SpeechStarted event or we have a seen
# a non-empty transcript (deepgram doesn't have a SpeechEnded event)
if is_endpoint and self._speaking:
self._speaking = False
self._event_ch.send_nowait(
stt.SpeechEvent(type=stt.SpeechEventType.END_OF_SPEECH)
)
elif data["type"] == "Metadata":
pass # metadata is too noisy
else:
logger.warning("received unexpected message from deepgram %s", data)
def live_transcription_to_speech_data(
language: str, data: dict
) -> List[stt.SpeechData]:
dg_alts = data["channel"]["alternatives"]
speech_data = []
for alt in dg_alts:
sd = stt.SpeechData(
language=language,
start_time=alt["words"][0]["start"] if alt["words"] else 0,
end_time=alt["words"][-1]["end"] if alt["words"] else 0,
confidence=alt["confidence"],
text=alt["transcript"],
)
if language == "multi" and "languages" in alt:
sd.language = alt["languages"][0] # TODO: handle multiple languages
speech_data.append(sd)
return speech_data
def prerecorded_transcription_to_speech_event(
language: str | None, # language should be None when 'detect_language' is enabled
data: dict,
) -> stt.SpeechEvent:
# We only support one channel for now
request_id = data["metadata"]["request_id"]
channel = data["results"]["channels"][0]
dg_alts = channel["alternatives"]
# Use the detected language if enabled
# https://developers.deepgram.com/docs/language-detection
detected_language = channel.get("detected_language")
return stt.SpeechEvent(
request_id=request_id,
type=stt.SpeechEventType.FINAL_TRANSCRIPT,
alternatives=[
stt.SpeechData(
language=language or detected_language,
start_time=alt["words"][0]["start"] if alt["words"] else 0,
end_time=alt["words"][-1]["end"] if alt["words"] else 0,
confidence=alt["confidence"],
text=alt["transcript"],
)
for alt in dg_alts
],
)
def _to_deepgram_url(opts: dict, base_url: str, *, websocket: bool) -> str:
# don't modify the original opts
opts = opts.copy()
if opts.get("keywords"):
# convert keywords to a list of "keyword:intensifier"
opts["keywords"] = [
f"{keyword}:{intensifier}" for (keyword, intensifier) in opts["keywords"]
]
# lowercase bools
opts = {k: str(v).lower() if isinstance(v, bool) else v for k, v in opts.items()}
if websocket and base_url.startswith("http"):
base_url = base_url.replace("http", "ws", 1)
elif not websocket and base_url.startswith("ws"):
base_url = base_url.replace("ws", "http", 1)
return f"{base_url}?{urlencode(opts, doseq=True)}"
def _validate_model(
model: DeepgramModels | str, language: DeepgramLanguages | str | None
) -> DeepgramModels | str:
en_only_models = {
"nova-2-meeting",
"nova-2-phonecall",
"nova-2-finance",
"nova-2-conversationalai",
"nova-2-voicemail",
"nova-2-video",
"nova-2-medical",
"nova-2-drivethru",
"nova-2-automotive",
# nova-3 will support more languages, but english-only for now
"nova-3-general",
}
if language not in ("en-US", "en") and model in en_only_models:
logger.warning(
f"{model} does not support language {language}, falling back to nova-2-general"
)
return "nova-2-general"
return model
from __future__ import annotations
import asyncio
import json
import os
import weakref
from dataclasses import dataclass
from typing import Optional
from urllib.parse import urlencode
import aiohttp
from livekit.agents import (
APIConnectionError,
APIConnectOptions,
APIStatusError,
APITimeoutError,
tokenize,
tts,
utils,
)
from .log import logger
BASE_URL = "https://api.deepgram.com/v1/speak"
NUM_CHANNELS = 1
@dataclass
class _TTSOptions:
model: str
encoding: str
sample_rate: int
word_tokenizer: tokenize.WordTokenizer
class TTS(tts.TTS):
def __init__(
self,
*,
model: str = "aura-asteria-en",
encoding: str = "linear16",
sample_rate: int = 24000,
api_key: str | None = None,
base_url: str = BASE_URL,
word_tokenizer: tokenize.WordTokenizer = tokenize.basic.WordTokenizer(
ignore_punctuation=False
),
http_session: aiohttp.ClientSession | None = None,
) -> None:
"""
Create a new instance of Deepgram TTS.
Args:
model (str): TTS model to use. Defaults to "aura-asteria-en".
encoding (str): Audio encoding to use. Defaults to "linear16".
sample_rate (int): Sample rate of audio. Defaults to 24000.
api_key (str): Deepgram API key. If not provided, will look for DEEPGRAM_API_KEY in environment.
base_url (str): Base URL for Deepgram TTS API. Defaults to "https://api.deepgram.com/v1/speak"
word_tokenizer (tokenize.WordTokenizer): Tokenizer for processing text. Defaults to basic WordTokenizer.
http_session (aiohttp.ClientSession): Optional aiohttp session to use for requests.
"""
super().__init__(
capabilities=tts.TTSCapabilities(streaming=True),
sample_rate=sample_rate,
num_channels=NUM_CHANNELS,
)
api_key = api_key or os.environ.get("DEEPGRAM_API_KEY")
if not api_key:
raise ValueError(
"Deepgram API key required. Set DEEPGRAM_API_KEY or provide api_key."
)
self._opts = _TTSOptions(
model=model,
encoding=encoding,
sample_rate=sample_rate,
word_tokenizer=word_tokenizer,
)
self._session = http_session
self._api_key = api_key
self._base_url = base_url
self._streams = weakref.WeakSet[SynthesizeStream]()
self._pool = utils.ConnectionPool[aiohttp.ClientWebSocketResponse](
connect_cb=self._connect_ws,
close_cb=self._close_ws,
max_session_duration=3600, # 1 hour
mark_refreshed_on_get=False,
)
async def _connect_ws(self) -> aiohttp.ClientWebSocketResponse:
session = self._ensure_session()
config = {
"encoding": self._opts.encoding,
"model": self._opts.model,
"sample_rate": self._opts.sample_rate,
}
return await asyncio.wait_for(
session.ws_connect(
_to_deepgram_url(config, self._base_url, websocket=True),
headers={"Authorization": f"Token {self._api_key}"},
),
self._conn_options.timeout,
)
async def _close_ws(self, ws: aiohttp.ClientWebSocketResponse):
await ws.close()
def _ensure_session(self) -> aiohttp.ClientSession:
if not self._session:
self._session = utils.http_context.http_session()
return self._session
def update_options(
self,
*,
model: str | None = None,
sample_rate: int | None = None,
) -> None:
"""
args:
model (str): TTS model to use.
sample_rate (int): Sample rate of audio.
"""
if model is not None:
self._opts.model = model
if sample_rate is not None:
self._opts.sample_rate = sample_rate
# deepgram sets options upon connection, so we need to invalidate the pool
# to get a new connection with the updated options
self._pool.invalidate()
def synthesize(
self,
text: str,
*,
conn_options: Optional[APIConnectOptions] = None,
) -> "ChunkedStream":
return ChunkedStream(
tts=self,
input_text=text,
base_url=self._base_url,
api_key=self._api_key,
conn_options=conn_options,
opts=self._opts,
session=self._ensure_session(),
)
def stream(
self, *, conn_options: Optional[APIConnectOptions] = None
) -> "SynthesizeStream":
stream = SynthesizeStream(
tts=self,
pool=self._pool,
opts=self._opts,
)
self._streams.add(stream)
return stream
def prewarm(self) -> None:
self._pool.prewarm()
async def aclose(self) -> None:
for stream in list(self._streams):
await stream.aclose()
self._streams.clear()
await self._pool.aclose()
await super().aclose()
class ChunkedStream(tts.ChunkedStream):
def __init__(
self,
*,
tts: TTS,
base_url: str,
api_key: str,
input_text: str,
opts: _TTSOptions,
session: aiohttp.ClientSession,
conn_options: Optional[APIConnectOptions] = None,
) -> None:
super().__init__(tts=tts, input_text=input_text, conn_options=conn_options)
self._opts = opts
self._session = session
self._base_url = base_url
self._api_key = api_key
async def _run(self) -> None:
request_id = utils.shortuuid()
audio_bstream = utils.audio.AudioByteStream(
sample_rate=self._opts.sample_rate,
num_channels=NUM_CHANNELS,
)
try:
config = {
"encoding": self._opts.encoding,
"model": self._opts.model,
"sample_rate": self._opts.sample_rate,
}
async with self._session.post(
_to_deepgram_url(config, self._base_url, websocket=False),
headers={
"Authorization": f"Token {self._api_key}",
"Content-Type": "application/json",
},
json={"text": self._input_text},
timeout=self._conn_options.timeout,
) as res:
if res.status != 200:
raise APIStatusError(
message=res.reason or "Unknown error occurred.",
status_code=res.status,
request_id=request_id,
body=await res.json(),
)
async for bytes_data, _ in res.content.iter_chunks():
for frame in audio_bstream.write(bytes_data):
self._event_ch.send_nowait(
tts.SynthesizedAudio(
request_id=request_id,
frame=frame,
)
)
for frame in audio_bstream.flush():
self._event_ch.send_nowait(
tts.SynthesizedAudio(request_id=request_id, frame=frame)
)
except asyncio.TimeoutError as e:
raise APITimeoutError() from e
except aiohttp.ClientResponseError as e:
raise APIStatusError(
message=e.message,
status_code=e.status,
request_id=request_id,
body=None,
) from e
except Exception as e:
raise APIConnectionError() from e
class SynthesizeStream(tts.SynthesizeStream):
def __init__(
self,
*,
tts: TTS,
opts: _TTSOptions,
pool: utils.ConnectionPool[aiohttp.ClientWebSocketResponse],
):
super().__init__(tts=tts)
self._opts = opts
self._pool = pool
self._segments_ch = utils.aio.Chan[tokenize.WordStream]()
async def _run(self) -> None:
request_id = utils.shortuuid()
@utils.log_exceptions(logger=logger)
async def _tokenize_input():
# Converts incoming text into WordStreams and sends them into _segments_ch
word_stream = None
async for input in self._input_ch:
if isinstance(input, str):
if word_stream is None:
word_stream = self._opts.word_tokenizer.stream()
self._segments_ch.send_nowait(word_stream)
word_stream.push_text(input)
elif isinstance(input, self._FlushSentinel):
if word_stream:
word_stream.end_input()
word_stream = None
self._segments_ch.close()
@utils.log_exceptions(logger=logger)
async def _run_segments():
async for word_stream in self._segments_ch:
await self._run_ws(word_stream, request_id)
tasks = [
asyncio.create_task(_tokenize_input()),
asyncio.create_task(_run_segments()),
]
try:
await asyncio.gather(*tasks)
except asyncio.TimeoutError as e:
raise APITimeoutError() from e
except aiohttp.ClientResponseError as e:
raise APIStatusError(
message=e.message,
status_code=e.status,
request_id=request_id,
body=None,
) from e
except Exception as e:
raise APIConnectionError() from e
finally:
await utils.aio.gracefully_cancel(*tasks)
async def _run_ws(self, word_stream: tokenize.WordStream, request_id: str):
segment_id = utils.shortuuid()
audio_bstream = utils.audio.AudioByteStream(
sample_rate=self._opts.sample_rate,
num_channels=NUM_CHANNELS,
)
async def send_task(ws: aiohttp.ClientWebSocketResponse):
async for word in word_stream:
speak_msg = {"type": "Speak", "text": f"{word.token} "}
self._mark_started()
await ws.send_str(json.dumps(speak_msg))
# Always flush after a segment
flush_msg = {"type": "Flush"}
await ws.send_str(json.dumps(flush_msg))
async def recv_task(ws: aiohttp.ClientWebSocketResponse):
emitter = tts.SynthesizedAudioEmitter(
event_ch=self._event_ch,
request_id=request_id,
segment_id=segment_id,
)
while True:
msg = await ws.receive()
if msg.type in (
aiohttp.WSMsgType.CLOSE,
aiohttp.WSMsgType.CLOSED,
aiohttp.WSMsgType.CLOSING,
):
raise APIStatusError(
"Deepgram websocket connection closed unexpectedly",
request_id=request_id,
)
if msg.type == aiohttp.WSMsgType.BINARY:
data = msg.data
for frame in audio_bstream.write(data):
emitter.push(frame)
elif msg.type == aiohttp.WSMsgType.TEXT:
resp = json.loads(msg.data)
mtype = resp.get("type")
if mtype == "Flushed":
for frame in audio_bstream.flush():
emitter.push(frame)
emitter.flush()
break
elif mtype == "Warning":
logger.warning("Deepgram warning: %s", resp.get("warn_msg"))
elif mtype == "Metadata":
pass
else:
logger.debug("Unknown message type: %s", resp)
async with self._pool.connection() as ws:
tasks = [
asyncio.create_task(send_task(ws)),
asyncio.create_task(recv_task(ws)),
]
try:
await asyncio.gather(*tasks)
except asyncio.TimeoutError as e:
raise APITimeoutError() from e
except aiohttp.ClientResponseError as e:
raise APIStatusError(
message=e.message,
status_code=e.status,
request_id=request_id,
body=None,
) from e
except Exception as e:
raise APIConnectionError() from e
finally:
await utils.aio.gracefully_cancel(*tasks)
def _to_deepgram_url(
opts: dict,
base_url: str,
*,
websocket: bool,
) -> str:
if websocket and base_url.startswith("http"):
base_url = base_url.replace("http", "ws", 1)
elif not websocket and base_url.startswith("ws"):
base_url = base_url.replace("ws", "http", 1)
return f"{base_url}?{urlencode(opts, doseq=True)}"
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "0.7.3"
{
"name": "livekit-plugins-deepgram",
"private": true,
"version": "0.7.3"
}
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pathlib
import setuptools
import setuptools.command.build_py
here = pathlib.Path(__file__).parent.resolve()
about = {}
with open(os.path.join(here, "livekit", "plugins", "deepgram", "version.py"), "r") as f:
exec(f.read(), about)
setuptools.setup(
name="livekit-plugins-deepgram",
version=about["__version__"],
description="Agent Framework plugin for services using Deepgram's API.",
long_description=(here / "README.md").read_text(encoding="utf-8"),
long_description_content_type="text/markdown",
url="https://github.com/livekit/agents",
cmdclass={},
classifiers=[
"Intended Audience :: Developers",
"License :: OSI Approved :: Apache Software License",
"Topic :: Multimedia :: Sound/Audio",
"Topic :: Multimedia :: Video",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3 :: Only",
],
keywords=["webrtc", "realtime", "audio", "video", "livekit"],
license="Apache-2.0",
packages=setuptools.find_namespace_packages(include=["livekit.*"]),
python_requires=">=3.9.0",
install_requires=["livekit-agents>=0.12.16,<1.0.0", "numpy>=1.26"],
package_data={"livekit.plugins.deepgram": ["py.typed"]},
project_urls={
"Documentation": "https://docs.livekit.io",
"Website": "https://livekit.io/",
"Source": "https://github.com/livekit/agents",
},
)
# livekit-plugins-elevenlabs
## 0.8.2
### Patch Changes
- use 22.05khz by default for 11labs - [`a294d28c2af672a47f88f598f9fdb3fb13c39c38`](https://github.com/livekit/agents/commit/a294d28c2af672a47f88f598f9fdb3fb13c39c38) ([@davidzhao](https://github.com/davidzhao))
## 0.8.1
### Patch Changes
- Revert to using 'isFinal' in ElevenLabs for reliable audio packet completion detection - [#1676](https://github.com/livekit/agents/pull/1676) ([@jayeshp19](https://github.com/jayeshp19))
## 0.8.0
### Minor Changes
- use streaming AudioDecoder to handle compressed encoding - [#1584](https://github.com/livekit/agents/pull/1584) ([@davidzhao](https://github.com/davidzhao))
### Patch Changes
- added a tts.prewarm method to start the connection pool early. - [#1587](https://github.com/livekit/agents/pull/1587) ([@davidzhao](https://github.com/davidzhao))
- deprecated elevenlabs' optimize_stream_latency option - [#1587](https://github.com/livekit/agents/pull/1587) ([@davidzhao](https://github.com/davidzhao))
- increase elevenlabs websocket connection timeout to default 300 seconds - [#1582](https://github.com/livekit/agents/pull/1582) ([@jayeshp19](https://github.com/jayeshp19))
- updated livekit-agent reference to <1.0 - [#1607](https://github.com/livekit/agents/pull/1607) ([@davidzhao](https://github.com/davidzhao))
- Added speed parameter for voices. - [#1574](https://github.com/livekit/agents/pull/1574) ([@MatthiasGruba](https://github.com/MatthiasGruba))
E.g.:
```python
voice = Voice(
id="EXAVITQu4vr4xnSDxMaL",
name="Bella",
category="premade",
settings=VoiceSettings(
stability=0.71,
speed=1.2,
similarity_boost=0.5,
style=0.0,
use_speaker_boost=True,
),
)
use connection pool for elevenlabs websocket persistant connection - #1546 (@jayeshp19)
remove update options from tts synthesis stream - #1546 (@jayeshp19)
add update_options to TTS - #922 (@theomonnom)
pipelineagent: expose timing metrics & api errors wip - #957 (@theomonnom)
expose usage metrics - #984 (@theomonnom)
avoid returning tiny frames from TTS - #747 (@theomonnom)
11labs: send phoneme in one entire xml chunk - #766 (@theomonnom)
elevenlabs: fix send_task not closing properly - #596 (@theomonnom)
elevenlabs: update default model to eleven_turbo_v2_5 - #578 (@theomonnom)
test release - #435 (@theomonnom)
pull: ‘–rebase –autostash …’ - #435 (@theomonnom)
bump versions to update dependencies - #510 (@theomonnom)
test release - #435 (@theomonnom)
fix changesets release CI - #435 (@theomonnom)
release v0.8.0 - 6e74aa714c2dfaa8212db4528d7b59d095b6c660
(@theomonnom)
dev fixes - multiprocessing & voiceassistant - #493 (@theomonnom)
## livekit-plugins/livekit-plugins-elevenlabs/README.md
```md
# LiveKit Plugins Elevenlabs
Agent Framework plugin for voice synthesis with [ElevenLabs](https://elevenlabs.io/) API.
## Installation
```bash
pip install livekit-plugins-elevenlabs
You’ll need an API key from ElevenLabs. It can be set as an environment variable: ELEVEN_API_KEY
## livekit-plugins/livekit-plugins-elevenlabs/livekit/plugins/elevenlabs/__init__.py
```py
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .models import TTSEncoding, TTSModels
from .tts import DEFAULT_VOICE, TTS, Voice, VoiceSettings
from .version import __version__
__all__ = [
"TTS",
"Voice",
"VoiceSettings",
"TTSEncoding",
"TTSModels",
"DEFAULT_VOICE",
"__version__",
]
from livekit.agents import Plugin
from .log import logger
class ElevenLabsPlugin(Plugin):
def __init__(self):
super().__init__(__name__, __version__, __package__, logger)
Plugin.register_plugin(ElevenLabsPlugin())
# Cleanup docs of unexported modules
_module = dir()
NOT_IN_ALL = [m for m in _module if m not in __all__]
__pdoc__ = {}
for n in NOT_IN_ALL:
__pdoc__[n] = False
import logging
logger = logging.getLogger("livekit.plugins.elevenlabs")
from typing import Literal
TTSModels = Literal[
"eleven_monolingual_v1",
"eleven_multilingual_v1",
"eleven_multilingual_v2",
"eleven_turbo_v2",
"eleven_turbo_v2_5",
"eleven_flash_v2_5",
"eleven_flash_v2",
]
TTSEncoding = Literal[
"mp3_22050_32",
"mp3_44100",
"mp3_44100_32",
"mp3_44100_64",
"mp3_44100_96",
"mp3_44100_128",
"mp3_44100_192",
]
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import asyncio
import base64
import dataclasses
import json
import os
import weakref
from dataclasses import dataclass
from typing import Any, List, Optional
import aiohttp
from livekit.agents import (
APIConnectionError,
APIConnectOptions,
APIStatusError,
APITimeoutError,
tokenize,
tts,
utils,
)
from .log import logger
from .models import TTSEncoding, TTSModels
# by default, use 22.05kHz sample rate at 32kbps
# in our testing, reduce TTFB by about ~110ms
_DefaultEncoding: TTSEncoding = "mp3_22050_32"
def _sample_rate_from_format(output_format: TTSEncoding) -> int:
split = output_format.split("_") # e.g: mp3_44100
return int(split[1])
@dataclass
class VoiceSettings:
stability: float # [0.0 - 1.0]
similarity_boost: float # [0.0 - 1.0]
style: float | None = None # [0.0 - 1.0]
speed: float | None = 1.0 # [0.8 - 1.2]
use_speaker_boost: bool | None = False
@dataclass
class Voice:
id: str
name: str
category: str
settings: VoiceSettings | None = None
DEFAULT_VOICE = Voice(
id="EXAVITQu4vr4xnSDxMaL",
name="Bella",
category="premade",
settings=VoiceSettings(
stability=0.71,
speed=1.0,
similarity_boost=0.5,
style=0.0,
use_speaker_boost=True,
),
)
API_BASE_URL_V1 = "https://api.elevenlabs.io/v1"
AUTHORIZATION_HEADER = "xi-api-key"
WS_INACTIVITY_TIMEOUT = 300
@dataclass
class _TTSOptions:
api_key: str
voice: Voice
model: TTSModels | str
language: str | None
base_url: str
encoding: TTSEncoding
sample_rate: int
streaming_latency: int
word_tokenizer: tokenize.WordTokenizer
chunk_length_schedule: list[int]
enable_ssml_parsing: bool
inactivity_timeout: int
class TTS(tts.TTS):
def __init__(
self,
*,
voice: Voice = DEFAULT_VOICE,
model: TTSModels | str = "eleven_flash_v2_5",
encoding: TTSEncoding | None = None,
api_key: str | None = None,
base_url: str | None = None,
streaming_latency: int = 0,
inactivity_timeout: int = WS_INACTIVITY_TIMEOUT,
word_tokenizer: Optional[tokenize.WordTokenizer] = None,
enable_ssml_parsing: bool = False,
chunk_length_schedule: list[int] = [80, 120, 200, 260], # range is [50, 500]
http_session: aiohttp.ClientSession | None = None,
# deprecated
model_id: TTSModels | str | None = None,
language: str | None = None,
) -> None:
"""
Create a new instance of ElevenLabs TTS.
Args:
voice (Voice): Voice configuration. Defaults to `DEFAULT_VOICE`.
model (TTSModels | str): TTS model to use. Defaults to "eleven_turbo_v2_5".
api_key (str | None): ElevenLabs API key. Can be set via argument or `ELEVEN_API_KEY` environment variable.
base_url (str | None): Custom base URL for the API. Optional.
streaming_latency (int): Optimize for streaming latency, defaults to 0 - disabled. 4 for max latency optimizations. deprecated
inactivity_timeout (int): Inactivity timeout in seconds for the websocket connection. Defaults to 300.
word_tokenizer (tokenize.WordTokenizer): Tokenizer for processing text. Defaults to basic WordTokenizer.
enable_ssml_parsing (bool): Enable SSML parsing for input text. Defaults to False.
chunk_length_schedule (list[int]): Schedule for chunk lengths, ranging from 50 to 500. Defaults to [80, 120, 200, 260].
http_session (aiohttp.ClientSession | None): Custom HTTP session for API requests. Optional.
language (str | None): Language code for the TTS model, as of 10/24/24 only valid for "eleven_turbo_v2_5". Optional.
"""
if not encoding:
encoding = _DefaultEncoding
super().__init__(
capabilities=tts.TTSCapabilities(
streaming=True,
),
sample_rate=_sample_rate_from_format(encoding),
num_channels=1,
)
if model_id is not None:
logger.warning(
"model_id is deprecated and will be removed in 1.5.0, use model instead",
)
model = model_id
api_key = api_key or os.environ.get("ELEVEN_API_KEY")
if not api_key:
raise ValueError(
"ElevenLabs API key is required, either as argument or set ELEVEN_API_KEY environmental variable"
)
if word_tokenizer is None:
word_tokenizer = tokenize.basic.WordTokenizer(
ignore_punctuation=False # punctuation can help for intonation
)
self._opts = _TTSOptions(
voice=voice,
model=model,
api_key=api_key,
base_url=base_url or API_BASE_URL_V1,
encoding=encoding,
sample_rate=self.sample_rate,
streaming_latency=streaming_latency,
word_tokenizer=word_tokenizer,
chunk_length_schedule=chunk_length_schedule,
enable_ssml_parsing=enable_ssml_parsing,
language=language,
inactivity_timeout=inactivity_timeout,
)
self._session = http_session
self._streams = weakref.WeakSet[SynthesizeStream]()
def _ensure_session(self) -> aiohttp.ClientSession:
if not self._session:
self._session = utils.http_context.http_session()
return self._session
async def list_voices(self) -> List[Voice]:
async with self._ensure_session().get(
f"{self._opts.base_url}/voices",
headers={AUTHORIZATION_HEADER: self._opts.api_key},
) as resp:
return _dict_to_voices_list(await resp.json())
def update_options(
self,
*,
voice: Voice = DEFAULT_VOICE,
model: TTSModels | str = "eleven_turbo_v2_5",
language: str | None = None,
) -> None:
"""
Args:
voice (Voice): Voice configuration. Defaults to `DEFAULT_VOICE`.
model (TTSModels | str): TTS model to use. Defaults to "eleven_turbo_v2_5".
language (str | None): Language code for the TTS model. Optional.
"""
self._opts.model = model or self._opts.model
self._opts.voice = voice or self._opts.voice
self._opts.language = language or self._opts.language
def synthesize(
self,
text: str,
*,
conn_options: Optional[APIConnectOptions] = None,
) -> "ChunkedStream":
return ChunkedStream(
tts=self,
input_text=text,
conn_options=conn_options,
opts=self._opts,
session=self._ensure_session(),
)
def stream(
self, *, conn_options: Optional[APIConnectOptions] = None
) -> "SynthesizeStream":
stream = SynthesizeStream(
tts=self,
conn_options=conn_options,
opts=self._opts,
session=self._ensure_session(),
)
self._streams.add(stream)
return stream
async def aclose(self) -> None:
for stream in list(self._streams):
await stream.aclose()
self._streams.clear()
await super().aclose()
class ChunkedStream(tts.ChunkedStream):
"""Synthesize using the chunked api endpoint"""
def __init__(
self,
*,
tts: TTS,
input_text: str,
opts: _TTSOptions,
conn_options: Optional[APIConnectOptions] = None,
session: aiohttp.ClientSession,
) -> None:
super().__init__(tts=tts, input_text=input_text, conn_options=conn_options)
self._opts, self._session = opts, session
async def _run(self) -> None:
request_id = utils.shortuuid()
voice_settings = (
_strip_nones(dataclasses.asdict(self._opts.voice.settings))
if self._opts.voice.settings
else None
)
data = {
"text": self._input_text,
"model_id": self._opts.model,
"voice_settings": voice_settings,
}
decoder = utils.codecs.AudioStreamDecoder(
sample_rate=self._opts.sample_rate,
num_channels=1,
)
decode_task: asyncio.Task | None = None
try:
async with self._session.post(
_synthesize_url(self._opts),
headers={AUTHORIZATION_HEADER: self._opts.api_key},
json=data,
) as resp:
if not resp.content_type.startswith("audio/"):
content = await resp.text()
logger.error("11labs returned non-audio data: %s", content)
return
async def _decode_loop():
try:
async for bytes_data, _ in resp.content.iter_chunks():
decoder.push(bytes_data)
finally:
decoder.end_input()
decode_task = asyncio.create_task(_decode_loop())
emitter = tts.SynthesizedAudioEmitter(
event_ch=self._event_ch,
request_id=request_id,
)
async for frame in decoder:
emitter.push(frame)
emitter.flush()
except asyncio.TimeoutError as e:
raise APITimeoutError() from e
except aiohttp.ClientResponseError as e:
raise APIStatusError(
message=e.message,
status_code=e.status,
request_id=None,
body=None,
) from e
except Exception as e:
raise APIConnectionError() from e
finally:
if decode_task:
await utils.aio.gracefully_cancel(decode_task)
await decoder.aclose()
class SynthesizeStream(tts.SynthesizeStream):
"""Streamed API using websockets"""
def __init__(
self,
*,
tts: TTS,
session: aiohttp.ClientSession,
opts: _TTSOptions,
conn_options: Optional[APIConnectOptions] = None,
):
super().__init__(tts=tts, conn_options=conn_options)
self._opts, self._session = opts, session
async def _run(self) -> None:
request_id = utils.shortuuid()
self._segments_ch = utils.aio.Chan[tokenize.WordStream]()
@utils.log_exceptions(logger=logger)
async def _tokenize_input():
"""tokenize text from the input_ch to words"""
word_stream = None
async for input in self._input_ch:
if isinstance(input, str):
if word_stream is None:
# new segment (after flush for e.g)
word_stream = self._opts.word_tokenizer.stream()
self._segments_ch.send_nowait(word_stream)
word_stream.push_text(input)
elif isinstance(input, self._FlushSentinel):
if word_stream is not None:
word_stream.end_input()
word_stream = None
if word_stream is not None:
word_stream.end_input()
self._segments_ch.close()
@utils.log_exceptions(logger=logger)
async def _process_segments():
async for word_stream in self._segments_ch:
await self._run_ws(word_stream, request_id)
tasks = [
asyncio.create_task(_tokenize_input()),
asyncio.create_task(_process_segments()),
]
try:
await asyncio.gather(*tasks)
except asyncio.TimeoutError as e:
raise APITimeoutError() from e
except aiohttp.ClientResponseError as e:
raise APIStatusError(
message=e.message,
status_code=e.status,
request_id=request_id,
body=None,
) from e
except Exception as e:
raise APIConnectionError() from e
finally:
await utils.aio.gracefully_cancel(*tasks)
async def _run_ws(
self,
word_stream: tokenize.WordStream,
request_id: str,
) -> None:
ws_conn = await self._session.ws_connect(
_stream_url(self._opts),
headers={AUTHORIZATION_HEADER: self._opts.api_key},
)
segment_id = utils.shortuuid()
decoder = utils.codecs.AudioStreamDecoder(
sample_rate=self._opts.sample_rate,
num_channels=1,
)
# 11labs protocol expects the first message to be an "init msg"
init_pkt = dict(
text=" ",
voice_settings=_strip_nones(dataclasses.asdict(self._opts.voice.settings))
if self._opts.voice.settings
else None,
generation_config=dict(
chunk_length_schedule=self._opts.chunk_length_schedule
),
)
await ws_conn.send_str(json.dumps(init_pkt))
eos_sent = False
@utils.log_exceptions(logger=logger)
async def send_task():
nonlocal eos_sent
xml_content = []
async for data in word_stream:
text = data.token
# send the xml phoneme in one go
if (
self._opts.enable_ssml_parsing
and data.token.startswith("<phoneme")
or xml_content
):
xml_content.append(text)
if data.token.find("</phoneme>") > -1:
text = self._opts.word_tokenizer.format_words(xml_content)
xml_content = []
else:
continue
data_pkt = dict(text=f"{text} ") # must always end with a space
self._mark_started()
await ws_conn.send_str(json.dumps(data_pkt))
if xml_content:
logger.warning("11labs stream ended with incomplete xml content")
# no more token, mark eos
eos_pkt = dict(text="")
await ws_conn.send_str(json.dumps(eos_pkt))
eos_sent = True
# consumes from decoder and generates events
@utils.log_exceptions(logger=logger)
async def generate_task():
emitter = tts.SynthesizedAudioEmitter(
event_ch=self._event_ch,
request_id=request_id,
segment_id=segment_id,
)
async for frame in decoder:
emitter.push(frame)
emitter.flush()
# receives from ws and decodes audio
@utils.log_exceptions(logger=logger)
async def recv_task():
nonlocal eos_sent
while True:
msg = await ws_conn.receive()
if msg.type in (
aiohttp.WSMsgType.CLOSED,
aiohttp.WSMsgType.CLOSE,
aiohttp.WSMsgType.CLOSING,
):
if not eos_sent:
raise APIStatusError(
"11labs connection closed unexpectedly, not all tokens have been consumed",
request_id=request_id,
)
return
if msg.type != aiohttp.WSMsgType.TEXT:
logger.warning("unexpected 11labs message type %s", msg.type)
continue
data = json.loads(msg.data)
if data.get("audio"):
b64data = base64.b64decode(data["audio"])
decoder.push(b64data)
elif data.get("isFinal"):
decoder.end_input()
break
elif data.get("error"):
raise APIStatusError(
message=data["error"],
status_code=500,
request_id=request_id,
body=None,
)
else:
raise APIStatusError(
message=f"unexpected 11labs message {data}",
status_code=500,
request_id=request_id,
body=None,
)
tasks = [
asyncio.create_task(send_task()),
asyncio.create_task(recv_task()),
asyncio.create_task(generate_task()),
]
try:
await asyncio.gather(*tasks)
except asyncio.TimeoutError as e:
raise APITimeoutError() from e
except aiohttp.ClientResponseError as e:
raise APIStatusError(
message=e.message,
status_code=e.status,
request_id=request_id,
body=None,
) from e
except APIStatusError:
raise
except Exception as e:
raise APIConnectionError() from e
finally:
await utils.aio.gracefully_cancel(*tasks)
await decoder.aclose()
if ws_conn is not None:
await ws_conn.close()
def _dict_to_voices_list(data: dict[str, Any]):
voices: List[Voice] = []
for voice in data["voices"]:
voices.append(
Voice(
id=voice["voice_id"],
name=voice["name"],
category=voice["category"],
settings=None,
)
)
return voices
def _strip_nones(data: dict[str, Any]):
return {k: v for k, v in data.items() if v is not None}
def _synthesize_url(opts: _TTSOptions) -> str:
base_url = opts.base_url
voice_id = opts.voice.id
model_id = opts.model
output_format = opts.encoding
url = (
f"{base_url}/text-to-speech/{voice_id}/stream?"
f"model_id={model_id}&output_format={output_format}"
)
if opts.streaming_latency:
url += f"&optimize_streaming_latency={opts.streaming_latency}"
return url
def _stream_url(opts: _TTSOptions) -> str:
base_url = opts.base_url
voice_id = opts.voice.id
model_id = opts.model
output_format = opts.encoding
enable_ssml = str(opts.enable_ssml_parsing).lower()
language = opts.language
inactivity_timeout = opts.inactivity_timeout
url = (
f"{base_url}/text-to-speech/{voice_id}/stream-input?"
f"model_id={model_id}&output_format={output_format}&"
f"enable_ssml_parsing={enable_ssml}&inactivity_timeout={inactivity_timeout}"
)
if language is not None:
url += f"&language_code={language}"
if opts.streaming_latency:
url += f"&optimize_streaming_latency={opts.streaming_latency}"
return url
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "0.8.2"
{
"name": "livekit-plugins-elevenlabs",
"private": true,
"version": "0.8.2"
}
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pathlib
import setuptools
import setuptools.command.build_py
here = pathlib.Path(__file__).parent.resolve()
about = {}
with open(
os.path.join(here, "livekit", "plugins", "elevenlabs", "version.py"), "r"
) as f:
exec(f.read(), about)
setuptools.setup(
name="livekit-plugins-elevenlabs",
version=about["__version__"],
description="Agent Framework plugin for voice synthesis with ElevenLabs' API.",
long_description=(here / "README.md").read_text(encoding="utf-8"),
long_description_content_type="text/markdown",
url="https://github.com/livekit/agents",
cmdclass={},
classifiers=[
"Intended Audience :: Developers",
"License :: OSI Approved :: Apache Software License",
"Topic :: Multimedia :: Sound/Audio",
"Topic :: Multimedia :: Video",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3 :: Only",
],
keywords=["webrtc", "realtime", "audio", "video", "livekit", "elevenlabs"],
license="Apache-2.0",
packages=setuptools.find_namespace_packages(include=["livekit.*"]),
python_requires=">=3.9.0",
install_requires=["livekit-agents[codecs]>=0.12.16,<1.0.0"],
package_data={"livekit.plugins.elevenlabs": ["py.typed"]},
project_urls={
"Documentation": "https://docs.livekit.io",
"Website": "https://livekit.io/",
"Source": "https://github.com/livekit/agents",
},
)
# livekit-plugins-fal
## 0.2.4
### Patch Changes
- updated livekit-agent reference to <1.0 - [#1607](https://github.com/livekit/agents/pull/1607) ([@davidzhao](https://github.com/davidzhao))
## 0.2.3
### Patch Changes
- publish package - [`ed974f81a2eab7c1b2d7cff3a27c868ddebb45ee`](https://github.com/livekit/agents/commit/ed974f81a2eab7c1b2d7cff3a27c868ddebb45ee) ([@davidzhao](https://github.com/davidzhao))
## 0.2.2
### Patch Changes
- fix release - [#1176](https://github.com/livekit/agents/pull/1176) ([@theomonnom](https://github.com/theomonnom))
## 0.2.1
### Patch Changes
- feat: stt retry & stt.FallbackAdapter - [#1114](https://github.com/livekit/agents/pull/1114) ([@theomonnom](https://github.com/theomonnom))
## 0.2.0
### Minor Changes
- initial version - [#991](https://github.com/livekit/agents/pull/991) ([@jayeshp19](https://github.com/jayeshp19))
## 0.1.0
- Initial release
# LiveKit Plugins fal
This plugin provides a simple way to integrate fal.ai models into the LiveKit Agent Framework. currently supports Wizper model for STT.
## Installation
```bash
pip install livekit-plugins-fal
You’ll need an API key from fal. It can be set as an environment variable: FAL_KEY
## livekit-plugins/livekit-plugins-fal/livekit/plugins/fal/__init__.py
```py
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .stt import WizperSTT
from .version import __version__
__all__ = ["WizperSTT", "__version__"]
from livekit.agents import Plugin
from .log import logger
class FalPlugin(Plugin):
def __init__(self):
super().__init__(__name__, __version__, __package__, logger)
Plugin.register_plugin(FalPlugin())
import logging
logger = logging.getLogger("livekit.plugins.fal")
from __future__ import annotations
import dataclasses
import os
from dataclasses import dataclass
from typing import Optional
import fal_client
from livekit import rtc
from livekit.agents import (
APIConnectionError,
APIConnectOptions,
stt,
)
from livekit.agents.stt import SpeechEventType, STTCapabilities
from livekit.agents.utils import AudioBuffer
@dataclass
class _STTOptions:
language: str
task: str
chunk_level: str
version: str
class WizperSTT(stt.STT):
def __init__(
self,
*,
language: Optional[str] = "en",
task: Optional[str] = "transcribe",
chunk_level: Optional[str] = "segment",
version: Optional[str] = "3",
):
super().__init__(
capabilities=STTCapabilities(streaming=False, interim_results=True)
)
self._api_key = os.getenv("FAL_KEY")
self._opts = _STTOptions(
language=language or "en",
task=task or "transcribe",
chunk_level=chunk_level or "segment",
version=version or "3",
)
self._fal_client = fal_client.AsyncClient()
if not self._api_key:
raise ValueError(
"fal AI API key is required. It should be set with env FAL_KEY"
)
def update_options(self, *, language: Optional[str] = None) -> None:
self._opts.language = language or self._opts.language
def _sanitize_options(
self,
*,
language: Optional[str] = None,
task: Optional[str] = None,
chunk_level: Optional[str] = None,
version: Optional[str] = None,
) -> _STTOptions:
config = dataclasses.replace(self._opts)
config.language = language or config.language
config.task = task or config.task
config.chunk_level = chunk_level or config.chunk_level
config.version = version or config.version
return config
async def _recognize_impl(
self,
buffer: AudioBuffer,
*,
language: str | None,
conn_options: APIConnectOptions,
) -> stt.SpeechEvent:
try:
config = self._sanitize_options(language=language)
data_uri = fal_client.encode(
rtc.combine_audio_frames(buffer).to_wav_bytes(), "audio/x-wav"
)
response = await self._fal_client.run(
"fal-ai/wizper",
arguments={
"audio_url": data_uri,
"task": config.task,
"language": config.language,
"chunk_level": config.chunk_level,
"version": config.version,
},
timeout=conn_options.timeout,
)
text = response.get("text", "")
return self._transcription_to_speech_event(text=text)
except fal_client.client.FalClientError as e:
raise APIConnectionError() from e
def _transcription_to_speech_event(
self, event_type=SpeechEventType.FINAL_TRANSCRIPT, text=None
) -> stt.SpeechEvent:
return stt.SpeechEvent(
type=event_type,
alternatives=[stt.SpeechData(text=text, language=self._opts.language)],
)
async def aclose(self) -> None:
await self._fal_client._client.aclose()
# Copyright 2023 LiveKit, Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "0.2.4"
{
"name": "livekit-plugins-fal",
"private": true,
"version": "0.2.4"
}
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pathlib
import setuptools
import setuptools.command.build_py
here = pathlib.Path(__file__).parent.resolve()
about = {}
with open(os.path.join(here, "livekit", "plugins", "fal", "version.py"), "r") as f:
exec(f.read(), about)
setuptools.setup(
name="livekit-plugins-fal",
version=about["__version__"],
description="fal plugin template for LiveKit Agents",
long_description=(here / "README.md").read_text(encoding="utf-8"),
long_description_content_type="text/markdown",
url="https://github.com/livekit/agents",
cmdclass={},
classifiers=[
"Intended Audience :: Developers",
"License :: OSI Approved :: Apache Software License",
"Topic :: Multimedia :: Sound/Audio",
"Topic :: Multimedia :: Video",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3 :: Only",
],
keywords=["webrtc", "realtime", "audio", "video", "livekit"],
license="Apache-2.0",
packages=setuptools.find_namespace_packages(include=["livekit.*"]),
python_requires=">=3.9.0",
install_requires=["livekit-agents>=0.12.16,<1.0.0", "fal_client"],
package_data={"livekit.plugins.fal": ["py.typed"]},
project_urls={
"Documentation": "https://docs.livekit.io",
"Website": "https://livekit.io/",
"Source": "https://github.com/livekit/agents",
},
)
# livekit-plugins-google
## 0.11.3
### Patch Changes
- google tts: configure api_endpoint based on location - [#1890](https://github.com/livekit/agents/pull/1890) ([@jayeshp19](https://github.com/jayeshp19))
## 0.11.2
### Patch Changes
- fix: double transcript issue for google stt - [#1694](https://github.com/livekit/agents/pull/1694) ([@jayeshp19](https://github.com/jayeshp19))
## 0.11.1
### Patch Changes
- allow configurable api version in gemini realtime - [#1656](https://github.com/livekit/agents/pull/1656) ([@jayeshp19](https://github.com/jayeshp19))
## 0.11.0
### Minor Changes
- Add simple video input support for gemini live - [#1536](https://github.com/livekit/agents/pull/1536) ([@bcherry](https://github.com/bcherry))
### Patch Changes
- use streaming AudioDecoder to handle compressed encoding - [#1584](https://github.com/livekit/agents/pull/1584) ([@davidzhao](https://github.com/davidzhao))
- updated livekit-agent reference to <1.0 - [#1607](https://github.com/livekit/agents/pull/1607) ([@davidzhao](https://github.com/davidzhao))
## 0.10.6
### Patch Changes
- google stt: change default model to `latest_long` - [#1552](https://github.com/livekit/agents/pull/1552) ([@jayeshp19](https://github.com/jayeshp19))
- feat: connection pooling. speeds up generation with STT/TTS providers - [#1538](https://github.com/livekit/agents/pull/1538) ([@davidzhao](https://github.com/davidzhao))
- fix: functioncall cancellation ids in realtime - [#1572](https://github.com/livekit/agents/pull/1572) ([@jayeshp19](https://github.com/jayeshp19))
- google-genai version bump & remove id feild from function call and function response - [#1559](https://github.com/livekit/agents/pull/1559) ([@jayeshp19](https://github.com/jayeshp19))
## 0.10.5
### Patch Changes
- fix(google): require min confidence score due to aggressive generation - [#1507](https://github.com/livekit/agents/pull/1507) ([@davidzhao](https://github.com/davidzhao))
## 0.10.4
### Patch Changes
- Gemini realtime : rollback default model to `gemini-2.0-flash-exp` - [#1489](https://github.com/livekit/agents/pull/1489) ([@jayeshp19](https://github.com/jayeshp19))
## 0.10.3
### Patch Changes
- Gemini Realtime: Transcribe model audio via gemini api & use latest model as default for google plugin - [#1446](https://github.com/livekit/agents/pull/1446) ([@jayeshp19](https://github.com/jayeshp19))
- Update to support passing chirp_2 location for other STT credentials - [#1098](https://github.com/livekit/agents/pull/1098) ([@brightsparc](https://github.com/brightsparc))
- Added an additional field in LLM capabilities class to check if model providers support function call history within chat context without needing function definitions. - [#1441](https://github.com/livekit/agents/pull/1441) ([@jayeshp19](https://github.com/jayeshp19))
## 0.10.2
### Patch Changes
- gemini-realtime: fix input audio sample rate - [#1411](https://github.com/livekit/agents/pull/1411) ([@jayeshp19](https://github.com/jayeshp19))
- chore: Replace ValueError with logger.warning for missing GOOGLE_APPLICATION_CREDENTIALS environment variable - [#1415](https://github.com/livekit/agents/pull/1415) ([@hironow](https://github.com/hironow))
## 0.10.1
### Patch Changes
- fix: update default model to chirp2 in google stt & update generate_reply method in gemini realtime - [#1401](https://github.com/livekit/agents/pull/1401) ([@jayeshp19](https://github.com/jayeshp19))
## 0.10.0
### Minor Changes
- support gemini LLM - [#1382](https://github.com/livekit/agents/pull/1382) ([@jayeshp19](https://github.com/jayeshp19))
### Patch Changes
- fix: address breaking change from google-genai >= 0.3.0 - [#1383](https://github.com/livekit/agents/pull/1383) ([@jayeshp19](https://github.com/jayeshp19))
- gemini improvements: exception handling, transcription & Ensure contents.parts is non-empty in gemini contex - [#1398](https://github.com/livekit/agents/pull/1398) ([@jayeshp19](https://github.com/jayeshp19))
- support transcriber session for user/agent audio - [#1321](https://github.com/livekit/agents/pull/1321) ([@jayeshp19](https://github.com/jayeshp19))
## 0.9.1
### Patch Changes
- fetch fresh client on update location and small fix for max_session_duration (4 mins) - [#1342](https://github.com/livekit/agents/pull/1342) ([@jayeshp19](https://github.com/jayeshp19))
- fix Google STT handling of session timeouts - [#1337](https://github.com/livekit/agents/pull/1337) ([@davidzhao](https://github.com/davidzhao))
## 0.9.0
### Minor Changes
- make multimodal class generic and support gemini live api - [#1240](https://github.com/livekit/agents/pull/1240) ([@jayeshp19](https://github.com/jayeshp19))
### Patch Changes
- fix: Ensure STT exceptions are being propagated - [#1291](https://github.com/livekit/agents/pull/1291) ([@davidzhao](https://github.com/davidzhao))
## 0.8.1
### Patch Changes
- fix release - [#1176](https://github.com/livekit/agents/pull/1176) ([@theomonnom](https://github.com/theomonnom))
## 0.8.0
### Minor Changes
- Add support for google STT chirp_2 model. - [#1089](https://github.com/livekit/agents/pull/1089) ([@brightsparc](https://github.com/brightsparc))
### Patch Changes
- feat: stt retry & stt.FallbackAdapter - [#1114](https://github.com/livekit/agents/pull/1114) ([@theomonnom](https://github.com/theomonnom))
- fix: add retry logic for google stt abort exception - [#1100](https://github.com/livekit/agents/pull/1100) ([@jayeshp19](https://github.com/jayeshp19))
- feat: tts retry & tts.FallbackAdapter - [#1074](https://github.com/livekit/agents/pull/1074) ([@theomonnom](https://github.com/theomonnom))
- google STT - use the baseclass resampler - [#1106](https://github.com/livekit/agents/pull/1106) ([@jayeshp19](https://github.com/jayeshp19))
## 0.7.3
### Patch Changes
- added catch for aborted speech - [#1055](https://github.com/livekit/agents/pull/1055) ([@jayeshp19](https://github.com/jayeshp19))
- Make Google STT keywords match Deepgram - [#1067](https://github.com/livekit/agents/pull/1067) ([@martin-purplefish](https://github.com/martin-purplefish))
- Add support for boosting phrases in Google STT - [#1066](https://github.com/livekit/agents/pull/1066) ([@martin-purplefish](https://github.com/martin-purplefish))
## 0.7.2
### Patch Changes
- add update_options to TTS - [#922](https://github.com/livekit/agents/pull/922) ([@theomonnom](https://github.com/theomonnom))
- Additional options enabled on Google TTS - [#945](https://github.com/livekit/agents/pull/945) ([@hari-truviz](https://github.com/hari-truviz))
- pipelineagent: expose timing metrics & api errors wip - [#957](https://github.com/livekit/agents/pull/957) ([@theomonnom](https://github.com/theomonnom))
- expose usage metrics - [#984](https://github.com/livekit/agents/pull/984) ([@theomonnom](https://github.com/theomonnom))
## 0.7.1
### Patch Changes
- avoid returning tiny frames from TTS - [#747](https://github.com/livekit/agents/pull/747) ([@theomonnom](https://github.com/theomonnom))
## 0.7.0
### Minor Changes
- Enable use of Google STT with Application Default Credentials. - [#721](https://github.com/livekit/agents/pull/721) ([@rsinnet](https://github.com/rsinnet))
### Patch Changes
- google-tts: ignore wav header - [#703](https://github.com/livekit/agents/pull/703) ([@theomonnom](https://github.com/theomonnom))
## 0.6.3
### Patch Changes
- Fix Google STT exception when no valid speech is recognized - [#680](https://github.com/livekit/agents/pull/680) ([@davidzhao](https://github.com/davidzhao))
## 0.6.2
### Patch Changes
- stt/tts: fix unread inputs when the input channel is closed - [#594](https://github.com/livekit/agents/pull/594) ([@theomonnom](https://github.com/theomonnom))
## 0.6.1
### Patch Changes
- fix end_input not flushing & unhandled flush messages - [#528](https://github.com/livekit/agents/pull/528) ([@theomonnom](https://github.com/theomonnom))
## 0.6.0
### Minor Changes
- dev prerelease - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
### Patch Changes
- test release - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
- pull: '--rebase --autostash ...' - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
- Default loglevel to warn - [#472](https://github.com/livekit/agents/pull/472) ([@lukasIO](https://github.com/lukasIO))
- bump versions to update dependencies - [#510](https://github.com/livekit/agents/pull/510) ([@theomonnom](https://github.com/theomonnom))
- test release - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
- fix changesets release CI - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
- release v0.8.0 - [`6e74aa714c2dfaa8212db4528d7b59d095b6c660`](https://github.com/livekit/agents/commit/6e74aa714c2dfaa8212db4528d7b59d095b6c660) ([@theomonnom](https://github.com/theomonnom))
- dev fixes - multiprocessing & voiceassistant - [#493](https://github.com/livekit/agents/pull/493) ([@theomonnom](https://github.com/theomonnom))
## 0.6.0-dev.7
### Patch Changes
- pull: '--rebase --autostash ...' - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
## 0.6.0-dev.6
### Patch Changes
- test release - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
## 0.6.0-dev.5
### Patch Changes
- test release - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
## 0.6.0-dev.4
### Patch Changes
- fix changesets release CI - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
## 0.6.0-dev.3
### Patch Changes
- bump versions to update dependencies - [#510](https://github.com/livekit/agents/pull/510) ([@theomonnom](https://github.com/theomonnom))
## 0.6.0-dev.2
### Patch Changes
- dev fixes - multiprocessing & voiceassistant - [#493](https://github.com/livekit/agents/pull/493) ([@theomonnom](https://github.com/theomonnom))
## 0.6.0-dev.1
### Minor Changes
- dev prerelease - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
## 0.5.2-dev.0
### Patch Changes
- Default loglevel to warn - [#472](https://github.com/livekit/agents/pull/472) ([@lukasIO](https://github.com/lukasIO))
# LiveKit Plugins Google
Agent Framework plugin for services from Google Cloud. Currently supporting Google's [Speech-to-Text](https://cloud.google.com/speech-to-text) API.
## Installation
```bash
pip install livekit-plugins-google
For credentials, you’ll need a Google Cloud account and obtain the correct credentials. Credentials can be passed directly or via Application Default Credentials as specified in How Application Default Credentials works.
To use the STT and TTS API, you’ll need to enable the respective services for your Google Cloud project.
Gemini Multimodal Live can be used with the MultimodalAgent
class. See examples/multimodal_agent/gemini_agent.py for an example.
You can push video frames to your Gemini Multimodal Live session alongside the audio automatically handled by the MultimodalAgent
. The basic approach is to subscribe to the video track, create a video stream, sample frames at a suitable frame rate, and push them into the RealtimeSession:
# Make sure you subscribe to audio and video tracks
await ctx.connect(auto_subscribe=AutoSubscribe.SUBSCRIBE_ALL)
# Create your RealtimeModel and store a reference
model = google.beta.realtime.RealtimeModel(
# ...
)
# Create your MultimodalAgent as usual
agent = MultimodalAgent(
model=model,
# ...
)
# Async method to process the video track and push frames to Gemini
async def _process_video_track(self, track: Track):
video_stream = VideoStream(track)
last_frame_time = 0
async for event in video_stream:
current_time = asyncio.get_event_loop().time()
# Sample at 1 FPS
if current_time - last_frame_time < 1.0:
continue
last_frame_time = current_time
frame = event.frame
# Push the frame into the RealtimeSession
model.sessions[0].push_video(frame)
await video_stream.aclose()
# Subscribe to new tracks and process them
@ctx.room.on("track_subscribed")
def _on_track_subscribed(track: Track, pub, participant):
if track.kind == TrackKind.KIND_VIDEO:
asyncio.create_task(self._process_video_track(track))
## livekit-plugins/livekit-plugins-google/livekit/plugins/google/__init__.py
```py
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import beta
from .llm import LLM
from .stt import STT, SpeechStream
from .tts import TTS
from .version import __version__
__all__ = ["STT", "TTS", "SpeechStream", "__version__", "beta", "LLM"]
from livekit.agents import Plugin
from .log import logger
class GooglePlugin(Plugin):
def __init__(self):
super().__init__(__name__, __version__, __package__, logger)
Plugin.register_plugin(GooglePlugin())
# Cleanup docs of unexported modules
_module = dir()
NOT_IN_ALL = [m for m in _module if m not in __all__]
__pdoc__ = {}
for n in NOT_IN_ALL:
__pdoc__[n] = False
from __future__ import annotations
import base64
import inspect
import json
from typing import Any, Dict, List, Optional, get_args, get_origin
from livekit import rtc
from livekit.agents import llm, utils
from livekit.agents.llm.function_context import _is_optional_type
from google.genai import types
from google.genai.types import Type as GenaiType
JSON_SCHEMA_TYPE_MAP: dict[type, GenaiType] = {
str: GenaiType.STRING,
int: GenaiType.INTEGER,
float: GenaiType.NUMBER,
bool: GenaiType.BOOLEAN,
dict: GenaiType.OBJECT,
list: GenaiType.ARRAY,
}
__all__ = ["_build_gemini_ctx", "_build_tools"]
def _build_parameters(arguments: Dict[str, Any]) -> types.Schema | None:
properties: Dict[str, types.Schema] = {}
required: List[str] = []
for arg_name, arg_info in arguments.items():
prop = types.Schema()
if arg_info.description:
prop.description = arg_info.description
_, py_type = _is_optional_type(arg_info.type)
origin = get_origin(py_type)
if origin is list:
item_type = get_args(py_type)[0]
if item_type not in JSON_SCHEMA_TYPE_MAP:
raise ValueError(f"Unsupported type: {item_type}")
prop.type = GenaiType.ARRAY
prop.items = types.Schema(type=JSON_SCHEMA_TYPE_MAP[item_type])
if arg_info.choices:
prop.items.enum = arg_info.choices
else:
if py_type not in JSON_SCHEMA_TYPE_MAP:
raise ValueError(f"Unsupported type: {py_type}")
prop.type = JSON_SCHEMA_TYPE_MAP[py_type]
if arg_info.choices:
prop.enum = arg_info.choices
if py_type is int:
raise ValueError(
f"Parameter '{arg_info.name}' uses integer choices, not supported by this model."
)
properties[arg_name] = prop
if arg_info.default is inspect.Parameter.empty:
required.append(arg_name)
if properties:
parameters = types.Schema(type=GenaiType.OBJECT, properties=properties)
if required:
parameters.required = required
return parameters
return None
def _build_tools(fnc_ctx: Any) -> List[types.FunctionDeclaration]:
function_declarations: List[types.FunctionDeclaration] = []
for fnc_info in fnc_ctx.ai_functions.values():
parameters = _build_parameters(fnc_info.arguments)
func_decl = types.FunctionDeclaration(
name=fnc_info.name,
description=fnc_info.description,
parameters=parameters,
)
function_declarations.append(func_decl)
return function_declarations
def _build_gemini_ctx(
chat_ctx: llm.ChatContext, cache_key: Any
) -> tuple[list[types.Content], Optional[types.Content]]:
turns: list[types.Content] = []
system_instruction: Optional[types.Content] = None
current_role: Optional[str] = None
parts: list[types.Part] = []
for msg in chat_ctx.messages:
if msg.role == "system":
if isinstance(msg.content, str):
system_instruction = types.Content(parts=[types.Part(text=msg.content)])
continue
if msg.role == "assistant":
role = "model"
elif msg.role == "tool":
role = "user"
else:
role = "user"
# If role changed, finalize previous parts into a turn
if role != current_role:
if current_role is not None and parts:
turns.append(types.Content(role=current_role, parts=parts))
current_role = role
parts = []
if msg.tool_calls:
for fnc in msg.tool_calls:
parts.append(
types.Part(
function_call=types.FunctionCall(
name=fnc.function_info.name,
args=fnc.arguments,
)
)
)
if msg.role == "tool":
if msg.content:
if isinstance(msg.content, dict):
parts.append(
types.Part(
function_response=types.FunctionResponse(
name=msg.name,
response=msg.content,
)
)
)
elif isinstance(msg.content, str):
parts.append(
types.Part(
function_response=types.FunctionResponse(
name=msg.name,
response={"result": msg.content},
)
)
)
else:
if msg.content:
if isinstance(msg.content, str):
parts.append(types.Part(text=msg.content))
elif isinstance(msg.content, dict):
parts.append(types.Part(text=json.dumps(msg.content)))
elif isinstance(msg.content, list):
for item in msg.content:
if isinstance(item, str):
parts.append(types.Part(text=item))
elif isinstance(item, llm.ChatImage):
parts.append(_build_gemini_image_part(item, cache_key))
# Finalize last role's parts if any remain
if current_role is not None and parts:
turns.append(types.Content(role=current_role, parts=parts))
return turns, system_instruction
def _build_gemini_image_part(image: llm.ChatImage, cache_key: Any) -> types.Part:
if isinstance(image.image, str):
# Check if the string is a Data URL
if image.image.startswith("data:image/jpeg;base64,"):
# Extract the base64 part after the comma
base64_data = image.image.split(",", 1)[1]
try:
image_bytes = base64.b64decode(base64_data)
except Exception as e:
raise ValueError("Invalid base64 data in image URL") from e
return types.Part.from_bytes(data=image_bytes, mime_type="image/jpeg")
else:
# Assume it's a regular URL
return types.Part.from_uri(file_uri=image.image, mime_type="image/jpeg")
elif isinstance(image.image, rtc.VideoFrame):
if cache_key not in image._cache:
opts = utils.images.EncodeOptions()
if image.inference_width and image.inference_height:
opts.resize_options = utils.images.ResizeOptions(
width=image.inference_width,
height=image.inference_height,
strategy="scale_aspect_fit",
)
image._cache[cache_key] = utils.images.encode(image.image, opts)
return types.Part.from_bytes(
data=image._cache[cache_key], mime_type="image/jpeg"
)
raise ValueError(f"Unsupported image type: {type(image.image)}")
from . import realtime
__all__ = ["realtime"]
from .api_proto import (
ClientEvents,
LiveAPIModels,
Voice,
)
from .realtime_api import RealtimeModel
__all__ = [
"RealtimeModel",
"ClientEvents",
"LiveAPIModels",
"Voice",
]
from __future__ import annotations
from typing import Literal, Sequence, Union
from google.genai import types
from ..._utils import _build_gemini_ctx, _build_tools
LiveAPIModels = Literal["gemini-2.0-flash-exp"]
Voice = Literal["Puck", "Charon", "Kore", "Fenrir", "Aoede"]
__all__ = ["_build_tools", "ClientEvents", "_build_gemini_ctx"]
ClientEvents = Union[
types.ContentListUnion,
types.ContentListUnionDict,
types.LiveClientContentOrDict,
types.LiveClientRealtimeInput,
types.LiveClientRealtimeInputOrDict,
types.LiveClientToolResponseOrDict,
types.FunctionResponseOrDict,
Sequence[types.FunctionResponseOrDict],
]
from __future__ import annotations
import asyncio
import json
import os
from dataclasses import dataclass
from typing import AsyncIterable, Literal
from livekit import rtc
from livekit.agents import llm, utils
from livekit.agents.llm.function_context import _create_ai_function_info
from livekit.agents.utils import images
from google import genai
from google.genai.types import (
Blob,
Content,
FunctionResponse,
GenerationConfig,
HttpOptions,
LiveClientContent,
LiveClientRealtimeInput,
LiveClientToolResponse,
LiveConnectConfig,
Modality,
Part,
PrebuiltVoiceConfig,
SpeechConfig,
Tool,
VoiceConfig,
)
from ...log import logger
from .api_proto import (
ClientEvents,
LiveAPIModels,
Voice,
_build_gemini_ctx,
_build_tools,
)
from .transcriber import ModelTranscriber, TranscriberSession, TranscriptionContent
EventTypes = Literal[
"start_session",
"input_speech_started",
"response_content_added",
"response_content_done",
"function_calls_collected",
"function_calls_finished",
"function_calls_cancelled",
"input_speech_transcription_completed",
"agent_speech_transcription_completed",
"agent_speech_stopped",
]
@dataclass
class GeminiContent:
response_id: str
item_id: str
output_index: int
content_index: int
text: str
audio: list[rtc.AudioFrame]
text_stream: AsyncIterable[str]
audio_stream: AsyncIterable[rtc.AudioFrame]
content_type: Literal["text", "audio"]
@dataclass
class InputTranscription:
item_id: str
transcript: str
@dataclass
class Capabilities:
supports_truncate: bool
input_audio_sample_rate: int | None = None
@dataclass
class ModelOptions:
model: LiveAPIModels | str
api_key: str | None
api_version: str
voice: Voice | str
response_modalities: list[Modality] | None
vertexai: bool
project: str | None
location: str | None
candidate_count: int
temperature: float | None
max_output_tokens: int | None
top_p: float | None
top_k: int | None
presence_penalty: float | None
frequency_penalty: float | None
instructions: Content | None
enable_user_audio_transcription: bool
enable_agent_audio_transcription: bool
class RealtimeModel:
def __init__(
self,
*,
instructions: str | None = None,
model: LiveAPIModels | str = "gemini-2.0-flash-exp",
api_key: str | None = None,
api_version: str = "v1alpha",
voice: Voice | str = "Puck",
modalities: list[Modality] = [Modality.AUDIO],
enable_user_audio_transcription: bool = True,
enable_agent_audio_transcription: bool = True,
vertexai: bool = False,
project: str | None = None,
location: str | None = None,
candidate_count: int = 1,
temperature: float | None = None,
max_output_tokens: int | None = None,
top_p: float | None = None,
top_k: int | None = None,
presence_penalty: float | None = None,
frequency_penalty: float | None = None,
loop: asyncio.AbstractEventLoop | None = None,
):
"""
Initializes a RealtimeModel instance for interacting with Google's Realtime API.
Environment Requirements:
- For VertexAI: Set the `GOOGLE_APPLICATION_CREDENTIALS` environment variable to the path of the service account key file.
The Google Cloud project and location can be set via `project` and `location` arguments or the environment variables
`GOOGLE_CLOUD_PROJECT` and `GOOGLE_CLOUD_LOCATION`. By default, the project is inferred from the service account key file,
and the location defaults to "us-central1".
- For Google Gemini API: Set the `api_key` argument or the `GOOGLE_API_KEY` environment variable.
Args:
instructions (str, optional): Initial system instructions for the model. Defaults to "".
api_key (str or None, optional): Google Gemini API key. If None, will attempt to read from the environment variable GOOGLE_API_KEY.
api_version (str, optional): The version of the API to use. Defaults to "v1alpha".
modalities (list[Modality], optional): Modalities to use, such as ["TEXT", "AUDIO"]. Defaults to ["AUDIO"].
model (str or None, optional): The name of the model to use. Defaults to "gemini-2.0-flash-exp".
voice (api_proto.Voice, optional): Voice setting for audio outputs. Defaults to "Puck".
enable_user_audio_transcription (bool, optional): Whether to enable user audio transcription. Defaults to True
enable_agent_audio_transcription (bool, optional): Whether to enable agent audio transcription. Defaults to True
temperature (float, optional): Sampling temperature for response generation. Defaults to 0.8.
vertexai (bool, optional): Whether to use VertexAI for the API. Defaults to False.
project (str or None, optional): The project id to use for the API. Defaults to None. (for vertexai)
location (str or None, optional): The location to use for the API. Defaults to None. (for vertexai)
candidate_count (int, optional): The number of candidate responses to generate. Defaults to 1.
top_p (float, optional): The top-p value for response generation
top_k (int, optional): The top-k value for response generation
presence_penalty (float, optional): The presence penalty for response generation
frequency_penalty (float, optional): The frequency penalty for response generation
loop (asyncio.AbstractEventLoop or None, optional): Event loop to use for async operations. If None, the current event loop is used.
Raises:
ValueError: If the API key is not provided and cannot be found in environment variables.
"""
super().__init__()
self._capabilities = Capabilities(
supports_truncate=False,
input_audio_sample_rate=16000,
)
self._model = model
self._loop = loop or asyncio.get_event_loop()
self._api_key = api_key or os.environ.get("GOOGLE_API_KEY")
self._project = project or os.environ.get("GOOGLE_CLOUD_PROJECT")
self._location = location or os.environ.get("GOOGLE_CLOUD_LOCATION")
if vertexai:
if not self._project or not self._location:
raise ValueError(
"Project and location are required for VertexAI either via project and location or GOOGLE_CLOUD_PROJECT and GOOGLE_CLOUD_LOCATION environment variables"
)
self._api_key = None # VertexAI does not require an API key
else:
self._project = None
self._location = None
if not self._api_key:
raise ValueError(
"API key is required for Google API either via api_key or GOOGLE_API_KEY environment variable"
)
instructions_content = (
Content(parts=[Part(text=instructions)]) if instructions else None
)
self._rt_sessions: list[GeminiRealtimeSession] = []
self._opts = ModelOptions(
model=model,
api_version=api_version,
api_key=self._api_key,
voice=voice,
enable_user_audio_transcription=enable_user_audio_transcription,
enable_agent_audio_transcription=enable_agent_audio_transcription,
response_modalities=modalities,
vertexai=vertexai,
project=self._project,
location=self._location,
candidate_count=candidate_count,
temperature=temperature,
max_output_tokens=max_output_tokens,
top_p=top_p,
top_k=top_k,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
instructions=instructions_content,
)
@property
def sessions(self) -> list[GeminiRealtimeSession]:
return self._rt_sessions
@property
def capabilities(self) -> Capabilities:
return self._capabilities
def session(
self,
*,
chat_ctx: llm.ChatContext | None = None,
fnc_ctx: llm.FunctionContext | None = None,
) -> GeminiRealtimeSession:
session = GeminiRealtimeSession(
opts=self._opts,
chat_ctx=chat_ctx or llm.ChatContext(),
fnc_ctx=fnc_ctx,
loop=self._loop,
)
self._rt_sessions.append(session)
return session
async def aclose(self) -> None:
for session in self._rt_sessions:
await session.aclose()
class GeminiRealtimeSession(utils.EventEmitter[EventTypes]):
def __init__(
self,
*,
opts: ModelOptions,
chat_ctx: llm.ChatContext,
fnc_ctx: llm.FunctionContext | None,
loop: asyncio.AbstractEventLoop,
):
"""
Initializes a GeminiRealtimeSession instance for interacting with Google's Realtime API.
Args:
opts (ModelOptions): The model options for the session.
chat_ctx (llm.ChatContext): The chat context for the session.
fnc_ctx (llm.FunctionContext or None): The function context for the session.
loop (asyncio.AbstractEventLoop): The event loop for the session.
"""
super().__init__()
self._loop = loop
self._opts = opts
self._chat_ctx = chat_ctx
self._fnc_ctx = fnc_ctx
self._fnc_tasks = utils.aio.TaskSet()
self._is_interrupted = False
self._playout_complete = asyncio.Event()
self._playout_complete.set()
tools = []
if self._fnc_ctx is not None:
functions = _build_tools(self._fnc_ctx)
tools.append(Tool(function_declarations=functions))
self._config = LiveConnectConfig(
response_modalities=self._opts.response_modalities,
generation_config=GenerationConfig(
candidate_count=self._opts.candidate_count,
temperature=self._opts.temperature,
max_output_tokens=self._opts.max_output_tokens,
top_p=self._opts.top_p,
top_k=self._opts.top_k,
presence_penalty=self._opts.presence_penalty,
frequency_penalty=self._opts.frequency_penalty,
),
system_instruction=self._opts.instructions,
speech_config=SpeechConfig(
voice_config=VoiceConfig(
prebuilt_voice_config=PrebuiltVoiceConfig(
voice_name=self._opts.voice
)
)
),
tools=tools,
)
self._client = genai.Client(
http_options=HttpOptions(api_version=self._opts.api_version),
api_key=self._opts.api_key,
vertexai=self._opts.vertexai,
project=self._opts.project,
location=self._opts.location,
)
self._main_atask = asyncio.create_task(
self._main_task(), name="gemini-realtime-session"
)
if self._opts.enable_user_audio_transcription:
self._transcriber = TranscriberSession(
client=self._client, model=self._opts.model
)
self._transcriber.on("input_speech_done", self._on_input_speech_done)
if self._opts.enable_agent_audio_transcription:
self._agent_transcriber = ModelTranscriber(
client=self._client, model=self._opts.model
)
self._agent_transcriber.on("input_speech_done", self._on_agent_speech_done)
# init dummy task
self._init_sync_task = asyncio.create_task(asyncio.sleep(0))
self._send_ch = utils.aio.Chan[ClientEvents]()
self._active_response_id = None
async def aclose(self) -> None:
if self._send_ch.closed:
return
self._send_ch.close()
await self._main_atask
@property
def playout_complete(self) -> asyncio.Event | None:
return self._playout_complete
@property
def fnc_ctx(self) -> llm.FunctionContext | None:
return self._fnc_ctx
@fnc_ctx.setter
def fnc_ctx(self, value: llm.FunctionContext | None) -> None:
self._fnc_ctx = value
def _push_media_chunk(self, data: bytes, mime_type: str) -> None:
realtime_input = LiveClientRealtimeInput(
media_chunks=[Blob(data=data, mime_type=mime_type)],
)
self._queue_msg(realtime_input)
DEFAULT_ENCODE_OPTIONS = images.EncodeOptions(
format="JPEG",
quality=75,
resize_options=images.ResizeOptions(
width=1024, height=1024, strategy="scale_aspect_fit"
),
)
def push_video(
self,
frame: rtc.VideoFrame,
encode_options: images.EncodeOptions = DEFAULT_ENCODE_OPTIONS,
) -> None:
"""Push a video frame to the Gemini Multimodal Live session.
Args:
frame (rtc.VideoFrame): The video frame to push.
encode_options (images.EncodeOptions, optional): The encode options for the video frame. Defaults to 1024x1024 JPEG.
Notes:
- This will be sent immediately so you should use a sampling frame rate that makes sense for your application and Gemini's constraints. 1 FPS is a good starting point.
"""
encoded_data = images.encode(
frame,
encode_options,
)
mime_type = (
"image/jpeg"
if encode_options.format == "JPEG"
else "image/png"
if encode_options.format == "PNG"
else "image/jpeg"
)
self._push_media_chunk(encoded_data, mime_type)
def _push_audio(self, frame: rtc.AudioFrame) -> None:
if self._opts.enable_user_audio_transcription:
self._transcriber._push_audio(frame)
self._push_media_chunk(frame.data.tobytes(), "audio/pcm")
def _queue_msg(self, msg: ClientEvents) -> None:
self._send_ch.send_nowait(msg)
def chat_ctx_copy(self) -> llm.ChatContext:
return self._chat_ctx.copy()
async def set_chat_ctx(self, ctx: llm.ChatContext) -> None:
self._chat_ctx = ctx.copy()
def cancel_response(self) -> None:
raise NotImplementedError("cancel_response is not supported yet")
def create_response(
self,
on_duplicate: Literal[
"cancel_existing", "cancel_new", "keep_both"
] = "keep_both",
) -> None:
turns, _ = _build_gemini_ctx(self._chat_ctx, id(self))
ctx = [self._opts.instructions] + turns if self._opts.instructions else turns
if not ctx:
logger.warning(
"gemini-realtime-session: No chat context to send, sending dummy content."
)
ctx = [Content(parts=[Part(text=".")])]
self._queue_msg(LiveClientContent(turns=ctx, turn_complete=True))
def commit_audio_buffer(self) -> None:
raise NotImplementedError("commit_audio_buffer is not supported yet")
def server_vad_enabled(self) -> bool:
return True
def _on_input_speech_done(self, content: TranscriptionContent) -> None:
if content.response_id and content.text:
self.emit(
"input_speech_transcription_completed",
InputTranscription(
item_id=content.response_id,
transcript=content.text,
),
)
# self._chat_ctx.append(text=content.text, role="user")
# TODO: implement sync mechanism to make sure the transcribed user speech is inside the chat_ctx and always before the generated agent speech
def _on_agent_speech_done(self, content: TranscriptionContent) -> None:
if content.response_id and content.text:
self.emit(
"agent_speech_transcription_completed",
InputTranscription(
item_id=content.response_id,
transcript=content.text,
),
)
# self._chat_ctx.append(text=content.text, role="assistant")
@utils.log_exceptions(logger=logger)
async def _main_task(self):
@utils.log_exceptions(logger=logger)
async def _send_task():
async for msg in self._send_ch:
await self._session.send(input=msg)
await self._session.send(input=".", end_of_turn=True)
@utils.log_exceptions(logger=logger)
async def _recv_task():
while True:
async for response in self._session.receive():
if self._active_response_id is None:
self._is_interrupted = False
self._active_response_id = utils.shortuuid()
text_stream = utils.aio.Chan[str]()
audio_stream = utils.aio.Chan[rtc.AudioFrame]()
content = GeminiContent(
response_id=self._active_response_id,
item_id=self._active_response_id,
output_index=0,
content_index=0,
text="",
audio=[],
text_stream=text_stream,
audio_stream=audio_stream,
content_type="audio",
)
self.emit("response_content_added", content)
server_content = response.server_content
if server_content:
model_turn = server_content.model_turn
if model_turn:
for part in model_turn.parts:
if part.text:
content.text_stream.send_nowait(part.text)
if part.inline_data:
frame = rtc.AudioFrame(
data=part.inline_data.data,
sample_rate=24000,
num_channels=1,
samples_per_channel=len(part.inline_data.data)
// 2,
)
if self._opts.enable_agent_audio_transcription:
content.audio.append(frame)
content.audio_stream.send_nowait(frame)
if server_content.interrupted or server_content.turn_complete:
if self._opts.enable_agent_audio_transcription:
self._agent_transcriber._push_audio(content.audio)
for stream in (content.text_stream, content.audio_stream):
if isinstance(stream, utils.aio.Chan):
stream.close()
self.emit("agent_speech_stopped")
self._is_interrupted = True
self._active_response_id = None
if response.tool_call:
if self._fnc_ctx is None:
raise ValueError("Function context is not set")
fnc_calls = []
for fnc_call in response.tool_call.function_calls:
fnc_call_info = _create_ai_function_info(
self._fnc_ctx,
fnc_call.id,
fnc_call.name,
json.dumps(fnc_call.args),
)
fnc_calls.append(fnc_call_info)
self.emit("function_calls_collected", fnc_calls)
for fnc_call_info in fnc_calls:
self._fnc_tasks.create_task(
self._run_fnc_task(fnc_call_info, content.item_id)
)
# Handle function call cancellations
if response.tool_call_cancellation:
logger.warning(
"function call cancelled",
extra={
"function_call_ids": response.tool_call_cancellation.ids,
},
)
self.emit(
"function_calls_cancelled",
response.tool_call_cancellation.ids,
)
async with self._client.aio.live.connect(
model=self._opts.model, config=self._config
) as session:
self._session = session
tasks = [
asyncio.create_task(_send_task(), name="gemini-realtime-send"),
asyncio.create_task(_recv_task(), name="gemini-realtime-recv"),
]
try:
await asyncio.gather(*tasks)
finally:
await utils.aio.gracefully_cancel(*tasks)
await self._session.close()
if self._opts.enable_user_audio_transcription:
await self._transcriber.aclose()
if self._opts.enable_agent_audio_transcription:
await self._agent_transcriber.aclose()
@utils.log_exceptions(logger=logger)
async def _run_fnc_task(self, fnc_call_info: llm.FunctionCallInfo, item_id: str):
logger.debug(
"executing ai function",
extra={
"function": fnc_call_info.function_info.name,
},
)
called_fnc = fnc_call_info.execute()
try:
await called_fnc.task
except Exception as e:
logger.exception(
"error executing ai function",
extra={
"function": fnc_call_info.function_info.name,
},
exc_info=e,
)
tool_call = llm.ChatMessage.create_tool_from_called_function(called_fnc)
if tool_call.content is not None:
tool_response = LiveClientToolResponse(
function_responses=[
FunctionResponse(
name=tool_call.name,
id=tool_call.tool_call_id,
response={"result": tool_call.content},
)
]
)
await self._session.send(input=tool_response)
self.emit("function_calls_finished", [called_fnc])
from __future__ import annotations
import asyncio
import re
from dataclasses import dataclass
from typing import Literal
import websockets
from livekit import rtc
from livekit.agents import APIConnectionError, APIStatusError, utils
from google import genai
from google.genai import types
from google.genai.errors import APIError, ClientError, ServerError
from ...log import logger
from .api_proto import ClientEvents, LiveAPIModels
EventTypes = Literal["input_speech_started", "input_speech_done"]
DEFAULT_LANGUAGE = "English"
SYSTEM_INSTRUCTIONS = f"""
You are an **Audio Transcriber**. Your task is to convert audio content into accurate and precise text.
- Transcribe verbatim; exclude non-speech sounds.
- Provide only transcription; no extra text or explanations.
- If audio is unclear, respond with: `...`
- Ensure error-free transcription, preserving meaning and context.
- Use proper punctuation and formatting.
- Do not add explanations, comments, or extra information.
- Do not include timestamps, speaker labels, or annotations unless specified.
- Audio Language: {DEFAULT_LANGUAGE}
"""
@dataclass
class TranscriptionContent:
response_id: str
text: str
class TranscriberSession(utils.EventEmitter[EventTypes]):
"""
Handles live audio transcription using the realtime API.
"""
def __init__(self, *, client: genai.Client, model: LiveAPIModels | str):
super().__init__()
self._client = client
self._model = model
self._needed_sr = 16000
self._closed = False
system_instructions = types.Content(
parts=[types.Part(text=SYSTEM_INSTRUCTIONS)]
)
self._config = types.LiveConnectConfig(
response_modalities=[types.Modality.TEXT],
system_instruction=system_instructions,
generation_config=types.GenerationConfig(temperature=0.0),
)
self._main_atask = asyncio.create_task(
self._main_task(), name="gemini-realtime-transcriber"
)
self._send_ch = utils.aio.Chan[ClientEvents]()
self._resampler: rtc.AudioResampler | None = None
self._active_response_id = None
def _push_audio(self, frame: rtc.AudioFrame) -> None:
if self._closed:
return
if frame.sample_rate != self._needed_sr:
if not self._resampler:
self._resampler = rtc.AudioResampler(
frame.sample_rate,
self._needed_sr,
quality=rtc.AudioResamplerQuality.HIGH,
)
if self._resampler:
for f in self._resampler.push(frame):
self._queue_msg(
types.LiveClientRealtimeInput(
media_chunks=[
types.Blob(data=f.data.tobytes(), mime_type="audio/pcm")
]
)
)
else:
self._queue_msg(
types.LiveClientRealtimeInput(
media_chunks=[
types.Blob(data=frame.data.tobytes(), mime_type="audio/pcm")
]
)
)
def _queue_msg(self, msg: ClientEvents) -> None:
if not self._closed:
self._send_ch.send_nowait(msg)
async def aclose(self) -> None:
if self._send_ch.closed:
return
self._closed = True
self._send_ch.close()
await self._main_atask
@utils.log_exceptions(logger=logger)
async def _main_task(self):
@utils.log_exceptions(logger=logger)
async def _send_task():
try:
async for msg in self._send_ch:
if self._closed:
break
await self._session.send(input=msg)
except websockets.exceptions.ConnectionClosedError as e:
logger.exception(f"Transcriber session closed in _send_task: {e}")
self._closed = True
except Exception as e:
logger.exception(f"Uncaught error in transcriber _send_task: {e}")
self._closed = True
@utils.log_exceptions(logger=logger)
async def _recv_task():
try:
while not self._closed:
async for response in self._session.receive():
if self._closed:
break
if self._active_response_id is None:
self._active_response_id = utils.shortuuid()
content = TranscriptionContent(
response_id=self._active_response_id,
text="",
)
self.emit("input_speech_started", content)
server_content = response.server_content
if server_content:
model_turn = server_content.model_turn
if model_turn:
for part in model_turn.parts:
if part.text:
content.text += part.text
if server_content.turn_complete:
content.text = clean_transcription(content.text)
self.emit("input_speech_done", content)
self._active_response_id = None
except websockets.exceptions.ConnectionClosedError as e:
logger.exception(f"Transcriber session closed in _recv_task: {e}")
self._closed = True
except Exception as e:
logger.exception(f"Uncaught error in transcriber _recv_task: {e}")
self._closed = True
async with self._client.aio.live.connect(
model=self._model, config=self._config
) as session:
self._session = session
tasks = [
asyncio.create_task(
_send_task(), name="gemini-realtime-transcriber-send"
),
asyncio.create_task(
_recv_task(), name="gemini-realtime-transcriber-recv"
),
]
try:
await asyncio.gather(*tasks)
finally:
await utils.aio.gracefully_cancel(*tasks)
await self._session.close()
class ModelTranscriber(utils.EventEmitter[EventTypes]):
"""
Transcribes agent audio using model generation.
"""
def __init__(self, *, client: genai.Client, model: LiveAPIModels | str):
super().__init__()
self._client = client
self._model = model
self._needed_sr = 16000
self._system_instructions = types.Content(
parts=[types.Part(text=SYSTEM_INSTRUCTIONS)]
)
self._config = types.GenerateContentConfig(
temperature=0.0,
system_instruction=self._system_instructions,
# TODO: add response_schem
)
self._resampler: rtc.AudioResampler | None = None
self._buffer: rtc.AudioFrame | None = None
self._audio_ch = utils.aio.Chan[rtc.AudioFrame]()
self._main_atask = asyncio.create_task(
self._main_task(), name="gemini-model-transcriber"
)
async def aclose(self) -> None:
if self._audio_ch.closed:
return
self._audio_ch.close()
await self._main_atask
def _push_audio(self, frames: list[rtc.AudioFrame]) -> None:
if not frames:
return
buffer = utils.merge_frames(frames)
if buffer.sample_rate != self._needed_sr:
if self._resampler is None:
self._resampler = rtc.AudioResampler(
input_rate=buffer.sample_rate,
output_rate=self._needed_sr,
quality=rtc.AudioResamplerQuality.HIGH,
)
buffer = utils.merge_frames(self._resampler.push(buffer))
self._audio_ch.send_nowait(buffer)
@utils.log_exceptions(logger=logger)
async def _main_task(self):
request_id = utils.shortuuid()
try:
async for buffer in self._audio_ch:
# TODO: stream content for better latency
response = await self._client.aio.models.generate_content(
model=self._model,
contents=[
types.Content(
parts=[
types.Part(text=SYSTEM_INSTRUCTIONS),
types.Part.from_bytes(
data=buffer.to_wav_bytes(),
mime_type="audio/wav",
),
],
role="user",
)
],
config=self._config,
)
content = TranscriptionContent(
response_id=request_id, text=clean_transcription(response.text)
)
self.emit("input_speech_done", content)
except (ClientError, ServerError, APIError) as e:
raise APIStatusError(
f"model transcriber error: {e}",
status_code=e.code,
body=e.message,
request_id=request_id,
) from e
except Exception as e:
raise APIConnectionError("Error generating transcription") from e
def clean_transcription(text: str) -> str:
text = text.replace("\n", " ")
text = re.sub(r"\s+", " ", text)
return text.strip()
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import asyncio
import json
import os
from dataclasses import dataclass
from typing import Any, Literal, MutableSet, Union, cast
from livekit.agents import (
APIConnectionError,
APIStatusError,
llm,
utils,
)
from livekit.agents.llm import LLMCapabilities, ToolChoice, _create_ai_function_info
from livekit.agents.types import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions
from google import genai
from google.auth._default_async import default_async
from google.genai import types
from google.genai.errors import APIError, ClientError, ServerError
from ._utils import _build_gemini_ctx, _build_tools
from .log import logger
from .models import ChatModels
@dataclass
class LLMOptions:
model: ChatModels | str
temperature: float | None
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = "auto"
vertexai: bool = False
project: str | None = None
location: str | None = None
candidate_count: int = 1
max_output_tokens: int | None = None
top_p: float | None = None
top_k: float | None = None
presence_penalty: float | None = None
frequency_penalty: float | None = None
class LLM(llm.LLM):
def __init__(
self,
*,
model: ChatModels | str = "gemini-2.0-flash-001",
api_key: str | None = None,
vertexai: bool = False,
project: str | None = None,
location: str | None = None,
candidate_count: int = 1,
temperature: float = 0.8,
max_output_tokens: int | None = None,
top_p: float | None = None,
top_k: float | None = None,
presence_penalty: float | None = None,
frequency_penalty: float | None = None,
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = "auto",
) -> None:
"""
Create a new instance of Google GenAI LLM.
Environment Requirements:
- For VertexAI: Set the `GOOGLE_APPLICATION_CREDENTIALS` environment variable to the path of the service account key file.
The Google Cloud project and location can be set via `project` and `location` arguments or the environment variables
`GOOGLE_CLOUD_PROJECT` and `GOOGLE_CLOUD_LOCATION`. By default, the project is inferred from the service account key file,
and the location defaults to "us-central1".
- For Google Gemini API: Set the `api_key` argument or the `GOOGLE_API_KEY` environment variable.
Args:
model (ChatModels | str, optional): The model name to use. Defaults to "gemini-2.0-flash-001".
api_key (str, optional): The API key for Google Gemini. If not provided, it attempts to read from the `GOOGLE_API_KEY` environment variable.
vertexai (bool, optional): Whether to use VertexAI. Defaults to False.
project (str, optional): The Google Cloud project to use (only for VertexAI). Defaults to None.
location (str, optional): The location to use for VertexAI API requests. Defaults value is "us-central1".
candidate_count (int, optional): Number of candidate responses to generate. Defaults to 1.
temperature (float, optional): Sampling temperature for response generation. Defaults to 0.8.
max_output_tokens (int, optional): Maximum number of tokens to generate in the output. Defaults to None.
top_p (float, optional): The nucleus sampling probability for response generation. Defaults to None.
top_k (int, optional): The top-k sampling value for response generation. Defaults to None.
presence_penalty (float, optional): Penalizes the model for generating previously mentioned concepts. Defaults to None.
frequency_penalty (float, optional): Penalizes the model for repeating words. Defaults to None.
tool_choice (ToolChoice or Literal["auto", "required", "none"], optional): Specifies whether to use tools during response generation. Defaults to "auto".
"""
super().__init__(
capabilities=LLMCapabilities(
supports_choices_on_int=False,
requires_persistent_functions=False,
)
)
self._project_id = project or os.environ.get("GOOGLE_CLOUD_PROJECT", None)
self._location = location or os.environ.get(
"GOOGLE_CLOUD_LOCATION", "us-central1"
)
self._api_key = api_key or os.environ.get("GOOGLE_API_KEY", None)
_gac = os.environ.get("GOOGLE_APPLICATION_CREDENTIALS")
if _gac is None:
logger.warning(
"`GOOGLE_APPLICATION_CREDENTIALS` environment variable is not set. please set it to the path of the service account key file. Otherwise, use any of the other Google Cloud auth methods."
)
if vertexai:
if not self._project_id:
_, self._project_id = default_async(
scopes=["https://www.googleapis.com/auth/cloud-platform"]
)
self._api_key = None # VertexAI does not require an API key
else:
self._project_id = None
self._location = None
if not self._api_key:
raise ValueError(
"API key is required for Google API either via api_key or GOOGLE_API_KEY environment variable"
)
self._opts = LLMOptions(
model=model,
temperature=temperature,
tool_choice=tool_choice,
vertexai=vertexai,
project=project,
location=location,
candidate_count=candidate_count,
max_output_tokens=max_output_tokens,
top_p=top_p,
top_k=top_k,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
)
self._client = genai.Client(
api_key=self._api_key,
vertexai=vertexai,
project=self._project_id,
location=self._location,
)
self._running_fncs: MutableSet[asyncio.Task[Any]] = set()
def chat(
self,
*,
chat_ctx: llm.ChatContext,
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
fnc_ctx: llm.FunctionContext | None = None,
temperature: float | None = None,
n: int | None = 1,
parallel_tool_calls: bool | None = None,
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]]
| None = None,
) -> "LLMStream":
if tool_choice is None:
tool_choice = self._opts.tool_choice
if temperature is None:
temperature = self._opts.temperature
return LLMStream(
self,
client=self._client,
model=self._opts.model,
max_output_tokens=self._opts.max_output_tokens,
top_p=self._opts.top_p,
top_k=self._opts.top_k,
presence_penalty=self._opts.presence_penalty,
frequency_penalty=self._opts.frequency_penalty,
chat_ctx=chat_ctx,
fnc_ctx=fnc_ctx,
conn_options=conn_options,
n=n,
temperature=temperature,
tool_choice=tool_choice,
)
class LLMStream(llm.LLMStream):
def __init__(
self,
llm: LLM,
*,
client: genai.Client,
model: str | ChatModels,
chat_ctx: llm.ChatContext,
conn_options: APIConnectOptions,
fnc_ctx: llm.FunctionContext | None,
temperature: float | None,
n: int | None,
max_output_tokens: int | None,
top_p: float | None,
top_k: float | None,
presence_penalty: float | None,
frequency_penalty: float | None,
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]],
) -> None:
super().__init__(
llm, chat_ctx=chat_ctx, fnc_ctx=fnc_ctx, conn_options=conn_options
)
self._client = client
self._model = model
self._llm: LLM = llm
self._max_output_tokens = max_output_tokens
self._top_p = top_p
self._top_k = top_k
self._presence_penalty = presence_penalty
self._frequency_penalty = frequency_penalty
self._temperature = temperature
self._n = n
self._tool_choice = tool_choice
async def _run(self) -> None:
retryable = True
request_id = utils.shortuuid()
try:
opts: dict[str, Any] = dict()
turns, system_instruction = _build_gemini_ctx(self._chat_ctx, id(self))
if self._fnc_ctx and len(self._fnc_ctx.ai_functions) > 0:
functions = _build_tools(self._fnc_ctx)
opts["tools"] = [types.Tool(function_declarations=functions)]
if self._tool_choice is not None:
if isinstance(self._tool_choice, ToolChoice):
# specific function
tool_config = types.ToolConfig(
function_calling_config=types.FunctionCallingConfig(
mode=types.FunctionCallingConfigMode.ANY,
allowed_function_names=[self._tool_choice.name],
)
)
elif self._tool_choice == "required":
# model must call any function
tool_config = types.ToolConfig(
function_calling_config=types.FunctionCallingConfig(
mode=types.FunctionCallingConfigMode.ANY,
allowed_function_names=[
fnc.name
for fnc in self._fnc_ctx.ai_functions.values()
],
)
)
elif self._tool_choice == "auto":
# model can call any function
tool_config = types.ToolConfig(
function_calling_config=types.FunctionCallingConfig(
mode=types.FunctionCallingConfigMode.AUTO
)
)
elif self._tool_choice == "none":
# model cannot call any function
tool_config = types.ToolConfig(
function_calling_config=types.FunctionCallingConfig(
mode=types.FunctionCallingConfigMode.NONE,
)
)
opts["tool_config"] = tool_config
config = types.GenerateContentConfig(
candidate_count=self._n,
temperature=self._temperature,
max_output_tokens=self._max_output_tokens,
top_p=self._top_p,
top_k=self._top_k,
presence_penalty=self._presence_penalty,
frequency_penalty=self._frequency_penalty,
system_instruction=system_instruction,
**opts,
)
stream = await self._client.aio.models.generate_content_stream(
model=self._model,
contents=cast(types.ContentListUnion, turns),
config=config,
)
async for response in stream: # type: ignore
if response.prompt_feedback:
raise APIStatusError(
response.prompt_feedback.json(),
retryable=False,
request_id=request_id,
)
if (
not response.candidates
or not response.candidates[0].content
or not response.candidates[0].content.parts
):
raise APIStatusError(
"No candidates in the response",
retryable=True,
request_id=request_id,
)
if len(response.candidates) > 1:
logger.warning(
"gemini llm: there are multiple candidates in the response, returning response from the first one."
)
for index, part in enumerate(response.candidates[0].content.parts):
chat_chunk = self._parse_part(request_id, index, part)
if chat_chunk is not None:
retryable = False
self._event_ch.send_nowait(chat_chunk)
if response.usage_metadata is not None:
usage = response.usage_metadata
self._event_ch.send_nowait(
llm.ChatChunk(
request_id=request_id,
usage=llm.CompletionUsage(
completion_tokens=usage.candidates_token_count or 0,
prompt_tokens=usage.prompt_token_count or 0,
total_tokens=usage.total_token_count or 0,
),
)
)
except ClientError as e:
raise APIStatusError(
"gemini llm: client error",
status_code=e.code,
body=e.message,
request_id=request_id,
retryable=False if e.code != 429 else True,
) from e
except ServerError as e:
raise APIStatusError(
"gemini llm: server error",
status_code=e.code,
body=e.message,
request_id=request_id,
retryable=retryable,
) from e
except APIError as e:
raise APIStatusError(
"gemini llm: api error",
status_code=e.code,
body=e.message,
request_id=request_id,
retryable=retryable,
) from e
except Exception as e:
raise APIConnectionError(
"gemini llm: error generating content",
retryable=retryable,
) from e
def _parse_part(
self, id: str, index: int, part: types.Part
) -> llm.ChatChunk | None:
if part.function_call:
return self._try_build_function(id, index, part)
return llm.ChatChunk(
request_id=id,
choices=[
llm.Choice(
delta=llm.ChoiceDelta(content=part.text, role="assistant"),
index=index,
)
],
)
def _try_build_function(
self, id: str, index: int, part: types.Part
) -> llm.ChatChunk | None:
if part.function_call is None:
logger.warning("gemini llm: no function call in the response")
return None
if part.function_call.name is None:
logger.warning("gemini llm: no function name in the response")
return None
if part.function_call.id is None:
part.function_call.id = utils.shortuuid()
if self._fnc_ctx is None:
logger.warning(
"google stream tried to run function without function context"
)
return None
fnc_info = _create_ai_function_info(
self._fnc_ctx,
part.function_call.id,
part.function_call.name,
json.dumps(part.function_call.args),
)
self._function_calls_info.append(fnc_info)
return llm.ChatChunk(
request_id=id,
choices=[
llm.Choice(
delta=llm.ChoiceDelta(
role="assistant",
tool_calls=[fnc_info],
content=part.text,
),
index=index,
)
],
)
import logging
logger = logging.getLogger("livekit.plugins.google")
from typing import Literal
# Speech to Text v2
SpeechModels = Literal[
"long",
"short",
"telephony",
"medical_dictation",
"medical_conversation",
"chirp",
"chirp_2",
"latest_long",
"latest_short",
]
SpeechLanguages = Literal[
"en-US",
"ja-JP",
"en-IN",
"en-GB",
"hi-IN",
"af-ZA",
"sq-AL",
"am-ET",
"ar-EG",
"hy-AM",
"ast-ES",
"az-AZ",
"eu-ES",
"be-BY",
"bs-BA",
"bg-BG",
"my-MM",
"ca-ES",
"ceb-PH",
"ckb-IQ",
"zh-Hans-CN",
"yue-Hant-HK",
"zh-TW",
"hr-HR",
"cs-CZ",
"da-DK",
"nl-NL",
"en-AU",
"et-EE",
"fil-PH",
"fi-FI",
"fr-CA",
"fr-FR",
"gl-ES",
"ka-GE",
"de-DE",
"el-GR",
"gu-IN",
"ha-NG",
"iw-IL",
"hi-IN",
"hu-HU",
"is-IS",
"id-ID",
"it-IT",
"ja-JP",
"jv-ID",
"kea-CV",
"kam-KE",
"kn-IN",
"kk-KZ",
"km-KH",
"ko-KR",
"ky-KG",
"lo-LA",
"lv-LV",
"ln-CD",
"lt-LT",
"luo-KE",
"lb-LU",
"mk-MK",
"no-NO",
"pl-PL",
"pt-BR",
"pt-PT",
"ro-RO",
"ru-RU",
"es-CO",
"es-MX",
"es-US",
"th-TH",
"tr-TR",
"uk-UA",
"vi-VN",
"da-DK",
]
Gender = Literal["male", "female", "neutral"]
ChatModels = Literal[
"gemini-2.0-flash-001",
"gemini-2.0-flash-lite-preview-02-05",
"gemini-2.0-pro-exp-02-05",
"gemini-1.5-pro",
]
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import asyncio
import dataclasses
import time
import weakref
from dataclasses import dataclass
from typing import Callable, List, Union
from livekit import rtc
from livekit.agents import (
DEFAULT_API_CONNECT_OPTIONS,
APIConnectionError,
APIConnectOptions,
APIStatusError,
APITimeoutError,
stt,
utils,
)
from google.api_core.client_options import ClientOptions
from google.api_core.exceptions import DeadlineExceeded, GoogleAPICallError
from google.auth import default as gauth_default
from google.auth.exceptions import DefaultCredentialsError
from google.cloud.speech_v2 import SpeechAsyncClient
from google.cloud.speech_v2.types import cloud_speech
from .log import logger
from .models import SpeechLanguages, SpeechModels
LgType = Union[SpeechLanguages, str]
LanguageCode = Union[LgType, List[LgType]]
# Google STT has a timeout of 5 mins, we'll attempt to restart the session
# before that timeout is reached
_max_session_duration = 240
# Google is very sensitive to background noise, so we'll ignore results with low confidence
_min_confidence = 0.65
# This class is only be used internally to encapsulate the options
@dataclass
class STTOptions:
languages: List[LgType]
detect_language: bool
interim_results: bool
punctuate: bool
spoken_punctuation: bool
model: SpeechModels | str
sample_rate: int
keywords: List[tuple[str, float]] | None
def build_adaptation(self) -> cloud_speech.SpeechAdaptation | None:
if self.keywords:
return cloud_speech.SpeechAdaptation(
phrase_sets=[
cloud_speech.SpeechAdaptation.AdaptationPhraseSet(
inline_phrase_set=cloud_speech.PhraseSet(
phrases=[
cloud_speech.PhraseSet.Phrase(
value=keyword, boost=boost
)
for keyword, boost in self.keywords
]
)
)
]
)
return None
class STT(stt.STT):
def __init__(
self,
*,
languages: LanguageCode = "en-US", # Google STT can accept multiple languages
detect_language: bool = True,
interim_results: bool = True,
punctuate: bool = True,
spoken_punctuation: bool = False,
model: SpeechModels | str = "latest_long",
location: str = "global",
sample_rate: int = 16000,
credentials_info: dict | None = None,
credentials_file: str | None = None,
keywords: List[tuple[str, float]] | None = None,
):
"""
Create a new instance of Google STT.
Credentials must be provided, either by using the ``credentials_info`` dict, or reading
from the file specified in ``credentials_file`` or via Application Default Credentials as
described in https://cloud.google.com/docs/authentication/application-default-credentials
args:
languages(LanguageCode): list of language codes to recognize (default: "en-US")
detect_language(bool): whether to detect the language of the audio (default: True)
interim_results(bool): whether to return interim results (default: True)
punctuate(bool): whether to punctuate the audio (default: True)
spoken_punctuation(bool): whether to use spoken punctuation (default: False)
model(SpeechModels): the model to use for recognition default: "latest_long"
location(str): the location to use for recognition default: "global"
sample_rate(int): the sample rate of the audio default: 16000
credentials_info(dict): the credentials info to use for recognition (default: None)
credentials_file(str): the credentials file to use for recognition (default: None)
keywords(List[tuple[str, float]]): list of keywords to recognize (default: None)
"""
super().__init__(
capabilities=stt.STTCapabilities(streaming=True, interim_results=True)
)
self._location = location
self._credentials_info = credentials_info
self._credentials_file = credentials_file
if credentials_file is None and credentials_info is None:
try:
gauth_default()
except DefaultCredentialsError:
raise ValueError(
"Application default credentials must be available "
"when using Google STT without explicitly passing "
"credentials through credentials_info or credentials_file."
)
if isinstance(languages, str):
languages = [languages]
self._config = STTOptions(
languages=languages,
detect_language=detect_language,
interim_results=interim_results,
punctuate=punctuate,
spoken_punctuation=spoken_punctuation,
model=model,
sample_rate=sample_rate,
keywords=keywords,
)
self._streams = weakref.WeakSet[SpeechStream]()
self._pool = utils.ConnectionPool[SpeechAsyncClient](
max_session_duration=_max_session_duration,
connect_cb=self._create_client,
)
async def _create_client(self) -> SpeechAsyncClient:
# Add support for passing a specific location that matches recognizer
# see: https://cloud.google.com/speech-to-text/v2/docs/speech-to-text-supported-languages
client_options = None
client: SpeechAsyncClient | None = None
if self._location != "global":
client_options = ClientOptions(
api_endpoint=f"{self._location}-speech.googleapis.com"
)
if self._credentials_info:
client = SpeechAsyncClient.from_service_account_info(
self._credentials_info,
client_options=client_options,
)
elif self._credentials_file:
client = SpeechAsyncClient.from_service_account_file(
self._credentials_file,
client_options=client_options,
)
else:
client = SpeechAsyncClient(
client_options=client_options,
)
assert client is not None
return client
def _get_recognizer(self, client: SpeechAsyncClient) -> str:
# TODO(theomonnom): should we use recognizers?
# recognizers may improve latency https://cloud.google.com/speech-to-text/v2/docs/recognizers#understand_recognizers
# TODO(theomonnom): find a better way to access the project_id
try:
project_id = client.transport._credentials.project_id # type: ignore
except AttributeError:
from google.auth import default as ga_default
_, project_id = ga_default()
return f"projects/{project_id}/locations/{self._location}/recognizers/_"
def _sanitize_options(self, *, language: str | None = None) -> STTOptions:
config = dataclasses.replace(self._config)
if language:
config.languages = [language]
if not isinstance(config.languages, list):
config.languages = [config.languages]
elif not config.detect_language:
if len(config.languages) > 1:
logger.warning(
"multiple languages provided, but language detection is disabled"
)
config.languages = [config.languages[0]]
return config
async def _recognize_impl(
self,
buffer: utils.AudioBuffer,
*,
language: SpeechLanguages | str | None,
conn_options: APIConnectOptions,
) -> stt.SpeechEvent:
config = self._sanitize_options(language=language)
frame = rtc.combine_audio_frames(buffer)
config = cloud_speech.RecognitionConfig(
explicit_decoding_config=cloud_speech.ExplicitDecodingConfig(
encoding=cloud_speech.ExplicitDecodingConfig.AudioEncoding.LINEAR16,
sample_rate_hertz=frame.sample_rate,
audio_channel_count=frame.num_channels,
),
adaptation=config.build_adaptation(),
features=cloud_speech.RecognitionFeatures(
enable_automatic_punctuation=config.punctuate,
enable_spoken_punctuation=config.spoken_punctuation,
enable_word_time_offsets=True,
),
model=config.model,
language_codes=config.languages,
)
try:
async with self._pool.connection() as client:
raw = await client.recognize(
cloud_speech.RecognizeRequest(
recognizer=self._get_recognizer(client),
config=config,
content=frame.data.tobytes(),
),
timeout=conn_options.timeout,
)
return _recognize_response_to_speech_event(raw)
except DeadlineExceeded:
raise APITimeoutError()
except GoogleAPICallError as e:
raise APIStatusError(
e.message,
status_code=e.code or -1,
)
except Exception as e:
raise APIConnectionError() from e
def stream(
self,
*,
language: SpeechLanguages | str | None = None,
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
) -> "SpeechStream":
config = self._sanitize_options(language=language)
stream = SpeechStream(
stt=self,
pool=self._pool,
recognizer_cb=self._get_recognizer,
config=config,
conn_options=conn_options,
)
self._streams.add(stream)
return stream
def update_options(
self,
*,
languages: LanguageCode | None = None,
detect_language: bool | None = None,
interim_results: bool | None = None,
punctuate: bool | None = None,
spoken_punctuation: bool | None = None,
model: SpeechModels | None = None,
location: str | None = None,
keywords: List[tuple[str, float]] | None = None,
):
if languages is not None:
if isinstance(languages, str):
languages = [languages]
self._config.languages = languages
if detect_language is not None:
self._config.detect_language = detect_language
if interim_results is not None:
self._config.interim_results = interim_results
if punctuate is not None:
self._config.punctuate = punctuate
if spoken_punctuation is not None:
self._config.spoken_punctuation = spoken_punctuation
if model is not None:
self._config.model = model
if location is not None:
self._location = location
# if location is changed, fetch a new client and recognizer as per the new location
self._pool.invalidate()
if keywords is not None:
self._config.keywords = keywords
for stream in self._streams:
stream.update_options(
languages=languages,
detect_language=detect_language,
interim_results=interim_results,
punctuate=punctuate,
spoken_punctuation=spoken_punctuation,
model=model,
keywords=keywords,
)
async def aclose(self) -> None:
await self._pool.aclose()
await super().aclose()
class SpeechStream(stt.SpeechStream):
def __init__(
self,
*,
stt: STT,
conn_options: APIConnectOptions,
pool: utils.ConnectionPool[SpeechAsyncClient],
recognizer_cb: Callable[[SpeechAsyncClient], str],
config: STTOptions,
) -> None:
super().__init__(
stt=stt, conn_options=conn_options, sample_rate=config.sample_rate
)
self._pool = pool
self._recognizer_cb = recognizer_cb
self._config = config
self._reconnect_event = asyncio.Event()
self._session_connected_at: float = 0
def update_options(
self,
*,
languages: LanguageCode | None = None,
detect_language: bool | None = None,
interim_results: bool | None = None,
punctuate: bool | None = None,
spoken_punctuation: bool | None = None,
model: SpeechModels | None = None,
keywords: List[tuple[str, float]] | None = None,
):
if languages is not None:
if isinstance(languages, str):
languages = [languages]
self._config.languages = languages
if detect_language is not None:
self._config.detect_language = detect_language
if interim_results is not None:
self._config.interim_results = interim_results
if punctuate is not None:
self._config.punctuate = punctuate
if spoken_punctuation is not None:
self._config.spoken_punctuation = spoken_punctuation
if model is not None:
self._config.model = model
if keywords is not None:
self._config.keywords = keywords
self._reconnect_event.set()
async def _run(self) -> None:
# google requires a async generator when calling streaming_recognize
# this function basically convert the queue into a async generator
async def input_generator(
client: SpeechAsyncClient, should_stop: asyncio.Event
):
try:
# first request should contain the config
yield cloud_speech.StreamingRecognizeRequest(
recognizer=self._recognizer_cb(client),
streaming_config=self._streaming_config,
)
async for frame in self._input_ch:
# when the stream is aborted due to reconnect, this input_generator
# needs to stop consuming frames
# when the generator stops, the previous gRPC stream will close
if should_stop.is_set():
return
if isinstance(frame, rtc.AudioFrame):
yield cloud_speech.StreamingRecognizeRequest(
audio=frame.data.tobytes()
)
except Exception:
logger.exception(
"an error occurred while streaming input to google STT"
)
async def process_stream(client: SpeechAsyncClient, stream):
has_started = False
async for resp in stream:
if (
resp.speech_event_type
== cloud_speech.StreamingRecognizeResponse.SpeechEventType.SPEECH_ACTIVITY_BEGIN
):
self._event_ch.send_nowait(
stt.SpeechEvent(type=stt.SpeechEventType.START_OF_SPEECH)
)
has_started = True
if (
resp.speech_event_type
== cloud_speech.StreamingRecognizeResponse.SpeechEventType.SPEECH_EVENT_TYPE_UNSPECIFIED
):
result = resp.results[0]
speech_data = _streaming_recognize_response_to_speech_data(resp)
if speech_data is None:
continue
if not result.is_final:
self._event_ch.send_nowait(
stt.SpeechEvent(
type=stt.SpeechEventType.INTERIM_TRANSCRIPT,
alternatives=[speech_data],
)
)
else:
self._event_ch.send_nowait(
stt.SpeechEvent(
type=stt.SpeechEventType.FINAL_TRANSCRIPT,
alternatives=[speech_data],
)
)
if (
time.time() - self._session_connected_at
> _max_session_duration
):
logger.debug(
"Google STT maximum connection time reached. Reconnecting..."
)
self._pool.remove(client)
if has_started:
self._event_ch.send_nowait(
stt.SpeechEvent(
type=stt.SpeechEventType.END_OF_SPEECH
)
)
has_started = False
self._reconnect_event.set()
return
if (
resp.speech_event_type
== cloud_speech.StreamingRecognizeResponse.SpeechEventType.SPEECH_ACTIVITY_END
):
self._event_ch.send_nowait(
stt.SpeechEvent(type=stt.SpeechEventType.END_OF_SPEECH)
)
has_started = False
while True:
try:
async with self._pool.connection() as client:
self._streaming_config = cloud_speech.StreamingRecognitionConfig(
config=cloud_speech.RecognitionConfig(
explicit_decoding_config=cloud_speech.ExplicitDecodingConfig(
encoding=cloud_speech.ExplicitDecodingConfig.AudioEncoding.LINEAR16,
sample_rate_hertz=self._config.sample_rate,
audio_channel_count=1,
),
adaptation=self._config.build_adaptation(),
language_codes=self._config.languages,
model=self._config.model,
features=cloud_speech.RecognitionFeatures(
enable_automatic_punctuation=self._config.punctuate,
enable_word_time_offsets=True,
),
),
streaming_features=cloud_speech.StreamingRecognitionFeatures(
interim_results=self._config.interim_results,
),
)
should_stop = asyncio.Event()
stream = await client.streaming_recognize(
requests=input_generator(client, should_stop),
)
self._session_connected_at = time.time()
process_stream_task = asyncio.create_task(
process_stream(client, stream)
)
wait_reconnect_task = asyncio.create_task(
self._reconnect_event.wait()
)
try:
done, _ = await asyncio.wait(
[process_stream_task, wait_reconnect_task],
return_when=asyncio.FIRST_COMPLETED,
)
for task in done:
if task != wait_reconnect_task:
task.result()
if wait_reconnect_task not in done:
break
self._reconnect_event.clear()
finally:
await utils.aio.gracefully_cancel(
process_stream_task, wait_reconnect_task
)
should_stop.set()
except DeadlineExceeded:
raise APITimeoutError()
except GoogleAPICallError as e:
raise APIStatusError(
e.message,
status_code=e.code or -1,
)
except Exception as e:
raise APIConnectionError() from e
def _recognize_response_to_speech_event(
resp: cloud_speech.RecognizeResponse,
) -> stt.SpeechEvent:
text = ""
confidence = 0.0
for result in resp.results:
text += result.alternatives[0].transcript
confidence += result.alternatives[0].confidence
# not sure why start_offset and end_offset returns a timedelta
start_offset = resp.results[0].alternatives[0].words[0].start_offset
end_offset = resp.results[-1].alternatives[0].words[-1].end_offset
confidence /= len(resp.results)
lg = resp.results[0].language_code
return stt.SpeechEvent(
type=stt.SpeechEventType.FINAL_TRANSCRIPT,
alternatives=[
stt.SpeechData(
language=lg,
start_time=start_offset.total_seconds(), # type: ignore
end_time=end_offset.total_seconds(), # type: ignore
confidence=confidence,
text=text,
)
],
)
def _streaming_recognize_response_to_speech_data(
resp: cloud_speech.StreamingRecognizeResponse,
) -> stt.SpeechData | None:
text = ""
confidence = 0.0
for result in resp.results:
if len(result.alternatives) == 0:
continue
text += result.alternatives[0].transcript
confidence += result.alternatives[0].confidence
confidence /= len(resp.results)
lg = resp.results[0].language_code
if confidence < _min_confidence:
return None
if text == "":
return None
data = stt.SpeechData(
language=lg, start_time=0, end_time=0, confidence=confidence, text=text
)
return data
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional
from livekit.agents import (
APIConnectionError,
APIConnectOptions,
APIStatusError,
APITimeoutError,
tts,
utils,
)
from google.api_core.client_options import ClientOptions
from google.api_core.exceptions import DeadlineExceeded, GoogleAPICallError
from google.cloud import texttospeech
from google.cloud.texttospeech_v1.types import SsmlVoiceGender, SynthesizeSpeechResponse
from .models import Gender, SpeechLanguages
@dataclass
class _TTSOptions:
voice: texttospeech.VoiceSelectionParams
audio_config: texttospeech.AudioConfig
class TTS(tts.TTS):
def __init__(
self,
*,
language: SpeechLanguages | str = "en-US",
gender: Gender | str = "neutral",
voice_name: str = "", # Not required
sample_rate: int = 24000,
pitch: int = 0,
effects_profile_id: str = "",
speaking_rate: float = 1.0,
location: str = "global",
credentials_info: dict | None = None,
credentials_file: str | None = None,
) -> None:
"""
Create a new instance of Google TTS.
Credentials must be provided, either by using the ``credentials_info`` dict, or reading
from the file specified in ``credentials_file`` or the ``GOOGLE_APPLICATION_CREDENTIALS``
environmental variable.
Args:
language (SpeechLanguages | str, optional): Language code (e.g., "en-US"). Default is "en-US".
gender (Gender | str, optional): Voice gender ("male", "female", "neutral"). Default is "neutral".
voice_name (str, optional): Specific voice name. Default is an empty string.
sample_rate (int, optional): Audio sample rate in Hz. Default is 24000.
location (str, optional): Location for the TTS client. Default is "global".
pitch (float, optional): Speaking pitch, ranging from -20.0 to 20.0 semitones relative to the original pitch. Default is 0.
effects_profile_id (str): Optional identifier for selecting audio effects profiles to apply to the synthesized speech.
speaking_rate (float, optional): Speed of speech. Default is 1.0.
credentials_info (dict, optional): Dictionary containing Google Cloud credentials. Default is None.
credentials_file (str, optional): Path to the Google Cloud credentials JSON file. Default is None.
"""
super().__init__(
capabilities=tts.TTSCapabilities(
streaming=False,
),
sample_rate=sample_rate,
num_channels=1,
)
self._client: texttospeech.TextToSpeechAsyncClient | None = None
self._credentials_info = credentials_info
self._credentials_file = credentials_file
self._location = location
voice = texttospeech.VoiceSelectionParams(
name=voice_name,
language_code=language,
ssml_gender=_gender_from_str(gender),
)
self._opts = _TTSOptions(
voice=voice,
audio_config=texttospeech.AudioConfig(
audio_encoding=texttospeech.AudioEncoding.OGG_OPUS,
sample_rate_hertz=sample_rate,
pitch=pitch,
effects_profile_id=effects_profile_id,
speaking_rate=speaking_rate,
),
)
def update_options(
self,
*,
language: SpeechLanguages | str = "en-US",
gender: Gender | str = "neutral",
voice_name: str = "", # Not required
speaking_rate: float = 1.0,
) -> None:
"""
Update the TTS options.
Args:
language (SpeechLanguages | str, optional): Language code (e.g., "en-US"). Default is "en-US".
gender (Gender | str, optional): Voice gender ("male", "female", "neutral"). Default is "neutral".
voice_name (str, optional): Specific voice name. Default is an empty string.
speaking_rate (float, optional): Speed of speech. Default is 1.0.
"""
self._opts.voice = texttospeech.VoiceSelectionParams(
name=voice_name,
language_code=language,
ssml_gender=_gender_from_str(gender),
)
self._opts.audio_config.speaking_rate = speaking_rate
def _ensure_client(self) -> texttospeech.TextToSpeechAsyncClient:
api_endpoint = "texttospeech.googleapis.com"
if self._location != "global":
api_endpoint = f"{self._location}-texttospeech.googleapis.com"
if self._client is None:
if self._credentials_info:
self._client = (
texttospeech.TextToSpeechAsyncClient.from_service_account_info(
self._credentials_info,
client_options=ClientOptions(api_endpoint=api_endpoint),
)
)
elif self._credentials_file:
self._client = (
texttospeech.TextToSpeechAsyncClient.from_service_account_file(
self._credentials_file,
client_options=ClientOptions(api_endpoint=api_endpoint),
)
)
else:
self._client = texttospeech.TextToSpeechAsyncClient(
client_options=ClientOptions(api_endpoint=api_endpoint)
)
assert self._client is not None
return self._client
def synthesize(
self,
text: str,
*,
conn_options: Optional[APIConnectOptions] = None,
) -> "ChunkedStream":
return ChunkedStream(
tts=self,
input_text=text,
conn_options=conn_options,
opts=self._opts,
client=self._ensure_client(),
)
class ChunkedStream(tts.ChunkedStream):
def __init__(
self,
*,
tts: TTS,
input_text: str,
opts: _TTSOptions,
client: texttospeech.TextToSpeechAsyncClient,
conn_options: Optional[APIConnectOptions] = None,
) -> None:
super().__init__(tts=tts, input_text=input_text, conn_options=conn_options)
self._opts, self._client = opts, client
async def _run(self) -> None:
request_id = utils.shortuuid()
try:
response: SynthesizeSpeechResponse = await self._client.synthesize_speech(
input=texttospeech.SynthesisInput(text=self._input_text),
voice=self._opts.voice,
audio_config=self._opts.audio_config,
timeout=self._conn_options.timeout,
)
# Create AudioStreamDecoder for OGG format
decoder = utils.codecs.AudioStreamDecoder(
sample_rate=self._opts.audio_config.sample_rate_hertz,
num_channels=1,
)
try:
decoder.push(response.audio_content)
decoder.end_input()
emitter = tts.SynthesizedAudioEmitter(
event_ch=self._event_ch,
request_id=request_id,
)
async for frame in decoder:
emitter.push(frame)
emitter.flush()
finally:
await decoder.aclose()
except DeadlineExceeded:
raise APITimeoutError()
except GoogleAPICallError as e:
raise APIStatusError(
e.message,
status_code=e.code or -1,
request_id=None,
body=None,
)
except Exception as e:
raise APIConnectionError() from e
def _gender_from_str(gender: str) -> SsmlVoiceGender:
ssml_gender = SsmlVoiceGender.NEUTRAL
if gender == "male":
ssml_gender = SsmlVoiceGender.MALE
elif gender == "female":
ssml_gender = SsmlVoiceGender.FEMALE
return ssml_gender # type: ignore
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "0.11.3"
{
"name": "livekit-plugins-google",
"private": true,
"version": "0.11.3"
}
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pathlib
import setuptools
import setuptools.command.build_py
here = pathlib.Path(__file__).parent.resolve()
about = {}
with open(os.path.join(here, "livekit", "plugins", "google", "version.py"), "r") as f:
exec(f.read(), about)
setuptools.setup(
name="livekit-plugins-google",
version=about["__version__"],
description="Agent Framework plugin for services from Google Cloud",
long_description=(here / "README.md").read_text(encoding="utf-8"),
long_description_content_type="text/markdown",
url="https://github.com/livekit/agents",
cmdclass={},
classifiers=[
"Intended Audience :: Developers",
"License :: OSI Approved :: Apache Software License",
"Topic :: Multimedia :: Sound/Audio",
"Topic :: Multimedia :: Video",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3 :: Only",
],
keywords=["webrtc", "realtime", "audio", "video", "livekit"],
license="Apache-2.0",
packages=setuptools.find_namespace_packages(include=["livekit.*"]),
python_requires=">=3.9.0",
install_requires=[
"google-auth >= 2, < 3",
"google-cloud-speech >= 2, < 3",
"google-cloud-texttospeech >= 2, < 3",
"google-genai == 1.3.0",
"livekit-agents>=0.12.16,<1.0.0",
],
package_data={"livekit.plugins.google": ["py.typed"]},
project_urls={
"Documentation": "https://docs.livekit.io",
"Website": "https://livekit.io/",
"Source": "https://github.com/livekit/agents",
},
)
# livekit-plugins-minimal
## 0.1.2
### Patch Changes
- update to tts model and voices - [#1725](https://github.com/livekit/agents/pull/1725) ([@davidzhao](https://github.com/davidzhao))
## 0.1.1
### Patch Changes
- initial version - [#1689](https://github.com/livekit/agents/pull/1689) ([@davidzhao](https://github.com/davidzhao))
## 0.1.0
Initial version
# LiveKit Plugins Groq
Agent Framework plugin for services from Groq. Currently supporting STT, and LLM
## Installation
```bash
pip install livekit-plugins-groq
For credentials, you’ll need a Groq Cloud account and obtain the correct credentials. Credentials can be passed directly or via GROQ_API_KEY environment variable
## livekit-plugins/livekit-plugins-groq/livekit/plugins/groq/__init__.py
```py
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from livekit.agents import Plugin
from .log import logger
from .services import LLM, STT
from .tts import TTS
from .version import __version__
__all__ = ["TTS", "LLM", "STT", "__version__"]
class GroqPlugin(Plugin):
def __init__(self):
super().__init__(__name__, __version__, __package__, logger)
Plugin.register_plugin(GroqPlugin())
import logging
logger = logging.getLogger("livekit.plugins.groq")
from typing import Literal
# listing production models from https://console.groq.com/docs/models
STTModels = Literal[
"whisper-large-v3",
"whisper-large-v3-turbo",
"distil-whisper-large-v3-en",
]
LLMModels = Literal[
"llama3-8b-8192",
"llama3-70b-8192",
"llama-guard-3-8b",
"llama-3.1-8b-instant",
"llama-3.3-70b-versatile",
]
TTSModels = Literal[
"playai-tts",
"playai-tts-arabic",
]
TTSVoices = Literal[
# english voices
"Arista-PlayAI",
"Atlas-PlayAI",
"Basil-PlayAI",
"Briggs-PlayAI",
"Calum-PlayAI",
"Celeste-PlayAI",
"Cheyenne-PlayAI",
"Chip-PlayAI",
"Cillian-PlayAI",
"Deedee-PlayAI",
"Fritz-PlayAI",
"Gail-PlayAI",
"Indigo-PlayAI",
"Mamaw-PlayAI",
"Mason-PlayAI",
"Mikail-PlayAI",
"Mitch-PlayAI",
"Quinn-PlayAI",
"Thunder-PlayAI",
# arabic voices
"Nasser-PlayAI",
"Khalid-PlayAI",
"Amira-PlayAI",
"Ahmad-PlayAI",
]
import os
from typing import Literal, Union
import openai
from livekit.agents.llm import ToolChoice
from livekit.plugins.openai import LLM as OpenAILLM
from livekit.plugins.openai import STT as OpenAISTT
from .models import LLMModels, STTModels
class LLM(OpenAILLM):
def __init__(
self,
*,
model: str | LLMModels = "llama-3.3-70b-versatile",
api_key: str | None = None,
user: str | None = None,
temperature: float | None = None,
parallel_tool_calls: bool | None = None,
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = "auto",
max_tokens: int | None = None,
base_url: str | None = "https://api.groq.com/openai/v1",
client: openai.AsyncClient | None = None,
):
"""
Create a new instance of Groq LLM.
``api_key`` must be set to your Groq API key, either using the argument or by setting
the ``GROQ_API_KEY`` environmental variable.
"""
super().__init__(
model=model,
api_key=_get_api_key(api_key),
base_url=base_url,
client=client,
user=user,
temperature=temperature,
parallel_tool_calls=parallel_tool_calls,
tool_choice=tool_choice,
max_tokens=max_tokens,
)
class STT(OpenAISTT):
def __init__(
self,
*,
model: STTModels | str = "whisper-large-v3-turbo",
api_key: str | None = None,
base_url: str | None = "https://api.groq.com/openai/v1",
client: openai.AsyncClient | None = None,
language: str = "en",
prompt: str | None = None,
detect_language: bool = False,
):
"""
Create a new instance of Groq STT.
``api_key`` must be set to your Groq API key, either using the argument or by setting
the ``GROQ_API_KEY`` environmental variable.
"""
super().__init__(
model=model,
api_key=_get_api_key(api_key),
base_url=base_url,
client=client,
language=language,
detect_language=detect_language,
prompt=prompt,
use_realtime=False,
)
def _get_api_key(key: str | None) -> str:
key = key or os.environ.get("GROQ_API_KEY")
if not key:
raise ValueError(
"GROQ_API_KEY is required, either as argument or set GROQ_API_KEY environmental variable"
)
return key
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import asyncio
import os
from dataclasses import dataclass
from typing import Optional
import aiohttp
from livekit.agents import (
APIConnectionError,
APIConnectOptions,
APIStatusError,
APITimeoutError,
tts,
utils,
)
from .log import logger
from .models import TTSModels, TTSVoices
DEFAULT_BASE_URL = "https://api.groq.com/openai/v1"
SAMPLE_RATE = 48000
NUM_CHANNELS = 1
@dataclass
class _TTSOptions:
model: TTSModels | str
voice: TTSVoices | str
api_key: str
base_url: str
class TTS(tts.TTS):
def __init__(
self,
*,
base_url: str | None = None,
model: TTSModels | str = "playai-tts",
voice: TTSVoices | str = "Arista-PlayAI",
api_key: str | None = None,
http_session: aiohttp.ClientSession | None = None,
) -> None:
"""
Create a new instance of Groq TTS.
if `api_key` is not provided, it will be read from the ``GROQ_API_KEY``
environmental variable.
Args:
model (SpeechModels | str, optional): Model to use. Default is "playai-tts".
voice (SpeechVoices | str, optional): Voice to use. Default is "Autumn-PlayAI".
api_key (str | None, optional): API key to use. Default is None.
"""
super().__init__(
capabilities=tts.TTSCapabilities(
streaming=False,
),
sample_rate=SAMPLE_RATE,
num_channels=1,
)
self._session = http_session
if not base_url:
base_url = DEFAULT_BASE_URL
if not api_key:
api_key = os.getenv("GROQ_API_KEY")
if not api_key:
raise ValueError("GROQ_API_KEY is not set")
self._opts = _TTSOptions(
model=model,
voice=voice,
api_key=api_key,
base_url=base_url,
)
def _ensure_session(self) -> aiohttp.ClientSession:
if not self._session:
self._session = utils.http_context.http_session()
return self._session
def update_options(
self,
*,
model: TTSModels | None = None,
voice: TTSVoices | None = None,
) -> None:
"""
Update the TTS options.
Args:
model (SpeechModels | str, optional): Model to use. Default is None.
voice (SpeechVoices | str, optional): Voice to use. Default is None.
"""
if model:
self._opts.model = model
if voice:
self._opts.voice = voice
def synthesize(
self,
text: str,
*,
conn_options: Optional[APIConnectOptions] = None,
segment_id: str | None = None,
) -> "ChunkedStream":
return ChunkedStream(
tts=self,
input_text=text,
conn_options=conn_options,
opts=self._opts,
session=self._ensure_session(),
segment_id=segment_id,
)
class ChunkedStream(tts.ChunkedStream):
def __init__(
self,
*,
tts: TTS,
input_text: str,
conn_options: Optional[APIConnectOptions] = None,
opts: _TTSOptions,
session: aiohttp.ClientSession,
segment_id: str | None = None,
) -> None:
super().__init__(tts=tts, input_text=input_text, conn_options=conn_options)
self._opts = opts
self._session = session
self._segment_id = segment_id
async def _run(self) -> None:
request_id = utils.shortuuid()
headers = {
"Authorization": f"Bearer {self._opts.api_key}",
"Content-Type": "application/json",
}
payload = {
"model": self._opts.model,
"voice": self._opts.voice,
"input": self._input_text,
"response_format": "wav",
}
decoder = utils.codecs.AudioStreamDecoder(
sample_rate=SAMPLE_RATE,
num_channels=NUM_CHANNELS,
)
decode_task: Optional[asyncio.Task] = None
api_url = f"{self._opts.base_url}/audio/speech"
try:
async with self._session.post(
api_url, headers=headers, json=payload
) as response:
if not response.content_type.startswith("audio"):
content = await response.text()
logger.error("Groq returned non-audio data: %s", content)
return
async def _decode_loop():
try:
async for bytes_data, _ in response.content.iter_chunks():
decoder.push(bytes_data)
finally:
decoder.end_input()
decode_task = asyncio.create_task(_decode_loop())
emitter = tts.SynthesizedAudioEmitter(
event_ch=self._event_ch,
request_id=request_id,
segment_id=self._segment_id,
)
async for frame in decoder:
emitter.push(frame)
emitter.flush()
except asyncio.TimeoutError as e:
raise APITimeoutError() from e
except aiohttp.ClientResponseError as e:
raise APIStatusError(
message=e.message,
status_code=e.status,
request_id=request_id,
body=None,
) from e
except Exception as e:
raise APIConnectionError() from e
finally:
if decode_task:
await utils.aio.gracefully_cancel(decode_task)
await decoder.aclose()
# Copyright 2023 LiveKit, Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "0.1.2"
{
"name": "livekit-plugins-groq",
"private": true,
"version": "0.1.2"
}
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pathlib
import setuptools
import setuptools.command.build_py
here = pathlib.Path(__file__).parent.resolve()
about = {}
with open(os.path.join(here, "livekit", "plugins", "groq", "version.py"), "r") as f:
exec(f.read(), about)
setuptools.setup(
name="livekit-plugins-groq",
version=about["__version__"],
description="Groq inference plugin for LiveKit",
long_description=(here / "README.md").read_text(encoding="utf-8"),
long_description_content_type="text/markdown",
url="https://github.com/livekit/agents",
cmdclass={},
classifiers=[
"Intended Audience :: Developers",
"License :: OSI Approved :: Apache Software License",
"Topic :: Multimedia :: Sound/Audio",
"Topic :: Multimedia :: Video",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3 :: Only",
],
keywords=["groq", "llm", "stt", "tts", "webrtc", "livekit"],
license="Apache-2.0",
packages=setuptools.find_namespace_packages(include=["livekit.*"]),
python_requires=">=3.9.0",
install_requires=[
"livekit-agents>=0.12.16,<1.0.0",
"livekit-plugins-openai>=0.12.0,<1.0.0",
],
package_data={"livekit.plugins.groq": ["py.typed"]},
project_urls={
"Documentation": "https://docs.livekit.io",
"Website": "https://livekit.io/",
"Source": "https://github.com/livekit/agents",
},
)
# livekit-plugins-llama-index
## 0.2.3
### Patch Changes
- updated livekit-agent reference to <1.0 - [#1607](https://github.com/livekit/agents/pull/1607) ([@davidzhao](https://github.com/davidzhao))
## 0.2.2
### Patch Changes
- fix release - [#1176](https://github.com/livekit/agents/pull/1176) ([@theomonnom](https://github.com/theomonnom))
## 0.2.1
### Patch Changes
- support for custom tool use in LLMs - [#1102](https://github.com/livekit/agents/pull/1102) ([@jayeshp19](https://github.com/jayeshp19))
- feat: llm retry & llm.FallbackAdapter - [#1132](https://github.com/livekit/agents/pull/1132) ([@theomonnom](https://github.com/theomonnom))
## 0.2.0
### Minor Changes
- prepare for release - [#1007](https://github.com/livekit/agents/pull/1007) ([@theomonnom](https://github.com/theomonnom))
### Patch Changes
- Publish llama-index plugin - [#924](https://github.com/livekit/agents/pull/924) ([@davidzhao](https://github.com/davidzhao))
## 0.1.1
### Patch Changes
- livekit-plugins-llama-index - [#696](https://github.com/livekit/agents/pull/696) ([@theomonnom](https://github.com/theomonnom))
# LiveKit Plugins Llama Index
Agent Framework plugin for using Llama Index. Currently supports [Query Engine](https://docs.llamaindex.ai/en/stable/module_guides/deploying/query_engine/) and [Chat Engine](https://docs.llamaindex.ai/en/stable/module_guides/deploying/chat_engines/).
## Install
```bash
pip install livekit-plugins-llama-index
Query Engine is primarily used for RAG. See example voice agent
Chat Engine can be used as an LLM within the framework.
# load the existing index
storage_context = StorageContext.from_defaults(persist_dir=<mydir>)
index = load_index_from_storage(storage_context)
async def entrypoint(ctx: JobContext):
...
chat_engine = index.as_chat_engine(chat_mode=ChatMode.CONTEXT)
assistant = VoicePipelineAgent(
vad=silero.VAD.load(),
stt=deepgram.STT(),
llm=llama_index.LLM(chat_engine=chat_engine),
tts=openai.TTS(),
chat_ctx=initial_ctx,
)
full example here
## livekit-plugins/livekit-plugins-llama-index/livekit/plugins/llama_index/__init__.py
```py
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from livekit.agents import Plugin
from .llm import LLM, LLMStream
from .log import logger
from .version import __version__
__all__ = ["LLM", "LLMStream"]
class LlamaIndexPlugin(Plugin):
def __init__(self):
super().__init__(__name__, __version__, __package__, logger)
Plugin.register_plugin(LlamaIndexPlugin())
from __future__ import annotations
from typing import Literal, Union
from livekit.agents import (
APIConnectionError,
llm,
)
from livekit.agents.llm import ToolChoice
from livekit.agents.types import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions
from llama_index.core.chat_engine.types import (
BaseChatEngine,
StreamingAgentChatResponse,
)
from llama_index.core.llms import ChatMessage, MessageRole
from .log import logger
class LLM(llm.LLM):
def __init__(
self,
*,
chat_engine: BaseChatEngine,
) -> None:
super().__init__()
self._chat_engine = chat_engine
def chat(
self,
*,
chat_ctx: llm.ChatContext,
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
fnc_ctx: llm.FunctionContext | None = None,
temperature: float | None = None,
n: int | None = 1,
parallel_tool_calls: bool | None = None,
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]]
| None = None,
) -> "LLMStream":
if fnc_ctx is not None:
logger.warning("fnc_ctx is currently not supported with llama_index.LLM")
return LLMStream(
self,
chat_engine=self._chat_engine,
chat_ctx=chat_ctx,
conn_options=conn_options,
)
class LLMStream(llm.LLMStream):
def __init__(
self,
llm: LLM,
*,
chat_engine: BaseChatEngine,
chat_ctx: llm.ChatContext,
conn_options: APIConnectOptions,
) -> None:
super().__init__(
llm, chat_ctx=chat_ctx, fnc_ctx=None, conn_options=conn_options
)
self._chat_engine = chat_engine
self._stream: StreamingAgentChatResponse | None = None
async def _run(self) -> None:
chat_ctx = self._chat_ctx.copy()
user_msg = chat_ctx.messages.pop()
if user_msg.role != "user":
raise ValueError(
"The last message in the chat context must be from the user"
)
assert isinstance(user_msg.content, str), (
"user message content must be a string"
)
try:
if not self._stream:
self._stream = await self._chat_engine.astream_chat(
user_msg.content,
chat_history=_to_llama_chat_messages(self._chat_ctx),
)
async for delta in self._stream.async_response_gen():
self._event_ch.send_nowait(
llm.ChatChunk(
request_id="",
choices=[
llm.Choice(
delta=llm.ChoiceDelta(
role="assistant",
content=delta,
)
)
],
)
)
except Exception as e:
raise APIConnectionError() from e
def _to_llama_chat_messages(chat_ctx: llm.ChatContext) -> list[ChatMessage]:
return [
ChatMessage(content=msg.content, role=_to_llama_message_role(msg.role))
for msg in chat_ctx.messages
]
def _to_llama_message_role(chat_role: llm.ChatRole) -> MessageRole:
if chat_role == "system":
return MessageRole.SYSTEM
elif chat_role == "user":
return MessageRole.USER
elif chat_role == "assistant":
return MessageRole.ASSISTANT
elif chat_role == "tool":
return MessageRole.TOOL
import logging
logger = logging.getLogger("livekit.plugins.llama_index")
# Copyright 2023 LiveKit, Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "0.2.3"
{
"name": "livekit-plugins-llama-index",
"private": true,
"version": "0.2.3"
}
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pathlib
import setuptools
import setuptools.command.build_py
here = pathlib.Path(__file__).parent.resolve()
about = {}
with open(
os.path.join(here, "livekit", "plugins", "llama_index", "version.py"), "r"
) as f:
exec(f.read(), about)
setuptools.setup(
name="livekit-plugins-llama-index",
version=about["__version__"],
description="Llama Index plugin for LiveKit Agents",
long_description=(here / "README.md").read_text(encoding="utf-8"),
long_description_content_type="text/markdown",
url="https://github.com/livekit/agents",
cmdclass={},
classifiers=[
"Intended Audience :: Developers",
"License :: OSI Approved :: Apache Software License",
"Topic :: Multimedia :: Sound/Audio",
"Topic :: Multimedia :: Video",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3 :: Only",
],
keywords=["webrtc", "realtime", "audio", "video", "livekit"],
license="Apache-2.0",
packages=setuptools.find_namespace_packages(include=["livekit.*"]),
python_requires=">=3.9.0",
install_requires=["livekit-agents>=0.12.16,<1.0.0"],
package_data={"livekit.plugins.llama_index": ["py.typed"]},
project_urls={
"Documentation": "https://docs.livekit.io",
"Website": "https://livekit.io/",
"Source": "https://github.com/livekit/agents",
},
)
# livekit-plugins-minimal
## 0.2.2
### Patch Changes
- updated livekit-agent reference to <1.0 - [#1607](https://github.com/livekit/agents/pull/1607) ([@davidzhao](https://github.com/davidzhao))
## 0.2.1
### Patch Changes
- fix release - [#1176](https://github.com/livekit/agents/pull/1176) ([@theomonnom](https://github.com/theomonnom))
## 0.2.0
### Minor Changes
- dev prerelease - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
### Patch Changes
- test release - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
- pull: '--rebase --autostash ...' - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
- Default loglevel to warn - [#472](https://github.com/livekit/agents/pull/472) ([@lukasIO](https://github.com/lukasIO))
- bump versions to update dependencies - [#510](https://github.com/livekit/agents/pull/510) ([@theomonnom](https://github.com/theomonnom))
- test release - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
- fix changesets release CI - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
- release v0.8.0 - [`6e74aa714c2dfaa8212db4528d7b59d095b6c660`](https://github.com/livekit/agents/commit/6e74aa714c2dfaa8212db4528d7b59d095b6c660) ([@theomonnom](https://github.com/theomonnom))
- dev fixes - multiprocessing & voiceassistant - [#493](https://github.com/livekit/agents/pull/493) ([@theomonnom](https://github.com/theomonnom))
## 0.2.0-dev.7
### Patch Changes
- pull: '--rebase --autostash ...' - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
## 0.2.0-dev.6
### Patch Changes
- test release - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
## 0.2.0-dev.5
### Patch Changes
- test release - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
## 0.2.0-dev.4
### Patch Changes
- fix changesets release CI - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
## 0.2.0-dev.3
### Patch Changes
- bump versions to update dependencies - [#510](https://github.com/livekit/agents/pull/510) ([@theomonnom](https://github.com/theomonnom))
## 0.2.0-dev.2
### Patch Changes
- dev fixes - multiprocessing & voiceassistant - [#493](https://github.com/livekit/agents/pull/493) ([@theomonnom](https://github.com/theomonnom))
## 0.2.0-dev.1
### Minor Changes
- dev prerelease - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
## 0.1.1-dev.0
### Patch Changes
- Default loglevel to warn - [#472](https://github.com/livekit/agents/pull/472) ([@lukasIO](https://github.com/lukasIO))
# LiveKit Plugins Minimal
This is a minimal example of a LiveKit plugin for Agents.
### Developer note
When copying this directory over to create a new `livekit-plugins` package, make sure it's nested within the `livekit-plugins` folder and that the `"name"` field in `package.json` follows the proper naming convention for CI:
```json
{
"name": "livekit-plugins-<name>",
"private": true
}
## livekit-plugins/livekit-plugins-minimal/livekit/plugins/minimal/__init__.py
```py
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from livekit.agents import Plugin
from .log import logger
from .version import __version__
class MinimalPlugin(Plugin):
def __init__(self):
super().__init__(__name__, __version__, __package__, logger)
Plugin.register_plugin(MinimalPlugin())
import logging
logger = logging.getLogger("livekit.plugins.minimal")
# Copyright 2023 LiveKit, Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "0.2.2"
{
"name": "livekit-plugins-minimal",
"private": true,
"version": "0.2.2"
}
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pathlib
import setuptools
import setuptools.command.build_py
here = pathlib.Path(__file__).parent.resolve()
about = {}
with open(os.path.join(here, "livekit", "plugins", "minimal", "version.py"), "r") as f:
exec(f.read(), about)
setuptools.setup(
name="livekit-plugins-minimal",
version=about["__version__"],
description="Minimal plugin template for LiveKit Agents",
long_description=(here / "README.md").read_text(encoding="utf-8"),
long_description_content_type="text/markdown",
url="https://github.com/livekit/agents",
cmdclass={},
classifiers=[
"Intended Audience :: Developers",
"License :: OSI Approved :: Apache Software License",
"Topic :: Multimedia :: Sound/Audio",
"Topic :: Multimedia :: Video",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3 :: Only",
],
keywords=["webrtc", "realtime", "audio", "video", "livekit"],
license="Apache-2.0",
packages=setuptools.find_namespace_packages(include=["livekit.*"]),
python_requires=">=3.9.0",
install_requires=["livekit-agents>=0.12.16,<1.0.0"],
package_data={"livekit.plugins.minimal": ["py.typed"]},
project_urls={
"Documentation": "https://docs.livekit.io",
"Website": "https://livekit.io/",
"Source": "https://github.com/livekit/agents",
},
)
# livekit-plugins-neuphonic
## 0.1.1
### Patch Changes
- Add string type support to model parameter - [#1657](https://github.com/livekit/agents/pull/1657) ([@jayeshp19](https://github.com/jayeshp19))
- rename NEUPHONIC_API_TOKEN to NEUPHONIC_API_KEY - [#1642](https://github.com/livekit/agents/pull/1642) ([@davidzhao](https://github.com/davidzhao))
# LiveKit Plugins Neuphonic
Agent Framework plugin for voice synthesis with [Neuphonic](https://neuphonic.com) API.
## Installation
```bash
pip install livekit-plugins-neuphonic
You’ll need an API key from Neuphonic. It can be set as an environment variable: NEUPHONIC_API_KEY
## livekit-plugins/livekit-plugins-neuphonic/livekit/plugins/neuphonic/__init__.py
```py
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .tts import TTS, ChunkedStream
from .version import __version__
__all__ = ["TTS", "ChunkedStream", "__version__"]
from livekit.agents import Plugin
from .log import logger
class NeuphonicPlugin(Plugin):
def __init__(self):
super().__init__(__name__, __version__, __package__, logger)
Plugin.register_plugin(NeuphonicPlugin())
# Cleanup docs of unexported modules
_module = dir()
NOT_IN_ALL = [m for m in _module if m not in __all__]
__pdoc__ = {}
for n in NOT_IN_ALL:
__pdoc__[n] = False
import logging
logger = logging.getLogger("livekit.plugins.neuphonic")
from typing import Literal
TTSEncodings = Literal[
"pcm_linear",
"pcm_mulaw",
]
TTSModels = Literal["neu-fast", "neu-hq"]
TTSLangCodes = Literal["en", "nl", "es", "de", "hi", "en-hi", "ar"]
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import asyncio
import base64
import json
import os
import weakref
from dataclasses import dataclass
from typing import Optional
import aiohttp
from livekit.agents import (
APIConnectionError,
APIConnectOptions,
APIStatusError,
APITimeoutError,
tts,
utils,
)
from .log import logger
from .models import TTSEncodings, TTSLangCodes, TTSModels
API_BASE_URL = "api.neuphonic.com"
AUTHORIZATION_HEADER = "X-API-KEY"
NUM_CHANNELS = 1
@dataclass
class _TTSOptions:
base_url: str
api_key: str
model: TTSModels | str
lang_code: TTSLangCodes | str
encoding: TTSEncodings | str
sampling_rate: int
speed: float
voice_id: str | None = None
@property
def model_params(self) -> dict:
"""Returns a dict of all model parameters and their values."""
params = [
"voice_id",
"model",
"lang_code",
"encoding",
"sampling_rate",
"speed",
]
values = {}
for param in params:
if hasattr(self, param) and getattr(self, param) is not None:
values[param] = getattr(self, param)
return values
def get_query_param_string(self):
"""Forms the query parameter string from all model parameters."""
queries = []
for key, value in self.model_params.items():
queries.append(f"{key}={value}")
return "?" + "&".join(queries)
def _parse_sse_message(message: str) -> dict:
"""
Parse each response from the SSE endpoint.
The message will either be a string reading:
- `event: error`
- `event: message`
- `data: { "status_code": 200, "data": {"audio": ... } }`
"""
message = message.strip()
if not message or "data" not in message:
return None
_, value = message.split(": ", 1)
message = json.loads(value)
if message.get("errors") is not None:
raise Exception(
f"Status {message.status_code} error received: {message.errors}."
)
return message
class TTS(tts.TTS):
def __init__(
self,
*,
model: TTSModels | str = "neu_hq",
voice_id: str | None = None,
lang_code: TTSLangCodes | str = "en",
encoding: TTSEncodings | str = "pcm_linear",
speed: float = 1.0,
sample_rate: int = 22050,
api_key: str | None = None,
http_session: aiohttp.ClientSession | None = None,
base_url: str = API_BASE_URL,
) -> None:
"""
Create a new instance of the Neuphonic TTS.
See https://docs.neuphonic.com for more documentation on all of these options, or go to https://app.neuphonic.com/ to test out different options.
Args:
model (TTSModels | str, optional): The Neuphonic model to use. See Defaults to "neu_hq".
voice_id (str, optional): The voice ID for the desired voice. Defaults to None.
lang_code (TTSLanguages | str, optional): The language code for synthesis. Defaults to "en".
encoding (TTSEncodings | str, optional): The audio encoding format. Defaults to "pcm_mulaw".
speed (float, optional): The audio playback speed. Defaults to 1.0.
sample_rate (int, optional): The audio sample rate in Hz. Defaults to 22050.
api_key (str | None, optional): The Neuphonic API key. If not provided, it will be read from the NEUPHONIC_API_KEY environment variable.
http_session (aiohttp.ClientSession | None, optional): An existing aiohttp ClientSession to use. If not provided, a new session will be created.
base_url (str, optional): The base URL for the Neuphonic API. Defaults to "api.neuphonic.com".
"""
super().__init__(
capabilities=tts.TTSCapabilities(streaming=True),
sample_rate=sample_rate,
num_channels=NUM_CHANNELS,
)
api_key = api_key or os.environ.get("NEUPHONIC_API_KEY")
if not api_key:
raise ValueError(
"NEUPHONIC_API_KEY must be set using the argument or by setting the NEUPHONIC_API_KEY environment variable."
)
self._opts = _TTSOptions(
model=model,
voice_id=voice_id,
lang_code=lang_code,
encoding=encoding,
speed=speed,
sampling_rate=sample_rate,
api_key=api_key,
base_url=base_url,
)
self._session = http_session
self._pool = utils.ConnectionPool[aiohttp.ClientWebSocketResponse](
connect_cb=self._connect_ws,
close_cb=self._close_ws,
max_session_duration=90,
mark_refreshed_on_get=True,
)
self._streams = weakref.WeakSet[SynthesizeStream]()
async def _connect_ws(self) -> aiohttp.ClientWebSocketResponse:
session = self._ensure_session()
url = f"wss://{self._opts.base_url}/speak/{self._opts.lang_code}{self._opts.get_query_param_string()}"
return await asyncio.wait_for(
session.ws_connect(url, headers={AUTHORIZATION_HEADER: self._opts.api_key}),
self._conn_options.timeout,
)
async def _close_ws(self, ws: aiohttp.ClientWebSocketResponse):
await ws.close()
def _ensure_session(self) -> aiohttp.ClientSession:
if not self._session:
self._session = utils.http_context.http_session()
return self._session
def prewarm(self) -> None:
self._pool.prewarm()
def update_options(
self,
*,
model: TTSModels | str = None,
voice_id: str | None = None,
lang_code: TTSLangCodes | str | None = None,
encoding: TTSEncodings | str | None = None,
speed: float | None = None,
sample_rate: int | None = None,
) -> None:
"""
Update the Text-to-Speech (TTS) configuration options.
This method allows updating the TTS settings, including model type, voice_id, lang_code,
encoding, speed and sample_rate. If any parameter is not provided, the existing value will be
retained.
Args:
model (TTSModels | str, optional): The Neuphonic model to use.
voice_id (str, optional): The voice ID for the desired voice.
lang_code (TTSLanguages | str, optional): The language code for synthesis..
encoding (TTSEncodings | str, optional): The audio encoding format.
speed (float, optional): The audio playback speed.
sample_rate (int, optional): The audio sample rate in Hz.
"""
self._opts.model = model or self._opts.model
self._opts.voice_id = voice_id or self._opts.voice_id
self._opts.lang_code = lang_code or self._opts.lang_code
self._opts.encoding = encoding or self._opts.encoding
self._opts.speed = speed or self._opts.speed
self._opts.sampling_rate = sample_rate or self._opts.sampling_rate
self._pool.invalidate()
def synthesize(
self,
text: str,
*,
conn_options: Optional[APIConnectOptions] = None,
) -> ChunkedStream:
return ChunkedStream(
tts=self,
input_text=text,
conn_options=conn_options,
opts=self._opts,
session=self._ensure_session(),
)
def stream(
self, *, conn_options: Optional[APIConnectOptions] = None
) -> SynthesizeStream:
stream = SynthesizeStream(
tts=self,
pool=self._pool,
opts=self._opts,
)
self._streams.add(stream)
return stream
async def aclose(self) -> None:
for stream in list(self._streams):
await stream.aclose()
self._streams.clear()
await self._pool.aclose()
await super().aclose()
class ChunkedStream(tts.ChunkedStream):
"""Synthesize chunked text using the SSE endpoint"""
def __init__(
self,
*,
tts: TTS,
input_text: str,
opts: _TTSOptions,
session: aiohttp.ClientSession,
conn_options: Optional[APIConnectOptions] = None,
) -> None:
super().__init__(tts=tts, input_text=input_text, conn_options=conn_options)
self._opts, self._session = opts, session
async def _run(self) -> None:
request_id = utils.shortuuid()
bstream = utils.audio.AudioByteStream(
sample_rate=self._opts.sampling_rate, num_channels=NUM_CHANNELS
)
json_data = {
"text": self._input_text,
**self._opts.model_params,
}
headers = {
AUTHORIZATION_HEADER: self._opts.api_key,
}
try:
async with self._session.post(
f"https://{self._opts.base_url}/sse/speak/{self._opts.lang_code}",
headers=headers,
json=json_data,
timeout=aiohttp.ClientTimeout(
total=30,
sock_connect=self._conn_options.timeout,
),
read_bufsize=10
* 1024
* 1024, # large read_bufsize to avoid `ValueError: Chunk too big`
) as response:
response.raise_for_status()
emitter = tts.SynthesizedAudioEmitter(
event_ch=self._event_ch,
request_id=request_id,
)
async for line in response.content:
message = line.decode("utf-8").strip()
if message:
parsed_message = _parse_sse_message(message)
if (
parsed_message is not None
and parsed_message.get("data", {}).get("audio") is not None
):
audio_bytes = base64.b64decode(
parsed_message["data"]["audio"]
)
for frame in bstream.write(audio_bytes):
emitter.push(frame)
for frame in bstream.flush():
emitter.push(frame)
emitter.flush()
except asyncio.TimeoutError as e:
raise APITimeoutError() from e
except aiohttp.ClientResponseError as e:
raise APIStatusError(
message=e.message,
status_code=e.status,
request_id=None,
body=None,
) from e
except Exception as e:
raise APIConnectionError() from e
class SynthesizeStream(tts.SynthesizeStream):
def __init__(
self,
*,
tts: TTS,
opts: _TTSOptions,
pool: utils.ConnectionPool[aiohttp.ClientWebSocketResponse],
):
super().__init__(tts=tts)
self._opts, self._pool = opts, pool
async def _run(self) -> None:
request_id = utils.shortuuid()
async def _send_task(ws: aiohttp.ClientWebSocketResponse):
"""Stream text to the websocket."""
async for data in self._input_ch:
self._mark_started()
if isinstance(data, self._FlushSentinel):
await ws.send_str(json.dumps({"text": "<STOP>"}))
continue
await ws.send_str(json.dumps({"text": data}))
async def _recv_task(ws: aiohttp.ClientWebSocketResponse):
audio_bstream = utils.audio.AudioByteStream(
sample_rate=self._opts.sampling_rate,
num_channels=NUM_CHANNELS,
)
emitter = tts.SynthesizedAudioEmitter(
event_ch=self._event_ch,
request_id=request_id,
)
while True:
msg = await ws.receive()
if msg.type in (
aiohttp.WSMsgType.CLOSED,
aiohttp.WSMsgType.CLOSE,
aiohttp.WSMsgType.CLOSING,
):
raise APIStatusError(
"Neuphonic connection closed unexpectedly",
request_id=request_id,
)
if msg.type != aiohttp.WSMsgType.TEXT:
logger.warning("Unexpected Neuphonic message type %s", msg.type)
continue
data = json.loads(msg.data)
if data.get("data"):
b64data = base64.b64decode(data["data"]["audio"])
for frame in audio_bstream.write(b64data):
emitter.push(frame)
if data["data"].get(
"stop"
): # A bool flag, is True when audio reaches "<STOP>"
for frame in audio_bstream.flush():
emitter.push(frame)
emitter.flush()
break # we are not going to receive any more audio
else:
logger.error("Unexpected Neuphonic message %s", data)
async with self._pool.connection() as ws:
tasks = [
asyncio.create_task(_send_task(ws)),
asyncio.create_task(_recv_task(ws)),
]
try:
await asyncio.gather(*tasks)
finally:
await utils.aio.gracefully_cancel(*tasks)
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "0.1.1"
{
"name": "livekit-plugins-neuphonic",
"private": true,
"version": "0.1.1"
}
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pathlib
import setuptools
import setuptools.command.build_py
here = pathlib.Path(__file__).parent.resolve()
about = {}
with open(
os.path.join(here, "livekit", "plugins", "neuphonic", "version.py"), "r"
) as f:
exec(f.read(), about)
setuptools.setup(
name="livekit-plugins-neuphonic",
version=about["__version__"],
description="LiveKit Agents Plugin for Neuphonic",
long_description=(here / "README.md").read_text(encoding="utf-8"),
long_description_content_type="text/markdown",
url="https://github.com/livekit/agents",
cmdclass={},
classifiers=[
"Intended Audience :: Developers",
"License :: OSI Approved :: Apache Software License",
"Topic :: Multimedia :: Sound/Audio",
"Topic :: Multimedia :: Video",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3 :: Only",
],
keywords=["webrtc", "realtime", "audio", "video", "livekit"],
license="Apache-2.0",
packages=setuptools.find_namespace_packages(include=["livekit.*"]),
python_requires=">=3.9.0",
install_requires=["livekit-agents>=0.12.16,<1.0.0"],
project_urls={
"Documentation": "https://docs.livekit.io",
"Website": "https://livekit.io/",
"Source": "https://github.com/livekit/agents",
},
)
# livekit-plugins-nltk
## 0.7.4
### Patch Changes
- updated livekit-agent reference to <1.0 - [#1607](https://github.com/livekit/agents/pull/1607) ([@davidzhao](https://github.com/davidzhao))
## 0.7.3
### Patch Changes
- fix release - [#1176](https://github.com/livekit/agents/pull/1176) ([@theomonnom](https://github.com/theomonnom))
## 0.7.2
### Patch Changes
- fix another semver break - [#659](https://github.com/livekit/agents/pull/659) ([@theomonnom](https://github.com/theomonnom))
## 0.7.1
### Patch Changes
- Revert "nltk: fix broken punkt download" - [#630](https://github.com/livekit/agents/pull/630) ([@theomonnom](https://github.com/theomonnom))
## 0.7.0
### Minor Changes
- dev prerelease - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
### Patch Changes
- test release - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
- pull: '--rebase --autostash ...' - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
- Default loglevel to warn - [#472](https://github.com/livekit/agents/pull/472) ([@lukasIO](https://github.com/lukasIO))
- bump versions to update dependencies - [#510](https://github.com/livekit/agents/pull/510) ([@theomonnom](https://github.com/theomonnom))
- test release - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
- fix changesets release CI - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
- release v0.8.0 - [`6e74aa714c2dfaa8212db4528d7b59d095b6c660`](https://github.com/livekit/agents/commit/6e74aa714c2dfaa8212db4528d7b59d095b6c660) ([@theomonnom](https://github.com/theomonnom))
- dev fixes - multiprocessing & voiceassistant - [#493](https://github.com/livekit/agents/pull/493) ([@theomonnom](https://github.com/theomonnom))
## 0.7.0-dev.7
### Patch Changes
- pull: '--rebase --autostash ...' - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
## 0.7.0-dev.6
### Patch Changes
- test release - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
## 0.7.0-dev.5
### Patch Changes
- test release - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
## 0.7.0-dev.4
### Patch Changes
- fix changesets release CI - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
## 0.7.0-dev.3
### Patch Changes
- bump versions to update dependencies - [#510](https://github.com/livekit/agents/pull/510) ([@theomonnom](https://github.com/theomonnom))
## 0.7.0-dev.2
### Patch Changes
- dev fixes - multiprocessing & voiceassistant - [#493](https://github.com/livekit/agents/pull/493) ([@theomonnom](https://github.com/theomonnom))
## 0.7.0-dev.1
### Minor Changes
- dev prerelease - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
## 0.6.2-dev.0
### Patch Changes
- Default loglevel to warn - [#472](https://github.com/livekit/agents/pull/472) ([@lukasIO](https://github.com/lukasIO))
# LiveKit Plugins NLTK
Agent Framework plugin for [NLTK](https://www.nltk.org/)-based text processing. Currently featuring a `SentenceTokenizer`.
## Installation
```bash
pip install livekit-plugins-nltk
## livekit-plugins/livekit-plugins-nltk/livekit/plugins/nltk/__init__.py
```py
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .sentence_tokenizer import SentenceTokenizer
from .version import __version__
__all__ = ["SentenceTokenizer", "__version__"]
from livekit.agents import Plugin
import nltk # type: ignore
from .log import logger
class NltkPlugin(Plugin):
def __init__(self):
super().__init__(__name__, __version__, __package__, logger)
def download_files(self):
try:
_ = nltk.data.find("tokenizers/punkt_tab")
except LookupError:
nltk.download("punkt_tab")
Plugin.register_plugin(NltkPlugin())
# Cleanup docs of unexported modules
_module = dir()
NOT_IN_ALL = [m for m in _module if m not in __all__]
__pdoc__ = {}
for n in NOT_IN_ALL:
__pdoc__[n] = False
import logging
logger = logging.getLogger("livekit.plugins.nltk")
from __future__ import annotations
import dataclasses
import functools
from dataclasses import dataclass
from livekit import agents
import nltk # type: ignore
# nltk is using the punkt tokenizer
# https://www.nltk.org/_modules/nltk/tokenize/punkt.html
# this code is using a whitespace to concatenate small sentences together
# (languages such as Chinese and Japanese are not yet supported)
@dataclass
class _TokenizerOptions:
language: str
min_sentence_len: int
stream_context_len: int
class SentenceTokenizer(agents.tokenize.SentenceTokenizer):
def __init__(
self,
*,
language: str = "english",
min_sentence_len: int = 20,
stream_context_len: int = 10,
) -> None:
super().__init__()
self._config = _TokenizerOptions(
language=language,
min_sentence_len=min_sentence_len,
stream_context_len=stream_context_len,
)
def _sanitize_options(self, language: str | None = None) -> _TokenizerOptions:
config = dataclasses.replace(self._config)
if language:
config.language = language
return config
def tokenize(self, text: str, *, language: str | None = None) -> list[str]:
config = self._sanitize_options(language=language)
sentences = nltk.tokenize.sent_tokenize(text, config.language)
new_sentences = []
buff = ""
for sentence in sentences:
buff += sentence + " "
if len(buff) - 1 >= config.min_sentence_len:
new_sentences.append(buff.rstrip())
buff = ""
if buff:
new_sentences.append(buff.rstrip())
return new_sentences
def stream(self, *, language: str | None = None) -> agents.tokenize.SentenceStream:
config = self._sanitize_options(language=language)
return agents.tokenize.BufferedSentenceStream(
tokenizer=functools.partial(
nltk.tokenize.sent_tokenize, language=config.language
),
min_token_len=self._config.min_sentence_len,
min_ctx_len=self._config.stream_context_len,
)
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "0.7.4"
{
"name": "livekit-plugins-nltk",
"private": true,
"version": "0.7.4"
}
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pathlib
import setuptools
import setuptools.command.build_py
here = pathlib.Path(__file__).parent.resolve()
about = {}
with open(os.path.join(here, "livekit", "plugins", "nltk", "version.py"), "r") as f:
exec(f.read(), about)
setuptools.setup(
name="livekit-plugins-nltk",
version=about["__version__"],
description="Agent Framework plugin for NLTK-based text processing.",
long_description=(here / "README.md").read_text(encoding="utf-8"),
long_description_content_type="text/markdown",
url="https://github.com/livekit/agents",
cmdclass={},
classifiers=[
"Intended Audience :: Developers",
"License :: OSI Approved :: Apache Software License",
"Topic :: Multimedia :: Sound/Audio",
"Topic :: Multimedia :: Video",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3 :: Only",
],
keywords=["webrtc", "realtime", "audio", "video", "livekit"],
license="Apache-2.0",
packages=setuptools.find_namespace_packages(include=["livekit.*"]),
python_requires=">=3.9.0",
install_requires=["livekit-agents>=0.12.16,<1.0.0", "nltk >= 3.9.1, < 4"],
package_data={"livekit.plugins.nltk": ["py.typed"]},
project_urls={
"Documentation": "https://docs.livekit.io",
"Website": "https://livekit.io/",
"Source": "https://github.com/livekit/agents",
},
)
# livekit-plugins-openai
## 0.12.3
### Patch Changes
- openai: default to use_realtime=False - [#1783](https://github.com/livekit/agents/pull/1783) ([@davidzhao](https://github.com/davidzhao))
- fix(openai): pass NotGiven to OpenAI when instructions are omitted - [#1834](https://github.com/livekit/agents/pull/1834) ([@davidzhao](https://github.com/davidzhao))
## 0.12.2
### Patch Changes
- fix: openai stt error when using detect language - [#1755](https://github.com/livekit/agents/pull/1755) ([@jayeshp19](https://github.com/jayeshp19))
## 0.12.1
### Patch Changes
- expose turn_detection options with openai STT - [#1726](https://github.com/livekit/agents/pull/1726) ([@davidzhao](https://github.com/davidzhao))
- feat(OpenAI STT): add support for semantic_vad - [#1707](https://github.com/livekit/agents/pull/1707) ([@chasemcdo](https://github.com/chasemcdo))
## 0.12.0
### Minor Changes
- support for streaming STT, new STT/TTS models - [#1701](https://github.com/livekit/agents/pull/1701) ([@davidzhao](https://github.com/davidzhao))
### Patch Changes
- openai new STT model and voices - [#1691](https://github.com/livekit/agents/pull/1691) ([@lundin](https://github.com/lundin))
- Make azure and openai take a timeout optionally. Also update the default timeout for Azure OpenAI to 5s from 10 minutes. - [#1674](https://github.com/livekit/agents/pull/1674) ([@martin-purplefish](https://github.com/martin-purplefish))
## 0.11.3
### Patch Changes
- Support more input transcription parameters for openai realtime - [#1637](https://github.com/livekit/agents/pull/1637) ([@adambenali](https://github.com/adambenali))
- Add string type support to model parameter - [#1657](https://github.com/livekit/agents/pull/1657) ([@jayeshp19](https://github.com/jayeshp19))
## 0.11.2
### Patch Changes
- version bump to 0.11.1 - [#1640](https://github.com/livekit/agents/pull/1640) ([@davidzhao](https://github.com/davidzhao))
## 1.0.1
### Patch Changes
- use streaming AudioDecoder to handle compressed encoding - [#1584](https://github.com/livekit/agents/pull/1584) ([@davidzhao](https://github.com/davidzhao))
- fix multimodal agent interrupts itself when creating function call response - [#1585](https://github.com/livekit/agents/pull/1585) ([@longcw](https://github.com/longcw))
- feat: add max_tokens option to LLM and LLMStream classes - [#1576](https://github.com/livekit/agents/pull/1576) ([@davidzhao](https://github.com/davidzhao))
- updated livekit-agent reference to <1.0 - [#1607](https://github.com/livekit/agents/pull/1607) ([@davidzhao](https://github.com/davidzhao))
## 1.0.0
### Major Changes
- feat: add max_tokens option to LLM and LLMStream classes - [#1545](https://github.com/livekit/agents/pull/1545) ([@dorlanpabon](https://github.com/dorlanpabon))
## 0.11.0
### Minor Changes
- openai tts: switch to using Opus encoding - [#1494](https://github.com/livekit/agents/pull/1494) ([@davidzhao](https://github.com/davidzhao))
## 0.10.19
### Patch Changes
- fix: [openai] only send params when set - [#1474](https://github.com/livekit/agents/pull/1474) ([@jayeshp19](https://github.com/jayeshp19))
- fix response create for openai realtime model - [#1469](https://github.com/livekit/agents/pull/1469) ([@longcw](https://github.com/longcw))
## 0.10.18
### Patch Changes
- Added an additional field in LLM capabilities class to check if model providers support function call history within chat context without needing function definitions. - [#1441](https://github.com/livekit/agents/pull/1441) ([@jayeshp19](https://github.com/jayeshp19))
## 0.10.17
### Patch Changes
- gemini-realtime: fix input audio sample rate - [#1411](https://github.com/livekit/agents/pull/1411) ([@jayeshp19](https://github.com/jayeshp19))
## 0.10.16
### Patch Changes
- add generate_reply api for multimodal agent - [#1359](https://github.com/livekit/agents/pull/1359) ([@longcw](https://github.com/longcw))
## 0.10.15
### Patch Changes
- support disabling server VAD for OpenAI realtime model - [#1347](https://github.com/livekit/agents/pull/1347) ([@longcw](https://github.com/longcw))
## 0.10.14
### Patch Changes
- fix: revert from weakset to list in multimodal for maintaining sessions - [#1326](https://github.com/livekit/agents/pull/1326) ([@jayeshp19](https://github.com/jayeshp19))
## 0.10.13
### Patch Changes
- improved handling of LLM errors, do not retry if already began - [#1298](https://github.com/livekit/agents/pull/1298) ([@davidzhao](https://github.com/davidzhao))
- make multimodal class generic and support gemini live api - [#1240](https://github.com/livekit/agents/pull/1240) ([@jayeshp19](https://github.com/jayeshp19))
## 0.10.12
### Patch Changes
- fix unknown `metadata` & `store` fields on OpenAI-like API - [#1276](https://github.com/livekit/agents/pull/1276) ([@theomonnom](https://github.com/theomonnom))
## 0.10.11
### Patch Changes
- Moved create_ai_function_info to function_context.py for better reusability and reduce repetation - [#1260](https://github.com/livekit/agents/pull/1260) ([@jayeshp19](https://github.com/jayeshp19))
- add on_duplicate option for multimodal agent response create - [#1204](https://github.com/livekit/agents/pull/1204) ([@longcw](https://github.com/longcw))
- Add support for OpenAI's "detail" parameter to ChatImage - [#1213](https://github.com/livekit/agents/pull/1213) ([@bcherry](https://github.com/bcherry))
Add support for data URLs on ChatImage in the Anthropic plugin.
- filter out empty message for set chat ctx in realtime model - [#1245](https://github.com/livekit/agents/pull/1245) ([@longcw](https://github.com/longcw))
- fix: correctly parse function argument types - [#1221](https://github.com/livekit/agents/pull/1221) ([@jayeshp19](https://github.com/jayeshp19))
- add session_updated event for RealtimeSession - [#1253](https://github.com/livekit/agents/pull/1253) ([@longcw](https://github.com/longcw))
- added llama 3.3 70b to model definitions - [#1233](https://github.com/livekit/agents/pull/1233) ([@davidzhao](https://github.com/davidzhao))
- update default realtime model to gpt-4o-realtime-preview-2024-12-17 - [#1250](https://github.com/livekit/agents/pull/1250) ([@davidzhao](https://github.com/davidzhao))
- Fix center_aspect_fit bug, add scale_aspect_fit and scale_aspect_fill resizing options. - [#1222](https://github.com/livekit/agents/pull/1222) ([@bcherry](https://github.com/bcherry))
Make scale_aspect_fit the new default resizing option for video frames.
## 0.10.10
### Patch Changes
- add `google/gemini-2.0-flash-exp` as default model for vertex - [#1214](https://github.com/livekit/agents/pull/1214) ([@jayeshp19](https://github.com/jayeshp19))
- emit error event for realtime model - [#1200](https://github.com/livekit/agents/pull/1200) ([@longcw](https://github.com/longcw))
- fix: return structured output from func calls - [#1187](https://github.com/livekit/agents/pull/1187) ([@jayeshp19](https://github.com/jayeshp19))
- Handle optional func args in tool calls when set to `None` - [#1211](https://github.com/livekit/agents/pull/1211) ([@jayeshp19](https://github.com/jayeshp19))
- fix: openai llm retries - [#1196](https://github.com/livekit/agents/pull/1196) ([@theomonnom](https://github.com/theomonnom))
- Improvements to end of turn plugin, ensure STT language settings. - [#1195](https://github.com/livekit/agents/pull/1195) ([@davidzhao](https://github.com/davidzhao))
- fix: Handle optional func args in tool calls when set to `None` - [#1211](https://github.com/livekit/agents/pull/1211) ([@jayeshp19](https://github.com/jayeshp19))
## 0.10.9
### Patch Changes
- fix release - [#1176](https://github.com/livekit/agents/pull/1176) ([@theomonnom](https://github.com/theomonnom))
## 0.10.8
### Patch Changes
- fix uncatched OAI errors - [#1158](https://github.com/livekit/agents/pull/1158) ([@theomonnom](https://github.com/theomonnom))
- feat: stt retry & stt.FallbackAdapter - [#1114](https://github.com/livekit/agents/pull/1114) ([@theomonnom](https://github.com/theomonnom))
- project id fix for google - [#1115](https://github.com/livekit/agents/pull/1115) ([@jayeshp19](https://github.com/jayeshp19))
- Add retries to recover from text mode to audio model for realtime API - [#1121](https://github.com/livekit/agents/pull/1121) ([@longcw](https://github.com/longcw))
- Support for Python 3.13, relaxed Pillow version requirement for 10.x - [#1127](https://github.com/livekit/agents/pull/1127) ([@davidzhao](https://github.com/davidzhao))
- support for custom tool use in LLMs - [#1102](https://github.com/livekit/agents/pull/1102) ([@jayeshp19](https://github.com/jayeshp19))
- feat: tts retry & tts.FallbackAdapter - [#1074](https://github.com/livekit/agents/pull/1074) ([@theomonnom](https://github.com/theomonnom))
- Add new OpenAI realtime voices - [#1116](https://github.com/livekit/agents/pull/1116) ([@bcherry](https://github.com/bcherry))
- Expose multimodal agent metrics - [#1080](https://github.com/livekit/agents/pull/1080) ([@longcw](https://github.com/longcw))
- feat: llm retry & llm.FallbackAdapter - [#1132](https://github.com/livekit/agents/pull/1132) ([@theomonnom](https://github.com/theomonnom))
- vertex ai support with openai library - [#1084](https://github.com/livekit/agents/pull/1084) ([@jayeshp19](https://github.com/jayeshp19))
## 0.10.7
### Patch Changes
- fix realtime API audio format values - [#1092](https://github.com/livekit/agents/pull/1092) ([@longcw](https://github.com/longcw))
- make ConversationItem.create and delete return a Future in Realtime model - [#1085](https://github.com/livekit/agents/pull/1085) ([@longcw](https://github.com/longcw))
## 0.10.6
### Patch Changes
- Expose usage metrics for Realtime model - [#1036](https://github.com/livekit/agents/pull/1036) ([@yuyuma](https://github.com/yuyuma))
- sync the Realtime API converstation items and add set_chat_ctx - [#1015](https://github.com/livekit/agents/pull/1015) ([@longcw](https://github.com/longcw))
## 0.10.5
### Patch Changes
- fix: Azure realtime model does not accept null for max_response_output_tokens - [#927](https://github.com/livekit/agents/pull/927) ([@davidzhao](https://github.com/davidzhao))
- add update_options to TTS - [#922](https://github.com/livekit/agents/pull/922) ([@theomonnom](https://github.com/theomonnom))
- Groq integration with Whisper-compatible STT endpoints - [#986](https://github.com/livekit/agents/pull/986) ([@jayeshp19](https://github.com/jayeshp19))
- pipelineagent: expose timing metrics & api errors wip - [#957](https://github.com/livekit/agents/pull/957) ([@theomonnom](https://github.com/theomonnom))
- openai: fix low timeouts - [#926](https://github.com/livekit/agents/pull/926) ([@theomonnom](https://github.com/theomonnom))
- expose usage metrics - [#984](https://github.com/livekit/agents/pull/984) ([@theomonnom](https://github.com/theomonnom))
## 0.10.4
### Patch Changes
- add x.ai support - [#907](https://github.com/livekit/agents/pull/907) ([@theomonnom](https://github.com/theomonnom))
- Fix functions to include content - [#897](https://github.com/livekit/agents/pull/897) ([@martin-purplefish](https://github.com/martin-purplefish))
## 0.10.3
### Patch Changes
- fix: handle when STT does not return any speech - [#854](https://github.com/livekit/agents/pull/854) ([@davidzhao](https://github.com/davidzhao))
- Support for Realtime API with Azure OpenAI - [#848](https://github.com/livekit/agents/pull/848) ([@davidzhao](https://github.com/davidzhao))
## 0.10.2
### Patch Changes
- oai-realtime: fix function calls - [#826](https://github.com/livekit/agents/pull/826) ([@KillianLucas](https://github.com/KillianLucas))
## 0.10.1
### Patch Changes
- oai-realtime: log response errors - [#819](https://github.com/livekit/agents/pull/819) ([@theomonnom](https://github.com/theomonnom))
## 0.10.0
### Minor Changes
- OpenAI Realtime API support - [#814](https://github.com/livekit/agents/pull/814) ([@theomonnom](https://github.com/theomonnom))
### Patch Changes
- Add Telnyx integration for LLM - [#803](https://github.com/livekit/agents/pull/803) ([@jamestwhedbee](https://github.com/jamestwhedbee))
## 0.8.5
### Patch Changes
- Fix function for OpenAI Assistants - [#784](https://github.com/livekit/agents/pull/784) ([@keepingitneil](https://github.com/keepingitneil))
## 0.8.4
### Patch Changes
- avoid returning tiny frames from TTS - [#747](https://github.com/livekit/agents/pull/747) ([@theomonnom](https://github.com/theomonnom))
- Fixing Assistant API Vision Capabilities - [#771](https://github.com/livekit/agents/pull/771) ([@keepingitneil](https://github.com/keepingitneil))
## 0.8.3
### Patch Changes
- Introduce function calling to OpenAI Assistants - [#710](https://github.com/livekit/agents/pull/710) ([@keepingitneil](https://github.com/keepingitneil))
- Add Cerebras to OpenAI Plugin - [#731](https://github.com/livekit/agents/pull/731) ([@henrytwo](https://github.com/henrytwo))
## 0.8.2
### Patch Changes
- Add deepseek LLMs at OpenAI plugin - [#714](https://github.com/livekit/agents/pull/714) ([@lenage](https://github.com/lenage))
- skip processing of choice.delta when it is None - [#705](https://github.com/livekit/agents/pull/705) ([@theomonnom](https://github.com/theomonnom))
## 0.8.1
### Patch Changes
- add support for Ollama, Perplexity, Fireworks, Octo, Together, and Groq LLMs through the OpenAI API - [#611](https://github.com/livekit/agents/pull/611) ([@nbsp](https://github.com/nbsp))
- allow sending user IDs - [#633](https://github.com/livekit/agents/pull/633) ([@nbsp](https://github.com/nbsp))
- Support OpenAI Assistants API as a beta feature under `livekit.plugins.openai.beta` - [#601](https://github.com/livekit/agents/pull/601) ([@keepingitneil](https://github.com/keepingitneil))
Add \_metadata to ChatCtx and ChatMessage which can be used (in the case of OpenAI assistants) for bookeeping to sync local state with remote, OpenAI state
## 0.8.0
### Minor Changes
- openai: use openai client for stt - [#583](https://github.com/livekit/agents/pull/583) ([@theomonnom](https://github.com/theomonnom))
### Patch Changes
- openai: add api_key argument - [#580](https://github.com/livekit/agents/pull/580) ([@theomonnom](https://github.com/theomonnom))
- openai: fix incorrect API urls on Windows - [#575](https://github.com/livekit/agents/pull/575) ([@theomonnom](https://github.com/theomonnom))
## 0.7.1
### Patch Changes
- set timeout to 5 seconds - [#524](https://github.com/livekit/agents/pull/524) ([@nbsp](https://github.com/nbsp))
## 0.7.0
### Minor Changes
- dev prerelease - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
### Patch Changes
- test release - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
- pull: '--rebase --autostash ...' - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
- Default loglevel to warn - [#472](https://github.com/livekit/agents/pull/472) ([@lukasIO](https://github.com/lukasIO))
- bump versions to update dependencies - [#510](https://github.com/livekit/agents/pull/510) ([@theomonnom](https://github.com/theomonnom))
- test release - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
- fix changesets release CI - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
- release v0.8.0 - [`6e74aa714c2dfaa8212db4528d7b59d095b6c660`](https://github.com/livekit/agents/commit/6e74aa714c2dfaa8212db4528d7b59d095b6c660) ([@theomonnom](https://github.com/theomonnom))
- dev fixes - multiprocessing & voiceassistant - [#493](https://github.com/livekit/agents/pull/493) ([@theomonnom](https://github.com/theomonnom))
## 0.7.0-dev.7
### Patch Changes
- pull: '--rebase --autostash ...' - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
## 0.7.0-dev.6
### Patch Changes
- test release - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
## 0.7.0-dev.5
### Patch Changes
- test release - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
## 0.7.0-dev.4
### Patch Changes
- fix changesets release CI - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
## 0.7.0-dev.3
### Patch Changes
- bump versions to update dependencies - [#510](https://github.com/livekit/agents/pull/510) ([@theomonnom](https://github.com/theomonnom))
## 0.7.0-dev.2
### Patch Changes
- dev fixes - multiprocessing & voiceassistant - [#493](https://github.com/livekit/agents/pull/493) ([@theomonnom](https://github.com/theomonnom))
## 0.7.0-dev.1
### Minor Changes
- dev prerelease - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
## 0.6.1-dev.0
### Patch Changes
- Default loglevel to warn - [#472](https://github.com/livekit/agents/pull/472) ([@lukasIO](https://github.com/lukasIO))
# LiveKit Plugins OpenAI
Agent Framework plugin for services from OpenAI. Currently supports STT, TTS, and Dalle 3.
## Installation
```bash
pip install livekit-plugins-openai
You’ll need an API key from OpenAI. It can be set as an environment variable: OPENAI_API_KEY
In addition to LLM, STT, and TTS, this package also supports using OpenAI’s Assistants API as a LLM.
The Assistants API is a stateful API that holds the conversation state on the server-side.
The AssistantLLM
class gives you a LLM-like interface to interact with the Assistant API.
For examples of using Assistants API with VoicePipelineAssistant, see the openai assistants API example
## livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/__init__.py
```py
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import beta, realtime
from .embeddings import EmbeddingData, create_embeddings
from .llm import LLM, LLMStream
from .models import STTModels, TTSModels, TTSVoices
from .stt import STT
from .tts import TTS
from .version import __version__
__all__ = [
"STT",
"TTS",
"LLM",
"LLMStream",
"STTModels",
"beta",
"TTSModels",
"TTSVoices",
"create_embeddings",
"EmbeddingData",
"realtime",
"__version__",
]
from livekit.agents import Plugin
from .log import logger
class OpenAIPlugin(Plugin):
def __init__(self) -> None:
super().__init__(__name__, __version__, __package__, logger)
Plugin.register_plugin(OpenAIPlugin())
# Cleanup docs of unexported modules
_module = dir()
NOT_IN_ALL = [m for m in _module if m not in __all__]
__pdoc__ = {}
for n in NOT_IN_ALL:
__pdoc__[n] = False
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import inspect
import typing
from typing import Any
from livekit.agents.llm import function_context, llm
from livekit.agents.llm.function_context import _is_optional_type
__all__ = ["build_oai_function_description"]
def build_oai_function_description(
fnc_info: function_context.FunctionInfo,
capabilities: llm.LLMCapabilities | None = None,
) -> dict[str, Any]:
def build_oai_property(arg_info: function_context.FunctionArgInfo):
def type2str(t: type) -> str:
if t is str:
return "string"
elif t in (int, float):
return "number"
elif t is bool:
return "boolean"
raise ValueError(f"unsupported type {t} for ai_property")
p: dict[str, Any] = {}
if arg_info.description:
p["description"] = arg_info.description
is_optional, inner_th = _is_optional_type(arg_info.type)
if typing.get_origin(inner_th) is list:
inner_type = typing.get_args(inner_th)[0]
p["type"] = "array"
p["items"] = {}
p["items"]["type"] = type2str(inner_type)
if arg_info.choices:
p["items"]["enum"] = arg_info.choices
else:
p["type"] = type2str(inner_th)
if arg_info.choices:
p["enum"] = arg_info.choices
if (
inner_th is int
and capabilities
and not capabilities.supports_choices_on_int
):
raise ValueError(
f"Parameter '{arg_info.name}' uses 'choices' with 'int', which is not supported by this model."
)
return p
properties_info: dict[str, dict[str, Any]] = {}
required_properties: list[str] = []
for arg_info in fnc_info.arguments.values():
if arg_info.default is inspect.Parameter.empty:
required_properties.append(arg_info.name)
properties_info[arg_info.name] = build_oai_property(arg_info)
return {
"type": "function",
"function": {
"name": fnc_info.name,
"description": fnc_info.description,
"parameters": {
"type": "object",
"properties": properties_info,
"required": required_properties,
},
},
}
from .assistant_llm import (
AssistantCreateOptions,
AssistantLLM,
AssistantLoadOptions,
AssistantOptions,
OnFileUploaded,
OnFileUploadedInfo,
)
__all__ = [
"AssistantLLM",
"AssistantOptions",
"AssistantCreateOptions",
"AssistantLoadOptions",
"OnFileUploaded",
"OnFileUploadedInfo",
]
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import asyncio
import json
import uuid
from dataclasses import dataclass
from typing import Any, Callable, Dict, Literal, MutableSet, Union
import httpx
from livekit import rtc
from livekit.agents import llm, utils
from livekit.agents.llm import LLMCapabilities, ToolChoice
from livekit.agents.types import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions
from openai import AsyncAssistantEventHandler, AsyncClient
from openai.types.beta.threads import Text, TextDelta
from openai.types.beta.threads.run_create_params import AdditionalMessage
from openai.types.beta.threads.run_submit_tool_outputs_params import ToolOutput
from openai.types.beta.threads.runs import (
CodeInterpreterToolCall,
FileSearchToolCall,
FunctionToolCall,
ToolCall,
)
from openai.types.file_object import FileObject
from .._oai_api import build_oai_function_description
from ..log import logger
from ..models import ChatModels
DEFAULT_MODEL = "gpt-4o"
OPENAI_MESSAGE_ID_KEY = "__openai_message_id__"
LIVEKIT_MESSAGE_ID_KEY = "__livekit_message_id__"
OPENAI_MESSAGES_ADDED_KEY = "__openai_messages_added__"
OPENAI_FILE_ID_KEY = "__openai_file_id__"
@dataclass
class LLMOptions:
model: str | ChatModels
@dataclass
class AssistantOptions:
"""Options for creating (on-the-fly) or loading an assistant. Only one of create_options or load_options should be set."""
create_options: AssistantCreateOptions | None = None
load_options: AssistantLoadOptions | None = None
@dataclass
class AssistantCreateOptions:
name: str
instructions: str
model: ChatModels
temperature: float | None = None
# TODO: when we implement code_interpreter and file_search tools
# tool_resources: ToolResources | None = None
# tools: list[AssistantTools] = field(default_factory=list)
@dataclass
class AssistantLoadOptions:
assistant_id: str
thread_id: str | None
@dataclass
class OnFileUploadedInfo:
type: Literal["image"]
original_file: llm.ChatImage
openai_file_object: FileObject
OnFileUploaded = Callable[[OnFileUploadedInfo], None]
class AssistantLLM(llm.LLM):
def __init__(
self,
*,
assistant_opts: AssistantOptions,
client: AsyncClient | None = None,
api_key: str | None = None,
base_url: str | None = None,
on_file_uploaded: OnFileUploaded | None = None,
) -> None:
super().__init__(
capabilities=LLMCapabilities(
supports_choices_on_int=True,
requires_persistent_functions=False,
)
)
test_ctx = llm.ChatContext()
if not hasattr(test_ctx, "_metadata"):
raise Exception(
"This beta feature of 'livekit-plugins-openai' requires a newer version of 'livekit-agents'"
)
self._client = client or AsyncClient(
api_key=api_key,
base_url=base_url,
http_client=httpx.AsyncClient(
timeout=httpx.Timeout(timeout=30, connect=10, read=5, pool=5),
follow_redirects=True,
limits=httpx.Limits(
max_connections=1000,
max_keepalive_connections=100,
keepalive_expiry=120,
),
),
)
self._assistant_opts = assistant_opts
self._running_fncs: MutableSet[asyncio.Task[Any]] = set()
self._on_file_uploaded = on_file_uploaded
self._tool_call_run_id_lookup = dict[str, str]()
self._submitted_tool_calls = set[str]()
self._sync_openai_task: asyncio.Task[AssistantLoadOptions] | None = None
try:
self._sync_openai_task = asyncio.create_task(self._sync_openai())
except Exception:
logger.error(
"failed to create sync openai task. This can happen when instantiating without a running asyncio event loop (such has when running tests)"
)
self._done_futures = list[asyncio.Future[None]]()
async def _sync_openai(self) -> AssistantLoadOptions:
if self._assistant_opts.create_options:
kwargs: Dict[str, Any] = {
"model": self._assistant_opts.create_options.model,
"name": self._assistant_opts.create_options.name,
"instructions": self._assistant_opts.create_options.instructions,
# "tools": [
# {"type": t} for t in self._assistant_opts.create_options.tools
# ],
# "tool_resources": self._assistant_opts.create_options.tool_resources,
}
# TODO when we implement code_interpreter and file_search tools
# if self._assistant_opts.create_options.tool_resources:
# kwargs["tool_resources"] = (
# self._assistant_opts.create_options.tool_resources
# )
if self._assistant_opts.create_options.temperature:
kwargs["temperature"] = self._assistant_opts.create_options.temperature
assistant = await self._client.beta.assistants.create(**kwargs)
thread = await self._client.beta.threads.create()
return AssistantLoadOptions(assistant_id=assistant.id, thread_id=thread.id)
elif self._assistant_opts.load_options:
if not self._assistant_opts.load_options.thread_id:
thread = await self._client.beta.threads.create()
self._assistant_opts.load_options.thread_id = thread.id
return self._assistant_opts.load_options
raise Exception("One of create_options or load_options must be set")
def chat(
self,
*,
chat_ctx: llm.ChatContext,
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
fnc_ctx: llm.FunctionContext | None = None,
temperature: float | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]]
| None = None,
):
if n is not None:
logger.warning("OpenAI Assistants does not support the 'n' parameter")
if parallel_tool_calls is not None:
logger.warning(
"OpenAI Assistants does not support the 'parallel_tool_calls' parameter"
)
if not self._sync_openai_task:
self._sync_openai_task = asyncio.create_task(self._sync_openai())
return AssistantLLMStream(
temperature=temperature,
assistant_llm=self,
sync_openai_task=self._sync_openai_task,
client=self._client,
chat_ctx=chat_ctx,
fnc_ctx=fnc_ctx,
on_file_uploaded=self._on_file_uploaded,
conn_options=conn_options,
)
async def _register_tool_call(self, tool_call_id: str, run_id: str) -> None:
self._tool_call_run_id_lookup[tool_call_id] = run_id
async def _submit_tool_call_result(self, tool_call_id: str, result: str) -> None:
if tool_call_id in self._submitted_tool_calls:
return
logger.debug(f"submitting tool call {tool_call_id} result")
run_id = self._tool_call_run_id_lookup.get(tool_call_id)
if not run_id:
logger.error(f"tool call {tool_call_id} not found")
return
if not self._sync_openai_task:
logger.error("sync_openai_task not set")
return
thread_id = (await self._sync_openai_task).thread_id
if not thread_id:
logger.error("thread_id not set")
return
tool_output = ToolOutput(output=result, tool_call_id=tool_call_id)
await self._client.beta.threads.runs.submit_tool_outputs_and_poll(
tool_outputs=[tool_output], run_id=run_id, thread_id=thread_id
)
self._submitted_tool_calls.add(tool_call_id)
logger.debug(f"submitted tool call {tool_call_id} result")
class AssistantLLMStream(llm.LLMStream):
class EventHandler(AsyncAssistantEventHandler):
def __init__(
self,
llm: AssistantLLM,
llm_stream: AssistantLLMStream,
event_ch: utils.aio.Chan[llm.ChatChunk],
chat_ctx: llm.ChatContext,
fnc_ctx: llm.FunctionContext | None = None,
):
super().__init__()
self._llm = llm
self._llm_stream = llm_stream
self._chat_ctx = chat_ctx
self._event_ch = event_ch
self._fnc_ctx = fnc_ctx
async def on_text_delta(self, delta: TextDelta, snapshot: Text):
assert self.current_run is not None
self._event_ch.send_nowait(
llm.ChatChunk(
request_id=self.current_run.id,
choices=[
llm.Choice(
delta=llm.ChoiceDelta(role="assistant", content=delta.value)
)
],
)
)
async def on_tool_call_created(self, tool_call: ToolCall):
if not self.current_run:
logger.error("tool call created without run")
return
await self._llm._register_tool_call(tool_call.id, self.current_run.id)
async def on_tool_call_done(
self,
tool_call: CodeInterpreterToolCall | FileSearchToolCall | FunctionToolCall,
) -> None:
assert self.current_run is not None
if tool_call.type == "code_interpreter":
logger.warning("code interpreter tool call not yet implemented")
elif tool_call.type == "file_search":
logger.warning("file_search tool call not yet implemented")
elif tool_call.type == "function":
if not self._fnc_ctx:
logger.error("function tool called without function context")
return
fnc = llm.FunctionCallInfo(
function_info=self._fnc_ctx.ai_functions[tool_call.function.name],
arguments=json.loads(tool_call.function.arguments),
tool_call_id=tool_call.id,
raw_arguments=tool_call.function.arguments,
)
self._llm_stream._function_calls_info.append(fnc)
chunk = llm.ChatChunk(
request_id=self.current_run.id,
choices=[
llm.Choice(
delta=llm.ChoiceDelta(role="assistant", tool_calls=[fnc]),
index=0,
)
],
)
self._event_ch.send_nowait(chunk)
def __init__(
self,
*,
assistant_llm: AssistantLLM,
client: AsyncClient,
sync_openai_task: asyncio.Task[AssistantLoadOptions],
chat_ctx: llm.ChatContext,
fnc_ctx: llm.FunctionContext | None,
temperature: float | None,
on_file_uploaded: OnFileUploaded | None,
conn_options: APIConnectOptions,
) -> None:
super().__init__(
assistant_llm, chat_ctx=chat_ctx, fnc_ctx=fnc_ctx, conn_options=conn_options
)
self._client = client
self._temperature = temperature
self._on_file_uploaded = on_file_uploaded
# current function call that we're waiting for full completion (args are streamed)
self._tool_call_id: str | None = None
self._fnc_name: str | None = None
self._fnc_raw_arguments: str | None = None
self._create_stream_task = asyncio.create_task(self._main_task())
self._sync_openai_task = sync_openai_task
# Running stream is used to ensure that we only have one stream running at a time
self._done_future: asyncio.Future[None] = asyncio.Future()
async def _run(self) -> None:
assert isinstance(self._llm, AssistantLLM)
# This function's complexity is due to the fact that we need to sync chat_ctx messages with OpenAI.
# OpenAI also does not allow us to modify messages while a stream is running. So we need to make sure streams run
# sequentially. The strategy is as follows:
#
# 1. ensure that we have a thread_id and assistant_id from OpenAI. This comes from the _sync_openai_task
# 2. make sure all previous streams are done before starting a new one
# 3. delete messages that are no longer in the chat_ctx but are still in OpenAI by using the OpenAI message id
# 4. add new messages to OpenAI that are in the chat_ctx but not in OpenAI. We don't know the OpenAI message id yet
# so we create a random uuid (we call it the LiveKit message id) and set that in the metdata.
# 5. start the stream and wait for it to finish
# 6. get the OpenAI message ids for the messages we added to OpenAI by using the metadata
# 7. Resolve the OpenAI message id with all messages that have a LiveKit message id.
try:
load_options = await self._sync_openai_task
# The assistants api does not let us modify messages while a stream is running.
# So we have to make sure previous streams are done before starting a new one.
await asyncio.gather(*self._llm._done_futures)
self._llm._done_futures.clear()
self._llm._done_futures.append(self._done_future)
# OpenAI required submitting tool call outputs manually. We iterate
# tool outputs in the chat_ctx (from previous runs) and submit them
# before continuing.
for msg in self._chat_ctx.messages:
if msg.role == "tool":
if not msg.tool_call_id:
logger.error("tool message without tool_call_id")
continue
if not isinstance(msg.content, str):
logger.error("tool message content is not str")
continue
await self._llm._submit_tool_call_result(
msg.tool_call_id, msg.content
)
# At the chat_ctx level, create a map of thread_id to message_ids
# This is used to keep track of which messages have been added to the thread
# and which we may need to delete from OpenAI
if OPENAI_MESSAGES_ADDED_KEY not in self._chat_ctx._metadata:
self._chat_ctx._metadata[OPENAI_MESSAGES_ADDED_KEY] = dict()
if (
load_options.thread_id
not in self._chat_ctx._metadata[OPENAI_MESSAGES_ADDED_KEY]
):
self._chat_ctx._metadata[OPENAI_MESSAGES_ADDED_KEY][
load_options.thread_id
] = set()
# Keep this handy to make the code more readable later on
openai_addded_messages_set: set[str] = self._chat_ctx._metadata[
OPENAI_MESSAGES_ADDED_KEY
][load_options.thread_id]
# Keep track of messages that are no longer in the chat_ctx but are still in OpenAI
# Note: Unfortuneately, this will add latency unfortunately. Usually it's just one message so we loop it but
# it will create an extra round trip to OpenAI before being able to run inference.
# TODO: parallelize it?
for msg in self._chat_ctx.messages:
msg_id = msg._metadata.get(OPENAI_MESSAGE_ID_KEY, {}).get(
load_options.thread_id
)
assert load_options.thread_id
if msg_id and msg_id not in openai_addded_messages_set:
await self._client.beta.threads.messages.delete(
thread_id=load_options.thread_id,
message_id=msg_id,
)
logger.debug(
f"Deleted message '{msg_id}' in thread '{load_options.thread_id}'"
)
openai_addded_messages_set.remove(msg_id)
# Upload any images in the chat_ctx that have not been uploaded to OpenAI
for msg in self._chat_ctx.messages:
if msg.role != "user":
continue
if not isinstance(msg.content, list):
continue
for cnt in msg.content:
if (
not isinstance(cnt, llm.ChatImage)
or OPENAI_FILE_ID_KEY in cnt._cache
):
continue
if isinstance(cnt.image, str):
continue
file_obj = await self._upload_frame(
cnt.image, cnt.inference_width, cnt.inference_height
)
cnt._cache[OPENAI_FILE_ID_KEY] = file_obj.id
if self._on_file_uploaded:
self._on_file_uploaded(
OnFileUploadedInfo(
type="image",
original_file=cnt,
openai_file_object=file_obj,
)
)
# Keep track of the new messages in the chat_ctx that we need to add to OpenAI
additional_messages: list[AdditionalMessage] = []
for msg in self._chat_ctx.messages:
if msg.role != "user":
continue
msg_id = str(uuid.uuid4())
if OPENAI_MESSAGE_ID_KEY not in msg._metadata:
msg._metadata[OPENAI_MESSAGE_ID_KEY] = dict[str, str]()
if LIVEKIT_MESSAGE_ID_KEY not in msg._metadata:
msg._metadata[LIVEKIT_MESSAGE_ID_KEY] = dict[str, str]()
oai_msg_id_dict = msg._metadata[OPENAI_MESSAGE_ID_KEY]
lk_msg_id_dict = msg._metadata[LIVEKIT_MESSAGE_ID_KEY]
if load_options.thread_id not in oai_msg_id_dict:
converted_msg = build_oai_message(msg)
converted_msg["private_message_id"] = msg_id
additional_messages.append(
AdditionalMessage(
role="user",
content=converted_msg["content"],
metadata={LIVEKIT_MESSAGE_ID_KEY: msg_id},
)
)
lk_msg_id_dict[load_options.thread_id] = msg_id
eh = AssistantLLMStream.EventHandler(
llm=self._llm,
event_ch=self._event_ch,
chat_ctx=self._chat_ctx,
fnc_ctx=self._fnc_ctx,
llm_stream=self,
)
assert load_options.thread_id
kwargs: dict[str, Any] = {
"additional_messages": additional_messages,
"thread_id": load_options.thread_id,
"assistant_id": load_options.assistant_id,
"event_handler": eh,
"temperature": self._temperature,
}
if self._fnc_ctx:
kwargs["tools"] = [
build_oai_function_description(f)
for f in self._fnc_ctx.ai_functions.values()
]
async with self._client.beta.threads.runs.stream(**kwargs) as stream:
await stream.until_done()
# Populate the openai_message_id for the messages we added to OpenAI. Note, we do this after
# sending None to close the iterator so that it is done in parellel with any users of
# the stream. However, the next stream will not start until this is done.
lk_to_oai_lookup = dict[str, str]()
messages = await self._client.beta.threads.messages.list(
thread_id=load_options.thread_id,
limit=10, # We could be smarter and make a more exact query, but this is probably fine
)
for oai_msg in messages.data:
if oai_msg.metadata.get(LIVEKIT_MESSAGE_ID_KEY): # type: ignore
lk_to_oai_lookup[oai_msg.metadata[LIVEKIT_MESSAGE_ID_KEY]] = ( # type: ignore
oai_msg.id
)
for msg in self._chat_ctx.messages:
if msg.role != "user":
continue
oai_msg_id_dict = msg._metadata.get(OPENAI_MESSAGE_ID_KEY)
lk_msg_id_dict = msg._metadata.get(LIVEKIT_MESSAGE_ID_KEY)
if oai_msg_id_dict is None or lk_msg_id_dict is None:
continue
lk_msg_id = lk_msg_id_dict.get(load_options.thread_id)
if lk_msg_id and lk_msg_id in lk_to_oai_lookup:
oai_msg_id = lk_to_oai_lookup[lk_msg_id]
oai_msg_id_dict[load_options.thread_id] = oai_msg_id
openai_addded_messages_set.add(oai_msg_id)
# We don't need the LiveKit message id anymore
lk_msg_id_dict.pop(load_options.thread_id)
finally:
self._done_future.set_result(None)
async def _upload_frame(
self,
frame: rtc.VideoFrame,
inference_width: int | None,
inference_height: int | None,
):
# inside our internal implementation, we allow to put extra metadata to
# each ChatImage (avoid to reencode each time we do a chatcompletion request)
opts = utils.images.EncodeOptions()
if inference_width and inference_height:
opts.resize_options = utils.images.ResizeOptions(
width=inference_width,
height=inference_height,
strategy="scale_aspect_fit",
)
encoded_data = utils.images.encode(frame, opts)
fileObj = await self._client.files.create(
file=("image.jpg", encoded_data),
purpose="vision",
)
return fileObj
def build_oai_message(msg: llm.ChatMessage):
oai_msg: dict[str, Any] = {"role": msg.role}
if msg.name:
oai_msg["name"] = msg.name
# add content if provided
if isinstance(msg.content, str):
oai_msg["content"] = msg.content
elif isinstance(msg.content, list):
oai_content: list[dict[str, Any]] = []
for cnt in msg.content:
if isinstance(cnt, str):
oai_content.append({"type": "text", "text": cnt})
elif isinstance(cnt, llm.ChatImage):
if cnt._cache[OPENAI_FILE_ID_KEY]:
oai_content.append(
{
"type": "image_file",
"image_file": {"file_id": cnt._cache[OPENAI_FILE_ID_KEY]},
}
)
oai_msg["content"] = oai_content
# make sure to provide when function has been called inside the context
# (+ raw_arguments)
if msg.tool_calls is not None:
tool_calls: list[dict[str, Any]] = []
oai_msg["tool_calls"] = tool_calls
for fnc in msg.tool_calls:
tool_calls.append(
{
"id": fnc.tool_call_id,
"type": "function",
"function": {
"name": fnc.function_info.name,
"arguments": fnc.raw_arguments,
},
}
)
# tool_call_id is set when the message is a response/result to a function call
# (content is a string in this case)
if msg.tool_call_id:
oai_msg["tool_call_id"] = msg.tool_call_id
return oai_msg
from __future__ import annotations
import base64
import os
import struct
from dataclasses import dataclass
import aiohttp
from livekit.agents import utils
from . import models
@dataclass
class EmbeddingData:
index: int
embedding: list[float]
async def create_embeddings(
*,
input: list[str],
model: models.EmbeddingModels = "text-embedding-3-small",
dimensions: int | None = None,
api_key: str | None = None,
http_session: aiohttp.ClientSession | None = None,
) -> list[EmbeddingData]:
http_session = http_session or utils.http_context.http_session()
api_key = api_key or os.environ.get("OPENAI_API_KEY")
if not api_key:
raise ValueError("OPENAI_API_KEY must be set")
async with http_session.post(
"https://api.openai.com/v1/embeddings",
headers={"Authorization": f"Bearer {api_key}"},
json={
"model": model,
"input": input,
"encoding_format": "base64",
"dimensions": dimensions,
},
) as resp:
json = await resp.json()
data = json["data"]
list_data = []
for d in data:
bytes = base64.b64decode(d["embedding"])
num_floats = len(bytes) // 4
floats = list(struct.unpack("f" * num_floats, bytes))
list_data.append(EmbeddingData(index=d["index"], embedding=floats))
return list_data
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import asyncio
import datetime
import os
from dataclasses import dataclass
from typing import Any, Literal, MutableSet, Union
import aiohttp
import httpx
from livekit.agents import (
APIConnectionError,
APIStatusError,
APITimeoutError,
llm,
)
from livekit.agents.llm import (
LLMCapabilities,
ToolChoice,
_create_ai_function_info,
)
from livekit.agents.types import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions
import openai
from openai.types.chat import ChatCompletionChunk, ChatCompletionMessageParam
from openai.types.chat.chat_completion_chunk import Choice
from ._oai_api import build_oai_function_description
from .log import logger
from .models import (
CerebrasChatModels,
ChatModels,
DeepSeekChatModels,
GroqChatModels,
OctoChatModels,
PerplexityChatModels,
TelnyxChatModels,
TogetherChatModels,
VertexModels,
XAIChatModels,
)
from .utils import AsyncAzureADTokenProvider, build_oai_message
@dataclass
class LLMOptions:
model: str | ChatModels
user: str | None
temperature: float | None
parallel_tool_calls: bool | None
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = "auto"
store: bool | None = None
metadata: dict[str, str] | None = None
max_tokens: int | None = None
class LLM(llm.LLM):
def __init__(
self,
*,
model: str | ChatModels = "gpt-4o",
api_key: str | None = None,
base_url: str | None = None,
user: str | None = None,
client: openai.AsyncClient | None = None,
temperature: float | None = None,
parallel_tool_calls: bool | None = None,
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = "auto",
store: bool | None = None,
metadata: dict[str, str] | None = None,
max_tokens: int | None = None,
timeout: httpx.Timeout | None = None,
) -> None:
"""
Create a new instance of OpenAI LLM.
``api_key`` must be set to your OpenAI API key, either using the argument or by setting the
``OPENAI_API_KEY`` environmental variable.
"""
super().__init__(
capabilities=LLMCapabilities(
supports_choices_on_int=True,
requires_persistent_functions=False,
)
)
self._opts = LLMOptions(
model=model,
user=user,
temperature=temperature,
parallel_tool_calls=parallel_tool_calls,
tool_choice=tool_choice,
store=store,
metadata=metadata,
max_tokens=max_tokens,
)
self._client = client or openai.AsyncClient(
api_key=api_key,
base_url=base_url,
max_retries=0,
http_client=httpx.AsyncClient(
timeout=timeout
if timeout
else httpx.Timeout(connect=15.0, read=5.0, write=5.0, pool=5.0),
follow_redirects=True,
limits=httpx.Limits(
max_connections=50,
max_keepalive_connections=50,
keepalive_expiry=120,
),
),
)
self._running_fncs: MutableSet[asyncio.Task[Any]] = set()
@staticmethod
def with_azure(
*,
model: str | ChatModels = "gpt-4o",
azure_endpoint: str | None = None,
azure_deployment: str | None = None,
api_version: str | None = None,
api_key: str | None = None,
azure_ad_token: str | None = None,
azure_ad_token_provider: AsyncAzureADTokenProvider | None = None,
organization: str | None = None,
project: str | None = None,
base_url: str | None = None,
user: str | None = None,
temperature: float | None = None,
parallel_tool_calls: bool | None = None,
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = "auto",
timeout: httpx.Timeout | None = None,
) -> LLM:
"""
This automatically infers the following arguments from their corresponding environment variables if they are not provided:
- `api_key` from `AZURE_OPENAI_API_KEY`
- `organization` from `OPENAI_ORG_ID`
- `project` from `OPENAI_PROJECT_ID`
- `azure_ad_token` from `AZURE_OPENAI_AD_TOKEN`
- `api_version` from `OPENAI_API_VERSION`
- `azure_endpoint` from `AZURE_OPENAI_ENDPOINT`
"""
azure_client = openai.AsyncAzureOpenAI(
max_retries=0,
azure_endpoint=azure_endpoint,
azure_deployment=azure_deployment,
api_version=api_version,
api_key=api_key,
azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
organization=organization,
project=project,
base_url=base_url,
timeout=timeout
if timeout
else httpx.Timeout(connect=15.0, read=5.0, write=5.0, pool=5.0),
) # type: ignore
return LLM(
model=model,
client=azure_client,
user=user,
temperature=temperature,
parallel_tool_calls=parallel_tool_calls,
tool_choice=tool_choice,
)
@staticmethod
def with_cerebras(
*,
model: str | CerebrasChatModels = "llama3.1-8b",
api_key: str | None = None,
base_url: str | None = "https://api.cerebras.ai/v1",
client: openai.AsyncClient | None = None,
user: str | None = None,
temperature: float | None = None,
parallel_tool_calls: bool | None = None,
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = "auto",
) -> LLM:
"""
Create a new instance of Cerebras LLM.
``api_key`` must be set to your Cerebras API key, either using the argument or by setting
the ``CEREBRAS_API_KEY`` environmental variable.
"""
api_key = _get_api_key("CEREBRAS_API_KEY", api_key)
return LLM(
model=model,
api_key=api_key,
base_url=base_url,
client=client,
user=user,
temperature=temperature,
parallel_tool_calls=parallel_tool_calls,
tool_choice=tool_choice,
)
@staticmethod
def with_vertex(
*,
model: str | VertexModels = "google/gemini-2.0-flash-exp",
project_id: str | None = None,
location: str = "us-central1",
user: str | None = None,
temperature: float | None = None,
parallel_tool_calls: bool | None = None,
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = "auto",
) -> LLM:
"""
Create a new instance of VertexAI LLM.
`GOOGLE_APPLICATION_CREDENTIALS` environment variable must be set to the path of the service account key file.
"""
logger.warning(
"`openai.LLM.with_vertex()` is deprecated. Use `google.LLM()` instead."
)
project_id = project_id
location = location
_gac = os.environ.get("GOOGLE_APPLICATION_CREDENTIALS")
if _gac is None:
logger.warning(
"`GOOGLE_APPLICATION_CREDENTIALS` environment variable is not set. please set it to the path of the service account key file. Otherwise, use any of the other Google Cloud auth methods."
)
try:
from google.auth._default_async import default_async
from google.auth.transport._aiohttp_requests import Request
except ImportError:
raise ImportError(
"Google Auth dependencies not found. Please install with: `pip install livekit-plugins-openai[vertex]`"
)
class AuthTokenRefresher(openai.AsyncClient):
def __init__(self, **kwargs: Any) -> None:
self.creds, self.project = default_async(
scopes=["https://www.googleapis.com/auth/cloud-platform"]
)
project = project_id or self.project
base_url = f"https://{location}-aiplatform.googleapis.com/v1beta1/projects/{project}/locations/{location}/endpoints/openapi"
kwargs.update({"base_url": base_url})
super().__init__(api_key="DUMMY", **kwargs)
self.refresh_threshold = 600 # 10 minutes
def _token_needs_refresh(self) -> bool:
if not self.creds or not self.creds.valid:
return True
expiry = self.creds.expiry
if expiry is None:
return True
remaining = (expiry - datetime.datetime.utcnow()).total_seconds()
return remaining < self.refresh_threshold
async def _refresh_credentials(self) -> None:
if self.creds and self.creds.valid and not self._token_needs_refresh():
return
async with aiohttp.ClientSession(auto_decompress=False) as session:
auth_req = Request(session=session)
await self.creds.refresh(auth_req)
self.api_key = self.creds.token
client = AuthTokenRefresher(
max_retries=0,
http_client=httpx.AsyncClient(
timeout=httpx.Timeout(connect=15.0, read=5.0, write=5.0, pool=5.0),
follow_redirects=True,
limits=httpx.Limits(
max_connections=50,
max_keepalive_connections=50,
keepalive_expiry=120,
),
),
)
vertex_llm = LLM(
model=model,
client=client,
user=user,
temperature=temperature,
parallel_tool_calls=parallel_tool_calls,
tool_choice=tool_choice,
)
vertex_llm._capabilities = llm.LLMCapabilities(supports_choices_on_int=False)
return vertex_llm
@staticmethod
def with_fireworks(
*,
model: str = "accounts/fireworks/models/llama-v3p3-70b-instruct",
api_key: str | None = None,
base_url: str | None = "https://api.fireworks.ai/inference/v1",
client: openai.AsyncClient | None = None,
user: str | None = None,
temperature: float | None = None,
parallel_tool_calls: bool | None = None,
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = "auto",
) -> LLM:
"""
Create a new instance of Fireworks LLM.
``api_key`` must be set to your Fireworks API key, either using the argument or by setting
the ``FIREWORKS_API_KEY`` environmental variable.
"""
api_key = _get_api_key("FIREWORKS_API_KEY", api_key)
return LLM(
model=model,
api_key=api_key,
base_url=base_url,
client=client,
user=user,
temperature=temperature,
parallel_tool_calls=parallel_tool_calls,
tool_choice=tool_choice,
)
@staticmethod
def with_x_ai(
*,
model: str | XAIChatModels = "grok-2-public",
api_key: str | None = None,
base_url: str | None = "https://api.x.ai/v1",
client: openai.AsyncClient | None = None,
user: str | None = None,
temperature: float | None = None,
parallel_tool_calls: bool | None = None,
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = "auto",
) -> LLM:
"""
Create a new instance of XAI LLM.
``api_key`` must be set to your XAI API key, either using the argument or by setting
the ``XAI_API_KEY`` environmental variable.
"""
api_key = _get_api_key("XAI_API_KEY", api_key)
return LLM(
model=model,
api_key=api_key,
base_url=base_url,
client=client,
user=user,
temperature=temperature,
parallel_tool_calls=parallel_tool_calls,
tool_choice=tool_choice,
)
@staticmethod
def with_groq(
*,
model: str | GroqChatModels = "llama3-8b-8192",
api_key: str | None = None,
base_url: str | None = "https://api.groq.com/openai/v1",
client: openai.AsyncClient | None = None,
user: str | None = None,
temperature: float | None = None,
parallel_tool_calls: bool | None = None,
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = "auto",
max_tokens: int | None = None,
) -> LLM:
"""
Create a new instance of Groq LLM.
``api_key`` must be set to your Groq API key, either using the argument or by setting
the ``GROQ_API_KEY`` environmental variable.
"""
api_key = _get_api_key("GROQ_API_KEY", api_key)
return LLM(
model=model,
api_key=api_key,
base_url=base_url,
client=client,
user=user,
temperature=temperature,
parallel_tool_calls=parallel_tool_calls,
tool_choice=tool_choice,
max_tokens=max_tokens,
)
@staticmethod
def with_deepseek(
*,
model: str | DeepSeekChatModels = "deepseek-chat",
api_key: str | None = None,
base_url: str | None = "https://api.deepseek.com/v1",
client: openai.AsyncClient | None = None,
user: str | None = None,
temperature: float | None = None,
parallel_tool_calls: bool | None = None,
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = "auto",
) -> LLM:
"""
Create a new instance of DeepSeek LLM.
``api_key`` must be set to your DeepSeek API key, either using the argument or by setting
the ``DEEPSEEK_API_KEY`` environmental variable.
"""
api_key = _get_api_key("DEEPSEEK_API_KEY", api_key)
return LLM(
model=model,
api_key=api_key,
base_url=base_url,
client=client,
user=user,
temperature=temperature,
parallel_tool_calls=parallel_tool_calls,
tool_choice=tool_choice,
)
@staticmethod
def with_octo(
*,
model: str | OctoChatModels = "llama-2-13b-chat",
api_key: str | None = None,
base_url: str | None = "https://text.octoai.run/v1",
client: openai.AsyncClient | None = None,
user: str | None = None,
temperature: float | None = None,
parallel_tool_calls: bool | None = None,
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = "auto",
) -> LLM:
"""
Create a new instance of OctoAI LLM.
``api_key`` must be set to your OctoAI API key, either using the argument or by setting
the ``OCTOAI_TOKEN`` environmental variable.
"""
api_key = _get_api_key("OCTOAI_TOKEN", api_key)
return LLM(
model=model,
api_key=api_key,
base_url=base_url,
client=client,
user=user,
temperature=temperature,
parallel_tool_calls=parallel_tool_calls,
tool_choice=tool_choice,
)
@staticmethod
def with_ollama(
*,
model: str = "llama3.1",
base_url: str | None = "http://localhost:11434/v1",
client: openai.AsyncClient | None = None,
temperature: float | None = None,
parallel_tool_calls: bool | None = None,
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = "auto",
) -> LLM:
"""
Create a new instance of Ollama LLM.
"""
return LLM(
model=model,
api_key="ollama",
base_url=base_url,
client=client,
temperature=temperature,
parallel_tool_calls=parallel_tool_calls,
tool_choice=tool_choice,
)
@staticmethod
def with_perplexity(
*,
model: str | PerplexityChatModels = "llama-3.1-sonar-small-128k-chat",
api_key: str | None = None,
base_url: str | None = "https://api.perplexity.ai",
client: openai.AsyncClient | None = None,
user: str | None = None,
temperature: float | None = None,
parallel_tool_calls: bool | None = None,
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = "auto",
) -> LLM:
"""
Create a new instance of PerplexityAI LLM.
``api_key`` must be set to your Perplexity API key, either using the argument or by setting
the ``PERPLEXITY_API_KEY`` environmental variable.
"""
api_key = _get_api_key("PERPLEXITY_API_KEY", api_key)
return LLM(
model=model,
api_key=api_key,
base_url=base_url,
client=client,
user=user,
temperature=temperature,
parallel_tool_calls=parallel_tool_calls,
tool_choice=tool_choice,
)
@staticmethod
def with_together(
*,
model: str | TogetherChatModels = "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
api_key: str | None = None,
base_url: str | None = "https://api.together.xyz/v1",
client: openai.AsyncClient | None = None,
user: str | None = None,
temperature: float | None = None,
parallel_tool_calls: bool | None = None,
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = "auto",
) -> LLM:
"""
Create a new instance of TogetherAI LLM.
``api_key`` must be set to your TogetherAI API key, either using the argument or by setting
the ``TOGETHER_API_KEY`` environmental variable.
"""
api_key = _get_api_key("TOGETHER_API_KEY", api_key)
return LLM(
model=model,
api_key=api_key,
base_url=base_url,
client=client,
user=user,
temperature=temperature,
parallel_tool_calls=parallel_tool_calls,
tool_choice=tool_choice,
)
@staticmethod
def with_telnyx(
*,
model: str | TelnyxChatModels = "meta-llama/Meta-Llama-3.1-70B-Instruct",
api_key: str | None = None,
base_url: str | None = "https://api.telnyx.com/v2/ai",
client: openai.AsyncClient | None = None,
user: str | None = None,
temperature: float | None = None,
parallel_tool_calls: bool | None = None,
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = "auto",
) -> LLM:
"""
Create a new instance of Telnyx LLM.
``api_key`` must be set to your Telnyx API key, either using the argument or by setting
the ``TELNYX_API_KEY`` environmental variable.
"""
api_key = _get_api_key("TELNYX_API_KEY", api_key)
return LLM(
model=model,
api_key=api_key,
base_url=base_url,
client=client,
user=user,
temperature=temperature,
parallel_tool_calls=parallel_tool_calls,
tool_choice=tool_choice,
)
@staticmethod
def create_azure_client(
*,
model: str | ChatModels = "gpt-4o",
azure_endpoint: str | None = None,
azure_deployment: str | None = None,
api_version: str | None = None,
api_key: str | None = None,
azure_ad_token: str | None = None,
azure_ad_token_provider: AsyncAzureADTokenProvider | None = None,
organization: str | None = None,
project: str | None = None,
base_url: str | None = None,
user: str | None = None,
temperature: float | None = None,
parallel_tool_calls: bool | None = None,
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = "auto",
) -> LLM:
logger.warning("This alias is deprecated. Use LLM.with_azure() instead")
return LLM.with_azure(
model=model,
azure_endpoint=azure_endpoint,
api_version=api_version,
api_key=api_key,
azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
organization=organization,
project=project,
base_url=base_url,
user=user,
temperature=temperature,
parallel_tool_calls=parallel_tool_calls,
tool_choice=tool_choice,
)
def chat(
self,
*,
chat_ctx: llm.ChatContext,
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
fnc_ctx: llm.FunctionContext | None = None,
temperature: float | None = None,
n: int | None = 1,
parallel_tool_calls: bool | None = None,
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]]
| None = None,
) -> "LLMStream":
if parallel_tool_calls is None:
parallel_tool_calls = self._opts.parallel_tool_calls
if tool_choice is None:
tool_choice = self._opts.tool_choice
if temperature is None:
temperature = self._opts.temperature
return LLMStream(
self,
client=self._client,
model=self._opts.model,
user=self._opts.user,
chat_ctx=chat_ctx,
fnc_ctx=fnc_ctx,
conn_options=conn_options,
n=n,
temperature=temperature,
parallel_tool_calls=parallel_tool_calls,
tool_choice=tool_choice,
)
class LLMStream(llm.LLMStream):
def __init__(
self,
llm: LLM,
*,
client: openai.AsyncClient,
model: str | ChatModels,
user: str | None,
chat_ctx: llm.ChatContext,
conn_options: APIConnectOptions,
fnc_ctx: llm.FunctionContext | None,
temperature: float | None,
n: int | None,
parallel_tool_calls: bool | None,
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]],
) -> None:
super().__init__(
llm, chat_ctx=chat_ctx, fnc_ctx=fnc_ctx, conn_options=conn_options
)
self._client = client
self._model = model
self._llm: LLM = llm
self._user = user
self._temperature = temperature
self._n = n
self._parallel_tool_calls = parallel_tool_calls
self._tool_choice = tool_choice
async def _run(self) -> None:
if hasattr(self._llm._client, "_refresh_credentials"):
await self._llm._client._refresh_credentials()
# current function call that we're waiting for full completion (args are streamed)
# (defined inside the _run method to make sure the state is reset for each run/attempt)
self._oai_stream: openai.AsyncStream[ChatCompletionChunk] | None = None
self._tool_call_id: str | None = None
self._fnc_name: str | None = None
self._fnc_raw_arguments: str | None = None
self._tool_index: int | None = None
retryable = True
try:
if self._fnc_ctx and len(self._fnc_ctx.ai_functions) > 0:
tools = [
build_oai_function_description(fnc, self._llm._capabilities)
for fnc in self._fnc_ctx.ai_functions.values()
]
else:
tools = None
opts: dict[str, Any] = {
"tools": tools,
"parallel_tool_calls": self._parallel_tool_calls if tools else None,
"tool_choice": (
{"type": "function", "function": {"name": self._tool_choice.name}}
if isinstance(self._tool_choice, ToolChoice)
else self._tool_choice
)
if tools is not None
else None,
"temperature": self._temperature,
"metadata": self._llm._opts.metadata,
"max_tokens": self._llm._opts.max_tokens,
"store": self._llm._opts.store,
"n": self._n,
"stream": True,
"stream_options": {"include_usage": True},
"user": self._user or openai.NOT_GIVEN,
}
# remove None values from the options
opts = _strip_nones(opts)
messages = _build_oai_context(self._chat_ctx, id(self))
stream = await self._client.chat.completions.create(
messages=messages,
model=self._model,
**opts,
)
async with stream:
async for chunk in stream:
for choice in chunk.choices:
chat_chunk = self._parse_choice(chunk.id, choice)
if chat_chunk is not None:
retryable = False
self._event_ch.send_nowait(chat_chunk)
if chunk.usage is not None:
usage = chunk.usage
self._event_ch.send_nowait(
llm.ChatChunk(
request_id=chunk.id,
usage=llm.CompletionUsage(
completion_tokens=usage.completion_tokens,
prompt_tokens=usage.prompt_tokens,
total_tokens=usage.total_tokens,
),
)
)
except openai.APITimeoutError:
raise APITimeoutError(retryable=retryable)
except openai.APIStatusError as e:
raise APIStatusError(
e.message,
status_code=e.status_code,
request_id=e.request_id,
body=e.body,
)
except Exception as e:
raise APIConnectionError(retryable=retryable) from e
def _parse_choice(self, id: str, choice: Choice) -> llm.ChatChunk | None:
delta = choice.delta
# https://github.com/livekit/agents/issues/688
# the delta can be None when using Azure OpenAI using content filtering
if delta is None:
return None
if delta.tool_calls:
# check if we have functions to calls
for tool in delta.tool_calls:
if not tool.function:
continue # oai may add other tools in the future
call_chunk = None
if self._tool_call_id and tool.id and tool.index != self._tool_index:
call_chunk = self._try_build_function(id, choice)
if tool.function.name:
self._tool_index = tool.index
self._tool_call_id = tool.id
self._fnc_name = tool.function.name
self._fnc_raw_arguments = tool.function.arguments or ""
elif tool.function.arguments:
self._fnc_raw_arguments += tool.function.arguments # type: ignore
if call_chunk is not None:
return call_chunk
if choice.finish_reason in ("tool_calls", "stop") and self._tool_call_id:
# we're done with the tool calls, run the last one
return self._try_build_function(id, choice)
return llm.ChatChunk(
request_id=id,
choices=[
llm.Choice(
delta=llm.ChoiceDelta(content=delta.content, role="assistant"),
index=choice.index,
)
],
)
def _try_build_function(self, id: str, choice: Choice) -> llm.ChatChunk | None:
if not self._fnc_ctx:
logger.warning("oai stream tried to run function without function context")
return None
if self._tool_call_id is None:
logger.warning(
"oai stream tried to run function but tool_call_id is not set"
)
return None
if self._fnc_name is None or self._fnc_raw_arguments is None:
logger.warning(
"oai stream tried to call a function but raw_arguments and fnc_name are not set"
)
return None
fnc_info = _create_ai_function_info(
self._fnc_ctx, self._tool_call_id, self._fnc_name, self._fnc_raw_arguments
)
self._tool_call_id = self._fnc_name = self._fnc_raw_arguments = None
self._function_calls_info.append(fnc_info)
return llm.ChatChunk(
request_id=id,
choices=[
llm.Choice(
delta=llm.ChoiceDelta(
role="assistant",
tool_calls=[fnc_info],
content=choice.delta.content,
),
index=choice.index,
)
],
)
def _build_oai_context(
chat_ctx: llm.ChatContext, cache_key: Any
) -> list[ChatCompletionMessageParam]:
return [build_oai_message(msg, cache_key) for msg in chat_ctx.messages] # type: ignore
def _strip_nones(data: dict[str, Any]) -> dict[str, Any]:
return {k: v for k, v in data.items() if v is not None}
def _get_api_key(env_var: str, key: str | None) -> str:
key = key or os.environ.get(env_var)
if not key:
raise ValueError(
f"{env_var} is required, either as argument or set {env_var} environmental variable"
)
return key
import logging
logger = logging.getLogger("livekit.plugins.openai")
from typing import Literal
from openai.types import AudioModel
STTModels = AudioModel
TTSModels = Literal["tts-1", "tts-1-hd", "gpt-4o-mini-tts"]
TTSVoices = Literal[
"alloy",
"ash",
"ballad",
"coral",
"echo",
"fable",
"onyx",
"nova",
"sage",
"shimmer",
]
DalleModels = Literal["dall-e-2", "dall-e-3"]
ChatModels = Literal[
"gpt-4o",
"gpt-4o-2024-05-13",
"gpt-4o-mini",
"gpt-4o-mini-2024-07-18",
"gpt-4-turbo",
"gpt-4-turbo-2024-04-09",
"gpt-4-turbo-preview",
"gpt-4-0125-preview",
"gpt-4-1106-preview",
"gpt-4-vision-preview",
"gpt-4-1106-vision-preview",
"gpt-4",
"gpt-4-0314",
"gpt-4-0613",
"gpt-4-32k",
"gpt-4-32k-0314",
"gpt-4-32k-0613",
"gpt-3.5-turbo",
"gpt-3.5-turbo-16k",
"gpt-3.5-turbo-0301",
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-1106",
"gpt-3.5-turbo-16k-0613",
]
EmbeddingModels = Literal[
"text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large"
]
AssistantTools = Literal["code_interpreter", "file_search", "function"]
# adapters for OpenAI-compatible LLMs
TelnyxChatModels = Literal[
"meta-llama/Meta-Llama-3.1-8B-Instruct",
"meta-llama/Meta-Llama-3.1-70B-Instruct",
]
CerebrasChatModels = Literal[
"llama3.1-8b",
"llama3.1-70b",
"llama-3.3-70b",
]
PerplexityChatModels = Literal[
"llama-3.1-sonar-small-128k-online",
"llama-3.1-sonar-small-128k-chat",
"llama-3.1-sonar-large-128k-online",
"llama-3.1-sonar-large-128k-chat",
"llama-3.1-8b-instruct",
"llama-3.1-70b-instruct",
]
GroqChatModels = Literal[
"llama-3.1-405b-reasoning",
"llama-3.1-8b-instant",
"llama-3.3-70b-versatile",
"llama3-groq-70b-8192-tool-use-preview",
"llama3-groq-8b-8192-tool-use-preview",
"llama-guard-3-8b",
"llama3-70b-8192",
"llama3-8b-8192",
"mixtral-8x7b-32768",
"gemma-7b-it",
"gemma2-9b-it",
]
GroqAudioModels = Literal[
"whisper-large-v3", "distil-whisper-large-v3-en", "whisper-large-v3-turbo"
]
DeepSeekChatModels = Literal[
"deepseek-coder",
"deepseek-chat",
]
VertexModels = Literal[
"google/gemini-2.0-flash-exp",
"google/gemini-1.5-flash",
"google/gemini-1.5-pro",
"google/gemini-1.0-pro-vision",
"google/gemini-1.0-pro-vision-001",
"google/gemini-1.0-pro-002",
"google/gemini-1.0-pro-001",
"google/gemini-1.0-pro",
]
TogetherChatModels = Literal[
"Austism/chronos-hermes-13b",
"Gryphe/MythoMax-L2-13b",
"NousResearch/Nous-Capybara-7B-V1p9",
"NousResearch/Nous-Hermes-2-Mistral-7B-DPO",
"NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
"NousResearch/Nous-Hermes-2-Mixtral-8x7B-SFT",
"NousResearch/Nous-Hermes-2-Yi-34B",
"NousResearch/Nous-Hermes-Llama2-13b",
"NousResearch/Nous-Hermes-llama-2-7b",
"Open-Orca/Mistral-7B-OpenOrca",
"Qwen/Qwen1.5-0.5B-Chat",
"Qwen/Qwen1.5-1.8B-Chat",
"Qwen/Qwen1.5-110B-Chat",
"Qwen/Qwen1.5-14B-Chat",
"Qwen/Qwen1.5-32B-Chat",
"Qwen/Qwen1.5-4B-Chat",
"Qwen/Qwen1.5-72B-Chat",
"Qwen/Qwen1.5-7B-Chat",
"Qwen/Qwen2-72B-Instruct",
"Snowflake/snowflake-arctic-instruct",
"Undi95/ReMM-SLERP-L2-13B",
"Undi95/Toppy-M-7B",
"WizardLM/WizardLM-13B-V1.2",
"allenai/OLMo-7B",
"allenai/OLMo-7B-Instruct",
"allenai/OLMo-7B-Twin-2T",
"codellama/CodeLlama-13b-Instruct-hf",
"codellama/CodeLlama-34b-Instruct-hf",
"codellama/CodeLlama-70b-Instruct-hf",
"codellama/CodeLlama-7b-Instruct-hf",
"cognitivecomputations/dolphin-2.5-mixtral-8x7b",
"databricks/dbrx-instruct",
"deepseek-ai/deepseek-coder-33b-instruct",
"deepseek-ai/deepseek-llm-67b-chat",
"garage-bAInd/Platypus2-70B-instruct",
"google/gemma-2-27b-it",
"google/gemma-2-9b-it",
"google/gemma-2b-it",
"google/gemma-7b-it",
"lmsys/vicuna-13b-v1.5",
"lmsys/vicuna-7b-v1.5",
"meta-llama/Llama-2-13b-chat-hf",
"meta-llama/Llama-2-70b-chat-hf",
"meta-llama/Llama-2-7b-chat-hf",
"meta-llama/Llama-3-70b-chat-hf",
"meta-llama/Llama-3-8b-chat-hf",
"meta-llama/Meta-Llama-3-70B-Instruct-Lite",
"meta-llama/Meta-Llama-3-70B-Instruct-Turbo",
"meta-llama/Meta-Llama-3-8B-Instruct-Lite",
"meta-llama/Meta-Llama-3-8B-Instruct-Turbo",
"meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
"meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
"meta-llama/Llama-3.3-70B-Instruct-Turbo",
"mistralai/Mistral-7B-Instruct-v0.1",
"mistralai/Mistral-7B-Instruct-v0.2",
"mistralai/Mistral-7B-Instruct-v0.3",
"mistralai/Mixtral-8x22B-Instruct-v0.1",
"mistralai/Mixtral-8x7B-Instruct-v0.1",
"openchat/openchat-3.5-1210",
"snorkelai/Snorkel-Mistral-PairRM-DPO",
"teknium/OpenHermes-2-Mistral-7B",
"teknium/OpenHermes-2p5-Mistral-7B",
"togethercomputer/Llama-2-7B-32K-Instruct",
"togethercomputer/RedPajama-INCITE-7B-Chat",
"togethercomputer/RedPajama-INCITE-Chat-3B-v1",
"togethercomputer/StripedHyena-Nous-7B",
"togethercomputer/alpaca-7b",
"upstage/SOLAR-10.7B-Instruct-v1.0",
"zero-one-ai/Yi-34B-Chat",
]
OctoChatModels = Literal[
"meta-llama-3-70b-instruct",
"meta-llama-3.1-405b-instruct",
"meta-llama-3.1-70b-instruct",
"meta-llama-3.1-8b-instruct",
"mistral-7b-instruct",
"mixtral-8x7b-instruct",
"wizardlm-2-8x22bllamaguard-2-7b",
]
XAIChatModels = Literal[
"grok-2",
"grok-2-mini",
"grok-2-mini-public",
"grok-2-public",
]
from . import api_proto
from .realtime_model import (
DEFAULT_INPUT_AUDIO_TRANSCRIPTION,
DEFAULT_SERVER_VAD_OPTIONS,
InputTranscriptionOptions,
RealtimeContent,
RealtimeError,
RealtimeModel,
RealtimeOutput,
RealtimeResponse,
RealtimeSession,
RealtimeSessionOptions,
RealtimeToolCall,
SemanticVadEagerness,
SemanticVadOptions,
ServerVadOptions,
)
__all__ = [
"RealtimeContent",
"RealtimeOutput",
"RealtimeResponse",
"RealtimeToolCall",
"RealtimeSession",
"RealtimeModel",
"RealtimeError",
"RealtimeSessionOptions",
"ServerVadOptions",
"InputTranscriptionOptions",
"ConversationItemCreated",
"ConversationItemDeleted",
"api_proto",
"DEFAULT_INPUT_AUDIO_TRANSCRIPTION",
"DEFAULT_SERVER_VAD_OPTIONS",
"SemanticVadEagerness",
"SemanticVadOptions",
]
from __future__ import annotations
from typing import Literal, Union
from typing_extensions import NotRequired, TypedDict
SAMPLE_RATE = 24000
NUM_CHANNELS = 1
IN_FRAME_SIZE = 2400 # 100ms
OUT_FRAME_SIZE = 1200 # 50ms
class FunctionToolChoice(TypedDict):
type: Literal["function"]
name: str
Voice = Literal["alloy", "echo", "shimmer", "ash", "ballad", "coral", "sage", "verse"]
ToolChoice = Union[Literal["auto", "none", "required"], FunctionToolChoice]
Role = Literal["system", "assistant", "user", "tool"]
GenerationFinishedReason = Literal["stop", "max_tokens", "content_filter", "interrupt"]
AudioFormat = Literal["pcm16", "g711_ulaw", "g711_alaw"]
InputTranscriptionModel = Literal["whisper-1"]
Modality = Literal["text", "audio"]
ResponseStatus = Literal[
"in_progress", "completed", "incomplete", "cancelled", "failed"
]
# https://platform.openai.com/docs/models/gp#gpt-4o-realtime
OpenAIModel = Literal[
"gpt-4o-realtime-preview",
"gpt-4o-realtime-preview-2024-10-01",
"gpt-4o-realtime-preview-2024-12-17",
"gpt-4o-mini-realtime-preview",
"gpt-4o-mini-realtime-preview-2024-12-17",
]
DefaultOpenAIModel = "gpt-4o-realtime-preview"
class TextContent(TypedDict):
type: Literal["text"]
text: str
class InputTextContent(TypedDict):
type: Literal["input_text"]
text: str
class AudioContent(TypedDict):
type: Literal["audio"]
audio: str # b64
class InputAudioContent(TypedDict):
type: Literal["input_audio"]
audio: str # b64
Content = Union[InputTextContent, TextContent, AudioContent, InputAudioContent]
class ContentPart(TypedDict):
type: Literal["text", "audio"]
audio: NotRequired[str] # b64
transcript: NotRequired[str]
class InputAudioTranscription(TypedDict):
model: InputTranscriptionModel | str
language: NotRequired[str]
prompt: NotRequired[str]
class ServerVad(TypedDict):
type: Literal["server_vad"]
threshold: NotRequired[float]
prefix_padding_ms: NotRequired[int]
silence_duration_ms: NotRequired[int]
create_response: NotRequired[bool]
class SemanticVad(TypedDict):
type: Literal["semantic_vad"]
eagerness: NotRequired[Literal["low", "medium", "high", "auto"]]
create_response: NotRequired[bool]
interrupt_response: NotRequired[bool]
class FunctionTool(TypedDict):
type: Literal["function"]
name: str
description: NotRequired[str | None]
parameters: dict
class SystemItem(TypedDict):
id: str
object: Literal["realtime.item"]
type: Literal["message"]
role: Literal["system"]
content: list[InputTextContent]
class UserItem(TypedDict):
id: str
object: Literal["realtime.item"]
type: Literal["message"]
role: Literal["user"]
content: list[InputTextContent | InputAudioContent]
class AssistantItem(TypedDict):
id: str
object: Literal["realtime.item"]
type: Literal["message"]
role: Literal["assistant"]
content: list[TextContent | AudioContent]
class FunctionCallItem(TypedDict):
id: str
object: Literal["realtime.item"]
type: Literal["function_call"]
call_id: str
name: str
arguments: str
class FunctionCallOutputItem(TypedDict):
id: str
object: Literal["realtime.item"]
type: Literal["function_call_output"]
call_id: str
output: str
class CancelledStatusDetails(TypedDict):
type: Literal["cancelled"]
reason: Literal["turn_detected", "client_cancelled"]
class IncompleteStatusDetails(TypedDict):
type: Literal["incomplete"]
reason: Literal["max_output_tokens", "content_filter"]
class Error(TypedDict):
code: str
message: str
class FailedStatusDetails(TypedDict):
type: Literal["failed"]
error: NotRequired[Error | None]
ResponseStatusDetails = Union[
CancelledStatusDetails, IncompleteStatusDetails, FailedStatusDetails
]
class InputTokenDetails(TypedDict):
cached_tokens: int
text_tokens: int
audio_tokens: int
cached_tokens_details: CachedTokenDetails
class CachedTokenDetails(TypedDict):
text_tokens: int
audio_tokens: int
class OutputTokenDetails(TypedDict):
text_tokens: int
audio_tokens: int
class Usage(TypedDict):
total_tokens: int
input_tokens: int
output_tokens: int
input_token_details: InputTokenDetails
output_token_details: OutputTokenDetails
class Resource:
class Session(TypedDict):
id: str
object: Literal["realtime.session"]
expires_at: int
model: str
modalities: list[Literal["text", "audio"]]
instructions: str
voice: Voice
input_audio_format: AudioFormat
output_audio_format: AudioFormat
input_audio_transcription: InputAudioTranscription | None
turn_detection: Union[ServerVad, SemanticVad, None]
tools: list[FunctionTool]
tool_choice: ToolChoice
temperature: float
max_response_output_tokens: int | Literal["inf"]
class Conversation(TypedDict):
id: str
object: Literal["realtime.conversation"]
Item = Union[SystemItem, UserItem, FunctionCallItem, FunctionCallOutputItem]
class Response(TypedDict):
id: str
object: Literal["realtime.response"]
status: ResponseStatus
status_details: NotRequired[ResponseStatusDetails | None]
output: list[Resource.Item]
usage: NotRequired[Usage | None]
class ClientEvent:
class SessionUpdateData(TypedDict):
modalities: list[Literal["text", "audio"]]
instructions: str
voice: Voice
input_audio_format: AudioFormat
output_audio_format: AudioFormat
input_audio_transcription: InputAudioTranscription | None
turn_detection: Union[ServerVad, SemanticVad, None]
tools: list[FunctionTool]
tool_choice: ToolChoice
temperature: float
# microsoft does not support inf, but accepts None
max_response_output_tokens: int | Literal["inf"] | None
class SessionUpdate(TypedDict):
event_id: NotRequired[str]
type: Literal["session.update"]
session: ClientEvent.SessionUpdateData
class InputAudioBufferAppend(TypedDict):
event_id: NotRequired[str]
type: Literal["input_audio_buffer.append"]
audio: str # b64
class InputAudioBufferCommit(TypedDict):
event_id: NotRequired[str]
type: Literal["input_audio_buffer.commit"]
class InputAudioBufferClear(TypedDict):
event_id: NotRequired[str]
type: Literal["input_audio_buffer.clear"]
class UserItemCreate(TypedDict):
id: str | None
type: Literal["message"]
role: Literal["user"]
content: list[InputTextContent | InputAudioContent]
class AssistantItemCreate(TypedDict):
id: str | None
type: Literal["message"]
role: Literal["assistant"]
content: list[TextContent]
class SystemItemCreate(TypedDict):
id: str | None
type: Literal["message"]
role: Literal["system"]
content: list[InputTextContent]
class FunctionCallOutputItemCreate(TypedDict):
id: str | None
type: Literal["function_call_output"]
call_id: str
output: str
class FunctionCallItemCreate(TypedDict):
id: str | None
type: Literal["function_call"]
call_id: str
name: str
arguments: str
ConversationItemCreateContent = Union[
UserItemCreate,
AssistantItemCreate,
SystemItemCreate,
FunctionCallOutputItemCreate,
FunctionCallItemCreate,
]
class ConversationItemCreate(TypedDict):
event_id: NotRequired[str]
type: Literal["conversation.item.create"]
previous_item_id: NotRequired[str | None]
item: ClientEvent.ConversationItemCreateContent
class ConversationItemTruncate(TypedDict):
event_id: NotRequired[str]
type: Literal["conversation.item.truncate"]
item_id: str
content_index: int
audio_end_ms: int
class ConversationItemDelete(TypedDict):
event_id: NotRequired[str]
type: Literal["conversation.item.delete"]
item_id: str
class ResponseCreateData(TypedDict, total=False):
modalities: list[Literal["text", "audio"]]
instructions: str
voice: Voice
output_audio_format: AudioFormat
tools: list[FunctionTool]
tool_choice: ToolChoice
temperature: float
conversation: Literal["auto", "none"]
metadata: NotRequired[dict[str, str] | None]
max_output_tokens: int | Literal["inf"]
class ResponseCreate(TypedDict):
event_id: NotRequired[str]
type: Literal["response.create"]
response: NotRequired[ClientEvent.ResponseCreateData]
class ResponseCancel(TypedDict):
event_id: NotRequired[str]
type: Literal["response.cancel"]
class ServerEvent:
class ErrorContent(TypedDict):
type: str
code: NotRequired[str]
message: str
param: NotRequired[str]
event_id: NotRequired[str]
class Error(TypedDict):
event_id: str
type: Literal["error"]
error: ServerEvent.ErrorContent
class SessionCreated(TypedDict):
event_id: str
type: Literal["session.created"]
session: Resource.Session
class SessionUpdated(TypedDict):
event_id: str
type: Literal["session.updated"]
session: Resource.Session
class ConversationCreated(TypedDict):
event_id: str
type: Literal["conversation.created"]
conversation: Resource.Conversation
class InputAudioBufferCommitted(TypedDict):
event_id: str
type: Literal["input_audio_buffer.committed"]
item_id: str
class InputAudioBufferCleared(TypedDict):
event_id: str
type: Literal["input_audio_buffer.cleared"]
class InputAudioBufferSpeechStarted(TypedDict):
event_id: str
type: Literal["input_audio_buffer.speech_started"]
item_id: str
audio_start_ms: int
class InputAudioBufferSpeechStopped(TypedDict):
event_id: str
type: Literal["input_audio_buffer.speech_stopped"]
item_id: str
audio_end_ms: int
class ConversationItemCreated(TypedDict):
event_id: str
type: Literal["conversation.item.created"]
previous_item_id: str | None
item: Resource.Item
class ConversationItemInputAudioTranscriptionCompleted(TypedDict):
event_id: str
type: Literal["conversation.item.input_audio_transcription.completed"]
item_id: str
content_index: int
transcript: str
class InputAudioTranscriptionError(TypedDict):
type: str
code: NotRequired[str]
message: str
param: NotRequired[str]
class ConversationItemInputAudioTranscriptionFailed(TypedDict):
event_id: str
type: Literal["conversation.item.input_audio_transcription.failed"]
item_id: str
content_index: int
error: ServerEvent.InputAudioTranscriptionError
class ConversationItemTruncated(TypedDict):
event_id: str
type: Literal["conversation.item.truncated"]
item_id: str
content_index: int
audio_end_ms: int
class ConversationItemDeleted(TypedDict):
event_id: str
type: Literal["conversation.item.deleted"]
item_id: str
class ResponseCreated(TypedDict):
event_id: str
type: Literal["response.created"]
response: Resource.Response
class ResponseDone(TypedDict):
event_id: str
type: Literal["response.done"]
response: Resource.Response
class ResponseOutputItemAdded(TypedDict):
event_id: str
type: Literal["response.output_item.added"]
response_id: str
output_index: int
item: Resource.Item
class ResponseOutputItemDone(TypedDict):
event_id: str
type: Literal["response.output.done"]
response_id: str
output_index: int
item: Resource.Item
class ResponseContentPartAdded(TypedDict):
event_id: str
type: Literal["response.content_part.added"]
item_id: str
response_id: str
output_index: int
content_index: int
part: ContentPart
class ResponseContentPartDone(TypedDict):
event_id: str
type: Literal["response.content.done"]
response_id: str
output_index: int
content_index: int
part: ContentPart
class ResponseTextDeltaAdded(TypedDict):
event_id: str
type: Literal["response.text.delta"]
response_id: str
output_index: int
content_index: int
delta: str
class ResponseTextDone(TypedDict):
event_id: str
type: Literal["response.text.done"]
response_id: str
output_index: int
content_index: int
text: str
class ResponseAudioTranscriptDelta(TypedDict):
event_id: str
type: Literal["response.audio_transcript.delta"]
response_id: str
output_index: int
content_index: int
delta: str
class ResponseAudioTranscriptDone(TypedDict):
event_id: str
type: Literal["response.audio_transcript.done"]
response_id: str
output_index: int
content_index: int
transcript: str
class ResponseAudioDelta(TypedDict):
event_id: str
type: Literal["response.audio.delta"]
response_id: str
output_index: int
content_index: int
delta: str # b64
class ResponseAudioDone(TypedDict):
event_id: str
type: Literal["response.audio.done"]
response_id: str
output_index: int
content_index: int
class ResponseFunctionCallArgumentsDelta(TypedDict):
event_id: str
type: Literal["response.function_call_arguments.delta"]
response_id: str
output_index: int
delta: str
class ResponseFunctionCallArgumentsDone(TypedDict):
event_id: str
type: Literal["response.function_call_arguments.done"]
response_id: str
output_index: int
arguments: str
class RateLimitsData(TypedDict):
name: Literal["requests", "tokens", "input_tokens", "output_tokens"]
limit: int
remaining: int
reset_seconds: float
class RateLimitsUpdated:
event_id: str
type: Literal["rate_limits.updated"]
limits: list[ServerEvent.RateLimitsData]
ClientEvents = Union[
ClientEvent.SessionUpdate,
ClientEvent.InputAudioBufferAppend,
ClientEvent.InputAudioBufferCommit,
ClientEvent.InputAudioBufferClear,
ClientEvent.ConversationItemCreate,
ClientEvent.ConversationItemTruncate,
ClientEvent.ConversationItemDelete,
ClientEvent.ResponseCreate,
ClientEvent.ResponseCancel,
]
ServerEvents = Union[
ServerEvent.Error,
ServerEvent.SessionCreated,
ServerEvent.SessionUpdated,
ServerEvent.ConversationCreated,
ServerEvent.InputAudioBufferCommitted,
ServerEvent.InputAudioBufferCleared,
ServerEvent.InputAudioBufferSpeechStarted,
ServerEvent.InputAudioBufferSpeechStopped,
ServerEvent.ConversationItemCreated,
ServerEvent.ConversationItemInputAudioTranscriptionCompleted,
ServerEvent.ConversationItemInputAudioTranscriptionFailed,
ServerEvent.ConversationItemTruncated,
ServerEvent.ConversationItemDeleted,
ServerEvent.ResponseCreated,
ServerEvent.ResponseDone,
ServerEvent.ResponseOutputItemAdded,
ServerEvent.ResponseOutputItemDone,
ServerEvent.ResponseContentPartAdded,
ServerEvent.ResponseContentPartDone,
ServerEvent.ResponseTextDeltaAdded,
ServerEvent.ResponseTextDone,
ServerEvent.ResponseAudioTranscriptDelta,
ServerEvent.ResponseAudioTranscriptDone,
ServerEvent.ResponseAudioDelta,
ServerEvent.ResponseAudioDone,
ServerEvent.ResponseFunctionCallArgumentsDelta,
ServerEvent.ResponseFunctionCallArgumentsDone,
ServerEvent.RateLimitsUpdated,
]
ClientEventType = Literal[
"session.update",
"input_audio_buffer.append",
"input_audio_buffer.commit",
"input_audio_buffer.clear",
"conversation.item.create",
"conversation.item.truncate",
"conversation.item.delete",
"response.create",
"response.cancel",
]
ServerEventType = Literal[
"error",
"session.created",
"session.updated",
"conversation.created",
"input_audio_buffer.committed",
"input_audio_buffer.cleared",
"input_audio_buffer.speech_started",
"input_audio_buffer.speech_stopped",
"conversation.item.created",
"conversation.item.input_audio_transcription.completed",
"conversation.item.input_audio_transcription.failed",
"conversation.item.truncated",
"conversation.item.deleted",
"response.created",
"response.done",
"response.output_item.added",
"response.output_item.done",
"response.content_part.added",
"response.content_part.done",
"response.text.delta",
"response.text.done",
"response.audio_transcript.delta",
"response.audio_transcript.done",
"response.audio.delta",
"response.audio.done",
"response.function_call_arguments.delta",
"response.function_call_arguments.done",
"rate_limits.updated",
]
import logging
logger = logging.getLogger("livekit.plugins.openai.realtime")
from __future__ import annotations
import asyncio
import base64
import os
import time
from copy import deepcopy
from dataclasses import dataclass
from enum import Enum
from typing import AsyncIterable, Literal, Optional, Union, cast, overload
from urllib.parse import urlencode
import aiohttp
from livekit import rtc
from livekit.agents import llm, utils
from livekit.agents.llm.function_context import _create_ai_function_info
from livekit.agents.metrics import MultimodalLLMError, MultimodalLLMMetrics
from livekit.agents.types import NOT_GIVEN, NotGivenOr
from typing_extensions import TypedDict
from .._oai_api import build_oai_function_description
from . import api_proto, remote_items
from .log import logger
EventTypes = Literal[
"start_session",
"session_updated",
"error",
"input_speech_started",
"input_speech_stopped",
"input_speech_committed",
"input_speech_transcription_completed",
"input_speech_transcription_failed",
"response_created",
"response_output_added", # message & assistant
"response_content_added", # message type (audio/text)
"response_content_done",
"response_output_done",
"response_done",
"function_calls_collected",
"function_calls_finished",
"metrics_collected",
]
@dataclass
class InputTranscriptionCompleted:
item_id: str
"""id of the item"""
transcript: str
"""transcript of the input audio"""
@dataclass
class InputTranscriptionFailed:
item_id: str
"""id of the item"""
message: str
"""error message"""
@dataclass
class RealtimeResponse:
id: str
"""id of the message"""
status: api_proto.ResponseStatus
"""status of the response"""
status_details: api_proto.ResponseStatusDetails | None
"""details of the status (only with "incomplete, cancelled and failed")"""
output: list[RealtimeOutput]
"""list of outputs"""
usage: api_proto.Usage | None
"""usage of the response"""
done_fut: asyncio.Future[None]
"""future that will be set when the response is completed"""
_created_timestamp: float
"""timestamp when the response was created"""
_first_token_timestamp: float | None = None
"""timestamp when the first token was received"""
metadata: map | None = None
"""developer-provided string key-value pairs"""
@dataclass
class RealtimeOutput:
response_id: str
"""id of the response"""
item_id: str
"""id of the item"""
output_index: int
"""index of the output"""
role: api_proto.Role
"""role of the message"""
type: Literal["message", "function_call"]
"""type of the output"""
content: list[RealtimeContent]
"""list of content"""
done_fut: asyncio.Future[None]
"""future that will be set when the output is completed"""
@dataclass
class RealtimeToolCall:
name: str
"""name of the function"""
arguments: str
"""accumulated arguments"""
tool_call_id: str
"""id of the tool call"""
@dataclass
class Capabilities:
supports_truncate: bool
input_audio_sample_rate: int | None = None
@dataclass
class RealtimeContent:
response_id: str
"""id of the response"""
item_id: str
"""id of the item"""
output_index: int
"""index of the output"""
content_index: int
"""index of the content"""
text: str
"""accumulated text content"""
audio: list[rtc.AudioFrame]
"""accumulated audio content"""
text_stream: AsyncIterable[str]
"""stream of text content"""
audio_stream: AsyncIterable[rtc.AudioFrame]
"""stream of audio content"""
tool_calls: list[RealtimeToolCall]
"""pending tool calls"""
content_type: api_proto.Modality
"""type of the content"""
@dataclass
class ServerVadOptions:
threshold: float
prefix_padding_ms: int
silence_duration_ms: int
create_response: bool = True
class SemanticVadEagerness(str, Enum):
LOW = "low"
MEDIUM = "medium"
HIGH = "high"
AUTO = "auto"
@dataclass
class SemanticVadOptions:
eagerness: SemanticVadEagerness = SemanticVadEagerness.AUTO
create_response: bool = True
interrupt_response: bool = True
@dataclass
class InputTranscriptionOptions:
model: api_proto.InputTranscriptionModel | str
language: str | None = None
prompt: str | None = None
@dataclass
class RealtimeError:
event_id: str
type: str
message: str
code: Optional[str]
param: Optional[str]
@dataclass
class RealtimeSessionOptions:
model: api_proto.OpenAIModel | str
modalities: list[api_proto.Modality]
instructions: str
voice: api_proto.Voice
input_audio_format: api_proto.AudioFormat
output_audio_format: api_proto.AudioFormat
input_audio_transcription: InputTranscriptionOptions | None
turn_detection: Union[ServerVadOptions, SemanticVadOptions, None]
tool_choice: api_proto.ToolChoice
temperature: float
max_response_output_tokens: int | Literal["inf"]
@dataclass
class _ModelOptions(RealtimeSessionOptions):
api_key: str | None
base_url: str
entra_token: str | None
azure_deployment: str | None
is_azure: bool
api_version: str | None
class _ContentPtr(TypedDict):
response_id: str
output_index: int
content_index: int
DEFAULT_SERVER_VAD_OPTIONS = ServerVadOptions(
threshold=0.5,
prefix_padding_ms=300,
silence_duration_ms=500,
create_response=True,
)
DEFAULT_SEMANTIC_VAD_OPTIONS = SemanticVadOptions(
eagerness=SemanticVadEagerness.AUTO,
create_response=True,
interrupt_response=True,
)
DEFAULT_INPUT_AUDIO_TRANSCRIPTION = InputTranscriptionOptions(model="whisper-1")
class RealtimeModel:
@overload
def __init__(
self,
*,
instructions: str = "",
modalities: list[api_proto.Modality] = ["text", "audio"],
model: api_proto.OpenAIModel | str = api_proto.DefaultOpenAIModel,
voice: api_proto.Voice = "alloy",
input_audio_format: api_proto.AudioFormat = "pcm16",
output_audio_format: api_proto.AudioFormat = "pcm16",
input_audio_transcription: InputTranscriptionOptions = DEFAULT_INPUT_AUDIO_TRANSCRIPTION,
turn_detection: Optional[ServerVadOptions] = DEFAULT_SERVER_VAD_OPTIONS,
tool_choice: api_proto.ToolChoice = "auto",
temperature: float = 0.8,
max_response_output_tokens: int | Literal["inf"] = "inf",
api_key: str | None = None,
base_url: str | None = None,
http_session: aiohttp.ClientSession | None = None,
loop: asyncio.AbstractEventLoop | None = None,
) -> None: ...
@overload
def __init__(
self,
*,
azure_deployment: str | None = None,
entra_token: str | None = None,
api_key: str | None = None,
api_version: str | None = None,
base_url: str | None = None,
instructions: str = "",
modalities: list[api_proto.Modality] = ["text", "audio"],
voice: api_proto.Voice = "alloy",
input_audio_format: api_proto.AudioFormat = "pcm16",
output_audio_format: api_proto.AudioFormat = "pcm16",
input_audio_transcription: InputTranscriptionOptions = DEFAULT_INPUT_AUDIO_TRANSCRIPTION,
turn_detection: Optional[ServerVadOptions] = DEFAULT_SERVER_VAD_OPTIONS,
tool_choice: api_proto.ToolChoice = "auto",
temperature: float = 0.8,
max_response_output_tokens: int | Literal["inf"] = "inf",
http_session: aiohttp.ClientSession | None = None,
loop: asyncio.AbstractEventLoop | None = None,
) -> None: ...
def __init__(
self,
*,
instructions: str = "",
modalities: list[api_proto.Modality] = ["text", "audio"],
model: api_proto.OpenAIModel | str = api_proto.DefaultOpenAIModel,
voice: api_proto.Voice = "alloy",
input_audio_format: api_proto.AudioFormat = "pcm16",
output_audio_format: api_proto.AudioFormat = "pcm16",
input_audio_transcription: InputTranscriptionOptions = DEFAULT_INPUT_AUDIO_TRANSCRIPTION,
turn_detection: Optional[ServerVadOptions] = DEFAULT_SERVER_VAD_OPTIONS,
tool_choice: api_proto.ToolChoice = "auto",
temperature: float = 0.8,
max_response_output_tokens: int | Literal["inf"] = "inf",
base_url: str | None = None,
http_session: aiohttp.ClientSession | None = None,
loop: asyncio.AbstractEventLoop | None = None,
# azure specific parameters
azure_deployment: str | None = None,
entra_token: str | None = None,
api_key: str | None = None,
api_version: str | None = None,
) -> None:
"""
Initializes a RealtimeClient instance for interacting with OpenAI's Realtime API.
Args:
instructions (str, optional): Initial system instructions for the model. Defaults to "".
api_key (str or None, optional): OpenAI API key. If None, will attempt to read from the environment variable OPENAI_API_KEY
modalities (list[api_proto.Modality], optional): Modalities to use, such as ["text", "audio"]. Defaults to ["text", "audio"].
model (str or None, optional): The name of the model to use. Defaults to "gpt-4o-realtime-preview-2024-10-01".
voice (api_proto.Voice, optional): Voice setting for audio outputs. Defaults to "alloy".
input_audio_format (api_proto.AudioFormat, optional): Format of input audio data. Defaults to "pcm16".
output_audio_format (api_proto.AudioFormat, optional): Format of output audio data. Defaults to "pcm16".
input_audio_transcription (InputTranscriptionOptions, optional): Options for transcribing input audio. Defaults to DEFAULT_INPUT_AUDIO_TRANSCRIPTION.
turn_detection (ServerVadOptions, optional): Options for server-based voice activity detection (VAD). Defaults to DEFAULT_SERVER_VAD_OPTIONS.
tool_choice (api_proto.ToolChoice, optional): Tool choice for the model, such as "auto". Defaults to "auto".
temperature (float, optional): Sampling temperature for response generation. Defaults to 0.8.
max_response_output_tokens (int or Literal["inf"], optional): Maximum number of tokens in the response. Defaults to "inf".
base_url (str or None, optional): Base URL for the API endpoint. If None, defaults to OpenAI's default API URL.
http_session (aiohttp.ClientSession or None, optional): Async HTTP session to use for requests. If None, a new session will be created.
loop (asyncio.AbstractEventLoop or None, optional): Event loop to use for async operations. If None, the current event loop is used.
Raises:
ValueError: If the API key is not provided and cannot be found in environment variables.
"""
super().__init__()
self._capabilities = Capabilities(
supports_truncate=True,
)
self._base_url = base_url
is_azure = (
api_version is not None
or entra_token is not None
or azure_deployment is not None
)
api_key = api_key or os.environ.get("OPENAI_API_KEY")
if api_key is None and not is_azure:
raise ValueError(
"OpenAI API key is required, either using the argument or by setting the OPENAI_API_KEY environmental variable"
)
if not base_url:
base_url = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
self._default_opts = _ModelOptions(
model=model,
modalities=modalities,
instructions=instructions,
voice=voice,
input_audio_format=input_audio_format,
output_audio_format=output_audio_format,
input_audio_transcription=input_audio_transcription,
turn_detection=turn_detection,
temperature=temperature,
tool_choice=tool_choice,
max_response_output_tokens=max_response_output_tokens,
api_key=api_key,
base_url=base_url,
azure_deployment=azure_deployment,
entra_token=entra_token,
is_azure=is_azure,
api_version=api_version,
)
self._loop = loop or asyncio.get_event_loop()
self._rt_sessions: list[RealtimeSession] = []
self._http_session = http_session
@classmethod
def with_azure(
cls,
*,
azure_deployment: str,
azure_endpoint: str | None = None,
api_version: str | None = None,
api_key: str | None = None,
entra_token: str | None = None,
base_url: str | None = None,
instructions: str = "",
modalities: list[api_proto.Modality] = ["text", "audio"],
voice: api_proto.Voice = "alloy",
input_audio_format: api_proto.AudioFormat = "pcm16",
output_audio_format: api_proto.AudioFormat = "pcm16",
input_audio_transcription: InputTranscriptionOptions = DEFAULT_INPUT_AUDIO_TRANSCRIPTION,
turn_detection: Optional[ServerVadOptions] = DEFAULT_SERVER_VAD_OPTIONS,
tool_choice: api_proto.ToolChoice = "auto",
temperature: float = 0.8,
max_response_output_tokens: int | Literal["inf"] = "inf",
http_session: aiohttp.ClientSession | None = None,
loop: asyncio.AbstractEventLoop | None = None,
):
"""
Create a RealtimeClient instance configured for Azure OpenAI Service.
Args:
azure_deployment (str): The name of your Azure OpenAI deployment.
azure_endpoint (str or None, optional): The endpoint URL for your Azure OpenAI resource. If None, will attempt to read from the environment variable AZURE_OPENAI_ENDPOINT.
api_version (str or None, optional): API version to use with Azure OpenAI Service. If None, will attempt to read from the environment variable OPENAI_API_VERSION.
api_key (str or None, optional): Azure OpenAI API key. If None, will attempt to read from the environment variable AZURE_OPENAI_API_KEY.
entra_token (str or None, optional): Azure Entra authentication token. Required if not using API key authentication.
base_url (str or None, optional): Base URL for the API endpoint. If None, constructed from the azure_endpoint.
instructions (str, optional): Initial system instructions for the model. Defaults to "".
modalities (list[api_proto.Modality], optional): Modalities to use, such as ["text", "audio"]. Defaults to ["text", "audio"].
voice (api_proto.Voice, optional): Voice setting for audio outputs. Defaults to "alloy".
input_audio_format (api_proto.AudioFormat, optional): Format of input audio data. Defaults to "pcm16".
output_audio_format (api_proto.AudioFormat, optional): Format of output audio data. Defaults to "pcm16".
input_audio_transcription (InputTranscriptionOptions, optional): Options for transcribing input audio. Defaults to DEFAULT_INPUT_AUDIO_TRANSCRIPTION.
turn_detection (ServerVadOptions, optional): Options for server-based voice activity detection (VAD). Defaults to DEFAULT_SERVER_VAD_OPTIONS.
tool_choice (api_proto.ToolChoice, optional): Tool choice for the model, such as "auto". Defaults to "auto".
temperature (float, optional): Sampling temperature for response generation. Defaults to 0.8.
max_response_output_tokens (int or Literal["inf"], optional): Maximum number of tokens in the response. Defaults to "inf".
http_session (aiohttp.ClientSession or None, optional): Async HTTP session to use for requests. If None, a new session will be created.
loop (asyncio.AbstractEventLoop or None, optional): Event loop to use for async operations. If None, the current event loop is used.
Returns:
RealtimeClient: An instance of RealtimeClient configured for Azure OpenAI Service.
Raises:
ValueError: If required Azure parameters are missing or invalid.
"""
api_key = api_key or os.getenv("AZURE_OPENAI_API_KEY")
if api_key is None and entra_token is None:
raise ValueError(
"Missing credentials. Please pass one of `api_key`, `entra_token`, or the `AZURE_OPENAI_API_KEY` environment variable."
)
api_version = api_version or os.getenv("OPENAI_API_VERSION")
if api_version is None:
raise ValueError(
"Must provide either the `api_version` argument or the `OPENAI_API_VERSION` environment variable"
)
if base_url is None:
azure_endpoint = azure_endpoint or os.getenv("AZURE_OPENAI_ENDPOINT")
if azure_endpoint is None:
raise ValueError(
"Missing Azure endpoint. Please pass the `azure_endpoint` parameter or set the `AZURE_OPENAI_ENDPOINT` environment variable."
)
base_url = f"{azure_endpoint.rstrip('/')}/openai"
elif azure_endpoint is not None:
raise ValueError("base_url and azure_endpoint are mutually exclusive")
return cls(
instructions=instructions,
modalities=modalities,
voice=voice,
input_audio_format=input_audio_format,
output_audio_format=output_audio_format,
input_audio_transcription=input_audio_transcription,
turn_detection=turn_detection,
tool_choice=tool_choice,
temperature=temperature,
max_response_output_tokens=max_response_output_tokens,
api_key=api_key,
http_session=http_session,
loop=loop,
azure_deployment=azure_deployment,
api_version=api_version,
entra_token=entra_token,
base_url=base_url,
)
def _ensure_session(self) -> aiohttp.ClientSession:
if not self._http_session:
self._http_session = utils.http_context.http_session()
return self._http_session
@property
def sessions(self) -> list[RealtimeSession]:
return self._rt_sessions
@property
def capabilities(self) -> Capabilities:
return self._capabilities
def session(
self,
*,
chat_ctx: llm.ChatContext | None = None,
fnc_ctx: llm.FunctionContext | None = None,
modalities: list[api_proto.Modality] | None = None,
instructions: str | None = None,
voice: api_proto.Voice | None = None,
input_audio_format: api_proto.AudioFormat | None = None,
output_audio_format: api_proto.AudioFormat | None = None,
tool_choice: api_proto.ToolChoice | None = None,
input_audio_transcription: NotGivenOr[
InputTranscriptionOptions | None
] = NOT_GIVEN,
turn_detection: NotGivenOr[
Union[ServerVadOptions, SemanticVadOptions, None]
] = NOT_GIVEN,
temperature: float | None = None,
max_response_output_tokens: int | Literal["inf"] | None = None,
) -> RealtimeSession:
opts = deepcopy(self._default_opts)
if modalities is not None:
opts.modalities = modalities
if instructions is not None:
opts.instructions = instructions
if voice is not None:
opts.voice = voice
if input_audio_format is not None:
opts.input_audio_format = input_audio_format
if output_audio_format is not None:
opts.output_audio_format = output_audio_format
if tool_choice is not None:
opts.tool_choice = tool_choice
if utils.is_given(input_audio_transcription):
opts.input_audio_transcription = input_audio_transcription
if utils.is_given(turn_detection):
opts.turn_detection = cast(
Union[ServerVadOptions, SemanticVadOptions, None], turn_detection
)
if temperature is not None:
opts.temperature = temperature
if max_response_output_tokens is not None:
opts.max_response_output_tokens = max_response_output_tokens
new_session = RealtimeSession(
chat_ctx=chat_ctx or llm.ChatContext(),
fnc_ctx=fnc_ctx,
opts=opts,
http_session=self._ensure_session(),
loop=self._loop,
)
self._rt_sessions.append(new_session)
return new_session
async def aclose(self) -> None:
for session in self._rt_sessions:
await session.aclose()
class RealtimeSession(utils.EventEmitter[EventTypes]):
class InputAudioBuffer:
def __init__(self, sess: RealtimeSession) -> None:
self._sess = sess
def append(self, frame: rtc.AudioFrame) -> None:
self._sess._queue_msg(
{
"type": "input_audio_buffer.append",
"audio": base64.b64encode(frame.data).decode("utf-8"),
}
)
def clear(self) -> None:
self._sess._queue_msg({"type": "input_audio_buffer.clear"})
def commit(self) -> None:
self._sess._queue_msg({"type": "input_audio_buffer.commit"})
class ConversationItem:
def __init__(self, sess: RealtimeSession) -> None:
self._sess = sess
def create(
self, message: llm.ChatMessage, previous_item_id: str | None = None
) -> asyncio.Future[bool]:
fut = asyncio.Future[bool]()
message_content = message.content
tool_call_id = message.tool_call_id
event: api_proto.ClientEvent.ConversationItemCreate | None = None
if tool_call_id:
if message.role == "tool":
# function_call_output
assert isinstance(message_content, str)
event = {
"type": "conversation.item.create",
"previous_item_id": previous_item_id,
"item": {
"id": message.id,
"type": "function_call_output",
"call_id": tool_call_id,
"output": message_content,
},
}
else:
# function_call
if not message.tool_calls or message.name is None:
logger.warning(
"function call message has no name or tool calls: %s",
message,
extra=self._sess.logging_extra(),
)
fut.set_result(False)
return fut
if len(message.tool_calls) > 1:
logger.warning(
"function call message has multiple tool calls, "
"only the first one will be used",
extra=self._sess.logging_extra(),
)
event = {
"type": "conversation.item.create",
"previous_item_id": previous_item_id,
"item": {
"id": message.id,
"type": "function_call",
"call_id": tool_call_id,
"name": message.name,
"arguments": message.tool_calls[0].raw_arguments,
},
}
else:
if message_content is None:
logger.warning(
"message content is None, skipping: %s",
message,
extra=self._sess.logging_extra(),
)
fut.set_result(False)
return fut
if not isinstance(message_content, list):
message_content = [message_content]
if message.role == "user":
user_contents: list[
api_proto.InputTextContent | api_proto.InputAudioContent
] = []
for cnt in message_content:
if isinstance(cnt, str):
user_contents.append(
{
"type": "input_text",
"text": cnt,
}
)
elif isinstance(cnt, llm.ChatAudio):
user_contents.append(
{
"type": "input_audio",
"audio": base64.b64encode(
utils.merge_frames(cnt.frame).data
).decode("utf-8"),
}
)
event = {
"type": "conversation.item.create",
"previous_item_id": previous_item_id,
"item": {
"id": message.id,
"type": "message",
"role": "user",
"content": user_contents,
},
}
elif message.role == "assistant":
assistant_contents: list[api_proto.TextContent] = []
for cnt in message_content:
if isinstance(cnt, str):
assistant_contents.append(
{
"type": "text",
"text": cnt,
}
)
elif isinstance(cnt, llm.ChatAudio):
logger.warning(
"audio content in assistant message is not supported"
)
event = {
"type": "conversation.item.create",
"previous_item_id": previous_item_id,
"item": {
"id": message.id,
"type": "message",
"role": "assistant",
"content": assistant_contents,
},
}
elif message.role == "system":
system_contents: list[api_proto.InputTextContent] = []
for cnt in message_content:
if isinstance(cnt, str):
system_contents.append({"type": "input_text", "text": cnt})
elif isinstance(cnt, llm.ChatAudio):
logger.warning(
"audio content in system message is not supported"
)
event = {
"type": "conversation.item.create",
"previous_item_id": previous_item_id,
"item": {
"id": message.id,
"type": "message",
"role": "system",
"content": system_contents,
},
}
if event is None:
logger.warning(
"chat message is not supported inside the realtime API %s",
message,
extra=self._sess.logging_extra(),
)
fut.set_result(False)
return fut
self._sess._item_created_futs[message.id] = fut
self._sess._queue_msg(event)
return fut
def truncate(
self, *, item_id: str, content_index: int, audio_end_ms: int
) -> asyncio.Future[bool]:
fut = asyncio.Future[bool]()
self._sess._item_truncated_futs[item_id] = fut
self._sess._queue_msg(
{
"type": "conversation.item.truncate",
"item_id": item_id,
"content_index": content_index,
"audio_end_ms": audio_end_ms,
}
)
return fut
def delete(self, *, item_id: str) -> asyncio.Future[bool]:
fut = asyncio.Future[bool]()
self._sess._item_deleted_futs[item_id] = fut
self._sess._queue_msg(
{
"type": "conversation.item.delete",
"item_id": item_id,
}
)
return fut
class Conversation:
def __init__(self, sess: RealtimeSession) -> None:
self._sess = sess
@property
def item(self) -> RealtimeSession.ConversationItem:
return RealtimeSession.ConversationItem(self._sess)
class Response:
def __init__(self, sess: RealtimeSession) -> None:
self._sess = sess
def create(
self,
*,
on_duplicate: Literal[
"cancel_existing", "cancel_new", "keep_both"
] = "keep_both",
instructions: Optional[str] = None,
modalities: Optional[list[api_proto.Modality]] = None,
conversation: Literal["auto", "none"] = "auto",
metadata: Optional[dict[str, str]] = None,
) -> asyncio.Future[bool]:
"""Creates a new response.
Args:
on_duplicate: How to handle when there is an existing response in progress:
- "cancel_existing": Cancel the existing response before creating new one
- "cancel_new": Skip creating new response if one is in progress
- "keep_both": Wait for the existing response to be done and then create a new one
instructions: explicit prompt used for out-of-band events
modalities: set of modalities that the model can respond in, defaults to audio
conversation: specifies whether response is out-of-band
- "auto": Contents of the response will be added to the default conversation
- "none": Creates an out-of-band response which will not add items to default conversation
metadata: set of key-value pairs that can be used for storing additional information
Returns:
Future that resolves when the response create request is queued
"""
if on_duplicate not in ("cancel_existing", "cancel_new", "keep_both"):
raise ValueError(
"invalid on_duplicate value, must be one of: "
"cancel_existing, cancel_new, keep_both"
)
# check if there is a pending response creation request sent
pending_create_fut = self._sess._response_create_fut
if pending_create_fut is not None:
if on_duplicate == "cancel_new":
logger.warning(
"skip new response creation due to previous pending response creation",
extra=self._sess.logging_extra(),
)
_fut = asyncio.Future[bool]()
_fut.set_result(False)
return _fut
active_resp_id = self._sess._active_response_id
_logging_extra = {
"existing_response_id": active_resp_id,
**self._sess.logging_extra(),
}
response_request: api_proto.ClientEvent.ResponseCreateData = {
"conversation": conversation
}
if instructions is not None:
response_request["instructions"] = instructions
if modalities is not None:
response_request["modalities"] = modalities
if metadata is not None:
response_request["metadata"] = metadata
if (
not active_resp_id
or self._sess._pending_responses[active_resp_id].done_fut.done()
):
# no active response in progress, create a new one
self._sess._queue_msg(
{
"type": "response.create",
"response": response_request,
}
)
_fut = asyncio.Future[bool]()
_fut.set_result(True)
return _fut
# there is an active response in progress
if on_duplicate == "cancel_new":
logger.warning(
"skip new response creation due to active response in progress",
extra=_logging_extra,
)
_fut = asyncio.Future[bool]()
_fut.set_result(False)
return _fut
if on_duplicate == "cancel_existing":
self.cancel()
logger.warning(
"cancelling in-progress response to create a new one",
extra=_logging_extra,
)
elif on_duplicate == "keep_both":
logger.warning(
"waiting for in-progress response to be done "
"before creating a new one",
extra=_logging_extra,
)
# create a task to wait for the previous response and then create new one
async def wait_and_create() -> bool:
await self._sess._pending_responses[active_resp_id].done_fut
logger.info(
"in-progress response is done, creating a new one",
extra=_logging_extra,
)
new_create_fut = asyncio.Future[None]()
self._sess._response_create_fut = new_create_fut
self._sess._queue_msg(
{
"type": "response.create",
"response": response_request,
}
)
return True
return asyncio.create_task(wait_and_create())
def cancel(self) -> None:
self._sess._queue_msg({"type": "response.cancel"})
def __init__(
self,
*,
opts: _ModelOptions,
http_session: aiohttp.ClientSession,
chat_ctx: llm.ChatContext,
fnc_ctx: llm.FunctionContext | None,
loop: asyncio.AbstractEventLoop,
) -> None:
super().__init__()
self._label = f"{type(self).__module__}.{type(self).__name__}"
self._main_atask = asyncio.create_task(
self._main_task(), name="openai-realtime-session"
)
# manage conversation items internally
self._remote_conversation_items = remote_items._RemoteConversationItems()
# wait for the item to be created or deleted
self._item_created_futs: dict[str, asyncio.Future[bool]] = {}
self._item_deleted_futs: dict[str, asyncio.Future[bool]] = {}
self._item_truncated_futs: dict[str, asyncio.Future[bool]] = {}
self._fnc_ctx = fnc_ctx
self._loop = loop
self._opts = opts
self._send_ch = utils.aio.Chan[api_proto.ClientEvents]()
self._http_session = http_session
self._pending_responses: dict[str, RealtimeResponse] = {}
self._active_response_id: str | None = None
self._response_create_fut: asyncio.Future[None] | None = None
self._playout_complete = asyncio.Event()
self._playout_complete.set()
self._session_id = "not-connected"
self.session_update() # initial session init
# sync the chat context to the session
self._init_sync_task = asyncio.create_task(self.set_chat_ctx(chat_ctx))
self._fnc_tasks = utils.aio.TaskSet()
async def aclose(self) -> None:
if self._send_ch.closed:
return
self._send_ch.close()
await self._main_atask
@property
def playout_complete(self) -> asyncio.Event:
return self._playout_complete
@property
def fnc_ctx(self) -> llm.FunctionContext | None:
return self._fnc_ctx
@fnc_ctx.setter
def fnc_ctx(self, fnc_ctx: llm.FunctionContext | None) -> None:
self._fnc_ctx = fnc_ctx
@property
def conversation(self) -> Conversation:
return RealtimeSession.Conversation(self)
@property
def input_audio_buffer(self) -> InputAudioBuffer:
return RealtimeSession.InputAudioBuffer(self)
def _push_audio(self, frame: rtc.AudioFrame) -> None:
self.input_audio_buffer.append(frame)
@property
def response(self) -> Response:
return RealtimeSession.Response(self)
def session_update(
self,
*,
modalities: list[api_proto.Modality] | None = None,
instructions: str | None = None,
voice: api_proto.Voice | None = None,
input_audio_format: api_proto.AudioFormat | None = None,
output_audio_format: api_proto.AudioFormat | None = None,
input_audio_transcription: NotGivenOr[
InputTranscriptionOptions | None
] = NOT_GIVEN,
turn_detection: NotGivenOr[
Union[ServerVadOptions, SemanticVadOptions, None]
] = NOT_GIVEN,
tool_choice: api_proto.ToolChoice | None = None,
temperature: float | None = None,
max_response_output_tokens: int | Literal["inf"] | None = None,
) -> None:
self._opts = deepcopy(self._opts)
if modalities is not None:
self._opts.modalities = modalities
if instructions is not None:
self._opts.instructions = instructions
if voice is not None:
self._opts.voice = voice
if input_audio_format is not None:
self._opts.input_audio_format = input_audio_format
if output_audio_format is not None:
self._opts.output_audio_format = output_audio_format
if utils.is_given(input_audio_transcription):
self._opts.input_audio_transcription = input_audio_transcription
if utils.is_given(turn_detection):
self._opts.turn_detection = cast(
Union[ServerVadOptions, SemanticVadOptions, None], turn_detection
)
if tool_choice is not None:
self._opts.tool_choice = tool_choice
if temperature is not None:
self._opts.temperature = temperature
if max_response_output_tokens is not None:
self._opts.max_response_output_tokens = max_response_output_tokens
tools = []
if self._fnc_ctx is not None:
for fnc in self._fnc_ctx.ai_functions.values():
# the realtime API is using internally-tagged polymorphism.
# build_oai_function_description was built for the ChatCompletion API
function_data = build_oai_function_description(fnc)["function"]
function_data["type"] = "function"
tools.append(function_data)
server_vad_opts: Union[api_proto.ServerVad, api_proto.SemanticVad, None] = None
if self._opts.turn_detection is not None:
if isinstance(self._opts.turn_detection, ServerVadOptions):
server_vad_opts = {
"type": "server_vad",
"threshold": self._opts.turn_detection.threshold,
"prefix_padding_ms": self._opts.turn_detection.prefix_padding_ms,
"silence_duration_ms": self._opts.turn_detection.silence_duration_ms,
"create_response": self._opts.turn_detection.create_response,
}
elif isinstance(self._opts.turn_detection, SemanticVadOptions):
server_vad_opts = {
"type": "semantic_vad",
"eagerness": self._opts.turn_detection.eagerness.value,
"create_response": self._opts.turn_detection.create_response,
"interrupt_response": self._opts.turn_detection.interrupt_response,
}
input_audio_transcription_opts: api_proto.InputAudioTranscription | None = None
if self._opts.input_audio_transcription is not None:
input_audio_transcription_opts = {
"model": self._opts.input_audio_transcription.model,
}
if self._opts.input_audio_transcription.language is not None:
input_audio_transcription_opts["language"] = (
self._opts.input_audio_transcription.language
)
if self._opts.input_audio_transcription.prompt is not None:
input_audio_transcription_opts["prompt"] = (
self._opts.input_audio_transcription.prompt
)
session_data: api_proto.ClientEvent.SessionUpdateData = {
"modalities": self._opts.modalities,
"instructions": self._opts.instructions,
"voice": self._opts.voice,
"input_audio_format": self._opts.input_audio_format,
"output_audio_format": self._opts.output_audio_format,
"input_audio_transcription": input_audio_transcription_opts,
"turn_detection": server_vad_opts,
"tools": tools,
"tool_choice": self._opts.tool_choice,
"temperature": self._opts.temperature,
"max_response_output_tokens": None,
}
# azure doesn't support inf for max_response_output_tokens
if not self._opts.is_azure or isinstance(
self._opts.max_response_output_tokens, int
):
session_data["max_response_output_tokens"] = (
self._opts.max_response_output_tokens
)
else:
del session_data["max_response_output_tokens"] # type: ignore
self._queue_msg(
{
"type": "session.update",
"session": session_data,
}
)
def chat_ctx_copy(self) -> llm.ChatContext:
return self._remote_conversation_items.to_chat_context()
async def set_chat_ctx(self, new_ctx: llm.ChatContext) -> None:
"""Sync the chat context with the agent's chat context.
Compute the minimum number of insertions and deletions to transform the old
chat context messages to the new chat context messages.
"""
original_ctx = self._remote_conversation_items.to_chat_context()
def _validate_message(msg: llm.ChatMessage) -> bool:
# already exists in the remote conversation items
# or is a function call or has content
return (
self._remote_conversation_items.get(msg.id) is not None
or msg.tool_call_id is not None
or msg.content is not None
)
filtered_messages = list(filter(_validate_message, new_ctx.messages))
changes = utils._compute_changes(
original_ctx.messages, filtered_messages, key_fnc=lambda x: x.id
)
logger.debug(
"sync chat context",
extra={
"to_delete": [msg.id for msg in changes.to_delete],
"to_add": [
(prev.id if prev else None, msg.id) for prev, msg in changes.to_add
],
},
)
# append an empty audio message if all new messages are text
if changes.to_add and not any(
isinstance(msg.content, llm.ChatAudio) for _, msg in changes.to_add
):
# Patch: append an empty audio message to set the API in audio mode
changes.to_add.append((None, self._create_empty_user_audio_message(1.0)))
_futs = [
self.conversation.item.delete(item_id=msg.id) for msg in changes.to_delete
] + [
self.conversation.item.create(msg, prev.id if prev else None)
for prev, msg in changes.to_add
]
# wait for all the futures to complete
await asyncio.gather(*_futs)
def cancel_response(self) -> None:
if self._active_response_id:
self.response.cancel()
def create_response(
self,
on_duplicate: Literal[
"cancel_existing", "cancel_new", "keep_both"
] = "keep_both",
) -> None:
self.response.create(on_duplicate=on_duplicate)
def commit_audio_buffer(self) -> None:
self.input_audio_buffer.commit()
@property
def server_vad_enabled(self) -> bool:
return self._opts.turn_detection is not None
def _create_empty_user_audio_message(self, duration: float) -> llm.ChatMessage:
"""Create an empty audio message with the given duration."""
samples = int(duration * api_proto.SAMPLE_RATE)
return llm.ChatMessage(
role="user",
content=llm.ChatAudio(
frame=rtc.AudioFrame(
data=b"\x00\x00" * (samples * api_proto.NUM_CHANNELS),
sample_rate=api_proto.SAMPLE_RATE,
num_channels=api_proto.NUM_CHANNELS,
samples_per_channel=samples,
)
),
)
def _recover_from_text_response(self, item_id: str | None = None) -> None:
"""Try to recover from a text response to audio mode.
Sometimes the OpenAI Realtime API returns text instead of audio responses.
This method tries to recover from this by requesting a new response after
deleting the text response and creating an empty user audio message.
"""
if item_id:
# remove the text response if needed
self.conversation.item.delete(item_id=item_id)
self.conversation.item.create(self._create_empty_user_audio_message(1.0))
self.response.create(on_duplicate="keep_both")
def _truncate_conversation_item(
self, item_id: str, content_index: int, audio_end_ms: int
) -> None:
self.conversation.item.truncate(
item_id=item_id,
content_index=content_index,
audio_end_ms=audio_end_ms,
)
def _update_conversation_item_content(
self, item_id: str, content: llm.ChatContent | list[llm.ChatContent] | None
) -> None:
item = self._remote_conversation_items.get(item_id)
if item is None:
logger.warning(
"conversation item not found, skipping update",
extra={"item_id": item_id, "content": str(content)},
)
return
item.content = content
def _queue_msg(self, msg: api_proto.ClientEvents) -> None:
self._send_ch.send_nowait(msg)
@utils.log_exceptions(logger=logger)
async def _main_task(self) -> None:
try:
headers = {"User-Agent": "LiveKit Agents"}
query_params: dict[str, str] = {}
base_url = self._opts.base_url
if self._opts.is_azure:
if self._opts.entra_token:
headers["Authorization"] = f"Bearer {self._opts.entra_token}"
if self._opts.api_key:
headers["api-key"] = self._opts.api_key
if self._opts.api_version:
query_params["api-version"] = self._opts.api_version
if self._opts.azure_deployment:
query_params["deployment"] = self._opts.azure_deployment
else:
# OAI endpoint
headers["Authorization"] = f"Bearer {self._opts.api_key}"
headers["OpenAI-Beta"] = "realtime=v1"
if self._opts.model:
query_params["model"] = self._opts.model
url = f"{base_url.rstrip('/')}/realtime?{urlencode(query_params)}"
if url.startswith("http"):
url = url.replace("http", "ws", 1)
ws_conn = await self._http_session.ws_connect(
url,
headers=headers,
)
except Exception:
logger.exception("failed to connect to OpenAI API S2S")
return
closing = False
@utils.log_exceptions(logger=logger)
async def _send_task():
nonlocal closing
async for msg in self._send_ch:
await ws_conn.send_json(msg)
closing = True
await ws_conn.close()
@utils.log_exceptions(logger=logger)
async def _recv_task():
while True:
msg = await ws_conn.receive()
if msg.type in (
aiohttp.WSMsgType.CLOSED,
aiohttp.WSMsgType.CLOSE,
aiohttp.WSMsgType.CLOSING,
):
if closing:
return
raise Exception("OpenAI S2S connection closed unexpectedly")
if msg.type != aiohttp.WSMsgType.TEXT:
logger.warning(
"unexpected OpenAI S2S message type %s",
msg.type,
extra=self.logging_extra(),
)
continue
try:
data = msg.json()
event: api_proto.ServerEventType = data["type"]
if event == "session.created":
self._handle_session_created(data)
if event == "session.updated":
self._handle_session_updated(data)
elif event == "error":
self._handle_error(data)
elif event == "input_audio_buffer.speech_started":
self._handle_input_audio_buffer_speech_started(data)
elif event == "input_audio_buffer.speech_stopped":
self._handle_input_audio_buffer_speech_stopped(data)
elif event == "input_audio_buffer.committed":
self._handle_input_audio_buffer_speech_committed(data)
elif (
event == "conversation.item.input_audio_transcription.completed"
):
self._handle_conversation_item_input_audio_transcription_completed(
data
)
elif event == "conversation.item.input_audio_transcription.failed":
self._handle_conversation_item_input_audio_transcription_failed(
data
)
elif event == "conversation.item.created":
self._handle_conversation_item_created(data)
elif event == "conversation.item.deleted":
self._handle_conversation_item_deleted(data)
elif event == "conversation.item.truncated":
self._handle_conversation_item_truncated(data)
elif event == "response.created":
self._handle_response_created(data)
elif event == "response.output_item.added":
self._handle_response_output_item_added(data)
elif event == "response.content_part.added":
self._handle_response_content_part_added(data)
elif event == "response.audio.delta":
self._handle_response_audio_delta(data)
elif event == "response.audio_transcript.delta":
self._handle_response_audio_transcript_delta(data)
elif event == "response.audio.done":
self._handle_response_audio_done(data)
elif event == "response.text.done":
self._handle_response_text_done(data)
elif event == "response.audio_transcript.done":
self._handle_response_audio_transcript_done(data)
elif event == "response.content_part.done":
self._handle_response_content_part_done(data)
elif event == "response.output_item.done":
self._handle_response_output_item_done(data)
elif event == "response.done":
self._handle_response_done(data)
except Exception:
logger.exception(
"failed to handle OpenAI S2S message",
extra={"websocket_message": msg, **self.logging_extra()},
)
tasks = [
asyncio.create_task(_send_task(), name="openai-realtime-send"),
asyncio.create_task(_recv_task(), name="openai-realtime-recv"),
]
try:
await asyncio.gather(*tasks)
finally:
await utils.aio.gracefully_cancel(*tasks)
def _handle_session_created(
self, session_created: api_proto.ServerEvent.SessionCreated
):
self._session_id = session_created["session"]["id"]
def _handle_session_updated(
self, session_updated: api_proto.ServerEvent.SessionUpdated
):
session = session_updated["session"]
session_turn_detection = session["turn_detection"]
turn_detection: Union[ServerVadOptions, SemanticVadOptions, None] = None
if session_turn_detection is not None:
turn_detection_type = session_turn_detection.get("type")
if turn_detection_type == "server_vad":
session_turn_detection_opts = cast(
api_proto.ServerVad, session_turn_detection
)
turn_detection = ServerVadOptions(
threshold=session_turn_detection_opts["threshold"],
prefix_padding_ms=session_turn_detection_opts["prefix_padding_ms"],
silence_duration_ms=session_turn_detection_opts[
"silence_duration_ms"
],
create_response=session_turn_detection_opts["create_response"],
)
elif turn_detection_type == "semantic_vad":
session_turn_detection_semantic = cast(
api_proto.SemanticVad, session_turn_detection
)
eagerness_value = session_turn_detection_semantic.get(
"eagerness", "auto"
)
create_response_value = session_turn_detection_semantic.get(
"create_response", True
)
interrupt_response_value = session_turn_detection_semantic.get(
"interrupt_response", True
)
turn_detection = SemanticVadOptions(
eagerness=SemanticVadEagerness(eagerness_value),
create_response=bool(create_response_value),
interrupt_response=bool(interrupt_response_value),
)
else:
turn_detection = None
if session["input_audio_transcription"] is None:
input_audio_transcription = None
else:
input_audio_transcription = InputTranscriptionOptions(
model=session["input_audio_transcription"]["model"],
language=session["input_audio_transcription"].get("language"),
prompt=session["input_audio_transcription"].get("prompt"),
)
self.emit(
"session_updated",
RealtimeSessionOptions(
model=session["model"],
modalities=session["modalities"],
instructions=session["instructions"],
voice=session["voice"],
input_audio_format=session["input_audio_format"],
output_audio_format=session["output_audio_format"],
input_audio_transcription=input_audio_transcription,
turn_detection=turn_detection,
tool_choice=session["tool_choice"],
temperature=session["temperature"],
max_response_output_tokens=session["max_response_output_tokens"],
),
)
def _handle_error(self, error: api_proto.ServerEvent.Error):
logger.error(
"OpenAI S2S error %s",
error,
extra=self.logging_extra(),
)
error_content = error["error"]
self.emit(
"error",
RealtimeError(
event_id=error["event_id"],
type=error_content["type"],
message=error_content["message"],
code=error_content.get("code"),
param=error_content.get("param"),
),
)
def _handle_input_audio_buffer_speech_started(
self, speech_started: api_proto.ServerEvent.InputAudioBufferSpeechStarted
):
self.emit("input_speech_started")
def _handle_input_audio_buffer_speech_stopped(
self, speech_stopped: api_proto.ServerEvent.InputAudioBufferSpeechStopped
):
self.emit("input_speech_stopped")
def _handle_input_audio_buffer_speech_committed(
self, speech_committed: api_proto.ServerEvent.InputAudioBufferCommitted
):
self.emit("input_speech_committed")
def _handle_conversation_item_input_audio_transcription_completed(
self,
transcription_completed: api_proto.ServerEvent.ConversationItemInputAudioTranscriptionCompleted,
):
transcript = transcription_completed["transcript"]
self.emit(
"input_speech_transcription_completed",
InputTranscriptionCompleted(
item_id=transcription_completed["item_id"],
transcript=transcript,
),
)
def _handle_conversation_item_input_audio_transcription_failed(
self,
transcription_failed: api_proto.ServerEvent.ConversationItemInputAudioTranscriptionFailed,
):
error = transcription_failed["error"]
logger.error(
"OAI S2S failed to transcribe input audio: %s",
error["message"],
extra=self.logging_extra(),
)
self.emit(
"input_speech_transcription_failed",
InputTranscriptionFailed(
item_id=transcription_failed["item_id"],
message=error["message"],
),
)
def _handle_conversation_item_created(
self, item_created: api_proto.ServerEvent.ConversationItemCreated
):
previous_item_id = item_created["previous_item_id"]
item = item_created["item"]
item_type = item["type"]
item_id = item["id"]
# Create message based on item type
# Leave the content empty and fill it in later from the content parts
if item_type == "message":
# Handle message items (system/user/assistant)
item = cast(Union[api_proto.SystemItem, api_proto.UserItem], item)
role = item["role"]
message = llm.ChatMessage(id=item_id, role=role)
if item.get("content"):
content = item["content"][0]
if content["type"] in ("text", "input_text"):
content = cast(api_proto.InputTextContent, content)
message.content = content["text"]
elif content["type"] == "input_audio" and content.get("audio"):
audio_data = base64.b64decode(content["audio"])
message.content = llm.ChatAudio(
frame=rtc.AudioFrame(
data=audio_data,
sample_rate=api_proto.SAMPLE_RATE,
num_channels=api_proto.NUM_CHANNELS,
samples_per_channel=len(audio_data) // 2,
)
)
elif item_type == "function_call":
# Handle function call items
item = cast(api_proto.FunctionCallItem, item)
message = llm.ChatMessage(
id=item_id,
role="assistant",
name=item["name"],
tool_call_id=item["call_id"],
)
elif item_type == "function_call_output":
# Handle function call output items
item = cast(api_proto.FunctionCallOutputItem, item)
message = llm.ChatMessage(
id=item_id,
role="tool",
tool_call_id=item["call_id"],
content=item["output"],
)
else:
logger.error(
f"unknown conversation item type {item_type}",
extra=self.logging_extra(),
)
return
# Insert into conversation items
self._remote_conversation_items.insert_after(previous_item_id, message)
if item_id in self._item_created_futs:
self._item_created_futs[item_id].set_result(True)
del self._item_created_futs[item_id]
logger.debug("conversation item created", extra=item_created)
def _handle_conversation_item_deleted(
self, item_deleted: api_proto.ServerEvent.ConversationItemDeleted
):
# Delete from conversation items
item_id = item_deleted["item_id"]
self._remote_conversation_items.delete(item_id)
if item_id in self._item_deleted_futs:
self._item_deleted_futs[item_id].set_result(True)
del self._item_deleted_futs[item_id]
logger.debug("conversation item deleted", extra=item_deleted)
def _handle_conversation_item_truncated(
self, item_truncated: api_proto.ServerEvent.ConversationItemTruncated
):
item_id = item_truncated["item_id"]
if item_id in self._item_truncated_futs:
self._item_truncated_futs[item_id].set_result(True)
del self._item_truncated_futs[item_id]
def _handle_response_created(
self, response_created: api_proto.ServerEvent.ResponseCreated
):
response = response_created["response"]
done_fut = self._loop.create_future()
status_details = response.get("status_details")
metadata = cast(map, response.get("metadata"))
new_response = RealtimeResponse(
id=response["id"],
status=response["status"],
status_details=status_details,
output=[],
metadata=metadata,
usage=response.get("usage"),
done_fut=done_fut,
_created_timestamp=time.time(),
)
self._pending_responses[new_response.id] = new_response
self._active_response_id = new_response.id
# complete the create future if it exists
if self._response_create_fut is not None:
self._response_create_fut.set_result(None)
self._response_create_fut = None
self.emit("response_created", new_response)
def _handle_response_output_item_added(
self, response_output_added: api_proto.ServerEvent.ResponseOutputItemAdded
):
response_id = response_output_added["response_id"]
response = self._pending_responses[response_id]
done_fut = self._loop.create_future()
item_data = response_output_added["item"]
item_type: Literal["message", "function_call"] = item_data["type"] # type: ignore
assert item_type in ("message", "function_call")
# function_call doesn't have a role field, defaulting it to assistant
item_role: api_proto.Role = item_data.get("role") or "assistant" # type: ignore
new_output = RealtimeOutput(
response_id=response_id,
item_id=item_data["id"],
output_index=response_output_added["output_index"],
type=item_type,
role=item_role,
content=[],
done_fut=done_fut,
)
response.output.append(new_output)
self.emit("response_output_added", new_output)
def _handle_response_content_part_added(
self, response_content_added: api_proto.ServerEvent.ResponseContentPartAdded
):
response_id = response_content_added["response_id"]
response = self._pending_responses[response_id]
output_index = response_content_added["output_index"]
output = response.output[output_index]
content_type = response_content_added["part"]["type"]
text_ch = utils.aio.Chan[str]()
audio_ch = utils.aio.Chan[rtc.AudioFrame]()
new_content = RealtimeContent(
response_id=response_id,
item_id=response_content_added["item_id"],
output_index=output_index,
content_index=response_content_added["content_index"],
text="",
audio=[],
text_stream=text_ch,
audio_stream=audio_ch,
tool_calls=[],
content_type=content_type,
)
output.content.append(new_content)
response._first_token_timestamp = time.time()
self.emit("response_content_added", new_content)
def _handle_response_audio_delta(
self, response_audio_delta: api_proto.ServerEvent.ResponseAudioDelta
):
content = self._get_content(response_audio_delta)
data = base64.b64decode(response_audio_delta["delta"])
audio = rtc.AudioFrame(
data=data,
sample_rate=api_proto.SAMPLE_RATE,
num_channels=api_proto.NUM_CHANNELS,
samples_per_channel=len(data) // 2,
)
content.audio.append(audio)
assert isinstance(content.audio_stream, utils.aio.Chan)
content.audio_stream.send_nowait(audio)
def _handle_response_audio_transcript_delta(
self,
response_audio_transcript_delta: api_proto.ServerEvent.ResponseAudioTranscriptDelta,
):
content = self._get_content(response_audio_transcript_delta)
transcript = response_audio_transcript_delta["delta"]
content.text += transcript
assert isinstance(content.text_stream, utils.aio.Chan)
content.text_stream.send_nowait(transcript)
def _handle_response_audio_done(
self, response_audio_done: api_proto.ServerEvent.ResponseAudioDone
):
content = self._get_content(response_audio_done)
assert isinstance(content.audio_stream, utils.aio.Chan)
content.audio_stream.close()
def _handle_response_text_done(
self, response_text_done: api_proto.ServerEvent.ResponseTextDone
):
content = self._get_content(response_text_done)
content.text = response_text_done["text"]
def _handle_response_audio_transcript_done(
self,
response_audio_transcript_done: api_proto.ServerEvent.ResponseAudioTranscriptDone,
):
content = self._get_content(response_audio_transcript_done)
assert isinstance(content.text_stream, utils.aio.Chan)
content.text_stream.close()
def _handle_response_content_part_done(
self, response_content_done: api_proto.ServerEvent.ResponseContentPartDone
):
content = self._get_content(response_content_done)
self.emit("response_content_done", content)
def _handle_response_output_item_done(
self, response_output_done: api_proto.ServerEvent.ResponseOutputItemDone
):
response_id = response_output_done["response_id"]
response = self._pending_responses[response_id]
output_index = response_output_done["output_index"]
output = response.output[output_index]
if output.type == "function_call":
if self._fnc_ctx is None:
logger.error(
"function call received but no fnc_ctx is available",
extra=self.logging_extra(),
)
return
# parse the arguments and call the function inside the fnc_ctx
item = response_output_done["item"]
assert item["type"] == "function_call"
fnc_call_info = _create_ai_function_info(
self._fnc_ctx,
item["call_id"],
item["name"],
item["arguments"],
)
msg = self._remote_conversation_items.get(output.item_id)
if msg is not None:
# update the content of the message
assert msg.tool_call_id == item["call_id"]
assert msg.role == "assistant"
msg.name = item["name"]
msg.tool_calls = [fnc_call_info]
self.emit("function_calls_collected", [fnc_call_info])
self._fnc_tasks.create_task(
self._run_fnc_task(fnc_call_info, output.item_id)
)
output.done_fut.set_result(None)
self.emit("response_output_done", output)
def _handle_response_done(self, response_done: api_proto.ServerEvent.ResponseDone):
response_data = response_done["response"]
response_id = response_data["id"]
response = self._pending_responses[response_id]
self._active_response_id = None
response.done_fut.set_result(None)
response.status = response_data["status"]
response.status_details = response_data.get("status_details")
response.metadata = cast(map, response_data.get("metadata"))
response.output = cast(list[RealtimeOutput], response_data.get("output"))
response.usage = response_data.get("usage")
metrics_error = None
cancelled = False
if response.status == "failed":
assert response.status_details is not None
error = response.status_details.get("error", {})
code: str | None = error.get("code") # type: ignore
message: str | None = error.get("message") # type: ignore
metrics_error = MultimodalLLMError(
type=response.status_details.get("type"),
code=code,
message=message,
)
logger.error(
"response generation failed",
extra={"code": code, "error": message, **self.logging_extra()},
)
elif response.status == "incomplete":
assert response.status_details is not None
reason = response.status_details.get("reason")
metrics_error = MultimodalLLMError(
type=response.status_details.get("type"),
reason=reason, # type: ignore
)
logger.warning(
"response generation incomplete",
extra={"reason": reason, **self.logging_extra()},
)
elif response.status == "cancelled":
cancelled = True
self.emit("response_done", response)
# calculate metrics
ttft = -1.0
if response._first_token_timestamp is not None:
ttft = response._first_token_timestamp - response._created_timestamp
duration = time.time() - response._created_timestamp
usage = response.usage or {} # type: ignore
input_token_details = usage.get("input_token_details", {})
metrics = MultimodalLLMMetrics(
timestamp=response._created_timestamp,
request_id=response.id,
ttft=ttft,
duration=duration,
cancelled=cancelled,
label=self._label,
completion_tokens=usage.get("output_tokens", 0),
prompt_tokens=usage.get("input_tokens", 0),
total_tokens=usage.get("total_tokens", 0),
tokens_per_second=usage.get("output_tokens", 0) / duration,
error=metrics_error,
input_token_details=MultimodalLLMMetrics.InputTokenDetails(
cached_tokens=input_token_details.get("cached_tokens", 0),
text_tokens=usage.get("input_token_details", {}).get("text_tokens", 0),
audio_tokens=usage.get("input_token_details", {}).get(
"audio_tokens", 0
),
cached_tokens_details=MultimodalLLMMetrics.CachedTokenDetails(
text_tokens=input_token_details.get(
"cached_tokens_details", {}
).get("text_tokens", 0),
audio_tokens=input_token_details.get(
"cached_tokens_details", {}
).get("audio_tokens", 0),
),
),
output_token_details=MultimodalLLMMetrics.OutputTokenDetails(
text_tokens=usage.get("output_token_details", {}).get("text_tokens", 0),
audio_tokens=usage.get("output_token_details", {}).get(
"audio_tokens", 0
),
),
)
self.emit("metrics_collected", metrics)
def _get_content(self, ptr: _ContentPtr) -> RealtimeContent:
response = self._pending_responses[ptr["response_id"]]
output = response.output[ptr["output_index"]]
content = output.content[ptr["content_index"]]
return content
@utils.log_exceptions(logger=logger)
async def _run_fnc_task(self, fnc_call_info: llm.FunctionCallInfo, item_id: str):
logger.debug(
"executing ai function",
extra={
"function": fnc_call_info.function_info.name,
},
)
called_fnc = fnc_call_info.execute()
await called_fnc.task
# wait for the audio to be played before creating the response
await self._playout_complete.wait()
tool_call = llm.ChatMessage.create_tool_from_called_function(called_fnc)
logger.info(
"creating response for tool call",
extra={
"function": fnc_call_info.function_info.name,
},
)
if tool_call.content is not None:
create_fut = self.conversation.item.create(
tool_call,
previous_item_id=item_id,
)
await self.response.create(on_duplicate="keep_both")
await create_fut
# update the message with the tool call result
msg = self._remote_conversation_items.get(tool_call.id)
if msg is not None:
assert msg.tool_call_id == tool_call.tool_call_id
assert msg.role == "tool"
msg.name = tool_call.name
msg.content = tool_call.content
msg.tool_exception = tool_call.tool_exception
self.emit("function_calls_finished", [called_fnc])
def logging_extra(self) -> dict:
return {"session_id": self._session_id}
from __future__ import annotations
from collections import OrderedDict
from dataclasses import dataclass, field
from typing import Optional
from livekit.agents import llm
from .log import logger
@dataclass
class _ConversationItem:
"""A node in the conversation linked list"""
message: llm.ChatMessage
_prev: Optional[_ConversationItem] = field(default=None, repr=False)
_next: Optional[_ConversationItem] = field(default=None, repr=False)
class _RemoteConversationItems:
"""Manages conversation items in a doubly-linked list"""
def __init__(self) -> None:
self._head: Optional[_ConversationItem] = None
self._tail: Optional[_ConversationItem] = None
self._id_to_item: OrderedDict[str, _ConversationItem] = OrderedDict()
@classmethod
def from_chat_context(cls, chat_ctx: llm.ChatContext) -> _RemoteConversationItems:
"""Create ConversationItems from a ChatContext"""
items = cls()
for msg in chat_ctx.messages:
items.append(msg)
return items
def to_chat_context(self) -> llm.ChatContext:
"""Export to a ChatContext"""
chat_ctx = llm.ChatContext()
current = self._head
while current:
chat_ctx.messages.append(current.message.copy())
current = current._next
return chat_ctx
def append(self, message: llm.ChatMessage) -> None:
"""Add a message to the end of the conversation"""
if message.id is None:
raise ValueError("Message must have an id")
if message.id in self._id_to_item:
raise ValueError(f"Message with id {message.id} already exists")
item = _ConversationItem(message=message)
item._prev = self._tail
item._next = None
if self._tail:
self._tail._next = item
self._tail = item
if not self._head:
self._head = item
self._id_to_item[message.id] = item
def insert_after(self, prev_item_id: str | None, message: llm.ChatMessage) -> None:
"""Insert a message after the specified message ID.
If prev_item_id is None, append to the end."""
if message.id is None:
raise ValueError("Message must have an id")
if message.id in self._id_to_item:
raise ValueError(f"Message with id {message.id} already exists")
if prev_item_id is None:
# Append to end instead of inserting at head
self.append(message)
return
prev_item = self._id_to_item.get(prev_item_id)
if not prev_item:
logger.error(
f"Previous message with id {prev_item_id} not found, ignore it"
)
return
new_item = _ConversationItem(message=message)
new_item._prev = prev_item
new_item._next = prev_item._next
prev_item._next = new_item
if new_item._next:
new_item._next._prev = new_item
else:
self._tail = new_item
self._id_to_item[message.id] = new_item
def delete(self, item_id: str) -> None:
"""Delete a message by its ID"""
item = self._id_to_item.get(item_id)
if not item:
logger.error(f"Message with id {item_id} not found for deletion")
return
if item._prev:
item._prev._next = item._next
else:
self._head = item._next
if item._next:
item._next._prev = item._prev
else:
self._tail = item._prev
del self._id_to_item[item_id]
def get(self, item_id: str) -> llm.ChatMessage | None:
"""Get a message by its ID"""
item = self._id_to_item.get(item_id)
return item.message if item else None
@property
def messages(self) -> list[llm.ChatMessage]:
"""Return all messages in order"""
return [item.message for item in self._id_to_item.values()]
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import asyncio
import base64
import dataclasses
import json
import os
import time
import weakref
from dataclasses import dataclass
from typing import Any
from urllib.parse import urlencode
import aiohttp
import httpx
from livekit import rtc
from livekit.agents import (
DEFAULT_API_CONNECT_OPTIONS,
APIConnectionError,
APIConnectOptions,
APIStatusError,
APITimeoutError,
stt,
utils,
)
from livekit.agents.utils import AudioBuffer
import openai
from openai.types.audio import TranscriptionVerbose
from openai.types.beta.realtime.transcription_session_update_param import (
SessionTurnDetection,
)
from .log import logger
from .models import GroqAudioModels, STTModels
# OpenAI Realtime API has a timeout of 15 mins, we'll attempt to restart the session
# before that timeout is reached
_max_session_duration = 10 * 60
# emit interim transcriptions every 0.5 seconds
_delta_transcript_interval = 0.5
SAMPLE_RATE = 24000
NUM_CHANNELS = 1
@dataclass
class _STTOptions:
model: STTModels | str
language: str
detect_language: bool
turn_detection: SessionTurnDetection
prompt: str | None = None
noise_reduction_type: str | None = None
class STT(stt.STT):
def __init__(
self,
*,
language: str = "en",
detect_language: bool = False,
model: STTModels | str = "gpt-4o-transcribe",
prompt: str | None = None,
turn_detection: SessionTurnDetection | None = None,
noise_reduction_type: str | None = None,
base_url: str | None = None,
api_key: str | None = None,
client: openai.AsyncClient | None = None,
use_realtime: bool = False,
):
"""
Create a new instance of OpenAI STT.
Args:
language: The language code to use for transcription (e.g., "en" for English).
detect_language: Whether to automatically detect the language.
model: The OpenAI model to use for transcription.
prompt: Optional text prompt to guide the transcription. Only supported for whisper-1.
turn_detection: When using realtime transcription, this controls how model detects the user is done speaking.
Final transcripts are generated only after the turn is over. See: https://platform.openai.com/docs/guides/realtime-vad
noise_reduction_type: Type of noise reduction to apply. "near_field" or "far_field"
This isn't needed when using LiveKit's noise cancellation.
base_url: Custom base URL for OpenAI API.
api_key: Your OpenAI API key. If not provided, will use the OPENAI_API_KEY environment variable.
client: Optional pre-configured OpenAI AsyncClient instance.
use_realtime: Whether to use the realtime transcription API. (default: False)
"""
super().__init__(
capabilities=stt.STTCapabilities(
streaming=use_realtime, interim_results=use_realtime
)
)
if detect_language:
language = ""
if turn_detection is None:
turn_detection = {
"type": "server_vad",
"threshold": 0.5,
"prefix_padding_ms": 600,
"silence_duration_ms": 350,
}
self._opts = _STTOptions(
language=language,
detect_language=detect_language,
model=model,
prompt=prompt,
turn_detection=turn_detection,
)
if noise_reduction_type is not None:
self._opts.noise_reduction_type = noise_reduction_type
self._client = client or openai.AsyncClient(
max_retries=0,
api_key=api_key,
base_url=base_url,
http_client=httpx.AsyncClient(
timeout=httpx.Timeout(connect=15.0, read=5.0, write=5.0, pool=5.0),
follow_redirects=True,
limits=httpx.Limits(
max_connections=50,
max_keepalive_connections=50,
keepalive_expiry=120,
),
),
)
self._streams = weakref.WeakSet[SpeechStream]()
self._session: aiohttp.ClientSession | None = None
self._pool = utils.ConnectionPool[aiohttp.ClientWebSocketResponse](
max_session_duration=_max_session_duration,
connect_cb=self._connect_ws,
close_cb=self._close_ws,
)
@staticmethod
def with_groq(
*,
model: GroqAudioModels | str = "whisper-large-v3-turbo",
api_key: str | None = None,
base_url: str | None = "https://api.groq.com/openai/v1",
client: openai.AsyncClient | None = None,
language: str = "en",
prompt: str | None = None,
detect_language: bool = False,
) -> STT:
"""
Create a new instance of Groq STT.
``api_key`` must be set to your Groq API key, either using the argument or by setting
the ``GROQ_API_KEY`` environmental variable.
"""
api_key = api_key or os.environ.get("GROQ_API_KEY")
if api_key is None:
raise ValueError("Groq API key is required")
return STT(
model=model,
api_key=api_key,
base_url=base_url,
client=client,
language=language,
detect_language=detect_language,
prompt=prompt,
use_realtime=False,
)
def stream(
self,
*,
language: str | None = None,
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
) -> "SpeechStream":
config = self._sanitize_options(language=language)
stream = SpeechStream(
stt=self,
opts=config,
pool=self._pool,
)
self._streams.add(stream)
return stream
def update_options(
self,
*,
model: STTModels | GroqAudioModels | str | None = None,
language: str | None = None,
prompt: str | None = None,
turn_detection: SessionTurnDetection | None = None,
noise_reduction_type: str | None = None,
) -> None:
self._opts.model = model or self._opts.model
self._opts.language = language or self._opts.language
self._opts.prompt = prompt or self._opts.prompt
self._opts.noise_reduction_type = (
noise_reduction_type or self._opts.noise_reduction_type
)
self._opts.turn_detection = turn_detection or self._opts.turn_detection
for stream in self._streams:
stream.update_options(language=language or self._opts.language)
async def _connect_ws(self) -> aiohttp.ClientWebSocketResponse:
realtime_config: dict[str, Any] = {
"type": "transcription_session.update",
"session": {
"input_audio_format": "pcm16",
"input_audio_transcription": {
"model": self._opts.model,
"prompt": self._opts.prompt or "",
},
"turn_detection": self._opts.turn_detection,
},
}
if self._opts.language:
realtime_config["session"]["input_audio_transcription"]["language"] = (
self._opts.language
)
if self._opts.noise_reduction_type:
realtime_config["session"]["input_audio_noise_reduction"] = {
"type": self._opts.noise_reduction_type
}
query_params: dict[str, str] = {
"intent": "transcription",
}
headers = {
"User-Agent": "LiveKit Agents",
"Authorization": f"Bearer {self._client.api_key}",
"OpenAI-Beta": "realtime=v1",
}
url = f"{str(self._client.base_url).rstrip('/')}/realtime?{urlencode(query_params)}"
if url.startswith("http"):
url = url.replace("http", "ws", 1)
session = self._ensure_session()
ws = await asyncio.wait_for(
session.ws_connect(url, headers=headers),
DEFAULT_API_CONNECT_OPTIONS.timeout,
)
await ws.send_json(realtime_config)
return ws
async def _close_ws(self, ws: aiohttp.ClientWebSocketResponse):
await ws.close()
def _ensure_session(self) -> aiohttp.ClientSession:
if not self._session:
self._session = utils.http_context.http_session()
return self._session
def _sanitize_options(self, *, language: str | None = None) -> _STTOptions:
config = dataclasses.replace(self._opts)
config.language = language or config.language
return config
async def _recognize_impl(
self,
buffer: AudioBuffer,
*,
language: str | None,
conn_options: APIConnectOptions,
) -> stt.SpeechEvent:
try:
config = self._sanitize_options(language=language)
data = rtc.combine_audio_frames(buffer).to_wav_bytes()
prompt = (
self._opts.prompt if self._opts.prompt is not None else openai.NOT_GIVEN
)
format = "json"
if self._opts.model == "whisper-1":
# verbose_json returns language and other details, only supported for whisper-1
format = "verbose_json"
resp = await self._client.audio.transcriptions.create(
file=(
"file.wav",
data,
"audio/wav",
),
model=self._opts.model, # type: ignore
language=config.language,
prompt=prompt,
response_format=format,
timeout=httpx.Timeout(30, connect=conn_options.timeout),
)
sd = stt.SpeechData(text=resp.text, language=config.language)
if isinstance(resp, TranscriptionVerbose) and resp.language:
sd.language = resp.language
return stt.SpeechEvent(
type=stt.SpeechEventType.FINAL_TRANSCRIPT,
alternatives=[sd],
)
except openai.APITimeoutError:
raise APITimeoutError()
except openai.APIStatusError as e:
raise APIStatusError(
e.message,
status_code=e.status_code,
request_id=e.request_id,
body=e.body,
)
except Exception as e:
raise APIConnectionError() from e
class SpeechStream(stt.SpeechStream):
def __init__(
self,
*,
stt: STT,
opts: _STTOptions,
pool: utils.ConnectionPool[aiohttp.ClientWebSocketResponse],
) -> None:
super().__init__(
stt=stt, conn_options=DEFAULT_API_CONNECT_OPTIONS, sample_rate=SAMPLE_RATE
)
self._pool = pool
self._opts = opts
self._request_id = ""
self._reconnect_event = asyncio.Event()
def update_options(
self,
*,
language: str | None = None,
):
"""
Update the options for the speech stream. Most options are updated at the
connection level. SpeechStreams will be recreated when options are updated.
Args:
language: The language to transcribe in.
"""
self._opts.language = language or self._opts.language
self._reconnect_event.set()
@utils.log_exceptions(logger=logger)
async def _run(self) -> None:
closing_ws = False
@utils.log_exceptions(logger=logger)
async def send_task(ws: aiohttp.ClientWebSocketResponse):
nonlocal closing_ws
# forward audio to OAI in chunks of 50ms
audio_bstream = utils.audio.AudioByteStream(
sample_rate=SAMPLE_RATE,
num_channels=NUM_CHANNELS,
samples_per_channel=SAMPLE_RATE // 20,
)
async for data in self._input_ch:
frames: list[rtc.AudioFrame] = []
if isinstance(data, rtc.AudioFrame):
frames.extend(audio_bstream.write(data.data.tobytes()))
elif isinstance(data, self._FlushSentinel):
frames.extend(audio_bstream.flush())
for frame in frames:
encoded_frame = {
"type": "input_audio_buffer.append",
"audio": base64.b64encode(frame.data.tobytes()).decode("utf-8"),
}
await ws.send_json(encoded_frame)
closing_ws = True
@utils.log_exceptions(logger=logger)
async def recv_task(ws: aiohttp.ClientWebSocketResponse):
nonlocal closing_ws
current_text = ""
last_interim_at: float = 0
connected_at = time.time()
while True:
msg = await ws.receive()
if msg.type in (
aiohttp.WSMsgType.CLOSED,
aiohttp.WSMsgType.CLOSE,
aiohttp.WSMsgType.CLOSING,
):
if closing_ws: # close is expected, see SpeechStream.aclose
return
# this will trigger a reconnection, see the _run loop
raise APIStatusError(
message="OpenAI Realtime STT connection closed unexpectedly"
)
if msg.type != aiohttp.WSMsgType.TEXT:
logger.warning("unexpected OpenAI message type %s", msg.type)
continue
try:
data = json.loads(msg.data)
msg_type = data.get("type")
if msg_type == "conversation.item.input_audio_transcription.delta":
delta = data.get("delta", "")
if delta:
current_text += delta
if (
time.time() - last_interim_at
> _delta_transcript_interval
):
self._event_ch.send_nowait(
stt.SpeechEvent(
type=stt.SpeechEventType.INTERIM_TRANSCRIPT,
alternatives=[
stt.SpeechData(
text=current_text,
language=self._opts.language,
)
],
)
)
last_interim_at = time.time()
elif (
msg_type
== "conversation.item.input_audio_transcription.completed"
):
current_text = ""
transcript = data.get("transcript", "")
if transcript:
self._event_ch.send_nowait(
stt.SpeechEvent(
type=stt.SpeechEventType.FINAL_TRANSCRIPT,
alternatives=[
stt.SpeechData(
text=transcript,
language=self._opts.language,
)
],
)
)
# restart session if needed
if time.time() - connected_at > _max_session_duration:
logger.info("resetting Realtime STT session due to timeout")
self._pool.remove(ws)
self._reconnect_event.set()
return
except Exception:
logger.exception("failed to process OpenAI message")
while True:
async with self._pool.connection() as ws:
tasks = [
asyncio.create_task(send_task(ws)),
asyncio.create_task(recv_task(ws)),
]
wait_reconnect_task = asyncio.create_task(self._reconnect_event.wait())
try:
done, _ = await asyncio.wait(
[asyncio.gather(*tasks), wait_reconnect_task],
return_when=asyncio.FIRST_COMPLETED,
) # type: ignore
# propagate exceptions from completed tasks
for task in done:
if task != wait_reconnect_task:
task.result()
if wait_reconnect_task not in done:
break
self._reconnect_event.clear()
finally:
await utils.aio.gracefully_cancel(*tasks, wait_reconnect_task)
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import asyncio
from dataclasses import dataclass
from typing import Optional
import httpx
from livekit.agents import (
APIConnectionError,
APIConnectOptions,
APIStatusError,
APITimeoutError,
tts,
utils,
)
import openai
from .log import logger
from .models import TTSModels, TTSVoices
from .utils import AsyncAzureADTokenProvider
OPENAI_TTS_SAMPLE_RATE = 48000
OPENAI_TTS_CHANNELS = 1
DEFAULT_MODEL = "gpt-4o-mini-tts"
DEFAULT_VOICE = "ash"
@dataclass
class _TTSOptions:
model: TTSModels | str
voice: TTSVoices | str
speed: float
instructions: Optional[str] = None
class TTS(tts.TTS):
def __init__(
self,
*,
model: TTSModels | str = DEFAULT_MODEL,
voice: TTSVoices | str = DEFAULT_VOICE,
speed: float = 1.0,
instructions: Optional[str] = None,
base_url: str | None = None,
api_key: str | None = None,
client: openai.AsyncClient | None = None,
) -> None:
"""
Create a new instance of OpenAI TTS.
``api_key`` must be set to your OpenAI API key, either using the argument or by setting the
``OPENAI_API_KEY`` environmental variable.
"""
super().__init__(
capabilities=tts.TTSCapabilities(
streaming=False,
),
sample_rate=OPENAI_TTS_SAMPLE_RATE,
num_channels=OPENAI_TTS_CHANNELS,
)
self._opts = _TTSOptions(
model=model,
voice=voice,
speed=speed,
instructions=instructions,
)
self._client = client or openai.AsyncClient(
max_retries=0,
api_key=api_key,
base_url=base_url,
http_client=httpx.AsyncClient(
timeout=httpx.Timeout(connect=15.0, read=5.0, write=5.0, pool=5.0),
follow_redirects=True,
limits=httpx.Limits(
max_connections=50,
max_keepalive_connections=50,
keepalive_expiry=120,
),
),
)
def update_options(
self,
*,
model: TTSModels | str | None,
voice: TTSVoices | str | None,
speed: float | None,
instructions: Optional[str] = None,
) -> None:
self._opts.model = model or self._opts.model
self._opts.voice = voice or self._opts.voice
self._opts.speed = speed or self._opts.speed
self._opts.instructions = instructions or self._opts.instructions
@staticmethod
def create_azure_client(
*,
model: TTSModels | str = DEFAULT_MODEL,
voice: TTSVoices | str = DEFAULT_VOICE,
speed: float = 1.0,
instructions: str | None = None,
azure_endpoint: str | None = None,
azure_deployment: str | None = None,
api_version: str | None = None,
api_key: str | None = None,
azure_ad_token: str | None = None,
azure_ad_token_provider: AsyncAzureADTokenProvider | None = None,
organization: str | None = None,
project: str | None = None,
base_url: str | None = None,
) -> TTS:
"""
This automatically infers the following arguments from their corresponding environment variables if they are not provided:
- `api_key` from `AZURE_OPENAI_API_KEY`
- `organization` from `OPENAI_ORG_ID`
- `project` from `OPENAI_PROJECT_ID`
- `azure_ad_token` from `AZURE_OPENAI_AD_TOKEN`
- `api_version` from `OPENAI_API_VERSION`
- `azure_endpoint` from `AZURE_OPENAI_ENDPOINT`
"""
azure_client = openai.AsyncAzureOpenAI(
max_retries=0,
azure_endpoint=azure_endpoint,
azure_deployment=azure_deployment,
api_version=api_version,
api_key=api_key,
azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
organization=organization,
project=project,
base_url=base_url,
) # type: ignore
return TTS(
model=model,
voice=voice,
speed=speed,
instructions=instructions,
client=azure_client,
)
def synthesize(
self,
text: str,
*,
conn_options: Optional[APIConnectOptions] = None,
) -> "ChunkedStream":
return ChunkedStream(
tts=self,
input_text=text,
conn_options=conn_options,
opts=self._opts,
client=self._client,
)
class ChunkedStream(tts.ChunkedStream):
def __init__(
self,
*,
tts: TTS,
input_text: str,
conn_options: Optional[APIConnectOptions] = None,
opts: _TTSOptions,
client: openai.AsyncClient,
) -> None:
super().__init__(tts=tts, input_text=input_text, conn_options=conn_options)
self._client = client
self._opts = opts
async def _run(self):
oai_stream = self._client.audio.speech.with_streaming_response.create(
input=self.input_text,
model=self._opts.model,
voice=self._opts.voice,
response_format="opus",
speed=self._opts.speed,
instructions=self._opts.instructions
if self._opts.instructions
else openai.NOT_GIVEN,
timeout=httpx.Timeout(30, connect=self._conn_options.timeout),
)
request_id = utils.shortuuid()
decoder = utils.codecs.AudioStreamDecoder(
sample_rate=OPENAI_TTS_SAMPLE_RATE,
num_channels=OPENAI_TTS_CHANNELS,
)
@utils.log_exceptions(logger=logger)
async def _decode_loop():
try:
async with oai_stream as stream:
async for data in stream.iter_bytes():
decoder.push(data)
finally:
decoder.end_input()
decode_task = asyncio.create_task(_decode_loop())
try:
emitter = tts.SynthesizedAudioEmitter(
event_ch=self._event_ch,
request_id=request_id,
)
async for frame in decoder:
emitter.push(frame)
emitter.flush()
except openai.APITimeoutError:
raise APITimeoutError()
except openai.APIStatusError as e:
raise APIStatusError(
e.message,
status_code=e.status_code,
request_id=e.request_id,
body=e.body,
)
except Exception as e:
raise APIConnectionError() from e
finally:
await utils.aio.gracefully_cancel(decode_task)
await decoder.aclose()
from __future__ import annotations
import base64
import json
import os
from typing import Any, Awaitable, Callable, Optional, Union
from livekit import rtc
from livekit.agents import llm, utils
AsyncAzureADTokenProvider = Callable[[], Union[str, Awaitable[str]]]
def get_base_url(base_url: Optional[str]) -> str:
if not base_url:
base_url = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
return base_url
def build_oai_message(msg: llm.ChatMessage, cache_key: Any):
oai_msg: dict[str, Any] = {"role": msg.role}
if msg.name:
oai_msg["name"] = msg.name
# add content if provided
if isinstance(msg.content, str):
oai_msg["content"] = msg.content
elif isinstance(msg.content, dict):
oai_msg["content"] = json.dumps(msg.content)
elif isinstance(msg.content, list):
oai_content: list[dict[str, Any]] = []
for cnt in msg.content:
if isinstance(cnt, str):
oai_content.append({"type": "text", "text": cnt})
elif isinstance(cnt, llm.ChatImage):
oai_content.append(_build_oai_image_content(cnt, cache_key))
oai_msg["content"] = oai_content
# make sure to provide when function has been called inside the context
# (+ raw_arguments)
if msg.tool_calls is not None:
tool_calls: list[dict[str, Any]] = []
oai_msg["tool_calls"] = tool_calls
for fnc in msg.tool_calls:
tool_calls.append(
{
"id": fnc.tool_call_id,
"type": "function",
"function": {
"name": fnc.function_info.name,
"arguments": fnc.raw_arguments,
},
}
)
# tool_call_id is set when the message is a response/result to a function call
# (content is a string in this case)
if msg.tool_call_id:
oai_msg["tool_call_id"] = msg.tool_call_id
return oai_msg
def _build_oai_image_content(image: llm.ChatImage, cache_key: Any):
if isinstance(image.image, str): # image url
return {
"type": "image_url",
"image_url": {"url": image.image, "detail": image.inference_detail},
}
elif isinstance(image.image, rtc.VideoFrame): # VideoFrame
if cache_key not in image._cache:
# inside our internal implementation, we allow to put extra metadata to
# each ChatImage (avoid to reencode each time we do a chatcompletion request)
opts = utils.images.EncodeOptions()
if image.inference_width and image.inference_height:
opts.resize_options = utils.images.ResizeOptions(
width=image.inference_width,
height=image.inference_height,
strategy="scale_aspect_fit",
)
encoded_data = utils.images.encode(image.image, opts)
image._cache[cache_key] = base64.b64encode(encoded_data).decode("utf-8")
return {
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image._cache[cache_key]}",
"detail": image.inference_detail,
},
}
raise ValueError(
"LiveKit OpenAI Plugin: ChatImage must be an rtc.VideoFrame or a URL"
)
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "0.12.3"
{
"name": "livekit-plugins-openai",
"private": true,
"version": "0.12.3"
}
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pathlib
import setuptools
import setuptools.command.build_py
here = pathlib.Path(__file__).parent.resolve()
about = {}
with open(os.path.join(here, "livekit", "plugins", "openai", "version.py"), "r") as f:
exec(f.read(), about)
setuptools.setup(
name="livekit-plugins-openai",
version=about["__version__"],
description="Agent Framework plugin for services from OpenAI",
long_description=(here / "README.md").read_text(encoding="utf-8"),
long_description_content_type="text/markdown",
url="https://github.com/livekit/agents",
cmdclass={},
classifiers=[
"Intended Audience :: Developers",
"License :: OSI Approved :: Apache Software License",
"Topic :: Multimedia :: Sound/Audio",
"Topic :: Multimedia :: Video",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3 :: Only",
],
keywords=["webrtc", "realtime", "audio", "video", "livekit"],
license="Apache-2.0",
packages=setuptools.find_namespace_packages(include=["livekit.*"]),
python_requires=">=3.9.0",
install_requires=[
"livekit-agents[codecs, images]>=0.12.16,<1.0.0",
"openai>=1.68.2",
],
extras_require={
"vertex": ["google-auth>=2.0.0"],
},
package_data={"livekit.plugins.openai": ["py.typed"]},
project_urls={
"Documentation": "https://docs.livekit.io",
"Website": "https://livekit.io/",
"Source": "https://github.com/livekit/agents",
},
)
# livekit-plugins-playht
## 1.0.9
### Patch Changes
- use streaming AudioDecoder to handle compressed encoding - [#1584](https://github.com/livekit/agents/pull/1584) ([@davidzhao](https://github.com/davidzhao))
- updated livekit-agent reference to <1.0 - [#1607](https://github.com/livekit/agents/pull/1607) ([@davidzhao](https://github.com/davidzhao))
## 1.0.8
### Patch Changes
- remove update options from tts synthesis stream - [#1546](https://github.com/livekit/agents/pull/1546) ([@jayeshp19](https://github.com/jayeshp19))
## 1.0.7
### Patch Changes
- PlayAI plugin: bump Python SDK version (fix websockets interrupt handling) - [#1427](https://github.com/livekit/agents/pull/1427) ([@bryananderson](https://github.com/bryananderson))
- improved TTFB metrics for streaming TTS - [#1431](https://github.com/livekit/agents/pull/1431) ([@davidzhao](https://github.com/davidzhao))
## 1.0.6
### Patch Changes
- fix: Avoid websocket reconnections for each request - [#1387](https://github.com/livekit/agents/pull/1387) ([@jayeshp19](https://github.com/jayeshp19))
## 1.0.5
### Patch Changes
- playai: enable streaming TTS - [#1340](https://github.com/livekit/agents/pull/1340) ([@davidzhao](https://github.com/davidzhao))
## 1.0.4
### Patch Changes
- Support PlayAI TTS engine. - [#1174](https://github.com/livekit/agents/pull/1174) ([@jayeshp19](https://github.com/jayeshp19))
## 1.0.3
### Patch Changes
- fix release - [#1176](https://github.com/livekit/agents/pull/1176) ([@theomonnom](https://github.com/theomonnom))
## 1.0.2
### Patch Changes
- fix(playht): add sample_rate parameter to JSON payload - [#1141](https://github.com/livekit/agents/pull/1141) ([@imsakg](https://github.com/imsakg))
- feat: tts retry & tts.FallbackAdapter - [#1074](https://github.com/livekit/agents/pull/1074) ([@theomonnom](https://github.com/theomonnom))
- feat(playht): add Play3.0-mini engine support - [#1140](https://github.com/livekit/agents/pull/1140) ([@imsakg](https://github.com/imsakg))
## 1.0.1
### Patch Changes
- pipelineagent: expose timing metrics & api errors wip - [#957](https://github.com/livekit/agents/pull/957) ([@theomonnom](https://github.com/theomonnom))
- expose usage metrics - [#984](https://github.com/livekit/agents/pull/984) ([@theomonnom](https://github.com/theomonnom))
# LiveKit Plugins PlayAI/PlayHT
Agent Framework plugin for voice synthesis with [PlayAI](https://play.ai/) API.
## Installation
```bash
pip install livekit-plugins-playai
You’ll need USER ID and API Secret KEY from PlayHT. It can be set as an environment variable: PLAYHT_USER_ID
, PLAYHT_API_KEY
get it from here
## livekit-plugins/livekit-plugins-playai/livekit/plugins/playai/__init__.py
```py
from .tts import TTS
from .version import __version__
__all__ = [
"TTS",
"__version__",
]
from livekit.agents import Plugin
class PlayAIPlugin(Plugin):
def __init__(self) -> None:
super().__init__(__name__, __version__, __package__)
Plugin.register_plugin(PlayAIPlugin())
# Cleanup docs of unexported modules
_module = dir()
NOT_IN_ALL = [m for m in _module if m not in __all__]
__pdoc__ = {}
for n in NOT_IN_ALL:
__pdoc__[n] = False
import logging
logger = logging.getLogger("livekit.plugins.playai")
# suppress verbose websocket logs
logging.getLogger("websockets.client").setLevel(logging.INFO)
from typing import Literal
from pyht.client import Format # type: ignore
TTSModel = Literal["Play3.0-mini", "PlayDialog"]
FORMAT = Literal["mp3"]
format_mapping = {
"mp3": Format.FORMAT_MP3,
}
from __future__ import annotations
import asyncio
import os
import weakref
from dataclasses import dataclass, fields
from typing import Optional
from livekit.agents import (
APIConnectionError,
APIConnectOptions,
tokenize,
tts,
utils,
)
from pyht import AsyncClient as PlayHTAsyncClient # type: ignore
from pyht.client import Format, Language, TTSOptions # type: ignore
from .log import logger
from .models import TTSModel
NUM_CHANNELS = 1
@dataclass
class _Options:
model: TTSModel | str
tts_options: TTSOptions
word_tokenizer: tokenize.WordTokenizer
class TTS(tts.TTS):
def __init__(
self,
*,
api_key: str | None = None,
user_id: str | None = None,
voice: str = "s3://voice-cloning-zero-shot/d9ff78ba-d016-47f6-b0ef-dd630f59414e/female-cs/manifest.json",
language: str = "english",
sample_rate: int = 24000,
model: TTSModel | str = "Play3.0-mini",
word_tokenizer: tokenize.WordTokenizer = tokenize.basic.WordTokenizer(
ignore_punctuation=False
),
**kwargs,
) -> None:
"""
Initialize the PlayAI TTS engine.
Args:
api_key (str): PlayAI API key.
user_id (str): PlayAI user ID.
voice (str): Voice manifest URL.
model (TTSModel): TTS model, defaults to "Play3.0-mini".
language (str): language, defaults to "english".
sample_rate (int): sample rate (Hz), A number greater than or equal to 8000, and must be less than or equal to 48000
word_tokenizer (tokenize.WordTokenizer): Tokenizer for processing text. Defaults to basic WordTokenizer.
**kwargs: Additional options.
"""
super().__init__(
capabilities=tts.TTSCapabilities(
streaming=True,
),
sample_rate=sample_rate,
num_channels=1,
)
api_key = api_key or os.environ.get("PLAYHT_API_KEY")
user_id = user_id or os.environ.get("PLAYHT_USER_ID")
if not api_key or not user_id:
raise ValueError(
"PlayHT API key and user ID are required. Set environment variables PLAYHT_API_KEY and PLAYHT_USER_ID or pass them explicitly."
)
_validate_kwargs(kwargs)
self._config = TTSOptions(
voice=voice,
format=Format.FORMAT_OGG, # Using OGG format for AudioDecoder
sample_rate=sample_rate,
language=Language(language),
**kwargs,
)
self._opts = _Options(
model=model,
tts_options=self._config,
word_tokenizer=word_tokenizer,
)
self._client = PlayHTAsyncClient(
user_id=user_id,
api_key=api_key,
)
self._streams = weakref.WeakSet[SynthesizeStream]()
def update_options(
self,
*,
voice: str | None = None,
model: TTSModel | str | None = None,
language: str | None = None,
**kwargs,
) -> None:
"""
Update the TTS options.
"""
updates = {}
if voice is not None:
updates["voice"] = voice
if language is not None:
updates["language"] = Language(language)
updates.update(kwargs)
_validate_kwargs(updates)
for key, value in updates.items():
if value is not None:
setattr(self._config, key, value)
if model is not None:
self._opts.model = model
def synthesize(
self,
text: str,
*,
conn_options: Optional[APIConnectOptions] = None,
) -> "ChunkedStream":
return ChunkedStream(
tts=self,
input_text=text,
conn_options=conn_options,
opts=self._opts,
)
def stream(
self, *, conn_options: Optional[APIConnectOptions] = None
) -> "SynthesizeStream":
stream = SynthesizeStream(
tts=self,
conn_options=conn_options,
opts=self._opts,
)
self._streams.add(stream)
return stream
class ChunkedStream(tts.ChunkedStream):
def __init__(
self,
*,
tts: TTS,
input_text: str,
opts: _Options,
conn_options: Optional[APIConnectOptions] = None,
) -> None:
super().__init__(tts=tts, input_text=input_text, conn_options=conn_options)
self._client = tts._client
self._opts = opts
self._config = self._opts.tts_options
async def _run(self) -> None:
request_id = utils.shortuuid()
decoder = utils.codecs.AudioStreamDecoder(
sample_rate=self._config.sample_rate,
num_channels=NUM_CHANNELS,
)
decode_task: Optional[asyncio.Task] = None
try:
# Create a task to push data to the decoder
async def _decode_loop():
try:
async for chunk in self._client.tts(
text=self._input_text,
options=self._config,
voice_engine=self._opts.model,
protocol="http",
streaming=True,
):
decoder.push(chunk)
finally:
decoder.end_input()
decode_task = asyncio.create_task(_decode_loop())
emitter = tts.SynthesizedAudioEmitter(
event_ch=self._event_ch,
request_id=request_id,
)
async for frame in decoder:
emitter.push(frame)
emitter.flush()
except Exception as e:
raise APIConnectionError() from e
finally:
if decode_task:
await utils.aio.gracefully_cancel(decode_task)
await decoder.aclose()
class SynthesizeStream(tts.SynthesizeStream):
def __init__(
self,
*,
tts: TTS,
opts: _Options,
conn_options: Optional[APIConnectOptions] = None,
):
super().__init__(tts=tts, conn_options=conn_options)
self._client = tts._client
self._opts = opts
self._config = self._opts.tts_options
self._segments_ch = utils.aio.Chan[tokenize.WordStream]()
async def _run(self) -> None:
request_id = utils.shortuuid()
segment_id = utils.shortuuid()
input_task = asyncio.create_task(self._tokenize_input())
try:
text_stream = await self._create_text_stream()
decoder = utils.codecs.AudioStreamDecoder(
sample_rate=self._config.sample_rate,
num_channels=NUM_CHANNELS,
)
# Create tasks for pushing data to decoder and generating events
async def decode_loop():
try:
async for chunk in self._client.stream_tts_input(
text_stream=text_stream,
options=self._config,
voice_engine=self._opts.model,
protocol="ws",
):
decoder.push(chunk)
finally:
decoder.end_input()
decode_task = asyncio.create_task(decode_loop())
try:
emitter = tts.SynthesizedAudioEmitter(
event_ch=self._event_ch,
request_id=request_id,
segment_id=segment_id,
)
async for frame in decoder:
emitter.push(frame)
emitter.flush()
finally:
await utils.aio.gracefully_cancel(decode_task)
await decoder.aclose()
except Exception as e:
raise APIConnectionError() from e
finally:
await utils.aio.gracefully_cancel(input_task)
@utils.log_exceptions(logger=logger)
async def _tokenize_input(self):
# Converts incoming text into WordStreams and sends them into _segments_ch
word_stream = None
async for input in self._input_ch:
if isinstance(input, str):
if word_stream is None:
word_stream = self._opts.word_tokenizer.stream()
self._segments_ch.send_nowait(word_stream)
word_stream.push_text(input)
elif isinstance(input, self._FlushSentinel):
if word_stream:
word_stream.end_input()
word_stream = None
self._segments_ch.close()
@utils.log_exceptions(logger=logger)
async def _create_text_stream(self):
async def text_stream():
async for word_stream in self._segments_ch:
async for word in word_stream:
self._mark_started()
yield word.token
return text_stream()
def _validate_kwargs(kwargs: dict) -> None:
valid_keys = {field.name for field in fields(TTSOptions)}
invalid_keys = set(kwargs.keys()) - valid_keys
if invalid_keys:
raise ValueError(
f"Invalid parameters: {invalid_keys}. Allowed parameters: {valid_keys}"
)
__version__ = "1.0.9"
{
"name": "livekit-plugins-playai",
"private": true,
"version": "1.0.9"
}
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
import os
import pathlib
import setuptools
import setuptools.command.build_py
here = pathlib.Path(__file__).parent.resolve()
about = {}
with open(os.path.join(here, "livekit", "plugins", "playai", "version.py"), "r") as f:
exec(f.read(), about)
setuptools.setup(
name="livekit-plugins-playai",
version=about["__version__"],
description="Agent Framework plugin for voice synthesis with PlayAI's API.",
long_description=(here / "README.md").read_text(encoding="utf-8"),
long_description_content_type="text/markdown",
url="https://github.com/livekit/agents",
cmdclass={},
classifiers=[
"Intended Audience :: Developers",
"Topic :: Multimedia :: Sound/Audio",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3 :: Only",
],
keywords=["webrtc", "realtime", "audio", "livekit", "playHT", "playAI"],
license="Apache-2.0",
packages=setuptools.find_namespace_packages(include=["livekit.*"]),
python_requires=">=3.9.0",
install_requires=[
"livekit-agents[codecs]>=0.12.16,<1.0.0",
"pyht>=0.1.12",
"aiohttp",
"livekit",
],
package_data={"livekit.plugins.playai": ["py.typed"]},
project_urls={
"Documentation": "https://docs.livekit.io",
"Website": "https://livekit.io/",
"Source": "https://github.com/livekit/agents",
},
)
# livekit-plugins-rag
## 0.2.4
### Patch Changes
- updated livekit-agent reference to <1.0 - [#1607](https://github.com/livekit/agents/pull/1607) ([@davidzhao](https://github.com/davidzhao))
## 0.2.3
### Patch Changes
- fix release - [#1176](https://github.com/livekit/agents/pull/1176) ([@theomonnom](https://github.com/theomonnom))
## 0.2.2
### Patch Changes
- rag: fix backward compatibility - [#629](https://github.com/livekit/agents/pull/629) ([@afigar](https://github.com/afigar))
## 0.2.1
### Patch Changes
- rag: add missing logger file - [#571](https://github.com/livekit/agents/pull/571) ([@theomonnom](https://github.com/theomonnom))
## 0.2.0
### Minor Changes
- dev prerelease - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
### Patch Changes
- test release - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
- pull: '--rebase --autostash ...' - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
- Default loglevel to warn - [#472](https://github.com/livekit/agents/pull/472) ([@lukasIO](https://github.com/lukasIO))
- bump versions to update dependencies - [#510](https://github.com/livekit/agents/pull/510) ([@theomonnom](https://github.com/theomonnom))
- test release - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
- fix changesets release CI - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
- release v0.8.0 - [`6e74aa714c2dfaa8212db4528d7b59d095b6c660`](https://github.com/livekit/agents/commit/6e74aa714c2dfaa8212db4528d7b59d095b6c660) ([@theomonnom](https://github.com/theomonnom))
- dev fixes - multiprocessing & voiceassistant - [#493](https://github.com/livekit/agents/pull/493) ([@theomonnom](https://github.com/theomonnom))
## 0.2.0-dev.7
### Patch Changes
- pull: '--rebase --autostash ...' - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
## 0.2.0-dev.6
### Patch Changes
- test release - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
## 0.2.0-dev.5
### Patch Changes
- test release - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
## 0.2.0-dev.4
### Patch Changes
- fix changesets release CI - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
## 0.2.0-dev.3
### Patch Changes
- bump versions to update dependencies - [#510](https://github.com/livekit/agents/pull/510) ([@theomonnom](https://github.com/theomonnom))
## 0.2.0-dev.2
### Patch Changes
- dev fixes - multiprocessing & voiceassistant - [#493](https://github.com/livekit/agents/pull/493) ([@theomonnom](https://github.com/theomonnom))
## 0.2.0-dev.1
### Minor Changes
- dev prerelease - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
## 0.1.1-dev.0
### Patch Changes
- Default loglevel to warn - [#472](https://github.com/livekit/agents/pull/472) ([@lukasIO](https://github.com/lukasIO))
# LiveKit Plugins RAG
Agent Framework plugin for RAG utilities.
## Installation
```bash
pip install livekit-plugins-rag
## livekit-plugins/livekit-plugins-rag/livekit/plugins/rag/__init__.py
```py
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import annoy
from .chunking import SentenceChunker
from .version import __version__
__all__ = ["SentenceChunker", "annoy", "__version__"]
from livekit.agents import Plugin
from .log import logger
class RAGPlugin(Plugin):
def __init__(self) -> None:
super().__init__(__name__, __version__, __package__, logger)
def download_files(self) -> None:
pass
Plugin.register_plugin(RAGPlugin())
# Cleanup docs of unexported modules
_module = dir()
NOT_IN_ALL = [m for m in _module if m not in __all__]
__pdoc__ = {}
for n in NOT_IN_ALL:
__pdoc__[n] = False
import pathlib
import pickle
from dataclasses import dataclass
from typing import Any, Iterable, Literal
import annoy
# https://github.com/spotify/annoy
__all__ = ["AnnoyIndex", "IndexBuilder", "Item", "Metric"]
Metric = Literal["angular", "euclidean", "manhattan", "hamming", "dot"]
ANNOY_FILE = "index.annoy"
METADATA_FILE = "metadata.pkl"
@dataclass
class Item:
i: int
userdata: Any
vector: list[float]
@dataclass
class _FileData:
f: int
metric: Metric
userdata: dict[int, Any]
@dataclass
class QueryResult:
userdata: Any
distance: float
class AnnoyIndex:
def __init__(self, index: annoy.AnnoyIndex, filedata: _FileData) -> None:
self._index = index
self._filedata = filedata
@classmethod
def load(cls, path: str) -> "AnnoyIndex":
p = pathlib.Path(path)
index_path = p / ANNOY_FILE
metadata_path = p / METADATA_FILE
with open(metadata_path, "rb") as f:
metadata: _FileData = pickle.load(f)
index = annoy.AnnoyIndex(metadata.f, metadata.metric)
index.load(str(index_path))
return cls(index, metadata)
@property
def size(self) -> int:
return self._index.get_n_items()
def items(self) -> Iterable[Item]:
for i in range(self._index.get_n_items()):
item = Item(
i=i,
userdata=self._filedata.userdata[i],
vector=self._index.get_item_vector(i),
)
yield item
def query(
self, vector: list[float], n: int, search_k: int = -1
) -> list[QueryResult]:
ids = self._index.get_nns_by_vector(
vector, n, search_k=search_k, include_distances=True
)
return [
QueryResult(userdata=self._filedata.userdata[i], distance=distance)
for i, distance in zip(*ids)
]
class IndexBuilder:
def __init__(self, f: int, metric: Metric) -> None:
self._index = annoy.AnnoyIndex(f, metric)
self._filedata = _FileData(f=f, metric=metric, userdata={})
self._i = 0
def save(self, path: str) -> None:
p = pathlib.Path(path)
p.mkdir(parents=True, exist_ok=True)
index_path = p / ANNOY_FILE
metadata_path = p / METADATA_FILE
self._index.save(str(index_path))
with open(metadata_path, "wb") as f:
pickle.dump(self._filedata, f)
def build(self, trees: int = 50, jobs: int = -1) -> AnnoyIndex:
# n_jobs=-1 means use all available cores
self._index.build(n_trees=trees, n_jobs=jobs)
return AnnoyIndex(self._index, self._filedata)
def add_item(self, vector: list[float], userdata: Any) -> None:
self._index.add_item(self._i, vector)
self._filedata.userdata[self._i] = userdata
self._i += 1
from typing import Callable
from livekit.agents import tokenize
class SentenceChunker:
def __init__(
self,
*,
max_chunk_size: int = 120,
chunk_overlap: int = 30,
paragraph_tokenizer: Callable[
[str], list[str]
] = tokenize.basic.tokenize_paragraphs,
sentence_tokenizer: tokenize.SentenceTokenizer = tokenize.basic.SentenceTokenizer(),
word_tokenizer: tokenize.WordTokenizer = tokenize.basic.WordTokenizer(
ignore_punctuation=False
),
) -> None:
self._max_chunk_size = max_chunk_size
self._chunk_overlap = chunk_overlap
self._paragraph_tokenizer = paragraph_tokenizer
self._sentence_tokenizer = sentence_tokenizer
self._word_tokenizer = word_tokenizer
def chunk(self, *, text: str) -> list[str]:
chunks = []
buf_words: list[str] = []
for paragraph in self._paragraph_tokenizer(text):
last_buf_words: list[str] = []
for sentence in self._sentence_tokenizer.tokenize(text=paragraph):
for word in self._word_tokenizer.tokenize(text=sentence):
reconstructed = self._word_tokenizer.format_words(
buf_words + [word]
)
if len(reconstructed) > self._max_chunk_size:
while (
len(self._word_tokenizer.format_words(last_buf_words))
> self._chunk_overlap
):
last_buf_words = last_buf_words[1:]
new_chunk = self._word_tokenizer.format_words(
last_buf_words + buf_words
)
chunks.append(new_chunk)
last_buf_words = buf_words
buf_words = []
buf_words.append(word)
if buf_words:
while (
len(self._word_tokenizer.format_words(last_buf_words))
> self._chunk_overlap
):
last_buf_words = last_buf_words[1:]
new_chunk = self._word_tokenizer.format_words(
last_buf_words + buf_words
)
chunks.append(new_chunk)
buf_words = []
return chunks
import logging
logger = logging.getLogger("livekit.plugins.rag")
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "0.2.4"
{
"name": "livekit-plugins-rag",
"private": true,
"version": "0.2.4"
}
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pathlib
import setuptools
import setuptools.command.build_py
here = pathlib.Path(__file__).parent.resolve()
about = {}
with open(os.path.join(here, "livekit", "plugins", "rag", "version.py"), "r") as f:
exec(f.read(), about)
setuptools.setup(
name="livekit-plugins-rag",
version=about["__version__"],
description="Agent Framework plugin for RAG",
long_description=(here / "README.md").read_text(encoding="utf-8"),
long_description_content_type="text/markdown",
url="https://github.com/livekit/agents",
cmdclass={},
classifiers=[
"Intended Audience :: Developers",
"License :: OSI Approved :: Apache Software License",
"Topic :: Multimedia :: Sound/Audio",
"Topic :: Multimedia :: Video",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3 :: Only",
],
keywords=["webrtc", "realtime", "audio", "video", "livekit"],
license="Apache-2.0",
packages=setuptools.find_namespace_packages(include=["livekit.*"]),
python_requires=">=3.9.0",
install_requires=["livekit-agents>=0.12.16,<1.0.0", "annoy>=1.17"],
package_data={"livekit.plugins.rag": ["py.typed"]},
project_urls={
"Documentation": "https://docs.livekit.io",
"Website": "https://livekit.io/",
"Source": "https://github.com/livekit/agents",
},
)
# livekit-plugins-resemble
## 0.1.1
### Patch Changes
- release Resemble.ai TTS - [#1833](https://github.com/livekit/agents/pull/1833) ([@davidzhao](https://github.com/davidzhao))
# LiveKit Plugins Resemble
Agent Framework plugin for voice synthesis with the [Resemble AI](https://www.resemble.ai/) API, using both their REST API and WebSocket streaming interface.
## Installation
```bash
pip install livekit-plugins-resemble
You’ll need an API key from Resemble AI. It can be set as an environment variable: RESEMBLE_API_KEY
Additionally, you’ll need the voice UUID from your Resemble AI account.
import asyncio
from livekit.plugins.resemble import TTS
async def run_tts_example():
# Use TTS with async context manager for automatic resource cleanup
async with TTS(
api_key="your_api_key", # or set RESEMBLE_API_KEY environment variable
voice_uuid="your_voice_uuid",
# Optional parameters
sample_rate=44100, # Sample rate in Hz (default: 44100)
precision="PCM_16", # Audio precision (PCM_32, PCM_24, PCM_16, MULAW)
output_format="wav", # Output format (wav or mp3)
) as tts:
# One-off synthesis (uses REST API)
audio_stream = tts.synthesize("Hello, world!")
# Process chunks as they arrive
async for chunk in audio_stream:
# Audio data is in the 'frame.data' attribute of SynthesizedAudio objects
audio_data = chunk.frame.data
print(f"Received chunk: {len(audio_data)} bytes")
# Alternative: collect all audio at once into a single AudioFrame
audio_stream = tts.synthesize("Another example sentence.")
audio_frame = await audio_stream.collect()
print(f"Collected complete audio: {len(audio_frame.data)} bytes")
# Real-time streaming synthesis (uses WebSocket API)
# Only available for Business plan users in Resemble AI
stream = tts.stream()
await stream.synthesize_text("Hello, world!")
# Run the example
asyncio.run(run_tts_example())
If you prefer to manage resources manually, make sure to properly clean up:
import asyncio
from livekit.plugins.resemble import TTS
async def run_tts_example():
# Initialize TTS with your credentials
tts = TTS(
api_key="your_api_key",
voice_uuid="your_voice_uuid",
)
try:
# TTS operations
audio_stream = tts.synthesize("Hello, world!")
async for chunk in audio_stream:
# Access audio data correctly
process_audio(chunk.frame.data)
finally:
# Always clean up resources when done
await tts.aclose()
# Run the example
asyncio.run(run_tts_example())
When using this plugin outside of the LiveKit agent framework, it’s important to properly manage the TTS instance lifecycle:
async with TTS(...) as tts:
)await tts.aclose()
in a finally blockhttp_session
parameter:import aiohttp
async def with_custom_session():
async with aiohttp.ClientSession() as session:
async with TTS(
api_key="your_api_key",
voice_uuid="your_voice_uuid",
http_session=session
) as tts:
# Use TTS...
# No need to manually close anything - context managers handle it all
This plugin uses two different approaches to generate speech:
The WebSocket streaming API is only available for Resemble AI Business plan users.
## livekit-plugins/livekit-plugins-resemble/livekit/plugins/resemble/__init__.py
```py
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .tts import TTS, ChunkedStream, SynthesizeStream
from .version import __version__
__all__ = ["TTS", "ChunkedStream", "SynthesizeStream", "__version__"]
from livekit.agents import Plugin
class ResemblePlugin(Plugin):
def __init__(self) -> None:
super().__init__(__name__, __version__, __package__)
Plugin.register_plugin(ResemblePlugin())
# Cleanup docs of unexported modules
_module = dir()
NOT_IN_ALL = [m for m in _module if m not in __all__]
__pdoc__ = {}
for n in NOT_IN_ALL:
__pdoc__[n] = False
import logging
logger = logging.getLogger("livekit.plugins.resemble")
from enum import Enum
class Precision(str, Enum):
PCM_16 = "PCM_16"
# Copyright 2025 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import asyncio
import base64
import json
import os
import weakref
from dataclasses import dataclass
from typing import Optional
import aiohttp
from livekit.agents import (
APIConnectionError,
APIConnectOptions,
APIStatusError,
APITimeoutError,
tokenize,
tts,
utils,
)
from livekit.agents.types import DEFAULT_API_CONNECT_OPTIONS
from .log import logger
RESEMBLE_WEBSOCKET_URL = "wss://websocket.cluster.resemble.ai/stream"
RESEMBLE_REST_API_URL = "https://f.cluster.resemble.ai/synthesize"
NUM_CHANNELS = 1
DEFAULT_VOICE_UUID = "55592656"
BUFFERED_WORDS_COUNT = 3
@dataclass
class _TTSOptions:
voice_uuid: str
sample_rate: int
tokenizer: tokenize.SentenceTokenizer
class TTS(tts.TTS):
def __init__(
self,
*,
api_key: str | None = None,
voice_uuid: str | None = None,
tokenizer: tokenize.SentenceTokenizer | None = None,
sample_rate: int = 44100,
http_session: aiohttp.ClientSession | None = None,
use_streaming: bool = True,
) -> None:
"""
Create a new instance of the Resemble TTS.
See https://docs.app.resemble.ai/docs/text_to_speech/ for more documentation on all of these options.
Args:
voice_uuid (str, optional): The voice UUID for the desired voice. Defaults to None.
sample_rate (int, optional): The audio sample rate in Hz. Defaults to 44100.
api_key (str | None, optional): The Resemble API key. If not provided, it will be read from the RESEMBLE_API_KEY environment variable.
http_session (aiohttp.ClientSession | None, optional): An existing aiohttp ClientSession to use. If not provided, a new session will be created.
tokenizer (tokenize.SentenceTokenizer, optional): The tokenizer to use. Defaults to tokenize.SentenceTokenizer().
use_streaming (bool, optional): Whether to use streaming or not. Defaults to True.
""" # noqa: E501
super().__init__(
capabilities=tts.TTSCapabilities(streaming=use_streaming),
sample_rate=sample_rate,
num_channels=NUM_CHANNELS,
)
api_key = api_key or os.environ.get("RESEMBLE_API_KEY")
if not api_key:
raise ValueError(
"Resemble API key is required, either as argument or set RESEMBLE_API_KEY environment variable"
)
self._api_key = api_key
if tokenizer is None:
tokenizer = tokenize.basic.SentenceTokenizer(
min_sentence_len=BUFFERED_WORDS_COUNT
)
if voice_uuid is None:
voice_uuid = DEFAULT_VOICE_UUID
self._opts = _TTSOptions(
voice_uuid=voice_uuid,
sample_rate=sample_rate,
tokenizer=tokenizer,
)
self._session = http_session
self._streams = weakref.WeakSet[SynthesizeStream]()
self._pool = utils.ConnectionPool[aiohttp.ClientWebSocketResponse](
connect_cb=self._connect_ws,
close_cb=self._close_ws,
)
async def _connect_ws(self) -> aiohttp.ClientWebSocketResponse:
session = self._ensure_session()
return await asyncio.wait_for(
session.ws_connect(
RESEMBLE_WEBSOCKET_URL,
headers={"Authorization": f"Bearer {self._api_key}"},
),
self._conn_options.timeout,
)
async def _close_ws(self, ws: aiohttp.ClientWebSocketResponse):
await ws.close()
def _ensure_session(self) -> aiohttp.ClientSession:
if not self._session:
self._session = utils.http_context.http_session()
return self._session
def prewarm(self) -> None:
self._pool.prewarm()
def update_options(
self,
*,
voice_uuid: str | None = None,
sample_rate: int | None = None,
) -> None:
"""
Update the Text-to-Speech (TTS) configuration options.
Args:
voice_uuid (str, optional): The voice UUID for the desired voice.
sample_rate (int, optional): The audio sample rate in Hz.
""" # noqa: E501
self._opts.voice_uuid = voice_uuid or self._opts.voice_uuid
self._opts.sample_rate = sample_rate or self._opts.sample_rate
def synthesize(
self,
text: str,
*,
conn_options: Optional[APIConnectOptions] = None,
) -> ChunkedStream:
return ChunkedStream(
tts=self,
input_text=text,
conn_options=conn_options or DEFAULT_API_CONNECT_OPTIONS,
opts=self._opts,
api_key=self._api_key,
session=self._ensure_session(),
)
def stream(
self, *, conn_options: Optional[APIConnectOptions] = None
) -> SynthesizeStream:
stream = SynthesizeStream(
tts=self,
pool=self._pool,
opts=self._opts,
api_key=self._api_key,
)
self._streams.add(stream)
return stream
async def aclose(self) -> None:
for stream in list(self._streams):
await stream.aclose()
self._streams.clear()
await self._pool.aclose()
await super().aclose()
class ChunkedStream(tts.ChunkedStream):
"""Synthesize text into speech in one go using Resemble AI's REST API."""
def __init__(
self,
*,
tts: TTS,
input_text: str,
opts: _TTSOptions,
conn_options: APIConnectOptions,
api_key: str,
session: aiohttp.ClientSession,
) -> None:
super().__init__(tts=tts, input_text=input_text, conn_options=conn_options)
self._opts, self._session, self._api_key = opts, session, api_key
async def _run(self) -> None:
request_id = utils.shortuuid()
# Create request headers
headers = {
"Authorization": f"Bearer {self._api_key}",
"Content-Type": "application/json",
"Accept": "application/json", # Expect JSON response
}
# Create request payload
payload = {
"voice_uuid": self._opts.voice_uuid,
"data": self._input_text,
"sample_rate": self._opts.sample_rate,
"precision": "PCM_16",
}
decoder = utils.codecs.AudioStreamDecoder(
sample_rate=self._opts.sample_rate,
num_channels=NUM_CHANNELS,
)
try:
async with self._session.post(
RESEMBLE_REST_API_URL,
headers=headers,
json=payload,
timeout=aiohttp.ClientTimeout(
total=30,
sock_connect=self._conn_options.timeout,
),
) as response:
response.raise_for_status()
response_json = await response.json()
# Check for success
if not response_json.get("success", False):
issues = response_json.get("issues", ["Unknown error"])
error_msg = "; ".join(issues)
raise APIStatusError(
message=f"Resemble API returned failure: {error_msg}",
status_code=response.status,
request_id=request_id,
body=json.dumps(response_json),
)
# Extract base64-encoded audio content
audio_content_b64 = response_json.get("audio_content")
if not audio_content_b64:
raise APIStatusError(
message="No audio content in response",
status_code=response.status,
request_id=request_id,
body=json.dumps(response_json),
)
# Decode base64 to get raw audio bytes
audio_bytes = base64.b64decode(audio_content_b64)
decoder.push(audio_bytes)
decoder.end_input()
emitter = tts.SynthesizedAudioEmitter(
event_ch=self._event_ch,
request_id=request_id,
)
async for frame in decoder:
emitter.push(frame)
emitter.flush()
except aiohttp.ClientResponseError as e:
raise APIStatusError(
message=e.message,
status_code=e.status,
request_id=request_id,
body=f"resemble api error: {str(e)}",
) from e
except asyncio.TimeoutError as e:
raise APITimeoutError() from e
except aiohttp.ClientError as e:
raise APIConnectionError(
message=f"Resemble API connection error: {str(e)}",
) from e
except Exception as e:
raise APIConnectionError(f"Error during synthesis: {str(e)}") from e
finally:
await decoder.aclose()
class SynthesizeStream(tts.SynthesizeStream):
"""Stream-based text-to-speech synthesis using Resemble AI WebSocket API.
This implementation connects to Resemble's WebSocket API for real-time streaming
synthesis. Note that this requires a Business plan subscription with Resemble AI.
"""
def __init__(
self,
*,
tts: TTS,
opts: _TTSOptions,
pool: utils.ConnectionPool[aiohttp.ClientWebSocketResponse],
api_key: str,
):
super().__init__(tts=tts)
self._opts, self._pool, self._api_key = opts, pool, api_key
async def _run(self) -> None:
request_id = utils.shortuuid()
self._segments_ch = utils.aio.Chan[tokenize.SentenceStream]()
@utils.log_exceptions(logger=logger)
async def _tokenize_input():
"""tokenize text from the input_ch to words"""
input_stream = None
async for input in self._input_ch:
if isinstance(input, str):
if input_stream is None:
# new segment (after flush for e.g)
input_stream = self._opts.tokenizer.stream()
self._segments_ch.send_nowait(input_stream)
input_stream.push_text(input)
elif isinstance(input, self._FlushSentinel):
if input_stream is not None:
input_stream.end_input()
input_stream = None
if input_stream is not None:
input_stream.end_input()
self._segments_ch.close()
@utils.log_exceptions(logger=logger)
async def _process_segments():
async for input_stream in self._segments_ch:
await self._run_ws(input_stream)
tasks = [
asyncio.create_task(_tokenize_input()),
asyncio.create_task(_process_segments()),
]
try:
await asyncio.gather(*tasks)
except asyncio.TimeoutError as e:
raise APITimeoutError() from e
except aiohttp.ClientResponseError as e:
raise APIStatusError(
message=e.message,
status_code=e.status,
request_id=request_id,
body=None,
) from e
except Exception as e:
raise APIConnectionError() from e
finally:
await utils.aio.gracefully_cancel(*tasks)
async def _run_ws(
self,
input_stream: tokenize.SentenceStream,
) -> None:
async with self._pool.connection() as ws:
segment_id = utils.shortuuid()
decoder = utils.codecs.AudioStreamDecoder(
sample_rate=self._opts.sample_rate,
num_channels=NUM_CHANNELS,
)
index_lock = asyncio.Lock()
current_index = 0
pending_requests = set()
@utils.log_exceptions(logger=logger)
async def _send_task(ws: aiohttp.ClientWebSocketResponse):
nonlocal current_index
index = 0
async for data in input_stream:
payload = {
"voice_uuid": self._opts.voice_uuid,
"data": data.token,
"request_id": index,
"sample_rate": self._opts.sample_rate,
"precision": "PCM_16",
"output_format": "mp3",
}
async with index_lock:
pending_requests.add(index)
index += 1
current_index = index
await ws.send_str(json.dumps(payload))
@utils.log_exceptions(logger=logger)
async def _emit_task():
emitter = tts.SynthesizedAudioEmitter(
event_ch=self._event_ch,
request_id=str(current_index),
segment_id=segment_id,
)
async for frame in decoder:
emitter.push(frame)
emitter.flush()
@utils.log_exceptions(logger=logger)
async def _recv_task(ws: aiohttp.ClientWebSocketResponse):
while True:
msg = await ws.receive()
if msg.type in (
aiohttp.WSMsgType.CLOSED,
aiohttp.WSMsgType.CLOSE,
aiohttp.WSMsgType.CLOSING,
):
raise APIStatusError(
"Resemble connection closed unexpectedly",
request_id=str(current_index),
)
if msg.type != aiohttp.WSMsgType.TEXT:
logger.warning("Unexpected Resemble message type %s", msg.type)
continue
data = json.loads(msg.data)
if data.get("type") == "audio":
if data.get("audio_content", None):
b64data = base64.b64decode(data["audio_content"])
decoder.push(b64data)
elif data.get("type") == "audio_end":
async with index_lock:
index = data["request_id"]
pending_requests.remove(index)
if not pending_requests:
decoder.end_input()
break # we are not going to receive any more audio
else:
logger.error("Unexpected Resemble message %s", data)
tasks = [
asyncio.create_task(_send_task(ws)),
asyncio.create_task(_recv_task(ws)),
asyncio.create_task(_emit_task()),
]
try:
await asyncio.gather(*tasks)
except asyncio.TimeoutError as e:
raise APITimeoutError() from e
except aiohttp.ClientResponseError as e:
raise APIStatusError(
message=e.message,
status_code=e.status,
request_id=str(current_index),
body=None,
) from e
except Exception as e:
raise APIConnectionError() from e
finally:
await utils.aio.gracefully_cancel(*tasks)
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "0.1.1"
{
"name": "livekit-plugins-resemble",
"private": true,
"version": "0.1.1"
}
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pathlib
import setuptools
import setuptools.command.build_py
here = pathlib.Path(__file__).parent.resolve()
about = {}
with open(os.path.join(here, "livekit", "plugins", "resemble", "version.py"), "r") as f:
exec(f.read(), about)
setuptools.setup(
name="livekit-plugins-resemble",
version=about["__version__"],
description="LiveKit Agents Plugin for Resemble AI",
long_description=(here / "README.md").read_text(encoding="utf-8"),
long_description_content_type="text/markdown",
url="https://github.com/livekit/agents",
cmdclass={},
classifiers=[
"Intended Audience :: Developers",
"Topic :: Multimedia :: Sound/Audio",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3 :: Only",
],
keywords=["webrtc", "realtime", "audio", "video", "livekit", "resemble", "tts"],
license="Apache-2.0",
packages=setuptools.find_namespace_packages(include=["livekit.*"]),
python_requires=">=3.9.0",
install_requires=["livekit-agents[codecs]>=0.12.10"],
package_data={"livekit.plugins.resemble": ["py.typed"]},
project_urls={
"Documentation": "https://docs.livekit.io",
"Website": "https://livekit.io/",
"Source": "https://github.com/livekit/agents",
},
)
# livekit-plugins-rime
## 0.2.2
### Patch Changes
- Add string type support to model parameter - [#1657](https://github.com/livekit/agents/pull/1657) ([@jayeshp19](https://github.com/jayeshp19))
## 0.2.1
### Patch Changes
- use streaming AudioDecoder to handle compressed encoding - [#1584](https://github.com/livekit/agents/pull/1584) ([@davidzhao](https://github.com/davidzhao))
- updated livekit-agent reference to <1.0 - [#1607](https://github.com/livekit/agents/pull/1607) ([@davidzhao](https://github.com/davidzhao))
## 0.2.0
### Minor Changes
- inital release - [#1377](https://github.com/livekit/agents/pull/1377) ([@jayeshp19](https://github.com/jayeshp19))
# LiveKit Plugins Rime
Agent Framework plugin for voice synthesis with the [Rime](https://rime.ai/) API ([documentation](https://rimelabs.mintlify.app/api-reference/quickstart)).
## Installation
```bash
pip install livekit-plugins-rime
You’ll need an API key from Rime. It can be set as an environment variable: RIME_API_KEY
## livekit-plugins/livekit-plugins-rime/livekit/plugins/rime/__init__.py
```py
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .tts import TTS, ChunkedStream
from .version import __version__
__all__ = ["TTS", "ChunkedStream", "__version__"]
from livekit.agents import Plugin
class RimePlugin(Plugin):
def __init__(self) -> None:
super().__init__(__name__, __version__, __package__)
Plugin.register_plugin(RimePlugin())
# Cleanup docs of unexported modules
_module = dir()
NOT_IN_ALL = [m for m in _module if m not in __all__]
__pdoc__ = {}
for n in NOT_IN_ALL:
__pdoc__[n] = False
import logging
logger = logging.getLogger("livekit.plugins.rime")
from typing import Literal
TTSModels = Literal["mist"]
# Copyright 202 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import asyncio
import os
from dataclasses import dataclass
from typing import Optional
import aiohttp
from livekit.agents import (
APIConnectionError,
APIConnectOptions,
APIStatusError,
APITimeoutError,
tts,
utils,
)
from .log import logger
from .models import TTSModels
@dataclass
class _TTSOptions:
model: TTSModels | str
speaker: str
sample_rate: int
speed_alpha: float
reduce_latency: bool
pause_between_brackets: bool
phonemize_between_brackets: bool
DEFAULT_API_URL = "https://users.rime.ai/v1/rime-tts"
NUM_CHANNELS = 1
class TTS(tts.TTS):
def __init__(
self,
*,
model: TTSModels | str = "mist",
speaker: str = "lagoon",
sample_rate: int = 22050,
speed_alpha: float = 1.0,
reduce_latency: bool = False,
pause_between_brackets: bool = False,
phonemize_between_brackets: bool = False,
api_key: str | None = None,
http_session: aiohttp.ClientSession | None = None,
) -> None:
super().__init__(
capabilities=tts.TTSCapabilities(
streaming=False,
),
sample_rate=sample_rate,
num_channels=NUM_CHANNELS,
)
self._api_key = api_key or os.environ.get("RIME_API_KEY")
if not self._api_key:
raise ValueError(
"Rime API key is required, either as argument or set RIME_API_KEY environmental variable"
)
self._opts = _TTSOptions(
model=model,
speaker=speaker,
sample_rate=sample_rate,
speed_alpha=speed_alpha,
reduce_latency=reduce_latency,
pause_between_brackets=pause_between_brackets,
phonemize_between_brackets=phonemize_between_brackets,
)
self._session = http_session
def _ensure_session(self) -> aiohttp.ClientSession:
if not self._session:
self._session = utils.http_context.http_session()
return self._session
def synthesize(
self,
text: str,
*,
conn_options: Optional[APIConnectOptions] = None,
segment_id: str | None = None,
) -> "ChunkedStream":
return ChunkedStream(
tts=self,
input_text=text,
conn_options=conn_options,
opts=self._opts,
session=self._ensure_session(),
segment_id=segment_id,
api_key=self._api_key,
)
def update_options(
self,
*,
model: TTSModels | str | None,
speaker: str | None,
) -> None:
self._opts.model = model or self._opts.model
self._opts.speaker = speaker or self._opts.speaker
class ChunkedStream(tts.ChunkedStream):
"""Synthesize using the chunked api endpoint"""
def __init__(
self,
tts: TTS,
input_text: str,
opts: _TTSOptions,
session: aiohttp.ClientSession,
conn_options: Optional[APIConnectOptions] = None,
segment_id: str | None = None,
api_key: str | None = None,
) -> None:
super().__init__(tts=tts, input_text=input_text, conn_options=conn_options)
self._opts = opts
self._session = session
self._segment_id = segment_id or utils.shortuuid()
self._api_key = api_key
async def _run(self) -> None:
request_id = utils.shortuuid()
headers = {
"accept": "audio/mp3",
"Authorization": f"Bearer {self._api_key}",
"content-type": "application/json",
}
payload = {
"speaker": self._opts.speaker,
"text": self._input_text,
"modelId": self._opts.model,
"samplingRate": self._opts.sample_rate,
"speedAlpha": self._opts.speed_alpha,
"reduceLatency": self._opts.reduce_latency,
"pauseBetweenBrackets": self._opts.pause_between_brackets,
"phonemizeBetweenBrackets": self._opts.phonemize_between_brackets,
}
decoder = utils.codecs.AudioStreamDecoder(
sample_rate=self._opts.sample_rate,
num_channels=NUM_CHANNELS,
)
decode_task: Optional[asyncio.Task] = None
try:
async with self._session.post(
DEFAULT_API_URL, headers=headers, json=payload
) as response:
if not response.content_type.startswith("audio"):
content = await response.text()
logger.error("Rime returned non-audio data: %s", content)
return
async def _decode_loop():
try:
async for bytes_data, _ in response.content.iter_chunks():
decoder.push(bytes_data)
finally:
decoder.end_input()
decode_task = asyncio.create_task(_decode_loop())
emitter = tts.SynthesizedAudioEmitter(
event_ch=self._event_ch,
request_id=request_id,
segment_id=self._segment_id,
)
async for frame in decoder:
emitter.push(frame)
emitter.flush()
except asyncio.TimeoutError as e:
raise APITimeoutError() from e
except aiohttp.ClientResponseError as e:
raise APIStatusError(
message=e.message,
status_code=e.status,
request_id=request_id,
body=None,
) from e
except Exception as e:
raise APIConnectionError() from e
finally:
if decode_task:
await utils.aio.gracefully_cancel(decode_task)
await decoder.aclose()
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "0.2.2"
{
"name": "livekit-plugins-rime",
"private": true,
"version": "0.2.2"
}
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pathlib
import setuptools
import setuptools.command.build_py
here = pathlib.Path(__file__).parent.resolve()
about = {}
with open(os.path.join(here, "livekit", "plugins", "rime", "version.py"), "r") as f:
exec(f.read(), about)
setuptools.setup(
name="livekit-plugins-rime",
version=about["__version__"],
description="LiveKit Agents Plugin for Rime",
long_description=(here / "README.md").read_text(encoding="utf-8"),
long_description_content_type="text/markdown",
url="https://github.com/livekit/agents",
cmdclass={},
classifiers=[
"Intended Audience :: Developers",
"Topic :: Multimedia :: Sound/Audio",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3 :: Only",
],
keywords=["webrtc", "realtime", "audio", "video", "livekit", "rime"],
license="Apache-2.0",
packages=setuptools.find_namespace_packages(include=["livekit.*"]),
python_requires=">=3.9.0",
install_requires=["livekit-agents[codecs]>=0.12.16,<1.0.0"],
package_data={"livekit.plugins.rime": ["py.typed"]},
project_urls={
"Documentation": "https://docs.livekit.io",
"Website": "https://livekit.io/",
"Source": "https://github.com/livekit/agents",
},
)
# livekit-plugins-silero
## 0.7.5
### Patch Changes
- updated livekit-agent reference to <1.0 - [#1607](https://github.com/livekit/agents/pull/1607) ([@davidzhao](https://github.com/davidzhao))
## 0.7.4
### Patch Changes
- fix release - [#1176](https://github.com/livekit/agents/pull/1176) ([@theomonnom](https://github.com/theomonnom))
## 0.7.3
### Patch Changes
- pipelineagent: expose timing metrics & api errors wip - [#957](https://github.com/livekit/agents/pull/957) ([@theomonnom](https://github.com/theomonnom))
## 0.7.2
### Patch Changes
- silero: add update_options - [#899](https://github.com/livekit/agents/pull/899) ([@theomonnom](https://github.com/theomonnom))
- silero: fix speech_buffer for END_OF_SPEECH - [#898](https://github.com/livekit/agents/pull/898) ([@theomonnom](https://github.com/theomonnom))
## 0.7.1
### Patch Changes
- Fix CI x LFS issue for silero plugin - [#818](https://github.com/livekit/agents/pull/818) ([@keepingitneil](https://github.com/keepingitneil))
## 0.7.0
### Minor Changes
- silero: support any sample rate - [#805](https://github.com/livekit/agents/pull/805) ([@theomonnom](https://github.com/theomonnom))
### Patch Changes
- silero: add prefix_padding_duration #801 - [#805](https://github.com/livekit/agents/pull/805) ([@theomonnom](https://github.com/theomonnom))
## 0.6.4
### Patch Changes
- silero: adjust vad activation threshold - [#639](https://github.com/livekit/agents/pull/639) ([@theomonnom](https://github.com/theomonnom))
- silero: fix vad padding & static audio - [#631](https://github.com/livekit/agents/pull/631) ([@theomonnom](https://github.com/theomonnom))
## 0.6.3
### Patch Changes
- silero: fix high cpu usage - [#569](https://github.com/livekit/agents/pull/569) ([@theomonnom](https://github.com/theomonnom))
## 0.6.2
### Patch Changes
- silero: tiny tweaks - [#565](https://github.com/livekit/agents/pull/565) ([@theomonnom](https://github.com/theomonnom))
- silero: optimize numpy input buffers - [#550](https://github.com/livekit/agents/pull/550) ([@theomonnom](https://github.com/theomonnom))
- silero: bring back expfilter - [#562](https://github.com/livekit/agents/pull/562) ([@theomonnom](https://github.com/theomonnom))
## 0.6.1
### Patch Changes
- fix end_input not flushing & unhandled flush messages - [#528](https://github.com/livekit/agents/pull/528) ([@theomonnom](https://github.com/theomonnom))
## 0.6.0
### Minor Changes
- dev prerelease - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
### Patch Changes
- test release - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
- pull: '--rebase --autostash ...' - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
- Default loglevel to warn - [#472](https://github.com/livekit/agents/pull/472) ([@lukasIO](https://github.com/lukasIO))
- bump versions to update dependencies - [#510](https://github.com/livekit/agents/pull/510) ([@theomonnom](https://github.com/theomonnom))
- test release - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
- fix changesets release CI - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
- release v0.8.0 - [`6e74aa714c2dfaa8212db4528d7b59d095b6c660`](https://github.com/livekit/agents/commit/6e74aa714c2dfaa8212db4528d7b59d095b6c660) ([@theomonnom](https://github.com/theomonnom))
- dev fixes - multiprocessing & voiceassistant - [#493](https://github.com/livekit/agents/pull/493) ([@theomonnom](https://github.com/theomonnom))
## 0.6.0-dev.7
### Patch Changes
- pull: '--rebase --autostash ...' - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
## 0.6.0-dev.6
### Patch Changes
- test release - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
## 0.6.0-dev.5
### Patch Changes
- test release - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
## 0.6.0-dev.4
### Patch Changes
- fix changesets release CI - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
## 0.6.0-dev.3
### Patch Changes
- bump versions to update dependencies - [#510](https://github.com/livekit/agents/pull/510) ([@theomonnom](https://github.com/theomonnom))
## 0.6.0-dev.2
### Patch Changes
- dev fixes - multiprocessing & voiceassistant - [#493](https://github.com/livekit/agents/pull/493) ([@theomonnom](https://github.com/theomonnom))
## 0.6.0-dev.1
### Minor Changes
- dev prerelease - [#435](https://github.com/livekit/agents/pull/435) ([@theomonnom](https://github.com/theomonnom))
## 0.5.2-dev.0
### Patch Changes
- Default loglevel to warn - [#472](https://github.com/livekit/agents/pull/472) ([@lukasIO](https://github.com/lukasIO))
# LiveKit Plugins Silero
Agent Framework Plugin for Silero. Currently supports Voice Activity Detection.
## Installation
```bash
pip install livekit-plugins-silero
This plugin contains model files that would need to be downloaded prior to use.
## livekit-plugins/livekit-plugins-silero/livekit/plugins/silero/__init__.py
```py
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .vad import VAD, VADStream
from .version import __version__
__all__ = ["VAD", "VADStream", "__version__"]
from livekit.agents import Plugin
from .log import logger
class SileroPlugin(Plugin):
def __init__(self):
super().__init__(__name__, __version__, __package__, logger)
Plugin.register_plugin(SileroPlugin())
# Cleanup docs of unexported modules
_module = dir()
NOT_IN_ALL = [m for m in _module if m not in __all__]
__pdoc__ = {}
for n in NOT_IN_ALL:
__pdoc__[n] = False
import logging
logger = logging.getLogger("livekit.plugins.silero")
import atexit
import importlib.resources
from contextlib import ExitStack
import numpy as np
import onnxruntime # type: ignore
_resource_files = ExitStack()
atexit.register(_resource_files.close)
SUPPORTED_SAMPLE_RATES = [8000, 16000]
def new_inference_session(force_cpu: bool) -> onnxruntime.InferenceSession:
res = (
importlib.resources.files("livekit.plugins.silero.resources")
/ "silero_vad.onnx"
)
ctx = importlib.resources.as_file(res)
path = str(_resource_files.enter_context(ctx))
opts = onnxruntime.SessionOptions()
opts.add_session_config_entry("session.intra_op.allow_spinning", "0")
opts.add_session_config_entry("session.inter_op.allow_spinning", "0")
opts.inter_op_num_threads = 1
opts.intra_op_num_threads = 1
opts.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL
if force_cpu and "CPUExecutionProvider" in onnxruntime.get_available_providers():
session = onnxruntime.InferenceSession(
path, providers=["CPUExecutionProvider"], sess_options=opts
)
else:
session = onnxruntime.InferenceSession(path, sess_options=opts)
return session
class OnnxModel:
def __init__(
self, *, onnx_session: onnxruntime.InferenceSession, sample_rate: int
) -> None:
self._sess = onnx_session
self._sample_rate = sample_rate
if sample_rate not in SUPPORTED_SAMPLE_RATES:
raise ValueError("Silero VAD only supports 8KHz and 16KHz sample rates")
if sample_rate == 8000:
self._window_size_samples = 256
self._context_size = 32
elif sample_rate == 16000:
self._window_size_samples = 512
self._context_size = 64
self._sample_rate_nd = np.array(sample_rate, dtype=np.int64)
self._context = np.zeros((1, self._context_size), dtype=np.float32)
self._rnn_state = np.zeros((2, 1, 128), dtype=np.float32)
self._input_buffer = np.zeros(
(1, self._context_size + self._window_size_samples), dtype=np.float32
)
@property
def sample_rate(self) -> int:
return self._sample_rate
@property
def window_size_samples(self) -> int:
return self._window_size_samples
@property
def context_size(self) -> int:
return self._context_size
def __call__(self, x: np.ndarray) -> float:
self._input_buffer[:, : self._context_size] = self._context
self._input_buffer[:, self._context_size :] = x
ort_inputs = {
"input": self._input_buffer,
"state": self._rnn_state,
"sr": self._sample_rate_nd,
}
out, self._state = self._sess.run(None, ort_inputs)
self._context = self._input_buffer[:, -self._context_size :]
return out.item()
"""Used by importlib.resources and setuptools"""
version https://git-lfs.github.com/spec/v1
oid sha256:6b99cbfd39246b6706f98ec13c7c50c6b299181f2474fa05cbc8046acc274396
size 2313101
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations, print_function
import asyncio
import time
import weakref
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from typing import Literal
import numpy as np
import onnxruntime # type: ignore
from livekit import agents, rtc
from livekit.agents import utils
from . import onnx_model
from .log import logger
SLOW_INFERENCE_THRESHOLD = 0.2 # late by 200ms
@dataclass
class _VADOptions:
min_speech_duration: float
min_silence_duration: float
prefix_padding_duration: float
max_buffered_speech: float
activation_threshold: float
sample_rate: int
class VAD(agents.vad.VAD):
"""
Silero Voice Activity Detection (VAD) class.
This class provides functionality to detect speech segments within audio data using the Silero VAD model.
"""
@classmethod
def load(
cls,
*,
min_speech_duration: float = 0.05,
min_silence_duration: float = 0.55,
prefix_padding_duration: float = 0.5,
max_buffered_speech: float = 60.0,
activation_threshold: float = 0.5,
sample_rate: Literal[8000, 16000] = 16000,
force_cpu: bool = True,
# deprecated
padding_duration: float | None = None,
) -> "VAD":
"""
Load and initialize the Silero VAD model.
This method loads the ONNX model and prepares it for inference. When options are not provided,
sane defaults are used.
**Note:**
This method is blocking and may take time to load the model into memory.
It is recommended to call this method inside your prewarm mechanism.
**Example:**
```python
def prewarm(proc: JobProcess):
proc.userdata["vad"] = silero.VAD.load()
async def entrypoint(ctx: JobContext):
vad = (ctx.proc.userdata["vad"],)
# your agent logic...
if __name__ == "__main__":
cli.run_app(
WorkerOptions(entrypoint_fnc=entrypoint, prewarm_fnc=prewarm)
)
```
Args:
min_speech_duration (float): Minimum duration of speech to start a new speech chunk.
min_silence_duration (float): At the end of each speech, wait this duration before ending the speech.
prefix_padding_duration (float): Duration of padding to add to the beginning of each speech chunk.
max_buffered_speech (float): Maximum duration of speech to keep in the buffer (in seconds).
activation_threshold (float): Threshold to consider a frame as speech.
sample_rate (Literal[8000, 16000]): Sample rate for the inference (only 8KHz and 16KHz are supported).
force_cpu (bool): Force the use of CPU for inference.
padding_duration (float | None): **Deprecated**. Use `prefix_padding_duration` instead.
Returns:
VAD: An instance of the VAD class ready for streaming.
Raises:
ValueError: If an unsupported sample rate is provided.
"""
if sample_rate not in onnx_model.SUPPORTED_SAMPLE_RATES:
raise ValueError("Silero VAD only supports 8KHz and 16KHz sample rates")
if padding_duration is not None:
logger.warning(
"padding_duration is deprecated and will be removed in 1.5.0, use prefix_padding_duration instead",
)
prefix_padding_duration = padding_duration
session = onnx_model.new_inference_session(force_cpu)
opts = _VADOptions(
min_speech_duration=min_speech_duration,
min_silence_duration=min_silence_duration,
prefix_padding_duration=prefix_padding_duration,
max_buffered_speech=max_buffered_speech,
activation_threshold=activation_threshold,
sample_rate=sample_rate,
)
return cls(session=session, opts=opts)
def __init__(
self,
*,
session: onnxruntime.InferenceSession,
opts: _VADOptions,
) -> None:
super().__init__(capabilities=agents.vad.VADCapabilities(update_interval=0.032))
self._onnx_session = session
self._opts = opts
self._streams = weakref.WeakSet[VADStream]()
def stream(self) -> "VADStream":
"""
Create a new VADStream for processing audio data.
Returns:
VADStream: A stream object for processing audio input and detecting speech.
"""
stream = VADStream(
self,
self._opts,
onnx_model.OnnxModel(
onnx_session=self._onnx_session, sample_rate=self._opts.sample_rate
),
)
self._streams.add(stream)
return stream
def update_options(
self,
*,
min_speech_duration: float | None = None,
min_silence_duration: float | None = None,
prefix_padding_duration: float | None = None,
max_buffered_speech: float | None = None,
activation_threshold: float | None = None,
) -> None:
"""
Update the VAD options.
This method allows you to update the VAD options after the VAD object has been created.
Args:
min_speech_duration (float): Minimum duration of speech to start a new speech chunk.
min_silence_duration (float): At the end of each speech, wait this duration before ending the speech.
prefix_padding_duration (float): Duration of padding to add to the beginning of each speech chunk.
max_buffered_speech (float): Maximum duration of speech to keep in the buffer (in seconds).
activation_threshold (float): Threshold to consider a frame as speech.
"""
self._opts = _VADOptions(
min_speech_duration=min_speech_duration or self._opts.min_speech_duration,
min_silence_duration=min_silence_duration
or self._opts.min_silence_duration,
prefix_padding_duration=prefix_padding_duration
or self._opts.prefix_padding_duration,
max_buffered_speech=max_buffered_speech or self._opts.max_buffered_speech,
activation_threshold=activation_threshold
or self._opts.activation_threshold,
sample_rate=self._opts.sample_rate,
)
for stream in self._streams:
stream.update_options(
min_speech_duration=min_speech_duration,
min_silence_duration=min_silence_duration,
prefix_padding_duration=prefix_padding_duration,
max_buffered_speech=max_buffered_speech,
activation_threshold=activation_threshold,
)
class VADStream(agents.vad.VADStream):
def __init__(
self, vad: VAD, opts: _VADOptions, model: onnx_model.OnnxModel
) -> None:
super().__init__(vad)
self._opts, self._model = opts, model
self._loop = asyncio.get_event_loop()
self._executor = ThreadPoolExecutor(max_workers=1)
self._task.add_done_callback(lambda _: self._executor.shutdown(wait=False))
self._exp_filter = utils.ExpFilter(alpha=0.35)
self._input_sample_rate = 0
self._speech_buffer: np.ndarray | None = None
self._speech_buffer_max_reached = False
self._prefix_padding_samples = 0 # (input_sample_rate)
def update_options(
self,
*,
min_speech_duration: float | None = None,
min_silence_duration: float | None = None,
prefix_padding_duration: float | None = None,
max_buffered_speech: float | None = None,
activation_threshold: float | None = None,
) -> None:
"""
Update the VAD options.
This method allows you to update the VAD options after the VAD object has been created.
Args:
min_speech_duration (float): Minimum duration of speech to start a new speech chunk.
min_silence_duration (float): At the end of each speech, wait this duration before ending the speech.
prefix_padding_duration (float): Duration of padding to add to the beginning of each speech chunk.
max_buffered_speech (float): Maximum duration of speech to keep in the buffer (in seconds).
activation_threshold (float): Threshold to consider a frame as speech.
"""
old_max_buffered_speech = self._opts.max_buffered_speech
self._opts = _VADOptions(
min_speech_duration=min_speech_duration or self._opts.min_speech_duration,
min_silence_duration=min_silence_duration
or self._opts.min_silence_duration,
prefix_padding_duration=prefix_padding_duration
or self._opts.prefix_padding_duration,
max_buffered_speech=max_buffered_speech or self._opts.max_buffered_speech,
activation_threshold=activation_threshold
or self._opts.activation_threshold,
sample_rate=self._opts.sample_rate,
)
if self._input_sample_rate:
assert self._speech_buffer is not None
self._prefix_padding_samples = int(
self._opts.prefix_padding_duration * self._input_sample_rate
)
self._speech_buffer.resize(
int(self._opts.max_buffered_speech * self._input_sample_rate)
+ self._prefix_padding_samples
)
if self._opts.max_buffered_speech > old_max_buffered_speech:
self._speech_buffer_max_reached = False
@agents.utils.log_exceptions(logger=logger)
async def _main_task(self):
inference_f32_data = np.empty(self._model.window_size_samples, dtype=np.float32)
speech_buffer_index: int = 0
# "pub_" means public, these values are exposed to the users through events
pub_speaking = False
pub_speech_duration = 0.0
pub_silence_duration = 0.0
pub_current_sample = 0
pub_timestamp = 0.0
speech_threshold_duration = 0.0
silence_threshold_duration = 0.0
input_frames = []
inference_frames = []
resampler: rtc.AudioResampler | None = None
# used to avoid drift when the sample_rate ratio is not an integer
input_copy_remaining_fract = 0.0
extra_inference_time = 0.0
async for input_frame in self._input_ch:
if not isinstance(input_frame, rtc.AudioFrame):
continue # ignore flush sentinel for now
if not self._input_sample_rate:
self._input_sample_rate = input_frame.sample_rate
# alloc the buffers now that we know the input sample rate
self._prefix_padding_samples = int(
self._opts.prefix_padding_duration * self._input_sample_rate
)
self._speech_buffer = np.empty(
int(self._opts.max_buffered_speech * self._input_sample_rate)
+ self._prefix_padding_samples,
dtype=np.int16,
)
if self._input_sample_rate != self._opts.sample_rate:
# resampling needed: the input sample rate isn't the same as the model's
# sample rate used for inference
resampler = rtc.AudioResampler(
input_rate=self._input_sample_rate,
output_rate=self._opts.sample_rate,
quality=rtc.AudioResamplerQuality.QUICK, # VAD doesn't need high quality
)
elif self._input_sample_rate != input_frame.sample_rate:
logger.error("a frame with another sample rate was already pushed")
continue
assert self._speech_buffer is not None
input_frames.append(input_frame)
if resampler is not None:
# the resampler may have a bit of latency, but it is OK to ignore since it should be
# negligible
inference_frames.extend(resampler.push(input_frame))
else:
inference_frames.append(input_frame)
while True:
start_time = time.perf_counter()
available_inference_samples = sum(
[frame.samples_per_channel for frame in inference_frames]
)
if available_inference_samples < self._model.window_size_samples:
break # not enough samples to run inference
input_frame = utils.combine_frames(input_frames)
inference_frame = utils.combine_frames(inference_frames)
# convert data to f32
np.divide(
inference_frame.data[: self._model.window_size_samples],
np.iinfo(np.int16).max,
out=inference_f32_data,
dtype=np.float32,
)
# run the inference
p = await self._loop.run_in_executor(
self._executor, self._model, inference_f32_data
)
p = self._exp_filter.apply(exp=1.0, sample=p)
window_duration = (
self._model.window_size_samples / self._opts.sample_rate
)
pub_current_sample += self._model.window_size_samples
pub_timestamp += window_duration
resampling_ratio = self._input_sample_rate / self._model.sample_rate
to_copy = (
self._model.window_size_samples * resampling_ratio
+ input_copy_remaining_fract
)
to_copy_int = int(to_copy)
input_copy_remaining_fract = to_copy - to_copy_int
# copy the inference window to the speech buffer
available_space = len(self._speech_buffer) - speech_buffer_index
to_copy_buffer = min(to_copy_int, available_space)
if to_copy_buffer > 0:
self._speech_buffer[
speech_buffer_index : speech_buffer_index + to_copy_buffer
] = input_frame.data[:to_copy_buffer]
speech_buffer_index += to_copy_buffer
elif not self._speech_buffer_max_reached:
# reached self._opts.max_buffered_speech (padding is included)
speech_buffer_max_reached = True
logger.warning(
"max_buffered_speech reached, ignoring further data for the current speech input"
)
inference_duration = time.perf_counter() - start_time
extra_inference_time = max(
0.0,
extra_inference_time + inference_duration - window_duration,
)
if inference_duration > SLOW_INFERENCE_THRESHOLD:
logger.warning(
"inference is slower than realtime",
extra={"delay": extra_inference_time},
)
def _reset_write_cursor():
nonlocal speech_buffer_index, speech_buffer_max_reached
assert self._speech_buffer is not None
if speech_buffer_index <= self._prefix_padding_samples:
return
padding_data = self._speech_buffer[
speech_buffer_index
- self._prefix_padding_samples : speech_buffer_index
]
self._speech_buffer_max_reached = False
self._speech_buffer[: self._prefix_padding_samples] = padding_data
speech_buffer_index = self._prefix_padding_samples
def _copy_speech_buffer() -> rtc.AudioFrame:
# copy the data from speech_buffer
assert self._speech_buffer is not None
speech_data = self._speech_buffer[:speech_buffer_index].tobytes()
return rtc.AudioFrame(
sample_rate=self._input_sample_rate,
num_channels=1,
samples_per_channel=speech_buffer_index,
data=speech_data,
)
if pub_speaking:
pub_speech_duration += window_duration
else:
pub_silence_duration += window_duration
self._event_ch.send_nowait(
agents.vad.VADEvent(
type=agents.vad.VADEventType.INFERENCE_DONE,
samples_index=pub_current_sample,
timestamp=pub_timestamp,
silence_duration=pub_silence_duration,
speech_duration=pub_speech_duration,
probability=p,
inference_duration=inference_duration,
frames=[
rtc.AudioFrame(
data=input_frame.data[:to_copy_int].tobytes(),
sample_rate=self._input_sample_rate,
num_channels=1,
samples_per_channel=to_copy_int,
)
],
speaking=pub_speaking,
raw_accumulated_silence=silence_threshold_duration,
raw_accumulated_speech=speech_threshold_duration,
)
)
if p >= self._opts.activation_threshold:
speech_threshold_duration += window_duration
silence_threshold_duration = 0.0
if not pub_speaking:
if speech_threshold_duration >= self._opts.min_speech_duration:
pub_speaking = True
pub_silence_duration = 0.0
pub_speech_duration = speech_threshold_duration
self._event_ch.send_nowait(
agents.vad.VADEvent(
type=agents.vad.VADEventType.START_OF_SPEECH,
samples_index=pub_current_sample,
timestamp=pub_timestamp,
silence_duration=pub_silence_duration,
speech_duration=pub_speech_duration,
frames=[_copy_speech_buffer()],
speaking=True,
)
)
else:
silence_threshold_duration += window_duration
speech_threshold_duration = 0.0
if not pub_speaking:
_reset_write_cursor()
if (
pub_speaking
and silence_threshold_duration
>= self._opts.min_silence_duration
):
pub_speaking = False
pub_speech_duration = 0.0
pub_silence_duration = silence_threshold_duration
self._event_ch.send_nowait(
agents.vad.VADEvent(
type=agents.vad.VADEventType.END_OF_SPEECH,
samples_index=pub_current_sample,
timestamp=pub_timestamp,
silence_duration=pub_silence_duration,
speech_duration=pub_speech_duration,
frames=[_copy_speech_buffer()],
speaking=False,
)
)
_reset_write_cursor()
# remove the frames that were used for inference from the input and inference frames
input_frames = []
inference_frames = []
# add the remaining data
if len(input_frame.data) - to_copy_int > 0:
data = input_frame.data[to_copy_int:].tobytes()
input_frames.append(
rtc.AudioFrame(
data=data,
sample_rate=self._input_sample_rate,
num_channels=1,
samples_per_channel=len(data) // 2,
)
)
if len(inference_frame.data) - self._model.window_size_samples > 0:
data = inference_frame.data[
self._model.window_size_samples :
].tobytes()
inference_frames.append(
rtc.AudioFrame(
data=data,
sample_rate=self._opts.sample_rate,
num_channels=1,
samples_per_channel=len(data) // 2,
)
)
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "0.7.5"
{
"name": "livekit-plugins-silero",
"private": true,
"version": "0.7.5"
}
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pathlib
import setuptools
import setuptools.command.build_py
here = pathlib.Path(__file__).parent.resolve()
about = {}
with open(os.path.join(here, "livekit", "plugins", "silero", "version.py"), "r") as f:
exec(f.read(), about)
setuptools.setup(
name="livekit-plugins-silero",
version=about["__version__"],
description="Agent Framework Plugin for Silero",
long_description=(here / "README.md").read_text(encoding="utf-8"),
long_description_content_type="text/markdown",
url="https://github.com/livekit/agents",
cmdclass={},
classifiers=[
"Intended Audience :: Developers",
"License :: OSI Approved :: Apache Software License",
"Topic :: Multimedia :: Sound/Audio",
"Topic :: Multimedia :: Video",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3 :: Only",
],
keywords=["webrtc", "realtime", "audio", "video", "livekit"],
license="Apache-2.0",
packages=setuptools.find_namespace_packages(include=["livekit.*"]),
python_requires=">=3.9.0",
install_requires=[
"livekit-agents>=0.12.16,<1.0.0",
"onnxruntime>=1.18",
"numpy>=1.26",
],
package_data={
"livekit.plugins.silero.resources": ["silero_vad.onnx"],
"livekit.plugins.silero": ["py.typed"],
},
project_urls={
"Documentation": "https://docs.livekit.io",
"Website": "https://livekit.io/",
"Source": "https://github.com/livekit/agents",
},
)
# livekit-plugins-speechmatics
## 0.0.2
### Patch Changes
- updated livekit-agent reference to <1.0 - [#1607](https://github.com/livekit/agents/pull/1607) ([@davidzhao](https://github.com/davidzhao))
## 0.0.1
### Minor changes
- Add speechmatics plugin [#1510](https://github.com/livekit/agents/pull/1510) ([@dumitrugutu](https://github.com/dumitrugutu))
# LiveKit Plugins Speechmatics
Agent Framework plugin for Speechmatics.
## Installation
```bash
pip install livekit-plugins-speechmatics
Usage:
agent = VoicePipelineAgent(
stt=speechmatics.STT(),
turn_detector=turn_detector.EOUModel(),
min_endpointing_delay=0.5,
max_endpointing_delay=5.0,
...
)
Note: The plugin was built with
LiveKit’s end-of-turn detection feature in mind,
and it doesn’t implement phrase endpointing. AddTranscript
and AddPartialTranscript
events are emitted as soon
as they’re received from the Speechmatics STT engine. For the best user experience,
we recommend running the agent with end-of-turn detection enabled (
see example).
You’ll need to specify a Speechmatics API Key. It can be set as environment variable SPEECHMATICS_API_KEY
or
.env.local
file.
## livekit-plugins/livekit-plugins-speechmatics/livekit/plugins/speechmatics/__init__.py
```py
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .log import logger
from .stt import STT, SpeechStream
from .version import __version__
__all__ = [
"STT",
"SpeechStream",
"logger",
"__version__",
]
from livekit.agents import Plugin
class SpeechmaticsPlugin(Plugin):
def __init__(self):
super().__init__(__name__, __version__, __package__)
Plugin.register_plugin(SpeechmaticsPlugin())
import logging
logger = logging.getLogger("livekit.plugins.speechmatics")
# Copyright 2025 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import asyncio
import dataclasses
import json
import os
import weakref
from typing import Dict, List, Optional
import aiohttp
from livekit.agents import (
DEFAULT_API_CONNECT_OPTIONS,
APIConnectOptions,
APIStatusError,
stt,
utils,
)
from livekit.agents.utils import AudioBuffer
from .log import logger
from .types import (
AudioSettings,
ClientMessageType,
ConnectionSettings,
ServerMessageType,
TranscriptionConfig,
)
from .utils import get_access_token, sanitize_url
class STT(stt.STT):
def __init__(
self,
*,
transcription_config: TranscriptionConfig = TranscriptionConfig(
language="en",
operating_point="enhanced",
enable_partials=True,
max_delay=0.7,
),
connection_settings: ConnectionSettings = ConnectionSettings(
url="wss://eu2.rt.speechmatics.com/v2",
),
audio_settings: AudioSettings = AudioSettings(),
http_session: Optional[aiohttp.ClientSession] = None,
extra_headers: Optional[Dict] = None,
):
super().__init__(
capabilities=stt.STTCapabilities(
streaming=True,
interim_results=True,
),
)
self._transcription_config = transcription_config
self._audio_settings = audio_settings
self._connection_settings = connection_settings
self._extra_headers = extra_headers or {}
self._session = http_session
self._streams = weakref.WeakSet[SpeechStream]()
@property
def session(self) -> aiohttp.ClientSession:
if not self._session:
self._session = utils.http_context.http_session()
return self._session
async def _recognize_impl(
self,
buffer: AudioBuffer,
*,
language: str | None,
conn_options: APIConnectOptions,
) -> stt.SpeechEvent:
raise NotImplementedError("Not implemented")
def stream(
self,
*,
language: Optional[str] = None,
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
) -> "SpeechStream":
config = dataclasses.replace(self._audio_settings)
stream = SpeechStream(
stt=self,
transcription_config=self._transcription_config,
audio_settings=config,
connection_settings=self._connection_settings,
conn_options=conn_options,
http_session=self.session,
extra_headers=self._extra_headers,
)
self._streams.add(stream)
return stream
class SpeechStream(stt.SpeechStream):
def __init__(
self,
*,
stt: STT,
transcription_config: TranscriptionConfig,
audio_settings: AudioSettings,
connection_settings: ConnectionSettings,
conn_options: APIConnectOptions,
http_session: aiohttp.ClientSession,
extra_headers: Optional[Dict] = None,
) -> None:
super().__init__(
stt=stt, conn_options=conn_options, sample_rate=audio_settings.sample_rate
)
self._transcription_config = transcription_config
self._audio_settings = audio_settings
self._connection_settings = connection_settings
self._session = http_session
self._extra_headers = extra_headers or {}
self._speech_duration: float = 0
self._reconnect_event = asyncio.Event()
self._recognition_started = asyncio.Event()
self._seq_no = 0
async def _run(self):
closing_ws = False
async def send_task(ws: aiohttp.ClientWebSocketResponse):
nonlocal closing_ws
start_recognition_msg = {
"message": ClientMessageType.StartRecognition,
"audio_format": self._audio_settings.asdict(),
"transcription_config": self._transcription_config.asdict(),
}
await ws.send_str(json.dumps(start_recognition_msg))
await self._recognition_started.wait()
audio_bstream = utils.audio.AudioByteStream(
sample_rate=self._audio_settings.sample_rate,
num_channels=1,
)
async for data in self._input_ch:
if isinstance(data, self._FlushSentinel):
frames = audio_bstream.flush()
else:
frames = audio_bstream.write(data.data.tobytes())
for frame in frames:
self._seq_no += 1
self._speech_duration += frame.duration
await ws.send_bytes(frame.data.tobytes())
closing_ws = True
await ws.send_str(
json.dumps(
{
"message": ClientMessageType.EndOfStream,
"last_seq_no": self._seq_no,
}
)
)
async def recv_task(ws: aiohttp.ClientWebSocketResponse):
nonlocal closing_ws
while True:
msg = await ws.receive()
if msg.type in (
aiohttp.WSMsgType.CLOSED,
aiohttp.WSMsgType.CLOSE,
aiohttp.WSMsgType.CLOSING,
):
if closing_ws: # close is expected, see SpeechStream.aclose
return
# this will trigger a reconnection, see the _run loop
raise APIStatusError(
message="Speechmatics connection closed unexpectedly"
)
try:
data = json.loads(msg.data)
self._process_stream_event(data, closing_ws)
except Exception:
logger.exception("failed to process Speechmatics message")
ws: aiohttp.ClientWebSocketResponse | None = None
while True:
try:
ws = await self._connect_ws()
tasks = [
asyncio.create_task(send_task(ws)),
asyncio.create_task(recv_task(ws)),
]
wait_reconnect_task = asyncio.create_task(self._reconnect_event.wait())
try:
done, _ = await asyncio.wait(
[asyncio.gather(*tasks), wait_reconnect_task],
return_when=asyncio.FIRST_COMPLETED,
) # type: ignore
for task in done:
if task != wait_reconnect_task:
task.result()
if wait_reconnect_task not in done:
break
self._reconnect_event.clear()
finally:
await utils.aio.gracefully_cancel(*tasks, wait_reconnect_task)
finally:
if ws is not None:
await ws.close()
async def _connect_ws(self) -> aiohttp.ClientWebSocketResponse:
api_key = self._connection_settings.api_key or os.environ.get(
"SPEECHMATICS_API_KEY"
)
if api_key is None:
raise ValueError(
"Speechmatics API key is required. "
"Pass one in via ConnectionSettings.api_key parameter, "
"or set `SPEECHMATICS_API_KEY` environment variable"
)
if self._connection_settings.get_access_token:
api_key = await get_access_token(api_key)
headers = {
"Authorization": f"Bearer {api_key}",
**self._extra_headers,
}
url = sanitize_url(
self._connection_settings.url, self._transcription_config.language
)
return await self._session.ws_connect(
url,
ssl=self._connection_settings.ssl_context,
headers=headers,
)
def _process_stream_event(self, data: dict, closing_ws: bool) -> None:
message_type = data["message"]
if message_type == ServerMessageType.RecognitionStarted:
self._recognition_started.set()
elif message_type == ServerMessageType.AddPartialTranscript:
alts = live_transcription_to_speech_data(data)
if len(alts) > 0 and alts[0].text:
interim_event = stt.SpeechEvent(
type=stt.SpeechEventType.INTERIM_TRANSCRIPT,
alternatives=alts,
)
self._event_ch.send_nowait(interim_event)
elif message_type == ServerMessageType.AddTranscript:
alts = live_transcription_to_speech_data(data)
if len(alts) > 0 and alts[0].text:
final_event = stt.SpeechEvent(
type=stt.SpeechEventType.FINAL_TRANSCRIPT,
alternatives=alts,
)
self._event_ch.send_nowait(final_event)
if self._speech_duration > 0:
usage_event = stt.SpeechEvent(
type=stt.SpeechEventType.RECOGNITION_USAGE,
alternatives=[],
recognition_usage=stt.RecognitionUsage(
audio_duration=self._speech_duration
),
)
self._event_ch.send_nowait(usage_event)
self._speech_duration = 0
elif message_type == ServerMessageType.EndOfTranscript:
if closing_ws:
pass
else:
raise Exception("Speechmatics connection closed unexpectedly")
def live_transcription_to_speech_data(data: dict) -> List[stt.SpeechData]:
speech_data: List[stt.SpeechData] = []
for result in data.get("results", []):
start_time, end_time, is_eos = (
result.get("start_time", 0),
result.get("end_time", 0),
result.get("is_eos", False),
)
for alt in result.get("alternatives", []):
content, confidence, language = (
alt.get("content", "").strip(),
alt.get("confidence", 1.0),
alt.get("language", "en"),
)
if not content:
continue
# append punctuation to the previous result
if is_eos and speech_data:
speech_data[-1].text += content
elif speech_data and start_time == speech_data[-1].end_time:
speech_data[-1].text += " " + content
else:
speech_data.append(
stt.SpeechData(language, content, start_time, end_time, confidence)
)
return speech_data
import ssl
from dataclasses import asdict, dataclass, field
from enum import Enum
from typing import Any, Dict, Optional
@dataclass
class TranscriptionConfig:
"""Real-time: Defines transcription parameters."""
language: str = "en"
"""ISO 639-1 language code. eg. `en`"""
operating_point: Optional[str] = None
"""Specifies which acoustic model to use."""
output_locale: Optional[str] = None
"""RFC-5646 language code for transcript output. eg. `en-AU`"""
diarization: Optional[str] = None
"""Indicates type of diarization to use, if any."""
additional_vocab: Optional[Dict] = None
"""Additional vocabulary that is not part of the standard language."""
punctuation_overrides: Optional[Dict] = None
"""Permitted puctuation marks for advanced punctuation."""
enable_entities: Optional[bool] = None
"""Indicates if inverse text normalization entity output is enabled."""
max_delay: Optional[float] = None
"""Maximum acceptable delay."""
max_delay_mode: Optional[str] = None
"""Determines whether the threshold specified in max_delay can be exceeded
if a potential entity is detected. Flexible means if a potential entity
is detected, then the max_delay can be overriden until the end of that
entity. Fixed means that max_delay specified ignores any potential
entity that would not be completed within that threshold."""
streaming_mode: Optional[bool] = None
"""Indicates if we run the engine in streaming mode, or regular RT mode."""
enable_partials: Optional[bool] = None
"""Indicates if partials for transcription, where words are produced
immediately, is enabled."""
def asdict(self) -> Dict[Any, Any]:
"""Returns model as a dict while excluding None values recursively."""
return asdict(
self, dict_factory=lambda x: {k: v for (k, v) in x if v is not None}
)
@dataclass
class AudioSettings:
"""Real-time: Defines audio parameters."""
encoding: str = "pcm_s16le"
"""Encoding format when raw audio is used. Allowed values are
`pcm_f32le`, `pcm_s16le` and `mulaw`."""
sample_rate: int = 16000
"""Sampling rate in hertz."""
def asdict(self):
return {
"type": "raw",
"encoding": self.encoding,
"sample_rate": self.sample_rate,
}
@dataclass
class ConnectionSettings:
"""Defines connection parameters."""
url: str
"""Websocket server endpoint."""
ssl_context: ssl.SSLContext = field(default_factory=ssl.create_default_context)
"""SSL context."""
api_key: Optional[str] = None
"""api key to authenticate a customer."""
get_access_token: Optional[bool] = True
"""Automatically generate a temporary token for authentication."""
class ClientMessageType(str, Enum):
# pylint: disable=invalid-name
"""Real-time: Defines various messages sent from client to server."""
StartRecognition = "StartRecognition"
"""Initiates a recognition job based on configuration set previously."""
AddAudio = "AddAudio"
"""Adds more audio data to the recognition job. The server confirms
receipt by sending an :py:attr:`ServerMessageType.AudioAdded` message."""
EndOfStream = "EndOfStream"
"""Indicates that the client has no more audio to send."""
SetRecognitionConfig = "SetRecognitionConfig"
"""Allows the client to re-configure the recognition session."""
class ServerMessageType(str, Enum):
"""Real-time: Defines various message types sent from server to client."""
RecognitionStarted = "RecognitionStarted"
"""Server response to :py:attr:`ClientMessageType.StartRecognition`,
acknowledging that a recognition session has started."""
AudioAdded = "AudioAdded"
"""Server response to :py:attr:`ClientMessageType.AddAudio`, indicating
that audio has been added successfully."""
AddPartialTranscript = "AddPartialTranscript"
"""Indicates a partial transcript, which is an incomplete transcript that
is immediately produced and may change as more context becomes available.
"""
AddTranscript = "AddTranscript"
"""Indicates the final transcript of a part of the audio."""
EndOfTranscript = "EndOfTranscript"
"""Server response to :py:attr:`ClientMessageType.EndOfStream`,
after the server has finished sending all :py:attr:`AddTranscript`
messages."""
Info = "Info"
"""Indicates a generic info message."""
Warning = "Warning"
"""Indicates a generic warning message."""
Error = "Error"
"""Indicates n generic error message."""
import importlib.metadata
import os
import aiohttp
async def get_access_token(api_key: str) -> str:
mp_api_url = os.getenv(
"SPEECHMATICS_MANAGEMENT_PLATFORM_URL", "https://mp.speechmatics.com"
)
endpoint = f"{mp_api_url}/v1/api_keys"
params = {"type": "rt", "sm-sdk": get_sdk_version()}
json_body = {"ttl": 60}
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
async with aiohttp.ClientSession() as session:
async with session.post(
endpoint, params=params, json=json_body, headers=headers
) as resp:
if resp.status == 201:
try:
data = await resp.json()
return data["key_value"]
except (ValueError, KeyError) as e:
raise Exception(
f"Failed to parse Speechmatics access token response: {e}"
)
else:
error_message = await resp.text()
raise Exception(
f"Failed to get Speechmatics access token. "
f"Status: {resp.status}, Error: {error_message}"
)
def get_sdk_version():
version = importlib.metadata.version("livekit-plugins-speechmatics")
return f"livekit-plugins-{version}"
def sanitize_url(url, language):
from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse
parsed_url = urlparse(url)
query_params = dict(parse_qsl(parsed_url.query))
query_params["sm-sdk"] = get_sdk_version()
updated_query = urlencode(query_params)
url_path = parsed_url.path
if not url_path.endswith(language):
if url_path.endswith("/"):
url_path += language
else:
url_path += f"/{language}"
return urlunparse(parsed_url._replace(path=url_path, query=updated_query))
# Copyright 2025 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "0.0.2"
{
"name": "livekit-plugins-speechmatics",
"private": true,
"version": "0.0.2"
}
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pathlib
import setuptools.command.build_py
here = pathlib.Path(__file__).parent.resolve()
about = {}
with open(
os.path.join(here, "livekit", "plugins", "speechmatics", "version.py"), "r"
) as f:
exec(f.read(), about)
setuptools.setup(
name="livekit-plugins-speechmatics",
version=about["__version__"],
description="Agent Framework plugin for Speechmatics",
long_description=(here / "README.md").read_text(encoding="utf-8"),
long_description_content_type="text/markdown",
url="https://github.com/livekit/agents",
cmdclass={},
classifiers=[
"Intended Audience :: Developers",
"License :: OSI Approved :: Apache Software License",
"Topic :: Multimedia :: Sound/Audio",
"Topic :: Multimedia :: Video",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3 :: Only",
],
keywords=["webrtc", "realtime", "audio", "video", "livekit"],
license="Apache-2.0",
packages=setuptools.find_namespace_packages(include=["livekit.*"]),
python_requires=">=3.9.0",
install_requires=[
"livekit-agents>=0.12.16,<1.0.0",
],
package_data={},
project_urls={
"Documentation": "https://docs.livekit.io",
"Website": "https://livekit.io/",
"Source": "https://github.com/livekit/agents",
},
)
# livekit-plugins-eou
## 0.4.4
### Patch Changes
- added a multilingual turn detector option - [#1736](https://github.com/livekit/agents/pull/1736) ([@jeradf](https://github.com/jeradf))
## 0.4.3
### Patch Changes
- updated livekit-agent reference to <1.0 - [#1607](https://github.com/livekit/agents/pull/1607) ([@davidzhao](https://github.com/davidzhao))
- retrained to be robust to missing terminal punctuation - [#1565](https://github.com/livekit/agents/pull/1565) ([@jeradf](https://github.com/jeradf))
## 0.4.2
### Patch Changes
- log from job process instead of inference - [#1506](https://github.com/livekit/agents/pull/1506) ([@davidzhao](https://github.com/davidzhao))
## 0.4.1
### Patch Changes
- fix incorrect dtype on windows - [#1452](https://github.com/livekit/agents/pull/1452) ([@jeradf](https://github.com/jeradf))
- adjust default probability cutoff - [#1465](https://github.com/livekit/agents/pull/1465) ([@jeradf](https://github.com/jeradf))
## 0.4.0
### Minor Changes
- more accurate, smaller, faster model - [#1426](https://github.com/livekit/agents/pull/1426) ([@jeradf](https://github.com/jeradf))
## 0.3.6
### Patch Changes
- prevent arbitrarily long inputs being passed to turn detector - [#1345](https://github.com/livekit/agents/pull/1345) ([@jeradf](https://github.com/jeradf))
- add timeout for EOU inference requests made to the inference process - [#1315](https://github.com/livekit/agents/pull/1315) ([@theomonnom](https://github.com/theomonnom))
## 0.3.5
### Patch Changes
- fix int32/64 errors on Windows - [#1285](https://github.com/livekit/agents/pull/1285) ([@nbsp](https://github.com/nbsp))
## 0.3.4
### Patch Changes
- add jinja2 dependency to turn detector - [#1277](https://github.com/livekit/agents/pull/1277) ([@davidzhao](https://github.com/davidzhao))
## 0.3.3
### Patch Changes
- use quantized onnx version of turn detector model - [#1231](https://github.com/livekit/agents/pull/1231) ([@jeradf](https://github.com/jeradf))
- use onnxruntime for turn detection and remove pytorch dependency - [#1257](https://github.com/livekit/agents/pull/1257) ([@jeradf](https://github.com/jeradf))
## 0.3.2
### Patch Changes
- improvements to endpointing latency - [#1212](https://github.com/livekit/agents/pull/1212) ([@davidzhao](https://github.com/davidzhao))
- Improvements to end of turn plugin, ensure STT language settings. - [#1195](https://github.com/livekit/agents/pull/1195) ([@davidzhao](https://github.com/davidzhao))
## 0.3.1
### Patch Changes
- fix release - [#1176](https://github.com/livekit/agents/pull/1176) ([@theomonnom](https://github.com/theomonnom))
## 0.3.0
### Minor Changes
- feat: inference process & end of utterance plugin - [#1133](https://github.com/livekit/agents/pull/1133) ([@theomonnom](https://github.com/theomonnom))
# LiveKit Plugins Turn Detector
This plugin introduces end-of-turn detection for LiveKit Agents using a custom open-weight model to determine when a user has finished speaking.
Traditional voice agents use VAD (voice activity detection) for end-of-turn detection. However, VAD models lack language understanding, often causing false positives where the agent interrupts the user before they finish speaking.
By leveraging a language model specifically trained for this task, this plugin offers a more accurate and robust method for detecting end-of-turns. The current version supports English only and should not be used when targeting other languages.
## Installation
```bash
pip install livekit-plugins-turn-detector
This plugin is designed to be used with the VoicePipelineAgent
:
from livekit.plugins import turn_detector
agent = VoicePipelineAgent(
...
turn_detector=turn_detector.EOUModel(),
)
This plugin requires model files. Before starting your agent for the first time, or when building Docker images for deployment, run the following command to download the model files:
python my_agent.py download-files
The end-of-turn model is optimized to run on CPUs with modest system requirements. It is designed to run on the same server hosting your agents. On a 4-core server instance, it completes inference in ~50ms with minimal CPU usage.
The model requires 1.5GB of RAM and runs within a shared inference server, supporting multiple concurrent sessions.
We are working to reduce the CPU and memory requirements in future releases.
The plugin source code is licensed under the Apache-2.0 license.
The end-of-turn model is licensed under the LiveKit Model License.
## livekit-plugins/livekit-plugins-turn-detector/livekit/plugins/turn_detector/__init__.py
```py
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from livekit.agents import Plugin
from .english import EnglishModel
from .log import logger
from .version import __version__
__all__ = ["EOUModel", "english", "multilingual", "__version__"]
class EOUPlugin(Plugin):
def __init__(self):
super().__init__(__name__, __version__, __package__, logger)
def download_files(self) -> None:
from transformers import AutoTokenizer
from .base import _download_from_hf_hub
from .models import HG_MODEL, MODEL_REVISIONS, ONNX_FILENAME
for revision in MODEL_REVISIONS.values():
AutoTokenizer.from_pretrained(HG_MODEL, revision=revision)
_download_from_hf_hub(
HG_MODEL, ONNX_FILENAME, subfolder="onnx", revision=revision
)
_download_from_hf_hub(HG_MODEL, "languages.json", revision=revision)
Plugin.register_plugin(EOUPlugin())
EOUModel = EnglishModel
from __future__ import annotations
import asyncio
import json
import time
from abc import ABC, abstractmethod
from livekit.agents import llm
from livekit.agents.inference_runner import _InferenceRunner
from livekit.agents.ipc.inference_executor import InferenceExecutor
from livekit.agents.job import get_current_job_context
from .log import logger
from .models import HG_MODEL, MODEL_REVISIONS, ONNX_FILENAME, EOUModelType
MAX_HISTORY_TOKENS = 512
MAX_HISTORY_TURNS = 6
def _download_from_hf_hub(repo_id, filename, **kwargs):
from huggingface_hub import hf_hub_download
local_path = hf_hub_download(repo_id=repo_id, filename=filename, **kwargs)
return local_path
class _EUORunnerBase(_InferenceRunner):
def __init__(self, model_type: EOUModelType):
super().__init__()
self._model_revision = MODEL_REVISIONS[model_type]
def _format_chat_ctx(self, chat_ctx: dict):
new_chat_ctx = []
for msg in chat_ctx:
content = msg["content"]
if not content:
continue
msg["content"] = content
new_chat_ctx.append(msg)
convo_text = self._tokenizer.apply_chat_template(
new_chat_ctx,
add_generation_prompt=False,
add_special_tokens=False,
tokenize=False,
)
# remove the EOU token from current utterance
ix = convo_text.rfind("<|im_end|>")
text = convo_text[:ix]
return text
def initialize(self) -> None:
import onnxruntime as ort
from huggingface_hub import errors
from transformers import AutoTokenizer
try:
local_path_onnx = _download_from_hf_hub(
HG_MODEL,
ONNX_FILENAME,
subfolder="onnx",
revision=self._model_revision,
local_files_only=True,
)
self._session = ort.InferenceSession(
local_path_onnx, providers=["CPUExecutionProvider"]
)
self._tokenizer = AutoTokenizer.from_pretrained(
HG_MODEL,
revision=self._model_revision,
local_files_only=True,
truncation_side="left",
)
except (errors.LocalEntryNotFoundError, OSError):
logger.error(
(
f"Could not find model {HG_MODEL} with revision {self._model_revision}. Make sure you have downloaded the model before running the agent. "
"Use `python3 your_agent.py download-files` to download the models."
)
)
raise RuntimeError(
f"livekit-plugins-turn-detector initialization failed. Could not find model {HG_MODEL} with revision {self._model_revision}."
) from None
def run(self, data: bytes) -> bytes | None:
data_json = json.loads(data)
chat_ctx = data_json.get("chat_ctx", None)
if not chat_ctx:
raise ValueError("chat_ctx is required on the inference input data")
start_time = time.perf_counter()
text = self._format_chat_ctx(chat_ctx)
inputs = self._tokenizer(
text,
add_special_tokens=False,
return_tensors="np",
max_length=MAX_HISTORY_TOKENS,
truncation=True,
)
# Run inference
outputs = self._session.run(
None, {"input_ids": inputs["input_ids"].astype("int64")}
)
eou_probability = outputs[0][0]
end_time = time.perf_counter()
data = {
"eou_probability": float(eou_probability),
"input": text,
"duration": round(end_time - start_time, 3),
}
return json.dumps(data).encode()
class EOUModelBase(ABC):
def __init__(
self,
model_type: EOUModelType = "en", # default to smaller, english-only model
inference_executor: InferenceExecutor | None = None,
) -> None:
self._model_type = model_type
self._executor = (
inference_executor or get_current_job_context().inference_executor
)
config_fname = _download_from_hf_hub(
HG_MODEL,
"languages.json",
revision=MODEL_REVISIONS[self._model_type],
local_files_only=True,
)
with open(config_fname, "r") as f:
self._languages = json.load(f)
@abstractmethod
def _inference_method(self): ...
def unlikely_threshold(self, language: str | None) -> float | None:
if language is None:
return None
lang = language.lower()
if lang in self._languages:
return self._languages[lang]["threshold"]
if "-" in lang:
part = lang.split("-")[0]
if part in self._languages:
return self._languages[part]["threshold"]
logger.warning(f"Language {language} not supported by EOU model")
return None
def supports_language(self, language: str | None) -> bool:
return self.unlikely_threshold(language) is not None
async def predict_eou(self, chat_ctx: llm.ChatContext) -> float:
return await self.predict_end_of_turn(chat_ctx)
# our EOU model inference should be fast, 3 seconds is more than enough
async def predict_end_of_turn(
self, chat_ctx: llm.ChatContext, *, timeout: float | None = 3
) -> float:
messages = []
for msg in chat_ctx.messages:
if msg.role not in ("user", "assistant"):
continue
if isinstance(msg.content, str):
messages.append(
{
"role": msg.role,
"content": msg.content,
}
)
elif isinstance(msg.content, list):
for cnt in msg.content:
if isinstance(cnt, str):
messages.append(
{
"role": msg.role,
"content": cnt,
}
)
break
messages = messages[-MAX_HISTORY_TURNS:]
json_data = json.dumps({"chat_ctx": messages}).encode()
result = await asyncio.wait_for(
self._executor.do_inference(self._inference_method(), json_data),
timeout=timeout,
)
assert result is not None, (
"end_of_utterance prediction should always returns a result"
)
result_json = json.loads(result.decode())
logger.debug(
"eou prediction",
extra=result_json,
)
return result_json["eou_probability"]
from livekit.agents.inference_runner import _InferenceRunner
from .base import EOUModelBase, _EUORunnerBase
class _EUORunnerEn(_EUORunnerBase):
INFERENCE_METHOD = "lk_end_of_utterance_en"
def __init__(self):
super().__init__("en")
class EnglishModel(EOUModelBase):
def __init__(self):
super().__init__(model_type="en")
def _inference_method(self) -> str:
return _EUORunnerEn.INFERENCE_METHOD
_InferenceRunner.register_runner(_EUORunnerEn)
import logging
logger = logging.getLogger("livekit.plugins.turn_detector")
from typing import Literal
EOUModelType = Literal["en", "multilingual"]
MODEL_REVISIONS: dict[EOUModelType, str] = {
"en": "v1.2.2-en",
"multilingual": "v0.1.0-intl",
}
HG_MODEL = "livekit/turn-detector"
ONNX_FILENAME = "model_q8.onnx"
from livekit.agents.inference_runner import _InferenceRunner
from .base import EOUModelBase, _EUORunnerBase
class _EUORunnerMultilingual(_EUORunnerBase):
INFERENCE_METHOD = "lk_end_of_utterance_multilingual"
def __init__(self):
super().__init__("multilingual")
class MultilingualModel(EOUModelBase):
def __init__(self):
super().__init__(model_type="multilingual")
def _inference_method(self) -> str:
return _EUORunnerMultilingual.INFERENCE_METHOD
_InferenceRunner.register_runner(_EUORunnerMultilingual)
# Copyright 2023 LiveKit, Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "0.4.4"
{
"name": "livekit-plugins-turn-detector",
"private": true,
"version": "0.4.4"
}
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pathlib
import setuptools
import setuptools.command.build_py
here = pathlib.Path(__file__).parent.resolve()
about = {}
with open(
os.path.join(here, "livekit", "plugins", "turn_detector", "version.py"), "r"
) as f:
exec(f.read(), about)
setuptools.setup(
name="livekit-plugins-turn-detector",
version=about["__version__"],
description="End of utterance detection for LiveKit Agents",
long_description=(here / "README.md").read_text(encoding="utf-8"),
long_description_content_type="text/markdown",
url="https://github.com/livekit/agents",
cmdclass={},
classifiers=[
"Intended Audience :: Developers",
"License :: OSI Approved :: Apache Software License",
"Topic :: Multimedia :: Sound/Audio",
"Topic :: Multimedia :: Video",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3 :: Only",
],
keywords=["webrtc", "realtime", "audio", "video", "livekit"],
license="Apache-2.0",
packages=setuptools.find_namespace_packages(include=["livekit.*"]),
python_requires=">=3.9.0",
install_requires=[
"livekit-agents>=0.12.16,<1.0.0",
"transformers>=4.47.1",
"numpy>=1.26",
"onnxruntime>=1.18",
"jinja2",
],
package_data={"livekit.plugins.turn_detector": ["py.typed"]},
project_urls={
"Documentation": "https://docs.livekit.io",
"Website": "https://livekit.io/",
"Source": "https://github.com/livekit/agents",
},
)
[mypy]
[mypy-google.genai.*]
follow_imports = normal
follow_untyped_imports = True
[mypy-aiobotocore.*]
follow_untyped_imports = True
[mypy-boto3.*]
follow_untyped_imports = True
{
"name": "@livekit/agents-py-monorepo",
"private": true,
"type": "module",
"scripts": {
"changeset": "changeset",
"ci:publish": "changeset publish",
"ci:version": "changeset version && python3 .github/update_versions.py"
},
"devDependencies": {
"@changesets/cli": "^2.27.1",
"@livekit/changesets-changelog-github": "^0.0.4"
},
"engines": {
"node": ">= 18"
},
"packageManager": "pnpm@9.2.0",
"dependencies": {}
}
lockfileVersion: '9.0'
settings:
autoInstallPeers: true
excludeLinksFromLockfile: false
importers:
.:
devDependencies:
'@changesets/cli':
specifier: ^2.27.1
version: 2.27.7
'@livekit/changesets-changelog-github':
specifier: ^0.0.4
version: 0.0.4
livekit-agents: {}
livekit-plugins/livekit-plugins-anthropic: {}
livekit-plugins/livekit-plugins-azure: {}
livekit-plugins/livekit-plugins-browser: {}
livekit-plugins/livekit-plugins-cartesia: {}
livekit-plugins/livekit-plugins-deepgram: {}
livekit-plugins/livekit-plugins-elevenlabs: {}
livekit-plugins/livekit-plugins-google: {}
livekit-plugins/livekit-plugins-minimal: {}
livekit-plugins/livekit-plugins-nltk: {}
livekit-plugins/livekit-plugins-openai: {}
livekit-plugins/livekit-plugins-rag: {}
livekit-plugins/livekit-plugins-silero: {}
packages:
'@babel/runtime@7.24.8':
resolution: {integrity: sha512-5F7SDGs1T72ZczbRwbGO9lQi0NLjQxzl6i4lJxLxfW9U5UluCSyEJeniWvnhl3/euNiqQVbo8zruhsDfid0esA==}
engines: {node: '>=6.9.0'}
'@changesets/apply-release-plan@7.0.4':
resolution: {integrity: sha512-HLFwhKWayKinWAul0Vj+76jVx1Pc2v55MGPVjZ924Y/ROeSsBMFutv9heHmCUj48lJyRfOTJG5+ar+29FUky/A==}
'@changesets/assemble-release-plan@6.0.3':
resolution: {integrity: sha512-bLNh9/Lgl1VwkjWZTq8JmRqH+hj7/Yzfz0jsQ/zJJ+FTmVqmqPj3szeKOri8O/hEM8JmHW019vh2gTO9iq5Cuw==}
'@changesets/changelog-git@0.2.0':
resolution: {integrity: sha512-bHOx97iFI4OClIT35Lok3sJAwM31VbUM++gnMBV16fdbtBhgYu4dxsphBF/0AZZsyAHMrnM0yFcj5gZM1py6uQ==}
'@changesets/cli@2.27.7':
resolution: {integrity: sha512-6lr8JltiiXPIjDeYg4iM2MeePP6VN/JkmqBsVA5XRiy01hGS3y629LtSDvKcycj/w/5Eur1rEwby/MjcYS+e2A==}
hasBin: true
'@changesets/config@3.0.2':
resolution: {integrity: sha512-cdEhS4t8woKCX2M8AotcV2BOWnBp09sqICxKapgLHf9m5KdENpWjyrFNMjkLqGJtUys9U+w93OxWT0czorVDfw==}
'@changesets/errors@0.2.0':
resolution: {integrity: sha512-6BLOQUscTpZeGljvyQXlWOItQyU71kCdGz7Pi8H8zdw6BI0g3m43iL4xKUVPWtG+qrrL9DTjpdn8eYuCQSRpow==}
'@changesets/get-dependents-graph@2.1.1':
resolution: {integrity: sha512-LRFjjvigBSzfnPU2n/AhFsuWR5DK++1x47aq6qZ8dzYsPtS/I5mNhIGAS68IAxh1xjO9BTtz55FwefhANZ+FCA==}
'@changesets/get-github-info@0.5.2':
resolution: {integrity: sha512-JppheLu7S114aEs157fOZDjFqUDpm7eHdq5E8SSR0gUBTEK0cNSHsrSR5a66xs0z3RWuo46QvA3vawp8BxDHvg==}
'@changesets/get-release-plan@4.0.3':
resolution: {integrity: sha512-6PLgvOIwTSdJPTtpdcr3sLtGatT+Jr22+cQwEBJBy6wP0rjB4yJ9lv583J9fVpn1bfQlBkDa8JxbS2g/n9lIyA==}
'@changesets/get-version-range-type@0.4.0':
resolution: {integrity: sha512-hwawtob9DryoGTpixy1D3ZXbGgJu1Rhr+ySH2PvTLHvkZuQ7sRT4oQwMh0hbqZH1weAooedEjRsbrWcGLCeyVQ==}
'@changesets/git@3.0.0':
resolution: {integrity: sha512-vvhnZDHe2eiBNRFHEgMiGd2CT+164dfYyrJDhwwxTVD/OW0FUD6G7+4DIx1dNwkwjHyzisxGAU96q0sVNBns0w==}
'@changesets/logger@0.1.0':
resolution: {integrity: sha512-pBrJm4CQm9VqFVwWnSqKEfsS2ESnwqwH+xR7jETxIErZcfd1u2zBSqrHbRHR7xjhSgep9x2PSKFKY//FAshA3g==}
'@changesets/parse@0.4.0':
resolution: {integrity: sha512-TS/9KG2CdGXS27S+QxbZXgr8uPsP4yNJYb4BC2/NeFUj80Rni3TeD2qwWmabymxmrLo7JEsytXH1FbpKTbvivw==}
'@changesets/pre@2.0.0':
resolution: {integrity: sha512-HLTNYX/A4jZxc+Sq8D1AMBsv+1qD6rmmJtjsCJa/9MSRybdxh0mjbTvE6JYZQ/ZiQ0mMlDOlGPXTm9KLTU3jyw==}
'@changesets/read@0.6.0':
resolution: {integrity: sha512-ZypqX8+/im1Fm98K4YcZtmLKgjs1kDQ5zHpc2U1qdtNBmZZfo/IBiG162RoP0CUF05tvp2y4IspH11PLnPxuuw==}
'@changesets/should-skip-package@0.1.0':
resolution: {integrity: sha512-FxG6Mhjw7yFStlSM7Z0Gmg3RiyQ98d/9VpQAZ3Fzr59dCOM9G6ZdYbjiSAt0XtFr9JR5U2tBaJWPjrkGGc618g==}
'@changesets/types@4.1.0':
resolution: {integrity: sha512-LDQvVDv5Kb50ny2s25Fhm3d9QSZimsoUGBsUioj6MC3qbMUCuC8GPIvk/M6IvXx3lYhAs0lwWUQLb+VIEUCECw==}
'@changesets/types@5.2.1':
resolution: {integrity: sha512-myLfHbVOqaq9UtUKqR/nZA/OY7xFjQMdfgfqeZIBK4d0hA6pgxArvdv8M+6NUzzBsjWLOtvApv8YHr4qM+Kpfg==}
'@changesets/types@6.0.0':
resolution: {integrity: sha512-b1UkfNulgKoWfqyHtzKS5fOZYSJO+77adgL7DLRDr+/7jhChN+QcHnbjiQVOz/U+Ts3PGNySq7diAItzDgugfQ==}
'@changesets/write@0.3.1':
resolution: {integrity: sha512-SyGtMXzH3qFqlHKcvFY2eX+6b0NGiFcNav8AFsYwy5l8hejOeoeTDemu5Yjmke2V5jpzY+pBvM0vCCQ3gdZpfw==}
'@livekit/changesets-changelog-github@0.0.4':
resolution: {integrity: sha512-MXaiLYwgkYciZb8G2wkVtZ1pJJzZmVx5cM30Q+ClslrIYyAqQhRbPmZDM79/5CGxb1MTemR/tfOM25tgJgAK0g==}
'@manypkg/find-root@1.1.0':
resolution: {integrity: sha512-mki5uBvhHzO8kYYix/WRy2WX8S3B5wdVSc9D6KcU5lQNglP2yt58/VfLuAK49glRXChosY8ap2oJ1qgma3GUVA==}
'@manypkg/get-packages@1.1.3':
resolution: {integrity: sha512-fo+QhuU3qE/2TQMQmbVMqaQ6EWbMhi4ABWP+O4AM1NqPBuy0OrApV5LO6BrrgnhtAHS2NH6RrVk9OL181tTi8A==}
'@nodelib/fs.scandir@2.1.5':
resolution: {integrity: sha512-vq24Bq3ym5HEQm2NKCr3yXDwjc7vTsEThRDnkp2DK9p1uqLR+DHurm/NOTo0KG7HYHU7eppKZj3MyqYuMBf62g==}
engines: {node: '>= 8'}
'@nodelib/fs.stat@2.0.5':
resolution: {integrity: sha512-RkhPPp2zrqDAQA/2jNhnztcPAlv64XdhIp7a7454A5ovI7Bukxgt7MX7udwAu3zg1DcpPU0rz3VV1SeaqvY4+A==}
engines: {node: '>= 8'}
'@nodelib/fs.walk@1.2.8':
resolution: {integrity: sha512-oGB+UxlgWcgQkgwo8GcEGwemoTFt3FIO9ababBmaGwXIoBKZ+GTy0pP185beGg7Llih/NSHSV2XAs1lnznocSg==}
engines: {node: '>= 8'}
'@types/node@12.20.55':
resolution: {integrity: sha512-J8xLz7q2OFulZ2cyGTLE1TbbZcjpno7FaN6zdJNrgAdrJ+DZzh/uFR6YrTb4C+nXakvud8Q4+rbhoIWlYQbUFQ==}
'@types/semver@7.5.8':
resolution: {integrity: sha512-I8EUhyrgfLrcTkzV3TSsGyl1tSuPrEDzr0yd5m90UgNxQkyDXULk3b6MlQqTCpZpNtWe1K0hzclnZkTcLBe2UQ==}
ansi-colors@4.1.3:
resolution: {integrity: sha512-/6w/C21Pm1A7aZitlI5Ni/2J6FFQN8i1Cvz3kHABAAbw93v/NlvKdVOqz7CCWz/3iv/JplRSEEZ83XION15ovw==}
engines: {node: '>=6'}
ansi-regex@5.0.1:
resolution: {integrity: sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ==}
engines: {node: '>=8'}
ansi-styles@3.2.1:
resolution: {integrity: sha512-VT0ZI6kZRdTh8YyJw3SMbYm/u+NqfsAxEpWO0Pf9sq8/e94WxxOpPKx9FR1FlyCtOVDNOQ+8ntlqFxiRc+r5qA==}
engines: {node: '>=4'}
argparse@1.0.10:
resolution: {integrity: sha512-o5Roy6tNG4SL/FOkCAN6RzjiakZS25RLYFrcMttJqbdd8BWrnA+fGz57iN5Pb06pvBGvl5gQ0B48dJlslXvoTg==}
array-union@2.1.0:
resolution: {integrity: sha512-HGyxoOTYUyCM6stUe6EJgnd4EoewAI7zMdfqO+kGjnlZmBDz/cR5pf8r/cR4Wq60sL/p0IkcjUEEPwS3GFrIyw==}
engines: {node: '>=8'}
better-path-resolve@1.0.0:
resolution: {integrity: sha512-pbnl5XzGBdrFU/wT4jqmJVPn2B6UHPBOhzMQkY/SPUPB6QtUXtmBHBIwCbXJol93mOpGMnQyP/+BB19q04xj7g==}
engines: {node: '>=4'}
braces@3.0.3:
resolution: {integrity: sha512-yQbXgO/OSZVD2IsiLlro+7Hf6Q18EJrKSEsdoMzKePKXct3gvD8oLcOQdIzGupr5Fj+EDe8gO/lxc1BzfMpxvA==}
engines: {node: '>=8'}
chalk@2.4.2:
resolution: {integrity: sha512-Mti+f9lpJNcwF4tWV8/OrTTtF1gZi+f8FqlyAdouralcFWFQWF2+NgCHShjkCb+IFBLq9buZwE1xckQU4peSuQ==}
engines: {node: '>=4'}
chardet@0.7.0:
resolution: {integrity: sha512-mT8iDcrh03qDGRRmoA2hmBJnxpllMR+0/0qlzjqZES6NdiWDcZkCNAk4rPFZ9Q85r27unkiNNg8ZOiwZXBHwcA==}
ci-info@3.9.0:
resolution: {integrity: sha512-NIxF55hv4nSqQswkAeiOi1r83xy8JldOFDTWiug55KBu9Jnblncd2U6ViHmYgHf01TPZS77NJBhBMKdWj9HQMQ==}
engines: {node: '>=8'}
color-convert@1.9.3:
resolution: {integrity: sha512-QfAUtd+vFdAtFQcC8CCyYt1fYWxSqAiK2cSD6zDB8N3cpsEBAvRxp9zOGg6G/SHHJYAT88/az/IuDGALsNVbGg==}
color-name@1.1.3:
resolution: {integrity: sha512-72fSenhMw2HZMTVHeCA9KCmpEIbzWiQsjN+BHcBbS9vr1mtt+vJjPdksIBNUmKAW8TFUDPJK5SUU3QhE9NEXDw==}
cross-spawn@5.1.0:
resolution: {integrity: sha512-pTgQJ5KC0d2hcY8eyL1IzlBPYjTkyH72XRZPnLyKus2mBfNjQs3klqbJU2VILqZryAZUt9JOb3h/mWMy23/f5A==}
dataloader@1.4.0:
resolution: {integrity: sha512-68s5jYdlvasItOJnCuI2Q9s4q98g0pCyL3HrcKJu8KNugUl8ahgmZYg38ysLTgQjjXX3H8CJLkAvWrclWfcalw==}
detect-indent@6.1.0:
resolution: {integrity: sha512-reYkTUJAZb9gUuZ2RvVCNhVHdg62RHnJ7WJl8ftMi4diZ6NWlciOzQN88pUhSELEwflJht4oQDv0F0BMlwaYtA==}
engines: {node: '>=8'}
dir-glob@3.0.1:
resolution: {integrity: sha512-WkrWp9GR4KXfKGYzOLmTuGVi1UWFfws377n9cc55/tb6DuqyF6pcQ5AbiHEshaDpY9v6oaSr2XCDidGmMwdzIA==}
engines: {node: '>=8'}
dotenv@8.6.0:
resolution: {integrity: sha512-IrPdXQsk2BbzvCBGBOTmmSH5SodmqZNt4ERAZDmW4CT+tL8VtvinqywuANaFu4bOMWki16nqf0e4oC0QIaDr/g==}
engines: {node: '>=10'}
enquirer@2.4.1:
resolution: {integrity: sha512-rRqJg/6gd538VHvR3PSrdRBb/1Vy2YfzHqzvbhGIQpDRKIa4FgV/54b5Q1xYSxOOwKvjXweS26E0Q+nAMwp2pQ==}
engines: {node: '>=8.6'}
escape-string-regexp@1.0.5:
resolution: {integrity: sha512-vbRorB5FUQWvla16U8R/qgaFIya2qGzwDrNmCZuYKrbdSUMG6I1ZCGQRefkRVhuOkIGVne7BQ35DSfo1qvJqFg==}
engines: {node: '>=0.8.0'}
esprima@4.0.1:
resolution: {integrity: sha512-eGuFFw7Upda+g4p+QHvnW0RyTX/SVeJBDM/gCtMARO0cLuT2HcEKnTPvhjV6aGeqrCB/sbNop0Kszm0jsaWU4A==}
engines: {node: '>=4'}
hasBin: true
extendable-error@0.1.7:
resolution: {integrity: sha512-UOiS2in6/Q0FK0R0q6UY9vYpQ21mr/Qn1KOnte7vsACuNJf514WvCCUHSRCPcgjPT2bAhNIJdlE6bVap1GKmeg==}
external-editor@3.1.0:
resolution: {integrity: sha512-hMQ4CX1p1izmuLYyZqLMO/qGNw10wSv9QDCPfzXfyFrOaCSSoRfqE1Kf1s5an66J5JZC62NewG+mK49jOCtQew==}
engines: {node: '>=4'}
fast-glob@3.3.2:
resolution: {integrity: sha512-oX2ruAFQwf/Orj8m737Y5adxDQO0LAB7/S5MnxCdTNDd4p6BsyIVsv9JQsATbTSq8KHRpLwIHbVlUNatxd+1Ow==}
engines: {node: '>=8.6.0'}
fastq@1.17.1:
resolution: {integrity: sha512-sRVD3lWVIXWg6By68ZN7vho9a1pQcN/WBFaAAsDDFzlJjvoGx0P8z7V1t72grFJfJhu3YPZBuu25f7Kaw2jN1w==}
fill-range@7.1.1:
resolution: {integrity: sha512-YsGpe3WHLK8ZYi4tWDg2Jy3ebRz2rXowDxnld4bkQB00cc/1Zw9AWnC0i9ztDJitivtQvaI9KaLyKrc+hBW0yg==}
engines: {node: '>=8'}
find-up@4.1.0:
resolution: {integrity: sha512-PpOwAdQ/YlXQ2vj8a3h8IipDuYRi3wceVQQGYWxNINccq40Anw7BlsEXCMbt1Zt+OLA6Fq9suIpIWD0OsnISlw==}
engines: {node: '>=8'}
find-up@5.0.0:
resolution: {integrity: sha512-78/PXT1wlLLDgTzDs7sjq9hzz0vXD+zn+7wypEe4fXQxCmdmqfGsEPQxmiCSQI3ajFV91bVSsvNtrJRiW6nGng==}
engines: {node: '>=10'}
find-yarn-workspace-root2@1.2.16:
resolution: {integrity: sha512-hr6hb1w8ePMpPVUK39S4RlwJzi+xPLuVuG8XlwXU3KD5Yn3qgBWVfy3AzNlDhWvE1EORCE65/Qm26rFQt3VLVA==}
fs-extra@7.0.1:
resolution: {integrity: sha512-YJDaCJZEnBmcbw13fvdAM9AwNOJwOzrE4pqMqBq5nFiEqXUqHwlK4B+3pUw6JNvfSPtX05xFHtYy/1ni01eGCw==}
engines: {node: '>=6 <7 || >=8'}
fs-extra@8.1.0:
resolution: {integrity: sha512-yhlQgA6mnOJUKOsRUFsgJdQCvkKhcz8tlZG5HBQfReYZy46OwLcY+Zia0mtdHsOo9y/hP+CxMN0TU9QxoOtG4g==}
engines: {node: '>=6 <7 || >=8'}
glob-parent@5.1.2:
resolution: {integrity: sha512-AOIgSQCepiJYwP3ARnGx+5VnTu2HBYdzbGP45eLw1vr3zB3vZLeyed1sC9hnbcOc9/SrMyM5RPQrkGz4aS9Zow==}
engines: {node: '>= 6'}
globby@11.1.0:
resolution: {integrity: sha512-jhIXaOzy1sb8IyocaruWSn1TjmnBVs8Ayhcy83rmxNJ8q2uWKCAj3CnJY+KpGSXCueAPc0i05kVvVKtP1t9S3g==}
engines: {node: '>=10'}
graceful-fs@4.2.11:
resolution: {integrity: sha512-RbJ5/jmFcNNCcDV5o9eTnBLJ/HszWV0P73bc+Ff4nS/rJj+YaS6IGyiOL0VoBYX+l1Wrl3k63h/KrH+nhJ0XvQ==}
has-flag@3.0.0:
resolution: {integrity: sha512-sKJf1+ceQBr4SMkvQnBDNDtf4TXpVhVGateu0t918bl30FnbE2m4vNLX+VWe/dpjlb+HugGYzW7uQXH98HPEYw==}
engines: {node: '>=4'}
human-id@1.0.2:
resolution: {integrity: sha512-UNopramDEhHJD+VR+ehk8rOslwSfByxPIZyJRfV739NDhN5LF1fa1MqnzKm2lGTQRjNrjK19Q5fhkgIfjlVUKw==}
iconv-lite@0.4.24:
resolution: {integrity: sha512-v3MXnZAcvnywkTUEZomIActle7RXXeedOR31wwl7VlyoXO4Qi9arvSenNQWne1TcRwhCL1HwLI21bEqdpj8/rA==}
engines: {node: '>=0.10.0'}
ignore@5.3.1:
resolution: {integrity: sha512-5Fytz/IraMjqpwfd34ke28PTVMjZjJG2MPn5t7OE4eUCUNf8BAa7b5WUS9/Qvr6mwOQS7Mk6vdsMno5he+T8Xw==}
engines: {node: '>= 4'}
is-extglob@2.1.1:
resolution: {integrity: sha512-SbKbANkN603Vi4jEZv49LeVJMn4yGwsbzZworEoyEiutsN3nJYdbO36zfhGJ6QEDpOZIFkDtnq5JRxmvl3jsoQ==}
engines: {node: '>=0.10.0'}
is-glob@4.0.3:
resolution: {integrity: sha512-xelSayHH36ZgE7ZWhli7pW34hNbNl8Ojv5KVmkJD4hBdD3th8Tfk9vYasLM+mXWOZhFkgZfxhLSnrwRr4elSSg==}
engines: {node: '>=0.10.0'}
is-number@7.0.0:
resolution: {integrity: sha512-41Cifkg6e8TylSpdtTpeLVMqvSBEVzTttHvERD741+pnZ8ANv0004MRL43QKPDlK9cGvNp6NZWZUBlbGXYxxng==}
engines: {node: '>=0.12.0'}
is-subdir@1.2.0:
resolution: {integrity: sha512-2AT6j+gXe/1ueqbW6fLZJiIw3F8iXGJtt0yDrZaBhAZEG1raiTxKWU+IPqMCzQAXOUCKdA4UDMgacKH25XG2Cw==}
engines: {node: '>=4'}
is-windows@1.0.2:
resolution: {integrity: sha512-eXK1UInq2bPmjyX6e3VHIzMLobc4J94i4AWn+Hpq3OU5KkrRC96OAcR3PRJ/pGu6m8TRnBHP9dkXQVsT/COVIA==}
engines: {node: '>=0.10.0'}
isexe@2.0.0:
resolution: {integrity: sha512-RHxMLp9lnKHGHRng9QFhRCMbYAcVpn69smSGcq3f36xjgVVWThj4qqLbTLlq7Ssj8B+fIQ1EuCEGI2lKsyQeIw==}
js-yaml@3.14.1:
resolution: {integrity: sha512-okMH7OXXJ7YrN9Ok3/SXrnu4iX9yOk+25nqX4imS2npuvTYDmo/QEZoqwZkYaIDk3jVvBOTOIEgEhaLOynBS9g==}
hasBin: true
jsonfile@4.0.0:
resolution: {integrity: sha512-m6F1R3z8jjlf2imQHS2Qez5sjKWQzbuuhuJ/FKYFRZvPE3PuHcSMVZzfsLhGVOkfd20obL5SWEBew5ShlquNxg==}
load-yaml-file@0.2.0:
resolution: {integrity: sha512-OfCBkGEw4nN6JLtgRidPX6QxjBQGQf72q3si2uvqyFEMbycSFFHwAZeXx6cJgFM9wmLrf9zBwCP3Ivqa+LLZPw==}
engines: {node: '>=6'}
locate-path@5.0.0:
resolution: {integrity: sha512-t7hw9pI+WvuwNJXwk5zVHpyhIqzg2qTlklJOf0mVxGSbe3Fp2VieZcduNYjaLDoy6p9uGpQEGWG87WpMKlNq8g==}
engines: {node: '>=8'}
locate-path@6.0.0:
resolution: {integrity: sha512-iPZK6eYjbxRu3uB4/WZ3EsEIMJFMqAoopl3R+zuq0UjcAm/MO6KCweDgPfP3elTztoKP3KtnVHxTn2NHBSDVUw==}
engines: {node: '>=10'}
lodash.startcase@4.4.0:
resolution: {integrity: sha512-+WKqsK294HMSc2jEbNgpHpd0JfIBhp7rEV4aqXWqFr6AlXov+SlcgB1Fv01y2kGe3Gc8nMW7VA0SrGuSkRfIEg==}
lru-cache@4.1.5:
resolution: {integrity: sha512-sWZlbEP2OsHNkXrMl5GYk/jKk70MBng6UU4YI/qGDYbgf6YbP4EvmqISbXCoJiRKs+1bSpFHVgQxvJ17F2li5g==}
merge2@1.4.1:
resolution: {integrity: sha512-8q7VEgMJW4J8tcfVPy8g09NcQwZdbwFEqhe/WZkoIzjn/3TGDwtOCYtXGxA3O8tPzpczCCDgv+P2P5y00ZJOOg==}
engines: {node: '>= 8'}
micromatch@4.0.7:
resolution: {integrity: sha512-LPP/3KorzCwBxfeUuZmaR6bG2kdeHSbe0P2tY3FLRU4vYrjYz5hI4QZwV0njUx3jeuKe67YukQ1LSPZBKDqO/Q==}
engines: {node: '>=8.6'}
mri@1.2.0:
resolution: {integrity: sha512-tzzskb3bG8LvYGFF/mDTpq3jpI6Q9wc3LEmBaghu+DdCssd1FakN7Bc0hVNmEyGq1bq3RgfkCb3cmQLpNPOroA==}
engines: {node: '>=4'}
node-fetch@2.7.0:
resolution: {integrity: sha512-c4FRfUm/dbcWZ7U+1Wq0AwCyFL+3nt2bEw05wfxSz+DWpWsitgmSgYmy2dQdWyKC1694ELPqMs/YzUSNozLt8A==}
engines: {node: 4.x || >=6.0.0}
peerDependencies:
encoding: ^0.1.0
peerDependenciesMeta:
encoding:
optional: true
os-tmpdir@1.0.2:
resolution: {integrity: sha512-D2FR03Vir7FIu45XBY20mTb+/ZSWB00sjU9jdQXt83gDrI4Ztz5Fs7/yy74g2N5SVQY4xY1qDr4rNddwYRVX0g==}
engines: {node: '>=0.10.0'}
outdent@0.5.0:
resolution: {integrity: sha512-/jHxFIzoMXdqPzTaCpFzAAWhpkSjZPF4Vsn6jAfNpmbH/ymsmd7Qc6VE9BGn0L6YMj6uwpQLxCECpus4ukKS9Q==}
p-filter@2.1.0:
resolution: {integrity: sha512-ZBxxZ5sL2HghephhpGAQdoskxplTwr7ICaehZwLIlfL6acuVgZPm8yBNuRAFBGEqtD/hmUeq9eqLg2ys9Xr/yw==}
engines: {node: '>=8'}
p-limit@2.3.0:
resolution: {integrity: sha512-//88mFWSJx8lxCzwdAABTJL2MyWB12+eIY7MDL2SqLmAkeKU9qxRvWuSyTjm3FUmpBEMuFfckAIqEaVGUDxb6w==}
engines: {node: '>=6'}
p-limit@3.1.0:
resolution: {integrity: sha512-TYOanM3wGwNGsZN2cVTYPArw454xnXj5qmWF1bEoAc4+cU/ol7GVh7odevjp1FNHduHc3KZMcFduxU5Xc6uJRQ==}
engines: {node: '>=10'}
p-locate@4.1.0:
resolution: {integrity: sha512-R79ZZ/0wAxKGu3oYMlz8jy/kbhsNrS7SKZ7PxEHBgJ5+F2mtFW2fK2cOtBh1cHYkQsbzFV7I+EoRKe6Yt0oK7A==}
engines: {node: '>=8'}
p-locate@5.0.0:
resolution: {integrity: sha512-LaNjtRWUBY++zB5nE/NwcaoMylSPk+S+ZHNB1TzdbMJMny6dynpAGt7X/tl/QYq3TIeE6nxHppbo2LGymrG5Pw==}
engines: {node: '>=10'}
p-map@2.1.0:
resolution: {integrity: sha512-y3b8Kpd8OAN444hxfBbFfj1FY/RjtTd8tzYwhUqNYXx0fXx2iX4maP4Qr6qhIKbQXI02wTLAda4fYUbDagTUFw==}
engines: {node: '>=6'}
p-try@2.2.0:
resolution: {integrity: sha512-R4nPAVTAU0B9D35/Gk3uJf/7XYbQcyohSKdvAxIRSNghFl4e71hVoGnBNQz9cWaXxO2I10KTC+3jMdvvoKw6dQ==}
engines: {node: '>=6'}
path-exists@4.0.0:
resolution: {integrity: sha512-ak9Qy5Q7jYb2Wwcey5Fpvg2KoAc/ZIhLSLOSBmRmygPsGwkVVt0fZa0qrtMz+m6tJTAHfZQ8FnmB4MG4LWy7/w==}
engines: {node: '>=8'}
path-type@4.0.0:
resolution: {integrity: sha512-gDKb8aZMDeD/tZWs9P6+q0J9Mwkdl6xMV8TjnGP3qJVJ06bdMgkbBlLU8IdfOsIsFz2BW1rNVT3XuNEl8zPAvw==}
engines: {node: '>=8'}
picomatch@2.3.1:
resolution: {integrity: sha512-JU3teHTNjmE2VCGFzuY8EXzCDVwEqB2a8fsIvwaStHhAWJEeVd1o1QD80CU6+ZdEXXSLbSsuLwJjkCBWqRQUVA==}
engines: {node: '>=8.6'}
pify@4.0.1:
resolution: {integrity: sha512-uB80kBFb/tfd68bVleG9T5GGsGPjJrLAUpR5PZIrhBnIaRTQRjqdJSsIKkOP6OAIFbj7GOrcudc5pNjZ+geV2g==}
engines: {node: '>=6'}
pkg-dir@4.2.0:
resolution: {integrity: sha512-HRDzbaKjC+AOWVXxAU/x54COGeIv9eb+6CkDSQoNTt4XyWoIJvuPsXizxu/Fr23EiekbtZwmh1IcIG/l/a10GQ==}
engines: {node: '>=8'}
preferred-pm@3.1.4:
resolution: {integrity: sha512-lEHd+yEm22jXdCphDrkvIJQU66EuLojPPtvZkpKIkiD+l0DMThF/niqZKJSoU8Vl7iuvtmzyMhir9LdVy5WMnA==}
engines: {node: '>=10'}
prettier@2.8.8:
resolution: {integrity: sha512-tdN8qQGvNjw4CHbY+XXk0JgCXn9QiF21a55rBe5LJAU+kDyC4WQn4+awm2Xfk2lQMk5fKup9XgzTZtGkjBdP9Q==}
engines: {node: '>=10.13.0'}
hasBin: true
pseudomap@1.0.2:
resolution: {integrity: sha512-b/YwNhb8lk1Zz2+bXXpS/LK9OisiZZ1SNsSLxN1x2OXVEhW2Ckr/7mWE5vrC1ZTiJlD9g19jWszTmJsB+oEpFQ==}
queue-microtask@1.2.3:
resolution: {integrity: sha512-NuaNSa6flKT5JaSYQzJok04JzTL1CA6aGhv5rfLW3PgqA+M2ChpZQnAC8h8i4ZFkBS8X5RqkDBHA7r4hej3K9A==}
read-yaml-file@1.1.0:
resolution: {integrity: sha512-VIMnQi/Z4HT2Fxuwg5KrY174U1VdUIASQVWXXyqtNRtxSr9IYkn1rsI6Tb6HsrHCmB7gVpNwX6JxPTHcH6IoTA==}
engines: {node: '>=6'}
regenerator-runtime@0.14.1:
resolution: {integrity: sha512-dYnhHh0nJoMfnkZs6GmmhFknAGRrLznOu5nc9ML+EJxGvrx6H7teuevqVqCuPcPK//3eDrrjQhehXVx9cnkGdw==}
resolve-from@5.0.0:
resolution: {integrity: sha512-qYg9KP24dD5qka9J47d0aVky0N+b4fTU89LN9iDnjB5waksiC49rvMB0PrUJQGoTmH50XPiqOvAjDfaijGxYZw==}
engines: {node: '>=8'}
reusify@1.0.4:
resolution: {integrity: sha512-U9nH88a3fc/ekCF1l0/UP1IosiuIjyTh7hBvXVMHYgVcfGvt897Xguj2UOLDeI5BG2m7/uwyaLVT6fbtCwTyzw==}
engines: {iojs: '>=1.0.0', node: '>=0.10.0'}
run-parallel@1.2.0:
resolution: {integrity: sha512-5l4VyZR86LZ/lDxZTR6jqL8AFE2S0IFLMP26AbjsLVADxHdhB/c0GUsH+y39UfCi3dzz8OlQuPmnaJOMoDHQBA==}
safer-buffer@2.1.2:
resolution: {integrity: sha512-YZo3K82SD7Riyi0E1EQPojLz7kpepnSQI9IyPbHHg1XXXevb5dJI7tpyN2ADxGcQbHG7vcyRHk0cbwqcQriUtg==}
semver@7.6.3:
resolution: {integrity: sha512-oVekP1cKtI+CTDvHWYFUcMtsK/00wmAEfyqKfNdARm8u1wNVhSgaX7A8d4UuIlUI5e84iEwOhs7ZPYRmzU9U6A==}
engines: {node: '>=10'}
hasBin: true
shebang-command@1.2.0:
resolution: {integrity: sha512-EV3L1+UQWGor21OmnvojK36mhg+TyIKDh3iFBKBohr5xeXIhNBcx8oWdgkTEEQ+BEFFYdLRuqMfd5L84N1V5Vg==}
engines: {node: '>=0.10.0'}
shebang-regex@1.0.0:
resolution: {integrity: sha512-wpoSFAxys6b2a2wHZ1XpDSgD7N9iVjg29Ph9uV/uaP9Ex/KXlkTZTeddxDPSYQpgvzKLGJke2UU0AzoGCjNIvQ==}
engines: {node: '>=0.10.0'}
signal-exit@3.0.7:
resolution: {integrity: sha512-wnD2ZE+l+SPC/uoS0vXeE9L1+0wuaMqKlfz9AMUo38JsyLSBWSFcHR1Rri62LZc12vLr1gb3jl7iwQhgwpAbGQ==}
slash@3.0.0:
resolution: {integrity: sha512-g9Q1haeby36OSStwb4ntCGGGaKsaVSjQ68fBxoQcutl5fS1vuY18H3wSt3jFyFtrkx+Kz0V1G85A4MyAdDMi2Q==}
engines: {node: '>=8'}
spawndamnit@2.0.0:
resolution: {integrity: sha512-j4JKEcncSjFlqIwU5L/rp2N5SIPsdxaRsIv678+TZxZ0SRDJTm8JrxJMjE/XuiEZNEir3S8l0Fa3Ke339WI4qA==}
sprintf-js@1.0.3:
resolution: {integrity: sha512-D9cPgkvLlV3t3IzL0D0YLvGA9Ahk4PcvVwUbN0dSGr1aP0Nrt4AEnTUbuGvquEC0mA64Gqt1fzirlRs5ibXx8g==}
strip-ansi@6.0.1:
resolution: {integrity: sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A==}
engines: {node: '>=8'}
strip-bom@3.0.0:
resolution: {integrity: sha512-vavAMRXOgBVNF6nyEEmL3DBK19iRpDcoIwW+swQ+CbGiu7lju6t+JklA1MHweoWtadgt4ISVUsXLyDq34ddcwA==}
engines: {node: '>=4'}
supports-color@5.5.0:
resolution: {integrity: sha512-QjVjwdXIt408MIiAqCX4oUKsgU2EqAGzs2Ppkm4aQYbjm+ZEWEcW4SfFNTr4uMNZma0ey4f5lgLrkB0aX0QMow==}
engines: {node: '>=4'}
term-size@2.2.1:
resolution: {integrity: sha512-wK0Ri4fOGjv/XPy8SBHZChl8CM7uMc5VML7SqiQ0zG7+J5Vr+RMQDoHa2CNT6KHUnTGIXH34UDMkPzAUyapBZg==}
engines: {node: '>=8'}
tmp@0.0.33:
resolution: {integrity: sha512-jRCJlojKnZ3addtTOjdIqoRuPEKBvNXcGYqzO6zWZX8KfKEpnGY5jfggJQ3EjKuu8D4bJRr0y+cYJFmYbImXGw==}
engines: {node: '>=0.6.0'}
to-regex-range@5.0.1:
resolution: {integrity: sha512-65P7iz6X5yEr1cwcgvQxbbIw7Uk3gOy5dIdtZ4rDveLqhrdJP+Li/Hx6tyK0NEb+2GCyneCMJiGqrADCSNk8sQ==}
engines: {node: '>=8.0'}
tr46@0.0.3:
resolution: {integrity: sha512-N3WMsuqV66lT30CrXNbEjx4GEwlow3v6rr4mCcv6prnfwhS01rkgyFdjPNBYd9br7LpXV1+Emh01fHnq2Gdgrw==}
universalify@0.1.2:
resolution: {integrity: sha512-rBJeI5CXAlmy1pV+617WB9J63U6XcazHHF2f2dbJix4XzpUF0RS3Zbj0FGIOCAva5P/d/GBOYaACQ1w+0azUkg==}
engines: {node: '>= 4.0.0'}
webidl-conversions@3.0.1:
resolution: {integrity: sha512-2JAn3z8AR6rjK8Sm8orRC0h/bcl/DqL7tRPdGZ4I1CjdF+EaMLmYxBHyXuKL849eucPFhvBoxMsflfOb8kxaeQ==}
whatwg-url@5.0.0:
resolution: {integrity: sha512-saE57nupxk6v3HY35+jzBwYa0rKSy0XR8JSxZPwgLr7ys0IBzhGviA1/TUGJLmSVqs8pb9AnvICXEuOHLprYTw==}
which-pm@2.2.0:
resolution: {integrity: sha512-MOiaDbA5ZZgUjkeMWM5EkJp4loW5ZRoa5bc3/aeMox/PJelMhE6t7S/mLuiY43DBupyxH+S0U1bTui9kWUlmsw==}
engines: {node: '>=8.15'}
which@1.3.1:
resolution: {integrity: sha512-HxJdYWq1MTIQbJ3nw0cqssHoTNU267KlrDuGZ1WYlxDStUtKUhOaJmh112/TZmHxxUfuJqPXSOm7tDyas0OSIQ==}
hasBin: true
yallist@2.1.2:
resolution: {integrity: sha512-ncTzHV7NvsQZkYe1DW7cbDLm0YpzHmZF5r/iyP3ZnQtMiJ+pjzisCiMNI+Sj+xQF5pXhSHxSB3uDbsBTzY/c2A==}
yocto-queue@0.1.0:
resolution: {integrity: sha512-rVksvsnNCdJ/ohGc6xgPwyN8eheCxsiLM8mxuE/t/mOVqJewPuO1miLpTHQiRgTKCLexL4MeAFVagts7HmNZ2Q==}
engines: {node: '>=10'}
snapshots:
'@babel/runtime@7.24.8':
dependencies:
regenerator-runtime: 0.14.1
'@changesets/apply-release-plan@7.0.4':
dependencies:
'@babel/runtime': 7.24.8
'@changesets/config': 3.0.2
'@changesets/get-version-range-type': 0.4.0
'@changesets/git': 3.0.0
'@changesets/should-skip-package': 0.1.0
'@changesets/types': 6.0.0
'@manypkg/get-packages': 1.1.3
detect-indent: 6.1.0
fs-extra: 7.0.1
lodash.startcase: 4.4.0
outdent: 0.5.0
prettier: 2.8.8
resolve-from: 5.0.0
semver: 7.6.3
'@changesets/assemble-release-plan@6.0.3':
dependencies:
'@babel/runtime': 7.24.8
'@changesets/errors': 0.2.0
'@changesets/get-dependents-graph': 2.1.1
'@changesets/should-skip-package': 0.1.0
'@changesets/types': 6.0.0
'@manypkg/get-packages': 1.1.3
semver: 7.6.3
'@changesets/changelog-git@0.2.0':
dependencies:
'@changesets/types': 6.0.0
'@changesets/cli@2.27.7':
dependencies:
'@babel/runtime': 7.24.8
'@changesets/apply-release-plan': 7.0.4
'@changesets/assemble-release-plan': 6.0.3
'@changesets/changelog-git': 0.2.0
'@changesets/config': 3.0.2
'@changesets/errors': 0.2.0
'@changesets/get-dependents-graph': 2.1.1
'@changesets/get-release-plan': 4.0.3
'@changesets/git': 3.0.0
'@changesets/logger': 0.1.0
'@changesets/pre': 2.0.0
'@changesets/read': 0.6.0
'@changesets/should-skip-package': 0.1.0
'@changesets/types': 6.0.0
'@changesets/write': 0.3.1
'@manypkg/get-packages': 1.1.3
'@types/semver': 7.5.8
ansi-colors: 4.1.3
chalk: 2.4.2
ci-info: 3.9.0
enquirer: 2.4.1
external-editor: 3.1.0
fs-extra: 7.0.1
human-id: 1.0.2
mri: 1.2.0
outdent: 0.5.0
p-limit: 2.3.0
preferred-pm: 3.1.4
resolve-from: 5.0.0
semver: 7.6.3
spawndamnit: 2.0.0
term-size: 2.2.1
'@changesets/config@3.0.2':
dependencies:
'@changesets/errors': 0.2.0
'@changesets/get-dependents-graph': 2.1.1
'@changesets/logger': 0.1.0
'@changesets/types': 6.0.0
'@manypkg/get-packages': 1.1.3
fs-extra: 7.0.1
micromatch: 4.0.7
'@changesets/errors@0.2.0':
dependencies:
extendable-error: 0.1.7
'@changesets/get-dependents-graph@2.1.1':
dependencies:
'@changesets/types': 6.0.0
'@manypkg/get-packages': 1.1.3
chalk: 2.4.2
fs-extra: 7.0.1
semver: 7.6.3
'@changesets/get-github-info@0.5.2':
dependencies:
dataloader: 1.4.0
node-fetch: 2.7.0
transitivePeerDependencies:
- encoding
'@changesets/get-release-plan@4.0.3':
dependencies:
'@babel/runtime': 7.24.8
'@changesets/assemble-release-plan': 6.0.3
'@changesets/config': 3.0.2
'@changesets/pre': 2.0.0
'@changesets/read': 0.6.0
'@changesets/types': 6.0.0
'@manypkg/get-packages': 1.1.3
'@changesets/get-version-range-type@0.4.0': {}
'@changesets/git@3.0.0':
dependencies:
'@babel/runtime': 7.24.8
'@changesets/errors': 0.2.0
'@changesets/types': 6.0.0
'@manypkg/get-packages': 1.1.3
is-subdir: 1.2.0
micromatch: 4.0.7
spawndamnit: 2.0.0
'@changesets/logger@0.1.0':
dependencies:
chalk: 2.4.2
'@changesets/parse@0.4.0':
dependencies:
'@changesets/types': 6.0.0
js-yaml: 3.14.1
'@changesets/pre@2.0.0':
dependencies:
'@babel/runtime': 7.24.8
'@changesets/errors': 0.2.0
'@changesets/types': 6.0.0
'@manypkg/get-packages': 1.1.3
fs-extra: 7.0.1
'@changesets/read@0.6.0':
dependencies:
'@babel/runtime': 7.24.8
'@changesets/git': 3.0.0
'@changesets/logger': 0.1.0
'@changesets/parse': 0.4.0
'@changesets/types': 6.0.0
chalk: 2.4.2
fs-extra: 7.0.1
p-filter: 2.1.0
'@changesets/should-skip-package@0.1.0':
dependencies:
'@babel/runtime': 7.24.8
'@changesets/types': 6.0.0
'@manypkg/get-packages': 1.1.3
'@changesets/types@4.1.0': {}
'@changesets/types@5.2.1': {}
'@changesets/types@6.0.0': {}
'@changesets/write@0.3.1':
dependencies:
'@babel/runtime': 7.24.8
'@changesets/types': 6.0.0
fs-extra: 7.0.1
human-id: 1.0.2
prettier: 2.8.8
'@livekit/changesets-changelog-github@0.0.4':
dependencies:
'@changesets/get-github-info': 0.5.2
'@changesets/types': 5.2.1
dotenv: 8.6.0
transitivePeerDependencies:
- encoding
'@manypkg/find-root@1.1.0':
dependencies:
'@babel/runtime': 7.24.8
'@types/node': 12.20.55
find-up: 4.1.0
fs-extra: 8.1.0
'@manypkg/get-packages@1.1.3':
dependencies:
'@babel/runtime': 7.24.8
'@changesets/types': 4.1.0
'@manypkg/find-root': 1.1.0
fs-extra: 8.1.0
globby: 11.1.0
read-yaml-file: 1.1.0
'@nodelib/fs.scandir@2.1.5':
dependencies:
'@nodelib/fs.stat': 2.0.5
run-parallel: 1.2.0
'@nodelib/fs.stat@2.0.5': {}
'@nodelib/fs.walk@1.2.8':
dependencies:
'@nodelib/fs.scandir': 2.1.5
fastq: 1.17.1
'@types/node@12.20.55': {}
'@types/semver@7.5.8': {}
ansi-colors@4.1.3: {}
ansi-regex@5.0.1: {}
ansi-styles@3.2.1:
dependencies:
color-convert: 1.9.3
argparse@1.0.10:
dependencies:
sprintf-js: 1.0.3
array-union@2.1.0: {}
better-path-resolve@1.0.0:
dependencies:
is-windows: 1.0.2
braces@3.0.3:
dependencies:
fill-range: 7.1.1
chalk@2.4.2:
dependencies:
ansi-styles: 3.2.1
escape-string-regexp: 1.0.5
supports-color: 5.5.0
chardet@0.7.0: {}
ci-info@3.9.0: {}
color-convert@1.9.3:
dependencies:
color-name: 1.1.3
color-name@1.1.3: {}
cross-spawn@5.1.0:
dependencies:
lru-cache: 4.1.5
shebang-command: 1.2.0
which: 1.3.1
dataloader@1.4.0: {}
detect-indent@6.1.0: {}
dir-glob@3.0.1:
dependencies:
path-type: 4.0.0
dotenv@8.6.0: {}
enquirer@2.4.1:
dependencies:
ansi-colors: 4.1.3
strip-ansi: 6.0.1
escape-string-regexp@1.0.5: {}
esprima@4.0.1: {}
extendable-error@0.1.7: {}
external-editor@3.1.0:
dependencies:
chardet: 0.7.0
iconv-lite: 0.4.24
tmp: 0.0.33
fast-glob@3.3.2:
dependencies:
'@nodelib/fs.stat': 2.0.5
'@nodelib/fs.walk': 1.2.8
glob-parent: 5.1.2
merge2: 1.4.1
micromatch: 4.0.7
fastq@1.17.1:
dependencies:
reusify: 1.0.4
fill-range@7.1.1:
dependencies:
to-regex-range: 5.0.1
find-up@4.1.0:
dependencies:
locate-path: 5.0.0
path-exists: 4.0.0
find-up@5.0.0:
dependencies:
locate-path: 6.0.0
path-exists: 4.0.0
find-yarn-workspace-root2@1.2.16:
dependencies:
micromatch: 4.0.7
pkg-dir: 4.2.0
fs-extra@7.0.1:
dependencies:
graceful-fs: 4.2.11
jsonfile: 4.0.0
universalify: 0.1.2
fs-extra@8.1.0:
dependencies:
graceful-fs: 4.2.11
jsonfile: 4.0.0
universalify: 0.1.2
glob-parent@5.1.2:
dependencies:
is-glob: 4.0.3
globby@11.1.0:
dependencies:
array-union: 2.1.0
dir-glob: 3.0.1
fast-glob: 3.3.2
ignore: 5.3.1
merge2: 1.4.1
slash: 3.0.0
graceful-fs@4.2.11: {}
has-flag@3.0.0: {}
human-id@1.0.2: {}
iconv-lite@0.4.24:
dependencies:
safer-buffer: 2.1.2
ignore@5.3.1: {}
is-extglob@2.1.1: {}
is-glob@4.0.3:
dependencies:
is-extglob: 2.1.1
is-number@7.0.0: {}
is-subdir@1.2.0:
dependencies:
better-path-resolve: 1.0.0
is-windows@1.0.2: {}
isexe@2.0.0: {}
js-yaml@3.14.1:
dependencies:
argparse: 1.0.10
esprima: 4.0.1
jsonfile@4.0.0:
optionalDependencies:
graceful-fs: 4.2.11
load-yaml-file@0.2.0:
dependencies:
graceful-fs: 4.2.11
js-yaml: 3.14.1
pify: 4.0.1
strip-bom: 3.0.0
locate-path@5.0.0:
dependencies:
p-locate: 4.1.0
locate-path@6.0.0:
dependencies:
p-locate: 5.0.0
lodash.startcase@4.4.0: {}
lru-cache@4.1.5:
dependencies:
pseudomap: 1.0.2
yallist: 2.1.2
merge2@1.4.1: {}
micromatch@4.0.7:
dependencies:
braces: 3.0.3
picomatch: 2.3.1
mri@1.2.0: {}
node-fetch@2.7.0:
dependencies:
whatwg-url: 5.0.0
os-tmpdir@1.0.2: {}
outdent@0.5.0: {}
p-filter@2.1.0:
dependencies:
p-map: 2.1.0
p-limit@2.3.0:
dependencies:
p-try: 2.2.0
p-limit@3.1.0:
dependencies:
yocto-queue: 0.1.0
p-locate@4.1.0:
dependencies:
p-limit: 2.3.0
p-locate@5.0.0:
dependencies:
p-limit: 3.1.0
p-map@2.1.0: {}
p-try@2.2.0: {}
path-exists@4.0.0: {}
path-type@4.0.0: {}
picomatch@2.3.1: {}
pify@4.0.1: {}
pkg-dir@4.2.0:
dependencies:
find-up: 4.1.0
preferred-pm@3.1.4:
dependencies:
find-up: 5.0.0
find-yarn-workspace-root2: 1.2.16
path-exists: 4.0.0
which-pm: 2.2.0
prettier@2.8.8: {}
pseudomap@1.0.2: {}
queue-microtask@1.2.3: {}
read-yaml-file@1.1.0:
dependencies:
graceful-fs: 4.2.11
js-yaml: 3.14.1
pify: 4.0.1
strip-bom: 3.0.0
regenerator-runtime@0.14.1: {}
resolve-from@5.0.0: {}
reusify@1.0.4: {}
run-parallel@1.2.0:
dependencies:
queue-microtask: 1.2.3
safer-buffer@2.1.2: {}
semver@7.6.3: {}
shebang-command@1.2.0:
dependencies:
shebang-regex: 1.0.0
shebang-regex@1.0.0: {}
signal-exit@3.0.7: {}
slash@3.0.0: {}
spawndamnit@2.0.0:
dependencies:
cross-spawn: 5.1.0
signal-exit: 3.0.7
sprintf-js@1.0.3: {}
strip-ansi@6.0.1:
dependencies:
ansi-regex: 5.0.1
strip-bom@3.0.0: {}
supports-color@5.5.0:
dependencies:
has-flag: 3.0.0
term-size@2.2.1: {}
tmp@0.0.33:
dependencies:
os-tmpdir: 1.0.2
to-regex-range@5.0.1:
dependencies:
is-number: 7.0.0
tr46@0.0.3: {}
universalify@0.1.2: {}
webidl-conversions@3.0.1: {}
whatwg-url@5.0.0:
dependencies:
tr46: 0.0.3
webidl-conversions: 3.0.1
which-pm@2.2.0:
dependencies:
load-yaml-file: 0.2.0
path-exists: 4.0.0
which@1.3.1:
dependencies:
isexe: 2.0.0
yallist@2.1.2: {}
yocto-queue@0.1.0: {}
packages:
- "livekit-agents"
- "livekit-plugins/*"
line-length = 88
indent-width = 4
target-version = "py39"
[lint]
extend-select = ["I"]
[lint.pydocstyle]
convention = "numpy"
[format]
docstring-code-format = true
import dataclasses
import logging
import pytest
from livekit.agents import DEFAULT_API_CONNECT_OPTIONS, utils
from livekit.agents.cli import log
TEST_CONNECT_OPTIONS = dataclasses.replace(
DEFAULT_API_CONNECT_OPTIONS, retry_interval=0.0
)
@pytest.fixture
def job_process(event_loop):
utils.http_context._new_session_ctx()
yield
event_loop.run_until_complete(utils.http_context._close_http_ctx())
@pytest.fixture(autouse=True)
def configure_test():
log._silence_noisy_loggers()
@pytest.fixture()
def logger():
logger = logging.getLogger("livekit.tests")
logger.setLevel(logging.DEBUG)
return logger
from __future__ import annotations
import asyncio
from livekit.agents import NOT_GIVEN, NotGivenOr, utils
from livekit.agents.stt import (
STT,
RecognizeStream,
SpeechData,
SpeechEvent,
SpeechEventType,
STTCapabilities,
)
from livekit.agents.types import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions
from livekit.agents.utils.audio import AudioBuffer
class RecognizeSentinel: ...
class FakeSTT(STT):
def __init__(
self,
*,
fake_exception: Exception | None = None,
fake_transcript: str | None = None,
fake_timeout: float | None = None,
) -> None:
super().__init__(
capabilities=STTCapabilities(streaming=True, interim_results=False),
)
self._fake_exception = fake_exception
self._fake_transcript = fake_transcript
self._fake_timeout = fake_timeout
self._recognize_ch = utils.aio.Chan[RecognizeSentinel]()
self._stream_ch = utils.aio.Chan[FakeRecognizeStream]()
def update_options(
self,
*,
fake_exception: NotGivenOr[Exception | None] = NOT_GIVEN,
fake_transcript: NotGivenOr[str | None] = NOT_GIVEN,
fake_timeout: NotGivenOr[float | None] = NOT_GIVEN,
) -> None:
if utils.is_given(fake_exception):
self._fake_exception = fake_exception
if utils.is_given(fake_transcript):
self._fake_transcript = fake_transcript
if utils.is_given(fake_timeout):
self._fake_timeout = fake_timeout
@property
def recognize_ch(self) -> utils.aio.ChanReceiver[RecognizeSentinel]:
return self._recognize_ch
@property
def stream_ch(self) -> utils.aio.ChanReceiver["FakeRecognizeStream"]:
return self._stream_ch
async def _recognize_impl(
self,
buffer: AudioBuffer,
*,
language: str | None,
conn_options: APIConnectOptions,
) -> SpeechEvent:
if self._fake_timeout is not None:
await asyncio.sleep(self._fake_timeout)
if self._fake_exception is not None:
raise self._fake_exception
return SpeechEvent(
type=SpeechEventType.FINAL_TRANSCRIPT,
alternatives=[
SpeechData(text=self._fake_transcript or "", language=language or "")
],
)
async def recognize(
self,
buffer: AudioBuffer,
*,
language: str | None = None,
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
):
self._recognize_ch.send_nowait(RecognizeSentinel())
return await super().recognize(
buffer, language=language, conn_options=conn_options
)
def stream(
self,
*,
language: str | None = None,
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
) -> "FakeRecognizeStream":
stream = FakeRecognizeStream(
stt=self,
conn_options=conn_options,
)
self._stream_ch.send_nowait(stream)
return stream
class FakeRecognizeStream(RecognizeStream):
def __init__(
self,
*,
stt: STT,
conn_options: APIConnectOptions,
):
super().__init__(stt=stt, conn_options=conn_options)
self._attempt = 0
@property
def attempt(self) -> int:
return self._attempt
def send_fake_transcript(self, transcript: str) -> None:
self._event_ch.send_nowait(
SpeechEvent(
type=SpeechEventType.FINAL_TRANSCRIPT,
alternatives=[SpeechData(text=transcript, language="")],
)
)
async def _run(self) -> None:
self._attempt += 1
assert isinstance(self._stt, FakeSTT)
if self._stt._fake_timeout is not None:
await asyncio.sleep(self._stt._fake_timeout)
if self._stt._fake_transcript is not None:
self.send_fake_transcript(self._stt._fake_transcript)
async for _ in self._input_ch:
pass
if self._stt._fake_exception is not None:
raise self._stt._fake_exception
from __future__ import annotations
import asyncio
from livekit import rtc
from livekit.agents import NOT_GIVEN, NotGivenOr, utils
from livekit.agents.tts import (
TTS,
ChunkedStream,
SynthesizedAudio,
SynthesizeStream,
TTSCapabilities,
)
from livekit.agents.types import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions
class FakeTTS(TTS):
def __init__(
self,
*,
sample_rate: int = 24000,
num_channels: int = 1,
fake_timeout: float | None = None,
fake_audio_duration: float | None = None,
fake_exception: Exception | None = None,
) -> None:
super().__init__(
capabilities=TTSCapabilities(streaming=True),
sample_rate=sample_rate,
num_channels=num_channels,
)
self._fake_timeout = fake_timeout
self._fake_audio_duration = fake_audio_duration
self._fake_exception = fake_exception
self._synthesize_ch = utils.aio.Chan[FakeChunkedStream]()
self._stream_ch = utils.aio.Chan[FakeSynthesizeStream]()
def update_options(
self,
*,
fake_timeout: NotGivenOr[float | None] = NOT_GIVEN,
fake_audio_duration: NotGivenOr[float | None] = NOT_GIVEN,
fake_exception: NotGivenOr[Exception | None] = NOT_GIVEN,
) -> None:
if utils.is_given(fake_timeout):
self._fake_timeout = fake_timeout
if utils.is_given(fake_audio_duration):
self._fake_audio_duration = fake_audio_duration
if utils.is_given(fake_exception):
self._fake_exception = fake_exception
@property
def synthesize_ch(self) -> utils.aio.ChanReceiver["FakeChunkedStream"]:
return self._synthesize_ch
@property
def stream_ch(self) -> utils.aio.ChanReceiver["FakeSynthesizeStream"]:
return self._stream_ch
def synthesize(
self,
text: str,
*,
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
) -> "FakeChunkedStream":
stream = FakeChunkedStream(tts=self, input_text=text, conn_options=conn_options)
self._synthesize_ch.send_nowait(stream)
return stream
def stream(
self, *, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS
) -> "FakeSynthesizeStream":
stream = FakeSynthesizeStream(
tts=self,
conn_options=conn_options,
)
self._stream_ch.send_nowait(stream)
return stream
class FakeChunkedStream(ChunkedStream):
def __init__(
self, *, tts: FakeTTS, input_text: str, conn_options: APIConnectOptions
) -> None:
super().__init__(tts=tts, input_text=input_text, conn_options=conn_options)
self._attempt = 0
@property
def attempt(self) -> int:
return self._attempt
async def _run(self) -> None:
self._attempt += 1
assert isinstance(self._tts, FakeTTS)
request_id = utils.shortuuid("fake_tts_")
if self._tts._fake_timeout is not None:
await asyncio.sleep(self._tts._fake_timeout)
if self._tts._fake_audio_duration is not None:
pushed_samples = 0
max_samples = (
int(self._tts.sample_rate * self._tts._fake_audio_duration + 0.5)
* self._tts.num_channels
)
while pushed_samples < max_samples:
num_samples = min(
self._tts.sample_rate // 100, max_samples - pushed_samples
)
self._event_ch.send_nowait(
SynthesizedAudio(
request_id=request_id,
frame=rtc.AudioFrame(
data=b"\x00\x00" * num_samples,
samples_per_channel=num_samples // self._tts.num_channels,
sample_rate=self._tts.sample_rate,
num_channels=self._tts.num_channels,
),
)
)
pushed_samples += num_samples
if self._tts._fake_exception is not None:
raise self._tts._fake_exception
class FakeSynthesizeStream(SynthesizeStream):
def __init__(
self,
*,
tts: TTS,
conn_options: APIConnectOptions,
):
super().__init__(tts=tts, conn_options=conn_options)
self._attempt = 0
@property
def attempt(self) -> int:
return self._attempt
async def _run(self) -> None:
self._attempt += 1
assert isinstance(self._tts, FakeTTS)
if self._tts._fake_timeout is not None:
await asyncio.sleep(self._tts._fake_timeout)
has_data = False
async for data in self._input_ch:
if isinstance(data, str):
has_data = True
continue
elif isinstance(data, SynthesizeStream._FlushSentinel) and not has_data:
continue
has_data = False
if self._tts._fake_audio_duration is None:
continue
request_id = utils.shortuuid("fake_tts_")
segment_id = utils.shortuuid("fake_segment_")
pushed_samples = 0
max_samples = (
int(self._tts.sample_rate * self._tts._fake_audio_duration + 0.5)
* self._tts.num_channels
)
while pushed_samples < max_samples:
num_samples = min(
self._tts.sample_rate // 100, max_samples - pushed_samples
)
self._event_ch.send_nowait(
SynthesizedAudio(
request_id=request_id,
segment_id=segment_id,
is_final=(pushed_samples + num_samples >= max_samples),
frame=rtc.AudioFrame(
data=b"\x00\x00" * num_samples,
samples_per_channel=num_samples // self._tts.num_channels,
sample_rate=self._tts.sample_rate,
num_channels=self._tts.num_channels,
),
)
)
pushed_samples += num_samples
if self._tts._fake_exception is not None:
raise self._tts._fake_exception
version https://git-lfs.github.com/spec/v1
oid sha256:06087a10c1864e6644d16a6e508852e678ad1a96e4d99bd8056bb7f60ab765cc
size 1048576
version https://git-lfs.github.com/spec/v1
oid sha256:a420326dbf4f37675bf14ae260fff776aee428ced887fc97e0936c36b96589f6
size 559968
The people who are crazy enough to think they can change the world are the ones who do.
The reasonable man adapts himself to the world; the unreasonable one persists in trying to adapt the world to himself. Therefore all progress depends on the unreasonable man.
Never doubt that a small group of thoughtful, committed citizens can change the world; indeed, it's the only thing that ever has.
Do not go where the path may lead, go instead where there is no path and leave a trail.
It could not have been ten seconds, and yet it seemed a long time that their hands were clasped together.
He had time to learn every detail of her hand.
He explored the long fingers, the shapely nails, the work-hardened palm with its row of callouses, the smooth flesh under the wrist.
Merely from feeling it he would have known it by sight.
In the same instant it occurred to him that he did not know what colour the girl's eyes were.
They were probably brown, but people with dark hair sometimes had blue eyes.
To turn his head and look at her would have been inconceivable folly.
With hands locked together, invisible among the press of bodies,
they stared steadily in front of them, and instead of the eyes of the girl, the eyes of the aged prisoner gazed mournfully at Winston out of nests of hair.
[pytest]
asyncio_mode = auto
timeout = 120
asyncio_default_fixture_loop_scope = "function"
log_cli = 1
log_cli_level = INFO
jiwer==3.0.4
import asyncio
from livekit.agents.utils import aio
async def test_channel():
tx = rx = aio.Chan[int]()
sum = 0
async def test_task():
nonlocal sum
while True:
try:
sum = sum + await rx.recv()
except aio.ChanClosed:
break
t = asyncio.create_task(test_task())
for _ in range(10):
await tx.send(1)
tx.close()
await t
assert sum == 10
async def test_interval():
interval = aio.interval(0.1)
_ = asyncio.get_event_loop()
async for i in interval:
if i == 3:
break
async def test_sleep():
await aio.sleep(0)
sleep = aio.sleep(5)
sleep.reset(0.1)
await sleep
import sys
from inspect import _empty
from typing import List, Optional, Union
import pytest
from livekit.agents.llm import FunctionArgInfo, FunctionInfo
from livekit.agents.llm.function_context import _is_optional_type
from livekit.plugins.openai import _oai_api
def test_typing():
assert _is_optional_type(Optional[int]) == (True, int)
assert _is_optional_type(Union[str, None]) == (True, str)
if sys.version_info >= (3, 10):
assert _is_optional_type(float | None) == (True, float)
assert _is_optional_type(Union[str, int]) == (False, None)
@pytest.mark.parametrize(
("arg_typ", "oai_type"),
[
pytest.param(int, "number", id="int"),
pytest.param(Optional[int], "number", id="optional[int]"),
pytest.param(Union[None, int], "number", id="union[none, int]"),
pytest.param(Union[str, None], "string", id="union[str, none]"),
pytest.param(List[int], "array", id="list[int]"),
pytest.param(Optional[List[int]], "array", id="optional[list[int]]"),
],
)
def test_description_building(arg_typ: type, oai_type: str):
fi = FunctionInfo(
name="foo",
description="foo",
auto_retry=False,
callable=lambda: None,
arguments={
"arg": FunctionArgInfo(
name="foo",
description="foo",
type=arg_typ,
default=_empty,
choices=(),
),
},
)
assert (
_oai_api.build_oai_function_description(fi)["function"]["parameters"][
"properties"
]["foo"]["type"]
== oai_type
)
import time
import pytest
from livekit.agents.utils import ConnectionPool
# A simple dummy connection object.
class DummyConnection:
def __init__(self, id):
self.id = id
def __repr__(self):
return f"DummyConnection({self.id})"
# Factory to produce a dummy async connect callback that returns unique DummyConnection objects.
def dummy_connect_factory():
counter = 0
async def dummy_connect():
nonlocal counter
counter += 1
return DummyConnection(counter)
return dummy_connect
@pytest.mark.asyncio
async def test_get_reuses_connection():
"""
Test that when a connection is returned to the pool via put(),
the subsequent call to get() reuses the same connection if it hasn't expired.
"""
dummy_connect = dummy_connect_factory()
pool = ConnectionPool(max_session_duration=60, connect_cb=dummy_connect)
conn1 = await pool.get()
# Return the connection to the pool
pool.put(conn1)
async with pool.connection() as conn:
assert conn is conn1, "Expected conn to be the same connection as conn1"
conn2 = await pool.get()
assert conn1 is conn2, (
"Expected the same connection to be reused when it hasn't expired."
)
@pytest.mark.asyncio
async def test_get_creates_new_connection_when_none_available():
"""
Test that get() creates a new connection when there are no available connections.
"""
dummy_connect = dummy_connect_factory()
pool = ConnectionPool(max_session_duration=60, connect_cb=dummy_connect)
conn1 = await pool.get()
# Not putting conn1 back means the available pool is empty,
# so calling get() again should create a new connection.
conn2 = await pool.get()
assert conn1 is not conn2, (
"Expected a new connection when no available connection exists."
)
@pytest.mark.asyncio
async def test_remove_connection():
"""
Test that after removing a connection, the connection is not reused.
"""
dummy_connect = dummy_connect_factory()
pool = ConnectionPool(max_session_duration=60, connect_cb=dummy_connect)
conn = await pool.get()
pool.put(conn)
# Reset the connection which should remove it from the pool.
pool.remove(conn)
# Even if we try to put it back, it won't be added because it's not tracked anymore.
pool.put(conn)
new_conn = await pool.get()
assert new_conn is not conn, "Expected a removed connection to not be reused."
@pytest.mark.asyncio
async def test_get_expired():
"""
Test that get() returns a new connection if the previous connection has expired.
"""
# Use a short max duration to simulate expiration.
dummy_connect = dummy_connect_factory()
pool = ConnectionPool(max_session_duration=1, connect_cb=dummy_connect)
conn = await pool.get()
pool.put(conn)
# Artificially set the connection's timestamp in the past to simulate expiration.
pool._connections[conn] = (
time.time() - 2
) # 2 seconds ago (max_session_duration is 1)
conn2 = await pool.get()
assert conn2 is not conn, "Expected a new connection to be returned."
import enum
from inspect import _empty
from typing import Annotated, List, Optional
import pytest
from livekit.agents import llm
from livekit.plugins.openai import _oai_api
def test_func_basic():
class TestFunctionContext(llm.FunctionContext):
@llm.ai_callable(name="test_function", description="A simple test function")
def test_fn(
self, param: Annotated[str, llm.TypeInfo(description="A string parameter")]
):
pass
fnc_ctx = TestFunctionContext()
assert "test_function" in fnc_ctx.ai_functions, (
"Function should be registered in ai_functions"
)
fnc_info = fnc_ctx.ai_functions["test_function"]
build_info = _oai_api.build_oai_function_description(fnc_info)
assert fnc_info.name == build_info["function"]["name"]
assert fnc_info.description == build_info["function"]["description"]
assert not fnc_info.auto_retry
assert "param" in fnc_info.arguments
assert "param" in build_info["function"]["parameters"]["properties"]
assert "param" in build_info["function"]["parameters"]["required"]
arg_info = fnc_info.arguments["param"]
build_arg_info = build_info["function"]["parameters"]["properties"]["param"]
assert arg_info.name == "param"
assert arg_info.description == "A string parameter"
assert arg_info.type is str
assert arg_info.default is _empty
assert arg_info.choices == ()
assert build_arg_info["description"] == arg_info.description
assert build_arg_info["type"] == "string"
def test_func_duplicate():
class TestFunctionContext(llm.FunctionContext):
@llm.ai_callable(
name="duplicate_function", description="A simple test function"
)
def fn1(self):
pass
@llm.ai_callable(
name="duplicate_function", description="A simple test function"
)
def fn2(self):
pass
with pytest.raises(
ValueError, match="duplicate ai_callable name: duplicate_function"
):
TestFunctionContext()
def test_func_with_docstring():
class TestFunctionContext(llm.FunctionContext):
@llm.ai_callable()
def test_fn(self):
"""A simple test function"""
pass
fnc_ctx = TestFunctionContext()
assert "test_fn" in fnc_ctx.ai_functions, (
"Function should be registered in ai_functions"
)
assert fnc_ctx.ai_functions["test_fn"].description == "A simple test function"
def test_func_with_optional_parameter():
class TestFunctionContext(llm.FunctionContext):
@llm.ai_callable(
name="optional_function", description="Function with optional parameter"
)
def optional_fn(
self,
param: Annotated[
Optional[int], llm.TypeInfo(description="An optional integer parameter")
] = None,
param2: Optional[List[str]] = None,
param3: str = "A string",
):
pass
fnc_ctx = TestFunctionContext()
assert "optional_function" in fnc_ctx.ai_functions, (
"Function should be registered in ai_functions"
)
fnc_info = fnc_ctx.ai_functions["optional_function"]
build_info = _oai_api.build_oai_function_description(fnc_info)
print(build_info)
assert fnc_info.name == build_info["function"]["name"]
assert fnc_info.description == build_info["function"]["description"]
assert "param" in fnc_info.arguments
assert "param2" in fnc_info.arguments
assert "param3" in fnc_info.arguments
assert "param" in build_info["function"]["parameters"]["properties"]
assert "param2" in build_info["function"]["parameters"]["properties"]
assert "param3" in build_info["function"]["parameters"]["properties"]
assert "param" not in build_info["function"]["parameters"]["required"]
assert "param2" not in build_info["function"]["parameters"]["required"]
assert "param3" not in build_info["function"]["parameters"]["required"]
# Check 'param'
arg_info = fnc_info.arguments["param"]
build_arg_info = build_info["function"]["parameters"]["properties"]["param"]
assert arg_info.name == "param"
assert arg_info.description == "An optional integer parameter"
assert arg_info.type == Optional[int]
assert arg_info.default is None
assert arg_info.choices == ()
assert build_arg_info["description"] == arg_info.description
assert build_arg_info["type"] == "number"
# Check 'param2'
arg_info = fnc_info.arguments["param2"]
build_arg_info = build_info["function"]["parameters"]["properties"]["param2"]
assert arg_info.name == "param2"
assert arg_info.description == ""
assert arg_info.type == Optional[List[str]]
assert arg_info.default is None
assert arg_info.choices == ()
assert build_arg_info["type"] == "array"
assert build_arg_info["items"]["type"] == "string"
# check 'param3'
arg_info = fnc_info.arguments["param3"]
build_arg_info = build_info["function"]["parameters"]["properties"]["param3"]
assert arg_info.name == "param3"
assert arg_info.description == ""
assert arg_info.type is str
assert arg_info.default == "A string"
assert arg_info.choices == ()
assert build_arg_info["type"] == "string"
def test_func_with_list_parameter():
class TestFunctionContext(llm.FunctionContext):
@llm.ai_callable(
name="list_function", description="Function with list parameter"
)
def list_fn(
self,
items: Annotated[List[str], llm.TypeInfo(description="A list of strings")],
):
pass
fnc_ctx = TestFunctionContext()
assert "list_function" in fnc_ctx.ai_functions, (
"Function should be registered in ai_functions"
)
fnc_info = fnc_ctx.ai_functions["list_function"]
build_info = _oai_api.build_oai_function_description(fnc_info)
assert fnc_info.name == build_info["function"]["name"]
assert fnc_info.description == build_info["function"]["description"]
assert not fnc_info.auto_retry
assert "items" in fnc_info.arguments
assert "items" in build_info["function"]["parameters"]["properties"]
assert "items" in build_info["function"]["parameters"]["required"]
arg_info = fnc_info.arguments["items"]
build_arg_info = build_info["function"]["parameters"]["properties"]["items"]
assert arg_info.name == "items"
assert arg_info.description == "A list of strings"
assert arg_info.type is List[str]
assert arg_info.default is _empty
assert arg_info.choices == ()
assert build_arg_info["description"] == arg_info.description
assert build_arg_info["type"] == "array"
assert build_arg_info["items"]["type"] == "string"
def test_func_with_enum_parameter():
class Status(enum.Enum):
ACTIVE = "active"
INACTIVE = "inactive"
PENDING = "pending"
class TestFunctionContext(llm.FunctionContext):
@llm.ai_callable(
name="enum_function", description="Function with enum parameter"
)
def enum_fn(
self,
status: Annotated[Status, llm.TypeInfo(description="Status of the entity")],
):
pass
fnc_ctx = TestFunctionContext()
assert "enum_function" in fnc_ctx.ai_functions, (
"Function should be registered in ai_functions"
)
fnc_info = fnc_ctx.ai_functions["enum_function"]
build_info = _oai_api.build_oai_function_description(fnc_info)
assert fnc_info.name == build_info["function"]["name"]
assert fnc_info.description == build_info["function"]["description"]
assert not fnc_info.auto_retry
assert "status" in fnc_info.arguments
assert "status" in build_info["function"]["parameters"]["properties"]
assert "status" in build_info["function"]["parameters"]["required"]
arg_info = fnc_info.arguments["status"]
build_arg_info = build_info["function"]["parameters"]["properties"]["status"]
assert arg_info.name == "status"
assert arg_info.description == "Status of the entity"
assert arg_info.type is str # Enum values are converted to their underlying type
assert arg_info.default is _empty
assert arg_info.choices == ("active", "inactive", "pending")
assert build_arg_info["description"] == arg_info.description
assert build_arg_info["type"] == "string"
assert build_arg_info["enum"] == arg_info.choices
import os
import threading
import time
from concurrent.futures import ThreadPoolExecutor
import aiohttp
import pytest
from livekit.agents.stt import SpeechEventType
from livekit.agents.utils.codecs import AudioStreamDecoder, StreamBuffer
from livekit.plugins import deepgram
from .utils import wer
TEST_AUDIO_FILEPATH = os.path.join(os.path.dirname(__file__), "change-sophie.opus")
@pytest.mark.asyncio
async def test_decode_and_transcribe():
# Skip if test file doesn't exist
if not os.path.exists(TEST_AUDIO_FILEPATH):
pytest.skip(f"Test file not found: {TEST_AUDIO_FILEPATH}")
decoder = AudioStreamDecoder()
with open(TEST_AUDIO_FILEPATH, "rb") as f:
opus_data = f.read()
decoder.push(opus_data)
decoder.end_input()
session = aiohttp.ClientSession()
stt = deepgram.STT(http_session=session)
stream = stt.stream()
# Push frames to STT
async for frame in decoder:
stream.push_frame(frame)
# Mark end of input
stream.end_input()
# Collect results
final_text = ""
async for event in stream:
if event.type == SpeechEventType.FINAL_TRANSCRIPT:
if event.alternatives:
if final_text:
final_text += " "
final_text += event.alternatives[0].text
await decoder.aclose()
await stream.aclose()
await session.close()
# Verify the transcription
expected_text = "the people that are crazy enough to think they can change the world are the ones who do"
assert wer(final_text, expected_text) < 0.2
def test_stream_buffer():
buffer = StreamBuffer()
data_chunks = [b"hello", b"world", b"test", b"data"]
received_data = bytearray()
write_completed = threading.Event()
def writer():
for chunk in data_chunks:
buffer.write(chunk)
time.sleep(0.01) # Simulate some processing time
buffer.end_input()
write_completed.set()
def reader():
while True:
data = buffer.read(4) # Read in small chunks
if not data: # EOF
break
received_data.extend(data)
# Run writer and reader in separate threads
with ThreadPoolExecutor(max_workers=2) as executor:
reader_future = executor.submit(reader)
writer_future = executor.submit(writer)
# Wait for both threads to complete
writer_future.result()
reader_future.result()
# Verify that all data was received correctly
expected_data = b"".join(data_chunks)
assert bytes(received_data) == expected_data
def test_stream_buffer_large_chunks():
import hashlib
buffer = StreamBuffer()
large_chunk = os.urandom(1024 * 1024) # 1MB of random bytes
num_chunks = 5
total_size = 0
write_completed = threading.Event()
input_hasher = hashlib.sha256()
def writer():
nonlocal total_size
for _ in range(num_chunks):
buffer.write(large_chunk)
total_size += len(large_chunk)
input_hasher.update(large_chunk)
buffer.end_input()
write_completed.set()
received_size = 0
output_hasher = hashlib.sha256()
def reader():
nonlocal received_size
# allow writer to start first
time.sleep(1)
while True:
chunk = buffer.read(8192) # Read in 8KB chunks
if not chunk:
break
received_size += len(chunk)
output_hasher.update(chunk)
# Run writer and reader in separate threads
with ThreadPoolExecutor(max_workers=2) as executor:
reader_future = executor.submit(reader)
writer_future = executor.submit(writer)
# Wait for both threads to complete
writer_future.result()
reader_future.result()
assert received_size == total_size
assert total_size == num_chunks * len(large_chunk)
assert input_hasher.hexdigest() == output_hasher.hexdigest()
def test_stream_buffer_early_close():
buffer = StreamBuffer()
# Write some data
buffer.write(b"test data")
# Close the buffer
buffer.close()
# Reading from closed buffer should return empty bytes
assert buffer.read() == b""
from __future__ import annotations
import asyncio
import ctypes
import io
import multiprocessing as mp
import socket
import time
import uuid
from dataclasses import dataclass
from multiprocessing.context import BaseContext
from typing import ClassVar
import psutil
from livekit.agents import JobContext, JobProcess, ipc, job, utils
from livekit.protocol import agent
@dataclass
class EmptyMessage:
MSG_ID: ClassVar[int] = 0
@dataclass
class SomeDataMessage:
MSG_ID: ClassVar[int] = 1
string: str = ""
number: int = 0
double: float = 0.0
data: bytes = b""
def write(self, b: io.BytesIO) -> None:
ipc.channel.write_string(b, self.string)
ipc.channel.write_int(b, self.number)
ipc.channel.write_double(b, self.double)
ipc.channel.write_bytes(b, self.data)
def read(self, b: io.BytesIO) -> None:
self.string = ipc.channel.read_string(b)
self.number = ipc.channel.read_int(b)
self.double = ipc.channel.read_double(b)
self.data = ipc.channel.read_bytes(b)
IPC_MESSAGES = {
EmptyMessage.MSG_ID: EmptyMessage,
SomeDataMessage.MSG_ID: SomeDataMessage,
}
def _echo_main(mp_cch):
async def _pong():
cch = await utils.aio.duplex_unix._AsyncDuplex.open(mp_cch)
while True:
try:
msg = await ipc.channel.arecv_message(cch, IPC_MESSAGES)
await ipc.channel.asend_message(cch, msg)
except utils.aio.duplex_unix.DuplexClosed:
print("_echo_main, duplex closed..")
break
asyncio.run(_pong())
async def test_async_channel():
mp_pch, mp_cch = socket.socketpair()
pch = await utils.aio.duplex_unix._AsyncDuplex.open(mp_pch)
proc = mp.get_context("spawn").Process(target=_echo_main, args=(mp_cch,))
proc.start()
mp_cch.close()
await ipc.channel.asend_message(pch, EmptyMessage())
assert await ipc.channel.arecv_message(pch, IPC_MESSAGES) == EmptyMessage()
await ipc.channel.asend_message(
pch, SomeDataMessage(string="hello", number=42, double=3.14, data=b"world")
)
assert await ipc.channel.arecv_message(pch, IPC_MESSAGES) == SomeDataMessage(
string="hello", number=42, double=3.14, data=b"world"
)
await pch.aclose()
await asyncio.sleep(0.5)
proc.terminate()
proc.join()
def test_sync_channel():
mp_pch, mp_cch = socket.socketpair()
pch = utils.aio.duplex_unix._Duplex.open(mp_pch)
proc = mp.get_context("spawn").Process(target=_echo_main, args=(mp_cch,))
proc.start()
mp_cch.close()
ipc.channel.send_message(pch, EmptyMessage())
assert ipc.channel.recv_message(pch, IPC_MESSAGES) == EmptyMessage()
ipc.channel.send_message(
pch, SomeDataMessage(string="hello", number=42, double=3.14, data=b"world")
)
assert ipc.channel.recv_message(pch, IPC_MESSAGES) == SomeDataMessage(
string="hello", number=42, double=3.14, data=b"world"
)
pch.close()
def _generate_fake_job() -> job.RunningJobInfo:
return job.RunningJobInfo(
job=agent.Job(
id="fake_job_" + str(uuid.uuid4().hex), type=agent.JobType.JT_ROOM
),
url="fake_url",
token="fake_token",
accept_arguments=job.JobAcceptArguments(name="", identity="", metadata=""),
worker_id="fake_id",
)
@dataclass
class _StartArgs:
initialize_counter: mp.Value
entrypoint_counter: mp.Value
shutdown_counter: mp.Value
initialize_simulate_work_time: float
entrypoint_simulate_work_time: float
shutdown_simulate_work_time: float
update_ev: mp.Condition
def _new_start_args(mp_ctx: BaseContext) -> _StartArgs:
return _StartArgs(
initialize_counter=mp_ctx.Value(ctypes.c_uint),
entrypoint_counter=mp_ctx.Value(ctypes.c_uint),
shutdown_counter=mp_ctx.Value(ctypes.c_uint),
initialize_simulate_work_time=0.0,
entrypoint_simulate_work_time=0.0,
shutdown_simulate_work_time=0.0,
update_ev=mp_ctx.Condition(),
)
def _initialize_proc(proc: JobProcess) -> None:
start_args: _StartArgs = proc.user_arguments
# incrementing isn't atomic (the lock should be reentrant by default)
with start_args.initialize_counter.get_lock():
start_args.initialize_counter.value += 1
time.sleep(start_args.initialize_simulate_work_time)
with start_args.update_ev:
start_args.update_ev.notify()
async def _job_entrypoint(job_ctx: JobContext) -> None:
start_args: _StartArgs = job_ctx.proc.user_arguments
async def _job_shutdown() -> None:
with start_args.shutdown_counter.get_lock():
start_args.shutdown_counter.value += 1
await asyncio.sleep(start_args.shutdown_simulate_work_time)
with start_args.update_ev:
start_args.update_ev.notify()
job_ctx.add_shutdown_callback(_job_shutdown)
with start_args.entrypoint_counter.get_lock():
start_args.entrypoint_counter.value += 1
await asyncio.sleep(start_args.entrypoint_simulate_work_time)
job_ctx.shutdown(
"calling shutdown inside the test to avoid a warning when neither shutdown nor connect is called."
)
with start_args.update_ev:
start_args.update_ev.notify()
async def _wait_for_elements(q: asyncio.Queue, num_elements: int) -> None:
for _ in range(num_elements):
await q.get()
async def test_proc_pool():
mp_ctx = mp.get_context("spawn")
loop = asyncio.get_running_loop()
num_idle_processes = 3
pool = ipc.proc_pool.ProcPool(
initialize_process_fnc=_initialize_proc,
job_entrypoint_fnc=_job_entrypoint,
num_idle_processes=num_idle_processes,
job_executor_type=job.JobExecutorType.PROCESS,
initialize_timeout=20.0,
close_timeout=20.0,
inference_executor=None,
memory_warn_mb=0,
memory_limit_mb=0,
mp_ctx=mp_ctx,
loop=loop,
)
start_args = _new_start_args(mp_ctx)
created_q = asyncio.Queue()
start_q = asyncio.Queue()
ready_q = asyncio.Queue()
close_q = asyncio.Queue()
pids = []
exitcodes = []
@pool.on("process_created")
def _process_created(proc: ipc.job_proc_executor.ProcJobExecutor):
created_q.put_nowait(None)
proc.user_arguments = start_args
@pool.on("process_started")
def _process_started(proc: ipc.job_proc_executor.ProcJobExecutor):
start_q.put_nowait(None)
pids.append(proc.pid)
@pool.on("process_ready")
def _process_ready(proc: ipc.job_proc_executor.ProcJobExecutor):
ready_q.put_nowait(None)
@pool.on("process_closed")
def _process_closed(proc: ipc.job_proc_executor.ProcJobExecutor):
close_q.put_nowait(None)
exitcodes.append(proc.exitcode)
pool.start()
await _wait_for_elements(created_q, num_idle_processes)
await _wait_for_elements(start_q, num_idle_processes)
await _wait_for_elements(ready_q, num_idle_processes)
assert start_args.initialize_counter.value == num_idle_processes
jobs_to_start = 2
for _ in range(jobs_to_start):
await pool.launch_job(_generate_fake_job())
await _wait_for_elements(created_q, jobs_to_start)
await _wait_for_elements(start_q, jobs_to_start)
await _wait_for_elements(ready_q, jobs_to_start)
await pool.aclose()
assert start_args.entrypoint_counter.value == jobs_to_start
assert start_args.shutdown_counter.value == jobs_to_start
await _wait_for_elements(close_q, num_idle_processes + jobs_to_start)
# the way we check that a process doesn't exist anymore isn't technically reliable (pid recycle could happen)
for pid in pids:
assert not psutil.pid_exists(pid)
for exitcode in exitcodes:
# this test expects graceful shutdown, kill is tested on another test
assert exitcode == 0, f"process did not exit cleanly: {exitcode}"
async def test_slow_initialization():
mp_ctx = mp.get_context("spawn")
loop = asyncio.get_running_loop()
num_idle_processes = 2
pool = ipc.proc_pool.ProcPool(
job_executor_type=job.JobExecutorType.PROCESS,
initialize_process_fnc=_initialize_proc,
job_entrypoint_fnc=_job_entrypoint,
num_idle_processes=num_idle_processes,
initialize_timeout=1.0,
close_timeout=20.0,
inference_executor=None,
memory_warn_mb=0,
memory_limit_mb=0,
mp_ctx=mp_ctx,
loop=loop,
)
start_args = _new_start_args(mp_ctx)
start_args.initialize_simulate_work_time = 2.0
start_q = asyncio.Queue()
close_q = asyncio.Queue()
pids = []
exitcodes = []
@pool.on("process_created")
def _process_created(proc: ipc.job_proc_executor.ProcJobExecutor):
proc.user_arguments = start_args
start_q.put_nowait(None)
@pool.on("process_closed")
def _process_closed(proc: ipc.job_proc_executor.ProcJobExecutor):
close_q.put_nowait(None)
pids.append(proc.pid)
exitcodes.append(proc.exitcode)
pool.start()
await _wait_for_elements(start_q, num_idle_processes)
await _wait_for_elements(close_q, num_idle_processes)
# after initialization failure, warmup should be retried
await _wait_for_elements(start_q, num_idle_processes)
await pool.aclose()
for pid in pids:
assert not psutil.pid_exists(pid)
for exitcode in exitcodes:
assert exitcode != 0, "process should have been killed"
def _create_proc(
*,
close_timeout: float,
mp_ctx: BaseContext,
initialize_timeout: float = 20.0,
) -> tuple[ipc.job_proc_executor.ProcJobExecutor, _StartArgs]:
start_args = _new_start_args(mp_ctx)
loop = asyncio.get_running_loop()
proc = ipc.job_proc_executor.ProcJobExecutor(
initialize_process_fnc=_initialize_proc,
job_entrypoint_fnc=_job_entrypoint,
initialize_timeout=initialize_timeout,
close_timeout=close_timeout,
memory_warn_mb=0,
memory_limit_mb=0,
ping_interval=2.5,
ping_timeout=10.0,
high_ping_threshold=1.0,
inference_executor=None,
mp_ctx=mp_ctx,
loop=loop,
)
proc.user_arguments = start_args
return proc, start_args
async def test_shutdown_no_job():
mp_ctx = mp.get_context("spawn")
proc, start_args = _create_proc(close_timeout=10.0, mp_ctx=mp_ctx)
await proc.start()
await proc.initialize()
await asyncio.sleep(1.0)
await proc.aclose()
assert proc.exitcode == 0
assert not proc.killed
assert start_args.shutdown_counter.value == 0, (
"shutdown_cb isn't called when there is no job"
)
async def test_job_slow_shutdown():
mp_ctx = mp.get_context("spawn")
proc, start_args = _create_proc(close_timeout=1.0, mp_ctx=mp_ctx)
start_args.shutdown_simulate_work_time = 10.0
await proc.start()
await proc.initialize()
await asyncio.sleep(1.0)
fake_job = _generate_fake_job()
await proc.launch_job(fake_job)
await asyncio.sleep(1.0)
await proc.aclose()
# process is killed when there is a job with slow timeout
assert proc.exitcode != 0, "process should have been killed"
assert proc.killed
async def test_job_graceful_shutdown():
mp_ctx = mp.get_context("spawn")
proc, start_args = _create_proc(close_timeout=10.0, mp_ctx=mp_ctx)
start_args.shutdown_simulate_work_time = 1.0
await proc.start()
await proc.initialize()
await asyncio.sleep(1.0)
fake_job = _generate_fake_job()
await proc.launch_job(fake_job)
await asyncio.sleep(1.0)
await proc.aclose()
assert proc.exitcode == 0, "process should have exited cleanly"
assert not proc.killed
assert start_args.shutdown_counter.value == 1
from __future__ import annotations
import asyncio
import base64
from enum import Enum
from pathlib import Path
from typing import Annotated, Callable, Literal, Optional, Union
import pytest
from livekit.agents import APIConnectionError, llm
from livekit.agents.llm import ChatContext, FunctionContext, TypeInfo, ai_callable
from livekit.plugins import anthropic, aws, google, openai
from livekit.rtc import VideoBufferType, VideoFrame
class Unit(Enum):
FAHRENHEIT = "fahrenheit"
CELSIUS = "celsius"
class FncCtx(FunctionContext):
@ai_callable(
description="Get the current weather in a given location", auto_retry=True
)
def get_weather(
self,
location: Annotated[
str, TypeInfo(description="The city and state, e.g. San Francisco, CA")
],
unit: Annotated[
Unit, TypeInfo(description="The temperature unit to use.")
] = Unit.CELSIUS,
) -> None: ...
@ai_callable(description="Play a music")
def play_music(
self,
name: Annotated[str, TypeInfo(description="the name of the Artist")],
) -> None: ...
# test for cancelled calls
@ai_callable(description="Turn on/off the lights in a room")
async def toggle_light(
self,
room: Annotated[str, TypeInfo(description="The room to control")],
on: bool = True,
) -> None:
await asyncio.sleep(60)
# used to test arrays as arguments
@ai_callable(description="Schedule recurring events on selected days")
def schedule_meeting(
self,
meeting_days: Annotated[
list[str],
TypeInfo(
description="The days of the week on which meetings will occur",
choices=[
"monday",
"tuesday",
"wednesday",
"thursday",
"friday",
"saturday",
"sunday",
],
),
],
) -> None: ...
@ai_callable(description="Update user info")
def update_user_info(
self,
email: Annotated[
Optional[str], TypeInfo(description="The user address email")
] = None,
name: Annotated[Optional[str], TypeInfo(description="The user name")] = None,
address: Optional[
Annotated[str, TypeInfo(description="The user address")]
] = None,
) -> None: ...
def test_hashable_typeinfo():
typeinfo = TypeInfo(description="testing", choices=[1, 2, 3])
# TypeInfo must be hashable when used in combination of typing.Annotated
hash(typeinfo)
LLMS: list[Callable[[], llm.LLM]] = [
pytest.param(lambda: openai.LLM(), id="openai"),
# lambda: openai.beta.AssistantLLM(
# assistant_opts=openai.beta.AssistantOptions(
# create_options=openai.beta.AssistantCreateOptions(
# name=f"test-{uuid.uuid4()}",
# instructions="You are a basic assistant",
# model="gpt-4o",
# )
# )
# ),
pytest.param(lambda: anthropic.LLM(), id="anthropic"),
pytest.param(lambda: google.LLM(), id="google"),
pytest.param(lambda: google.LLM(vertexai=True), id="google-vertexai"),
pytest.param(lambda: aws.LLM(), id="aws"),
]
@pytest.mark.parametrize("llm_factory", LLMS)
async def test_chat(llm_factory: Callable[[], llm.LLM]):
input_llm = llm_factory()
chat_ctx = ChatContext().append(
text='You are an assistant at a drive-thru restaurant "Live-Burger". Ask the customer what they would like to order.'
)
# Anthropic and vertex requires at least one message (system messages don't count)
chat_ctx.append(
text="Hello",
role="user",
)
stream = input_llm.chat(chat_ctx=chat_ctx)
text = ""
async for chunk in stream:
if not chunk.choices:
continue
content = chunk.choices[0].delta.content
if content:
text += content
assert len(text) > 0
@pytest.mark.parametrize("llm_factory", LLMS)
async def test_llm_chat_with_consecutive_messages(
llm_factory: callable,
):
input_llm = llm_factory()
chat_ctx = ChatContext()
chat_ctx.append(
text="Hello, How can I help you today?",
role="assistant",
)
chat_ctx.append(text="I see that you have a busy day ahead.", role="assistant")
chat_ctx.append(
text="Actually, I need some help with my recent order.", role="user"
)
chat_ctx.append(text="I want to cancel my order.", role="user")
stream = input_llm.chat(chat_ctx=chat_ctx)
collected_text = ""
async for chunk in stream:
if not chunk.choices:
continue
content = chunk.choices[0].delta.content
if content:
collected_text += content
assert len(collected_text) > 0, "Expected a non-empty response from the LLM chat"
@pytest.mark.parametrize("llm_factory", LLMS)
async def test_basic_fnc_calls(llm_factory: Callable[[], llm.LLM]):
input_llm = llm_factory()
fnc_ctx = FncCtx()
stream = await _request_fnc_call(
input_llm,
"What's the weather in San Francisco and what's the weather Paris?",
fnc_ctx,
)
calls = stream.execute_functions()
await asyncio.gather(*[f.task for f in calls])
await stream.aclose()
assert len(calls) == 2, "get_weather should be called twice"
@pytest.mark.parametrize("llm_factory", LLMS)
async def test_function_call_exception_handling(llm_factory: Callable[[], llm.LLM]):
input_llm = llm_factory()
fnc_ctx = FncCtx()
@fnc_ctx.ai_callable(description="Simulate a failure")
async def failing_function():
raise RuntimeError("Simulated failure")
stream = await _request_fnc_call(input_llm, "Call the failing function", fnc_ctx)
calls = stream.execute_functions()
await asyncio.gather(*[f.task for f in calls], return_exceptions=True)
await stream.aclose()
assert len(calls) == 1
assert isinstance(calls[0].exception, RuntimeError)
assert str(calls[0].exception) == "Simulated failure"
@pytest.mark.parametrize("llm_factory", LLMS)
async def test_runtime_addition(llm_factory: Callable[[], llm.LLM]):
input_llm = llm_factory()
fnc_ctx = FncCtx()
called_msg = ""
@fnc_ctx.ai_callable(description="Show a message on the screen")
async def show_message(
message: Annotated[str, TypeInfo(description="The message to show")],
):
nonlocal called_msg
called_msg = message
stream = await _request_fnc_call(
input_llm, "Can you show 'Hello LiveKit!' on the screen?", fnc_ctx
)
fns = stream.execute_functions()
await asyncio.gather(*[f.task for f in fns])
await stream.aclose()
assert called_msg == "Hello LiveKit!", "send_message should be called"
@pytest.mark.parametrize("llm_factory", LLMS)
async def test_cancelled_calls(llm_factory: Callable[[], llm.LLM]):
input_llm = llm_factory()
fnc_ctx = FncCtx()
stream = await _request_fnc_call(
input_llm, "Turn off the lights in the bedroom", fnc_ctx
)
calls = stream.execute_functions()
await asyncio.sleep(0.2) # wait for the loop executor to start the task
# don't wait for gather_function_results and directly close (this should cancel the ongoing calls)
await stream.aclose()
assert len(calls) == 1
assert isinstance(calls[0].exception, asyncio.CancelledError), (
"toggle_light should have been cancelled"
)
@pytest.mark.parametrize("llm_factory", LLMS)
async def test_calls_arrays(llm_factory: Callable[[], llm.LLM]):
input_llm = llm_factory()
fnc_ctx = FncCtx()
stream = await _request_fnc_call(
input_llm,
"can you schedule a meeting on monday and wednesday?",
fnc_ctx,
temperature=0.2,
)
calls = stream.execute_functions()
await asyncio.gather(*[f.task for f in calls])
await stream.aclose()
assert len(calls) == 1, "schedule_meeting should have been called only once"
call = calls[0]
meeting_days = call.call_info.arguments["meeting_days"]
assert len(meeting_days) == 2, "schedule_meeting should have 2 days"
assert "monday" in meeting_days and "wednesday" in meeting_days, (
"meeting_days should have monday, wednesday"
)
@pytest.mark.parametrize("llm_factory", LLMS)
async def test_calls_choices(llm_factory: Callable[[], llm.LLM]):
input_llm = llm_factory()
fnc_ctx = FncCtx()
# test choices on int
@fnc_ctx.ai_callable(description="Change the volume")
def change_volume(
volume: Annotated[
int, TypeInfo(description="The volume level", choices=[0, 11, 30, 83, 99])
],
) -> None: ...
if not input_llm.capabilities.supports_choices_on_int:
with pytest.raises(APIConnectionError):
stream = await _request_fnc_call(input_llm, "Set the volume to 30", fnc_ctx)
else:
stream = await _request_fnc_call(input_llm, "Set the volume to 30", fnc_ctx)
calls = stream.execute_functions()
await asyncio.gather(*[f.task for f in calls])
await stream.aclose()
assert len(calls) == 1, "change_volume should have been called only once"
call = calls[0]
volume = call.call_info.arguments["volume"]
assert volume == 30, "change_volume should have been called with volume 30"
@pytest.mark.parametrize("llm_factory", LLMS)
async def test_optional_args(llm_factory: Callable[[], llm.LLM]):
input_llm = llm_factory()
fnc_ctx = FncCtx()
stream = await _request_fnc_call(
input_llm, "Using a tool call update the user info to name Theo", fnc_ctx
)
calls = stream.execute_functions()
await asyncio.gather(*[f.task for f in calls])
await stream.aclose()
assert len(calls) == 1, "update_user_info should have been called only once"
call = calls[0]
name = call.call_info.arguments.get("name", None)
email = call.call_info.arguments.get("email", None)
address = call.call_info.arguments.get("address", None)
assert name == "Theo", "update_user_info should have been called with name 'Theo'"
assert email is None, "update_user_info should have been called with email None"
assert address is None, "update_user_info should have been called with address None"
test_tool_choice_cases = [
pytest.param(
"Default tool_choice (auto)",
"Get the weather for New York and play some music from the artist 'The Beatles'.",
None,
{"get_weather", "play_music"},
id="Default tool_choice (auto)",
),
pytest.param(
"Tool_choice set to 'required'",
"Get the weather for Chicago and play some music from the artist 'Eminem'.",
"required",
{"get_weather", "play_music"},
id="Tool_choice set to 'required'",
),
pytest.param(
"Tool_choice set to a specific tool ('get_weather')",
"Get the weather for Miami.",
llm.ToolChoice(type="function", name="get_weather"),
{"get_weather"},
id="Tool_choice set to a specific tool ('get_weather')",
),
pytest.param(
"Tool_choice set to 'none'",
"Get the weather for Seattle and play some music from the artist 'Frank Sinatra'.",
"none",
set(), # No tool calls expected
id="Tool_choice set to 'none'",
),
]
@pytest.mark.parametrize(
"description, user_request, tool_choice, expected_calls", test_tool_choice_cases
)
@pytest.mark.parametrize("llm_factory", LLMS)
async def test_tool_choice_options(
description: str,
user_request: str,
tool_choice: Union[dict, str, None],
expected_calls: set,
llm_factory: Callable[[], llm.LLM],
):
input_llm = llm_factory()
fnc_ctx = FncCtx()
stream = await _request_fnc_call(
input_llm,
user_request,
fnc_ctx,
tool_choice=tool_choice,
parallel_tool_calls=True,
)
calls = stream.execute_functions()
await asyncio.gather(*[f.task for f in calls], return_exceptions=True)
await stream.aclose()
print(calls)
call_names = {call.call_info.function_info.name for call in calls}
if tool_choice == "none":
assert call_names == expected_calls, (
f"Test '{description}' failed: Expected calls {expected_calls}, but got {call_names}"
)
async def _request_fnc_call(
model: llm.LLM,
request: str,
fnc_ctx: FncCtx,
temperature: float | None = None,
parallel_tool_calls: bool | None = None,
tool_choice: Union[llm.ToolChoice, Literal["auto", "required", "none"]]
| None = None,
) -> llm.LLMStream:
stream = model.chat(
chat_ctx=ChatContext()
.append(
text="You are an helpful assistant. Follow the instructions provided by the user. You can use multiple tool calls at once.",
role="system",
)
.append(text=request, role="user"),
fnc_ctx=fnc_ctx,
temperature=temperature,
tool_choice=tool_choice,
parallel_tool_calls=parallel_tool_calls,
)
async for _ in stream:
pass
return stream
_HEARTS_RGBA_PATH = Path(__file__).parent / "hearts.rgba"
with open(_HEARTS_RGBA_PATH, "rb") as f:
image_data = f.read()
_HEARTS_IMAGE_VIDEO_FRAME = VideoFrame(
width=512, height=512, type=VideoBufferType.RGBA, data=image_data
)
_HEARTS_JPEG_PATH = Path(__file__).parent / "hearts.jpg"
with open(_HEARTS_JPEG_PATH, "rb") as f:
_HEARTS_IMAGE_DATA_URL = (
f"data:image/jpeg;base64,{base64.b64encode(f.read()).decode()}"
)
@pytest.mark.parametrize("llm_factory", LLMS)
async def test_chat_with_image_data_url(llm_factory: Callable[[], llm.LLM]):
input_llm = llm_factory()
chat_ctx = (
ChatContext()
.append(
text="You are an AI assistant that describes images in detail upon request.",
role="system",
)
.append(
text="Describe this image",
images=[
llm.ChatImage(image=_HEARTS_IMAGE_DATA_URL, inference_detail="low")
],
role="user",
)
)
stream = input_llm.chat(chat_ctx=chat_ctx)
text = ""
async for chunk in stream:
if not chunk.choices:
continue
content = chunk.choices[0].delta.content
if content:
text += content
assert "heart" in text.lower()
@pytest.mark.parametrize("llm_factory", LLMS)
async def test_chat_with_image_frame(llm_factory: Callable[[], llm.LLM]):
input_llm = llm_factory()
chat_ctx = (
ChatContext()
.append(
text="You are an AI assistant that describes images in detail upon request.",
role="system",
)
.append(
text="Describe this image",
images=[
llm.ChatImage(image=_HEARTS_IMAGE_VIDEO_FRAME, inference_detail="low")
],
role="user",
)
)
stream = input_llm.chat(chat_ctx=chat_ctx)
text = ""
async for chunk in stream:
if not chunk.choices:
continue
content = chunk.choices[0].delta.content
if content:
text += content
assert "heart" in text.lower()
import pytest
from livekit.agents.llm import ChatMessage
from livekit.agents.utils._message_change import (
_check_order_preserved,
_compute_list_changes,
_find_longest_increasing_subsequence,
compute_changes,
)
@pytest.mark.parametrize(
"indices,expected_seq,desc",
[
# Basic cases
([0, 1, 2], [0, 1, 2], "Already sorted"),
([2, 1, 0], [2], "Must keep first (2)"),
([2, 0, 1], [2], "Must keep first (2)"),
([2, 1, 0, 3], [2, 3], "Keep first and what can follow"),
([3, 0, 1, 2], [3], "Only first when nothing can follow"),
([2, 1, 0, 3, 4], [2, 3, 4], "Keep first and increasing suffix"),
([4, 1, 2, 3], [4], "Only first when better sequence exists"),
([0, 1, 4, 2], [0, 1, 4], "Keep longest increasing with first"),
# Edge cases
([], [], "Empty list"),
([0], [0], "Single element"),
([1], [1], "Single element not zero"),
([2, 1], [2], "Two elements, keep first"),
],
)
def test_find_longest_increasing_subsequence(indices, expected_seq, desc):
"""Test the LIS algorithm with various cases"""
result = _find_longest_increasing_subsequence(indices)
result_seq = [indices[i] for i in result] if result else []
# Verify sequence is increasing
if result_seq:
assert all(
result_seq[i] < result_seq[i + 1] for i in range(len(result_seq) - 1)
), f"Not increasing in {desc}"
# Verify first element is included
if result:
assert result[0] == 0, f"First index not included in {desc}"
# Verify sequence matches expected
assert result_seq == expected_seq, (
f"Wrong sequence in {desc}: expected {expected_seq}, got {result_seq}"
)
@pytest.mark.parametrize(
"indices,expected",
[
([], True),
([0], True),
([0, 1, 2], True),
([0, 2, 1], False),
([1, 1, 2], False),
],
)
def test_check_order_preserved(indices, expected):
assert _check_order_preserved(indices) is expected
@pytest.mark.parametrize(
"old,new,expected_delete,expected_add",
[
# Empty lists
([], [], [], []),
(["a"], [], ["a"], []),
([], ["a"], [], [(None, "a")]),
# Simple append/delete
(["a", "b", "c"], ["a", "b", "c", "d"], [], [("c", "d")]),
(["a", "b", "c", "d"], ["a", "b", "c"], ["d"], []),
# Delete first item
(["a", "b", "c", "d"], ["b", "c", "d", "e"], ["a"], [("d", "e")]),
(["x", "y", "b", "c"], ["b", "c", "d"], ["x", "y"], [("c", "d")]),
# New first item - must delete all
(
["a", "b", "c", "d"],
["e", "a", "b", "c"],
["a", "b", "c", "d"],
[(None, "e"), ("e", "a"), ("a", "b"), ("b", "c")],
),
# First item exists but order changes
(["a", "b", "c", "d"], ["b", "a", "c", "d"], ["a"], [("b", "a")]),
(["x", "y", "b", "c"], ["b", "d", "c"], ["x", "y"], [("b", "d")]),
# Complex reordering
(
["a", "b", "c", "d"],
["a", "b", "d", "e", "c"],
["d"],
[("b", "d"), ("d", "e")],
),
(
["a", "b", "c", "d"],
["a", "d", "c", "b"],
["c", "d"],
[("a", "d"), ("d", "c")],
),
],
)
def test_compute_list_changes(old, new, expected_delete, expected_add):
changes = _compute_list_changes(old, new)
assert changes.to_delete == expected_delete
assert changes.to_add == expected_add
@pytest.mark.parametrize(
"old_ids,new_ids",
[
(["a", "b", "c", "d"], ["b", "c", "d", "e"]),
(["a", "b", "c", "d"], ["e", "a", "b", "c"]),
(["a", "b", "c", "d"], ["a", "b", "d", "e", "c"]),
],
)
def test_compute_changes(old_ids, new_ids):
"""Test computing changes with ChatMessage objects"""
def create_msg(id: str) -> ChatMessage:
return ChatMessage(role="test", id=id)
old = [create_msg(id) for id in old_ids]
new = [create_msg(id) for id in new_ids]
changes = compute_changes(old, new, lambda msg: msg.id)
# Apply changes and verify result
result = [msg for msg in old if msg not in changes.to_delete]
for prev, msg in changes.to_add:
if prev is None:
result.append(msg)
else:
idx = result.index(prev) + 1
result.insert(idx, msg)
assert [msg.id for msg in result] == new_ids
@pytest.mark.parametrize(
"old,new",
[
(["a", "b", "c", "d"], ["b", "c", "d", "e"]),
(["a", "b", "c", "d"], ["e", "a", "b", "c"]),
(["a", "b", "c", "d"], ["a", "b", "d", "e", "c"]),
(["a", "b", "c", "d"], ["b", "a", "c", "d"]),
(["x", "y", "b", "c"], ["b", "d", "c"]),
(["a", "b", "c", "d"], ["a", "d", "c", "b"]),
],
)
def test_changes_maintain_list_integrity(old, new):
"""Test that applying changes maintains list integrity"""
def apply_changes(old: list[str], changes):
# Apply deletions
result = [x for x in old if x not in changes.to_delete]
# Apply insertions
for prev, item in changes.to_add:
if prev is None:
result.append(item)
else:
idx = result.index(prev) + 1
result.insert(idx, item)
return result
changes = _compute_list_changes(old, new)
result = apply_changes(old, changes)
assert result == new, f"Failed to transform {old} into {new}, got {result}"
"""
Do speech recognition on a long audio file and compare the result with the expected transcript
"""
import asyncio
import time
from typing import Callable
import pytest
from livekit import agents
from livekit.agents import stt
from livekit.plugins import (
assemblyai,
aws,
azure,
deepgram,
fal,
google,
openai,
silero,
speechmatics,
)
from .utils import make_test_speech, wer
SAMPLE_RATES = [24000, 44100] # test multiple input sample rates
WER_THRESHOLD = 0.25
RECOGNIZE_STT: list[Callable[[], stt.STT]] = [
pytest.param(lambda: deepgram.STT(), id="deepgram"),
# pytest.param(lambda: google.STT(), id="google"),
# pytest.param(
# lambda: google.STT(
# languages=["en-AU"],
# model="chirp_2",
# spoken_punctuation=False,
# location="us-central1",
# ),
# id="google.chirp_2",
# ),
pytest.param(lambda: openai.STT(), id="openai"),
pytest.param(lambda: fal.WizperSTT(), id="fal"),
]
@pytest.mark.usefixtures("job_process")
@pytest.mark.parametrize("stt_factory", RECOGNIZE_STT)
@pytest.mark.parametrize("sample_rate", SAMPLE_RATES)
async def test_recognize(stt_factory, sample_rate):
async with stt_factory() as stt:
frames, transcript = await make_test_speech(sample_rate=sample_rate)
start_time = time.time()
event = await stt.recognize(buffer=frames)
text = event.alternatives[0].text
dt = time.time() - start_time
print(f"WER: {wer(text, transcript)} for {stt} in {dt:.2f}s")
assert wer(text, transcript) <= WER_THRESHOLD
assert event.type == agents.stt.SpeechEventType.FINAL_TRANSCRIPT
STREAM_VAD = silero.VAD.load(min_silence_duration=0.75)
STREAM_STT: list[Callable[[], stt.STT]] = [
pytest.param(lambda: aws.STT(), id="aws"),
pytest.param(lambda: assemblyai.STT(), id="assemblyai"),
pytest.param(lambda: deepgram.STT(), id="deepgram"),
pytest.param(lambda: google.STT(), id="google"),
pytest.param(
lambda: agents.stt.StreamAdapter(stt=openai.STT(), vad=STREAM_VAD),
id="openai.stream",
),
pytest.param(
lambda: agents.stt.StreamAdapter(stt=openai.STT.with_groq(), vad=STREAM_VAD),
id="openai.with_groq.stream",
),
pytest.param(
lambda: google.STT(
languages=["en-US"],
model="chirp_2",
spoken_punctuation=False,
location="us-central1",
),
id="google.chirp_2",
),
pytest.param(lambda: azure.STT(), id="azure"),
pytest.param(lambda: speechmatics.STT(), id="speechmatics"),
]
@pytest.mark.usefixtures("job_process")
@pytest.mark.parametrize("stt_factory", STREAM_STT)
@pytest.mark.parametrize("sample_rate", SAMPLE_RATES)
async def test_stream(stt_factory, sample_rate):
stt = stt_factory()
frames, transcript = await make_test_speech(
chunk_duration_ms=10, sample_rate=sample_rate
)
stream = stt.stream()
async def _stream_input():
for frame in frames:
stream.push_frame(frame)
await asyncio.sleep(0.005)
stream.end_input()
async def _stream_output():
text = ""
# make sure the events are sent in the right order
recv_start, recv_end = False, True
start_time = time.time()
async for event in stream:
if event.type == agents.stt.SpeechEventType.START_OF_SPEECH:
assert recv_end, (
"START_OF_SPEECH recv but no END_OF_SPEECH has been sent before"
)
assert not recv_start
recv_end = False
recv_start = True
continue
if event.type == agents.stt.SpeechEventType.FINAL_TRANSCRIPT:
if text != "":
text += " "
text += event.alternatives[0].text
# ensure STT is tagging languages correctly
language = event.alternatives[0].language
assert language is not None
assert language.lower().startswith("en")
if event.type == agents.stt.SpeechEventType.END_OF_SPEECH:
recv_start = False
recv_end = True
dt = time.time() - start_time
print(f"WER: {wer(text, transcript)} for streamed {stt} in {dt:.2f}s")
assert wer(text, transcript) <= WER_THRESHOLD
await asyncio.wait_for(
asyncio.gather(_stream_input(), _stream_output()), timeout=120
)
await stream.aclose()
from __future__ import annotations
import asyncio
import pytest
from livekit.agents import APIConnectionError, utils
from livekit.agents.stt import STT, AvailabilityChangedEvent, FallbackAdapter
from livekit.agents.utils.aio.channel import ChanEmpty
from .fake_stt import FakeSTT
class FallbackAdapterTester(FallbackAdapter):
def __init__(
self,
stt: list[STT],
*,
attempt_timeout: float = 10.0,
max_retry_per_stt: int = 1,
retry_interval: float = 5,
) -> None:
super().__init__(
stt,
attempt_timeout=attempt_timeout,
max_retry_per_stt=max_retry_per_stt,
retry_interval=retry_interval,
)
self.on("stt_availability_changed", self._on_stt_availability_changed)
self._availability_changed_ch: dict[
int, utils.aio.Chan[AvailabilityChangedEvent]
] = {id(t): utils.aio.Chan[AvailabilityChangedEvent]() for t in stt}
def _on_stt_availability_changed(self, ev: AvailabilityChangedEvent) -> None:
self._availability_changed_ch[id(ev.stt)].send_nowait(ev)
def availability_changed_ch(
self,
tts: STT,
) -> utils.aio.ChanReceiver[AvailabilityChangedEvent]:
return self._availability_changed_ch[id(tts)]
async def test_stt_fallback() -> None:
fake1 = FakeSTT(fake_exception=APIConnectionError("fake1 failed"))
fake2 = FakeSTT(fake_transcript="hello world")
fallback_adapter = FallbackAdapterTester([fake1, fake2])
ev = await fallback_adapter.recognize([])
assert ev.alternatives[0].text == "hello world"
assert fake1.recognize_ch.recv_nowait()
assert fake2.recognize_ch.recv_nowait()
assert not fallback_adapter.availability_changed_ch(fake1).recv_nowait().available
fake2.update_options(fake_exception=APIConnectionError("fake2 failed"))
with pytest.raises(APIConnectionError):
await fallback_adapter.recognize([])
assert not fallback_adapter.availability_changed_ch(fake2).recv_nowait().available
await fallback_adapter.aclose()
# stream
fake1 = FakeSTT(fake_exception=APIConnectionError("fake1 failed"))
fake2 = FakeSTT(fake_transcript="hello world")
fallback_adapter = FallbackAdapterTester([fake1, fake2])
async with fallback_adapter.stream() as stream:
stream.end_input()
last_alt = ""
async for ev in stream:
last_alt = ev.alternatives[0].text
assert last_alt == "hello world"
await fallback_adapter.aclose()
async def test_stt_stream_fallback() -> None:
fake1 = FakeSTT(fake_exception=APIConnectionError("fake1 failed"))
fake2 = FakeSTT(fake_transcript="hello world")
fallback_adapter = FallbackAdapterTester([fake1, fake2])
async with fallback_adapter.stream() as stream:
stream.end_input()
async for _ in stream:
pass
assert fake1.stream_ch.recv_nowait()
assert fake2.stream_ch.recv_nowait()
assert not fallback_adapter.availability_changed_ch(fake1).recv_nowait().available
await fallback_adapter.aclose()
async def test_stt_recover() -> None:
fake1 = FakeSTT(fake_exception=APIConnectionError("fake1 failed"))
fake2 = FakeSTT(fake_exception=APIConnectionError("fake2 failed"), fake_timeout=0.5)
fallback_adapter = FallbackAdapterTester([fake1, fake2])
with pytest.raises(APIConnectionError):
await fallback_adapter.recognize([])
fake2.update_options(fake_exception=None, fake_transcript="hello world")
assert not fallback_adapter.availability_changed_ch(fake1).recv_nowait().available
assert not fallback_adapter.availability_changed_ch(fake2).recv_nowait().available
assert (
await asyncio.wait_for(
fallback_adapter.availability_changed_ch(fake2).recv(), 1.0
)
).available, "fake2 should have recovered"
await fallback_adapter.recognize([])
assert fake1.recognize_ch.recv_nowait()
assert fake2.recognize_ch.recv_nowait()
with pytest.raises(ChanEmpty):
fallback_adapter.availability_changed_ch(fake1).recv_nowait()
with pytest.raises(ChanEmpty):
fallback_adapter.availability_changed_ch(fake2).recv_nowait()
await fallback_adapter.aclose()
import pytest
from livekit.agents import tokenize
from livekit.agents.tokenize import basic
from livekit.agents.tokenize._basic_paragraph import split_paragraphs
from livekit.plugins import nltk
# Download the punkt tokenizer, will only download if not already present
nltk.NltkPlugin().download_files()
TEXT = (
"Hi! "
"LiveKit is a platform for live audio and video applications and services. "
"R.T.C stands for Real-Time Communication... again R.T.C. "
"Mr. Theo is testing the sentence tokenizer. "
"This is a test. Another test. "
"A short sentence. "
"A longer sentence that is longer than the previous sentence. "
"f(x) = x * 2.54 + 42. "
"Hey! Hi! Hello! "
)
EXPECTED_MIN_20 = [
"Hi! LiveKit is a platform for live audio and video applications and services.",
"R.T.C stands for Real-Time Communication... again R.T.C.",
"Mr. Theo is testing the sentence tokenizer.",
"This is a test. Another test.",
"A short sentence. A longer sentence that is longer than the previous sentence.",
"f(x) = x * 2.54 + 42.",
"Hey! Hi! Hello!",
]
SENT_TOKENIZERS = [
nltk.SentenceTokenizer(min_sentence_len=20),
basic.SentenceTokenizer(min_sentence_len=20),
]
@pytest.mark.parametrize("tokenizer", SENT_TOKENIZERS)
def test_sent_tokenizer(tokenizer: tokenize.SentenceTokenizer):
segmented = tokenizer.tokenize(text=TEXT)
for i, segment in enumerate(EXPECTED_MIN_20):
assert segment == segmented[i]
@pytest.mark.parametrize("tokenizer", SENT_TOKENIZERS)
async def test_streamed_sent_tokenizer(tokenizer: tokenize.SentenceTokenizer):
# divide text by chunks of arbitrary length (1-4)
pattern = [1, 2, 4]
text = TEXT
chunks = []
pattern_iter = iter(pattern * (len(text) // sum(pattern) + 1))
for chunk_size in pattern_iter:
if not text:
break
chunks.append(text[:chunk_size])
text = text[chunk_size:]
stream = tokenizer.stream()
for chunk in chunks:
stream.push_text(chunk)
stream.end_input()
for i in range(len(EXPECTED_MIN_20)):
ev = await stream.__anext__()
assert ev.token == EXPECTED_MIN_20[i]
WORDS_TEXT = (
"This is a test. Blabla another test! multiple consecutive spaces: done"
)
WORDS_EXPECTED = [
"This",
"is",
"a",
"test",
"Blabla",
"another",
"test",
"multiple",
"consecutive",
"spaces",
"done",
]
WORD_TOKENIZERS = [basic.WordTokenizer()]
@pytest.mark.parametrize("tokenizer", WORD_TOKENIZERS)
def test_word_tokenizer(tokenizer: tokenize.WordTokenizer):
tokens = tokenizer.tokenize(text=WORDS_TEXT)
for i, token in enumerate(WORDS_EXPECTED):
assert token == tokens[i]
@pytest.mark.parametrize("tokenizer", WORD_TOKENIZERS)
async def test_streamed_word_tokenizer(tokenizer: tokenize.WordTokenizer):
# divide text by chunks of arbitrary length (1-4)
pattern = [1, 2, 4]
text = WORDS_TEXT
chunks = []
pattern_iter = iter(pattern * (len(text) // sum(pattern) + 1))
for chunk_size in pattern_iter:
if not text:
break
chunks.append(text[:chunk_size])
text = text[chunk_size:]
stream = tokenizer.stream()
for chunk in chunks:
stream.push_text(chunk)
stream.end_input()
for i in range(len(WORDS_EXPECTED)):
ev = await stream.__anext__()
assert ev.token == WORDS_EXPECTED[i]
WORDS_PUNCT_TEXT = 'This is <phoneme alphabet="cmu-arpabet" ph="AE K CH UW AH L IY">actually</phoneme> tricky to handle.'
WORDS_PUNCT_EXPECTED = [
"This",
"is",
"<phoneme",
'alphabet="cmu-arpabet"',
'ph="AE',
"K",
"CH",
"UW",
"AH",
"L",
'IY">actually</phoneme>',
"tricky",
"to",
"handle.",
]
WORD_PUNCT_TOKENIZERS = [basic.WordTokenizer(ignore_punctuation=False)]
@pytest.mark.parametrize("tokenizer", WORD_PUNCT_TOKENIZERS)
def test_punct_word_tokenizer(tokenizer: tokenize.WordTokenizer):
tokens = tokenizer.tokenize(text=WORDS_PUNCT_TEXT)
for i, token in enumerate(WORDS_PUNCT_EXPECTED):
assert token == tokens[i]
@pytest.mark.parametrize("tokenizer", WORD_PUNCT_TOKENIZERS)
async def test_streamed_punct_word_tokenizer(tokenizer: tokenize.WordTokenizer):
# divide text by chunks of arbitrary length (1-4)
pattern = [1, 2, 4]
text = WORDS_PUNCT_TEXT
chunks = []
pattern_iter = iter(pattern * (len(text) // sum(pattern) + 1))
for chunk_size in pattern_iter:
if not text:
break
chunks.append(text[:chunk_size])
text = text[chunk_size:]
stream = tokenizer.stream()
for chunk in chunks:
stream.push_text(chunk)
stream.end_input()
for i in range(len(WORDS_PUNCT_EXPECTED)):
ev = await stream.__anext__()
assert ev.token == WORDS_PUNCT_EXPECTED[i]
HYPHENATOR_TEXT = [
"Segment",
"expected",
"communication",
"window",
"welcome",
"bedroom",
]
HYPHENATOR_EXPECTED = [
["Seg", "ment"],
["ex", "pect", "ed"],
["com", "mu", "ni", "ca", "tion"],
["win", "dow"],
["wel", "come"],
["bed", "room"],
]
def test_hyphenate_word():
for i, word in enumerate(HYPHENATOR_TEXT):
hyphenated = basic.hyphenate_word(word)
assert hyphenated == HYPHENATOR_EXPECTED[i]
REPLACE_TEXT = (
"This is a test. Hello world, I'm creating this agents.. framework. Once again "
"framework. A.B.C"
)
REPLACE_EXPECTED = (
"This is a test. Hello universe, I'm creating this assistants.. library. twice again "
"library. A.B.C.D"
)
REPLACE_REPLACEMENTS = {
"world": "universe",
"framework": "library",
"a.b.c": "A.B.C.D",
"once": "twice",
"agents": "assistants",
}
def test_replace_words():
replaced = tokenize.utils.replace_words(
text=REPLACE_TEXT, replacements=REPLACE_REPLACEMENTS
)
assert replaced == REPLACE_EXPECTED
async def test_replace_words_async():
pattern = [1, 2, 4]
text = REPLACE_TEXT
chunks = []
pattern_iter = iter(pattern * (len(text) // sum(pattern) + 1))
for chunk_size in pattern_iter:
if not text:
break
chunks.append(text[:chunk_size])
text = text[chunk_size:]
async def _replace_words_async():
for chunk in chunks:
yield chunk
replaced_chunks = []
async for chunk in tokenize.utils.replace_words(
text=_replace_words_async(), replacements=REPLACE_REPLACEMENTS
):
replaced_chunks.append(chunk)
replaced = "".join(replaced_chunks)
assert replaced == REPLACE_EXPECTED
PARAGRAPH_TEST_CASES = [
("Single paragraph.", [("Single paragraph.", 0, 17)]),
(
"Paragraph 1.\n\nParagraph 2.",
[("Paragraph 1.", 0, 12), ("Paragraph 2.", 14, 26)],
),
(
"Para 1.\n\nPara 2.\n\nPara 3.",
[("Para 1.", 0, 7), ("Para 2.", 9, 16), ("Para 3.", 18, 25)],
),
(
"\n\nParagraph with leading newlines.",
[("Paragraph with leading newlines.", 2, 34)],
),
(
"Paragraph with trailing newlines.\n\n",
[("Paragraph with trailing newlines.", 0, 33)],
),
(
"\n\n Paragraph with leading and trailing spaces. \n\n",
[("Paragraph with leading and trailing spaces.", 4, 47)],
),
(
"Para 1.\n\n\n\nPara 2.", # Multiple newlines between paragraphs
[("Para 1.", 0, 7), ("Para 2.", 11, 18)],
),
(
"Para 1.\n \n \nPara 2.", # Newlines with spaces between paragraphs
[("Para 1.", 0, 7), ("Para 2.", 12, 19)],
),
(
"", # Empty string
[],
),
(
"\n\n\n", # Only newlines
[],
),
(
"Line 1\nLine 2\nLine 3", # Single paragraph with newlines
[("Line 1\nLine 2\nLine 3", 0, 20)],
),
]
@pytest.mark.parametrize(
"test_case",
PARAGRAPH_TEST_CASES,
)
def test_split_paragraphs(test_case):
input_text, expected_output = test_case
result = split_paragraphs(input_text)
assert result == expected_output, f"Failed for input: {input_text}"
"""
Check if all Text-To-Speech are producing valid audio.
We verify the content using a good STT model
"""
import dataclasses
from typing import Callable
import pytest
from livekit import agents
from livekit.agents import APIConnectionError, tokenize, tts
from livekit.agents.utils import AudioBuffer, merge_frames
from livekit.plugins import (
aws,
azure,
cartesia,
deepgram,
elevenlabs,
google,
neuphonic,
openai,
playai,
resemble,
rime,
)
from .conftest import TEST_CONNECT_OPTIONS
from .fake_tts import FakeTTS
from .utils import make_test_synthesize, wer
WER_THRESHOLD = 0.2
async def _assert_valid_synthesized_audio(
frames: AudioBuffer, tts: agents.tts.TTS, text: str, threshold: float
):
# use whisper as the source of truth to verify synthesized speech (smallest WER)
whisper_stt = openai.STT(model="whisper-1")
res = await whisper_stt.recognize(buffer=frames)
assert wer(res.alternatives[0].text, text) <= threshold
merged_frame = merge_frames(frames)
assert merged_frame.sample_rate == tts.sample_rate, "sample rate should be the same"
assert merged_frame.num_channels == tts.num_channels, (
"num channels should be the same"
)
SYNTHESIZE_TTS: list[Callable[[], tts.TTS]] = [
pytest.param(lambda: elevenlabs.TTS(), id="elevenlabs"),
pytest.param(lambda: openai.TTS(), id="openai"),
pytest.param(lambda: google.TTS(), id="google"),
pytest.param(lambda: azure.TTS(), id="azure"),
pytest.param(lambda: aws.TTS(), id="aws"),
pytest.param(lambda: cartesia.TTS(), id="cartesia"),
pytest.param(lambda: deepgram.TTS(), id="deepgram"),
pytest.param(lambda: playai.TTS(), id="playai"),
pytest.param(lambda: rime.TTS(), id="rime"),
pytest.param(lambda: neuphonic.TTS(), id="neuphonic"),
pytest.param(lambda: resemble.TTS(), id="resemble"),
]
@pytest.mark.usefixtures("job_process")
@pytest.mark.parametrize("tts_factory", SYNTHESIZE_TTS)
async def test_synthesize(tts_factory):
tts = tts_factory()
synthesize_transcript = make_test_synthesize()
frames = []
async for audio in tts.synthesize(text=synthesize_transcript):
frames.append(audio.frame)
await _assert_valid_synthesized_audio(
frames, tts, synthesize_transcript, WER_THRESHOLD
)
STREAM_SENT_TOKENIZER = tokenize.basic.SentenceTokenizer(min_sentence_len=20)
STREAM_TTS: list[Callable[[], tts.TTS]] = [
pytest.param(lambda: elevenlabs.TTS(), id="elevenlabs"),
pytest.param(lambda: cartesia.TTS(), id="cartesia"),
pytest.param(
lambda: agents.tts.StreamAdapter(
tts=openai.TTS(), sentence_tokenizer=STREAM_SENT_TOKENIZER
),
id="openai.stream",
),
pytest.param(
lambda: agents.tts.StreamAdapter(
tts=google.TTS(), sentence_tokenizer=STREAM_SENT_TOKENIZER
),
id="google.stream",
),
pytest.param(
lambda: agents.tts.StreamAdapter(
tts=azure.TTS(), sentence_tokenizer=STREAM_SENT_TOKENIZER
),
id="azure.stream",
),
pytest.param(lambda: deepgram.TTS(), id="deepgram"),
pytest.param(lambda: playai.TTS(), id="playai"),
pytest.param(
lambda: agents.tts.StreamAdapter(
tts=aws.TTS(), sentence_tokenizer=STREAM_SENT_TOKENIZER
),
id="aws.stream",
),
pytest.param(lambda: neuphonic.TTS(), id="neuphonic"),
pytest.param(lambda: resemble.TTS(), id="resemble"),
]
@pytest.mark.usefixtures("job_process")
@pytest.mark.parametrize("tts_factory", STREAM_TTS)
async def test_stream(tts_factory):
tts: agents.tts.TTS = tts_factory()
synthesize_transcript = make_test_synthesize()
pattern = [1, 2, 4]
text = synthesize_transcript
chunks = []
pattern_iter = iter(pattern * (len(text) // sum(pattern) + 1))
for chunk_size in pattern_iter:
if not text:
break
chunks.append(text[:chunk_size])
text = text[chunk_size:]
stream = tts.stream()
segments = set()
# for i in range(2): # TODO(theomonnom): we should test 2 segments
for chunk in chunks:
stream.push_text(chunk)
stream.flush()
# if i == 1:
stream.end_input()
frames = []
is_final = False
async for audio in stream:
is_final = audio.is_final
segments.add(audio.segment_id)
frames.append(audio.frame)
assert is_final, "final audio should be marked as final"
await _assert_valid_synthesized_audio(
frames, tts, synthesize_transcript, WER_THRESHOLD
)
# assert len(segments) == 2
await stream.aclose()
async def test_retry():
fake_tts = FakeTTS(fake_exception=APIConnectionError("fake exception"))
retry_options = dataclasses.replace(TEST_CONNECT_OPTIONS, max_retry=3)
stream = fake_tts.synthesize("testing", conn_options=retry_options)
with pytest.raises(APIConnectionError):
async for _ in stream:
pass
assert fake_tts.synthesize_ch.recv_nowait()
assert stream.attempt == 4
async def test_close():
fake_tts = FakeTTS(fake_timeout=5.0)
retry_options = dataclasses.replace(TEST_CONNECT_OPTIONS, max_retry=0)
stream = fake_tts.synthesize("testing", conn_options=retry_options)
await stream.aclose()
async for _ in stream:
pass
from __future__ import annotations
import asyncio
import contextlib
import pytest
from livekit import rtc
from livekit.agents import APIConnectionError, utils
from livekit.agents.tts import TTS, AvailabilityChangedEvent, FallbackAdapter
from livekit.agents.tts.tts import SynthesizeStream
from livekit.agents.utils.aio.channel import ChanEmpty
from .fake_tts import FakeTTS
class FallbackAdapterTester(FallbackAdapter):
def __init__(
self,
tts: list[TTS],
*,
attempt_timeout: float = 10.0,
max_retry_per_tts: int = 1, # only retry once by default
no_fallback_after_audio_duration: float | None = 3.0,
sample_rate: int | None = None,
) -> None:
super().__init__(
tts,
attempt_timeout=attempt_timeout,
max_retry_per_tts=max_retry_per_tts,
no_fallback_after_audio_duration=no_fallback_after_audio_duration,
sample_rate=sample_rate,
)
self.on("tts_availability_changed", self._on_tts_availability_changed)
self._availability_changed_ch: dict[
int, utils.aio.Chan[AvailabilityChangedEvent]
] = {id(t): utils.aio.Chan[AvailabilityChangedEvent]() for t in tts}
def _on_tts_availability_changed(self, ev: AvailabilityChangedEvent) -> None:
self._availability_changed_ch[id(ev.tts)].send_nowait(ev)
def availability_changed_ch(
self,
tts: TTS,
) -> utils.aio.ChanReceiver[AvailabilityChangedEvent]:
return self._availability_changed_ch[id(tts)]
async def test_tts_fallback() -> None:
fake1 = FakeTTS(fake_exception=APIConnectionError("fake1 failed"))
fake2 = FakeTTS(fake_audio_duration=5.0, sample_rate=48000)
fallback_adapter = FallbackAdapterTester([fake1, fake2])
async with fallback_adapter.synthesize("hello test") as stream:
frames = []
async for data in stream:
frames.append(data.frame)
assert fake1.synthesize_ch.recv_nowait()
assert fake2.synthesize_ch.recv_nowait()
assert rtc.combine_audio_frames(frames).duration == 5.0
assert not fallback_adapter.availability_changed_ch(fake1).recv_nowait().available
fake2.update_options(fake_audio_duration=0.0)
with pytest.raises(APIConnectionError):
async with fallback_adapter.synthesize("hello test") as stream:
async for _ in stream:
pass
assert not fallback_adapter.availability_changed_ch(fake2).recv_nowait().available
await fallback_adapter.aclose()
async def test_no_audio() -> None:
fake1 = FakeTTS(fake_audio_duration=0.0)
fallback_adapter = FallbackAdapterTester([fake1])
with pytest.raises(APIConnectionError):
async with fallback_adapter.synthesize("hello test") as stream:
async for _ in stream:
pass
# stream
fake1.update_options(fake_audio_duration=5.0)
async def _input_task(stream: SynthesizeStream):
with contextlib.suppress(RuntimeError):
stream.push_text("hello test")
stream.flush()
await asyncio.sleep(1.0)
fake1.update_options(fake_timeout=0.5, fake_audio_duration=None)
stream.push_text("hello test")
stream.end_input()
with pytest.raises(APIConnectionError):
async with fallback_adapter.stream() as stream:
input_task = asyncio.create_task(_input_task(stream))
segments = set()
try:
async for ev in stream:
segments.add(ev.segment_id)
finally:
await input_task
assert len(segments) == 1
await fallback_adapter.aclose()
async def test_tts_stream_fallback() -> None:
fake1 = FakeTTS(fake_exception=APIConnectionError("fake1 failed"))
fake2 = FakeTTS(fake_audio_duration=5.0)
fallback_adapter = FallbackAdapterTester([fake1, fake2])
async with fallback_adapter.stream() as stream:
stream.push_text("hello test")
stream.end_input()
async for _ in stream:
pass
assert fake1.stream_ch.recv_nowait()
assert fake2.stream_ch.recv_nowait()
assert not fallback_adapter.availability_changed_ch(fake1).recv_nowait().available
await fallback_adapter.aclose()
async def test_tts_recover() -> None:
fake1 = FakeTTS(fake_exception=APIConnectionError("fake1 failed"))
fake2 = FakeTTS(fake_exception=APIConnectionError("fake2 failed"), fake_timeout=0.5)
fallback_adapter = FallbackAdapterTester([fake1, fake2])
with pytest.raises(APIConnectionError):
async for _ in fallback_adapter.synthesize("hello test"):
pass
assert fake1.synthesize_ch.recv_nowait()
assert fake2.synthesize_ch.recv_nowait()
fake2.update_options(fake_exception=None, fake_audio_duration=5.0)
assert not fallback_adapter.availability_changed_ch(fake1).recv_nowait().available
assert not fallback_adapter.availability_changed_ch(fake2).recv_nowait().available
assert (
await asyncio.wait_for(
fallback_adapter.availability_changed_ch(fake2).recv(), 1.0
)
).available, "fake2 should have recovered"
async for _ in fallback_adapter.synthesize("hello test"):
pass
assert fake1.synthesize_ch.recv_nowait()
assert fake2.synthesize_ch.recv_nowait()
with pytest.raises(ChanEmpty):
fallback_adapter.availability_changed_ch(fake1).recv_nowait()
with pytest.raises(ChanEmpty):
fallback_adapter.availability_changed_ch(fake2).recv_nowait()
await fallback_adapter.aclose()
async def test_audio_resampled() -> None:
fake1 = FakeTTS(
sample_rate=48000, fake_exception=APIConnectionError("fake1 failed")
)
fake2 = FakeTTS(fake_audio_duration=5.0, sample_rate=16000)
fallback_adapter = FallbackAdapterTester([fake1, fake2])
async with fallback_adapter.synthesize("hello test") as stream:
frames = []
async for data in stream:
frames.append(data.frame)
assert fake1.synthesize_ch.recv_nowait()
assert fake2.synthesize_ch.recv_nowait()
assert (
not fallback_adapter.availability_changed_ch(fake1).recv_nowait().available
)
combined_frame = rtc.combine_audio_frames(frames)
assert combined_frame.duration == 5.0
assert combined_frame.sample_rate == 48000
assert await asyncio.wait_for(fake1.synthesize_ch.recv(), 1.0)
async with fallback_adapter.stream() as stream:
stream.push_text("hello test")
stream.end_input()
frames = []
async for data in stream:
frames.append(data.frame)
print(frames)
assert fake2.stream_ch.recv_nowait()
combined_frame = rtc.combine_audio_frames(frames)
assert combined_frame.duration == 5.0
assert combined_frame.sample_rate == 48000
await fallback_adapter.aclose()
async def test_timeout():
fake1 = FakeTTS(fake_timeout=0.5, sample_rate=48000)
fake2 = FakeTTS(fake_timeout=0.5, sample_rate=48000)
fallback_adapter = FallbackAdapterTester([fake1, fake2], attempt_timeout=0.1)
with pytest.raises(APIConnectionError):
async for _ in fallback_adapter.synthesize("hello test"):
pass
assert fake1.synthesize_ch.recv_nowait()
assert fake2.synthesize_ch.recv_nowait()
assert not fallback_adapter.availability_changed_ch(fake1).recv_nowait().available
assert not fallback_adapter.availability_changed_ch(fake2).recv_nowait().available
assert await asyncio.wait_for(fake1.synthesize_ch.recv(), 1.0)
assert await asyncio.wait_for(fake2.synthesize_ch.recv(), 1.0)
# stream
with pytest.raises(APIConnectionError):
async with fallback_adapter.stream() as stream:
stream.end_input()
async for _ in stream:
pass
assert fake1.stream_ch.recv_nowait()
assert fake2.stream_ch.recv_nowait()
assert await asyncio.wait_for(fake1.stream_ch.recv(), 1.0)
assert await asyncio.wait_for(fake2.stream_ch.recv(), 1.0)
await fallback_adapter.aclose()
# consecutive push must not timeout
fake1.update_options(fake_timeout=None, fake_audio_duration=5.0)
fallback_adapter = FallbackAdapterTester([fake1], attempt_timeout=0.25)
async def _input_task1(stream: SynthesizeStream):
stream.push_text("hello world")
stream.flush()
await asyncio.sleep(1.0)
stream.push_text("bye world")
stream.end_input()
async with fallback_adapter.stream() as stream:
input_task = asyncio.create_task(_input_task1(stream))
segments = set()
final_count = 0
async for ev in stream:
segments.add(ev.segment_id)
if ev.is_final:
final_count += 1
assert len(segments) == 2
assert final_count == 2
await input_task
async def _input_task2(stream: SynthesizeStream):
with contextlib.suppress(RuntimeError):
stream.push_text("hello test")
stream.flush()
await asyncio.sleep(1.0)
fake1.update_options(fake_timeout=0.5, fake_audio_duration=None)
stream.push_text("hello test")
stream.flush()
await asyncio.sleep(1.0)
stream.end_input()
with pytest.raises(APIConnectionError):
async with fallback_adapter.stream() as stream:
input_task = asyncio.create_task(_input_task2(stream))
try:
async for ev in stream:
pass
finally:
await input_task
await fallback_adapter.aclose()
import pytest
from livekit.agents import vad
from livekit.plugins import silero
from . import utils
SAMPLE_RATES = [16000, 44100] # test multiple input sample rates
VAD = silero.VAD.load(
min_speech_duration=0.5,
min_silence_duration=0.6,
)
@pytest.mark.parametrize("sample_rate", SAMPLE_RATES)
async def test_chunks_vad(sample_rate) -> None:
frames, _ = await utils.make_test_speech(
chunk_duration_ms=10, sample_rate=sample_rate
)
assert len(frames) > 1, "frames aren't chunked"
stream = VAD.stream()
for frame in frames:
stream.push_frame(frame)
stream.end_input()
start_of_speech_i = 0
end_of_speech_i = 0
inference_frames = []
async for ev in stream:
if ev.type == vad.VADEventType.START_OF_SPEECH:
with open(
f"test_vad.{sample_rate}.start_of_speech_frames_{start_of_speech_i}.wav",
"wb",
) as f:
f.write(utils.make_wav_file(ev.frames))
start_of_speech_i += 1
if ev.type == vad.VADEventType.INFERENCE_DONE:
inference_frames.extend(ev.frames)
if ev.type == vad.VADEventType.END_OF_SPEECH:
with open(
f"test_vad.{sample_rate}.end_of_speech_frames_{end_of_speech_i}.wav",
"wb",
) as f:
f.write(utils.make_wav_file(ev.frames))
end_of_speech_i += 1
assert start_of_speech_i > 0, "no start of speech detected"
assert start_of_speech_i == end_of_speech_i, "start and end of speech mismatch"
with open("test_vad.{sample_rate}.inference_frames.wav", "wb") as f:
f.write(utils.make_wav_file(inference_frames))
@pytest.mark.parametrize("sample_rate", SAMPLE_RATES)
async def test_file_vad(sample_rate):
frames, _ = await utils.make_test_speech(sample_rate=sample_rate)
assert len(frames) == 1, "one frame should be the whole audio"
stream = VAD.stream()
for frame in frames:
stream.push_frame(frame)
stream.end_input()
start_of_speech_i = 0
end_of_speech_i = 0
async for ev in stream:
if ev.type == vad.VADEventType.START_OF_SPEECH:
start_of_speech_i += 1
if ev.type == vad.VADEventType.END_OF_SPEECH:
end_of_speech_i += 1
assert start_of_speech_i > 0, "no start of speech detected"
assert start_of_speech_i == end_of_speech_i, "start and end of speech mismatch"
from __future__ import annotations
import io
import os
import pathlib
import wave
from typing import Tuple
import jiwer as tr
from livekit import rtc
from livekit.agents import utils
TEST_AUDIO_FILEPATH = os.path.join(os.path.dirname(__file__), "long.mp3")
TEST_AUDIO_TRANSCRIPT = pathlib.Path(
os.path.dirname(__file__), "long_transcript.txt"
).read_text()
TEST_AUDIO_SYNTHESIZE = pathlib.Path(
os.path.dirname(__file__), "long_synthesize.txt"
).read_text()
def wer(hypothesis: str, reference: str) -> float:
wer_standardize_contiguous = tr.Compose(
[
tr.ToLowerCase(),
tr.ExpandCommonEnglishContractions(),
tr.RemoveKaldiNonWords(),
tr.RemoveWhiteSpace(replace_by_space=True),
tr.RemoveMultipleSpaces(),
tr.Strip(),
tr.ReduceToSingleSentence(),
tr.ReduceToListOfListOfWords(),
]
)
return tr.wer(
reference,
hypothesis,
reference_transform=wer_standardize_contiguous,
hypothesis_transform=wer_standardize_contiguous,
)
async def read_mp3_file(path) -> rtc.AudioFrame:
decoder = utils.codecs.AudioStreamDecoder(
sample_rate=48000,
num_channels=1,
)
frames: list[rtc.AudioFrame] = []
with open(path, "rb") as file:
while True:
chunk = file.read(4096)
if not chunk:
break
decoder.push(chunk)
decoder.end_input()
async for frame in decoder:
frames.append(frame)
return rtc.combine_audio_frames(frames) # merging just for ease of use
async def make_test_speech(
*,
chunk_duration_ms: int | None = None,
sample_rate: int | None = None, # resample if not None
) -> Tuple[list[rtc.AudioFrame], str]:
input_audio = await read_mp3_file(TEST_AUDIO_FILEPATH)
if sample_rate is not None and input_audio.sample_rate != sample_rate:
resampler = rtc.AudioResampler(
input_rate=input_audio.sample_rate,
output_rate=sample_rate,
num_channels=input_audio.num_channels,
)
frames = []
if resampler:
frames = resampler.push(input_audio)
frames.extend(resampler.flush())
input_audio = rtc.combine_audio_frames(frames)
if not chunk_duration_ms:
return [input_audio], TEST_AUDIO_TRANSCRIPT
chunk_size = int(input_audio.sample_rate / (1000 / chunk_duration_ms))
bstream = utils.audio.AudioByteStream(
sample_rate=input_audio.sample_rate,
num_channels=input_audio.num_channels,
samples_per_channel=chunk_size,
)
frames = bstream.write(input_audio.data.tobytes())
frames.extend(bstream.flush())
return frames, TEST_AUDIO_TRANSCRIPT
def make_test_synthesize() -> str:
return TEST_AUDIO_SYNTHESIZE
def make_wav_file(frames: list[rtc.AudioFrame]) -> bytes:
buffer = utils.merge_frames(frames)
io_buffer = io.BytesIO()
with wave.open(io_buffer, "wb") as wav:
wav.setnchannels(buffer.num_channels)
wav.setsampwidth(2) # 16-bit
wav.setframerate(buffer.sample_rate)
wav.writeframes(buffer.data)
return io_buffer.getvalue()