import streamlit as st
import boto3
import uuid
import time
import json
from datetime import datetime
import io

# Additional imports for file processing
import PyPDF2
import docx
import tempfile
import os

# Configure page with a custom theme and favicon
st.set_page_config(
    page_title="Document Analysis with Bedrock Agent",
    layout="wide",
    initial_sidebar_state="expanded",
    menu_items={
        'About': "Document Analysis powered by Amazon Bedrock"
    }
)

# Custom CSS for better styling
st.markdown("""
<style>
    .main-header {
        font-size: 2.5rem;
        color: #232F3E;
        margin-bottom: 1rem;
    }
    .sub-header {
        font-size: 1.5rem;
        color: #232F3E;
        margin-top: 2rem;
    }
    .file-info {
        background-color: #f0f2f6;
        padding: 1rem;
        border-radius: 0.5rem;
        margin-bottom: 1rem;
    }
    .user-message {
        background-color: #E9F5FD;
        padding: 1rem;
        border-radius: 0.5rem;
        margin-bottom: 0.5rem;
    }
    .assistant-message {
        background-color: #F7F7F7;
        padding: 1rem;
        border-radius: 0.5rem;
        margin-bottom: 0.5rem;
    }
    .footer {
        margin-top: 3rem;
        text-align: center;
        color: #666;
    }
    .stButton>button {
        background-color: #FF9900;
        color: white;
    }
    .config-section {
        background-color: #f9f9f9;
        padding: 1rem;
        border-radius: 0.5rem;
        margin-bottom: 1.5rem;
        border-left: 3px solid #FF9900;
    }
    .chat-container {
        margin-bottom: 5rem;
    }
    .truncation-notice {
        background-color: #FFECB3;
        padding: 0.5rem;
        border-radius: 0.3rem;
        font-size: 0.8rem;
        margin-top: 0.5rem;
    }
    .production-note {
        background-color: #E1F5FE;
        padding: 0.8rem;
        border-radius: 0.3rem;
        border-left: 3px solid #039BE5;
        margin: 1rem 0;
        font-size: 0.9rem;
    }
    .file-size-error {
        background-color: #FFCDD2;
        padding: 0.8rem;
        border-radius: 0.3rem;
        border-left: 3px solid #D32F2F;
        margin: 1rem 0;
    }
</style>
""", unsafe_allow_html=True)

# Initialize session state
if 'session_id' not in st.session_state:
    st.session_state.session_id = str(uuid.uuid4())
if 'messages' not in st.session_state:
    st.session_state.messages = []
if 'document_uploaded' not in st.session_state:
    st.session_state.document_uploaded = False
if 'document_content' not in st.session_state:
    st.session_state.document_content = ""
if 'document_name' not in st.session_state:
    st.session_state.document_name = ""
if 'analysis_time' not in st.session_state:
    st.session_state.analysis_time = None
if 'is_truncated' not in st.session_state:
    st.session_state.is_truncated = False
if 'original_length' not in st.session_state:
    st.session_state.original_length = 0

# Default AWS region and agent configuration
if 'aws_region' not in st.session_state:
    st.session_state.aws_region = "us-east-1"
if 'agent_id' not in st.session_state:
    st.session_state.agent_id = "D4LXEMGLL8"
if 'agent_alias_id' not in st.session_state:
    st.session_state.agent_alias_id = "UIZTXNYTCP"
if 'config_saved' not in st.session_state:
    st.session_state.config_saved = False

# Constants
MAX_CHARS = 10000  # Fixed at 10000 characters
MAX_FILE_SIZE_MB = 5  # 5MB file size limit

# Function to extract text from PDF
def extract_text_from_pdf(file_content):
    pdf_reader = PyPDF2.PdfReader(io.BytesIO(file_content))
    text = ""
    for page_num in range(len(pdf_reader.pages)):
        text += pdf_reader.pages[page_num].extract_text() + "\n"
    return text

# Function to extract text from DOCX
def extract_text_from_docx(file_content):
    with tempfile.NamedTemporaryFile(delete=False, suffix='.docx') as tmp:
        tmp.write(file_content)
        tmp_path = tmp.name
    
    try:
        doc = docx.Document(tmp_path)
        text = "\n".join([paragraph.text for paragraph in doc.paragraphs])
        return text
    finally:
        # Clean up the temporary file
        if os.path.exists(tmp_path):
            os.remove(tmp_path)

# Function to extract text from uploaded file based on file type
def extract_text_from_file(uploaded_file):
    file_type = uploaded_file.name.split('.')[-1].lower()
    file_content = uploaded_file.getvalue()
    
    if file_type == 'pdf':
        return extract_text_from_pdf(file_content)
    elif file_type in ['docx', 'doc']:
        return extract_text_from_docx(file_content)
    elif file_type == 'txt':
        return file_content.decode('utf-8')
    else:
        raise ValueError(f"Unsupported file type: {file_type}")

# Initialize Bedrock Agent Runtime client
@st.cache_resource(show_spinner=False)
def get_bedrock_agent_client(region):
    return boto3.client('bedrock-agent-runtime', region_name=region)

# Function to invoke agent and get raw response
def invoke_agent(input_text):
    try:
        # Get the client with current region
        bedrock_agent = get_bedrock_agent_client(st.session_state.aws_region)
        
        # Show a small spinner during API call
        with st.spinner("Processing..."):
            # Call the agent
            response = bedrock_agent.invoke_agent(
                agentId=st.session_state.agent_id,
                agentAliasId=st.session_state.agent_alias_id,
                sessionId=st.session_state.session_id,
                inputText=input_text
            )
            
            # Process the response
            if 'completion' in response and hasattr(response['completion'], '__iter__'):
                full_response = ""
                for event in response['completion']:
                    if isinstance(event, dict) and 'chunk' in event and 'bytes' in event['chunk']:
                        chunk_text = event['chunk']['bytes'].decode('utf-8')
                        full_response += chunk_text
                
                # Try to parse as JSON for nicer display
                try:
                    return json.loads(full_response)
                except:
                    return full_response
            else:
                return "No completion found in response"
    
    except Exception as e:
        return f"Error: {str(e)}"

# Function to save configuration
def save_configuration():
    st.session_state.config_saved = True
    # Reset session when configuration changes
    st.session_state.session_id = str(uuid.uuid4())
    st.session_state.messages = []

# Function to reset session
def reset_session():
    st.session_state.document_uploaded = False
    st.session_state.document_content = ""
    st.session_state.document_name = ""
    st.session_state.messages = []
    st.session_state.session_id = str(uuid.uuid4())
    st.session_state.analysis_time = None
    st.session_state.is_truncated = False
    st.session_state.original_length = 0

# Sidebar for configuration and info
with st.sidebar:
    st.image("https://a0.awsstatic.com/libra-css/images/logos/aws_logo_smile_1200x630.png", width=250)
    
    # Configuration section
    st.markdown("### Agent Configuration")
    
    config_expanded = not st.session_state.config_saved
    with st.expander("Edit Configuration", expanded=config_expanded):
        st.markdown('<div class="config-section">', unsafe_allow_html=True)
        
        # AWS Region selection with ONLY us-east-1, us-east-2 and us-west-2
        region_options = [
            "US East (N. Virginia)",
            "US East (Ohio)",
            "US West (Oregon)"
        ]
        
        # Map display names to actual region codes
        region_mapping = {
            "US East (N. Virginia)": "us-east-1",
            "US East (Ohio)": "us-east-2",
            "US West (Oregon)": "us-west-2"

        }
        
        # Find the current index based on the region code
        current_region_display = next(
            (display for display, code in region_mapping.items() if code == st.session_state.aws_region),
            region_options[0]
        )
        
        selected_region_display = st.selectbox(
            "AWS Region", 
            options=region_options,
            index=region_options.index(current_region_display)
        )
        
        # Convert display name back to region code
        selected_region = region_mapping[selected_region_display]
        
        # Agent ID input
        agent_id = st.text_input("Agent ID", value=st.session_state.agent_id)
        
        # Agent Alias ID input
        agent_alias_id = st.text_input("Agent Alias ID", value=st.session_state.agent_alias_id)
        
        # Information about limits
        st.info("Document analysis is limited to 10,000 characters and file size is limited to 5MB.")
        
        # Production recommendation note
        st.markdown(
            '<div class="production-note">'
            '<strong>Production Recommendation:</strong> This demo limits document size to prevent timeouts and ensure reliable performance. '
            'For production applications, implement document chunking with context management to process larger documents efficiently.'
            '</div>',
            unsafe_allow_html=True
        )
        
        # Save button
        if st.button("Save Configuration"):
            st.session_state.aws_region = selected_region
            st.session_state.agent_id = agent_id
            st.session_state.agent_alias_id = agent_alias_id
                
            save_configuration()
            st.success("Configuration validated and saved!")
            time.sleep(1)
            st.rerun()
            
        st.markdown('</div>', unsafe_allow_html=True)
    
    # Display current configuration
    if st.session_state.config_saved:
        st.markdown("#### Current Settings")
        # Display the friendly region name
        region_display = "N. Virginia" if st.session_state.aws_region == "us-east-1" else "Oregon"
        st.markdown(f"**Region:** {region_display}")
        st.markdown(f"**Agent ID:** {st.session_state.agent_id[:6]}...")
        st.markdown(f"**Agent Alias ID:** {st.session_state.agent_alias_id[:6]}...")
        st.markdown(f"**Character Limit:** 10,000 (fixed)")
        st.markdown(f"**File Size Limit:** 5MB")
    
    st.markdown("---")
    
    st.markdown("### Session Info")
    st.markdown(f"**Session ID:** {st.session_state.session_id[:8]}...")
    
    if st.session_state.document_uploaded:
        st.markdown(f"**Document:** {st.session_state.document_name}")
        if st.session_state.analysis_time:
            st.markdown(f"**Analyzed at:** {st.session_state.analysis_time.strftime('%H:%M:%S')}")
        
        # Word count of document
        word_count = len(st.session_state.document_content.split())
        st.markdown(f"**Document size:** {word_count} words")
        
        # Show truncation info if applicable
        if st.session_state.is_truncated:
            original_chars = st.session_state.original_length
            percent_used = (10000 / original_chars) * 100
            st.markdown(f"**Content used:** {percent_used:.1f}% (10,000/{original_chars:,} chars)")
        
        # Add export conversation option
        if st.button("Export Conversation"):
            conversation = ""
            for msg in st.session_state.messages:
                conversation += f"{msg['role'].upper()}: {msg['content']}\n\n"
            
            st.download_button(
                label="Download Conversation",
                data=conversation,
                file_name=f"conversation_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt",
                mime="text/plain"
            )

# Main content area
st.markdown('<h1 class="main-header">Document Analysis with Amazon Bedrock</h1>', unsafe_allow_html=True)

# Production recommendation note in main area
st.markdown(
    '<div class="production-note">'
    '<strong>📝 Production Implementation Note:</strong> This demo uses simple truncation to limit document size and prevent timeouts. '
    'For production applications, implement document chunking with semantic segmentation and context management to process '
    'larger documents efficiently while maintaining context across the entire document.'
    '</div>',
    unsafe_allow_html=True
)

# Check if configuration is saved before allowing document upload
if not st.session_state.config_saved:
    st.warning("⚙️ Please configure your Bedrock Agent settings in the sidebar first.")
else:
    # Create two columns for upload section
    if not st.session_state.document_uploaded:
        col1, col2 = st.columns([2, 1])

        st.warning("⚠️ **IMPORTANT:** While the uploader shows 200MB limit, files over 5MB will be rejected by the code.")
        
        with col1:
            st.markdown("### Upload a document to analyze")
            uploaded_file = st.file_uploader("Choose a document file", 
                                        type=["txt", "pdf", "docx"], 
                                        help="Supported formats: TXT, PDF, DOCX. Maximum file size: 5MB")
    
            # Add a clear note about file size limit
            st.caption("📌 Note: Files larger than 5MB will be rejected by the code")


        
        with col2:
            st.markdown("### About")
            st.info("This application uses Amazon Bedrock to analyze documents and answer questions about their content.")

        # Display file information if uploaded
        if uploaded_file:
            # Check file size
            file_size_mb = len(uploaded_file.getvalue()) / (1024 * 1024)
            if file_size_mb > MAX_FILE_SIZE_MB:
                st.markdown(
                    f'<div class="file-size-error">'
                    f'⚠️ <strong>File too large:</strong> The uploaded file is {file_size_mb:.2f}MB, which exceeds the 5MB limit. '
                    f'Please upload a smaller file.'
                    f'</div>',
                    unsafe_allow_html=True
                )
            else:
                # File is within size limit, display file info
                st.markdown('<div class="file-info">', unsafe_allow_html=True)
                st.write(f"**Filename:** {uploaded_file.name}")
                st.write(f"**Size:** {file_size_mb:.2f} MB")
                file_type = uploaded_file.name.split('.')[-1].upper()
                st.write(f"**Type:** {file_type}")
                st.markdown('</div>', unsafe_allow_html=True)
                
                # Process the document
                if st.button("Analyze Document", key="analyze_btn"):
                    with st.spinner(f"Processing {file_type} document... This may take a moment."):
                        try:
                            # Extract text based on file type
                            file_content = extract_text_from_file(uploaded_file)
                            
                            # Check if content needs truncation
                            original_length = len(file_content)
                            is_truncated = original_length > MAX_CHARS
                            
                            if is_truncated:
                                # Store truncation info
                                st.session_state.is_truncated = True
                                st.session_state.original_length = original_length
                                
                                # Truncate content for API but keep full content for display
                                truncated_content = file_content[:MAX_CHARS]
                                prompt = f"Please analyze this {uploaded_file.name} document (truncated to first 10,000 characters of {original_length:,} total):\n\n{truncated_content}"
                            else:
                                st.session_state.is_truncated = False
                                prompt = f"Please analyze this {uploaded_file.name} document:\n\n{file_content}"
                            
                            # Save full document content to session state
                            st.session_state.document_content = file_content
                            st.session_state.document_name = uploaded_file.name
                            st.session_state.document_uploaded = True
                            st.session_state.analysis_time = datetime.now()
                            
                            # Send to agent
                            response = invoke_agent(prompt)
                            
                            # Add to conversation history
                            st.session_state.messages.append({"role": "user", "content": f"Please analyze this document: {uploaded_file.name}"})
                            st.session_state.messages.append({"role": "assistant", "content": response if isinstance(response, str) else json.dumps(response)})
                            
                            # Force a rerun to show the chat interface
                            st.rerun()
                            
                        except Exception as e:
                            st.error(f"An error occurred: {str(e)}")

    # Chat interface once document is uploaded
    if st.session_state.document_uploaded:
        # Create tabs for different views
        tab1, tab2 = st.tabs(["Chat", "Document Preview"])
        
        with tab1:
            st.markdown(f'<h2 class="sub-header">Conversation about: {st.session_state.document_name}</h2>', unsafe_allow_html=True)
            
            # Show truncation notice if applicable
            if st.session_state.is_truncated:
                st.markdown(
                    f'<div class="truncation-notice">⚠️ Note: The document was truncated to 10,000 characters for analysis '
                    f'(out of {st.session_state.original_length:,} total characters). The AI has analyzed approximately '
                    f'{(MAX_CHARS / st.session_state.original_length) * 100:.1f}% of the document.</div>',
                    unsafe_allow_html=True
                )
                
                # Add chunking recommendation for large documents
                st.markdown(
                    '<div class="production-note">'
                    '<strong>💡 Tip:</strong> For better analysis of this large document in a production environment, '
                    'consider implementing document chunking with semantic segmentation to process the entire content '
                    'while maintaining context across sections.'
                    '</div>',
                    unsafe_allow_html=True
                )
            
            # Create a container for chat messages that appears BEFORE the input
            chat_container = st.container()
            
            # Place the chat input AFTER the message container but store it in a variable
            user_input = st.chat_input("Ask a question about the document...")
            
            # Now display all messages in the container
            with chat_container:
                for message in st.session_state.messages:
                    if message["role"] == "user":
                        with st.chat_message("user", avatar="👤"):
                            st.markdown(f'<div class="user-message">{message["content"]}</div>', unsafe_allow_html=True)
                    else:
                        with st.chat_message("assistant", avatar="🤖"):
                            try:
                                content = message["content"]
                                if isinstance(content, str) and content.startswith("{"):
                                    # Try to parse as JSON
                                    content_json = json.loads(content)
                                    st.json(content_json)
                                else:
                                    st.markdown(f'<div class="assistant-message">{content}</div>', unsafe_allow_html=True)
                            except:
                                st.markdown(f'<div class="assistant-message">{message["content"]}</div>', unsafe_allow_html=True)
            
            # Handle user input after displaying messages
            if user_input:
                # Add user message to chat history
                st.session_state.messages.append({"role": "user", "content": user_input})
                
                # Get response from agent
                with st.spinner("Getting response..."):
                    response = invoke_agent(user_input)
                
                # Add assistant response to chat history
                st.session_state.messages.append({"role": "assistant", "content": response if isinstance(response, str) else json.dumps(response)})
                
                # Rerun to update the chat display with both new messages
                st.rerun()
        
        with tab2:
            st.markdown(f'<h2 class="sub-header">Document Content</h2>', unsafe_allow_html=True)
            
            # Show truncation notice in preview tab too
            if st.session_state.is_truncated:
                st.markdown(
                    f'<div class="truncation-notice">⚠️ Note: For AI analysis, only the first 10,000 characters '
                    f'of this document were used (out of {st.session_state.original_length:,} total).</div>',
                    unsafe_allow_html=True
                )
            
            # Display document content based on file type
            file_extension = st.session_state.document_name.split('.')[-1].lower()
            
            if file_extension == 'pdf':
                st.info("Displaying extracted text from PDF document")
                st.text_area("Document Text", st.session_state.document_content, height=400, disabled=True)
                
            elif file_extension in ['docx', 'doc']:
                st.info("Displaying extracted text from Word document")
                st.text_area("Document Text", st.session_state.document_content, height=400, disabled=True)
                
            else:  # txt and other text formats
                st.text_area("Document Text", st.session_state.document_content, height=400, disabled=True)

    # Reset button to clear session and start over
    if st.session_state.document_uploaded:
        col1, col2 = st.columns([1, 4])
        with col1:
            if st.button("Start Over"):
                reset_session()
                st.rerun()
        with col2:
            if st.button("Edit Agent Configuration"):
                st.session_state.config_saved = False
                st.rerun()

# Footer
st.markdown('<div class="footer">Built with ❤️ using Amazon Bedrock and Streamlit</div>', unsafe_allow_html=True)