# -*- coding: utf-8 -*-
""" Common utils."""
import base64
import contextlib
import datetime
import hashlib
import json
import os
import random
import re
import secrets
import signal
import socket
import string
import sys
import tempfile
import threading
from typing import Any, Generator, Optional, Union, Tuple, Literal, List
from urllib.parse import urlparse
import psutil
import requests
[docs]
@contextlib.contextmanager
def timer(seconds: Optional[Union[int, float]] = None) -> Generator:
"""
A context manager that limits the execution time of a code block to a
given number of seconds.
The implementation of this contextmanager are borrowed from
https://github.com/openai/human-eval/blob/master/human_eval/execution.py
Note:
This function only works in Unix and MainThread,
since `signal.setitimer` is only available in Unix.
"""
if (
seconds is None
or sys.platform == "win32"
or threading.currentThread().name # pylint: disable=W4902
!= "MainThread"
):
yield
return
def signal_handler(*args: Any, **kwargs: Any) -> None:
raise TimeoutError("timed out")
signal.setitimer(signal.ITIMER_REAL, seconds)
signal.signal(signal.SIGALRM, signal_handler)
try:
# Enter the context and execute the code block.
yield
finally:
signal.setitimer(signal.ITIMER_REAL, 0)
[docs]
@contextlib.contextmanager
def create_tempdir() -> Generator:
"""
A context manager that creates a temporary directory and changes the
current working directory to it.
The implementation of this contextmanager are borrowed from
https://github.com/openai/human-eval/blob/master/human_eval/execution.py
"""
with tempfile.TemporaryDirectory() as dirname:
with _chdir(dirname):
yield dirname
@contextlib.contextmanager
def _chdir(path: str) -> Generator:
"""
A context manager that changes the current working directory to the
given path.
The implementation of this contextmanager are borrowed from
https://github.com/openai/human-eval/blob/master/human_eval/execution.py
"""
if path == ".":
yield
return
cwd = os.getcwd()
os.chdir(path)
try:
yield
except BaseException as exc:
raise exc
finally:
os.chdir(cwd)
def _requests_get(
url: str,
params: dict,
headers: Optional[dict] = None,
) -> Union[dict, str]:
"""
Sends a GET request to the specified URL with the provided query parameters
and headers. Returns the JSON response as a dictionary.
This function handles the request, checks for errors, logs exceptions,
and parses the JSON response.
Args:
url (str): The URL to which the GET request is sent.
params (Dict): A dictionary containing query parameters to be included
in the request.
headers (Optional[Dict]): An optional dictionary of HTTP headers to
send with the request.
Returns:
Dict or str:
If the request is successful, returns a dictionary containing the
parsed JSON data.
If the request fails, returns the error string.
"""
# Make the request
try:
# Check if headers are provided, and include them if they are not None
if headers:
response = requests.get(url, params=params, headers=headers)
else:
response = requests.get(url, params=params)
# This will raise an exception for HTTP error codes
response.raise_for_status()
except requests.RequestException as e:
return str(e)
# Parse the JSON response
search_results = response.json()
return search_results
def _if_change_database(sql_query: str) -> bool:
"""Check whether SQL query only contains SELECT query"""
# Compile the regex pattern outside the function for better performance
pattern_unsafe_sql = re.compile(
r"\b(INSERT|UPDATE|DELETE|REPLACE|CREATE|ALTER|DROP|TRUNCATE|USE)\b",
re.IGNORECASE,
)
# Remove SQL comments
sql_query = re.sub(r"--.*?$", "", sql_query, flags=re.MULTILINE)
# Remove /* */ comments
sql_query = re.sub(r"/\*.*?\*/", "", sql_query, flags=re.DOTALL)
# Matching non-SELECT statements with regular expressions
if pattern_unsafe_sql.search(sql_query):
return False
return True
def _get_timestamp(
format_: str = "%Y-%m-%d %H:%M:%S",
time: datetime.datetime = None,
) -> str:
"""Get current timestamp."""
if time is None:
return datetime.datetime.now().strftime(format_)
else:
return time.strftime(format_)
[docs]
def to_openai_dict(item: dict) -> dict:
"""Convert `Msg` to `dict` for OpenAI API."""
clean_dict = {}
if "name" in item:
clean_dict["name"] = item["name"]
if "role" in item:
clean_dict["role"] = item["role"]
else:
clean_dict["role"] = "assistant"
if "content" in item:
clean_dict["content"] = _convert_to_str(item["content"])
else:
raise ValueError("The content of the message is missing.")
return clean_dict
def _find_available_port() -> int:
"""Get an unoccupied socket port number."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]
def _check_port(port: Optional[int] = None) -> int:
"""Check if the port is available.
Args:
port (`int`):
the port number being checked.
Returns:
`int`: the port number that passed the check. If the port is found
to be occupied, an available port number will be automatically
returned.
"""
if port is None:
new_port = _find_available_port()
return new_port
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
try:
if s.connect_ex(("localhost", port)) == 0:
raise RuntimeError("Port is occupied.")
except Exception:
new_port = _find_available_port()
return new_port
return port
def _guess_type_by_extension(
url: str,
) -> Literal["image", "audio", "video", "file"]:
"""Guess the type of the file by its extension."""
extension = url.split(".")[-1].lower()
if extension in [
"bmp",
"dib",
"icns",
"ico",
"jfif",
"jpe",
"jpeg",
"jpg",
"j2c",
"j2k",
"jp2",
"jpc",
"jpf",
"jpx",
"apng",
"png",
"bw",
"rgb",
"rgba",
"sgi",
"tif",
"tiff",
"webp",
]:
return "image"
elif extension in [
"amr",
"wav",
"3gp",
"3gpp",
"aac",
"mp3",
"flac",
"ogg",
]:
return "audio"
elif extension in [
"mp4",
"webm",
"mkv",
"flv",
"avi",
"mov",
"wmv",
"rmvb",
]:
return "video"
else:
return "file"
def _to_openai_image_url(url: str) -> str:
"""Convert an image url to openai format. If the given url is a local
file, it will be converted to base64 format. Otherwise, it will be
returned directly.
Args:
url (`str`):
The local or public url of the image.
"""
# See https://platform.openai.com/docs/guides/vision for details of
# support image extensions.
support_image_extensions = (
".png",
".jpg",
".jpeg",
".gif",
".webp",
)
parsed_url = urlparse(url)
lower_url = url.lower()
# Web url
if not os.path.exists(url) and parsed_url.scheme != "":
if any(lower_url.endswith(_) for _ in support_image_extensions):
return url
# Check if it is a local file
elif os.path.exists(url) and os.path.isfile(url):
if any(lower_url.endswith(_) for _ in support_image_extensions):
with open(url, "rb") as image_file:
base64_image = base64.b64encode(image_file.read()).decode(
"utf-8",
)
extension = parsed_url.path.lower().split(".")[-1]
mime_type = f"image/{extension}"
return f"data:{mime_type};base64,{base64_image}"
raise TypeError(f"{url} should be end with {support_image_extensions}.")
def _download_file(url: str, path_file: str, max_retries: int = 3) -> bool:
"""Download file from the given url and save it to the given path.
Args:
url (`str`):
The url of the file.
path_file (`str`):
The path to save the file.
max_retries (`int`, defaults to `3`)
The maximum number of retries when fail to download the file.
"""
for n_retry in range(1, max_retries + 1):
response = requests.get(url, stream=True)
if response.status_code == requests.codes.ok:
with open(path_file, "wb") as file:
for chunk in response.iter_content(1024):
file.write(chunk)
return True
else:
raise RuntimeError(
f"Failed to download file from {url} (status code: "
f"{response.status_code}). Retry {n_retry}/{max_retries}.",
)
return False
def _generate_random_code(
length: int = 6,
uppercase: bool = True,
lowercase: bool = True,
digits: bool = True,
) -> str:
"""Get random code."""
characters = ""
if uppercase:
characters += string.ascii_uppercase
if lowercase:
characters += string.ascii_lowercase
if digits:
characters += string.digits
return "".join(secrets.choice(characters) for i in range(length))
def _generate_id_from_seed(seed: str, length: int = 8) -> str:
"""Generate random id from seed str.
Args:
seed (`str`): seed string.
length (`int`): generated id length.
"""
hasher = hashlib.sha256()
hasher.update(seed.encode("utf-8"))
hash_digest = hasher.hexdigest()
random.seed(hash_digest)
id_chars = [
random.choice(string.ascii_letters + string.digits)
for _ in range(length)
]
return "".join(id_chars)
def _is_web_url(url: str) -> bool:
"""Whether the url is accessible from the Web.
Args:
url (`str`):
The url to check.
Note:
This function is not perfect, it only checks if the URL starts with
common web protocols, e.g., http, https, ftp, oss.
"""
parsed_url = urlparse(url)
return parsed_url.scheme in ["http", "https", "ftp", "oss"]
def _get_base64_from_image_path(image_path: str) -> str:
"""Get the base64 string from the image url.
Args:
image_path (`str`):
The local path of the image.
"""
with open(image_path, "rb") as image_file:
base64_image = base64.b64encode(image_file.read()).decode(
"utf-8",
)
extension = image_path.lower().split(".")[-1]
mime_type = f"image/{extension}"
return f"data:{mime_type};base64,{base64_image}"
def _is_json_serializable(obj: Any) -> bool:
"""Check if the given object is json serializable."""
try:
json.dumps(obj)
return True
except TypeError:
return False
def _convert_to_str(content: Any) -> str:
"""Convert the content to string.
Note:
For prompt engineering, simply calling `str(content)` or
`json.dumps(content)` is not enough.
- For `str(content)`, if `content` is a dictionary, it will turn double
quotes to single quotes. When this string is fed into prompt, the LLMs
may learn to use single quotes instead of double quotes (which
cannot be loaded by `json.loads` API).
- For `json.dumps(content)`, if `content` is a string, it will add
double quotes to the string. LLMs may learn to use double quotes to
wrap strings, which leads to the same issue as `str(content)`.
To avoid these issues, we use this function to safely convert the
content to a string used in prompt.
Args:
content (`Any`):
The content to be converted.
Returns:
`str`: The converted string.
"""
if isinstance(content, str):
return content
elif isinstance(content, (dict, list, int, float, bool, tuple)):
return json.dumps(content, ensure_ascii=False)
else:
return str(content)
def _join_str_with_comma_and(elements: List[str]) -> str:
"""Return the JSON string with comma, and use " and " between the last two
elements."""
if len(elements) == 0:
return ""
elif len(elements) == 1:
return elements[0]
elif len(elements) == 2:
return " and ".join(elements)
else:
return ", ".join(elements[:-1]) + f", and {elements[-1]}"
[docs]
class ImportErrorReporter:
"""Used as a placeholder for missing packages.
When called, an ImportError will be raised, prompting the user to install
the specified extras requirement.
"""
def __init__(self, error: ImportError, extras_require: str = None) -> None:
"""Init the ImportErrorReporter.
Args:
error (`ImportError`): the original ImportError.
extras_require (`str`): the extras requirement.
"""
self.error = error
self.extras_require = extras_require
def __call__(self, *args: Any, **kwds: Any) -> Any:
return self._raise_import_error()
def __getattr__(self, name: str) -> Any:
return self._raise_import_error()
def __getitem__(self, __key: Any) -> Any:
return self._raise_import_error()
def _raise_import_error(self) -> Any:
"""Raise the ImportError"""
err_msg = f"ImportError occorred: [{self.error.msg}]."
if self.extras_require is not None:
err_msg += (
f" Please install [{self.extras_require}] version"
" of agentscope."
)
raise ImportError(err_msg)
def _hash_string(
data: str,
hash_method: Literal["sha256", "md5", "sha1"],
) -> str:
"""Hash the string data."""
hash_func = getattr(hashlib, hash_method)()
hash_func.update(data.encode())
return hash_func.hexdigest()
def _get_process_creation_time() -> datetime.datetime:
"""Get the creation time of the process."""
pid = os.getpid()
# Find the process by pid
current_process = psutil.Process(pid)
# Obtain the process creation time
create_time = current_process.create_time()
# Change the timestamp to a readable format
return datetime.datetime.fromtimestamp(create_time)
def _is_process_alive(
pid: int,
create_time_str: str,
create_time_format: str = "%Y-%m-%d %H:%M:%S",
tolerance_seconds: int = 10,
) -> bool:
"""Check if the process is alive by comparing the actual creation time of
the process with the given creation time.
Args:
pid (`int`):
The process id.
create_time_str (`str`):
The given creation time string.
create_time_format (`str`, defaults to `"%Y-%m-%d %H:%M:%S"`):
The format of the given creation time string.
tolerance_seconds (`int`, defaults to `10`):
The tolerance seconds for comparing the actual creation time with
the given creation time.
Returns:
`bool`: True if the process is alive, False otherwise.
"""
try:
# Try to create a process object by pid
proc = psutil.Process(pid)
# Obtain the actual creation time of the process
actual_create_time_timestamp = proc.create_time()
# Convert the given creation time string to a datetime object
given_create_time_datetime = datetime.datetime.strptime(
create_time_str,
create_time_format,
)
# Calculate the time difference between the actual creation time and
time_difference = abs(
actual_create_time_timestamp
- given_create_time_datetime.timestamp(),
)
# Compare the actual creation time with the given creation time
if time_difference <= tolerance_seconds:
return True
except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
# If the process is not found, access is denied, or the process is a
# zombie process, return False
return False
return False
def _is_windows() -> bool:
"""Check if the system is Windows."""
return os.name == "nt"
def _map_string_to_color_mark(
target_str: str,
) -> Tuple[str, str]:
"""Map a string into an index within a given length.
Args:
target_str (`str`):
The string to be mapped.
Returns:
`Tuple[str, str]`: A color marker tuple
"""
color_marks = [
("\033[90m", "\033[0m"),
("\033[91m", "\033[0m"),
("\033[92m", "\033[0m"),
("\033[93m", "\033[0m"),
("\033[94m", "\033[0m"),
("\033[95m", "\033[0m"),
("\033[96m", "\033[0m"),
("\033[97m", "\033[0m"),
]
hash_value = int(hashlib.sha256(target_str.encode()).hexdigest(), 16)
index = hash_value % len(color_marks)
return color_marks[index]
def _generate_new_runtime_id() -> str:
"""Generate a new random runtime id."""
_RUNTIME_ID_FORMAT = "run_%Y%m%d-%H%M%S_{}"
return _get_timestamp(_RUNTIME_ID_FORMAT).format(
_generate_random_code(uppercase=False),
)