Coverage for src/appl/core/message.py: 89%
205 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-22 15:39 -0800
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-22 15:39 -0800
1from __future__ import annotations
3from abc import ABC
4from dataclasses import dataclass
5from typing import Any, Dict, Iterator, List, Optional, TypeVar
7from loguru import logger
8from pydantic import BaseModel, Field, model_validator
9from termcolor import COLORS, colored
11from .config import configs
12from .tool import ToolCall
13from .types import (
14 ASSISTANT_ROLE,
15 SYSTEM_ROLE,
16 TOOL_ROLE,
17 USER_ROLE,
18 ContentList,
19 FutureValue,
20 Image,
21 MessageRole,
22 MessageRoleType,
23 StrOrImg,
24)
27def get_role_color(role: MessageRole) -> Optional[str]:
28 """Get the color of the message based on the role."""
29 color_dict = configs.getattrs("settings.messages.colors", {})
30 return color_dict.get(role.type, None)
33def get_colored_role_text(role: Optional[MessageRole], content: str) -> str:
34 """Get the colored text based on the role."""
35 if role:
36 color = get_role_color(role)
37 if color in COLORS:
38 return colored(content, color) # type: ignore
39 return content
42class BaseMessage(BaseModel, ABC):
43 """The base class for messages."""
45 content: Any = Field(..., description="The content of the message")
46 role: Optional[MessageRole] = Field(
47 None, description="The role of the messages owner"
48 )
49 info: Optional[Dict] = Field(
50 {}, description="Additional information for the message"
51 )
53 def __init__(self, content: Any = None, *args: Any, **kwargs: Any) -> None:
54 """Create a message with content and extra arguments.
56 Provides a more flexible way to create a message.
57 """
58 super().__init__(content=content, *args, **kwargs)
60 @property
61 def is_system(self) -> bool:
62 """Whether the message is a system message."""
63 return self.role is not None and self.role.is_system
65 @property
66 def is_user(self) -> bool:
67 """Whether the message is a user message."""
68 return self.role is not None and self.role.is_user
70 @property
71 def is_ai(self) -> bool:
72 """Whether the message is an assistant message."""
73 return self.role is not None and self.role.is_assistant
75 @property
76 def is_tool(self) -> bool:
77 """Whether the message is a tool message."""
78 return self.role is not None and self.role.is_tool
80 def validate_role(self, target_role: MessageRole) -> None:
81 """Validate the role of the message, fill the role if not provided."""
82 target_type = target_role.type
83 if target_type is None:
84 raise ValueError("Target role type must be provided.")
85 if self.role is None:
86 self.role = target_role
87 elif self.role.type is None:
88 # fill the role type as the target type
89 self.role = MessageRole(type=target_type, name=self.role.name)
90 elif self.role.type != target_type:
91 raise ValueError(f"Invalid role for {target_type} message: {self.role}")
93 def should_merge(self, other: "BaseMessage") -> bool:
94 """Whether the message should be merged with the other message."""
95 if self.is_tool or other.is_tool:
96 # not merge tool messages
97 return False
98 if self.content is None or other.content is None:
99 return False
100 return self.role == other.role
102 def get_content(self, as_str: bool = False) -> Any:
103 """Get the content of the message.
105 Materialize the content if it is a FutureValue.
106 """
107 content = self.content
108 if content is not None:
109 if isinstance(content, ContentList):
110 return content.get_contents() # return a list of dict
111 if isinstance(content, FutureValue):
112 # materialize the content
113 content = content.val
114 if as_str: # not apply to ContentList
115 content = str(content)
116 return content
118 # TODO: implement classmethod: from dict
119 def get_dict(self, default_role: Optional[MessageRole] = None) -> Dict[str, Any]:
120 """Return a dict representation of the message."""
121 # materialize the content using str()
122 role = self.role or default_role
123 if role is None:
124 raise ValueError("Role or default role must be provided.")
125 if role.type is None:
126 if default_role and default_role.type:
127 role = MessageRole(type=default_role.type, name=role.name)
128 else:
129 raise ValueError("Role type must be provided.")
130 data = {"content": self.get_content(as_str=True), **role.get_dict()}
131 return data
133 def merge(self: "Message", other: "BaseMessage") -> Optional["Message"]:
134 """Merge the message with another message."""
135 if self.should_merge(other):
136 # merge the content
137 res = self.model_copy()
138 if isinstance(other.content, ContentList) and not isinstance(
139 res.content, ContentList
140 ):
141 res.content = ContentList(contents=[res.content])
142 res.content += other.content
143 return res
144 return None
146 def str_with_default_role(self, default_role: Optional[MessageRole] = None) -> str:
147 """Return the string representation of the message with default role."""
148 return self._get_colored_content(self.role or default_role)
150 def _get_serialized_content(self, role: Optional[MessageRole] = None) -> str:
151 if role is None:
152 return f"{self.content}"
153 return f"{role}: {self.content}"
155 def _get_colored_content(self, role: Optional[MessageRole] = None) -> str:
156 return get_colored_role_text(role, self._get_serialized_content(role))
158 def __str__(self) -> str:
159 return self._get_colored_content(self.role)
161 def __repr__(self) -> str:
162 return f"Message(role={self.role!r}, content={self.content!r})"
165Message = TypeVar("Message", bound=BaseMessage)
168class ChatMessage(BaseMessage):
169 """A message in the chat conversation."""
171 def __init__(
172 self,
173 content: Any = None,
174 *,
175 role: Optional[MessageRole] = None,
176 **kwargs: Any,
177 ) -> None:
178 """Create a chat message with content and extra arguments."""
179 super().__init__(content=content, role=role, **kwargs)
182class SystemMessage(BaseMessage):
183 """A system message in the conversation."""
185 def __init__(
186 self,
187 content: Any = None,
188 *,
189 role: Optional[MessageRole] = None,
190 **kwargs: Any,
191 ) -> None:
192 """Create a system message with content and extra arguments."""
193 super().__init__(content=content, role=role, **kwargs)
194 self.validate_role(SYSTEM_ROLE)
197class UserMessage(BaseMessage):
198 """A user message in the conversation."""
200 def __init__(
201 self,
202 content: Any = None,
203 *,
204 role: Optional[MessageRole] = None,
205 **kwargs: Any,
206 ) -> None:
207 """Create a user message with content and extra arguments."""
208 super().__init__(content=content, role=role, **kwargs)
209 self.validate_role(USER_ROLE)
212class AIMessage(BaseMessage):
213 """An assistant message in the conversation."""
215 tool_calls: List[ToolCall] = Field(
216 [], description="The tool calls generated by the model."
217 )
219 def __init__(
220 self,
221 content: Any = None,
222 *,
223 role: Optional[MessageRole] = None,
224 tool_calls: Optional[List[ToolCall]] = None,
225 **kwargs: Any,
226 ) -> None:
227 """Create an assistant message with content and extra arguments."""
228 if tool_calls is None:
229 tool_calls = []
230 super().__init__(content=content, role=role, tool_calls=tool_calls, **kwargs)
231 self.validate_role(ASSISTANT_ROLE)
233 def get_dict(self, default_role: Optional[MessageRole] = None) -> Dict[str, Any]:
234 """Return a dict representation of the message."""
235 data = super().get_dict(default_role)
236 if len(self.tool_calls):
237 data["tool_calls"] = [call.get_dict() for call in self.tool_calls]
238 return data
240 def _get_serialized_content(self, role: Optional[MessageRole] = None) -> str:
241 assert role == self.role, "Role must be the same as the message role."
242 s = f"{self.role}:"
243 if self.content is not None:
244 s += f" {self.content}"
245 if len(self.tool_calls):
246 s += f" {self.tool_calls}"
247 return s
249 def __repr__(self) -> str:
250 return (
251 f"AIMessage(role={self.role!r}, content={self.content!r}, "
252 f"tool_calls={self.tool_calls!r})"
253 )
256class ToolMessage(BaseMessage):
257 """A tool message in the conversation."""
259 tool_call_id: str = Field(
260 ..., description="Tool call that this message is responding to."
261 )
262 has_error: bool = Field(
263 False, description="Whether the message is an error message."
264 )
266 def __init__(
267 self,
268 content: Any = None,
269 *,
270 role: Optional[MessageRole] = None,
271 tool_call_id: str = "",
272 **kwargs: Any,
273 ) -> None:
274 """Create a tool message with content and extra arguments."""
275 super().__init__(
276 content=content, role=role, tool_call_id=tool_call_id, **kwargs
277 )
278 self.validate_role(TOOL_ROLE)
280 def get_dict(self, *args: Any, **kwargs: Any) -> Dict[str, Any]:
281 """Return a dict representation of the message."""
282 data = super().get_dict(*args, **kwargs)
283 data["tool_call_id"] = self.tool_call_id
284 return data
287MESSAGE_CLASS_DICT = {
288 None: ChatMessage,
289 MessageRoleType.SYSTEM: SystemMessage,
290 MessageRoleType.USER: UserMessage,
291 MessageRoleType.ASSISTANT: AIMessage,
292 MessageRoleType.TOOL: ToolMessage,
293}
296def as_message(
297 role: Optional[MessageRole],
298 content: StrOrImg,
299 *args: Any,
300 **kwargs: Any,
301) -> BaseMessage:
302 """Create a message with role, content and extra arguments."""
303 role_type = MessageRoleType(role.type) if role else None
304 if role_type not in MESSAGE_CLASS_DICT:
305 raise ValueError(f"Unknown role: {role}")
306 cls = MESSAGE_CLASS_DICT[role_type]
307 if isinstance(content, Image):
308 content = ContentList(contents=[content]) # type: ignore
309 return cls(content=content, role=role, *args, **kwargs)
312def collapse_messages(messages: List[Message]) -> List[Message]:
313 """Collapse a list of the messages by merging the messages with the same sender."""
314 res = []
315 msg: Optional[Message] = None
316 for m in messages:
317 if msg is None:
318 msg = m
319 else:
320 if (tmp := msg.merge(m)) is not None:
321 # merge success, update the msg
322 msg = tmp
323 else:
324 # merge failed, append the old message to the list
325 res.append(msg)
326 # a new message starts
327 msg = m
328 if msg is not None:
329 res.append(msg)
330 return res
333class Conversation(BaseModel):
334 """A conversation containing messages."""
336 system_messages: List[SystemMessage] = Field([], description="The system messages")
337 messages: List[BaseMessage] = Field(
338 [], description="The messages in the conversation"
339 )
341 @property
342 def has_message_role(self) -> bool:
343 """Whether the conversation has message roles."""
344 return any(m.role is not None for m in self.system_messages + self.messages)
346 def collapse(self) -> "Conversation":
347 """Collapse the messages in the conversation."""
348 self.system_messages = collapse_messages(self.system_messages)
349 if len(self.system_messages) > 1:
350 raise ValueError("System messages cannot be fully collapsed.")
351 self.messages = collapse_messages(self.messages)
352 return self
354 def materialize(self) -> None:
355 """Materialize the messages in the conversation."""
356 str(self)
358 def set_system_messages(self, messages: List[SystemMessage]) -> None:
359 """Set the system messages."""
360 if len(self.system_messages):
361 logger.warning("Overwriting system message.")
362 self.system_messages = messages
364 def append(self, message: Message) -> None:
365 """Append a message to the conversation."""
366 if message.is_system:
367 if len(self.messages):
368 # NOTE: Now allow appending system message after other messages
369 # raise ValueError("Cannot append system message after other messages.")
371 # Added a warning instead
372 logger.warning(
373 "Modifying system message after other types of messages."
374 )
375 self.system_messages.append(message) # type: ignore
376 else:
377 self.messages.append(message)
379 def extend(self, other: "Conversation") -> None:
380 """Extend the conversation with another conversation."""
381 for sys_m in other.system_messages:
382 self.append(sys_m)
383 for m in other.messages:
384 self.append(m)
386 def pop(self) -> BaseMessage:
387 """Pop the last message from the conversation."""
388 return self.messages.pop()
390 # TODO: implement classmethod: from list of dict
391 def as_list(
392 self, default_role: Optional[MessageRole] = USER_ROLE
393 ) -> List[Dict[str, str]]:
394 """Return a list of dict representation of the conversation."""
395 self.collapse()
396 res = [m.get_dict() for m in self.system_messages]
397 res += [m.get_dict(default_role) for m in self.messages]
398 return res
400 def make_copy(self):
401 """Make a copy of the conversation."""
402 return Conversation(
403 system_messages=self.system_messages.copy(),
404 messages=self.messages.copy(),
405 )
407 def __repr__(self) -> str:
408 return f"Conversation({self.system_messages!r}, {self.messages!r})"
410 def __str__(self) -> str:
411 self.collapse()
412 role = USER_ROLE if self.has_message_role else None
413 contents = [m.str_with_default_role() for m in self.system_messages]
414 contents += [m.str_with_default_role(role) for m in self.messages]
415 return "\n".join(contents)
417 def __iter__(self) -> Iterator[BaseMessage]: # type: ignore
418 """Iterate over messages, excluding the system message."""
419 return iter(self.messages)
421 def __getitem__(self, index: int) -> BaseMessage:
422 """Get message by index, excluding the system message."""
423 return self.messages[index]
425 def __len__(self) -> int:
426 """Length of the conversation, excluding the system message."""
427 return len(self.messages)