from uuid import UUID from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select from sqlalchemy.orm import selectinload import json from datetime import datetime from ..database import get_db, async_session from ..models.room import Room, RoomParticipant from ..models.track import RoomQueue from ..models.message import Message from ..models.user import User from ..services.sync import manager from ..utils.security import decode_token router = APIRouter(tags=["websocket"]) async def get_user_from_token(token: str) -> User | None: payload = decode_token(token) if not payload: return None user_id = payload.get("sub") if not user_id: return None async with async_session() as db: result = await db.execute(select(User).where(User.id == UUID(user_id))) return result.scalar_one_or_none() @router.websocket("/ws/rooms/{room_id}") async def room_websocket(websocket: WebSocket, room_id: UUID): # Get token from query params token = websocket.query_params.get("token") if not token: await websocket.close(code=4001, reason="No token provided") return user = await get_user_from_token(token) if not user: await websocket.close(code=4001, reason="Invalid token") return await manager.connect(websocket, room_id, user.id) # Notify others that user joined await manager.broadcast_to_room( room_id, {"type": "user_joined", "user": {"id": str(user.id), "username": user.username}}, exclude_user=user.id ) try: while True: data = await websocket.receive_text() message = json.loads(data) # Handle ping/pong for keepalive if message.get("type") == "ping": await websocket.send_json({"type": "pong"}) continue async with async_session() as db: if message["type"] == "player_action": await handle_player_action(db, room_id, user, message) elif message["type"] == "chat_message": await handle_chat_message(db, room_id, user, message) elif message["type"] == "sync_request": await handle_sync_request(db, room_id, websocket) except WebSocketDisconnect: manager.disconnect(websocket, room_id, user.id) await manager.broadcast_to_room( room_id, {"type": "user_left", "user_id": str(user.id)}, ) async def handle_player_action(db: AsyncSession, room_id: UUID, user: User, message: dict): action = message.get("action") result = await db.execute(select(Room).where(Room.id == room_id)) room = result.scalar_one_or_none() if not room: return if action == "play": room.is_playing = True room.playback_position = message.get("position", room.playback_position or 0) room.playback_started_at = datetime.utcnow() elif action == "pause": room.is_playing = False room.playback_position = message.get("position", room.playback_position or 0) room.playback_started_at = None elif action == "seek": room.playback_position = message.get("position", 0) if room.is_playing: room.playback_started_at = datetime.utcnow() elif action == "next": await play_next_track(db, room) elif action == "prev": await play_prev_track(db, room) elif action == "set_track": track_id = message.get("track_id") if track_id: room.current_track_id = UUID(track_id) room.playback_position = 0 room.is_playing = True room.playback_started_at = datetime.utcnow() await db.commit() # Get current track URL - use streaming endpoint to bypass S3 SSL issues track_url = None if room.current_track_id: track_url = f"/api/tracks/{room.current_track_id}/stream" # Calculate current position based on when playback started current_position = room.playback_position or 0 if room.is_playing and room.playback_started_at: elapsed = (datetime.utcnow() - room.playback_started_at).total_seconds() * 1000 current_position = int((room.playback_position or 0) + elapsed) await manager.broadcast_to_room( room_id, { "type": "player_state", "is_playing": room.is_playing, "position": current_position, "current_track_id": str(room.current_track_id) if room.current_track_id else None, "track_url": track_url, "server_time": datetime.utcnow().isoformat(), }, ) async def play_next_track(db: AsyncSession, room: Room): result = await db.execute( select(RoomQueue) .where(RoomQueue.room_id == room.id) .order_by(RoomQueue.position) ) queue = result.scalars().all() if not queue: room.current_track_id = None room.is_playing = False room.playback_started_at = None return # Find current track in queue current_index = -1 for i, item in enumerate(queue): if item.track_id == room.current_track_id: current_index = i break # Play next or first next_index = (current_index + 1) % len(queue) room.current_track_id = queue[next_index].track_id room.playback_position = 0 room.is_playing = True room.playback_started_at = datetime.utcnow() async def play_prev_track(db: AsyncSession, room: Room): result = await db.execute( select(RoomQueue) .where(RoomQueue.room_id == room.id) .order_by(RoomQueue.position) ) queue = result.scalars().all() if not queue: room.current_track_id = None room.is_playing = False room.playback_started_at = None return # Find current track in queue current_index = 0 for i, item in enumerate(queue): if item.track_id == room.current_track_id: current_index = i break # Play prev or last prev_index = (current_index - 1) % len(queue) room.current_track_id = queue[prev_index].track_id room.playback_position = 0 room.is_playing = True room.playback_started_at = datetime.utcnow() async def handle_chat_message(db: AsyncSession, room_id: UUID, user: User, message: dict): text = message.get("text", "").strip() if not text: return msg = Message(room_id=room_id, user_id=user.id, text=text) db.add(msg) await db.commit() await manager.broadcast_to_room( room_id, { "type": "chat_message", "id": str(msg.id), "user_id": str(user.id), "username": user.username, "text": text, "created_at": msg.created_at.isoformat(), }, ) async def handle_sync_request(db: AsyncSession, room_id: UUID, websocket: WebSocket): result = await db.execute( select(Room).options(selectinload(Room.current_track)).where(Room.id == room_id) ) room = result.scalar_one_or_none() if not room: return track_url = None if room.current_track_id: track_url = f"/api/tracks/{room.current_track_id}/stream" # Calculate current position based on when playback started current_position = room.playback_position or 0 if room.is_playing and room.playback_started_at: elapsed = (datetime.utcnow() - room.playback_started_at).total_seconds() * 1000 current_position = int((room.playback_position or 0) + elapsed) await websocket.send_json({ "type": "sync_state", "is_playing": room.is_playing, "position": current_position, "current_track_id": str(room.current_track_id) if room.current_track_id else None, "track_url": track_url, "server_time": datetime.utcnow().isoformat(), })