Coverage for src/appl/servers/manager.py: 77%
66 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-22 15:39 -0800
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-22 15:39 -0800
1import copy
2import threading
3from typing import Any, Dict, Optional
5from litellm import model_list, provider_list
6from loguru import logger
8from ..core.config import configs
9from ..core.server import BaseServer, DummyServer
12def _init_server(
13 model: str,
14 provider: Optional[str] = None,
15 base_url: Optional[str] = None,
16 api_key: Optional[str] = None,
17 **kwargs: Any,
18) -> BaseServer:
19 """Initialize a server based on the model, provider and other arguments."""
20 if provider is None:
21 provider = model.split("/")[0]
23 is_custom_llm = provider.split("/")[0] == "custom"
24 custom_llm_provider = (
25 provider[len("custom/") :] or "openai" if is_custom_llm else None
26 ) # "openai" is the default provider for custom models
28 if model == "_dummy":
29 server: BaseServer = DummyServer() # for testing purposes
30 elif is_custom_llm or model in model_list or provider in provider_list:
31 from .api import APIServer
33 msg = f"Initializing APIserver for model {model}"
34 if base_url is not None:
35 msg += f" with address {base_url}"
36 logger.info(msg)
37 server = APIServer(
38 model,
39 base_url=base_url,
40 api_key=api_key,
41 custom_llm_provider=custom_llm_provider,
42 **kwargs,
43 )
44 else:
45 raise ValueError(f"Unknown model {model}")
47 return server
50def _get_server_configs(name: str) -> dict:
51 server_configs = {}
52 if name not in configs.get("servers", {}):
53 logger.warning(
54 f"Server {name} not found in configs, using the server name as model name"
55 )
56 server_configs["model"] = name
57 else:
58 server_configs = configs.servers[name]
59 for _ in range(100): # prevent infinite loop (max 100 templates)
60 if "template" not in server_configs:
61 break
62 server_configs = copy.deepcopy(server_configs)
63 template_name = server_configs.pop("template")
64 if template_name not in configs.servers:
65 raise ValueError(f"Server config template {template_name} not found")
66 template_config = configs.servers[template_name]
67 # override template config
68 server_configs = {**template_config, **server_configs}
69 if "template" in server_configs:
70 raise ValueError(f"Template loop detected in server config {name}")
71 return server_configs
74class ServerManager:
75 """The manager for all servers."""
77 def __init__(self) -> None:
78 """Initialize the server manager."""
79 self._lock = threading.Lock()
80 self._servers: Dict[str, BaseServer] = {}
82 def register_server(self, name: str, server: BaseServer) -> None:
83 """Register a server with a name."""
84 self._servers[name] = server
86 def close_server(self, name: str) -> None:
87 """Close a server by name."""
88 if name in self._servers:
89 self._servers[name].close()
90 del self._servers[name]
92 def get_server(self, name: Optional[str]) -> BaseServer:
93 """Get a server by name. If name is None, get the default server."""
94 if name is None:
95 name = configs.getattrs("servers.default")
97 with self._lock:
98 if name not in self._servers:
99 server_configs = _get_server_configs(name)
100 server = _init_server(**server_configs)
101 self.register_server(name, server)
102 return self._servers[name]
104 @property
105 def default_server(self) -> BaseServer:
106 """The default server."""
107 if "default" not in self._servers:
108 raise ValueError("Default server not found")
109 return self._servers["default"]
112# Singleton server manager
113server_manager = ServerManager()