from sqlalchemy.orm import Session, joinedload
from app.models.order import CartItem, Order, OrderItem, OrderStatus, PaymentStatus
from app.models.product import Product

def get_cart(db: Session, user_id: int):
    return db.query(CartItem).options(joinedload(CartItem.product)).filter(CartItem.user_id == user_id).all()

def add_to_cart(db: Session, user_id: int, product_id: int, quantity: int, size: str = None):
    existing = db.query(CartItem).filter(
        CartItem.user_id == user_id, 
        CartItem.product_id == product_id,
        CartItem.size == size
    ).first()
    if existing:
        existing.quantity += quantity
        db.commit(); db.refresh(existing)
        return existing
    item = CartItem(user_id=user_id, product_id=product_id, quantity=quantity, size=size)
    db.add(item); db.commit(); db.refresh(item)
    return item

def remove_from_cart(db: Session, user_id: int, cart_item_id: int):
    item = db.query(CartItem).filter(CartItem.id == cart_item_id, CartItem.user_id == user_id).first()
    if item: db.delete(item); db.commit()
    return item

def clear_cart(db: Session, user_id: int):
    db.query(CartItem).filter(CartItem.user_id == user_id).delete()
    db.commit()

def place_order(db: Session, user_id: int, payment_method, shipping_address: str, notes: str = None, stripe_payment_id: str = None):
    cart = get_cart(db, user_id)
    if not cart: return None
    total = sum(item.product.price * item.quantity for item in cart)
    order = Order(user_id=user_id, payment_method=payment_method,
                  total_amount=total, shipping_address=shipping_address, notes=notes, stripe_payment_id=stripe_payment_id)
    db.add(order); db.flush()
    for cart_item in cart:
        order_item = OrderItem(order_id=order.id, product_id=cart_item.product_id,
                               quantity=cart_item.quantity, unit_price=cart_item.product.price,
                               size=cart_item.size)
        db.add(order_item)
        product = db.query(Product).filter(Product.id == cart_item.product_id).first()
        if product: product.stock -= cart_item.quantity
    clear_cart(db, user_id)
    db.commit(); db.refresh(order)
    return order

def get_orders(db: Session, user_id: int = None, skip: int = 0, limit: int = 20):
    q = db.query(Order).options(joinedload(Order.items).joinedload(OrderItem.product))
    if user_id: q = q.filter(Order.user_id == user_id)
    total = q.count()
    return {"items": q.order_by(Order.created_at.desc()).offset(skip).limit(limit).all(), "total": total}

def update_order_status(db: Session, order_id: int, status: OrderStatus):
    order = db.query(Order).filter(Order.id == order_id).first()
    if not order: return None
    order.status = status
    if status == OrderStatus.confirmed:
        order.payment_status = PaymentStatus.paid
    db.commit(); db.refresh(order)
    return order

def get_order_by_stripe_id(db: Session, stripe_id: str):
    return db.query(Order).filter(Order.stripe_payment_id == stripe_id).first()
