from flask import Flask, render_template, request, jsonify, redirect, url_for, flash, session
import psycopg2
from psycopg2.extras import RealDictCursor
from datetime import datetime
import config
import os
from werkzeug.utils import secure_filename
import bcrypt
import sys
from decimal import Decimal, ROUND_HALF_UP, getcontext

# Print startup diagnostics
print("=" * 60, file=sys.stderr)
print("STARTING APPLICATION", file=sys.stderr)
print(f"DATABASE_URL present: {bool(config.DATABASE_URL)}", file=sys.stderr)
print(f"USE_CLOUDINARY: {config.USE_CLOUDINARY}", file=sys.stderr)
print(f"SECRET_KEY present: {bool(config.SECRET_KEY)}", file=sys.stderr)
print("=" * 60, file=sys.stderr)

# Pagination configuration
PRODUCTS_PER_PAGE = 6

# Preferred size ordering for display
SIZE_ORDER = ['XS', 'S', 'M', 'L', 'XL', 'XXL', '28', '30', '32', '34', '36']

def _clean_text(text):
    """Remove photo number annotations like '(Photo 2)' from text"""
    import re
    if not text:
        return text
    # Remove patterns like "(Photo 1)", "(Photo 2)", etc.
    cleaned = re.sub(r'\s*\(Photo\s+\d+\)\s*', '', str(text))
    return cleaned.strip()

def _size_sort_key(size: str):
    """Sort sizes based on preferred order, keeping unknown sizes at the end."""
    size = _clean_text((size or '').strip())
    try:
        return (0, SIZE_ORDER.index(size))
    except ValueError:
        return (1, size)

app = Flask(__name__)
app.config.from_object(config)
app.config['SECRET_KEY'] = config.SECRET_KEY

# Configure upload folder
import os as _os
# Ensure absolute path for local uploads (works well on cPanel)
app.config['UPLOAD_FOLDER'] = _os.path.join(app.root_path, config.UPLOAD_FOLDER)
app.config['MAX_CONTENT_LENGTH'] = config.MAX_CONTENT_LENGTH

# Template filter for image optimization
@app.template_filter('optimize_image')
def optimize_image_filter(url, width=800):
    """Jinja2 filter to optimize images for mobile"""
    return config.get_optimized_image_url(url, width=width)

# Template filter to clean size/color text (remove photo annotations)
@app.template_filter('clean_text')
def clean_text_filter(text):
    """Remove photo number annotations like '(Photo 2)' from text"""
    import re
    if not text:
        return text
    # Remove patterns like "(Photo 1)", "(Photo 2)", etc.
    cleaned = re.sub(r'\s*\(Photo\s+\d+\)\s*', '', str(text))
    return cleaned.strip()

# Money helpers
getcontext().prec = 28
CENT = Decimal('0.01')

def money(value):
    """Convert to Decimal money with 2dp (half-up). Accepts str/float/int."""
    return (Decimal(str(value)) if value is not None else Decimal('0')).quantize(CENT, rounding=ROUND_HALF_UP)

# Initialize Cloudinary if configured
if config.USE_CLOUDINARY:
    try:
        import cloudinary
        import cloudinary.uploader
        cloudinary.config(
            cloud_name=config.CLOUDINARY_CLOUD_NAME,
            api_key=config.CLOUDINARY_API_KEY,
            api_secret=config.CLOUDINARY_API_SECRET
        )
        print("Cloudinary configured successfully", file=sys.stderr)
    except Exception as e:
        print(f"Cloudinary setup failed: {e}", file=sys.stderr)
        config.USE_CLOUDINARY = False

# Database connection helper function
def get_db_connection():
    """Get PostgreSQL database connection (supports cPanel and external databases)"""
    try:
        connection_params = {}
        
        if config.DATABASE_URL:
            # Handle postgres:// vs postgresql:// prefix
            db_url = config.DATABASE_URL
            if db_url.startswith('postgres://'):
                db_url = db_url.replace('postgres://', 'postgresql://', 1)
            print(f"Connecting to database via DATABASE_URL", file=sys.stderr)
            conn = psycopg2.connect(db_url, connect_timeout=10)
        else:
            # Use individual configuration
            print(f"Connecting to database: {config.DB_HOST}:{config.DB_PORT}/{config.DB_NAME}", file=sys.stderr)
            connection_params = {
                'host': config.DB_HOST,
                'user': config.DB_USER,
                'password': config.DB_PASSWORD,
                'database': config.DB_NAME,
                'port': int(config.DB_PORT),  # Ensure port is int
                'connect_timeout': 10
            }
            conn = psycopg2.connect(**connection_params)
        
        print("Database connection successful", file=sys.stderr)
        return conn
    except psycopg2.OperationalError as e:
        print(f"Database connection error (check credentials/host/port): {e}", file=sys.stderr)
        raise
    except Exception as e:
        print(f"Database connection error: {e}", file=sys.stderr)
        raise

def ensure_database_schema():
    """Ensure database schema is up to date with migrations"""
    try:
        conn = get_db_connection()
        cursor = conn.cursor()
        
        # Migration 1: Rename product_code to lot_details and add is_deleted
        cursor.execute("""
            SELECT column_name FROM information_schema.columns
            WHERE table_name = 'products' AND column_name = 'product_code'
        """)
        if cursor.fetchone():
            print("Migrating: Renaming product_code to lot_details...", file=sys.stderr)
            cursor.execute("ALTER TABLE products RENAME COLUMN product_code TO lot_details")
            conn.commit()
            print("Migration complete: product_code renamed to lot_details", file=sys.stderr)
        
        # Migration 2: Add is_deleted column if not exists
        cursor.execute("""
            SELECT column_name FROM information_schema.columns
            WHERE table_name = 'products' AND column_name = 'is_deleted'
        """)
        if not cursor.fetchone():
            print("Migrating: Adding is_deleted column...", file=sys.stderr)
            cursor.execute("ALTER TABLE products ADD COLUMN is_deleted BOOLEAN DEFAULT FALSE")
            conn.commit()
            print("Migration complete: is_deleted column added", file=sys.stderr)
        
        # Migration 3: Add is_visible column if not exists
        cursor.execute("""
            SELECT column_name FROM information_schema.columns
            WHERE table_name = 'products' AND column_name = 'is_visible'
        """)
        if not cursor.fetchone():
            print("Migrating: Adding is_visible column...", file=sys.stderr)
            cursor.execute("ALTER TABLE products ADD COLUMN is_visible BOOLEAN DEFAULT TRUE")
            conn.commit()
            print("Migration complete: is_visible column added", file=sys.stderr)
        
        cursor.close()
        conn.close()
        print("Database schema is up to date", file=sys.stderr)
    except Exception as e:
        print(f"Database migration error: {e}", file=sys.stderr)
        if conn:
            conn.rollback()

# Run migrations on startup
ensure_database_schema()

# Error handlers for cPanel compatibility
@app.errorhandler(404)
def page_not_found(error):
    """Handle 404 errors"""
    return render_template('404.html' if os.path.exists(os.path.join(app.root_path, 'templates', '404.html')) else None,
                          error=str(error)), 404

@app.errorhandler(500)
def internal_server_error(error):
    """Handle 500 errors"""
    print(f"Internal Server Error: {error}", file=sys.stderr)
    return jsonify({
        'status': 'error',
        'message': 'Internal server error. Please check the server logs.'
    }), 500

@app.errorhandler(403)
def forbidden(error):
    """Handle 403 forbidden errors"""
    return jsonify({
        'status': 'error',
        'message': 'Access forbidden'
    }), 403

@app.errorhandler(400)
def bad_request(error):
    """Handle 400 bad request errors"""
    return jsonify({
        'status': 'error',
        'message': 'Bad request'
    }), 400

@app.before_request
def before_request():
    """Ensure database connection is available"""
    try:
        # Test database connection at start of each request
        conn = get_db_connection()
        conn.close()
    except Exception as e:
        print(f"Database connection failed: {e}", file=sys.stderr)
        # Don't fail on health check or login
        if request.path not in ['/health', '/login']:
            flash('Database connection error. Please contact administrator.', 'danger')

def allowed_file(filename):
    """Check if file extension is allowed"""
    return '.' in filename and filename.rsplit('.', 1)[1].lower() in config.ALLOWED_EXTENSIONS

def upload_image(file):
    """Upload image to Cloudinary or local storage (cPanel friendly local storage)"""
    if not file or file.filename == '':
        return None
    
    if not allowed_file(file.filename):
        return None
    
    try:
        if config.USE_CLOUDINARY:
            # Upload to Cloudinary
            result = cloudinary.uploader.upload(file, folder="garments_products")
            return result['secure_url']
        else:
            # Upload to local storage - cPanel friendly
            filename = secure_filename(file.filename)
            unique_filename = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_{filename}"
            
            # Ensure upload directory exists with proper permissions
            upload_dir = app.config['UPLOAD_FOLDER']
            os.makedirs(upload_dir, exist_ok=True)
            
            file_path = os.path.join(upload_dir, unique_filename)
            file.save(file_path)
            
            # Return relative URL path for web access
            return f"uploads/{unique_filename}"
    except Exception as e:
        print(f"Error uploading image: {e}", file=sys.stderr)
        return None

# Login required decorator for admin routes
def login_required(f):
    from functools import wraps
    @wraps(f)
    def decorated_function(*args, **kwargs):
        if 'loggedin' not in session or session.get('user_type') != 'admin':
            flash('Please log in as admin to access this page.', 'warning')
            return redirect(url_for('login'))
        return f(*args, **kwargs)
    return decorated_function

# Login required decorator for salesman and admin routes
def sales_access_required(f):
    from functools import wraps
    @wraps(f)
    def decorated_function(*args, **kwargs):
        if 'loggedin' not in session:
            flash('Please log in to access this page.', 'warning')
            return redirect(url_for('login'))
        user_type = session.get('user_type')
        if user_type not in ['admin', 'salesman']:
            flash('You do not have permission to access this page.', 'danger')
            return redirect(url_for('login'))
        return f(*args, **kwargs)
    return decorated_function

# Public routes (no login required)
@app.route('/')
def index():
    """Public homepage - redirects to product outlet"""
    return redirect(url_for('product_outlet'))

@app.route('/health')
def health_check():
    """Health check endpoint for monitoring"""
    try:
        # Test database connection
        conn = get_db_connection()
        cursor = conn.cursor()
        cursor.execute("SELECT 1")
        cursor.close()
        conn.close()
        
        return jsonify({
            'status': 'healthy',
            'database': 'connected',
            'cloudinary': 'enabled' if config.USE_CLOUDINARY else 'disabled'
        }), 200
    except Exception as e:
        print(f"Health check failed: {e}", file=sys.stderr)
        return jsonify({
            'status': 'unhealthy',
            'error': str(e)
        }), 500

@app.route('/outlet')
def product_outlet():
    """Public product outlet - no login required"""
    # Get pagination parameters
    page = request.args.get('page', 1, type=int)
    per_page = PRODUCTS_PER_PAGE
    offset = (page - 1) * per_page
    
    conn = get_db_connection()
    cursor = conn.cursor(cursor_factory=RealDictCursor)
    
    try:
        # Get total count of visible products
        cursor.execute("""
            SELECT COUNT(*) as total
            FROM products p
            WHERE p.is_visible = TRUE
        """)
        total_products = cursor.fetchone()['total']
        total_pages = max(1, (total_products + per_page - 1) // per_page)
        
        # Get visible products with their variants (paginated)
        cursor.execute("""
            SELECT 
                p.*,
                COUNT(pv.id) as variant_count,
                COALESCE(SUM(pv.quantity), 0) as total_stock,
                COALESCE(SUM(pv.selling_price * COALESCE(pv.quantity, 0)) / NULLIF(SUM(pv.quantity), 0), 0) as avg_price,
                STRING_AGG(DISTINCT pv.size, ',') as sizes,
                '' as colors
            FROM products p
            LEFT JOIN product_variants pv ON p.id = pv.product_id
            WHERE p.is_visible = TRUE AND p.is_deleted = FALSE
            GROUP BY p.id
            ORDER BY p.id DESC
            LIMIT %s OFFSET %s
        """, (per_page, offset))
        
        products = cursor.fetchall()
        
        # Process the results
        for product in products:
            # Clean and deduplicate colors and sizes (remove photo annotations)
            raw_colors = [_clean_text(color.strip()) for color in product['colors'].split(',')] if product['colors'] else []
            raw_sizes = [_clean_text(size.strip()) for size in product['sizes'].split(',')] if product['sizes'] else []
            
            # Remove duplicates while preserving order
            seen_colors = set()
            product['colors'] = [c for c in raw_colors if c and not (c in seen_colors or seen_colors.add(c))]
            
            seen_sizes = set()
            product['sizes'] = sorted([s for s in raw_sizes if s and not (s in seen_sizes or seen_sizes.add(s))], key=_size_sort_key)
            
            # Get all variants for this product (for image gallery navigation)
            cursor.execute("""
                SELECT pv.*
                FROM product_variants pv
                WHERE pv.product_id = %s
                ORDER BY
                         CASE WHEN pv.image_path IS NOT NULL THEN 0 ELSE 1 END,
                         pv.id ASC
            """, (product['id'],))
            product['variants'] = cursor.fetchall()
        
        return render_template('outlet.html', 
                             products=products,
                             page=page,
                             total_pages=total_pages,
                             total_products=total_products)
    finally:
        cursor.close()

# --- Product Outlet Search API ---
@app.route('/api/outlet_search')
def outlet_search():
    """API endpoint to search all products in the outlet by name or code"""
    query = request.args.get('q', '').strip().lower()
    conn = get_db_connection()
    cursor = conn.cursor(cursor_factory=RealDictCursor)
    try:
        sql = """
            SELECT p.*, STRING_AGG(DISTINCT pv.size, ',') as sizes, '' as colors
            FROM products p
            LEFT JOIN product_variants pv ON p.id = pv.product_id
            WHERE p.is_visible = TRUE AND p.is_deleted = FALSE
        """
        params = []
        if query:
            sql += " AND (LOWER(p.product_name) LIKE %s OR LOWER(p.lot_details) LIKE %s)"
            params.extend([f'%{query}%', f'%{query}%'])
        sql += " GROUP BY p.id ORDER BY p.id DESC LIMIT 100"
        cursor.execute(sql, params)
        products = cursor.fetchall()
        
        for product in products:
            # Clean and deduplicate colors and sizes (remove photo annotations)
            raw_colors = [_clean_text(color.strip()) for color in product['colors'].split(',')] if product['colors'] else []
            raw_sizes = [_clean_text(size.strip()) for size in product['sizes'].split(',')] if product['sizes'] else []
            
            # Remove duplicates while preserving order
            seen_colors = set()
            product['colors'] = [c for c in raw_colors if c and not (c in seen_colors or seen_colors.add(c))]
            
            seen_sizes = set()
            product['sizes'] = sorted([s for s in raw_sizes if s and not (s in seen_sizes or seen_sizes.add(s))], key=_size_sort_key)
            
            # Fetch all variants with images for gallery support
            cursor.execute("""
                SELECT pv.id, pv.size, pv.image_path
                FROM product_variants pv
                WHERE pv.product_id = %s
                ORDER BY
                         CASE WHEN pv.image_path IS NOT NULL THEN 0 ELSE 1 END,
                         pv.id ASC
            """, (product['id'],))
            variants = cursor.fetchall()
            
            # Convert variants to list of dicts for JSON serialization
            product['variants'] = []
            for v in variants:
                product['variants'].append({
                    'id': v['id'],
                    'color': _clean_text(v['color']),
                    'size': _clean_text(v['size']),
                    'image_path': v['image_path']
                })
            
            # Set main image from first variant with image
            main_variant = next((v for v in variants if v['image_path']), None)
            if main_variant and main_variant['image_path']:
                img_path = main_variant['image_path']
                product['image_path'] = img_path
                product['image_url'] = img_path if img_path.startswith('http') else url_for('static', filename=img_path)
            else:
                product['image_path'] = None
                product['image_url'] = url_for('static', filename='images/logo.png')
                
        return {"products": products}
    finally:
        cursor.close()
        conn.close()

@app.route('/api/get_unique_lots')
@login_required
def get_unique_lots():
    """Get list of unique lot_details for dropdown"""
    conn = get_db_connection()
    cursor = conn.cursor(cursor_factory=RealDictCursor)
    try:
        cursor.execute("""
            SELECT DISTINCT p.lot_details,
                   COUNT(p.id) as product_count,
                   STRING_AGG(DISTINCT p.product_name, ', ') as product_names
            FROM products p
            WHERE p.is_deleted = FALSE AND p.lot_details IS NOT NULL AND p.lot_details != ''
            GROUP BY p.lot_details
            ORDER BY p.lot_details
        """)
        lots = cursor.fetchall()
        return jsonify({
            'success': True,
            'lots': [
                {
                    'lot_details': lot['lot_details'],
                    'product_count': lot['product_count'],
                    'product_names': lot['product_names']
                }
                for lot in lots
            ]
        })
    except Exception as e:
        return jsonify({'success': False, 'message': str(e)}), 500
    finally:
        cursor.close()
        conn.close()

@app.route('/lot_summary/<lot_details>')
@login_required
def lot_summary(lot_details):
    """Return lot-wise inventory summary grouped by lot_details, showing total and product breakdown."""
    conn = get_db_connection()
    cursor = conn.cursor(cursor_factory=RealDictCursor)
    try:
        # Aggregate variants once (avoids double-counting when joined with sales)
        cursor.execute("""
            WITH lot_products AS (
                SELECT id, product_name, category
                FROM products
                WHERE lot_details = %s AND is_deleted = FALSE
            ),
            variant_totals AS (
                SELECT pv.product_id, COALESCE(SUM(pv.quantity), 0) AS remaining_qty
                FROM product_variants pv
                GROUP BY pv.product_id
            ),
            sales_totals AS (
                SELECT pv.product_id,
                       COALESCE(SUM(s.quantity), 0) AS sold_qty,
                       COALESCE(SUM(s.total_price), 0) AS sales_value,
                       MAX(s.sale_date) AS last_sale_date
                FROM sales s
                JOIN product_variants pv ON s.product_variant_id = pv.id
                GROUP BY pv.product_id
            )
            SELECT lp_dummy.lot_details,
                   COUNT(lp_dummy.id) AS total_products,
                   COALESCE(SUM(vt.remaining_qty), 0) AS total_remaining_qty,
                   COALESCE(SUM(st.sold_qty), 0) AS total_sold_qty,
                   COALESCE(SUM(st.sales_value), 0) AS total_sales_value,
                   MAX(st.last_sale_date) AS last_sale_date
            FROM (
                SELECT id, lot_details FROM products WHERE lot_details = %s AND is_deleted = FALSE
            ) lp_dummy
            LEFT JOIN variant_totals vt ON vt.product_id = lp_dummy.id
            LEFT JOIN sales_totals st ON st.product_id = lp_dummy.id
            GROUP BY lp_dummy.lot_details
        """, (lot_details, lot_details))
        lot_total = cursor.fetchone()
        
        if not lot_total:
            return jsonify({'success': False, 'message': 'Lot not found'}), 404

        total_arrived = float(lot_total['total_remaining_qty'] or 0) + float(lot_total['total_sold_qty'] or 0)

        # Get product-wise breakdown for this lot without duplication
        cursor.execute("""
            WITH variant_totals AS (
                SELECT pv.product_id,
                       COALESCE(SUM(pv.quantity), 0) AS remaining_qty,
                       MAX(pv.selling_price) AS selling_price,
                       MAX(pv.buying_price) AS buying_price,
                       STRING_AGG(DISTINCT pv.size, ', ') AS sizes,
                       '' AS colors
                FROM product_variants pv
                GROUP BY pv.product_id
            ),
            sales_totals AS (
                SELECT pv.product_id,
                       COALESCE(SUM(s.quantity), 0) AS sold_qty,
                       COALESCE(SUM(s.total_price), 0) AS sales_value
                FROM sales s
                JOIN product_variants pv ON s.product_variant_id = pv.id
                GROUP BY pv.product_id
            )
            SELECT p.id, p.product_name, p.category,
                   COALESCE(vt.remaining_qty, 0) AS remaining_qty,
                   COALESCE(st.sold_qty, 0) AS sold_qty,
                   COALESCE(st.sales_value, 0) AS sales_value,
                   COALESCE(vt.selling_price, 0) AS selling_price,
                   COALESCE(vt.buying_price, 0) AS buying_price,
                   COALESCE(vt.colors, '') AS colors,
                   COALESCE(vt.sizes, '') AS sizes
            FROM products p
            LEFT JOIN variant_totals vt ON vt.product_id = p.id
            LEFT JOIN sales_totals st ON st.product_id = p.id
            WHERE p.lot_details = %s AND p.is_deleted = FALSE
            ORDER BY p.product_name
        """, (lot_details,))
        products = cursor.fetchall()

        return jsonify({
            'success': True,
            'lot_total': {
                'lot_details': lot_total['lot_details'],
                'total_products': lot_total['total_products'],
                'total_remaining_qty': float(lot_total['total_remaining_qty'] or 0),
                'total_sold_qty': float(lot_total['total_sold_qty'] or 0),
                'total_arrived': total_arrived,
                'total_sales_value': float(lot_total['total_sales_value'] or 0),
                'last_sale_date': lot_total['last_sale_date'].strftime('%Y-%m-%d') if lot_total['last_sale_date'] else None
            },
            'products': [
                {
                    'id': p['id'],
                    'product_name': p['product_name'],
                    'category': p['category'],
                    'colors': p['colors'],
                    'sizes': p['sizes'],
                    'remaining_qty': float(p['remaining_qty'] or 0),
                    'sold_qty': float(p['sold_qty'] or 0),
                    'total_arrived': float(p['remaining_qty'] or 0) + float(p['sold_qty'] or 0),
                    'sales_value': float(p['sales_value'] or 0),
                    'selling_price': float(p['selling_price'] or 0),
                    'buying_price': float(p['buying_price'] or 0)
                }
                for p in products
            ]
        })
    except Exception as e:
        return jsonify({'success': False, 'message': str(e)}), 500
    finally:
        cursor.close()
        conn.close()

# Authentication routes
@app.route('/login', methods=['GET', 'POST'])
def login():
    """Login page for admin"""
    if request.method == 'POST':
        username = request.form.get('username', '').strip()
        password = request.form.get('password', '').encode('utf-8')
        
        # First try to authenticate from database
        conn = None
        authenticated = False
        user_from_db = False
        
        try:
            conn = get_db_connection()
            cursor = conn.cursor(cursor_factory=RealDictCursor)
            cursor.execute("SELECT * FROM users WHERE username = %s", (username,))
            user = cursor.fetchone()
            
            if user and user['password_hash']:
                # User found in database, check password
                if bcrypt.checkpw(password, user['password_hash'].encode('utf-8')):
                    authenticated = True
                    user_from_db = True
                    session['loggedin'] = True
                    session['id'] = user['id']
                    session['username'] = user['username']
                    session['user_type'] = user.get('user_type', 'admin')
        except Exception as e:
            print(f"Database authentication error: {e}", file=sys.stderr)
        finally:
            if conn:
                conn.close()
        
        # Fallback to environment variables if database auth fails
        if not authenticated and username == config.ADMIN_USERNAME and config.ADMIN_PASSWORD_HASH:
            try:
                if bcrypt.checkpw(password, config.ADMIN_PASSWORD_HASH.encode('utf-8')):
                    authenticated = True
                    session['loggedin'] = True
                    session['id'] = 1
                    session['username'] = username
                    session['user_type'] = 'admin'
            except Exception as e:
                print(f"Environment auth error: {e}", file=sys.stderr)
        
        if authenticated:
            flash('Login successful!', 'success')
            # Redirect based on user type
            if session.get('user_type') == 'salesman':
                return redirect(url_for('products'))
            else:
                return redirect(url_for('dashboard'))
        else:
            flash('Invalid username or password!', 'danger')
    
    return render_template('login.html')

@app.route('/change_password', methods=['GET', 'POST'])
def change_password():
    """Change password page - accessible from login"""
    if request.method == 'POST':
        username = request.form.get('username', '').strip()
        current_password = request.form.get('current_password', '').encode('utf-8')
        new_password = request.form.get('new_password', '')
        confirm_password = request.form.get('confirm_password', '')
        
        # Validate inputs
        if not username or not current_password or not new_password or not confirm_password:
            flash('All fields are required!', 'danger')
            return render_template('change_password.html')
        
        if new_password != confirm_password:
            flash('New passwords do not match!', 'danger')
            return render_template('change_password.html')
        
        if len(new_password) < 6:
            flash('Password must be at least 6 characters long!', 'danger')
            return render_template('change_password.html')
        
        conn = None
        try:
            conn = get_db_connection()
            cursor = conn.cursor(cursor_factory=RealDictCursor)
            
            # Try database authentication first
            cursor.execute("SELECT * FROM users WHERE username = %s", (username,))
            user = cursor.fetchone()
            
            authenticated = False
            if user and user['password_hash']:
                # Verify current password from database
                if bcrypt.checkpw(current_password, user['password_hash'].encode('utf-8')):
                    authenticated = True
            elif username == config.ADMIN_USERNAME and config.ADMIN_PASSWORD_HASH:
                # Fallback to environment variables
                if bcrypt.checkpw(current_password, config.ADMIN_PASSWORD_HASH.encode('utf-8')):
                    authenticated = True
            
            if not authenticated:
                flash('Invalid username or current password!', 'danger')
                return render_template('change_password.html')
            
            # Hash new password
            new_password_hash = bcrypt.hashpw(new_password.encode('utf-8'), bcrypt.gensalt()).decode('utf-8')
            
            if user:
                # Update existing user
                cursor.execute(
                    "UPDATE users SET password_hash = %s, updated_at = CURRENT_TIMESTAMP WHERE id = %s",
                    (new_password_hash, user['id'])
                )
            else:
                # Create new user (for users authenticated via env vars only)
                cursor.execute(
                    "INSERT INTO users (username, password_hash, user_type) VALUES (%s, %s, %s)",
                    (username, new_password_hash, 'admin')
                )
            
            conn.commit()
            flash('Password changed successfully! Please login with your new password.', 'success')
            return redirect(url_for('login'))
            
        except Exception as e:
            print(f"Change password error: {e}", file=sys.stderr)
            if conn:
                conn.rollback()
            flash('An error occurred. Please try again.', 'danger')
            return render_template('change_password.html')
        finally:
            if conn:
                conn.close()
    
    return render_template('change_password.html')

@app.route('/logout')
def logout():
    """Logout user"""
    session.clear()
    flash('You have been logged out successfully.', 'info')
    return redirect(url_for('product_outlet'))

@app.route('/api/create-salesman', methods=['POST'])
@login_required
def create_salesman():
    """Admin-only endpoint to create salesman account"""
    if session.get('user_type') != 'admin':
        return {'error': 'Only admins can create accounts'}, 403
    
    conn = None
    try:
        conn = get_db_connection()
        cursor = conn.cursor(cursor_factory=RealDictCursor)
        
        # Check if salesman1 already exists
        cursor.execute("SELECT id FROM users WHERE username = %s", ('salesman1',))
        existing = cursor.fetchone()
        
        if existing:
            return {'status': 'exists', 'message': 'salesman1 account already exists'}, 200
        
        # Create salesman account
        password_hash = bcrypt.hashpw('user1shtrading'.encode('utf-8'), bcrypt.gensalt()).decode('utf-8')
        cursor.execute(
            "INSERT INTO users (username, password_hash, user_type) VALUES (%s, %s, %s)",
            ('salesman1', password_hash, 'salesman')
        )
        conn.commit()
        
        return {'status': 'created', 'message': 'Salesman account created successfully'}, 201
    except Exception as e:
        print(f"Error creating salesman: {e}", file=sys.stderr)
        if conn:
            conn.rollback()
        return {'error': str(e)}, 500
    finally:
        if conn:
            conn.close()

# Admin routes (protected)
@app.route('/dashboard')
@login_required
def dashboard():
    conn = get_db_connection()
    cursor = conn.cursor(cursor_factory=RealDictCursor)
    
    try:
        # Get total products count
        cursor.execute("SELECT COUNT(*) as total FROM products")
        total_products = cursor.fetchone()['total']
        
        # Get total sales
        cursor.execute("SELECT SUM(total_price) as total_sales FROM sales")
        total_sales = cursor.fetchone()['total_sales'] or 0
        
        # Get active shipments (status is NOT 'arrived')
        cursor.execute("""
            SELECT COUNT(*) as active_shipments
            FROM shipments
            WHERE status != 'arrived'
        """)
        active_shipments = cursor.fetchone()['active_shipments']
        
        # Get recent sales (fixed ordering - DESC means newest first)
        cursor.execute("""
            SELECT s.*, p.product_name, pv.size
            FROM sales s
            JOIN product_variants pv ON s.product_variant_id = pv.id
            JOIN products p ON pv.product_id = p.id
            ORDER BY s.id DESC LIMIT 5
        """)
        recent_sales = cursor.fetchall()
        
        return render_template('dashboard.html',
                             total_products=total_products,
                             total_sales=total_sales,
                             active_shipments=active_shipments,
                             recent_sales=recent_sales)
    finally:
        cursor.close()
        conn.close()

@app.route('/products')
@sales_access_required
def products():
    # No pagination - load all products for client-side filtering
    conn = get_db_connection()
    cursor = conn.cursor(cursor_factory=RealDictCursor)
    
    try:
        # Get total count of products
        cursor.execute("SELECT COUNT(*) as total FROM products")
        total_products = cursor.fetchone()['total']
        
        # Get ALL products with their variants aggregated (no pagination)
        cursor.execute("""
            SELECT 
                p.*,
                COUNT(pv.id) as variant_count,
                COALESCE(SUM(pv.quantity), 0) as total_stock,
                COALESCE(SUM(pv.selling_price * COALESCE(pv.quantity, 0)) / NULLIF(SUM(pv.quantity), 0), 0) as avg_price,
                (
                    SELECT COUNT(*) 
                    FROM sales s 
                    JOIN product_variants pv2 ON s.product_variant_id = pv2.id 
                    WHERE pv2.product_id = p.id
                ) as total_sales,
                STRING_AGG(DISTINCT pv.size, ',') as sizes,
                '' as colors
            FROM products p
            LEFT JOIN product_variants pv ON p.id = pv.product_id
            WHERE p.is_deleted = FALSE
            GROUP BY p.id
            ORDER BY p.id DESC
        """)
        
        products = cursor.fetchall()
        
        # Process the results to convert comma-separated values to lists
        for product in products:
            # Clean and deduplicate colors and sizes (remove photo annotations)
            raw_colors = [_clean_text(color.strip()) for color in product['colors'].split(',')] if product['colors'] else []
            raw_sizes = [_clean_text(size.strip()) for size in product['sizes'].split(',')] if product['sizes'] else []
            
            # Remove duplicates while preserving order
            seen_colors = set()
            product['colors'] = [c for c in raw_colors if c and not (c in seen_colors or seen_colors.add(c))]
            
            seen_sizes = set()
            product['sizes'] = sorted([s for s in raw_sizes if s and not (s in seen_sizes or seen_sizes.add(s))], key=_size_sort_key)

            # Get all variants for this product (for image gallery navigation)
            cursor.execute("""
                SELECT pv.*,
                       (SELECT COUNT(*) FROM sales s WHERE s.product_variant_id = pv.id) as sales_count
                FROM product_variants pv
                WHERE pv.product_id = %s
                ORDER BY
                         CASE WHEN pv.image_path IS NOT NULL THEN 0 ELSE 1 END,
                         pv.id ASC
        """, (product['id'],))
            product['variants'] = cursor.fetchall()

        return render_template('products.html', 
                             products=products,
                             total_products=total_products)
    finally:
        cursor.close()
        conn.close()

@app.route('/get_product_details/<int:product_id>')
@sales_access_required
def get_product_details(product_id):
    conn = get_db_connection()
    cursor = conn.cursor(cursor_factory=RealDictCursor)
    
    try:
        # Get product basic info
        cursor.execute("""
            SELECT 
                p.*,
                COALESCE(SUM(pv.quantity), 0) as total_stock,
                COALESCE(SUM(pv.selling_price * COALESCE(pv.quantity, 0)) / NULLIF(SUM(pv.quantity), 0), 0) as avg_price,
                (
                    SELECT COUNT(*) 
                    FROM sales s 
                    JOIN product_variants pv2 ON s.product_variant_id = pv2.id 
                    WHERE pv2.product_id = p.id
                ) as total_sales
            FROM products p
            LEFT JOIN product_variants pv ON p.id = pv.product_id
            WHERE p.id = %s AND p.is_deleted = FALSE
            GROUP BY p.id
        """, (product_id,))
        
        product = cursor.fetchone()
        
        if product:
            # Get all variants with sales count
            # Order by size, then prioritize those with images
            cursor.execute("""
                SELECT pv.*,
                       (SELECT COUNT(*) FROM sales s WHERE s.product_variant_id = pv.id) as sales_count
            FROM product_variants pv
            WHERE pv.product_id = %s
            ORDER BY 
                CASE WHEN pv.image_path IS NOT NULL THEN 0 ELSE 1 END,
                array_position(ARRAY['XS','S','M','L','XL','XXL','28','30','32','34','36']::text[], pv.size),
                pv.size
        """, (product_id,))
            product['variants'] = cursor.fetchall()

            return jsonify({'success': True, 'product': product})
        else:
            return jsonify({'success': False, 'message': 'Product not found'})
            
    except Exception as e:
        return jsonify({'success': False, 'message': str(e)})
    finally:
        cursor.close()

@app.route('/add_product', methods=['POST'])
@sales_access_required
def add_product():
    if request.method == 'POST':
        # Product details
        product_name = request.form['product_name']
        lot_details = request.form['lot_details']  # Lot Details (can be repeated)
        category = request.form.get('category', '')
        description = request.form.get('description', '')
        
        # Simplified wholesale inputs
        colors_text = request.form.get('colors_text', '')
        sizes_text = request.form.get('sizes_text', '')
        quantity = int(request.form.get('quantity', 0))
        buying_price = float(request.form.get('buying_price', 0))
        selling_price = float(request.form.get('selling_price', 0))
        
        # Handle multiple image uploads
        images = request.files.getlist('images')
        image_paths = []
        for image_file in images:
            if image_file and image_file.filename:
                uploaded_path = upload_image(image_file)
                if uploaded_path:
                    image_paths.append(uploaded_path)
        
        conn = get_db_connection()
        cursor = conn.cursor()
        
        try:
            # Insert product with colors and sizes as text
            cursor.execute("""
                INSERT INTO products (product_name, lot_details, category, description, is_visible, is_deleted)
                VALUES (%s, %s, %s, %s, TRUE, FALSE)
                RETURNING id
            """, (product_name, lot_details, category, description))
            
            # Get the inserted product ID
            product_id = cursor.fetchone()[0]
            
            # Create a single master variant to maintain compatibility with existing system
            # Use first image if available, otherwise None
            primary_image = image_paths[0] if image_paths else None
            
            cursor.execute("""
                INSERT INTO product_variants 
                (product_id, size, color, quantity, buying_price, selling_price, image_path)
                VALUES (%s, %s, %s, %s, %s, %s, %s)
            """, (product_id, sizes_text, colors_text, quantity, buying_price, selling_price, primary_image))
            
            # If there are additional images, create additional variants with same data but different images
            for i, image_path in enumerate(image_paths[1:], start=1):
                cursor.execute("""
                    INSERT INTO product_variants
                    (product_id, size, color, quantity, buying_price, selling_price, image_path)
                    VALUES (%s, %s, %s, %s, %s, %s, %s)
                """, (product_id, sizes_text, colors_text, 0, buying_price, selling_price, image_path))
            
            conn.commit()
            flash('Wholesale product added successfully!', 'success')
            
        except Exception as e:
            conn.rollback()
            print(f"Error adding product: {str(e)}")
            flash(f'Error adding product: {str(e)}', 'danger')
        
        finally:
            cursor.close()
            conn.close()
        
        return redirect(url_for('products'))




@app.route('/add_variant', methods=['POST'])
@sales_access_required
def add_variant():
    product_id = request.form['product_id']
    size = request.form['size']
    color = request.form['color']
    quantity = request.form['quantity'] or 0
    buying_price = request.form['buying_price'] or 0
    selling_price = request.form['selling_price'] or 0
    
    # Handle file upload
    image_path = None
    if 'image' in request.files:
        file = request.files['image']
        if file and file.filename != '' and allowed_file(file.filename):
            filename = secure_filename(file.filename)
            # Create unique filename
            unique_filename = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_{filename}"
            file_path = os.path.join(app.config['UPLOAD_FOLDER'], unique_filename)
            file.save(file_path)
            image_path = f"uploads/{unique_filename}"
    
    conn = get_db_connection()
    cursor = conn.cursor()
    
    try:
        # Check if variant already exists
        cursor.execute("""
            SELECT id FROM product_variants 
            WHERE product_id = %s AND size = %s AND color = %s
        """, (product_id, size, color))
        
        existing_variant = cursor.fetchone()
        
        if existing_variant:
            flash('A variant with this size and color already exists!', 'danger')
        else:
            cursor.execute("""
                INSERT INTO product_variants 
                (product_id, size, color, quantity, buying_price, selling_price, image_path)
                VALUES (%s, %s, %s, %s, %s, %s, %s)
            """, (product_id, size, color, quantity, buying_price, selling_price, image_path))
            
            conn.commit()
            flash('Product variant added successfully!', 'success')
            
    except Exception as e:
        conn.rollback()
        flash(f'Error adding variant: {str(e)}', 'danger')
    
    finally:
        cursor.close()
    
    return redirect(url_for('products'))

@app.route('/edit_product/<int:product_id>', methods=['POST'])
@sales_access_required
def edit_product(product_id):
    if request.method == 'POST':
        # Product details
        product_name = request.form['product_name']
        lot_details = request.form['lot_details']  # Lot Details (can be repeated)
        category = request.form.get('category', '')
        description = request.form.get('description', '')
        
        # Simplified wholesale inputs
        colors_text = request.form.get('colors_text', '')
        sizes_text = request.form.get('sizes_text', '')
        quantity = int(request.form.get('quantity', 0))
        buying_price = float(request.form.get('buying_price', 0))
        selling_price = float(request.form.get('selling_price', 0))
        variant_id = request.form.get('variant_id', '')
        
        # Handle multiple image uploads
        images = request.files.getlist('images')
        new_image_paths = []
        for image_file in images:
            if image_file and image_file.filename:
                uploaded_path = upload_image(image_file)
                if uploaded_path:
                    new_image_paths.append(uploaded_path)
        
        conn = get_db_connection()
        cursor = conn.cursor()
        
        try:
            # Update product basic info
            cursor.execute("""
                UPDATE products 
                SET product_name = %s, lot_details = %s, category = %s, description = %s
                WHERE id = %s
            """, (product_name, lot_details, category, description, product_id))
            
            # Update the master variant if it exists
            if variant_id and variant_id.isdigit():
                # Update master variant data WITHOUT changing its image
                cursor.execute("""
                    UPDATE product_variants
                    SET size = %s, color = %s, quantity = %s,
                        buying_price = %s, selling_price = %s
                    WHERE id = %s
                """, (sizes_text, colors_text, quantity, buying_price, selling_price, int(variant_id)))
                
                # Add ALL new images as separate variants (don't replace any existing images)
                for i, image_path in enumerate(new_image_paths, start=0):
                    cursor.execute("""
                        INSERT INTO product_variants
                        (product_id, size, color, quantity, buying_price, selling_price, image_path)
                        VALUES (%s, %s, %s, %s, %s, %s, %s)
                    """, (product_id, sizes_text, colors_text, 0, buying_price, selling_price, image_path))
            else:
                # Create new master variant if none exists
                primary_image = new_image_paths[0] if new_image_paths else None
                cursor.execute("""
                    INSERT INTO product_variants 
                    (product_id, size, color, quantity, buying_price, selling_price, image_path)
                    VALUES (%s, %s, %s, %s, %s, %s, %s)
                """, (product_id, sizes_text, colors_text, quantity, buying_price, selling_price, primary_image))
                
                # Add additional images
                for i, image_path in enumerate(new_image_paths[1:], start=1):
                    cursor.execute("""
                        INSERT INTO product_variants
                        (product_id, size, color, quantity, buying_price, selling_price, image_path)
                        VALUES (%s, %s, %s, %s, %s, %s, %s)
                    """, (product_id, sizes_text, colors_text, 0, buying_price, selling_price, image_path))
            
            conn.commit()
            flash('Product updated successfully!', 'success')
            
        except Exception as e:
            conn.rollback()
            print(f"Error updating product: {str(e)}")
            flash(f'Error updating product: {str(e)}', 'danger')
        
        finally:
            cursor.close()
            conn.close()
        
        return redirect(url_for('products'))

@app.route('/bulk_delete_products', methods=['POST'])
@sales_access_required
def bulk_delete_products():
    """Bulk delete products, skipping those with sales records"""
    try:
        data = request.get_json()
        product_ids = data.get('product_ids', [])
        
        if not product_ids:
            return jsonify({'success': False, 'message': 'No products selected'})
        
        conn = get_db_connection()
        cursor = conn.cursor(cursor_factory=RealDictCursor)
        
        deleted_count = 0
        skipped_count = 0
        skipped_products = []
        
        for product_id in product_ids:
            try:
                # Check if product has any sales
                cursor.execute("""
                    SELECT COUNT(*) as sales_count,
                           p.product_name
                    FROM sales s 
                    JOIN product_variants pv ON s.product_variant_id = pv.id 
                    JOIN products p ON pv.product_id = p.id
                    WHERE pv.product_id = %s
                    GROUP BY p.product_name
                """, (product_id,))
                result = cursor.fetchone()
                
                if result and result['sales_count'] > 0:
                    # Skip products with sales
                    skipped_count += 1
                    skipped_products.append(result['product_name'])
                else:
                    # Delete product (cascades to variants)
                    cursor.execute("DELETE FROM products WHERE id = %s", (product_id,))
                    deleted_count += 1
            except Exception as e:
                print(f"Error deleting product {product_id}: {str(e)}")
                skipped_count += 1
        
        conn.commit()
        
        # Build message
        message = f'Successfully deleted {deleted_count} product(s).'
        if skipped_count > 0:
            message += f' Skipped {skipped_count} product(s) with sales records.'
        
        cursor.close()
        conn.close()
        
        return jsonify({
            'success': True, 
            'message': message,
            'deleted': deleted_count,
            'skipped': skipped_count
        })
        
    except Exception as e:
        print(f"Bulk delete error: {str(e)}")
        return jsonify({'success': False, 'message': f'Error: {str(e)}'})

@app.route('/delete_product/<int:product_id>', methods=['POST'])
@sales_access_required
def delete_product(product_id):
    conn = get_db_connection()
    cursor = conn.cursor(cursor_factory=RealDictCursor)
    
    try:
        # Mark product as deleted (soft delete to preserve history)
        cursor.execute("UPDATE products SET is_deleted = TRUE WHERE id = %s", (product_id,))
        conn.commit()
        flash('Product deleted successfully!', 'success')
        
    except Exception as e:
        conn.rollback()
        flash(f'Error deleting product: {str(e)}', 'danger')
    
    finally:
        cursor.close()
    
    return redirect(url_for('products'))

@app.route('/delete_variant/<int:variant_id>', methods=['POST'])
@sales_access_required
def delete_variant(variant_id):
    conn = get_db_connection()
    cursor = conn.cursor(cursor_factory=RealDictCursor)
    
    try:
        # Check if variant has sales (but allow deletion anyway)
        cursor.execute("SELECT COUNT(*) as sales_count FROM sales WHERE product_variant_id = %s", (variant_id,))
        result = cursor.fetchone()
        
        if result['sales_count'] > 0:
            flash(f'Variant deleted successfully! Note: {result["sales_count"]} sales records still exist for this variant.', 'warning')
        
        # Get product info
        cursor.execute("SELECT product_id FROM product_variants WHERE id = %s", (variant_id,))
        variant = cursor.fetchone()
        
        if variant:
            # Delete the variant
            cursor.execute("DELETE FROM product_variants WHERE id = %s", (variant_id,))
            
            # Check if this was the last variant
            cursor.execute("SELECT COUNT(*) as variant_count FROM product_variants WHERE product_id = %s", (variant['product_id'],))
            product_result = cursor.fetchone()
            
            if product_result['variant_count'] == 0:
                # Delete the product too if no variants left
                cursor.execute("DELETE FROM products WHERE id = %s", (variant['product_id'],))
                flash('Product and last variant deleted successfully!', 'success')
            else:
                flash('Product variant deleted successfully!', 'success')
            
            conn.commit()
        
    except Exception as e:
        conn.rollback()
        flash(f'Error deleting variant: {str(e)}', 'danger')
    
    finally:
        cursor.close()
    
    return redirect(url_for('products'))

@app.route('/shipments')
@login_required
def shipments():
    conn = get_db_connection()
    cursor = conn.cursor(cursor_factory=RealDictCursor)
    
    # Ensure shipping_date and receiving_date columns exist
    try:
        cursor.execute("""
            SELECT column_name FROM information_schema.columns
            WHERE table_name = 'shipments' AND column_name = 'shipping_date'
        """)
        if not cursor.fetchone():
            cursor.execute("ALTER TABLE shipments ADD COLUMN shipping_date DATE DEFAULT NULL")
            conn.commit()
        
        cursor.execute("""
            SELECT column_name FROM information_schema.columns
            WHERE table_name = 'shipments' AND column_name = 'receiving_date'
        """)
        if not cursor.fetchone():
            cursor.execute("ALTER TABLE shipments ADD COLUMN receiving_date DATE DEFAULT NULL")
            conn.commit()
    except Exception as e:
        conn.rollback()
        print(f"Error adding columns: {e}")
    
    # Get all shipments with product information
    cursor.execute("""
        SELECT s.*, 
               c.company_name,
               (SELECT STRING_AGG(DISTINCT si.product_name || ' (' || si.quantity || ')', ', ' ORDER BY si.product_name || ' (' || si.quantity || ')')
                FROM shipment_items si 
                WHERE si.shipment_id = s.id) as products_list,
               (SELECT COUNT(*) FROM shipment_items si WHERE si.shipment_id = s.id) as item_count
        FROM shipments s 
        LEFT JOIN companies c ON s.company_id = c.id
        ORDER BY s.ordered_date DESC
    """)
    shipments_list = cursor.fetchall()
    
    # Calculate statistics
    # Active Shipments: status is NOT 'arrived'
    cursor.execute("""
        SELECT COUNT(*) as count
        FROM shipments
        WHERE status != 'arrived'
    """)
    shipments_active = cursor.fetchone()['count']
    
    # Arrived: status is 'arrived'
    cursor.execute("""
        SELECT COUNT(*) as count
        FROM shipments
        WHERE status = 'arrived'
    """)
    shipments_arrived = cursor.fetchone()['count']
    
    # In Transit: status is 'shipped'
    cursor.execute("""
        SELECT COUNT(*) as count
        FROM shipments
        WHERE status = 'shipped'
    """)
    shipments_in_transit = cursor.fetchone()['count']
    
    # Delayed: status is 'ordered' AND receiving_date has passed
    cursor.execute("""
        SELECT COUNT(*) as count
        FROM shipments
        WHERE status = 'ordered' 
        AND receiving_date IS NOT NULL
        AND receiving_date < CURRENT_DATE
    """)
    shipments_delayed = cursor.fetchone()['count']
    
    # Get all companies (for supplier section and shipment forms)
    cursor.execute("""
        SELECT id, company_name, contact_person, email, phone, address, 
               COALESCE(total_due, 0) as total_due,
               COALESCE(advance_payment, 0) as advance_payment
        FROM companies 
        ORDER BY id DESC
    """)
    companies = cursor.fetchall()
    
    cursor.close()
    conn.close()
    
    return render_template('shipments.html', 
                         shipments=shipments_list, 
                         companies=companies,
                         shipments_active=shipments_active,
                         shipments_delivered=shipments_arrived,
                         shipments_in_transit=shipments_in_transit,
                         shipments_delayed=shipments_delayed)

@app.route('/add_shipment', methods=['POST'])
@login_required
def add_shipment():
    shipment_number = request.form['shipment_number']
    company_id = request.form['company_id']
    ordered_date = request.form['ordered_date']
    shipping_date = request.form.get('shipping_date') or None
    receiving_date = request.form.get('receiving_date') or None
    status = request.form['status']
    
    # Get product arrays
    product_names = request.form.getlist('product_name[]')
    quantities = request.form.getlist('quantity[]')
    unit_prices = request.form.getlist('unit_price[]')
    
    # Calculate total shipping cost from products
    shipping_cost = 0
    for i in range(len(product_names)):
        qty = float(quantities[i] or 0)
        price = float(unit_prices[i] or 0)
        shipping_cost += qty * price
    
    conn = get_db_connection()
    cursor = conn.cursor(cursor_factory=RealDictCursor)
    try:
        # Insert shipment
        cursor.execute("""
            INSERT INTO shipments 
            (shipment_number, company_id, shipping_cost, ordered_date, shipping_date, receiving_date, status)
            VALUES (%s, %s, %s, %s, %s, %s, %s) RETURNING id
        """, (shipment_number, company_id, shipping_cost, ordered_date, shipping_date, receiving_date, status))
        
        shipment_id = cursor.fetchone()['id']
        
        # Insert shipment items
        for i in range(len(product_names)):
            product_name = product_names[i]
            quantity = float(quantities[i] or 0)
            unit_price = float(unit_prices[i] or 0)
            total_price = quantity * unit_price
            
            cursor.execute("""
                INSERT INTO shipment_items 
                (shipment_id, product_name, quantity, unit_price, total_price)
                VALUES (%s, %s, %s, %s, %s)
            """, (shipment_id, product_name, quantity, unit_price, total_price))
        
        # Get company's advance payment if any
        cursor.execute("""
            SELECT COALESCE(advance_payment, 0) as advance_payment 
            FROM companies 
            WHERE id = %s
        """, (company_id,))
        company = cursor.fetchone()
        advance = float(company['advance_payment']) if company else 0
        
        # Deduct advance payment from the new due
        if advance > 0:
            if advance >= shipping_cost:
                # Advance covers full shipment cost
                remaining_advance = advance - shipping_cost
                cursor.execute("""
                    UPDATE companies 
                    SET advance_payment = %s
                    WHERE id = %s
                """, (remaining_advance, company_id))
                flash(f'Shipment added! Cost BDT {shipping_cost:,.2f} deducted from advance. Remaining advance: BDT {remaining_advance:,.2f}', 'success')
            else:
                # Advance partially covers cost
                remaining_due = shipping_cost - advance
                cursor.execute("""
                    UPDATE companies 
                    SET total_due = COALESCE(total_due, 0) + %s,
                        advance_payment = 0
                    WHERE id = %s
                """, (remaining_due, company_id))
                flash(f'Shipment added! BDT {advance:,.2f} advance used. Remaining due: BDT {remaining_due:,.2f}', 'success')
        else:
            # No advance, add full cost to due
            cursor.execute("""
                UPDATE companies 
                SET total_due = COALESCE(total_due, 0) + %s
                WHERE id = %s
            """, (shipping_cost, company_id))
            flash('Shipment added successfully!', 'success')
        
        conn.commit()
    except Exception as e:
        conn.rollback()
        flash(f'Error adding shipment: {str(e)}', 'danger')
        print(f"Error in add_shipment: {str(e)}", file=sys.stderr)
    finally:
        cursor.close()
        conn.close()
    
    return redirect(url_for('shipments'))

@app.route('/edit_shipment/<int:shipment_id>', methods=['POST'])
@login_required
def edit_shipment(shipment_id):
    shipment_number = request.form['shipment_number']
    company_id = request.form['company_id']
    ordered_date = request.form['ordered_date']
    shipping_date = request.form.get('shipping_date') or None
    receiving_date = request.form.get('receiving_date') or None
    status = request.form['status']
    
    # Get product arrays
    product_names = request.form.getlist('product_name[]')
    quantities = request.form.getlist('quantity[]')
    unit_prices = request.form.getlist('unit_price[]')
    
    # Calculate total shipping cost from products
    shipping_cost = 0
    for i in range(len(product_names)):
        qty = float(quantities[i] or 0)
        price = float(unit_prices[i] or 0)
        shipping_cost += qty * price
    
    conn = get_db_connection()
    cursor = conn.cursor(cursor_factory=RealDictCursor)
    try:
        # Get old shipment cost and company_id
        cursor.execute("SELECT shipping_cost, company_id FROM shipments WHERE id = %s", (shipment_id,))
        old_shipment = cursor.fetchone()
        
        if old_shipment:
            old_cost = float(old_shipment['shipping_cost'] or 0)
            old_company_id = old_shipment['company_id']
            
            # Update shipment
            cursor.execute("""
                UPDATE shipments 
                SET shipment_number = %s, company_id = %s, shipping_cost = %s, 
                    ordered_date = %s, shipping_date = %s, receiving_date = %s, status = %s
                WHERE id = %s
            """, (shipment_number, company_id, shipping_cost, ordered_date, shipping_date, receiving_date, status, shipment_id))
            
            # Delete old shipment items
            cursor.execute("DELETE FROM shipment_items WHERE shipment_id = %s", (shipment_id,))
            
            # Insert new shipment items
            for i in range(len(product_names)):
                product_name = product_names[i]
                quantity = float(quantities[i] or 0)
                unit_price = float(unit_prices[i] or 0)
                total_price = quantity * unit_price
                
                cursor.execute("""
                    INSERT INTO shipment_items 
                    (shipment_id, product_name, quantity, unit_price, total_price)
                    VALUES (%s, %s, %s, %s, %s)
                """, (shipment_id, product_name, quantity, unit_price, total_price))
            
            # Adjust old company's total_due (subtract old cost)
            if old_company_id:
                cursor.execute("""
                    UPDATE companies 
                    SET total_due = COALESCE(total_due, 0) - %s
                    WHERE id = %s
                """, (old_cost, old_company_id))
            
            # Adjust new company's total_due (add new cost)
            cursor.execute("""
                UPDATE companies 
                SET total_due = COALESCE(total_due, 0) + %s
                WHERE id = %s
            """, (shipping_cost, company_id))
            
            conn.commit()
            flash('Shipment updated successfully!', 'success')
        else:
            flash('Shipment not found!', 'danger')
    except Exception as e:
        conn.rollback()
        flash(f'Error updating shipment: {str(e)}', 'danger')
        print(f"Error in edit_shipment: {str(e)}", file=sys.stderr)
    finally:
        cursor.close()
        conn.close()
    
    return redirect(url_for('shipments'))

@app.route('/delete_shipment/<int:shipment_id>', methods=['POST'])
@login_required
def delete_shipment(shipment_id):
    conn = get_db_connection()
    cursor = conn.cursor()
    try:
        # Fetch shipment cost and company for due adjustment
        cursor.execute("SELECT shipping_cost, company_id FROM shipments WHERE id = %s", (shipment_id,))
        row = cursor.fetchone()
        shipping_cost = float(row[0] or 0) if row else 0.0
        company_id = row[1] if row else None

        cursor.execute("DELETE FROM shipments WHERE id = %s", (shipment_id,))

        # Reduce company's total_due by the shipment cost (never below zero)
        if company_id:
            cursor.execute(
                """
                UPDATE companies
                SET total_due = GREATEST(COALESCE(total_due, 0) - %s, 0)
                WHERE id = %s
                """,
                (shipping_cost, company_id),
            )
        conn.commit()
        flash('Shipment deleted successfully!', 'success')
    except Exception as e:
        conn.rollback()
        flash(f'Error deleting shipment: {str(e)}', 'danger')
    finally:
        cursor.close()
    
    return redirect(url_for('shipments'))

@app.route('/api/shipment/<int:shipment_id>')
@login_required
def get_shipment_details(shipment_id):
    """Get shipment details including all items"""
    conn = get_db_connection()
    cursor = conn.cursor(cursor_factory=RealDictCursor)
    
    try:
        # Get shipment info
        cursor.execute("""
            SELECT s.*, c.company_name
            FROM shipments s
            LEFT JOIN companies c ON s.company_id = c.id
            WHERE s.id = %s
        """, (shipment_id,))
        shipment = cursor.fetchone()
        
        if not shipment:
            return jsonify({'success': False, 'message': 'Shipment not found'}), 404
        
        # Get shipment items
        cursor.execute("""
            SELECT product_name, quantity, unit_price, total_price
            FROM shipment_items
            WHERE shipment_id = %s
            ORDER BY id
        """, (shipment_id,))
        items = cursor.fetchall()
        
        return jsonify({
            'success': True,
            'shipment': dict(shipment),
            'items': [dict(item) for item in items]
        })
    except Exception as e:
        print(f"Error fetching shipment details: {str(e)}", file=sys.stderr)
        return jsonify({'success': False, 'message': str(e)}), 500
    finally:
        cursor.close()
        conn.close()

@app.route('/bulk_delete_shipments', methods=['POST'])
@login_required
def bulk_delete_shipments():
    """Bulk delete shipments"""
    try:
        data = request.get_json()
        shipment_ids = data.get('shipment_ids', [])
        
        if not shipment_ids:
            return jsonify({'success': False, 'message': 'No shipments selected'})
        
        conn = get_db_connection()
        cursor = conn.cursor()
        
        deleted_count = 0
        
        for shipment_id in shipment_ids:
            try:
                # Fetch shipment cost and company for due adjustment
                cursor.execute("SELECT shipping_cost, company_id FROM shipments WHERE id = %s", (shipment_id,))
                row = cursor.fetchone()
                shipping_cost = float(row[0] or 0) if row else 0.0
                company_id = row[1] if row else None

                cursor.execute("DELETE FROM shipments WHERE id = %s", (shipment_id,))

                # Reduce company's total_due by the shipment cost
                if company_id:
                    cursor.execute(
                        "UPDATE companies SET total_due = GREATEST(COALESCE(total_due, 0) - %s, 0) WHERE id = %s",
                        (shipping_cost, company_id)
                    )
                deleted_count += 1
            except Exception as e:
                print(f"Error deleting shipment {shipment_id}: {str(e)}")
        
        conn.commit()
        cursor.close()
        conn.close()
        
        return jsonify({
            'success': True, 
            'message': f'Successfully deleted {deleted_count} shipment(s).',
            'deleted': deleted_count
        })
        
    except Exception as e:
        print(f"Bulk delete error: {str(e)}")
        return jsonify({'success': False, 'message': f'Error: {str(e)}'})

@app.route('/sales')
@sales_access_required
def sales():
    conn = get_db_connection()
    cursor = conn.cursor(cursor_factory=RealDictCursor)
    
    # Ensure phone_number column exists
    cursor.execute("""
        SELECT column_name FROM information_schema.columns
        WHERE table_name = 'sales' AND column_name = 'phone_number'
    """)
    has_phone_number = cursor.fetchone() is not None
    if not has_phone_number:
        try:
            cursor.execute("ALTER TABLE sales ADD COLUMN phone_number VARCHAR(50)")
            conn.commit()
        except Exception:
            conn.rollback()

    # Ensure paid/due and VAT columns exist
    def column_exists(table, col):
        cursor.execute(
            """
            SELECT column_name FROM information_schema.columns
            WHERE table_name = %s AND column_name = %s
            """,
            (table, col),
        )
        return cursor.fetchone() is not None

    def ensure_column(table, col, definition):
        if not column_exists(table, col):
            try:
                cursor.execute(f"ALTER TABLE {table} ADD COLUMN {col} {definition}")
                conn.commit()
            except Exception:
                conn.rollback()

    ensure_column('sales', 'vat_percentage', 'DECIMAL(5,2) DEFAULT 5.00')
    ensure_column('sales', 'vat_amount', 'DECIMAL(10,2) DEFAULT 0.00')
    ensure_column('sales', 'paid_amount', 'DECIMAL(10,2) DEFAULT 0.00')
    ensure_column('sales', 'due_amount', 'DECIMAL(10,2) DEFAULT 0.00')

    has_vat_amount = column_exists('sales', 'vat_amount')
    has_paid_amount = column_exists('sales', 'paid_amount')
    has_due_amount = column_exists('sales', 'due_amount')

    vat_expr = "COALESCE(s.vat_amount,0)" if has_vat_amount else "0"
    paid_expr = "COALESCE(s.paid_amount,0)" if has_paid_amount else "0"
    due_expr = "COALESCE(s.due_amount,0)" if has_due_amount else "0"
    
    # Get all sales grouped by buyer and date to show as single transactions
    cursor.execute(f"""
        SELECT 
            s.buyer_name,
            s.sale_date,
            s.payment_method,
            COALESCE(MAX(s.notes), '') AS notes,
            COALESCE(MAX(s.phone_number), '') AS phone_number,
            COUNT(s.id) AS item_count,
            SUM(s.total_price) AS total_amount,
            SUM({vat_expr}) AS vat_amount,
            SUM(s.total_price + {vat_expr}) AS total_with_vat,
            SUM({paid_expr}) AS paid_amount,
            SUM({due_expr}) AS due_amount,
            STRING_AGG(CONCAT(p.product_name, ' (', pv.size, ')'), '; ') AS product_names,
            STRING_AGG(s.id::text, ',') AS sale_ids,
            MIN(s.id) AS first_sale_id
        FROM sales s
        JOIN product_variants pv ON s.product_variant_id = pv.id
        JOIN products p ON pv.product_id = p.id
        GROUP BY s.buyer_name, s.sale_date, s.payment_method
        ORDER BY s.sale_date DESC, MIN(s.id) DESC
    """)
    grouped_sales = cursor.fetchall()
    
    # Generate sale numbers for display
    for sale in grouped_sales:
        sale['sale_number'] = f"SALE-{sale['first_sale_id']:06d}"
    
    cursor.execute("""
        SELECT pv.id,
               p.product_name,
               p.lot_details,
               pv.size,
               COALESCE(pv.quantity, 0) as quantity,
               COALESCE(pv.selling_price, 0) as selling_price
        FROM product_variants pv
        JOIN products p ON pv.product_id = p.id
        WHERE pv.quantity IS NULL OR pv.quantity > 0
    """)
    available_products = cursor.fetchall()

    # Due customers list
    cursor.execute(
        f"""
        SELECT 
            s.buyer_name,
            COALESCE(MAX(s.phone_number), '') AS phone_number,
            SUM(s.total_price + {vat_expr}) AS total_with_vat,
            SUM({paid_expr}) AS paid_amount,
            SUM({due_expr}) AS due_amount,
            STRING_AGG(s.id::text, ',') AS sale_ids
        FROM sales s
        GROUP BY s.buyer_name
        HAVING SUM({due_expr}) > 0
        ORDER BY due_amount DESC
        """
    )
    due_customers = cursor.fetchall()
    cursor.close()
    conn.close()
    
    return render_template(
        'sales.html',
        sales=grouped_sales,
        available_products=available_products,
        due_customers=due_customers,
    )

@app.route('/add_sale', methods=['POST'])
@sales_access_required
def add_sale():
    buyer_name = request.form['buyer_name']
    phone_number = request.form.get('phone_number', '').strip()
    sale_date = request.form['sale_date']
    payment_method = request.form.get('payment_method', 'cash')
    notes = request.form.get('notes', '')
    vat_percentage = money(request.form.get('vat_percentage', 5) or 0)
    partial_payment = request.form.get('partial_payment') == 'on'
    paid_amount_input = money(request.form.get('paid_amount', 0) or 0)
    
    # Get multiple products data
    product_variant_ids = request.form.getlist('product_variant_id[]')
    quantities = request.form.getlist('quantity[]')
    unit_prices = request.form.getlist('unit_price[]')
    
    conn = get_db_connection()
    cursor = conn.cursor()
    
    try:
        # Ensure sales table has required columns
        cursor.execute("""
            SELECT column_name FROM information_schema.columns
            WHERE table_name = 'sales' AND column_name = 'buying_price'
        """)
        has_buying_price = cursor.fetchone() is not None
        if not has_buying_price:
            try:
                cursor.execute("ALTER TABLE sales ADD COLUMN buying_price numeric")
                conn.commit()
                has_buying_price = True
            except Exception:
                conn.rollback()
        
        cursor.execute("""
            SELECT column_name FROM information_schema.columns
            WHERE table_name = 'sales' AND column_name = 'phone_number'
        """)
        has_phone_number = cursor.fetchone() is not None
        if not has_phone_number:
            try:
                cursor.execute("ALTER TABLE sales ADD COLUMN phone_number VARCHAR(50)")
                conn.commit()
                has_phone_number = True
            except Exception:
                conn.rollback()
        
        # Add VAT columns if they don't exist
        cursor.execute("""
            SELECT column_name FROM information_schema.columns
            WHERE table_name = 'sales' AND column_name = 'vat_percentage'
        """)
        has_vat_percentage = cursor.fetchone() is not None
        if not has_vat_percentage:
            try:
                cursor.execute("ALTER TABLE sales ADD COLUMN vat_percentage DECIMAL(5,2) DEFAULT 5.00")
                conn.commit()
                has_vat_percentage = True
            except Exception:
                conn.rollback()
        
        cursor.execute("""
            SELECT column_name FROM information_schema.columns
            WHERE table_name = 'sales' AND column_name = 'vat_amount'
        """)
        has_vat_amount = cursor.fetchone() is not None
        if not has_vat_amount:
            try:
                cursor.execute("ALTER TABLE sales ADD COLUMN vat_amount DECIMAL(10,2) DEFAULT 0.00")
                conn.commit()
                has_vat_amount = True
            except Exception:
                conn.rollback()

        cursor.execute("""
            SELECT column_name FROM information_schema.columns
            WHERE table_name = 'sales' AND column_name = 'paid_amount'
        """)
        has_paid_amount = cursor.fetchone() is not None
        if not has_paid_amount:
            try:
                cursor.execute("ALTER TABLE sales ADD COLUMN paid_amount DECIMAL(10,2) DEFAULT 0.00")
                conn.commit()
                has_paid_amount = True
            except Exception:
                conn.rollback()

        cursor.execute("""
            SELECT column_name FROM information_schema.columns
            WHERE table_name = 'sales' AND column_name = 'due_amount'
        """)
        has_due_amount = cursor.fetchone() is not None
        if not has_due_amount:
            try:
                cursor.execute("ALTER TABLE sales ADD COLUMN due_amount DECIMAL(10,2) DEFAULT 0.00")
                conn.commit()
                has_due_amount = True
            except Exception:
                conn.rollback()
                
        # Generate unique sale number for grouping
        sale_number = f"SALE-{datetime.now().strftime('%Y%m%d%H%M%S')}"
        total_sale_amount = Decimal('0.00')
        vat_amounts = []
        line_totals_with_vat = []
        
        # Validate all products first
        for i in range(len(product_variant_ids)):
            if product_variant_ids[i] and quantities[i] and unit_prices[i]:
                product_variant_id = product_variant_ids[i]
                quantity = int(quantities[i])
                
                # Check available stock
                cursor.execute("SELECT quantity FROM product_variants WHERE id = %s", (product_variant_id,))
                result = cursor.fetchone()
                
                if result:
                    available_qty = result[0] if isinstance(result, tuple) else result.get('quantity')
                    if available_qty is not None and available_qty < quantity:
                        flash(f'Insufficient stock for one of the products. Available: {available_qty}, Requested: {quantity}', 'danger')
                        return redirect(url_for('sales'))
        
        # First pass: compute totals for VAT and payment distribution
        for i in range(len(product_variant_ids)):
            if product_variant_ids[i] and quantities[i] and unit_prices[i]:
                product_variant_id = product_variant_ids[i]
                quantity = int(quantities[i])
                unit_price = money(unit_prices[i])
                total_price = money(quantity) * unit_price
                total_sale_amount += total_price
                vat_amount = (total_price * (vat_percentage / Decimal('100'))).quantize(CENT, rounding=ROUND_HALF_UP)
                vat_amounts.append(vat_amount)
                line_totals_with_vat.append((total_price + vat_amount).quantize(CENT, rounding=ROUND_HALF_UP))
            else:
                vat_amounts.append(Decimal('0.00'))
                line_totals_with_vat.append(Decimal('0.00'))

        total_vat_amount = sum(vat_amounts)
        total_with_vat = (total_sale_amount + total_vat_amount).quantize(CENT, rounding=ROUND_HALF_UP)
        if not partial_payment:
            paid_total = total_with_vat
        else:
            paid_total = paid_amount_input if paid_amount_input <= total_with_vat else total_with_vat
        due_total = (total_with_vat - paid_total).quantize(CENT, rounding=ROUND_HALF_UP)

        # Process each product in the sale with the same sale number
        # Pre-compute proportional paid distribution and fix rounding drift by adjusting last non-zero line
        nonzero_indices = [i for i, v in enumerate(line_totals_with_vat) if v > 0]
        line_paid_list = [Decimal('0.00')] * len(product_variant_ids)
        if nonzero_indices:
            remaining = paid_total
            for idx in nonzero_indices[:-1]:
                lt = line_totals_with_vat[idx]
                if total_with_vat > 0:
                    share = (lt / total_with_vat)
                else:
                    share = Decimal('0')
                proposed = (paid_total * share).quantize(CENT, rounding=ROUND_HALF_UP)
                allocated = min(proposed, lt)
                line_paid_list[idx] = allocated
                remaining -= allocated
            # Last line gets the remainder to guarantee exact sum
            last_idx = nonzero_indices[-1]
            line_paid_list[last_idx] = max(Decimal('0.00'), min(remaining.quantize(CENT, rounding=ROUND_HALF_UP), line_totals_with_vat[last_idx]))

        for i in range(len(product_variant_ids)):
            if product_variant_ids[i] and quantities[i] and unit_prices[i]:
                product_variant_id = product_variant_ids[i]
                quantity = int(quantities[i])
                unit_price = money(unit_prices[i])
                total_price = (money(quantity) * unit_price).quantize(CENT, rounding=ROUND_HALF_UP)
                vat_amount = vat_amounts[i]
                line_total_with_vat = line_totals_with_vat[i]
                line_paid = line_paid_list[i] if nonzero_indices else (Decimal('0.00') if partial_payment else line_total_with_vat)
                line_due = (line_total_with_vat - line_paid).quantize(CENT, rounding=ROUND_HALF_UP)

                # Get current buying price for this variant (to persist historical cost)
                cursor.execute("SELECT COALESCE(buying_price, 0) as buying_price FROM product_variants WHERE id = %s", (product_variant_id,))
                bp_row = cursor.fetchone()
                current_buying = 0
                if bp_row:
                    # bp_row may be tuple or dict depending on cursor
                    if isinstance(bp_row, tuple):
                        current_buying = bp_row[0] or 0
                    elif isinstance(bp_row, dict):
                        current_buying = bp_row.get('buying_price', 0) or 0

                # Update inventory
                cursor.execute("""
                    UPDATE product_variants 
                    SET quantity = COALESCE(quantity, 0) - %s 
                    WHERE id = %s
                """, (quantity, product_variant_id))

                # Record sale with VAT and payment breakdown
                # Note: total_price is pre-VAT amount for profit calculation
                # VAT is added on top and paid by customer
                if has_buying_price and has_phone_number and has_vat_percentage and has_vat_amount and has_paid_amount and has_due_amount:
                    cursor.execute("""
                        INSERT INTO sales 
                        (product_variant_id, buyer_name, phone_number, quantity, unit_price, total_price, sale_date, payment_method, notes, buying_price, vat_percentage, vat_amount, paid_amount, due_amount)
                        VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
                    """, (product_variant_id, buyer_name, phone_number, quantity, float(unit_price), float(total_price), sale_date, payment_method, notes, current_buying, float(vat_percentage), float(vat_amount), float(line_paid), float(line_due)))
                elif has_buying_price and has_phone_number and has_vat_percentage and has_vat_amount:
                    cursor.execute("""
                        INSERT INTO sales 
                        (product_variant_id, buyer_name, phone_number, quantity, unit_price, total_price, sale_date, payment_method, notes, buying_price, vat_percentage, vat_amount)
                        VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
                    """, (product_variant_id, buyer_name, phone_number, quantity, float(unit_price), float(total_price), sale_date, payment_method, notes, current_buying, float(vat_percentage), float(vat_amount)))
                elif has_buying_price and has_phone_number and has_paid_amount and has_due_amount:
                    cursor.execute("""
                        INSERT INTO sales 
                        (product_variant_id, buyer_name, phone_number, quantity, unit_price, total_price, sale_date, payment_method, notes, buying_price, paid_amount, due_amount)
                        VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
                    """, (product_variant_id, buyer_name, phone_number, quantity, float(unit_price), float(total_price), sale_date, payment_method, notes, current_buying, float(line_paid), float(line_due)))
                elif has_phone_number and has_vat_percentage and has_vat_amount and has_paid_amount and has_due_amount:
                    cursor.execute("""
                        INSERT INTO sales 
                        (product_variant_id, buyer_name, phone_number, quantity, unit_price, total_price, sale_date, payment_method, notes, vat_percentage, vat_amount, paid_amount, due_amount)
                        VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
                    """, (product_variant_id, buyer_name, phone_number, quantity, float(unit_price), float(total_price), sale_date, payment_method, notes, float(vat_percentage), float(vat_amount), float(line_paid), float(line_due)))
                elif has_buying_price and has_phone_number:
                    cursor.execute("""
                        INSERT INTO sales 
                        (product_variant_id, buyer_name, phone_number, quantity, unit_price, total_price, sale_date, payment_method, notes, buying_price)
                        VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
                    """, (product_variant_id, buyer_name, phone_number, quantity, float(unit_price), float(total_price), sale_date, payment_method, notes, current_buying))
                elif has_phone_number:
                    cursor.execute("""
                        INSERT INTO sales 
                        (product_variant_id, buyer_name, phone_number, quantity, unit_price, total_price, sale_date, payment_method, notes)
                        VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
                    """, (product_variant_id, buyer_name, phone_number, quantity, float(unit_price), float(total_price), sale_date, payment_method, notes))
                elif has_buying_price:
                    cursor.execute("""
                        INSERT INTO sales 
                        (product_variant_id, buyer_name, quantity, unit_price, total_price, sale_date, payment_method, notes, buying_price)
                        VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
                    """, (product_variant_id, buyer_name, quantity, float(unit_price), float(total_price), sale_date, payment_method, notes, current_buying))
                else:
                    cursor.execute("""
                        INSERT INTO sales 
                        (product_variant_id, buyer_name, quantity, unit_price, total_price, sale_date, payment_method, notes)
                        VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
                    """, (product_variant_id, buyer_name, quantity, float(unit_price), float(total_price), sale_date, payment_method, notes))
        
        conn.commit()
        flash(f'Sale recorded successfully with {len(product_variant_ids)} products! Total amount: AED {total_sale_amount:,.2f}', 'success')
        
    except Exception as e:
        conn.rollback()
        flash(f'Error recording sale: {str(e)}', 'danger')
    
    finally:
        cursor.close()
    
    return redirect(url_for('sales'))

@app.route('/delete_sale_group', methods=['POST'])
@sales_access_required
def delete_sale_group():
    buyer_name = request.form['buyer_name']
    sale_date = request.form['sale_date']
    
    conn = get_db_connection()
    cursor = conn.cursor(cursor_factory=RealDictCursor)
    
    try:
        # Get all sales in this group with their product variants and quantities
        cursor.execute("""
            SELECT product_variant_id, quantity
            FROM sales 
            WHERE buyer_name = %s AND sale_date = %s
        """, (buyer_name, sale_date))
        sales = cursor.fetchall()
        
        sale_count = len(sales)
        
        # Restore stock for each product variant
        for sale in sales:
            product_variant_id = sale['product_variant_id']
            quantity = sale['quantity']
            
            # Add the quantity back to inventory
            cursor.execute("""
                UPDATE product_variants 
                SET quantity = COALESCE(quantity, 0) + %s 
                WHERE id = %s
            """, (quantity, product_variant_id))
        
        # Delete all sales in this group
        cursor.execute("DELETE FROM sales WHERE buyer_name = %s AND sale_date = %s", (buyer_name, sale_date))
        conn.commit()
        
        flash(f'Sale group deleted successfully! {sale_count} records removed and inventory restored.', 'success')
            
    except Exception as e:
        conn.rollback()
        flash(f'Error deleting sale group: {str(e)}', 'danger')
    
    finally:
        cursor.close()
    
    return redirect(url_for('sales'))

@app.route('/bulk_delete_sale_groups', methods=['POST'])
@sales_access_required
def bulk_delete_sale_groups():
    """Bulk delete sale groups and restore inventory"""
    try:
        data = request.get_json()
        sale_groups = data.get('sale_groups', [])
        
        if not sale_groups:
            return jsonify({'success': False, 'message': 'No sales selected'})
        
        conn = get_db_connection()
        cursor = conn.cursor(cursor_factory=RealDictCursor)
        
        deleted_count = 0
        
        for group in sale_groups:
            try:
                buyer_name = group.get('buyer_name')
                sale_date = group.get('sale_date')
                
                # Get all sales in this group with their product variants and quantities
                cursor.execute("""
                    SELECT product_variant_id, quantity
                    FROM sales 
                    WHERE buyer_name = %s AND sale_date = %s
                """, (buyer_name, sale_date))
                sales = cursor.fetchall()
                
                # Restore stock for each product variant
                for sale in sales:
                    product_variant_id = sale['product_variant_id']
                    quantity = sale['quantity']
                    
                    # Add the quantity back to inventory
                    cursor.execute("""
                        UPDATE product_variants 
                        SET quantity = COALESCE(quantity, 0) + %s 
                        WHERE id = %s
                    """, (quantity, product_variant_id))
                
                # Delete all sales in this group
                cursor.execute("DELETE FROM sales WHERE buyer_name = %s AND sale_date = %s", (buyer_name, sale_date))
                deleted_count += 1
            except Exception as e:
                print(f"Error deleting sale group: {str(e)}")
        
        conn.commit()
        cursor.close()
        conn.close()
        
        return jsonify({
            'success': True, 
            'message': f'Successfully deleted {deleted_count} sale group(s) and restored inventory.',
            'deleted': deleted_count
        })
        
    except Exception as e:
        print(f"Bulk delete error: {str(e)}")
        return jsonify({'success': False, 'message': f'Error: {str(e)}'})

@app.route('/get_sale_receipt_data')
@sales_access_required
def get_sale_receipt_data():
    buyer_name = request.args.get('buyer')
    sale_date = request.args.get('date')
    
    conn = get_db_connection()
    cursor = conn.cursor(cursor_factory=RealDictCursor)
    
    try:
        cursor.execute("""
            SELECT 
                s.*,
                p.product_name,
                p.category,
                pv.size,
                pv.selling_price,
                s.unit_price,
                s.quantity,
                s.total_price
            FROM sales s
            JOIN product_variants pv ON s.product_variant_id = pv.id
            JOIN products p ON pv.product_id = p.id
            WHERE s.buyer_name = %s AND s.sale_date = %s
            ORDER BY p.category, p.product_name, pv.size
        """, (buyer_name, sale_date))
        
        sales = cursor.fetchall()
        
        if sales:
            return jsonify({
                'success': True, 
                'sales': sales,
                'buyer': buyer_name,
                'sale_date': sale_date,
                'payment_method': sales[0]['payment_method']
            })
        else:
            return jsonify({'success': False, 'message': 'No sales found'})
            
    except Exception as e:
        return jsonify({'success': False, 'message': str(e)})
    finally:
        cursor.close()

@app.route('/get_sale_group_details')
@sales_access_required
def get_sale_group_details():
    buyer_name = request.args.get('buyer')
    sale_date = request.args.get('date')
    
    conn = get_db_connection()
    cursor = conn.cursor(cursor_factory=RealDictCursor)
    
    try:
        cursor.execute("""
            SELECT 
                s.*,
                p.product_name,
                p.category,
                pv.size,
                pv.selling_price,
                s.unit_price,
                s.quantity,
                s.total_price
            FROM sales s
            JOIN product_variants pv ON s.product_variant_id = pv.id
            JOIN products p ON pv.product_id = p.id
            WHERE s.buyer_name = %s AND s.sale_date = %s
            ORDER BY p.category, p.product_name, pv.size
        """, (buyer_name, sale_date))
        
        sales = cursor.fetchall()
        
        if sales:
            return jsonify({'success': True, 'sales': sales})
        else:
            return jsonify({'success': False, 'message': 'No sales found'})
            
    except Exception as e:
        return jsonify({'success': False, 'message': str(e)})
    finally:
        cursor.close()

@app.route('/edit_sale_group', methods=['POST'])
@sales_access_required
def edit_sale_group():
    """Edit quantities within a sale group and restore/update inventory and VAT accordingly.
    Expected JSON payload:
    {
      "buyer_name": "...",
      "sale_date": "YYYY-MM-DD",
      "items": [
        {"sale_id": 123, "product_variant_id": 45, "old_quantity": 2, "new_quantity": 3, "unit_price": 12.50, "vat_percentage": 5.0}
      ]
    }
    """
    try:
        data = request.get_json(force=True)
    except Exception:
        return jsonify({"success": False, "message": "Invalid JSON payload"}), 400

    buyer_name = (data or {}).get('buyer_name')
    sale_date = (data or {}).get('sale_date')
    # Optional group-level edits
    buyer_name_new = (data or {}).get('buyer_name_new')
    sale_date_new = (data or {}).get('sale_date_new')
    phone_number = (data or {}).get('phone_number')
    payment_method = (data or {}).get('payment_method')
    notes = (data or {}).get('notes')
    items = (data or {}).get('items', [])

    if not buyer_name or not sale_date or not items:
        return jsonify({"success": False, "message": "Missing buyer_name, sale_date or items"}), 400

    conn = get_db_connection()
    cursor = conn.cursor(cursor_factory=RealDictCursor)

    try:
        # Validate that sale rows belong to the provided group
        sale_ids = [it.get('sale_id') for it in items if it.get('sale_id')]
        if not sale_ids:
            return jsonify({"success": False, "message": "Each item must include sale_id"}), 400

        cursor.execute(
            """
            SELECT id, product_variant_id, quantity, unit_price, vat_percentage, paid_amount
            FROM sales
            WHERE id = ANY(%s)
            """,
            (sale_ids,)
        )
        existing_rows = {row['id']: row for row in cursor.fetchall()}

        # If editing group-level fields, update them before quantity changes
        if buyer_name_new or sale_date_new or phone_number is not None or payment_method is not None or notes is not None:
            # Build dynamic SET clause
            set_parts = []
            params = []
            if buyer_name_new:
                set_parts.append('buyer_name = %s')
                params.append(buyer_name_new)
            if sale_date_new:
                set_parts.append('sale_date = %s')
                params.append(sale_date_new)
            if phone_number is not None:
                set_parts.append('phone_number = %s')
                params.append(phone_number)
            if payment_method is not None:
                set_parts.append('payment_method = %s')
                params.append(payment_method)
            if notes is not None:
                set_parts.append('notes = %s')
                params.append(notes)
            if set_parts:
                params.extend([buyer_name, sale_date])
                cursor.execute(
                    f"UPDATE sales SET {', '.join(set_parts)} WHERE buyer_name = %s AND sale_date = %s",
                    tuple(params)
                )
                # If group identifiers change, use new values for subsequent operations
                buyer_name = buyer_name_new or buyer_name
                sale_date = sale_date_new or sale_date

        # First pass: validate availability when increasing quantities
        for it in items:
            sid = it.get('sale_id')
            row = existing_rows.get(sid)
            if not row:
                return jsonify({"success": False, "message": f"Sale row {sid} not found"}), 404

            pv_id = it.get('product_variant_id') or row['product_variant_id']
            old_q = int(it.get('old_quantity', row['quantity'] or 0) or 0)
            new_q = int(it.get('new_quantity', old_q) or 0)
            delta = new_q - old_q
            if delta > 0:
                # Check stock availability
                cursor.execute("SELECT COALESCE(quantity,0) AS quantity FROM product_variants WHERE id = %s", (pv_id,))
                pv = cursor.fetchone()
                available = (pv or {}).get('quantity', 0)
                if available < delta:
                    return jsonify({"success": False, "message": f"Insufficient stock for variant {pv_id}. Available {available}, needed {delta}"}), 400

        # Second pass: apply updates and adjust inventory
        for it in items:
            sid = it.get('sale_id')
            row = existing_rows.get(sid)
            pv_id = it.get('product_variant_id') or row['product_variant_id']
            old_q = int(it.get('old_quantity', row['quantity'] or 0) or 0)
            new_q = int(it.get('new_quantity', old_q) or 0)
            unit_price = float(it.get('unit_price', row['unit_price'] or 0) or 0)
            vat_pct = float(it.get('vat_percentage', row.get('vat_percentage') or 0) or 0)
            delta = new_q - old_q

            # Update inventory: reduce when delta>0 (more sold), increase when delta<0 (returned to stock)
            if delta != 0:
                cursor.execute(
                    """
                    UPDATE product_variants
                    SET quantity = COALESCE(quantity,0) - %s
                    WHERE id = %s
                    """,
                    (delta, pv_id)
                )

            # Recompute totals for the sale row
            new_total_price = round(new_q * unit_price, 2)
            new_vat_amount = round(new_total_price * (vat_pct / 100.0), 2)

            # Keep existing paid_amount; recompute due based on new total with VAT
            existing_paid = float(row.get('paid_amount') or 0)
            new_due_amount = max(round(new_total_price + new_vat_amount - existing_paid, 2), 0)

            # Persist changes
            cursor.execute(
                """
                UPDATE sales
                SET quantity = %s,
                    total_price = %s,
                    vat_percentage = %s,
                    vat_amount = %s,
                    due_amount = %s
                WHERE id = %s AND buyer_name = %s AND sale_date = %s
                """,
                (new_q, new_total_price, vat_pct, new_vat_amount, new_due_amount, sid, buyer_name, sale_date)
            )

        conn.commit()
        return jsonify({"success": True, "message": "Sale group updated and inventory adjusted."})
    except Exception as e:
        conn.rollback()
        return jsonify({"success": False, "message": str(e)}), 500
    finally:
        cursor.close()
        conn.close()

@app.route('/analytics')
@login_required
def analytics():
    ensure_expenses_table()
    conn = get_db_connection()
    cursor = conn.cursor(cursor_factory=RealDictCursor)
    
    # Get selected year and month from query parameters
    selected_year = request.args.get('year', '')
    selected_month = request.args.get('month', '')
    
    # Get available years from sales and expenses
    cursor.execute("""
        SELECT DISTINCT EXTRACT(YEAR FROM sale_date)::text AS year
        FROM sales
        UNION
        SELECT DISTINCT EXTRACT(YEAR FROM expense_date)::text AS year
        FROM expenses
        ORDER BY year DESC
    """)
    available_years = [row['year'] for row in cursor.fetchall()]
    
    # Build date filter for queries
    date_filter_sales = ""
    date_filter_expenses = ""
    filter_params = []
    period_label = "All Time"
    
    if selected_year and selected_month:
        # Both year and month selected - filter by specific month
        date_filter_sales = "AND EXTRACT(YEAR FROM sale_date) = %s AND EXTRACT(MONTH FROM sale_date) = %s"
        date_filter_expenses = "AND EXTRACT(YEAR FROM expense_date) = %s AND EXTRACT(MONTH FROM expense_date) = %s"
        filter_params = [selected_year, selected_month, selected_year, selected_month]
        period_label = f"{selected_year}-{selected_month.zfill(2)}"
    elif selected_year:
        # Only year selected - filter by entire year
        date_filter_sales = "AND EXTRACT(YEAR FROM sale_date) = %s"
        date_filter_expenses = "AND EXTRACT(YEAR FROM expense_date) = %s"
        filter_params = [selected_year, selected_year]
        period_label = selected_year
    else:
        # No filter - show all time data
        filter_params = []
    
    # Monthly/Yearly sales data
    if selected_year and selected_month:
        # Show single month data
        cursor.execute("""
            SELECT TO_CHAR(date_trunc('month', sale_date), 'YYYY-MM') AS month,
                   SUM(total_price) AS revenue
            FROM sales
            WHERE EXTRACT(YEAR FROM sale_date) = %s AND EXTRACT(MONTH FROM sale_date) = %s
            GROUP BY date_trunc('month', sale_date)
            ORDER BY date_trunc('month', sale_date)
        """, (selected_year, selected_month))
    elif selected_year:
        # Show all months in selected year
        cursor.execute("""
            SELECT TO_CHAR(date_trunc('month', sale_date), 'YYYY-MM') AS month,
                   SUM(total_price) AS revenue
            FROM sales
            WHERE EXTRACT(YEAR FROM sale_date) = %s
            GROUP BY date_trunc('month', sale_date)
            ORDER BY date_trunc('month', sale_date)
        """, (selected_year,))
    else:
        # Show last 12 months
        cursor.execute("""
            SELECT TO_CHAR(date_trunc('month', sale_date), 'YYYY-MM') AS month,
                   SUM(total_price) AS revenue
            FROM sales
            WHERE sale_date >= (CURRENT_DATE - INTERVAL '12 months')
            GROUP BY date_trunc('month', sale_date)
            ORDER BY date_trunc('month', sale_date)
        """)
    monthly_sales = cursor.fetchall()
    
    # Monthly/Yearly expenses data
    if selected_year and selected_month:
        # Show single month data
        cursor.execute("""
            SELECT TO_CHAR(date_trunc('month', expense_date), 'YYYY-MM') AS month,
                   SUM(amount) AS total_expense
            FROM expenses
            WHERE EXTRACT(YEAR FROM expense_date) = %s AND EXTRACT(MONTH FROM expense_date) = %s
            GROUP BY date_trunc('month', expense_date)
            ORDER BY date_trunc('month', expense_date)
        """, (selected_year, selected_month))
    elif selected_year:
        # Show all months in selected year
        cursor.execute("""
            SELECT TO_CHAR(date_trunc('month', expense_date), 'YYYY-MM') AS month,
                   SUM(amount) AS total_expense
            FROM expenses
            WHERE EXTRACT(YEAR FROM expense_date) = %s
            GROUP BY date_trunc('month', expense_date)
            ORDER BY date_trunc('month', expense_date)
        """, (selected_year,))
    else:
        # Show last 12 months
        cursor.execute("""
            SELECT TO_CHAR(date_trunc('month', expense_date), 'YYYY-MM') AS month,
                   SUM(amount) AS total_expense
            FROM expenses
            WHERE expense_date >= (CURRENT_DATE - INTERVAL '12 months')
            GROUP BY date_trunc('month', expense_date)
            ORDER BY date_trunc('month', expense_date)
        """)
    monthly_expenses = cursor.fetchall()
    
    # Expense breakdown by category
    if selected_year and selected_month:
        cursor.execute("""
            SELECT COALESCE(category, 'Uncategorized') as category,
                   SUM(amount) as total_amount
            FROM expenses
            WHERE EXTRACT(YEAR FROM expense_date) = %s AND EXTRACT(MONTH FROM expense_date) = %s
            GROUP BY category
            ORDER BY total_amount DESC
        """, (selected_year, selected_month))
    elif selected_year:
        cursor.execute("""
            SELECT COALESCE(category, 'Uncategorized') as category,
                   SUM(amount) as total_amount
            FROM expenses
            WHERE EXTRACT(YEAR FROM expense_date) = %s
            GROUP BY category
            ORDER BY total_amount DESC
        """, (selected_year,))
    else:
        cursor.execute("""
            SELECT COALESCE(category, 'Uncategorized') as category,
                   SUM(amount) as total_amount
            FROM expenses
            GROUP BY category
            ORDER BY total_amount DESC
        """)
    expense_breakdown = cursor.fetchall()
    
    # Determine whether sales table has a buying_price column
    cursor.execute("""
        SELECT column_name FROM information_schema.columns
        WHERE table_name = 'sales' AND column_name = 'buying_price'
    """)
    sales_has_buying = cursor.fetchone() is not None

    # Profit analysis (handle NULL values) - prefer stored sale buying_price when available
    if sales_has_buying:
        if selected_year and selected_month:
            cursor.execute("""
                SELECT p.product_name,
                       SUM(s.quantity) as total_sold,
                       SUM(s.total_price) as total_revenue,
                       SUM(s.quantity * COALESCE(s.buying_price, pv.buying_price, 0)) as total_cost,
                       SUM(s.total_price - (s.quantity * COALESCE(s.buying_price, pv.buying_price, 0))) as profit
                FROM sales s
                JOIN product_variants pv ON s.product_variant_id = pv.id
                JOIN products p ON pv.product_id = p.id
                WHERE EXTRACT(YEAR FROM s.sale_date) = %s AND EXTRACT(MONTH FROM s.sale_date) = %s
                GROUP BY p.product_name
                ORDER BY profit DESC
            """, (selected_year, selected_month))
        elif selected_year:
            cursor.execute("""
                SELECT p.product_name,
                       SUM(s.quantity) as total_sold,
                       SUM(s.total_price) as total_revenue,
                       SUM(s.quantity * COALESCE(s.buying_price, pv.buying_price, 0)) as total_cost,
                       SUM(s.total_price - (s.quantity * COALESCE(s.buying_price, pv.buying_price, 0))) as profit
                FROM sales s
                JOIN product_variants pv ON s.product_variant_id = pv.id
                JOIN products p ON pv.product_id = p.id
                WHERE EXTRACT(YEAR FROM s.sale_date) = %s
                GROUP BY p.product_name
                ORDER BY profit DESC
            """, (selected_year,))
        else:
            cursor.execute("""
                SELECT p.product_name,
                       SUM(s.quantity) as total_sold,
                       SUM(s.total_price) as total_revenue,
                       SUM(s.quantity * COALESCE(s.buying_price, pv.buying_price, 0)) as total_cost,
                       SUM(s.total_price - (s.quantity * COALESCE(s.buying_price, pv.buying_price, 0))) as profit
                FROM sales s
                JOIN product_variants pv ON s.product_variant_id = pv.id
                JOIN products p ON pv.product_id = p.id
                GROUP BY p.product_name
                ORDER BY profit DESC
            """)
    else:
        if selected_year and selected_month:
            cursor.execute("""
                SELECT p.product_name,
                       SUM(s.quantity) as total_sold,
                       SUM(s.total_price) as total_revenue,
                       SUM(s.quantity * COALESCE(pv.buying_price, 0)) as total_cost,
                       SUM(s.total_price - (s.quantity * COALESCE(pv.buying_price, 0))) as profit
                FROM sales s
                JOIN product_variants pv ON s.product_variant_id = pv.id
                JOIN products p ON pv.product_id = p.id
                WHERE EXTRACT(YEAR FROM s.sale_date) = %s AND EXTRACT(MONTH FROM s.sale_date) = %s
                GROUP BY p.product_name
                ORDER BY profit DESC
            """, (selected_year, selected_month))
        elif selected_year:
            cursor.execute("""
                SELECT p.product_name,
                       SUM(s.quantity) as total_sold,
                       SUM(s.total_price) as total_revenue,
                       SUM(s.quantity * COALESCE(pv.buying_price, 0)) as total_cost,
                       SUM(s.total_price - (s.quantity * COALESCE(pv.buying_price, 0))) as profit
                FROM sales s
                JOIN product_variants pv ON s.product_variant_id = pv.id
                JOIN products p ON pv.product_id = p.id
                WHERE EXTRACT(YEAR FROM s.sale_date) = %s
                GROUP BY p.product_name
                ORDER BY profit DESC
            """, (selected_year,))
        else:
            cursor.execute("""
                SELECT p.product_name,
                       SUM(s.quantity) as total_sold,
                       SUM(s.total_price) as total_revenue,
                       SUM(s.quantity * COALESCE(pv.buying_price, 0)) as total_cost,
                       SUM(s.total_price - (s.quantity * COALESCE(pv.buying_price, 0))) as profit
                FROM sales s
                JOIN product_variants pv ON s.product_variant_id = pv.id
                JOIN products p ON pv.product_id = p.id
                GROUP BY p.product_name
                ORDER BY profit DESC
            """)
    profit_analysis = cursor.fetchall()
    
    # Overall financial summary
    if selected_year and selected_month:
        cursor.execute("""
            SELECT COALESCE(SUM(total_price), 0) as total_revenue FROM sales
            WHERE EXTRACT(YEAR FROM sale_date) = %s AND EXTRACT(MONTH FROM sale_date) = %s
        """, (selected_year, selected_month))
    elif selected_year:
        cursor.execute("""
            SELECT COALESCE(SUM(total_price), 0) as total_revenue FROM sales
            WHERE EXTRACT(YEAR FROM sale_date) = %s
        """, (selected_year,))
    else:
        cursor.execute("SELECT COALESCE(SUM(total_price), 0) as total_revenue FROM sales")
    total_revenue = cursor.fetchone()['total_revenue'] or 0
    
    # Overall product cost (buying cost from sales)
    if sales_has_buying:
        if selected_year and selected_month:
            cursor.execute("""
                SELECT COALESCE(SUM(s.quantity * COALESCE(s.buying_price, pv.buying_price, 0)), 0) as total_cost 
                FROM sales s
                JOIN product_variants pv ON s.product_variant_id = pv.id
                WHERE EXTRACT(YEAR FROM s.sale_date) = %s AND EXTRACT(MONTH FROM s.sale_date) = %s
            """, (selected_year, selected_month))
        elif selected_year:
            cursor.execute("""
                SELECT COALESCE(SUM(s.quantity * COALESCE(s.buying_price, pv.buying_price, 0)), 0) as total_cost 
                FROM sales s
                JOIN product_variants pv ON s.product_variant_id = pv.id
                WHERE EXTRACT(YEAR FROM s.sale_date) = %s
            """, (selected_year,))
        else:
            cursor.execute("""
                SELECT COALESCE(SUM(s.quantity * COALESCE(s.buying_price, pv.buying_price, 0)), 0) as total_cost 
                FROM sales s
                JOIN product_variants pv ON s.product_variant_id = pv.id
            """)
    else:
        if selected_year and selected_month:
            cursor.execute("""
                SELECT COALESCE(SUM(s.quantity * COALESCE(pv.buying_price, 0)), 0) as total_cost 
                FROM sales s
                JOIN product_variants pv ON s.product_variant_id = pv.id
                WHERE EXTRACT(YEAR FROM s.sale_date) = %s AND EXTRACT(MONTH FROM s.sale_date) = %s
            """, (selected_year, selected_month))
        elif selected_year:
            cursor.execute("""
                SELECT COALESCE(SUM(s.quantity * COALESCE(pv.buying_price, 0)), 0) as total_cost 
                FROM sales s
                JOIN product_variants pv ON s.product_variant_id = pv.id
                WHERE EXTRACT(YEAR FROM s.sale_date) = %s
            """, (selected_year,))
        else:
            cursor.execute("""
                SELECT COALESCE(SUM(s.quantity * COALESCE(pv.buying_price, 0)), 0) as total_cost 
                FROM sales s
                JOIN product_variants pv ON s.product_variant_id = pv.id
            """)
    row = cursor.fetchone()
    product_cost = row['total_cost'] if row and row.get('total_cost') is not None else 0
    
    # Total expenses
    if selected_year and selected_month:
        cursor.execute("""
            SELECT COALESCE(SUM(amount), 0) as total_expenses FROM expenses
            WHERE EXTRACT(YEAR FROM expense_date) = %s AND EXTRACT(MONTH FROM expense_date) = %s
        """, (selected_year, selected_month))
    elif selected_year:
        cursor.execute("""
            SELECT COALESCE(SUM(amount), 0) as total_expenses FROM expenses
            WHERE EXTRACT(YEAR FROM expense_date) = %s
        """, (selected_year,))
    else:
        cursor.execute("SELECT COALESCE(SUM(amount), 0) as total_expenses FROM expenses")
    total_expenses = cursor.fetchone()['total_expenses'] or 0
    
    # Total cost includes both product cost and expenses
    total_cost = product_cost + total_expenses
    
    # Net profit = Revenue - (Product Cost + Expenses)
    net_profit = total_revenue - total_cost
    profit_margin = (net_profit / total_revenue * 100) if total_revenue > 0 else 0
    
    cursor.close()
    conn.close()
    
    return render_template('analytics.html', 
                         monthly_sales=monthly_sales,
                         monthly_expenses=monthly_expenses,
                         expense_breakdown=expense_breakdown,
                         profit_analysis=profit_analysis,
                         total_revenue=total_revenue,
                         product_cost=product_cost,
                         total_expenses=total_expenses,
                         total_cost=total_cost,
                         net_profit=net_profit,
                         profit_margin=round(profit_margin, 1),
                         selected_year=selected_year,
                         selected_month=selected_month,
                         available_years=available_years,
                         period_label=period_label)

@app.route('/companies')
@login_required
def companies():
    conn = get_db_connection()
    cursor = conn.cursor(cursor_factory=RealDictCursor)
    
    # Ensure advance_payment column exists
    cursor.execute("""
        SELECT column_name FROM information_schema.columns
        WHERE table_name = 'companies' AND column_name = 'advance_payment'
    """)
    has_advance = cursor.fetchone() is not None
    
    if not has_advance:
        cursor.execute("ALTER TABLE companies ADD COLUMN IF NOT EXISTS advance_payment DECIMAL(10,2) DEFAULT 0.00")
        conn.commit()
    
    cursor.execute("SELECT *, COALESCE(advance_payment, 0) as advance_payment FROM companies ORDER BY company_name")
    companies = cursor.fetchall()
    cursor.close()
    conn.close()
    return render_template('companies.html', companies=companies)

@app.route('/add_company', methods=['POST'])
@login_required
def add_company():
    company_name = request.form['company_name']
    contact_person = request.form['contact_person']
    email = request.form['email']
    phone = request.form['phone']
    address = request.form['address']
    redirect_to = request.form.get('redirect_to', 'companies')
    
    conn = get_db_connection()
    cursor = conn.cursor()
    try:
        cursor.execute("""
            INSERT INTO companies (company_name, contact_person, email, phone, address)
            VALUES (%s, %s, %s, %s, %s)
        """, (company_name, contact_person, email, phone, address))
        conn.commit()
        flash('Supplier added successfully!', 'success')
    except Exception as e:
        conn.rollback()
        flash(f'Error adding supplier: {str(e)}', 'danger')
    finally:
        cursor.close()
    
    if redirect_to == 'shipments':
        return redirect(url_for('shipments'))
    return redirect(url_for('companies'))

@app.route('/bulk_delete_companies', methods=['POST'])
@login_required
def bulk_delete_companies():
    """Bulk delete companies, skipping those with shipments"""
    try:
        data = request.get_json()
        company_ids = data.get('company_ids', [])
        
        if not company_ids:
            return jsonify({'success': False, 'message': 'No companies selected'})
        
        conn = get_db_connection()
        cursor = conn.cursor(cursor_factory=RealDictCursor)
        
        deleted_count = 0
        skipped_count = 0
        
        for company_id in company_ids:
            try:
                # Check if company has any shipments
                cursor.execute("SELECT COUNT(*) as shipment_count FROM shipments WHERE company_id = %s", (company_id,))
                result = cursor.fetchone()
                
                if result and result['shipment_count'] > 0:
                    skipped_count += 1
                else:
                    cursor.execute("DELETE FROM companies WHERE id = %s", (company_id,))
                    deleted_count += 1
            except Exception as e:
                print(f"Error deleting company {company_id}: {str(e)}")
                skipped_count += 1
        
        conn.commit()
        
        message = f'Successfully deleted {deleted_count} company(ies).'
        if skipped_count > 0:
            message += f' Skipped {skipped_count} company(ies) with shipment records.'
        
        cursor.close()
        conn.close()
        
        return jsonify({
            'success': True, 
            'message': message,
            'deleted': deleted_count,
            'skipped': skipped_count
        })
        
    except Exception as e:
        print(f"Bulk delete error: {str(e)}")
        return jsonify({'success': False, 'message': f'Error: {str(e)}'})

@app.route('/delete_company/<int:company_id>', methods=['POST'])
@login_required
def delete_company(company_id):
    redirect_to = request.form.get('redirect_to', 'companies')
    conn = get_db_connection()
    cursor = conn.cursor(cursor_factory=RealDictCursor)
    
    try:
        # Check if company has any shipments
        cursor.execute("SELECT COUNT(*) as shipment_count FROM shipments WHERE company_id = %s", (company_id,))
        result = cursor.fetchone()
        
        if result['shipment_count'] > 0:
            flash('Cannot delete supplier because it has shipment records. Please delete the shipments first.', 'danger')
            if redirect_to == 'shipments':
                return redirect(url_for('shipments'))
            return redirect(url_for('companies'))
        
        # Delete company
        cursor.execute("DELETE FROM companies WHERE id = %s", (company_id,))
        conn.commit()
        flash('Supplier deleted successfully!', 'success')
        
    except Exception as e:
        conn.rollback()
        flash(f'Error deleting supplier: {str(e)}', 'danger')
    finally:
        cursor.close()
        conn.close()
    
    if redirect_to == 'shipments':
        return redirect(url_for('shipments'))
    return redirect(url_for('companies'))

@app.route('/edit_company/<int:company_id>', methods=['POST'])
@login_required
def edit_company(company_id):
    company_name = request.form['company_name']
    contact_person = request.form['contact_person']
    email = request.form['email']
    phone = request.form['phone']
    address = request.form['address']
    
    conn = get_db_connection()
    cursor = conn.cursor()
    try:
        cursor.execute("""
            UPDATE companies 
            SET company_name = %s, contact_person = %s, email = %s, phone = %s, address = %s
            WHERE id = %s
        """, (company_name, contact_person, email, phone, address, company_id))
        conn.commit()
        flash('Company updated successfully!', 'success')
    except Exception as e:
        conn.rollback()
        flash(f'Error updating company: {str(e)}', 'danger')
    finally:
        cursor.close()
        conn.close()
    
    return redirect(url_for('shipments'))

@app.route('/toggle_product_visibility/<int:product_id>', methods=['POST'])
@login_required
def toggle_product_visibility(product_id):
    conn = get_db_connection()
    cursor = conn.cursor(cursor_factory=RealDictCursor)
    
    try:
        # Get current visibility status
        cursor.execute("SELECT is_visible FROM products WHERE id = %s", (product_id,))
        product = cursor.fetchone()
        
        if product:
            new_visibility = False if product['is_visible'] else True
            cursor.execute("UPDATE products SET is_visible = %s WHERE id = %s", (new_visibility, product_id))
            conn.commit()
            
            status = "visible" if new_visibility else "hidden"
            flash(f'Product is now {status} in the outlet!', 'success')
        else:
            flash('Product not found!', 'danger')
            
    except Exception as e:
        conn.rollback()
        flash(f'Error updating product visibility: {str(e)}', 'danger')
    
    finally:
        cursor.close()
    
    return redirect(url_for('products'))

# Expense Management Routes
def ensure_expenses_table():
    """Ensure expenses table exists in database"""
    conn = get_db_connection()
    cursor = conn.cursor()
    try:
        cursor.execute("""
            CREATE TABLE IF NOT EXISTS expenses (
                id SERIAL PRIMARY KEY,
                expense_name VARCHAR(255) NOT NULL,
                category VARCHAR(100) DEFAULT NULL,
                amount DECIMAL(10,2) NOT NULL DEFAULT 0.00,
                expense_date DATE NOT NULL,
                notes TEXT DEFAULT NULL,
                created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
            )
        """)
        conn.commit()
    except Exception as e:
        conn.rollback()
        print(f"Error creating expenses table: {e}")
    finally:
        cursor.close()
        conn.close()

@app.route('/expenses')
@login_required
def expenses():
    ensure_expenses_table()
    conn = get_db_connection()
    cursor = conn.cursor(cursor_factory=RealDictCursor)
    
    try:
        # Get all expenses ordered by date desc
        cursor.execute("""
            SELECT * FROM expenses 
            ORDER BY expense_date DESC, id DESC
        """)
        expenses_list = cursor.fetchall()
        
        # Get current year expenses
        cursor.execute("""
            SELECT COALESCE(SUM(amount), 0) as total FROM expenses 
            WHERE EXTRACT(YEAR FROM expense_date) = EXTRACT(YEAR FROM CURRENT_DATE)
        """)
        current_year_expenses = cursor.fetchone()['total']
        
        # Get current month expenses
        cursor.execute("""
            SELECT COALESCE(SUM(amount), 0) as total FROM expenses 
            WHERE date_trunc('month', expense_date) = date_trunc('month', CURRENT_DATE)
        """)
        current_month_expenses = cursor.fetchone()['total']
        
        # Get all companies for payment dropdown
        cursor.execute("""
            SELECT column_name FROM information_schema.columns
            WHERE table_name = 'companies' AND column_name = 'advance_payment'
        """)
        has_advance = cursor.fetchone() is not None
        
        if not has_advance:
            cursor.execute("ALTER TABLE companies ADD COLUMN IF NOT EXISTS advance_payment DECIMAL(10,2) DEFAULT 0.00")
            conn.commit()
        
        cursor.execute("""
            SELECT id, company_name, 
                   COALESCE(total_due, 0) as total_due,
                   COALESCE(advance_payment, 0) as advance_payment
            FROM companies 
            ORDER BY company_name
        """)
        companies = cursor.fetchall()
        
        # Get all company payments with balance AFTER payment (from snapshots)
        cursor.execute("""
            SELECT cp.*, 
                   c.company_name,
                   COALESCE(cp.due_after, 0) as current_due,
                   COALESCE(cp.advance_after, 0) as current_advance
            FROM company_payments cp
            JOIN companies c ON cp.company_id = c.id
            ORDER BY cp.payment_date DESC, cp.id DESC
        """)
        company_payments = cursor.fetchall()
        
        return render_template('expenses.html', 
                             expenses=expenses_list,
                             current_year_expenses=current_year_expenses,
                             current_month_expenses=current_month_expenses,
                             companies=companies,
                             company_payments=company_payments)
    finally:
        cursor.close()
        conn.close()

@app.route('/add_expense', methods=['POST'])
@login_required
def add_expense():
    ensure_expenses_table()
    expense_name = request.form['expense_name']
    category = request.form.get('category', '').strip()
    amount = float(request.form.get('amount', 0) or 0)
    expense_date = request.form['expense_date']
    notes = request.form.get('notes', '')
    
    conn = get_db_connection()
    cursor = conn.cursor()
    
    try:
        cursor.execute("""
            INSERT INTO expenses (expense_name, category, amount, expense_date, notes)
            VALUES (%s, %s, %s, %s, %s)
        """, (expense_name, category, amount, expense_date, notes))
        conn.commit()
        flash('Expense added successfully!', 'success')
    except Exception as e:
        conn.rollback()
        flash(f'Error adding expense: {str(e)}', 'danger')
    finally:
        cursor.close()
        conn.close()
    
    return redirect(url_for('expenses'))

@app.route('/edit_expense/<int:expense_id>', methods=['POST'])
@login_required
def edit_expense(expense_id):
    expense_name = request.form['expense_name']
    category = request.form.get('category', '').strip()
    amount = float(request.form.get('amount', 0) or 0)
    expense_date = request.form['expense_date']
    notes = request.form.get('notes', '')
    
    conn = get_db_connection()
    cursor = conn.cursor()
    
    try:
        cursor.execute("""
            UPDATE expenses 
            SET expense_name = %s, category = %s, amount = %s, 
                expense_date = %s, notes = %s
            WHERE id = %s
        """, (expense_name, category, amount, expense_date, notes, expense_id))
        conn.commit()
        flash('Expense updated successfully!', 'success')
    except Exception as e:
        conn.rollback()
        flash(f'Error updating expense: {str(e)}', 'danger')
    finally:
        cursor.close()
        conn.close()
    
    return redirect(url_for('expenses'))

@app.route('/delete_expense/<int:expense_id>', methods=['POST'])
@login_required
def delete_expense(expense_id):
    conn = get_db_connection()
    cursor = conn.cursor()
    
    try:
        cursor.execute("DELETE FROM expenses WHERE id = %s", (expense_id,))
        conn.commit()
        flash('Expense deleted successfully!', 'success')
    except Exception as e:
        conn.rollback()
        flash(f'Error deleting expense: {str(e)}', 'danger')
    finally:
        cursor.close()
        conn.close()
    
    return redirect(url_for('expenses'))

@app.route('/bulk_delete_expenses', methods=['POST'])
@login_required
def bulk_delete_expenses():
    """Bulk delete expenses"""
    try:
        data = request.get_json()
        expense_ids = data.get('expense_ids', [])
        
        if not expense_ids:
            return jsonify({'success': False, 'message': 'No expenses selected'})
        
        conn = get_db_connection()
        cursor = conn.cursor()
        
        deleted_count = 0
        
        for expense_id in expense_ids:
            try:
                cursor.execute("DELETE FROM expenses WHERE id = %s", (expense_id,))
                deleted_count += 1
            except Exception as e:
                print(f"Error deleting expense {expense_id}: {str(e)}")
        
        conn.commit()
        cursor.close()
        conn.close()
        
        return jsonify({
            'success': True, 
            'message': f'Successfully deleted {deleted_count} expense(s).',
            'deleted': deleted_count
        })
        
    except Exception as e:
        print(f"Bulk delete error: {str(e)}")
        return jsonify({'success': False, 'message': f'Error: {str(e)}'})

# ==================== COMPANY PAYMENT ROUTES ====================

@app.route('/api/companies')
@login_required
def get_all_companies():
    """Get all companies with their due amounts for autocomplete"""
    conn = get_db_connection()
    cursor = conn.cursor(cursor_factory=RealDictCursor)

    try:
        cursor.execute("""
            SELECT id, company_name, contact_person, phone, email, 
                   total_due, advance_payment
            FROM companies
            ORDER BY company_name ASC
        """)
        companies = cursor.fetchall()

        companies_list = []
        for company in companies:
            companies_list.append({
                'id': company['id'],
                'company_name': company['company_name'],
                'contact_person': company['contact_person'],
                'phone': company['phone'],
                'email': company['email'],
                'total_due': float(company['total_due'] or 0),
                'advance_payment': float(company['advance_payment'] or 0)
            })

        return jsonify({
            'success': True,
            'companies': companies_list
        })
    except Exception as e:
        print(f"Error fetching companies: {str(e)}", file=sys.stderr)
        return jsonify({'success': False, 'message': str(e)}), 500
    finally:
        cursor.close()
        conn.close()

@app.route('/api/company/<int:company_id>')
@login_required
def get_company_info(company_id):
    """Get company information including due amount and advance payment"""
    conn = get_db_connection()
    cursor = conn.cursor(cursor_factory=RealDictCursor)

    try:
        cursor.execute("""
            SELECT id, company_name, total_due, advance_payment
            FROM companies
            WHERE id = %s
        """, (company_id,))
        company = cursor.fetchone()

        if company:
            return jsonify({
                'success': True,
                'id': company['id'],
                'company_name': company['company_name'],
                'total_due': float(company['total_due'] or 0),
                'advance_payment': float(company['advance_payment'] or 0)
            })
        else:
            return jsonify({'success': False, 'message': 'Company not found'}), 404
    except Exception as e:
        print(f"Error fetching company info: {str(e)}", file=sys.stderr)
        return jsonify({'success': False, 'message': str(e)}), 500
    finally:
        cursor.close()
        conn.close()

@app.route('/get_company_due/<int:company_id>')
@login_required
def get_company_due(company_id):
    """Get total due amount for a specific company (AJAX)"""
    conn = get_db_connection()
    cursor = conn.cursor(cursor_factory=RealDictCursor)
    
    try:
        cursor.execute("SELECT total_due FROM companies WHERE id = %s", (company_id,))
        company = cursor.fetchone()
        
        if company:
            return jsonify({
                'success': True,
                'total_due': float(company['total_due'] or 0)
            })
        else:
            return jsonify({'success': False, 'message': 'Company not found'}), 404
    except Exception as e:
        return jsonify({'success': False, 'message': str(e)}), 500
    finally:
        cursor.close()
        conn.close()

@app.route('/add_company_payment', methods=['POST'])
@login_required
def add_company_payment():
    """Add a new payment for a company with advance payment support"""
    company_id = request.form['company_id']
    amount = float(request.form['amount'])
    payment_date = request.form['payment_date']
    payment_method = request.form['payment_method']
    notes = request.form.get('notes', '').strip() or None
    
    conn = get_db_connection()
    cursor = conn.cursor(cursor_factory=RealDictCursor)
    
    try:
        # Ensure advance_payment column exists
        cursor.execute("""
            SELECT column_name FROM information_schema.columns
            WHERE table_name = 'companies' AND column_name = 'advance_payment'
        """)
        has_advance = cursor.fetchone() is not None
        
        if not has_advance:
            cursor.execute("ALTER TABLE companies ADD COLUMN advance_payment DECIMAL(10,2) DEFAULT 0.00")
            conn.commit()

        # Ensure snapshot columns exist on company_payments so each payment stores its own state
        snapshot_columns = [
            ('payment_date', "ALTER TABLE company_payments ADD COLUMN IF NOT EXISTS payment_date DATE"),
            ('notes', "ALTER TABLE company_payments ADD COLUMN IF NOT EXISTS notes TEXT"),
            ('due_before', "ALTER TABLE company_payments ADD COLUMN IF NOT EXISTS due_before DECIMAL(10,2)"),
            ('due_after', "ALTER TABLE company_payments ADD COLUMN IF NOT EXISTS due_after DECIMAL(10,2)"),
            ('advance_before', "ALTER TABLE company_payments ADD COLUMN IF NOT EXISTS advance_before DECIMAL(10,2)"),
            ('advance_after', "ALTER TABLE company_payments ADD COLUMN IF NOT EXISTS advance_after DECIMAL(10,2)")
        ]
        for col_name, alter_sql in snapshot_columns:
            cursor.execute("""
                SELECT column_name FROM information_schema.columns
                WHERE table_name = 'company_payments' AND column_name = %s
            """, (col_name,))
            if cursor.fetchone() is None:
                cursor.execute(alter_sql)
                conn.commit()
        
        # Get current due amount
        cursor.execute("SELECT COALESCE(total_due, 0) as total_due, COALESCE(advance_payment, 0) as advance_payment FROM companies WHERE id = %s", (company_id,))
        company = cursor.fetchone()
        current_due = float(company['total_due']) if company else 0
        current_adv = float(company['advance_payment']) if company else 0
        
        # Generate receipt number
        cursor.execute("SELECT generate_receipt_number() as receipt_number")
        result = cursor.fetchone()
        receipt_number = result['receipt_number']
        
        # Compute new balances (snapshot)
        if amount > current_due:
            advance_amount = amount - current_due
            due_after = 0
            advance_after = current_adv + advance_amount
        else:
            advance_amount = 0
            due_after = max(current_due - amount, 0)
            advance_after = current_adv

        # Insert payment with snapshots
        cursor.execute("""
            INSERT INTO company_payments 
            (receipt_number, company_id, payment_date, amount, payment_method, notes,
             due_before, due_after, advance_before, advance_after)
            VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
        """, (receipt_number, company_id, payment_date, amount, payment_method, notes,
               current_due, due_after, current_adv, advance_after))
        
        # Persist new balances on company
        cursor.execute("""
            UPDATE companies 
            SET total_due = %s,
                advance_payment = %s
            WHERE id = %s
        """, (due_after, advance_after, company_id))

        if amount > current_due:
            flash(f'Payment recorded successfully! Receipt: {receipt_number}. Advance of BDT {advance_amount:,.2f} stored for future shipments.', 'success')
        else:
            flash(f'Payment recorded successfully! Receipt: {receipt_number}', 'success')
        
        conn.commit()
    except Exception as e:
        conn.rollback()
        flash(f'Error recording payment: {str(e)}', 'danger')
        print(f"Payment error details: {str(e)}", file=sys.stderr)  # Debug logging
    finally:
        cursor.close()
        conn.close()

    # Check if redirect_to parameter is set
    redirect_to = request.form.get('redirect_to', 'expenses')
    if redirect_to == 'shipments':
        return redirect(url_for('shipments'))
    else:
        return redirect(url_for('expenses'))

@app.route('/get_payment_details/<int:payment_id>')
@login_required
def get_payment_details(payment_id):
    """Get details for a specific payment (AJAX for view modal)"""
    conn = get_db_connection()
    cursor = conn.cursor(cursor_factory=RealDictCursor)
    
    try:
        cursor.execute("""
            SELECT cp.*, c.company_name, c.contact_person, c.phone, c.email,
                   COALESCE(c.total_due, 0) as current_due,
                   COALESCE(c.advance_payment, 0) as current_advance
            FROM company_payments cp
            JOIN companies c ON cp.company_id = c.id
            WHERE cp.id = %s
        """, (payment_id,))
        payment = cursor.fetchone()
        
        if payment:
            return jsonify({
                'success': True,
                'payment': {
                    'id': payment['id'],
                    'receipt_number': payment['receipt_number'],
                    'company_name': payment['company_name'],
                    'contact_person': payment['contact_person'],
                    'phone': payment['phone'],
                    'email': payment['email'],
                    'payment_date': payment['payment_date'].strftime('%Y-%m-%d'),
                    'amount': float(payment['amount']),
                    'payment_method': payment['payment_method'],
                    'notes': payment['notes'],
                    'current_due': float(payment['current_due']),
                    'current_advance': float(payment['current_advance']),
                    'due_before': float(payment.get('due_before') or 0),
                    'due_after': float(payment.get('due_after') or 0),
                    'advance_before': float(payment.get('advance_before') or 0),
                    'advance_after': float(payment.get('advance_after') or 0),
                    'created_at': payment['created_at'].strftime('%Y-%m-%d %H:%M:%S')
                }
            })
        else:
            return jsonify({'success': False, 'message': 'Payment not found'}), 404
    except Exception as e:
        return jsonify({'success': False, 'message': str(e)}), 500
    finally:
        cursor.close()
        conn.close()

@app.route('/delete_company_payment/<int:payment_id>', methods=['POST'])
@login_required
def delete_company_payment(payment_id):
    """Delete a payment and restore the amount to company's total_due"""
    conn = get_db_connection()
    cursor = conn.cursor(cursor_factory=RealDictCursor)
    
    try:
        # Get payment details before deleting
        cursor.execute("""
            SELECT company_id, amount 
            FROM company_payments 
            WHERE id = %s
        """, (payment_id,))
        payment = cursor.fetchone()
        
        if payment:
            company_id = payment['company_id']
            amount = float(payment['amount'])
            
            # Delete payment
            cursor.execute("DELETE FROM company_payments WHERE id = %s", (payment_id,))
            
            # Restore amount to company's total_due
            cursor.execute("""
                UPDATE companies 
                SET total_due = COALESCE(total_due, 0) + %s
                WHERE id = %s
            """, (amount, company_id))
            
            conn.commit()
            flash('Payment deleted successfully and amount restored to company due!', 'success')
        else:
            flash('Payment not found!', 'danger')
    except Exception as e:
        conn.rollback()
        flash(f'Error deleting payment: {str(e)}', 'danger')
    finally:
        cursor.close()
        conn.close()
    
    return redirect(url_for('expenses'))

@app.route('/edit_company_payment/<int:payment_id>', methods=['POST'])
@login_required
def edit_company_payment(payment_id):
    """Edit a company payment and update the company's total_due accordingly"""
    conn = get_db_connection()
    cursor = conn.cursor(cursor_factory=RealDictCursor)
    
    try:
        # Get the old payment details
        cursor.execute("""
            SELECT company_id, amount 
            FROM company_payments 
            WHERE id = %s
        """, (payment_id,))
        old_payment = cursor.fetchone()
        
        if not old_payment:
            flash('Payment not found!', 'danger')
            return redirect(url_for('expenses'))
        
        old_company_id = old_payment['company_id']
        old_amount = float(old_payment['amount'])
        
        # Get new payment details from form
        company_id = int(request.form['company_id'])
        new_amount = float(request.form['amount'])
        payment_date = request.form['payment_date']
        payment_method = request.form['payment_method']
        notes = request.form.get('notes', '')
        
        # Calculate the difference
        amount_difference = new_amount - old_amount
        
        # Update the payment record
        cursor.execute("""
            UPDATE company_payments 
            SET amount = %s, payment_date = %s, payment_method = %s, notes = %s
            WHERE id = %s
        """, (new_amount, payment_date, payment_method, notes, payment_id))
        
        # Adjust company's total_due
        # If amount increased, reduce the due (more was paid)
        # If amount decreased, increase the due (less was paid)
        cursor.execute("""
            UPDATE companies 
            SET total_due = COALESCE(total_due, 0) - %s
            WHERE id = %s
        """, (amount_difference, company_id))
        
        conn.commit()
        flash('Payment updated successfully!', 'success')
        
    except Exception as e:
        conn.rollback()
        flash(f'Error updating payment: {str(e)}', 'danger')
    finally:
        cursor.close()
        conn.close()
    
    return redirect(url_for('expenses'))

@app.route('/add_payment', methods=['POST'])
@login_required
def add_payment():
    """Record additional payment for a customer with due"""
    buyer_name = request.form.get('buyer_name', '').strip()
    payment_amount = float(request.form.get('payment_amount', 0) or 0)
    payment_date = request.form.get('payment_date', datetime.now().strftime('%Y-%m-%d'))
    payment_method = request.form.get('payment_method', 'cash')
    payment_notes = request.form.get('payment_notes', '')

    if not buyer_name or payment_amount <= 0:
        flash('Invalid payment details', 'danger')
        return redirect(url_for('sales'))

    conn = get_db_connection()
    cursor = conn.cursor(cursor_factory=RealDictCursor)

    try:
        # Get all sales for this customer grouped by (buyer, date)
        cursor.execute("""
            SELECT s.buyer_name, s.sale_date, s.id,
                   s.total_price,
                   COALESCE(s.vat_amount, 0) as vat_amount,
                   COALESCE(s.paid_amount, 0) as paid_amount,
                   COALESCE(s.due_amount, 0) as due_amount
            FROM sales s
            WHERE s.buyer_name = %s
            ORDER BY s.sale_date DESC, s.id DESC
        """, (buyer_name,))
        sales = cursor.fetchall()

        if not sales:
            flash(f'No sales found for {buyer_name}', 'warning')
            return redirect(url_for('sales'))

        # Calculate total due
        total_due = sum(float(s.get('due_amount', 0) or 0) for s in sales)

        if total_due <= 0:
            flash(f'No outstanding balance for {buyer_name}', 'info')
            return redirect(url_for('sales'))

        if payment_amount > total_due:
            flash(f'Payment amount exceeds due balance. Due: AED {total_due:.2f}', 'warning')
            return redirect(url_for('sales'))

        # Distribute payment across sales by due amount (FIFO by date)
        remaining_payment = payment_amount

        for sale in sales:
            if remaining_payment <= 0:
                break

            current_due = float(sale.get('due_amount', 0) or 0)
            if current_due <= 0:
                continue

            # Apply payment to this sale
            amount_to_reduce = min(remaining_payment, current_due)
            new_due = max(current_due - amount_to_reduce, 0)
            new_paid = float(sale.get('paid_amount', 0) or 0) + amount_to_reduce

            cursor.execute("""
                UPDATE sales
                SET paid_amount = %s, due_amount = %s
                WHERE id = %s
            """, (new_paid, new_due, sale['id']))

            remaining_payment -= amount_to_reduce

        conn.commit()
        flash(f'Payment of AED {payment_amount:.2f} recorded for {buyer_name}', 'success')

    except Exception as e:
        conn.rollback()
        flash(f'Error recording payment: {str(e)}', 'danger')
    finally:
        cursor.close()
        conn.close()

    return redirect(url_for('sales'))

@app.route('/api/vat_summary')
@login_required
def get_vat_summary():
    """Get VAT summary for a date range"""
    date_from = request.args.get('date_from', '')
    date_to = request.args.get('date_to', '')
    
    conn = get_db_connection()
    cursor = conn.cursor(cursor_factory=RealDictCursor)
    
    try:
        # Build the query with date filters if provided
        where_clause = "WHERE 1=1"
        params = []
        
        if date_from:
            where_clause += " AND s.sale_date >= %s"
            params.append(date_from)
        
        if date_to:
            where_clause += " AND s.sale_date <= %s"
            params.append(date_to)
        
        # Get VAT summary
        query = f"""
            SELECT 
                COUNT(DISTINCT CONCAT(s.buyer_name, '|', s.sale_date)) as sales_count,
                COALESCE(SUM(s.total_price), 0) as subtotal,
                COALESCE(SUM(s.vat_amount), 0) as total_vat,
                COALESCE(SUM(s.total_price + COALESCE(s.vat_amount, 0)), 0) as grand_total
            FROM sales s
            {where_clause}
        """
        
        cursor.execute(query, tuple(params))
        result = cursor.fetchone()
        
        response = {
            'success': True,
            'sales_count': int(result['sales_count'] or 0),
            'subtotal': float(result['subtotal'] or 0),
            'total_vat': float(result['total_vat'] or 0),
            'grand_total': float(result['grand_total'] or 0),
            'date_from': date_from,
            'date_to': date_to
        }
        
        return jsonify(response)
        
    except Exception as e:
        return jsonify({
            'success': False,
            'message': f'Error calculating VAT summary: {str(e)}'
        }), 500
    
    finally:
        cursor.close()
        conn.close()

@app.context_processor
def inject_user():
    """Make user data available to all templates"""
    return {
        'user': session.get('username'),
        'user_type': session.get('user_type'),
        'is_logged_in': 'loggedin' in session,
        'today': datetime.now().strftime('%Y-%m-%d')
    }

# Create upload folder if it doesn't exist
if not os.path.exists(app.config['UPLOAD_FOLDER']):
    os.makedirs(app.config['UPLOAD_FOLDER'])

if __name__ == '__main__':
    app.run(debug=True, host='0.0.0.0', port=5000)