init
This commit is contained in:
0
backend/app/routers/__init__.py
Normal file
0
backend/app/routers/__init__.py
Normal file
62
backend/app/routers/auth.py
Normal file
62
backend/app/routers/auth.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from ..database import get_db
|
||||
from ..models.user import User
|
||||
from ..schemas.user import UserCreate, UserLogin, UserResponse, Token
|
||||
from ..utils.security import get_password_hash, verify_password, create_access_token
|
||||
from ..services.auth import get_current_user
|
||||
|
||||
router = APIRouter(prefix="/api/auth", tags=["auth"])
|
||||
|
||||
|
||||
@router.post("/register", response_model=Token)
|
||||
async def register(user_data: UserCreate, db: AsyncSession = Depends(get_db)):
|
||||
# Check if email exists
|
||||
result = await db.execute(select(User).where(User.email == user_data.email))
|
||||
if result.scalar_one_or_none():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Email already registered",
|
||||
)
|
||||
|
||||
# Check if username exists
|
||||
result = await db.execute(select(User).where(User.username == user_data.username))
|
||||
if result.scalar_one_or_none():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Username already taken",
|
||||
)
|
||||
|
||||
# Create user
|
||||
user = User(
|
||||
username=user_data.username,
|
||||
email=user_data.email,
|
||||
password_hash=get_password_hash(user_data.password),
|
||||
)
|
||||
db.add(user)
|
||||
await db.flush()
|
||||
|
||||
# Create token
|
||||
access_token = create_access_token(data={"sub": str(user.id)})
|
||||
return Token(access_token=access_token)
|
||||
|
||||
|
||||
@router.post("/login", response_model=Token)
|
||||
async def login(user_data: UserLogin, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(select(User).where(User.email == user_data.email))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user or not verify_password(user_data.password, user.password_hash):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid email or password",
|
||||
)
|
||||
|
||||
access_token = create_access_token(data={"sub": str(user.id)})
|
||||
return Token(access_token=access_token)
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserResponse)
|
||||
async def get_me(current_user: User = Depends(get_current_user)):
|
||||
return current_user
|
||||
38
backend/app/routers/messages.py
Normal file
38
backend/app/routers/messages.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from uuid import UUID
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
from ..database import get_db
|
||||
from ..models.message import Message
|
||||
from ..schemas.message import MessageResponse
|
||||
|
||||
router = APIRouter(prefix="/api/rooms", tags=["messages"])
|
||||
|
||||
|
||||
@router.get("/{room_id}/messages", response_model=list[MessageResponse])
|
||||
async def get_messages(
|
||||
room_id: UUID,
|
||||
limit: int = 50,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(Message)
|
||||
.options(selectinload(Message.user))
|
||||
.where(Message.room_id == room_id)
|
||||
.order_by(Message.created_at.desc())
|
||||
.limit(limit)
|
||||
)
|
||||
messages = result.scalars().all()
|
||||
|
||||
return [
|
||||
MessageResponse(
|
||||
id=msg.id,
|
||||
room_id=msg.room_id,
|
||||
user_id=msg.user_id,
|
||||
username=msg.user.username,
|
||||
text=msg.text,
|
||||
created_at=msg.created_at,
|
||||
)
|
||||
for msg in reversed(messages)
|
||||
]
|
||||
248
backend/app/routers/rooms.py
Normal file
248
backend/app/routers/rooms.py
Normal file
@@ -0,0 +1,248 @@
|
||||
from uuid import UUID
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.orm import selectinload
|
||||
from ..database import get_db
|
||||
from ..models.user import User
|
||||
from ..models.room import Room, RoomParticipant
|
||||
from ..models.track import RoomQueue
|
||||
from ..schemas.room import RoomCreate, RoomResponse, RoomDetailResponse, QueueAdd
|
||||
from ..schemas.track import TrackResponse
|
||||
from ..schemas.user import UserResponse
|
||||
from ..services.auth import get_current_user
|
||||
from ..services.sync import manager
|
||||
from ..config import get_settings
|
||||
|
||||
settings = get_settings()
|
||||
router = APIRouter(prefix="/api/rooms", tags=["rooms"])
|
||||
|
||||
|
||||
@router.get("", response_model=list[RoomResponse])
|
||||
async def get_rooms(db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(
|
||||
select(Room, func.count(RoomParticipant.user_id).label("participants_count"))
|
||||
.outerjoin(RoomParticipant)
|
||||
.group_by(Room.id)
|
||||
.order_by(Room.created_at.desc())
|
||||
)
|
||||
rooms = []
|
||||
for room, count in result.all():
|
||||
room_dict = {
|
||||
"id": room.id,
|
||||
"name": room.name,
|
||||
"owner_id": room.owner_id,
|
||||
"current_track_id": room.current_track_id,
|
||||
"playback_position": room.playback_position,
|
||||
"is_playing": room.is_playing,
|
||||
"created_at": room.created_at,
|
||||
"participants_count": count,
|
||||
}
|
||||
rooms.append(RoomResponse(**room_dict))
|
||||
return rooms
|
||||
|
||||
|
||||
@router.post("", response_model=RoomResponse)
|
||||
async def create_room(
|
||||
room_data: RoomCreate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
room = Room(name=room_data.name, owner_id=current_user.id)
|
||||
db.add(room)
|
||||
await db.flush()
|
||||
return RoomResponse(
|
||||
id=room.id,
|
||||
name=room.name,
|
||||
owner_id=room.owner_id,
|
||||
current_track_id=room.current_track_id,
|
||||
playback_position=room.playback_position,
|
||||
is_playing=room.is_playing,
|
||||
created_at=room.created_at,
|
||||
participants_count=0,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{room_id}", response_model=RoomDetailResponse)
|
||||
async def get_room(room_id: UUID, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(
|
||||
select(Room)
|
||||
.options(
|
||||
selectinload(Room.owner),
|
||||
selectinload(Room.current_track),
|
||||
selectinload(Room.participants).selectinload(RoomParticipant.user),
|
||||
)
|
||||
.where(Room.id == room_id)
|
||||
)
|
||||
room = result.scalar_one_or_none()
|
||||
|
||||
if not room:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Room not found")
|
||||
|
||||
return RoomDetailResponse(
|
||||
id=room.id,
|
||||
name=room.name,
|
||||
owner=UserResponse.model_validate(room.owner),
|
||||
current_track=TrackResponse.model_validate(room.current_track) if room.current_track else None,
|
||||
playback_position=room.playback_position,
|
||||
is_playing=room.is_playing,
|
||||
created_at=room.created_at,
|
||||
participants=[UserResponse.model_validate(p.user) for p in room.participants],
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{room_id}")
|
||||
async def delete_room(
|
||||
room_id: UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
result = await db.execute(select(Room).where(Room.id == room_id))
|
||||
room = result.scalar_one_or_none()
|
||||
|
||||
if not room:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Room not found")
|
||||
|
||||
if room.owner_id != current_user.id:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not room owner")
|
||||
|
||||
await db.delete(room)
|
||||
return {"status": "deleted"}
|
||||
|
||||
|
||||
@router.post("/{room_id}/join")
|
||||
async def join_room(
|
||||
room_id: UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
result = await db.execute(select(Room).where(Room.id == room_id))
|
||||
room = result.scalar_one_or_none()
|
||||
|
||||
if not room:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Room not found")
|
||||
|
||||
# Check participant limit
|
||||
result = await db.execute(
|
||||
select(func.count(RoomParticipant.user_id)).where(RoomParticipant.room_id == room_id)
|
||||
)
|
||||
count = result.scalar()
|
||||
if count >= settings.max_room_participants:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Room is full")
|
||||
|
||||
# Check if already joined
|
||||
result = await db.execute(
|
||||
select(RoomParticipant).where(
|
||||
RoomParticipant.room_id == room_id,
|
||||
RoomParticipant.user_id == current_user.id,
|
||||
)
|
||||
)
|
||||
if result.scalar_one_or_none():
|
||||
return {"status": "already joined"}
|
||||
|
||||
participant = RoomParticipant(room_id=room_id, user_id=current_user.id)
|
||||
db.add(participant)
|
||||
|
||||
# Notify others
|
||||
await manager.broadcast_to_room(
|
||||
room_id,
|
||||
{"type": "user_joined", "user": {"id": str(current_user.id), "username": current_user.username}},
|
||||
)
|
||||
|
||||
return {"status": "joined"}
|
||||
|
||||
|
||||
@router.post("/{room_id}/leave")
|
||||
async def leave_room(
|
||||
room_id: UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(RoomParticipant).where(
|
||||
RoomParticipant.room_id == room_id,
|
||||
RoomParticipant.user_id == current_user.id,
|
||||
)
|
||||
)
|
||||
participant = result.scalar_one_or_none()
|
||||
|
||||
if participant:
|
||||
await db.delete(participant)
|
||||
|
||||
# Notify others
|
||||
await manager.broadcast_to_room(
|
||||
room_id,
|
||||
{"type": "user_left", "user_id": str(current_user.id)},
|
||||
)
|
||||
|
||||
return {"status": "left"}
|
||||
|
||||
|
||||
@router.get("/{room_id}/queue", response_model=list[TrackResponse])
|
||||
async def get_queue(room_id: UUID, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(
|
||||
select(RoomQueue)
|
||||
.options(selectinload(RoomQueue.track))
|
||||
.where(RoomQueue.room_id == room_id)
|
||||
.order_by(RoomQueue.position)
|
||||
)
|
||||
queue_items = result.scalars().all()
|
||||
return [TrackResponse.model_validate(item.track) for item in queue_items]
|
||||
|
||||
|
||||
@router.post("/{room_id}/queue")
|
||||
async def add_to_queue(
|
||||
room_id: UUID,
|
||||
data: QueueAdd,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
# Get max position
|
||||
result = await db.execute(
|
||||
select(func.max(RoomQueue.position)).where(RoomQueue.room_id == room_id)
|
||||
)
|
||||
max_pos = result.scalar() or 0
|
||||
|
||||
queue_item = RoomQueue(
|
||||
room_id=room_id,
|
||||
track_id=data.track_id,
|
||||
position=max_pos + 1,
|
||||
added_by=current_user.id,
|
||||
)
|
||||
db.add(queue_item)
|
||||
await db.flush()
|
||||
|
||||
# Notify others
|
||||
await manager.broadcast_to_room(
|
||||
room_id,
|
||||
{"type": "queue_updated"},
|
||||
)
|
||||
|
||||
return {"status": "added"}
|
||||
|
||||
|
||||
@router.delete("/{room_id}/queue/{track_id}")
|
||||
async def remove_from_queue(
|
||||
room_id: UUID,
|
||||
track_id: UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(RoomQueue).where(
|
||||
RoomQueue.room_id == room_id,
|
||||
RoomQueue.track_id == track_id,
|
||||
)
|
||||
)
|
||||
queue_item = result.scalar_one_or_none()
|
||||
|
||||
if queue_item:
|
||||
await db.delete(queue_item)
|
||||
|
||||
# Notify others
|
||||
await manager.broadcast_to_room(
|
||||
room_id,
|
||||
{"type": "queue_updated"},
|
||||
)
|
||||
|
||||
return {"status": "removed"}
|
||||
222
backend/app/routers/tracks.py
Normal file
222
backend/app/routers/tracks.py
Normal file
@@ -0,0 +1,222 @@
|
||||
import uuid
|
||||
from urllib.parse import quote
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, UploadFile, File, Form, Request, Response
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func
|
||||
from mutagen.mp3 import MP3
|
||||
from io import BytesIO
|
||||
from ..database import get_db
|
||||
from ..models.user import User
|
||||
from ..models.track import Track
|
||||
from ..schemas.track import TrackResponse, TrackWithUrl
|
||||
from ..services.auth import get_current_user
|
||||
from ..services.s3 import upload_file, delete_file, generate_presigned_url, can_upload_file, get_file_content
|
||||
from ..config import get_settings
|
||||
|
||||
settings = get_settings()
|
||||
router = APIRouter(prefix="/api/tracks", tags=["tracks"])
|
||||
|
||||
|
||||
@router.get("", response_model=list[TrackResponse])
|
||||
async def get_tracks(db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(select(Track).order_by(Track.created_at.desc()))
|
||||
return result.scalars().all()
|
||||
|
||||
|
||||
@router.post("/upload", response_model=TrackResponse)
|
||||
async def upload_track(
|
||||
file: UploadFile = File(...),
|
||||
title: str = Form(None),
|
||||
artist: str = Form(None),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
# Check file type
|
||||
if not file.content_type or not file.content_type.startswith("audio/"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="File must be an audio file",
|
||||
)
|
||||
|
||||
# Read file content
|
||||
content = await file.read()
|
||||
file_size = len(content)
|
||||
|
||||
# Check file size
|
||||
max_size = settings.max_file_size_mb * 1024 * 1024
|
||||
if file_size > max_size:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"File size exceeds {settings.max_file_size_mb}MB limit",
|
||||
)
|
||||
|
||||
# Check storage limit
|
||||
if not await can_upload_file(file_size):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Storage limit exceeded",
|
||||
)
|
||||
|
||||
# Get duration and metadata from MP3
|
||||
try:
|
||||
audio = MP3(BytesIO(content))
|
||||
duration = int(audio.info.length * 1000) # Convert to milliseconds
|
||||
|
||||
# Extract ID3 tags if title/artist not provided
|
||||
if not title or not artist:
|
||||
tags = audio.tags
|
||||
if tags:
|
||||
# TIT2 = Title, TPE1 = Artist
|
||||
if not title and tags.get("TIT2"):
|
||||
title = str(tags.get("TIT2"))
|
||||
if not artist and tags.get("TPE1"):
|
||||
artist = str(tags.get("TPE1"))
|
||||
|
||||
# Fallback to filename if still no title
|
||||
if not title:
|
||||
title = file.filename.rsplit(".", 1)[0] if file.filename else "Unknown"
|
||||
if not artist:
|
||||
artist = "Unknown"
|
||||
|
||||
except Exception:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Could not read audio file",
|
||||
)
|
||||
|
||||
# Upload to S3
|
||||
s3_key = f"tracks/{uuid.uuid4()}.mp3"
|
||||
await upload_file(content, s3_key)
|
||||
|
||||
# Create track record
|
||||
track = Track(
|
||||
title=title,
|
||||
artist=artist,
|
||||
duration=duration,
|
||||
s3_key=s3_key,
|
||||
file_size=file_size,
|
||||
uploaded_by=current_user.id,
|
||||
)
|
||||
db.add(track)
|
||||
await db.flush()
|
||||
|
||||
return track
|
||||
|
||||
|
||||
@router.get("/{track_id}", response_model=TrackWithUrl)
|
||||
async def get_track(track_id: uuid.UUID, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(select(Track).where(Track.id == track_id))
|
||||
track = result.scalar_one_or_none()
|
||||
|
||||
if not track:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Track not found")
|
||||
|
||||
url = generate_presigned_url(track.s3_key)
|
||||
return TrackWithUrl(
|
||||
id=track.id,
|
||||
title=track.title,
|
||||
artist=track.artist,
|
||||
duration=track.duration,
|
||||
file_size=track.file_size,
|
||||
uploaded_by=track.uploaded_by,
|
||||
created_at=track.created_at,
|
||||
url=url,
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{track_id}")
|
||||
async def delete_track(
|
||||
track_id: uuid.UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
result = await db.execute(select(Track).where(Track.id == track_id))
|
||||
track = result.scalar_one_or_none()
|
||||
|
||||
if not track:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Track not found")
|
||||
|
||||
if track.uploaded_by != current_user.id:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not track owner")
|
||||
|
||||
# Delete from S3
|
||||
await delete_file(track.s3_key)
|
||||
|
||||
# Delete from DB
|
||||
await db.delete(track)
|
||||
|
||||
return {"status": "deleted"}
|
||||
|
||||
|
||||
@router.get("/storage/info")
|
||||
async def get_storage_info(db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(select(func.sum(Track.file_size)))
|
||||
total_size = result.scalar() or 0
|
||||
max_size = settings.max_storage_gb * 1024 * 1024 * 1024
|
||||
|
||||
return {
|
||||
"used_bytes": total_size,
|
||||
"max_bytes": max_size,
|
||||
"used_gb": round(total_size / (1024 * 1024 * 1024), 2),
|
||||
"max_gb": settings.max_storage_gb,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{track_id}/stream")
|
||||
async def stream_track(track_id: uuid.UUID, request: Request, db: AsyncSession = Depends(get_db)):
|
||||
"""Stream audio file through backend with Range support (bypasses S3 SSL issues)"""
|
||||
result = await db.execute(select(Track).where(Track.id == track_id))
|
||||
track = result.scalar_one_or_none()
|
||||
|
||||
if not track:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Track not found")
|
||||
|
||||
# Get full file content
|
||||
content = get_file_content(track.s3_key)
|
||||
file_size = len(content)
|
||||
|
||||
# Parse Range header
|
||||
range_header = request.headers.get("range")
|
||||
|
||||
if range_header:
|
||||
# Parse "bytes=start-end"
|
||||
range_match = range_header.replace("bytes=", "").split("-")
|
||||
start = int(range_match[0]) if range_match[0] else 0
|
||||
end = int(range_match[1]) if range_match[1] else file_size - 1
|
||||
|
||||
# Ensure valid range
|
||||
if start >= file_size:
|
||||
raise HTTPException(status_code=416, detail="Range not satisfiable")
|
||||
|
||||
end = min(end, file_size - 1)
|
||||
content_length = end - start + 1
|
||||
|
||||
# Encode filename for non-ASCII characters
|
||||
encoded_filename = quote(f"{track.title}.mp3")
|
||||
|
||||
return Response(
|
||||
content=content[start:end + 1],
|
||||
status_code=206,
|
||||
media_type="audio/mpeg",
|
||||
headers={
|
||||
"Content-Range": f"bytes {start}-{end}/{file_size}",
|
||||
"Accept-Ranges": "bytes",
|
||||
"Content-Length": str(content_length),
|
||||
"Content-Disposition": f"inline; filename*=UTF-8''{encoded_filename}",
|
||||
}
|
||||
)
|
||||
|
||||
# Encode filename for non-ASCII characters
|
||||
encoded_filename = quote(f"{track.title}.mp3")
|
||||
|
||||
# No range - return full file
|
||||
return Response(
|
||||
content=content,
|
||||
media_type="audio/mpeg",
|
||||
headers={
|
||||
"Accept-Ranges": "bytes",
|
||||
"Content-Length": str(file_size),
|
||||
"Content-Disposition": f"inline; filename*=UTF-8''{encoded_filename}",
|
||||
}
|
||||
)
|
||||
234
backend/app/routers/websocket.py
Normal file
234
backend/app/routers/websocket.py
Normal file
@@ -0,0 +1,234 @@
|
||||
from uuid import UUID
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
from ..database import get_db, async_session
|
||||
from ..models.room import Room, RoomParticipant
|
||||
from ..models.track import RoomQueue
|
||||
from ..models.message import Message
|
||||
from ..models.user import User
|
||||
from ..services.sync import manager
|
||||
from ..utils.security import decode_token
|
||||
|
||||
router = APIRouter(tags=["websocket"])
|
||||
|
||||
|
||||
async def get_user_from_token(token: str) -> User | None:
|
||||
payload = decode_token(token)
|
||||
if not payload:
|
||||
return None
|
||||
|
||||
user_id = payload.get("sub")
|
||||
if not user_id:
|
||||
return None
|
||||
|
||||
async with async_session() as db:
|
||||
result = await db.execute(select(User).where(User.id == UUID(user_id)))
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
@router.websocket("/ws/rooms/{room_id}")
|
||||
async def room_websocket(websocket: WebSocket, room_id: UUID):
|
||||
# Get token from query params
|
||||
token = websocket.query_params.get("token")
|
||||
if not token:
|
||||
await websocket.close(code=4001, reason="No token provided")
|
||||
return
|
||||
|
||||
user = await get_user_from_token(token)
|
||||
if not user:
|
||||
await websocket.close(code=4001, reason="Invalid token")
|
||||
return
|
||||
|
||||
await manager.connect(websocket, room_id, user.id)
|
||||
|
||||
try:
|
||||
while True:
|
||||
data = await websocket.receive_text()
|
||||
message = json.loads(data)
|
||||
|
||||
async with async_session() as db:
|
||||
if message["type"] == "player_action":
|
||||
await handle_player_action(db, room_id, user, message)
|
||||
elif message["type"] == "chat_message":
|
||||
await handle_chat_message(db, room_id, user, message)
|
||||
elif message["type"] == "sync_request":
|
||||
await handle_sync_request(db, room_id, websocket)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
manager.disconnect(websocket, room_id, user.id)
|
||||
await manager.broadcast_to_room(
|
||||
room_id,
|
||||
{"type": "user_left", "user_id": str(user.id)},
|
||||
)
|
||||
|
||||
|
||||
async def handle_player_action(db: AsyncSession, room_id: UUID, user: User, message: dict):
|
||||
action = message.get("action")
|
||||
result = await db.execute(select(Room).where(Room.id == room_id))
|
||||
room = result.scalar_one_or_none()
|
||||
|
||||
if not room:
|
||||
return
|
||||
|
||||
if action == "play":
|
||||
room.is_playing = True
|
||||
room.playback_position = message.get("position", room.playback_position or 0)
|
||||
room.playback_started_at = datetime.utcnow()
|
||||
elif action == "pause":
|
||||
room.is_playing = False
|
||||
room.playback_position = message.get("position", room.playback_position or 0)
|
||||
room.playback_started_at = None
|
||||
elif action == "seek":
|
||||
room.playback_position = message.get("position", 0)
|
||||
if room.is_playing:
|
||||
room.playback_started_at = datetime.utcnow()
|
||||
elif action == "next":
|
||||
await play_next_track(db, room)
|
||||
elif action == "prev":
|
||||
await play_prev_track(db, room)
|
||||
elif action == "set_track":
|
||||
track_id = message.get("track_id")
|
||||
if track_id:
|
||||
room.current_track_id = UUID(track_id)
|
||||
room.playback_position = 0
|
||||
room.is_playing = True
|
||||
room.playback_started_at = datetime.utcnow()
|
||||
|
||||
await db.commit()
|
||||
|
||||
# Get current track URL - use streaming endpoint to bypass S3 SSL issues
|
||||
track_url = None
|
||||
if room.current_track_id:
|
||||
track_url = f"/api/tracks/{room.current_track_id}/stream"
|
||||
|
||||
# Calculate current position based on when playback started
|
||||
current_position = room.playback_position or 0
|
||||
if room.is_playing and room.playback_started_at:
|
||||
elapsed = (datetime.utcnow() - room.playback_started_at).total_seconds() * 1000
|
||||
current_position = int((room.playback_position or 0) + elapsed)
|
||||
|
||||
await manager.broadcast_to_room(
|
||||
room_id,
|
||||
{
|
||||
"type": "player_state",
|
||||
"is_playing": room.is_playing,
|
||||
"position": current_position,
|
||||
"current_track_id": str(room.current_track_id) if room.current_track_id else None,
|
||||
"track_url": track_url,
|
||||
"server_time": datetime.utcnow().isoformat(),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def play_next_track(db: AsyncSession, room: Room):
|
||||
result = await db.execute(
|
||||
select(RoomQueue)
|
||||
.where(RoomQueue.room_id == room.id)
|
||||
.order_by(RoomQueue.position)
|
||||
)
|
||||
queue = result.scalars().all()
|
||||
|
||||
if not queue:
|
||||
room.current_track_id = None
|
||||
room.is_playing = False
|
||||
room.playback_started_at = None
|
||||
return
|
||||
|
||||
# Find current track in queue
|
||||
current_index = -1
|
||||
for i, item in enumerate(queue):
|
||||
if item.track_id == room.current_track_id:
|
||||
current_index = i
|
||||
break
|
||||
|
||||
# Play next or first
|
||||
next_index = (current_index + 1) % len(queue)
|
||||
room.current_track_id = queue[next_index].track_id
|
||||
room.playback_position = 0
|
||||
room.is_playing = True
|
||||
room.playback_started_at = datetime.utcnow()
|
||||
|
||||
|
||||
async def play_prev_track(db: AsyncSession, room: Room):
|
||||
result = await db.execute(
|
||||
select(RoomQueue)
|
||||
.where(RoomQueue.room_id == room.id)
|
||||
.order_by(RoomQueue.position)
|
||||
)
|
||||
queue = result.scalars().all()
|
||||
|
||||
if not queue:
|
||||
room.current_track_id = None
|
||||
room.is_playing = False
|
||||
room.playback_started_at = None
|
||||
return
|
||||
|
||||
# Find current track in queue
|
||||
current_index = 0
|
||||
for i, item in enumerate(queue):
|
||||
if item.track_id == room.current_track_id:
|
||||
current_index = i
|
||||
break
|
||||
|
||||
# Play prev or last
|
||||
prev_index = (current_index - 1) % len(queue)
|
||||
room.current_track_id = queue[prev_index].track_id
|
||||
room.playback_position = 0
|
||||
room.is_playing = True
|
||||
room.playback_started_at = datetime.utcnow()
|
||||
|
||||
|
||||
async def handle_chat_message(db: AsyncSession, room_id: UUID, user: User, message: dict):
|
||||
text = message.get("text", "").strip()
|
||||
if not text:
|
||||
return
|
||||
|
||||
msg = Message(room_id=room_id, user_id=user.id, text=text)
|
||||
db.add(msg)
|
||||
await db.commit()
|
||||
|
||||
await manager.broadcast_to_room(
|
||||
room_id,
|
||||
{
|
||||
"type": "chat_message",
|
||||
"id": str(msg.id),
|
||||
"user_id": str(user.id),
|
||||
"username": user.username,
|
||||
"text": text,
|
||||
"created_at": msg.created_at.isoformat(),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def handle_sync_request(db: AsyncSession, room_id: UUID, websocket: WebSocket):
|
||||
result = await db.execute(
|
||||
select(Room).options(selectinload(Room.current_track)).where(Room.id == room_id)
|
||||
)
|
||||
room = result.scalar_one_or_none()
|
||||
|
||||
if not room:
|
||||
return
|
||||
|
||||
track_url = None
|
||||
if room.current_track_id:
|
||||
track_url = f"/api/tracks/{room.current_track_id}/stream"
|
||||
|
||||
# Calculate current position based on when playback started
|
||||
current_position = room.playback_position or 0
|
||||
if room.is_playing and room.playback_started_at:
|
||||
elapsed = (datetime.utcnow() - room.playback_started_at).total_seconds() * 1000
|
||||
current_position = int((room.playback_position or 0) + elapsed)
|
||||
|
||||
await websocket.send_json({
|
||||
"type": "sync_state",
|
||||
"is_playing": room.is_playing,
|
||||
"position": current_position,
|
||||
"current_track_id": str(room.current_track_id) if room.current_track_id else None,
|
||||
"track_url": track_url,
|
||||
"server_time": datetime.utcnow().isoformat(),
|
||||
})
|
||||
Reference in New Issue
Block a user