103 lines
3.6 KiB
Python
103 lines
3.6 KiB
Python
|
|
from datetime import datetime
|
||
|
|
from typing import Any, Optional
|
||
|
|
|
||
|
|
from sqlalchemy import Column, DateTime, func, JSON, desc
|
||
|
|
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||
|
|
from sqlalchemy.orm import selectinload
|
||
|
|
from sqlmodel import Field, Relationship, SQLModel, select
|
||
|
|
|
||
|
|
from knowledge_platform.database.database import get_session
|
||
|
|
|
||
|
|
|
||
|
|
class SystemPromptsDao(AsyncAttrs, SQLModel, table=True):
|
||
|
|
__tablename__ = "system_prompt"
|
||
|
|
|
||
|
|
id: int | None = Field(default=None, primary_key=True)
|
||
|
|
title: str
|
||
|
|
prompt: str
|
||
|
|
created_at: datetime | None = Field(
|
||
|
|
sa_column=Column(DateTime(), server_default=func.now())
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
class MessageDao(AsyncAttrs, SQLModel, table=True):
|
||
|
|
__tablename__ = "message"
|
||
|
|
|
||
|
|
id: int | None = Field(default=None, primary_key=True)
|
||
|
|
chat_id: Optional[int] = Field(foreign_key="chat.id")
|
||
|
|
chat: Optional["ChatDao"] = Relationship(back_populates="messages")
|
||
|
|
role: str
|
||
|
|
content: str
|
||
|
|
timestamp: datetime | None = Field(
|
||
|
|
sa_column=Column(DateTime(), server_default=func.now())
|
||
|
|
)
|
||
|
|
meta: dict[Any, Any] = Field(sa_column=Column(JSON), default={})
|
||
|
|
parent_id: Optional[int] = Field(
|
||
|
|
foreign_key="message.id", default=None, nullable=True
|
||
|
|
)
|
||
|
|
parent: Optional["MessageDao"] = Relationship(
|
||
|
|
back_populates="replies",
|
||
|
|
sa_relationship_kwargs={"remote_side": "MessageDao.id"},
|
||
|
|
)
|
||
|
|
"""The message this message is responding to."""
|
||
|
|
replies: list["MessageDao"] = Relationship(back_populates="parent")
|
||
|
|
"""The replies to this message
|
||
|
|
(could be multiple replies e.g. from different models).
|
||
|
|
"""
|
||
|
|
model: str | None
|
||
|
|
"""The model that wrote this response. (Could switch models mid-chat, possibly)"""
|
||
|
|
|
||
|
|
|
||
|
|
class ChatDao(AsyncAttrs, SQLModel, table=True):
|
||
|
|
__tablename__ = "chat"
|
||
|
|
|
||
|
|
id: int = Field(default=None, primary_key=True)
|
||
|
|
model: str
|
||
|
|
title: str | None
|
||
|
|
started_at: datetime | None = Field(
|
||
|
|
sa_column=Column(DateTime(), server_default=func.now())
|
||
|
|
)
|
||
|
|
messages: list[MessageDao] = Relationship(back_populates="chat")
|
||
|
|
archived: bool = Field(default=False)
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
async def all() -> list["ChatDao"]:
|
||
|
|
async with get_session() as session:
|
||
|
|
# Create a subquery that finds the maximum
|
||
|
|
# (most recent) timestamp for each chat.
|
||
|
|
max_timestamp: Any = func.max(MessageDao.timestamp).label("max_timestamp")
|
||
|
|
subquery = (
|
||
|
|
select(MessageDao.chat_id, max_timestamp)
|
||
|
|
.group_by(MessageDao.chat_id)
|
||
|
|
.alias("subquery")
|
||
|
|
)
|
||
|
|
|
||
|
|
statement = (
|
||
|
|
select(ChatDao)
|
||
|
|
.join(subquery, subquery.c.chat_id == ChatDao.id)
|
||
|
|
.where(ChatDao.archived == False) # noqa: E712
|
||
|
|
.order_by(desc(subquery.c.max_timestamp))
|
||
|
|
.options(selectinload(ChatDao.messages))
|
||
|
|
)
|
||
|
|
results = await session.exec(statement)
|
||
|
|
return list(results)
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
async def from_id(chat_id: int) -> "ChatDao":
|
||
|
|
async with get_session() as session:
|
||
|
|
statement = (
|
||
|
|
select(ChatDao)
|
||
|
|
.where(ChatDao.id == int(chat_id))
|
||
|
|
.options(selectinload(ChatDao.messages))
|
||
|
|
)
|
||
|
|
result = await session.exec(statement)
|
||
|
|
return result.one()
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
async def rename_chat(chat_id: int, new_title: str) -> None:
|
||
|
|
async with get_session() as session:
|
||
|
|
chat = await ChatDao.from_id(chat_id)
|
||
|
|
chat.title = new_title
|
||
|
|
session.add(chat)
|
||
|
|
await session.commit()
|