Coverage for src/appl/utils.py: 61%
88 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-22 15:39 -0800
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-22 15:39 -0800
1import functools
2import os
3import sys
4import time
5from typing import Any, Callable, Dict, Optional, Type, TypeVar
7import tiktoken
8from dotenv.main import _walk_to_root
9from loguru import logger
11from .core.config import configs
13try:
14 from langsmith import traceable as _langsmith_traceable # type: ignore
15except Exception:
16 F = TypeVar("F", bound=Callable)
18 # compatible to the case when langsmith is not installed
19 def _langsmith_traceable(*trace_args: Any, **trace_kwargs: Any) -> Callable[[F], F]:
20 def decorator(func: F) -> F:
21 @functools.wraps(func)
22 def inner(*args: Any, **kwargs: Any) -> Any:
23 return func(*args, **kwargs)
25 return inner # type: ignore
27 return decorator
30def _is_interactive():
31 """Decide whether this is running in a REPL or IPython notebook."""
32 try:
33 main = __import__("__main__", None, None, fromlist=["__file__"])
34 except ModuleNotFoundError:
35 return False
36 return not hasattr(main, "__file__")
39def get_folder(
40 current_folder: Optional[str] = None,
41 usecwd: bool = False,
42) -> str:
43 """Get the the current working directory."""
44 if usecwd or _is_interactive() or getattr(sys, "frozen", False):
45 # Should work without __file__, e.g. in REPL or IPython notebook.
46 folder = os.getcwd()
47 elif current_folder is not None: # [ADD] option to specify the folder
48 folder = current_folder
49 else:
50 # will work for .py files
51 frame = sys._getframe()
52 current_file = __file__
54 while frame.f_code.co_filename == current_file:
55 assert frame.f_back is not None
56 frame = frame.f_back
57 frame_filename = frame.f_code.co_filename
58 folder = os.path.dirname(os.path.abspath(frame_filename))
60 return folder
63def find_files(folder: str, filenames: list[str]) -> list[str]:
64 """Find files in the folder or its parent folders."""
65 results = []
66 for dirname in _walk_to_root(folder):
67 for filename in filenames:
68 check_path = os.path.join(dirname, filename)
69 if os.path.isfile(check_path):
70 results.append(check_path)
71 # return the first found file among the filenames
72 break
73 return results
76# rewrite find_dotenv, origin in https://github.com/theskumar/python-dotenv/blob/main/src/dotenv/main.py
77def find_dotenv(
78 filename: str = ".env",
79 raise_error_if_not_found: bool = False,
80 current_folder: Optional[str] = None,
81 usecwd: bool = False,
82) -> str:
83 """Search in increasingly higher folders for the given file.
85 Returns path to the file if found, or an empty string otherwise.
86 """
87 # Rewrited the original function to add the option to start with a custom folder
88 folder = get_folder(current_folder, usecwd)
89 results = find_files(folder, [filename])
90 if results:
91 return results[0]
93 if raise_error_if_not_found:
94 raise IOError("File not found")
95 return ""
98def get_num_tokens(prompt: str, encoding: str = "cl100k_base") -> int:
99 """Get the number of tokens in the prompt for the given encoding."""
100 return len(tiktoken.get_encoding(encoding).encode(prompt))
103def get_meta_file(trace_file: str) -> str:
104 """Get the meta file storing metadata of the trace file."""
105 # meta file derived from trace_file: *.pkl -> *_meta.json
106 return os.path.splitext(trace_file)[0] + "_meta.json"
109def timeit(func: Callable) -> Callable:
110 """Time the execution of a function as a decorator."""
112 @functools.wraps(func)
113 def timer(*args, **kwargs):
114 start = time.time()
115 result = func(*args, **kwargs)
116 end = time.time()
117 logger.info(f"{func.__name__} executed in {end - start:.2f} seconds.")
118 return result
120 return timer
123class LoguruFormatter:
124 """Custom formatter for loguru logger."""
126 def __init__(
127 self,
128 fmt: Optional[str] = None,
129 max_length: Optional[int] = None,
130 suffix_length: int = 0,
131 ):
132 """Initialize the formatter with the format string and max length of the message.
134 Args:
135 fmt: The format string for the log message.
136 max_length: The maximum length of the message, truncate if longer.
137 suffix_length: The length of the suffix to keep when truncating.
138 """
139 if fmt is None:
140 fmt = configs.getattrs("settings.logging.format")
141 self.fmt = fmt.rstrip()
142 self.max_length = max_length
143 self.suffix_length = suffix_length
145 def loguru_format(self, record: Dict) -> str:
146 """Format the log message with the record."""
147 msg = record["message"]
148 fmt = self.fmt
149 if self.max_length is not None and len(msg) > self.max_length:
150 suffix_len = min(self.max_length, self.suffix_length)
151 truncated = msg[: self.max_length - suffix_len]
152 truncated += f"...(snipped {len(msg) - self.max_length} chars)"
153 if suffix_len > 0:
154 truncated += "..." + msg[-suffix_len:]
155 record["trunc_message"] = truncated
156 fmt = fmt.replace("{message}", "{trunc_message}")
157 return fmt + "\n"