Coverage for src/appl/utils.py: 61%

88 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-22 15:39 -0800

1import functools 

2import os 

3import sys 

4import time 

5from typing import Any, Callable, Dict, Optional, Type, TypeVar 

6 

7import tiktoken 

8from dotenv.main import _walk_to_root 

9from loguru import logger 

10 

11from .core.config import configs 

12 

13try: 

14 from langsmith import traceable as _langsmith_traceable # type: ignore 

15except Exception: 

16 F = TypeVar("F", bound=Callable) 

17 

18 # compatible to the case when langsmith is not installed 

19 def _langsmith_traceable(*trace_args: Any, **trace_kwargs: Any) -> Callable[[F], F]: 

20 def decorator(func: F) -> F: 

21 @functools.wraps(func) 

22 def inner(*args: Any, **kwargs: Any) -> Any: 

23 return func(*args, **kwargs) 

24 

25 return inner # type: ignore 

26 

27 return decorator 

28 

29 

30def _is_interactive(): 

31 """Decide whether this is running in a REPL or IPython notebook.""" 

32 try: 

33 main = __import__("__main__", None, None, fromlist=["__file__"]) 

34 except ModuleNotFoundError: 

35 return False 

36 return not hasattr(main, "__file__") 

37 

38 

39def get_folder( 

40 current_folder: Optional[str] = None, 

41 usecwd: bool = False, 

42) -> str: 

43 """Get the the current working directory.""" 

44 if usecwd or _is_interactive() or getattr(sys, "frozen", False): 

45 # Should work without __file__, e.g. in REPL or IPython notebook. 

46 folder = os.getcwd() 

47 elif current_folder is not None: # [ADD] option to specify the folder 

48 folder = current_folder 

49 else: 

50 # will work for .py files 

51 frame = sys._getframe() 

52 current_file = __file__ 

53 

54 while frame.f_code.co_filename == current_file: 

55 assert frame.f_back is not None 

56 frame = frame.f_back 

57 frame_filename = frame.f_code.co_filename 

58 folder = os.path.dirname(os.path.abspath(frame_filename)) 

59 

60 return folder 

61 

62 

63def find_files(folder: str, filenames: list[str]) -> list[str]: 

64 """Find files in the folder or its parent folders.""" 

65 results = [] 

66 for dirname in _walk_to_root(folder): 

67 for filename in filenames: 

68 check_path = os.path.join(dirname, filename) 

69 if os.path.isfile(check_path): 

70 results.append(check_path) 

71 # return the first found file among the filenames 

72 break 

73 return results 

74 

75 

76# rewrite find_dotenv, origin in https://github.com/theskumar/python-dotenv/blob/main/src/dotenv/main.py 

77def find_dotenv( 

78 filename: str = ".env", 

79 raise_error_if_not_found: bool = False, 

80 current_folder: Optional[str] = None, 

81 usecwd: bool = False, 

82) -> str: 

83 """Search in increasingly higher folders for the given file. 

84 

85 Returns path to the file if found, or an empty string otherwise. 

86 """ 

87 # Rewrited the original function to add the option to start with a custom folder 

88 folder = get_folder(current_folder, usecwd) 

89 results = find_files(folder, [filename]) 

90 if results: 

91 return results[0] 

92 

93 if raise_error_if_not_found: 

94 raise IOError("File not found") 

95 return "" 

96 

97 

98def get_num_tokens(prompt: str, encoding: str = "cl100k_base") -> int: 

99 """Get the number of tokens in the prompt for the given encoding.""" 

100 return len(tiktoken.get_encoding(encoding).encode(prompt)) 

101 

102 

103def get_meta_file(trace_file: str) -> str: 

104 """Get the meta file storing metadata of the trace file.""" 

105 # meta file derived from trace_file: *.pkl -> *_meta.json 

106 return os.path.splitext(trace_file)[0] + "_meta.json" 

107 

108 

109def timeit(func: Callable) -> Callable: 

110 """Time the execution of a function as a decorator.""" 

111 

112 @functools.wraps(func) 

113 def timer(*args, **kwargs): 

114 start = time.time() 

115 result = func(*args, **kwargs) 

116 end = time.time() 

117 logger.info(f"{func.__name__} executed in {end - start:.2f} seconds.") 

118 return result 

119 

120 return timer 

121 

122 

123class LoguruFormatter: 

124 """Custom formatter for loguru logger.""" 

125 

126 def __init__( 

127 self, 

128 fmt: Optional[str] = None, 

129 max_length: Optional[int] = None, 

130 suffix_length: int = 0, 

131 ): 

132 """Initialize the formatter with the format string and max length of the message. 

133 

134 Args: 

135 fmt: The format string for the log message. 

136 max_length: The maximum length of the message, truncate if longer. 

137 suffix_length: The length of the suffix to keep when truncating. 

138 """ 

139 if fmt is None: 

140 fmt = configs.getattrs("settings.logging.format") 

141 self.fmt = fmt.rstrip() 

142 self.max_length = max_length 

143 self.suffix_length = suffix_length 

144 

145 def loguru_format(self, record: Dict) -> str: 

146 """Format the log message with the record.""" 

147 msg = record["message"] 

148 fmt = self.fmt 

149 if self.max_length is not None and len(msg) > self.max_length: 

150 suffix_len = min(self.max_length, self.suffix_length) 

151 truncated = msg[: self.max_length - suffix_len] 

152 truncated += f"...(snipped {len(msg) - self.max_length} chars)" 

153 if suffix_len > 0: 

154 truncated += "..." + msg[-suffix_len:] 

155 record["trunc_message"] = truncated 

156 fmt = fmt.replace("{message}", "{trunc_message}") 

157 return fmt + "\n"