462 lines
16 KiB
Python
462 lines
16 KiB
Python
"""AI Coaching service — onboarding, plan generation, chat."""
|
|
|
|
import json
|
|
import re
|
|
from datetime import date, datetime, timedelta
|
|
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from backend.app.models.activity import Activity, ActivityMetrics
|
|
from backend.app.models.coaching import CoachingChat
|
|
from backend.app.models.fitness import FitnessHistory, PowerCurve
|
|
from backend.app.models.rider import Rider
|
|
from backend.app.models.training import TrainingPlan
|
|
from backend.app.services.gemini_client import chat_async
|
|
|
|
|
|
ONBOARDING_SYSTEM = """You are VeloBrain AI Coach — a professional cycling coach.
|
|
You are conducting an onboarding interview with a new athlete.
|
|
|
|
Ask questions ONE AT A TIME, in a friendly conversational tone.
|
|
Keep responses short (2-3 sentences + question).
|
|
|
|
Questions to cover (in rough order):
|
|
1. Main cycling goal (fitness, racing, gran fondo, weight loss, etc.)
|
|
2. Target event/race (if any) and its date
|
|
3. Current weekly training volume (hours/week)
|
|
4. How many days per week they can train
|
|
5. Which days are available for training
|
|
6. Indoor trainer availability (smart trainer, basic, none)
|
|
7. Power meter availability
|
|
8. Any injuries or health concerns
|
|
9. Previous coaching or structured training experience
|
|
10. What they enjoy most about cycling
|
|
|
|
When ALL questions are answered, respond with your summary and then output EXACTLY this marker on a new line:
|
|
[ONBOARDING_COMPLETE]
|
|
Followed by a JSON block with the structured data:
|
|
```json
|
|
{
|
|
"goal": "...",
|
|
"target_event": "...",
|
|
"target_event_date": "...",
|
|
"hours_per_week": N,
|
|
"days_per_week": N,
|
|
"available_days": ["monday", ...],
|
|
"has_indoor_trainer": true/false,
|
|
"trainer_type": "smart/basic/none",
|
|
"has_power_meter": true/false,
|
|
"injuries": "...",
|
|
"coaching_experience": "...",
|
|
"enjoys": "..."
|
|
}
|
|
```
|
|
|
|
Respond in Russian."""
|
|
|
|
|
|
PLAN_GENERATION_SYSTEM = """You are VeloBrain AI Coach generating a structured training plan.
|
|
Based on the rider's profile, current fitness, and goals, create a detailed multi-week training plan.
|
|
|
|
Output ONLY a valid JSON block with this structure:
|
|
```json
|
|
{
|
|
"goal": "short goal description",
|
|
"description": "plan overview in 2-3 sentences",
|
|
"phase": "base/build/peak/recovery",
|
|
"duration_weeks": N,
|
|
"weeks": [
|
|
{
|
|
"week_number": 1,
|
|
"focus": "week focus description",
|
|
"target_tss": 300,
|
|
"target_hours": 8,
|
|
"days": [
|
|
{
|
|
"day": "monday",
|
|
"workout_type": "rest",
|
|
"title": "Rest Day",
|
|
"description": "",
|
|
"duration_minutes": 0,
|
|
"target_tss": 0,
|
|
"target_if": 0
|
|
}
|
|
]
|
|
}
|
|
]
|
|
}
|
|
```
|
|
|
|
workout_type options: rest, endurance, tempo, sweetspot, threshold, vo2max, sprint, recovery, race
|
|
Plan duration: 4-8 weeks based on goal.
|
|
Include progressive overload with recovery weeks every 3-4 weeks.
|
|
Adjust intensity based on rider's FTP and experience level.
|
|
All descriptions in Russian."""
|
|
|
|
|
|
ADJUSTMENT_SYSTEM = """You are VeloBrain AI Coach reviewing a training plan.
|
|
The rider's plan needs adjustment based on their recent performance, compliance, and fatigue.
|
|
|
|
Analyze the data provided and suggest specific changes to upcoming weeks.
|
|
When you've decided on adjustments, output:
|
|
[PLAN_ADJUSTED]
|
|
Followed by the updated weeks JSON.
|
|
|
|
Respond in Russian."""
|
|
|
|
|
|
GENERAL_CHAT_SYSTEM = """You are VeloBrain AI Coach — a knowledgeable and supportive cycling coach.
|
|
You have access to the rider's training data and can answer questions about:
|
|
- Training methodology and periodization
|
|
- Nutrition and recovery
|
|
- Equipment and bike fit
|
|
- Race strategy and pacing
|
|
- Interpreting their power/HR data
|
|
|
|
Be concise, specific, and actionable. Use the rider's actual data when relevant.
|
|
Respond in Russian."""
|
|
|
|
|
|
async def build_rider_context(rider: Rider, session: AsyncSession) -> str:
|
|
"""Build a concise context string with rider's current state."""
|
|
lines = [
|
|
f"Rider: {rider.name}",
|
|
f"FTP: {rider.ftp or 'not set'} W",
|
|
f"Weight: {rider.weight or 'not set'} kg",
|
|
f"LTHR: {rider.lthr or 'not set'} bpm",
|
|
f"Experience: {rider.experience_level or 'not set'}",
|
|
f"Goals: {rider.goals or 'not set'}",
|
|
]
|
|
|
|
if rider.ftp and rider.weight:
|
|
lines.append(f"W/kg: {rider.ftp / rider.weight:.2f}")
|
|
|
|
# Coaching profile
|
|
if rider.coaching_profile:
|
|
cp = rider.coaching_profile
|
|
lines.append(f"\nCoaching Profile:")
|
|
for k, v in cp.items():
|
|
lines.append(f" {k}: {v}")
|
|
|
|
# Fitness (latest CTL/ATL/TSB)
|
|
fh_query = (
|
|
select(FitnessHistory)
|
|
.where(FitnessHistory.rider_id == rider.id)
|
|
.order_by(FitnessHistory.date.desc())
|
|
.limit(1)
|
|
)
|
|
fh_result = await session.execute(fh_query)
|
|
fh = fh_result.scalar_one_or_none()
|
|
if fh:
|
|
lines.append(f"\nFitness: CTL={fh.ctl:.0f} ATL={fh.atl:.0f} TSB={fh.tsb:.0f}")
|
|
|
|
# Recent 4 weeks volume
|
|
four_weeks_ago = date.today() - timedelta(weeks=4)
|
|
vol_query = (
|
|
select(
|
|
Activity.date,
|
|
Activity.duration,
|
|
ActivityMetrics.tss,
|
|
)
|
|
.outerjoin(ActivityMetrics, ActivityMetrics.activity_id == Activity.id)
|
|
.where(Activity.rider_id == rider.id)
|
|
.where(Activity.date >= four_weeks_ago)
|
|
.order_by(Activity.date.desc())
|
|
)
|
|
vol_result = await session.execute(vol_query)
|
|
rides = list(vol_result.all())
|
|
if rides:
|
|
total_hours = sum(r.duration for r in rides) / 3600
|
|
total_tss = sum(float(r.tss or 0) for r in rides)
|
|
lines.append(f"\nLast 4 weeks: {len(rides)} rides, {total_hours:.1f}h, TSS={total_tss:.0f}")
|
|
lines.append(f"Avg/week: {total_hours / 4:.1f}h, TSS={total_tss / 4:.0f}")
|
|
|
|
# Personal records (from power curves)
|
|
pc_query = (
|
|
select(PowerCurve.curve_data)
|
|
.join(Activity, Activity.id == PowerCurve.activity_id)
|
|
.where(Activity.rider_id == rider.id)
|
|
)
|
|
pc_result = await session.execute(pc_query)
|
|
best: dict[int, int] = {}
|
|
for row in pc_result:
|
|
for dur_str, power in row.curve_data.items():
|
|
dur = int(dur_str)
|
|
if dur not in best or power > best[dur]:
|
|
best[dur] = power
|
|
if best:
|
|
pr_strs = []
|
|
for dur in sorted(best.keys()):
|
|
if dur < 60:
|
|
pr_strs.append(f"{dur}s={best[dur]}W")
|
|
elif dur < 3600:
|
|
pr_strs.append(f"{dur // 60}m={best[dur]}W")
|
|
else:
|
|
pr_strs.append(f"{dur // 3600}h={best[dur]}W")
|
|
lines.append(f"\nPower PRs: {', '.join(pr_strs)}")
|
|
|
|
# Active plan status
|
|
plan_query = (
|
|
select(TrainingPlan)
|
|
.where(TrainingPlan.rider_id == rider.id)
|
|
.where(TrainingPlan.status == "active")
|
|
.order_by(TrainingPlan.created_at.desc())
|
|
.limit(1)
|
|
)
|
|
plan_result = await session.execute(plan_query)
|
|
plan = plan_result.scalar_one_or_none()
|
|
if plan:
|
|
lines.append(f"\nActive plan: '{plan.goal}' ({plan.start_date} to {plan.end_date}), phase: {plan.phase}")
|
|
|
|
return "\n".join(lines)
|
|
|
|
|
|
async def process_chat_message(
|
|
rider: Rider,
|
|
chat_id,
|
|
user_message: str,
|
|
session: AsyncSession,
|
|
) -> str:
|
|
"""Process a user message in a coaching chat and return AI response."""
|
|
chat = await session.get(CoachingChat, chat_id)
|
|
if not chat or chat.rider_id != rider.id:
|
|
raise ValueError("Chat not found")
|
|
|
|
# Build context
|
|
rider_context = await build_rider_context(rider, session)
|
|
|
|
# Select system prompt
|
|
system_prompts = {
|
|
"onboarding": ONBOARDING_SYSTEM,
|
|
"general": GENERAL_CHAT_SYSTEM,
|
|
"adjustment": ADJUSTMENT_SYSTEM,
|
|
}
|
|
system = system_prompts.get(chat.chat_type, GENERAL_CHAT_SYSTEM)
|
|
system = f"{system}\n\n--- Rider Data ---\n{rider_context}"
|
|
|
|
# Build message history
|
|
messages = list(chat.messages_json or [])
|
|
messages.append({"role": "user", "text": user_message})
|
|
|
|
# Call Gemini
|
|
gemini_messages = [{"role": m["role"], "text": m["text"]} for m in messages]
|
|
response = await chat_async(gemini_messages, system_instruction=system, temperature=0.7)
|
|
|
|
# Save messages
|
|
now = datetime.utcnow().isoformat()
|
|
messages_to_save = list(chat.messages_json or [])
|
|
messages_to_save.append({"role": "user", "text": user_message, "timestamp": now})
|
|
messages_to_save.append({"role": "model", "text": response, "timestamp": now})
|
|
chat.messages_json = messages_to_save
|
|
|
|
# Check for onboarding completion
|
|
if chat.chat_type == "onboarding" and "[ONBOARDING_COMPLETE]" in response:
|
|
chat.status = "completed"
|
|
profile_data = _extract_json(response)
|
|
if profile_data:
|
|
rider.coaching_profile = profile_data
|
|
rider.onboarding_completed = True
|
|
if profile_data.get("goal"):
|
|
rider.goals = profile_data["goal"]
|
|
|
|
# Check for plan adjustment
|
|
if chat.chat_type == "adjustment" and "[PLAN_ADJUSTED]" in response:
|
|
chat.status = "completed"
|
|
plan_data = _extract_json(response)
|
|
if plan_data:
|
|
plan_query = (
|
|
select(TrainingPlan)
|
|
.where(TrainingPlan.rider_id == rider.id)
|
|
.where(TrainingPlan.status == "active")
|
|
.order_by(TrainingPlan.created_at.desc())
|
|
.limit(1)
|
|
)
|
|
plan_result = await session.execute(plan_query)
|
|
plan = plan_result.scalar_one_or_none()
|
|
if plan and "weeks" in plan_data:
|
|
current = plan.weeks_json or {}
|
|
current["weeks"] = plan_data["weeks"]
|
|
plan.weeks_json = current
|
|
|
|
await session.commit()
|
|
return response
|
|
|
|
|
|
async def generate_plan(rider: Rider, session: AsyncSession) -> TrainingPlan:
|
|
"""Generate a new training plan using AI."""
|
|
rider_context = await build_rider_context(rider, session)
|
|
|
|
prompt = f"Generate a training plan for this rider.\n\n{rider_context}"
|
|
messages = [{"role": "user", "text": prompt}]
|
|
|
|
response = await chat_async(
|
|
messages,
|
|
system_instruction=PLAN_GENERATION_SYSTEM,
|
|
temperature=0.5,
|
|
)
|
|
|
|
plan_data = _extract_json(response)
|
|
if not plan_data or "weeks" not in plan_data:
|
|
raise ValueError("Failed to parse plan from AI response")
|
|
|
|
# Cancel existing active plans
|
|
existing_query = (
|
|
select(TrainingPlan)
|
|
.where(TrainingPlan.rider_id == rider.id)
|
|
.where(TrainingPlan.status == "active")
|
|
)
|
|
existing_result = await session.execute(existing_query)
|
|
for old_plan in existing_result.scalars().all():
|
|
old_plan.status = "cancelled"
|
|
|
|
duration_weeks = plan_data.get("duration_weeks", len(plan_data["weeks"]))
|
|
start = date.today()
|
|
# Align to next Monday
|
|
days_until_monday = (7 - start.weekday()) % 7
|
|
if days_until_monday == 0:
|
|
days_until_monday = 0
|
|
start = start + timedelta(days=days_until_monday)
|
|
|
|
plan = TrainingPlan(
|
|
rider_id=rider.id,
|
|
goal=plan_data.get("goal", rider.goals or "General fitness"),
|
|
start_date=start,
|
|
end_date=start + timedelta(weeks=duration_weeks),
|
|
phase=plan_data.get("phase", "base"),
|
|
weeks_json=plan_data,
|
|
description=plan_data.get("description", ""),
|
|
status="active",
|
|
onboarding_data=rider.coaching_profile,
|
|
)
|
|
session.add(plan)
|
|
await session.commit()
|
|
await session.refresh(plan)
|
|
return plan
|
|
|
|
|
|
async def get_today_workout(rider: Rider, session: AsyncSession) -> dict | None:
|
|
"""Get today's planned workout from the active plan."""
|
|
plan_query = (
|
|
select(TrainingPlan)
|
|
.where(TrainingPlan.rider_id == rider.id)
|
|
.where(TrainingPlan.status == "active")
|
|
.order_by(TrainingPlan.created_at.desc())
|
|
.limit(1)
|
|
)
|
|
result = await session.execute(plan_query)
|
|
plan = result.scalar_one_or_none()
|
|
if not plan or not plan.weeks_json:
|
|
return None
|
|
|
|
today = date.today()
|
|
if today < plan.start_date or today > plan.end_date:
|
|
return None
|
|
|
|
week_num = (today - plan.start_date).days // 7 + 1
|
|
day_name = today.strftime("%A").lower()
|
|
|
|
weeks = plan.weeks_json.get("weeks", [])
|
|
for week in weeks:
|
|
if week.get("week_number") == week_num:
|
|
for day in week.get("days", []):
|
|
if day.get("day") == day_name:
|
|
return {
|
|
"plan_id": str(plan.id),
|
|
"plan_goal": plan.goal,
|
|
"week_number": week_num,
|
|
"week_focus": week.get("focus", ""),
|
|
**day,
|
|
}
|
|
|
|
return None
|
|
|
|
|
|
async def calculate_compliance(plan: TrainingPlan, session: AsyncSession) -> list[dict]:
|
|
"""Compare planned vs actual per week."""
|
|
if not plan.weeks_json:
|
|
return []
|
|
|
|
weeks = plan.weeks_json.get("weeks", [])
|
|
results = []
|
|
|
|
for week in weeks:
|
|
week_num = week.get("week_number", 0)
|
|
week_start = plan.start_date + timedelta(weeks=week_num - 1)
|
|
week_end = week_start + timedelta(days=7)
|
|
|
|
# Skip future weeks
|
|
if week_start > date.today():
|
|
results.append({
|
|
"week_number": week_num,
|
|
"focus": week.get("focus", ""),
|
|
"planned_tss": week.get("target_tss", 0),
|
|
"actual_tss": 0,
|
|
"planned_hours": week.get("target_hours", 0),
|
|
"actual_hours": 0,
|
|
"planned_rides": sum(1 for d in week.get("days", []) if d.get("workout_type") != "rest"),
|
|
"actual_rides": 0,
|
|
"adherence_pct": 0,
|
|
"status": "upcoming",
|
|
})
|
|
continue
|
|
|
|
# Get actual activities in this week
|
|
act_query = (
|
|
select(Activity, ActivityMetrics.tss)
|
|
.outerjoin(ActivityMetrics, ActivityMetrics.activity_id == Activity.id)
|
|
.where(Activity.rider_id == plan.rider_id)
|
|
.where(Activity.date >= week_start)
|
|
.where(Activity.date < week_end)
|
|
)
|
|
act_result = await session.execute(act_query)
|
|
acts = list(act_result.all())
|
|
|
|
actual_tss = sum(float(r.tss or 0) for r in acts)
|
|
actual_hours = sum(r[0].duration for r in acts) / 3600
|
|
actual_rides = len(acts)
|
|
planned_rides = sum(1 for d in week.get("days", []) if d.get("workout_type") != "rest")
|
|
planned_tss = week.get("target_tss", 0)
|
|
|
|
adherence = 0
|
|
if planned_rides > 0:
|
|
adherence = min(100, round(actual_rides / planned_rides * 100))
|
|
|
|
is_current = week_start <= date.today() < week_end
|
|
|
|
results.append({
|
|
"week_number": week_num,
|
|
"focus": week.get("focus", ""),
|
|
"planned_tss": planned_tss,
|
|
"actual_tss": round(actual_tss, 0),
|
|
"planned_hours": week.get("target_hours", 0),
|
|
"actual_hours": round(actual_hours, 1),
|
|
"planned_rides": planned_rides,
|
|
"actual_rides": actual_rides,
|
|
"adherence_pct": adherence,
|
|
"status": "current" if is_current else "completed",
|
|
})
|
|
|
|
return results
|
|
|
|
|
|
def _extract_json(text: str) -> dict | None:
|
|
"""Extract JSON from AI response text."""
|
|
# Try to find JSON in code blocks
|
|
match = re.search(r"```(?:json)?\s*\n?(.*?)\n?```", text, re.DOTALL)
|
|
if match:
|
|
try:
|
|
return json.loads(match.group(1))
|
|
except json.JSONDecodeError:
|
|
pass
|
|
|
|
# Try to find raw JSON object
|
|
match = re.search(r"\{[\s\S]*\}", text)
|
|
if match:
|
|
try:
|
|
return json.loads(match.group(0))
|
|
except json.JSONDecodeError:
|
|
pass
|
|
|
|
return None
|