308 lines
8.6 KiB
Python
308 lines
8.6 KiB
Python
"""Coaching API endpoints — onboarding, chat, plan, compliance."""
|
|
|
|
import uuid
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException
|
|
from pydantic import BaseModel
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from backend.app.core.auth import get_current_rider
|
|
from backend.app.core.database import get_session
|
|
from backend.app.models.coaching import CoachingChat
|
|
from backend.app.models.rider import Rider
|
|
from backend.app.models.training import TrainingPlan
|
|
from backend.app.services.coaching import (
|
|
process_chat_message,
|
|
generate_plan,
|
|
get_today_workout,
|
|
calculate_compliance,
|
|
)
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
class MessageRequest(BaseModel):
|
|
message: str
|
|
|
|
|
|
class ChatResponse(BaseModel):
|
|
chat_id: str
|
|
chat_type: str
|
|
status: str
|
|
messages: list[dict]
|
|
|
|
|
|
class PlanResponse(BaseModel):
|
|
id: str
|
|
goal: str
|
|
start_date: str
|
|
end_date: str
|
|
phase: str | None
|
|
description: str | None
|
|
status: str
|
|
weeks: list[dict]
|
|
|
|
|
|
# --- Onboarding ---
|
|
|
|
@router.post("/onboarding/start")
|
|
async def start_onboarding(
|
|
rider: Rider = Depends(get_current_rider),
|
|
session: AsyncSession = Depends(get_session),
|
|
):
|
|
"""Start or resume onboarding chat."""
|
|
# Check for existing active onboarding
|
|
query = (
|
|
select(CoachingChat)
|
|
.where(CoachingChat.rider_id == rider.id)
|
|
.where(CoachingChat.chat_type == "onboarding")
|
|
.where(CoachingChat.status == "active")
|
|
.order_by(CoachingChat.created_at.desc())
|
|
.limit(1)
|
|
)
|
|
result = await session.execute(query)
|
|
chat = result.scalar_one_or_none()
|
|
|
|
if chat:
|
|
return {
|
|
"chat_id": str(chat.id),
|
|
"status": chat.status,
|
|
"messages": chat.messages_json or [],
|
|
"onboarding_completed": rider.onboarding_completed,
|
|
}
|
|
|
|
# Create new onboarding chat
|
|
chat = CoachingChat(
|
|
rider_id=rider.id,
|
|
chat_type="onboarding",
|
|
status="active",
|
|
messages_json=[],
|
|
)
|
|
session.add(chat)
|
|
await session.commit()
|
|
await session.refresh(chat)
|
|
|
|
# Send initial greeting by processing an empty-ish message
|
|
response = await process_chat_message(rider, chat.id, "Привет! Я хочу начать тренировки.", session)
|
|
|
|
return {
|
|
"chat_id": str(chat.id),
|
|
"status": chat.status,
|
|
"messages": chat.messages_json or [],
|
|
"onboarding_completed": rider.onboarding_completed,
|
|
}
|
|
|
|
|
|
@router.get("/onboarding/status")
|
|
async def get_onboarding_status(
|
|
rider: Rider = Depends(get_current_rider),
|
|
session: AsyncSession = Depends(get_session),
|
|
):
|
|
"""Check onboarding status."""
|
|
return {
|
|
"onboarding_completed": rider.onboarding_completed,
|
|
"coaching_profile": rider.coaching_profile,
|
|
}
|
|
|
|
|
|
# --- Chat ---
|
|
|
|
@router.post("/chat/new")
|
|
async def create_chat(
|
|
rider: Rider = Depends(get_current_rider),
|
|
session: AsyncSession = Depends(get_session),
|
|
):
|
|
"""Create a new general coaching chat."""
|
|
chat = CoachingChat(
|
|
rider_id=rider.id,
|
|
chat_type="general",
|
|
status="active",
|
|
messages_json=[],
|
|
)
|
|
session.add(chat)
|
|
await session.commit()
|
|
await session.refresh(chat)
|
|
return {"chat_id": str(chat.id), "status": "active", "messages": []}
|
|
|
|
|
|
@router.get("/chats")
|
|
async def list_chats(
|
|
rider: Rider = Depends(get_current_rider),
|
|
session: AsyncSession = Depends(get_session),
|
|
):
|
|
"""List all coaching chats."""
|
|
query = (
|
|
select(CoachingChat)
|
|
.where(CoachingChat.rider_id == rider.id)
|
|
.order_by(CoachingChat.updated_at.desc())
|
|
.limit(20)
|
|
)
|
|
result = await session.execute(query)
|
|
chats = result.scalars().all()
|
|
return [
|
|
{
|
|
"id": str(c.id),
|
|
"chat_type": c.chat_type,
|
|
"status": c.status,
|
|
"message_count": len(c.messages_json or []),
|
|
"created_at": c.created_at.isoformat() if c.created_at else None,
|
|
"updated_at": c.updated_at.isoformat() if c.updated_at else None,
|
|
"last_message": (c.messages_json[-1]["text"][:100] if c.messages_json else None),
|
|
}
|
|
for c in chats
|
|
]
|
|
|
|
|
|
@router.get("/chat/{chat_id}")
|
|
async def get_chat(
|
|
chat_id: uuid.UUID,
|
|
rider: Rider = Depends(get_current_rider),
|
|
session: AsyncSession = Depends(get_session),
|
|
):
|
|
"""Get chat with messages."""
|
|
chat = await session.get(CoachingChat, chat_id)
|
|
if not chat or chat.rider_id != rider.id:
|
|
raise HTTPException(status_code=404, detail="Chat not found")
|
|
return {
|
|
"chat_id": str(chat.id),
|
|
"chat_type": chat.chat_type,
|
|
"status": chat.status,
|
|
"messages": chat.messages_json or [],
|
|
}
|
|
|
|
|
|
@router.post("/chat/{chat_id}/message")
|
|
async def send_message(
|
|
chat_id: uuid.UUID,
|
|
body: MessageRequest,
|
|
rider: Rider = Depends(get_current_rider),
|
|
session: AsyncSession = Depends(get_session),
|
|
):
|
|
"""Send a message to the coaching chat."""
|
|
response = await process_chat_message(rider, chat_id, body.message, session)
|
|
|
|
chat = await session.get(CoachingChat, chat_id)
|
|
return {
|
|
"response": response,
|
|
"chat_id": str(chat_id),
|
|
"status": chat.status if chat else "active",
|
|
"messages": chat.messages_json if chat else [],
|
|
"onboarding_completed": rider.onboarding_completed,
|
|
}
|
|
|
|
|
|
# --- Training Plan ---
|
|
|
|
@router.post("/plan/generate")
|
|
async def create_plan(
|
|
rider: Rider = Depends(get_current_rider),
|
|
session: AsyncSession = Depends(get_session),
|
|
):
|
|
"""Generate a new AI training plan."""
|
|
plan = await generate_plan(rider, session)
|
|
return _plan_to_dict(plan)
|
|
|
|
|
|
@router.get("/plan/active")
|
|
async def get_active_plan(
|
|
rider: Rider = Depends(get_current_rider),
|
|
session: AsyncSession = Depends(get_session),
|
|
):
|
|
"""Get the active training plan."""
|
|
query = (
|
|
select(TrainingPlan)
|
|
.where(TrainingPlan.rider_id == rider.id)
|
|
.where(TrainingPlan.status == "active")
|
|
.order_by(TrainingPlan.created_at.desc())
|
|
.limit(1)
|
|
)
|
|
result = await session.execute(query)
|
|
plan = result.scalar_one_or_none()
|
|
if not plan:
|
|
return None
|
|
return _plan_to_dict(plan)
|
|
|
|
|
|
@router.get("/plan/{plan_id}/compliance")
|
|
async def get_plan_compliance(
|
|
plan_id: uuid.UUID,
|
|
rider: Rider = Depends(get_current_rider),
|
|
session: AsyncSession = Depends(get_session),
|
|
):
|
|
"""Get compliance data for a plan."""
|
|
plan = await session.get(TrainingPlan, plan_id)
|
|
if not plan or plan.rider_id != rider.id:
|
|
raise HTTPException(status_code=404, detail="Plan not found")
|
|
compliance = await calculate_compliance(plan, session)
|
|
return compliance
|
|
|
|
|
|
@router.get("/today")
|
|
async def get_today(
|
|
rider: Rider = Depends(get_current_rider),
|
|
session: AsyncSession = Depends(get_session),
|
|
):
|
|
"""Get today's planned workout."""
|
|
workout = await get_today_workout(rider, session)
|
|
return workout
|
|
|
|
|
|
# --- Adjustment chat ---
|
|
|
|
@router.post("/plan/adjust")
|
|
async def start_plan_adjustment(
|
|
rider: Rider = Depends(get_current_rider),
|
|
session: AsyncSession = Depends(get_session),
|
|
):
|
|
"""Start an adjustment chat for the active plan."""
|
|
# Check active plan exists
|
|
plan_query = (
|
|
select(TrainingPlan)
|
|
.where(TrainingPlan.rider_id == rider.id)
|
|
.where(TrainingPlan.status == "active")
|
|
.limit(1)
|
|
)
|
|
plan_result = await session.execute(plan_query)
|
|
plan = plan_result.scalar_one_or_none()
|
|
if not plan:
|
|
raise HTTPException(status_code=400, detail="No active plan to adjust")
|
|
|
|
chat = CoachingChat(
|
|
rider_id=rider.id,
|
|
chat_type="adjustment",
|
|
status="active",
|
|
messages_json=[],
|
|
)
|
|
session.add(chat)
|
|
await session.commit()
|
|
await session.refresh(chat)
|
|
|
|
# Start with context about what needs adjustment
|
|
response = await process_chat_message(
|
|
rider, chat.id,
|
|
"Мне нужно скорректировать мой текущий план тренировок. Посмотри на мои последние данные и предложи изменения.",
|
|
session,
|
|
)
|
|
|
|
return {
|
|
"chat_id": str(chat.id),
|
|
"status": chat.status,
|
|
"messages": chat.messages_json or [],
|
|
"response": response,
|
|
}
|
|
|
|
|
|
def _plan_to_dict(plan: TrainingPlan) -> dict:
|
|
weeks = plan.weeks_json.get("weeks", []) if plan.weeks_json else []
|
|
return {
|
|
"id": str(plan.id),
|
|
"goal": plan.goal,
|
|
"start_date": plan.start_date.isoformat(),
|
|
"end_date": plan.end_date.isoformat(),
|
|
"phase": plan.phase,
|
|
"description": plan.description,
|
|
"status": plan.status,
|
|
"weeks": weeks,
|
|
}
|