from langchain_openai import AzureOpenAIEmbeddings, AzureOpenAI, AzureChatOpenAI
from langchain_ollama import OllamaEmbeddings, OllamaLLM
from langchain_core.messages import HumanMessage, AIMessage
from langchain.chains import LLMChain
from langchain_core.prompts import PromptTemplate
import os
import json
import transport

# from langchain.chains import ConversationChain
from langchain.memory import ConversationBufferMemory

from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser

import cms
import uuid
from multiprocessing import Process 
import pandas as pd
import copy

# from gtts import gTTS
import io
import requests

# USE_OPENAI = os.environ.get('USE_OPENAI',0)

class ILLM :
    def __init__(self,**_args) :
        self.USE_OPENAI = 'openai' in _args
        # self._args = _args['ollama'] if not self.USE_OPENAI else _args['openai']
        _path = _args['ollama'] if not self.USE_OPENAI else _args['openai']
        f = open(_path)
        self._args = json.loads( f.read() )
        f.close()
        self._prompt = _args['prompt'] if 'prompt' in _args else {}
        self._token = _args['token']
    def embed(self,_question):
        _pointer =  AzureOpenAIEmbeddings if self.USE_OPENAI else OllamaEmbeddings
        _kwargs = self._args if 'embedding' not in self._args else self._args['embedding']
        _handler = _pointer(**_kwargs)
        return _handler.embed_query(_question)
    
   

    def answer(self,_question,_context) :
        """
        This function will answer a question against a LLM backend (Ollama or OpenAI)
        """
        
        _pointer =  AzureChatOpenAI if self.USE_OPENAI else OllamaLLM
        _kwargs = self._args if 'completion' not in self._args else self._args['completion']
        
        _prompt = PromptTemplate(**self._prompt)
        _llm =  _pointer(**_kwargs)
        _memory = ConversationBufferMemory(memory_key=f'{self._token}', return_messages=True)
        
        # chain   = LLMChain(llm=_llm,prompt=_prompt,memory=_memory)
        # _input = {'context':_context,'question':_question,'chat_history':''}
        _memory = ConversationBufferMemory(memory_key=f'{self._token}',return_messages=True)
        _schema = 'openai' if self.USE_OPENAI else 'ollama'
        pgr = transport.get.reader(label='llm', schema=_schema)
        _sql = f"select question, answer::JSON ->>'summary' as answer from llm_logs where token = '{self._token}' ORDER BY _date DESC LIMIT 10"
        _ldf = pgr.read(sql=_sql)
        _ldf.apply(lambda row: [_memory.chat_memory.add_user_message(row.question), _memory.chat_memory.add_ai_message(row.answer)] , axis=1)
        chain = (
                RunnablePassthrough.assign(
                    context=lambda _x: _context,
                    chat_history=lambda _x : _memory.chat_memory.messages if _memory.chat_memory.messages else []
                )
                | _prompt
                | _llm
                | StrOutputParser()
            )
        chain.invoke({'question':_question})

        # _input = json.loads(json.dumps(_input))
        resp = chain.invoke( {'question':_question})
        
        # #
        # # add question and answers to the _memory object so we can submit them next time around
        # # @TODO:
        # _memory.chat_memory.add_user_message(_question)
        # _memory.chat_memory.add_ai_message(resp)

        return {'text':resp}
    def schema(self):
        return 'openai' if self.USE_OPENAI else 'public'
    def documents(self,_vector) :
        _schema = 'openai' if self.USE_OPENAI else 'ollama'
        pgr = transport.get.reader(label='llm', schema=_schema)
        sql = f"""SELECT file, name, page, content, embeddings <-> '{json.dumps(_vector)}' similarity FROM {_schema}.documents
        ORDER BY similarity ASC
        LIMIT 5
        """
        _df = pgr.read(sql=sql)
        pgr.close()

        return _df
    def lookup (self,index:int,token) :
        _schema = 'openai' if self.USE_OPENAI else 'ollama'
        pgr = transport.get.reader(label='llm', schema=_schema)
        index = int(index)  + 1
        _sql = f"SELECT * FROM (select row_number() over(partition by token) as row_index, answer from llm_logs where token='{token}') as _x where row_index    = {index}"
        # print (_sql)
        _df = pgr.read(sql= _sql)
        return _df.answer[0] if _df.shape[0] > 0 else None
@cms.Plugin(mimetype="application/json",method="POST")
def answer (**_args):
    
    _request = _args['request']
    _config = _args['config']['system']['source']['llm']

    _question = _request.json['question']
    token = str(uuid.uuid4()) if 'token' not in _request.json else _request.json['token']
    _index = _request.json['index'] if 'index' in _request.json else 0 
    _config['token'] = token
    _llmproc = ILLM(**_config)
    

    #
    # Turn the question into a vector and send it to the LLM Server
    #
    _vector = _llmproc.embed(_question)
    _df = _llmproc.documents(_vector)
    _pages = _df.apply(lambda row: row.content,axis=1).tolist()
    
    #
    # We should also pull the previous questions/answers


    # return _df[['name','page','similarity']].to_dict(orient='records')
    #
    # Let us submit the question to the llm-server
    #
    resp =  _llmproc.answer(_question, _pages)
    #
    # @TODO :
    #   - Log questions, answers and sources to see what things are like
    _out = {"token":token,"openai":_llmproc.USE_OPENAI,"answer":resp["text"],"documents": _df[["name","page"]].to_dict(orient='records')}
    
    
    try:
        def _logger():
            _log = pd.DataFrame([dict(_out,**{'question':_question})])
            _log.documents = _df[["name","page"]].to_json(orient='records')
            pgw = transport.get.writer (label='llm',table='llm_logs')
            pgw.write(_log)
            pgw.close()
        #
        # send the thread 
        pthread = Process(target=_logger)
        pthread.start()
        _out['answer'] = json.loads(_out['answer'])
    except Exception as e:
        print (e)
    _context = _args['config']['system']['context'].strip()
    _out['stream'] = f'{_context}/api/medix/audio?token={token}&index={_index}'
    # _out['index']  = _index
    return _out

@cms.Plugin(mimetype="text/plain",method="POST")
def info (**_args):
    _config = _args['config']
    return 'openai' if 'openai' in _config['system']['source']['llm'] else 'ollama'
@cms.Plugin(mimetype="audio/mpeg",method="GET")
def audio (**_args):
    _request = _args['request']
    _config = _args['config']['system']['source']['llm']
    _index = _request.args['index']
    _token = _request.args['token']
    _config['token'] = _token

    _llmproc = ILLM(**_config)
    _stream = _llmproc.lookup(_index,_token)
    if _stream.strip().startswith('{') :
        _text = json.loads(_stream)['summary']
    else:
        _text = _stream.strip()
    _text = _text.replace('\n',' ').replace('\r',' ').replace('  ',' ').strip()
    
    r = requests.post(f"http://localhost:5002/api/tts",headers={"text":f"""{_text}""","Content-Type":"application/json"},stream=True)
    # r = requests.get(f"""http://localhost:5000/api/tts?text={_text}""")
    # f = open('/home/steve/tmp/out.wav','w')
    # f.write(r.content)
    # f.close()
    return r.content
   
    # g = gTTS(_text,lang='en')
    # return g.stream()
    # stream = io.BytesIO()
    # for line in g.stream() :
    #     stream.write(line)
    # stream.seek(0)
    # return stream #g.stream()