from sqlalchemy.ext.asyncio import AsyncSession
from src.models.member.models import Member
from src.models.module.models import Module
from src.models.project.models import Project
from src.models.users.models import HrmsEmployeeProfile
from src.models.task.models import Task, TaskAssignedMember
from src.request.task.TaskRequest import TaskCreateRequest, TaskUpdateRequest, ProjectModelMemeberTaskResponse, TaskStatus
from fastapi import HTTPException, status
from typing import Optional
from sqlalchemy import select, update, func, case, and_
from datetime import datetime,timezone
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import selectinload

class TaskRepo:
    def __init__(self,session:AsyncSession) -> None:
        self.session = session

    async def create(self, task_data: TaskCreateRequest, user_id: int) -> Task:
        try:
            # 1. Create the task record
            task_dict = task_data.model_dump(exclude={"assigned_members"}, exclude_unset=True)
            record = Task(**task_dict)
            record.created_by = user_id        
            self.session.add(record)
            await self.session.flush() # Flushes to get the new task ID

            # 2. Add the assigned members
            for member_data in task_data.assigned_members:
                member_record = TaskAssignedMember(
                    task_id=record.id,
                    hrms_user_id=member_data.hrms_user_id,
                    created_by=user_id,
                    updated_by=user_id,
                )
                self.session.add(member_record)

            await self.session.commit()
            
            # 3. Eagerly load the newly created record with its assigned members
            # This is the crucial step to prevent the MissingGreenlet error
            eagerly_loaded_record = await self.session.execute(
                select(Task)
                .options(joinedload(Task.assigned_members))
                .where(Task.id == record.id)
            )
            final_record = eagerly_loaded_record.scalars().first()

            return final_record
    
        except Exception as e:
            await self.session.rollback()
            # It's better to log the full traceback for debugging
            print(f"Error during task creation: {e}") 
            raise HTTPException(
                status_code=500,
                detail=f"An error occurred during task creation: {e}"
            )


    # This helper method is crucial and should be in your repo
    async def find_with_members(self, id: int) -> Optional[Task]:
        stmt = (
            select(Task)
            .options(joinedload(Task.assigned_members))
            .where(Task.id == id)
        )
        result = await self.session.execute(stmt)
        return result.scalars().first()

    async def update(self, id: int, update_data: TaskUpdateRequest, user_id: int) -> Task:
        try:
            # Step 1: Find the record to update, eagerly loading related members
            # Using a helper method is good practice to prevent duplicate code
            record = await self.find_with_members(id)
            if record is None:
                raise HTTPException(status_code=404, detail=f"Record not found for id {id}")

            # Step 2: Update task attributes
            task_updates = update_data.model_dump(exclude={"assigned_members"}, exclude_unset=True)
            task_updates["updated_by"] = user_id
            task_updates["updated_at"] = datetime.now(timezone.utc)

            for key, value in task_updates.items():
                setattr(record, key, value)

            # Step 3: Handle assigned members
            if update_data.assigned_members is not None:
                # Set is_active to false for existing members
                stmt = (
                    update(TaskAssignedMember)
                    .where(TaskAssignedMember.task_id == id)
                    .where(TaskAssignedMember.is_active == True)
                    .values(is_active=False, updated_at=datetime.now(timezone.utc), updated_by=user_id)
                )
                await self.session.execute(stmt)

                # Add new members
                for member_data in update_data.assigned_members:
                    member_record = TaskAssignedMember(
                        task_id=record.id,
                        hrms_user_id=member_data.hrms_user_id,
                        is_active=member_data.is_active,
                        created_by=user_id,
                        updated_by=user_id,
                    )
                    self.session.add(member_record)

            # Step 4: Commit the entire transaction outside the if-block
            await self.session.commit()

            final_record = await self.session.execute(
                select(Task)
                .options(joinedload(Task.assigned_members))
                .where(Task.id == id)
            )
            record = final_record.scalars().first()

            # Step 5: Filter the assigned_members list to only show active ones
            if record and record.assigned_members:
                record.assigned_members = [
                    member for member in record.assigned_members if member.is_active
                ]

            return record
        except Exception as e:
            await self.session.rollback()
            print(f"Error during task update: {e}") 
            raise HTTPException(
                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
                detail=f"An error occurred during task update: {e}"
            )
            
            
    async def find(self, id: int) -> Task | None:
        stmt = (
            select(Task)
            .options(joinedload(Task.assigned_members))
            .where(Task.id == id)
        )
        result = await self.session.execute(stmt)
        return result.scalars().first()
    
    async def readAll(self) -> list[Task]:
        stmt = select(Task).options(joinedload(Task.assigned_members))
        result = await self.session.execute(stmt)
        return result.scalars().unique().all()
    
    async def get_task_counts(self):
        base_query = select(func.count()).filter(Task.is_active == True)

        # Execute each count query separately
        total_query = base_query
        completed_query = base_query.filter(Task.status == TaskStatus.COMPLETED)
        inprogress_query = base_query.filter(Task.status == TaskStatus.INPROGRESS)
        new_query = base_query.filter(Task.status == TaskStatus.NEW)

        # Use await for each session execution
        total = await self.session.scalar(total_query)
        completed = await self.session.scalar(completed_query)
        inprogress = await self.session.scalar(inprogress_query)
        new = await self.session.scalar(new_query)

        return {
            "Total_Task": total,
            "Completed": completed,
            "InProgress": inprogress,
            "New": new
        }
    
    async def get_user_wise_task_counts(self):

        statement = (
        select(
            TaskAssignedMember.hrms_user_id.label('hrms_user_id'),
            func.trim((func.coalesce(HrmsEmployeeProfile.first_name, '') + 
                    " " + 
                func.coalesce(HrmsEmployeeProfile.last_name, ''))).label("full_name"),

            func.count(case((Task.is_active == True, 1))).label("total"),
            func.count(case((Task.status == TaskStatus.COMPLETED, 1))).label("completed"),
            func.count(case((Task.status == TaskStatus.INPROGRESS, 1))).label("inprogress"),
            func.count(case((Task.status == TaskStatus.NEW, 1))).label("new"),
        )
        .join(Task, TaskAssignedMember.task_id == Task.id)
        .join(HrmsEmployeeProfile, HrmsEmployeeProfile.user_id == TaskAssignedMember.hrms_user_id)  # Ensure this join is valid
        .where( (Task.is_active == True) &
                (TaskAssignedMember.is_active == True) )
        .group_by(TaskAssignedMember.hrms_user_id, HrmsEmployeeProfile.first_name, HrmsEmployeeProfile.last_name)
        )
        # Execute the single query and fetch all results
        result_simple = await self.session.execute(statement)
        result = result_simple.mappings().all()
        #user_stats = {}
        
        return [
                    {
                        "hrms_user_id": row["hrms_user_id"],
                        "full_name": row["full_name"],
                        "Total_Task": row["total"],
                        "Completed": row["completed"],
                        "InProgress": row["inprogress"],
                        "New": row["new"],
                    }
                    for row in result
                ]

    async def get_project_wise_task_counts(self):
        statement = (
        select(
            Project.id.label('project_id'),
            Project.name.label('project_name'),
            func.count(case((Task.is_active == True, 1))).label("total"),
            func.count(case((Task.status == TaskStatus.COMPLETED, 1))).label("completed"),
            func.count(case((Task.status == TaskStatus.INPROGRESS, 1))).label("inprogress"),
            func.count(case((Task.status == TaskStatus.NEW, 1))).label("new"),
        )
        .join(Module, Task.module_id == Module.id) 
        .join(Project, Module.project_id == Project.id)
        .join(TaskAssignedMember, TaskAssignedMember.task_id == Task.id)
        .where( (Task.is_active == True) &
                (TaskAssignedMember.is_active == True) )
        .group_by(Project.id, Project.name)
        )
        # Execute the single query and fetch all results
        result_simple = await self.session.execute(statement)
        result = result_simple.mappings().all()
        #user_stats = {}
        
        return [
                    {
                        "project_id": row["project_id"],
                        "project_name": row["project_name"],
                        "Total_Task": row["total"],
                        "Completed": row["completed"],
                        "InProgress": row["inprogress"],
                        "New": row["new"],
                    }
                    for row in result
                ]
    
    async def task_list_by_project(self, project_id: int) -> list[ProjectModelMemeberTaskResponse]:
        # Pull the project and all required relationships in one shot
        
        stmt = (
            select(
                Project.id.label("project_id"),
                Project.name.label("project_name"),
                Module.id.label("module_id"),
                Module.title.label("module_name"),
                Task.id.label("task_id"),
                Task.title.label("task_name"),
                Task.description.label("task_description"),
                Task.priority.label("task_priority"),
                Task.status.label("task_status"),
                Task.start_date.label("task_start_date"),    
                Task.end_date.label("task_end_date"),    
                TaskAssignedMember.id.label("am_id"),
                TaskAssignedMember.hrms_user_id.label("am_user_id"),
                #(HrmsEmployeeProfile.first_name + " " + HrmsEmployeeProfile.last_name).label("am_name"), 
                func.trim((func.coalesce(HrmsEmployeeProfile.first_name, '') + 
                    " " + 
                func.coalesce(HrmsEmployeeProfile.last_name, ''))).label("am_name")
        
            )
            .join(Module, Module.project_id == Project.id)
            .join(Task, Task.module_id == Module.id)    
            .join(TaskAssignedMember, TaskAssignedMember.task_id == Task.id)
            .join(HrmsEmployeeProfile, HrmsEmployeeProfile.user_id == TaskAssignedMember.hrms_user_id)
            .where( (Project.id == project_id) &
                (TaskAssignedMember.is_active == True) & 
                (Task.is_active == True)                
            )
            .order_by(Module.id, Task.id)    
        )
        
        result = await self.session.execute(stmt)
        
        rows = result.mappings().all()  # This is now a list of Row objects (similar to named tuples)
        #print(rows)
        # Handle no rows returned
        if not rows:
            return {"message": "No data found for the given project ID."}
        #project_data = []
        # Initialize the project data structure
        project_data = {
            'project_id': project_id,
            'project_name': rows[0]["project_name"],  # Access by column name
            'modules': []
        }

        modules: dict[int, dict] = {}

        for r in rows:
            m_id = r["module_id"]
            if m_id not in modules:
                modules[m_id] = {
                    "module_id": m_id,
                    "module_name": r["module_name"],
                    "tasks": {}
                }

            if r["task_id"] is None:
                continue

            t_id = r["task_id"]
            if t_id not in modules[m_id]["tasks"]:
                modules[m_id]["tasks"][t_id] = {
                    "task_id": t_id,
                    "task_name": r["task_name"],
                    "task_description": r["task_description"],
                    "task_priority": r["task_priority"],
                    "task_status": r["task_status"],
                    "task_start_date": r["task_start_date"],
                    "task_end_date": r["task_end_date"],
                    "assigned_members": []
                }

            if r["am_id"] is not None:
                modules[m_id]["tasks"][t_id]["assigned_members"].append({
                    "am_id": r["am_id"],
                    "am_user_id": r["am_user_id"],
                    "am_name": r["am_name"]
                })

        # convert inner dicts to lists
        for mod in modules.values():
            mod["tasks"] = list(mod["tasks"].values())

        project_data["modules"] = list(modules.values())
        return project_data
    
    
    async def task_status_generic(self, project_id: Optional[int] = None, user_id: Optional[int] = None, task_status: Optional[str] = None) -> list[ProjectModelMemeberTaskResponse]:
        # Pull the project and all required relationships in one shot
        
        stmt = (
            select(
                Project.id.label("project_id"),
                Project.name.label("project_name"),
                Module.id.label("module_id"),
                Module.title.label("module_name"),
                Task.id.label("task_id"),
                Task.title.label("task_name"),
                Task.description.label("task_description"),
                Task.priority.label("task_priority"),
                Task.status.label("task_status"),
                Task.start_date.label("task_start_date"),    
                Task.end_date.label("task_end_date"),    
                TaskAssignedMember.id.label("am_id"),
                TaskAssignedMember.hrms_user_id.label("am_user_id"),
                #(HrmsEmployeeProfile.first_name + " " + HrmsEmployeeProfile.last_name).label("am_name"), 
                func.trim((func.coalesce(HrmsEmployeeProfile.first_name, '') + 
                    " " + 
                func.coalesce(HrmsEmployeeProfile.last_name, ''))).label("am_name")
        
            )
            .join(Module, Module.project_id == Project.id)
            .join(Task, Task.module_id == Module.id)    
            .join(TaskAssignedMember, TaskAssignedMember.task_id == Task.id)
            .join(HrmsEmployeeProfile, HrmsEmployeeProfile.user_id == TaskAssignedMember.hrms_user_id)
            .where( 
                Task.is_active.is_(True),
                TaskAssignedMember.is_active.is_(True)                
            )
            .order_by(Project.id, Module.id, Task.id)    
        )
        
            # dynamic filters
        if project_id is not None:
            stmt = stmt.where(Project.id == project_id)

        if user_id is not None:
            stmt = stmt.where(TaskAssignedMember.hrms_user_id == user_id)
        #print(task_status)
        if task_status is not None:
            stmt = stmt.where(Task.status == task_status)

        result = await self.session.execute(stmt)
        
        rows = result.mappings().all()  # This is now a list of Row objects (similar to named tuples)
        #print(rows)
        # Handle no rows returned
        if not rows:
            return []
        #project_data = []
        # Initialize the project data structure
        project_data = {
            'project_id': project_id,
            'project_name': rows[0]["project_name"],  # Access by column name
            'modules': []
        }

        # build nested JSON
        projects: dict[int, dict] = {}
        
        for r in rows:
                p_id = r["project_id"]
                if p_id not in projects:
                    projects[p_id] = {
                        "project_id": p_id,
                        "project_name": r["project_name"],
                        "modules": {}
                    }
                    
                m_id = r["module_id"]
                modules = projects[p_id]["modules"]
                if m_id not in modules:
                    modules[m_id] = {
                        "module_id": m_id,
                        "module_name": r["module_name"],
                        "tasks": {}
                    }

                if r["task_id"] is None:
                    continue

                t_id = r["task_id"]
                tasks = modules[m_id]["tasks"]
                if t_id not in tasks:
                    tasks[t_id] = {
                        "task_id": t_id,
                        "task_name": r["task_name"],
                        "task_description": r["task_description"],
                        "task_priority": r["task_priority"],
                        "task_status": r["task_status"],
                        "task_start_date": r["task_start_date"],
                        "task_end_date": r["task_end_date"],
                        "assigned_members": []
                    }

                if r["am_id"] is not None:
                    tasks[t_id]["assigned_members"].append({
                        "am_id": r["am_id"],
                        "am_user_id": r["am_user_id"],
                        "am_name": r["am_name"]
                    })

        # convert nested dicts to lists
        for proj in projects.values():
            for mod in proj["modules"].values():
                mod["tasks"] = list(mod["tasks"].values())
            proj["modules"] = list(proj["modules"].values())

        # if multiple projects match, return a list; if only one, list still returned
        return list(projects.values())
    