Coverage for src/appl/core/message.py: 89%

205 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-22 15:39 -0800

1from __future__ import annotations 

2 

3from abc import ABC 

4from dataclasses import dataclass 

5from typing import Any, Dict, Iterator, List, Optional, TypeVar 

6 

7from loguru import logger 

8from pydantic import BaseModel, Field, model_validator 

9from termcolor import COLORS, colored 

10 

11from .config import configs 

12from .tool import ToolCall 

13from .types import ( 

14 ASSISTANT_ROLE, 

15 SYSTEM_ROLE, 

16 TOOL_ROLE, 

17 USER_ROLE, 

18 ContentList, 

19 FutureValue, 

20 Image, 

21 MessageRole, 

22 MessageRoleType, 

23 StrOrImg, 

24) 

25 

26 

27def get_role_color(role: MessageRole) -> Optional[str]: 

28 """Get the color of the message based on the role.""" 

29 color_dict = configs.getattrs("settings.messages.colors", {}) 

30 return color_dict.get(role.type, None) 

31 

32 

33def get_colored_role_text(role: Optional[MessageRole], content: str) -> str: 

34 """Get the colored text based on the role.""" 

35 if role: 

36 color = get_role_color(role) 

37 if color in COLORS: 

38 return colored(content, color) # type: ignore 

39 return content 

40 

41 

42class BaseMessage(BaseModel, ABC): 

43 """The base class for messages.""" 

44 

45 content: Any = Field(..., description="The content of the message") 

46 role: Optional[MessageRole] = Field( 

47 None, description="The role of the messages owner" 

48 ) 

49 info: Optional[Dict] = Field( 

50 {}, description="Additional information for the message" 

51 ) 

52 

53 def __init__(self, content: Any = None, *args: Any, **kwargs: Any) -> None: 

54 """Create a message with content and extra arguments. 

55 

56 Provides a more flexible way to create a message. 

57 """ 

58 super().__init__(content=content, *args, **kwargs) 

59 

60 @property 

61 def is_system(self) -> bool: 

62 """Whether the message is a system message.""" 

63 return self.role is not None and self.role.is_system 

64 

65 @property 

66 def is_user(self) -> bool: 

67 """Whether the message is a user message.""" 

68 return self.role is not None and self.role.is_user 

69 

70 @property 

71 def is_ai(self) -> bool: 

72 """Whether the message is an assistant message.""" 

73 return self.role is not None and self.role.is_assistant 

74 

75 @property 

76 def is_tool(self) -> bool: 

77 """Whether the message is a tool message.""" 

78 return self.role is not None and self.role.is_tool 

79 

80 def validate_role(self, target_role: MessageRole) -> None: 

81 """Validate the role of the message, fill the role if not provided.""" 

82 target_type = target_role.type 

83 if target_type is None: 

84 raise ValueError("Target role type must be provided.") 

85 if self.role is None: 

86 self.role = target_role 

87 elif self.role.type is None: 

88 # fill the role type as the target type 

89 self.role = MessageRole(type=target_type, name=self.role.name) 

90 elif self.role.type != target_type: 

91 raise ValueError(f"Invalid role for {target_type} message: {self.role}") 

92 

93 def should_merge(self, other: "BaseMessage") -> bool: 

94 """Whether the message should be merged with the other message.""" 

95 if self.is_tool or other.is_tool: 

96 # not merge tool messages 

97 return False 

98 if self.content is None or other.content is None: 

99 return False 

100 return self.role == other.role 

101 

102 def get_content(self, as_str: bool = False) -> Any: 

103 """Get the content of the message. 

104 

105 Materialize the content if it is a FutureValue. 

106 """ 

107 content = self.content 

108 if content is not None: 

109 if isinstance(content, ContentList): 

110 return content.get_contents() # return a list of dict 

111 if isinstance(content, FutureValue): 

112 # materialize the content 

113 content = content.val 

114 if as_str: # not apply to ContentList 

115 content = str(content) 

116 return content 

117 

118 # TODO: implement classmethod: from dict 

119 def get_dict(self, default_role: Optional[MessageRole] = None) -> Dict[str, Any]: 

120 """Return a dict representation of the message.""" 

121 # materialize the content using str() 

122 role = self.role or default_role 

123 if role is None: 

124 raise ValueError("Role or default role must be provided.") 

125 if role.type is None: 

126 if default_role and default_role.type: 

127 role = MessageRole(type=default_role.type, name=role.name) 

128 else: 

129 raise ValueError("Role type must be provided.") 

130 data = {"content": self.get_content(as_str=True), **role.get_dict()} 

131 return data 

132 

133 def merge(self: "Message", other: "BaseMessage") -> Optional["Message"]: 

134 """Merge the message with another message.""" 

135 if self.should_merge(other): 

136 # merge the content 

137 res = self.model_copy() 

138 if isinstance(other.content, ContentList) and not isinstance( 

139 res.content, ContentList 

140 ): 

141 res.content = ContentList(contents=[res.content]) 

142 res.content += other.content 

143 return res 

144 return None 

145 

146 def str_with_default_role(self, default_role: Optional[MessageRole] = None) -> str: 

147 """Return the string representation of the message with default role.""" 

148 return self._get_colored_content(self.role or default_role) 

149 

150 def _get_serialized_content(self, role: Optional[MessageRole] = None) -> str: 

151 if role is None: 

152 return f"{self.content}" 

153 return f"{role}: {self.content}" 

154 

155 def _get_colored_content(self, role: Optional[MessageRole] = None) -> str: 

156 return get_colored_role_text(role, self._get_serialized_content(role)) 

157 

158 def __str__(self) -> str: 

159 return self._get_colored_content(self.role) 

160 

161 def __repr__(self) -> str: 

162 return f"Message(role={self.role!r}, content={self.content!r})" 

163 

164 

165Message = TypeVar("Message", bound=BaseMessage) 

166 

167 

168class ChatMessage(BaseMessage): 

169 """A message in the chat conversation.""" 

170 

171 def __init__( 

172 self, 

173 content: Any = None, 

174 *, 

175 role: Optional[MessageRole] = None, 

176 **kwargs: Any, 

177 ) -> None: 

178 """Create a chat message with content and extra arguments.""" 

179 super().__init__(content=content, role=role, **kwargs) 

180 

181 

182class SystemMessage(BaseMessage): 

183 """A system message in the conversation.""" 

184 

185 def __init__( 

186 self, 

187 content: Any = None, 

188 *, 

189 role: Optional[MessageRole] = None, 

190 **kwargs: Any, 

191 ) -> None: 

192 """Create a system message with content and extra arguments.""" 

193 super().__init__(content=content, role=role, **kwargs) 

194 self.validate_role(SYSTEM_ROLE) 

195 

196 

197class UserMessage(BaseMessage): 

198 """A user message in the conversation.""" 

199 

200 def __init__( 

201 self, 

202 content: Any = None, 

203 *, 

204 role: Optional[MessageRole] = None, 

205 **kwargs: Any, 

206 ) -> None: 

207 """Create a user message with content and extra arguments.""" 

208 super().__init__(content=content, role=role, **kwargs) 

209 self.validate_role(USER_ROLE) 

210 

211 

212class AIMessage(BaseMessage): 

213 """An assistant message in the conversation.""" 

214 

215 tool_calls: List[ToolCall] = Field( 

216 [], description="The tool calls generated by the model." 

217 ) 

218 

219 def __init__( 

220 self, 

221 content: Any = None, 

222 *, 

223 role: Optional[MessageRole] = None, 

224 tool_calls: Optional[List[ToolCall]] = None, 

225 **kwargs: Any, 

226 ) -> None: 

227 """Create an assistant message with content and extra arguments.""" 

228 if tool_calls is None: 

229 tool_calls = [] 

230 super().__init__(content=content, role=role, tool_calls=tool_calls, **kwargs) 

231 self.validate_role(ASSISTANT_ROLE) 

232 

233 def get_dict(self, default_role: Optional[MessageRole] = None) -> Dict[str, Any]: 

234 """Return a dict representation of the message.""" 

235 data = super().get_dict(default_role) 

236 if len(self.tool_calls): 

237 data["tool_calls"] = [call.get_dict() for call in self.tool_calls] 

238 return data 

239 

240 def _get_serialized_content(self, role: Optional[MessageRole] = None) -> str: 

241 assert role == self.role, "Role must be the same as the message role." 

242 s = f"{self.role}:" 

243 if self.content is not None: 

244 s += f" {self.content}" 

245 if len(self.tool_calls): 

246 s += f" {self.tool_calls}" 

247 return s 

248 

249 def __repr__(self) -> str: 

250 return ( 

251 f"AIMessage(role={self.role!r}, content={self.content!r}, " 

252 f"tool_calls={self.tool_calls!r})" 

253 ) 

254 

255 

256class ToolMessage(BaseMessage): 

257 """A tool message in the conversation.""" 

258 

259 tool_call_id: str = Field( 

260 ..., description="Tool call that this message is responding to." 

261 ) 

262 has_error: bool = Field( 

263 False, description="Whether the message is an error message." 

264 ) 

265 

266 def __init__( 

267 self, 

268 content: Any = None, 

269 *, 

270 role: Optional[MessageRole] = None, 

271 tool_call_id: str = "", 

272 **kwargs: Any, 

273 ) -> None: 

274 """Create a tool message with content and extra arguments.""" 

275 super().__init__( 

276 content=content, role=role, tool_call_id=tool_call_id, **kwargs 

277 ) 

278 self.validate_role(TOOL_ROLE) 

279 

280 def get_dict(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: 

281 """Return a dict representation of the message.""" 

282 data = super().get_dict(*args, **kwargs) 

283 data["tool_call_id"] = self.tool_call_id 

284 return data 

285 

286 

287MESSAGE_CLASS_DICT = { 

288 None: ChatMessage, 

289 MessageRoleType.SYSTEM: SystemMessage, 

290 MessageRoleType.USER: UserMessage, 

291 MessageRoleType.ASSISTANT: AIMessage, 

292 MessageRoleType.TOOL: ToolMessage, 

293} 

294 

295 

296def as_message( 

297 role: Optional[MessageRole], 

298 content: StrOrImg, 

299 *args: Any, 

300 **kwargs: Any, 

301) -> BaseMessage: 

302 """Create a message with role, content and extra arguments.""" 

303 role_type = MessageRoleType(role.type) if role else None 

304 if role_type not in MESSAGE_CLASS_DICT: 

305 raise ValueError(f"Unknown role: {role}") 

306 cls = MESSAGE_CLASS_DICT[role_type] 

307 if isinstance(content, Image): 

308 content = ContentList(contents=[content]) # type: ignore 

309 return cls(content=content, role=role, *args, **kwargs) 

310 

311 

312def collapse_messages(messages: List[Message]) -> List[Message]: 

313 """Collapse a list of the messages by merging the messages with the same sender.""" 

314 res = [] 

315 msg: Optional[Message] = None 

316 for m in messages: 

317 if msg is None: 

318 msg = m 

319 else: 

320 if (tmp := msg.merge(m)) is not None: 

321 # merge success, update the msg 

322 msg = tmp 

323 else: 

324 # merge failed, append the old message to the list 

325 res.append(msg) 

326 # a new message starts 

327 msg = m 

328 if msg is not None: 

329 res.append(msg) 

330 return res 

331 

332 

333class Conversation(BaseModel): 

334 """A conversation containing messages.""" 

335 

336 system_messages: List[SystemMessage] = Field([], description="The system messages") 

337 messages: List[BaseMessage] = Field( 

338 [], description="The messages in the conversation" 

339 ) 

340 

341 @property 

342 def has_message_role(self) -> bool: 

343 """Whether the conversation has message roles.""" 

344 return any(m.role is not None for m in self.system_messages + self.messages) 

345 

346 def collapse(self) -> "Conversation": 

347 """Collapse the messages in the conversation.""" 

348 self.system_messages = collapse_messages(self.system_messages) 

349 if len(self.system_messages) > 1: 

350 raise ValueError("System messages cannot be fully collapsed.") 

351 self.messages = collapse_messages(self.messages) 

352 return self 

353 

354 def materialize(self) -> None: 

355 """Materialize the messages in the conversation.""" 

356 str(self) 

357 

358 def set_system_messages(self, messages: List[SystemMessage]) -> None: 

359 """Set the system messages.""" 

360 if len(self.system_messages): 

361 logger.warning("Overwriting system message.") 

362 self.system_messages = messages 

363 

364 def append(self, message: Message) -> None: 

365 """Append a message to the conversation.""" 

366 if message.is_system: 

367 if len(self.messages): 

368 # NOTE: Now allow appending system message after other messages 

369 # raise ValueError("Cannot append system message after other messages.") 

370 

371 # Added a warning instead 

372 logger.warning( 

373 "Modifying system message after other types of messages." 

374 ) 

375 self.system_messages.append(message) # type: ignore 

376 else: 

377 self.messages.append(message) 

378 

379 def extend(self, other: "Conversation") -> None: 

380 """Extend the conversation with another conversation.""" 

381 for sys_m in other.system_messages: 

382 self.append(sys_m) 

383 for m in other.messages: 

384 self.append(m) 

385 

386 def pop(self) -> BaseMessage: 

387 """Pop the last message from the conversation.""" 

388 return self.messages.pop() 

389 

390 # TODO: implement classmethod: from list of dict 

391 def as_list( 

392 self, default_role: Optional[MessageRole] = USER_ROLE 

393 ) -> List[Dict[str, str]]: 

394 """Return a list of dict representation of the conversation.""" 

395 self.collapse() 

396 res = [m.get_dict() for m in self.system_messages] 

397 res += [m.get_dict(default_role) for m in self.messages] 

398 return res 

399 

400 def make_copy(self): 

401 """Make a copy of the conversation.""" 

402 return Conversation( 

403 system_messages=self.system_messages.copy(), 

404 messages=self.messages.copy(), 

405 ) 

406 

407 def __repr__(self) -> str: 

408 return f"Conversation({self.system_messages!r}, {self.messages!r})" 

409 

410 def __str__(self) -> str: 

411 self.collapse() 

412 role = USER_ROLE if self.has_message_role else None 

413 contents = [m.str_with_default_role() for m in self.system_messages] 

414 contents += [m.str_with_default_role(role) for m in self.messages] 

415 return "\n".join(contents) 

416 

417 def __iter__(self) -> Iterator[BaseMessage]: # type: ignore 

418 """Iterate over messages, excluding the system message.""" 

419 return iter(self.messages) 

420 

421 def __getitem__(self, index: int) -> BaseMessage: 

422 """Get message by index, excluding the system message.""" 

423 return self.messages[index] 

424 

425 def __len__(self) -> int: 

426 """Length of the conversation, excluding the system message.""" 

427 return len(self.messages)