Source code for agentscope.model._ollama_model

# -*- coding: utf-8 -*-
# pylint: disable=too-many-branches
"""Model wrapper for Ollama models."""
from datetime import datetime
from typing import (
    Any,
    TYPE_CHECKING,
    List,
    AsyncGenerator,
    AsyncIterator,
    Literal,
)
from collections import OrderedDict

from . import ChatResponse
from ._model_base import ChatModelBase
from ._model_usage import ChatUsage
from .._logging import logger
from .._utils._common import _json_loads_with_repair
from ..message import ToolUseBlock, TextBlock, ThinkingBlock
from ..tracing import trace_llm


if TYPE_CHECKING:
    from ollama._types import OllamaChatResponse
else:
    OllamaChatResponse = "ollama._types.ChatResponse"


[docs] class OllamaChatModel(ChatModelBase): """The Ollama chat model class in agentscope."""
[docs] def __init__( self, model_name: str, stream: bool = False, options: dict = None, keep_alive: str = "5m", enable_thinking: bool | None = None, host: str | None = None, **kwargs: Any, ) -> None: """Initialize the Ollama chat model. Args: model_name (`str`): The name of the model. stream (`bool`, default `True`): Streaming mode or not. options (`dict`, default `None`): Additional parameters to pass to the Ollama API. These can include temperature etc. keep_alive (`str`, default `"5m"`): Duration to keep the model loaded in memory. The format is a number followed by a unit suffix (s for seconds, m for minutes , h for hours). enable_thinking (`bool | None`, default `None`) Whether enable thinking or not, only for models such as qwen3, deepseek-r1, etc. For more details, please refer to https://ollama.com/search?c=thinking host (`str | None`, default `None`): The host address of the Ollama server. If None, uses the default address (typically http://localhost:11434). **kwargs (`Any`): Additional keyword arguments to pass to the base chat model class. """ try: import ollama except ImportError as e: raise ImportError( "The package ollama is not found. Please install it by " 'running command `pip install "ollama>=0.1.7"`', ) from e super().__init__(model_name, stream) self.client = ollama.AsyncClient( host=host, **kwargs, ) self.options = options self.keep_alive = keep_alive self.think = enable_thinking
[docs] @trace_llm async def __call__( self, messages: list[dict[str, Any]], tools: list[dict] | None = None, tool_choice: Literal["auto", "none", "any", "required"] | str | None = None, **kwargs: Any, ) -> ChatResponse | AsyncGenerator[ChatResponse, None]: """Get the response from Ollama chat completions API by the given arguments. Args: messages (`list[dict]`): A list of dictionaries, where `role` and `content` fields are required, and `name` field is optional. tools (`list[dict]`, default `None`): The tools JSON schemas that the model can use. tool_choice (`Literal["auto", "none", "any", "required"] | str \ | None`, default `None`): Controls which (if any) tool is called by the model. Can be "auto", "none", "any", "required", or specific tool name. **kwargs (`Any`): The keyword arguments for Ollama chat completions API, e.g. `think`etc. Please refer to the Ollama API documentation for more details. Returns: `ChatResponse | AsyncGenerator[ChatResponse, None]`: The response from the Ollama chat completions API. """ kwargs = { "model": self.model_name, "messages": messages, "stream": self.stream, "options": self.options, "keep_alive": self.keep_alive, **kwargs, } if self.think is not None and "think" not in kwargs: kwargs["think"] = self.think if tools: kwargs["tools"] = self._format_tools_json_schemas(tools) if tool_choice: logger.warning("Ollama does not support tool_choice yet, ignored.") start_datetime = datetime.now() response = await self.client.chat(**kwargs) if self.stream: return self._parse_ollama_stream_completion_response( start_datetime, response, ) parsed_response = await self._parse_ollama_completion_response( start_datetime, response, ) return parsed_response
async def _parse_ollama_stream_completion_response( self, start_datetime: datetime, response: AsyncIterator[Any], ) -> AsyncGenerator[ChatResponse, None]: """Parse the Ollama chat completion response stream into an async generator of `ChatResponse` objects.""" accumulated_text = "" acc_thinking_content = "" tool_calls = OrderedDict() # Store tool calls async for chunk in response: has_new_content = False has_new_thinking = False # Handle text content if hasattr(chunk, "message"): msg = chunk.message if getattr(msg, "thinking", None): acc_thinking_content += msg.thinking has_new_thinking = True if getattr(msg, "content", None): accumulated_text += msg.content has_new_content = True # Handle tool calls if getattr(msg, "tool_calls", None): has_new_content = True for idx, tool_call in enumerate(msg.tool_calls): function_name = ( getattr( tool_call, "function", None, ) and tool_call.function.name or "tool" ) tool_id = getattr( tool_call, "id", f"{function_name}_{idx}", ) if hasattr(tool_call, "function"): function = tool_call.function tool_calls[tool_id] = { "type": "tool_use", "id": tool_id, "name": function.name, "input": function.arguments, } # Calculate usage statistics current_time = (datetime.now() - start_datetime).total_seconds() usage = ChatUsage( input_tokens=getattr(chunk, "prompt_eval_count", 0) or 0, output_tokens=getattr(chunk, "eval_count", 0) or 0, time=current_time, ) # Create content blocks contents: list = [] if acc_thinking_content: contents.append( ThinkingBlock( type="thinking", thinking=acc_thinking_content, ), ) if accumulated_text: contents.append(TextBlock(type="text", text=accumulated_text)) # Add tool call blocks if tool_calls: for tool_call in tool_calls.values(): try: input_data = tool_call["input"] if isinstance(input_data, str): input_data = _json_loads_with_repair(input_data) contents.append( ToolUseBlock( type=tool_call["type"], id=tool_call["id"], name=tool_call["name"], input=input_data, ), ) except Exception as e: print(f"Error parsing tool call input: {e}") # Generate response when there's new content or at final chunk is_final = getattr(chunk, "done", False) if (has_new_thinking or has_new_content or is_final) and contents: res = ChatResponse(content=contents, usage=usage) yield res async def _parse_ollama_completion_response( self, start_datetime: datetime, response: OllamaChatResponse, ) -> ChatResponse: """Parse the Ollama chat completion response into a `ChatResponse` object. Args: start_datetime (`datetime`): The start datetime of the response generation. response (`OllamaChatResponse`): The Ollama chat response object to parse. Returns: `ChatResponse`: The content blocks and usage information extracted from the response. """ content_blocks: List[TextBlock | ToolUseBlock | ThinkingBlock] = [] if response.message.thinking: content_blocks.append( ThinkingBlock( type="thinking", thinking=response.message.thinking, ), ) if response.message.content: content_blocks.append( TextBlock( type="text", text=response.message.content, ), ) if response.message.tool_calls: for tool_call in response.message.tool_calls: content_blocks.append( ToolUseBlock( type="tool_use", id=tool_call.function.name, name=tool_call.function.name, input=tool_call.function.arguments, ), ) usage = None if "prompt_eval_count" in response and "eval_count" in response: usage = ChatUsage( input_tokens=response.get("prompt_eval_count", 0), output_tokens=response.get("eval_count", 0), time=(datetime.now() - start_datetime).total_seconds(), ) parsed_response = ChatResponse( content=content_blocks, usage=usage, ) return parsed_response def _format_tools_json_schemas( self, schemas: list[dict[str, Any]], ) -> list[dict[str, Any]]: """Format the tools JSON schemas to the Ollama format.""" return schemas