fix
This commit is contained in:
307
backend/app/api/coaching.py
Normal file
307
backend/app/api/coaching.py
Normal file
@@ -0,0 +1,307 @@
|
||||
"""Coaching API endpoints — onboarding, chat, plan, compliance."""
|
||||
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from backend.app.core.auth import get_current_rider
|
||||
from backend.app.core.database import get_session
|
||||
from backend.app.models.coaching import CoachingChat
|
||||
from backend.app.models.rider import Rider
|
||||
from backend.app.models.training import TrainingPlan
|
||||
from backend.app.services.coaching import (
|
||||
process_chat_message,
|
||||
generate_plan,
|
||||
get_today_workout,
|
||||
calculate_compliance,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class MessageRequest(BaseModel):
|
||||
message: str
|
||||
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
chat_id: str
|
||||
chat_type: str
|
||||
status: str
|
||||
messages: list[dict]
|
||||
|
||||
|
||||
class PlanResponse(BaseModel):
|
||||
id: str
|
||||
goal: str
|
||||
start_date: str
|
||||
end_date: str
|
||||
phase: str | None
|
||||
description: str | None
|
||||
status: str
|
||||
weeks: list[dict]
|
||||
|
||||
|
||||
# --- Onboarding ---
|
||||
|
||||
@router.post("/onboarding/start")
|
||||
async def start_onboarding(
|
||||
rider: Rider = Depends(get_current_rider),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
"""Start or resume onboarding chat."""
|
||||
# Check for existing active onboarding
|
||||
query = (
|
||||
select(CoachingChat)
|
||||
.where(CoachingChat.rider_id == rider.id)
|
||||
.where(CoachingChat.chat_type == "onboarding")
|
||||
.where(CoachingChat.status == "active")
|
||||
.order_by(CoachingChat.created_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
result = await session.execute(query)
|
||||
chat = result.scalar_one_or_none()
|
||||
|
||||
if chat:
|
||||
return {
|
||||
"chat_id": str(chat.id),
|
||||
"status": chat.status,
|
||||
"messages": chat.messages_json or [],
|
||||
"onboarding_completed": rider.onboarding_completed,
|
||||
}
|
||||
|
||||
# Create new onboarding chat
|
||||
chat = CoachingChat(
|
||||
rider_id=rider.id,
|
||||
chat_type="onboarding",
|
||||
status="active",
|
||||
messages_json=[],
|
||||
)
|
||||
session.add(chat)
|
||||
await session.commit()
|
||||
await session.refresh(chat)
|
||||
|
||||
# Send initial greeting by processing an empty-ish message
|
||||
response = await process_chat_message(rider, chat.id, "Привет! Я хочу начать тренировки.", session)
|
||||
|
||||
return {
|
||||
"chat_id": str(chat.id),
|
||||
"status": chat.status,
|
||||
"messages": chat.messages_json or [],
|
||||
"onboarding_completed": rider.onboarding_completed,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/onboarding/status")
|
||||
async def get_onboarding_status(
|
||||
rider: Rider = Depends(get_current_rider),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
"""Check onboarding status."""
|
||||
return {
|
||||
"onboarding_completed": rider.onboarding_completed,
|
||||
"coaching_profile": rider.coaching_profile,
|
||||
}
|
||||
|
||||
|
||||
# --- Chat ---
|
||||
|
||||
@router.post("/chat/new")
|
||||
async def create_chat(
|
||||
rider: Rider = Depends(get_current_rider),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
"""Create a new general coaching chat."""
|
||||
chat = CoachingChat(
|
||||
rider_id=rider.id,
|
||||
chat_type="general",
|
||||
status="active",
|
||||
messages_json=[],
|
||||
)
|
||||
session.add(chat)
|
||||
await session.commit()
|
||||
await session.refresh(chat)
|
||||
return {"chat_id": str(chat.id), "status": "active", "messages": []}
|
||||
|
||||
|
||||
@router.get("/chats")
|
||||
async def list_chats(
|
||||
rider: Rider = Depends(get_current_rider),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
"""List all coaching chats."""
|
||||
query = (
|
||||
select(CoachingChat)
|
||||
.where(CoachingChat.rider_id == rider.id)
|
||||
.order_by(CoachingChat.updated_at.desc())
|
||||
.limit(20)
|
||||
)
|
||||
result = await session.execute(query)
|
||||
chats = result.scalars().all()
|
||||
return [
|
||||
{
|
||||
"id": str(c.id),
|
||||
"chat_type": c.chat_type,
|
||||
"status": c.status,
|
||||
"message_count": len(c.messages_json or []),
|
||||
"created_at": c.created_at.isoformat() if c.created_at else None,
|
||||
"updated_at": c.updated_at.isoformat() if c.updated_at else None,
|
||||
"last_message": (c.messages_json[-1]["text"][:100] if c.messages_json else None),
|
||||
}
|
||||
for c in chats
|
||||
]
|
||||
|
||||
|
||||
@router.get("/chat/{chat_id}")
|
||||
async def get_chat(
|
||||
chat_id: uuid.UUID,
|
||||
rider: Rider = Depends(get_current_rider),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
"""Get chat with messages."""
|
||||
chat = await session.get(CoachingChat, chat_id)
|
||||
if not chat or chat.rider_id != rider.id:
|
||||
raise HTTPException(status_code=404, detail="Chat not found")
|
||||
return {
|
||||
"chat_id": str(chat.id),
|
||||
"chat_type": chat.chat_type,
|
||||
"status": chat.status,
|
||||
"messages": chat.messages_json or [],
|
||||
}
|
||||
|
||||
|
||||
@router.post("/chat/{chat_id}/message")
|
||||
async def send_message(
|
||||
chat_id: uuid.UUID,
|
||||
body: MessageRequest,
|
||||
rider: Rider = Depends(get_current_rider),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
"""Send a message to the coaching chat."""
|
||||
response = await process_chat_message(rider, chat_id, body.message, session)
|
||||
|
||||
chat = await session.get(CoachingChat, chat_id)
|
||||
return {
|
||||
"response": response,
|
||||
"chat_id": str(chat_id),
|
||||
"status": chat.status if chat else "active",
|
||||
"messages": chat.messages_json if chat else [],
|
||||
"onboarding_completed": rider.onboarding_completed,
|
||||
}
|
||||
|
||||
|
||||
# --- Training Plan ---
|
||||
|
||||
@router.post("/plan/generate")
|
||||
async def create_plan(
|
||||
rider: Rider = Depends(get_current_rider),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
"""Generate a new AI training plan."""
|
||||
plan = await generate_plan(rider, session)
|
||||
return _plan_to_dict(plan)
|
||||
|
||||
|
||||
@router.get("/plan/active")
|
||||
async def get_active_plan(
|
||||
rider: Rider = Depends(get_current_rider),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
"""Get the active training plan."""
|
||||
query = (
|
||||
select(TrainingPlan)
|
||||
.where(TrainingPlan.rider_id == rider.id)
|
||||
.where(TrainingPlan.status == "active")
|
||||
.order_by(TrainingPlan.created_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
result = await session.execute(query)
|
||||
plan = result.scalar_one_or_none()
|
||||
if not plan:
|
||||
return None
|
||||
return _plan_to_dict(plan)
|
||||
|
||||
|
||||
@router.get("/plan/{plan_id}/compliance")
|
||||
async def get_plan_compliance(
|
||||
plan_id: uuid.UUID,
|
||||
rider: Rider = Depends(get_current_rider),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
"""Get compliance data for a plan."""
|
||||
plan = await session.get(TrainingPlan, plan_id)
|
||||
if not plan or plan.rider_id != rider.id:
|
||||
raise HTTPException(status_code=404, detail="Plan not found")
|
||||
compliance = await calculate_compliance(plan, session)
|
||||
return compliance
|
||||
|
||||
|
||||
@router.get("/today")
|
||||
async def get_today(
|
||||
rider: Rider = Depends(get_current_rider),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
"""Get today's planned workout."""
|
||||
workout = await get_today_workout(rider, session)
|
||||
return workout
|
||||
|
||||
|
||||
# --- Adjustment chat ---
|
||||
|
||||
@router.post("/plan/adjust")
|
||||
async def start_plan_adjustment(
|
||||
rider: Rider = Depends(get_current_rider),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
"""Start an adjustment chat for the active plan."""
|
||||
# Check active plan exists
|
||||
plan_query = (
|
||||
select(TrainingPlan)
|
||||
.where(TrainingPlan.rider_id == rider.id)
|
||||
.where(TrainingPlan.status == "active")
|
||||
.limit(1)
|
||||
)
|
||||
plan_result = await session.execute(plan_query)
|
||||
plan = plan_result.scalar_one_or_none()
|
||||
if not plan:
|
||||
raise HTTPException(status_code=400, detail="No active plan to adjust")
|
||||
|
||||
chat = CoachingChat(
|
||||
rider_id=rider.id,
|
||||
chat_type="adjustment",
|
||||
status="active",
|
||||
messages_json=[],
|
||||
)
|
||||
session.add(chat)
|
||||
await session.commit()
|
||||
await session.refresh(chat)
|
||||
|
||||
# Start with context about what needs adjustment
|
||||
response = await process_chat_message(
|
||||
rider, chat.id,
|
||||
"Мне нужно скорректировать мой текущий план тренировок. Посмотри на мои последние данные и предложи изменения.",
|
||||
session,
|
||||
)
|
||||
|
||||
return {
|
||||
"chat_id": str(chat.id),
|
||||
"status": chat.status,
|
||||
"messages": chat.messages_json or [],
|
||||
"response": response,
|
||||
}
|
||||
|
||||
|
||||
def _plan_to_dict(plan: TrainingPlan) -> dict:
|
||||
weeks = plan.weeks_json.get("weeks", []) if plan.weeks_json else []
|
||||
return {
|
||||
"id": str(plan.id),
|
||||
"goal": plan.goal,
|
||||
"start_date": plan.start_date.isoformat(),
|
||||
"end_date": plan.end_date.isoformat(),
|
||||
"phase": plan.phase,
|
||||
"description": plan.description,
|
||||
"status": plan.status,
|
||||
"weeks": weeks,
|
||||
}
|
||||
Reference in New Issue
Block a user