Coverage for src/appl/servers/manager.py: 77%

66 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-22 15:39 -0800

1import copy 

2import threading 

3from typing import Any, Dict, Optional 

4 

5from litellm import model_list, provider_list 

6from loguru import logger 

7 

8from ..core.config import configs 

9from ..core.server import BaseServer, DummyServer 

10 

11 

12def _init_server( 

13 model: str, 

14 provider: Optional[str] = None, 

15 base_url: Optional[str] = None, 

16 api_key: Optional[str] = None, 

17 **kwargs: Any, 

18) -> BaseServer: 

19 """Initialize a server based on the model, provider and other arguments.""" 

20 if provider is None: 

21 provider = model.split("/")[0] 

22 

23 is_custom_llm = provider.split("/")[0] == "custom" 

24 custom_llm_provider = ( 

25 provider[len("custom/") :] or "openai" if is_custom_llm else None 

26 ) # "openai" is the default provider for custom models 

27 

28 if model == "_dummy": 

29 server: BaseServer = DummyServer() # for testing purposes 

30 elif is_custom_llm or model in model_list or provider in provider_list: 

31 from .api import APIServer 

32 

33 msg = f"Initializing APIserver for model {model}" 

34 if base_url is not None: 

35 msg += f" with address {base_url}" 

36 logger.info(msg) 

37 server = APIServer( 

38 model, 

39 base_url=base_url, 

40 api_key=api_key, 

41 custom_llm_provider=custom_llm_provider, 

42 **kwargs, 

43 ) 

44 else: 

45 raise ValueError(f"Unknown model {model}") 

46 

47 return server 

48 

49 

50def _get_server_configs(name: str) -> dict: 

51 server_configs = {} 

52 if name not in configs.get("servers", {}): 

53 logger.warning( 

54 f"Server {name} not found in configs, using the server name as model name" 

55 ) 

56 server_configs["model"] = name 

57 else: 

58 server_configs = configs.servers[name] 

59 for _ in range(100): # prevent infinite loop (max 100 templates) 

60 if "template" not in server_configs: 

61 break 

62 server_configs = copy.deepcopy(server_configs) 

63 template_name = server_configs.pop("template") 

64 if template_name not in configs.servers: 

65 raise ValueError(f"Server config template {template_name} not found") 

66 template_config = configs.servers[template_name] 

67 # override template config 

68 server_configs = {**template_config, **server_configs} 

69 if "template" in server_configs: 

70 raise ValueError(f"Template loop detected in server config {name}") 

71 return server_configs 

72 

73 

74class ServerManager: 

75 """The manager for all servers.""" 

76 

77 def __init__(self) -> None: 

78 """Initialize the server manager.""" 

79 self._lock = threading.Lock() 

80 self._servers: Dict[str, BaseServer] = {} 

81 

82 def register_server(self, name: str, server: BaseServer) -> None: 

83 """Register a server with a name.""" 

84 self._servers[name] = server 

85 

86 def close_server(self, name: str) -> None: 

87 """Close a server by name.""" 

88 if name in self._servers: 

89 self._servers[name].close() 

90 del self._servers[name] 

91 

92 def get_server(self, name: Optional[str]) -> BaseServer: 

93 """Get a server by name. If name is None, get the default server.""" 

94 if name is None: 

95 name = configs.getattrs("servers.default") 

96 

97 with self._lock: 

98 if name not in self._servers: 

99 server_configs = _get_server_configs(name) 

100 server = _init_server(**server_configs) 

101 self.register_server(name, server) 

102 return self._servers[name] 

103 

104 @property 

105 def default_server(self) -> BaseServer: 

106 """The default server.""" 

107 if "default" not in self._servers: 

108 raise ValueError("Default server not found") 

109 return self._servers["default"] 

110 

111 

112# Singleton server manager 

113server_manager = ServerManager()