diff --git a/main_service.py b/main_service.py index d184e86..91a38ed 100644 --- a/main_service.py +++ b/main_service.py @@ -47,6 +47,8 @@ def create_vana(): "api_key": config('CHAT_MODEL_API_KEY', default=''), "api_base": config('CHAT_MODEL_BASE_URL', default=''), "model": config('CHAT_MODEL_NAME', default=''), + 'temperature':config('CHAT_MODEL_TEMPERATURE', default=0.7, cast=float), + 'max_tokens':config('CHAT_MODEL_MAX_TOKEN', default=20000), }, ) diff --git a/service/cus_vanna_srevice.py b/service/cus_vanna_srevice.py index 56a96c0..0de0f99 100644 --- a/service/cus_vanna_srevice.py +++ b/service/cus_vanna_srevice.py @@ -26,7 +26,7 @@ class OpenAICompatibleLLM(VannaBase): def __init__(self, client=None, config_file=None): VannaBase.__init__(self, config=config_file) # default parameters - can be overrided using config - self.temperature = 0.5 + self.temperature = 0.6 self.max_tokens = 5000 if "temperature" in config_file: @@ -175,6 +175,7 @@ class OpenAICompatibleLLM(VannaBase): print( f"Using model {self.config['model']} for {num_tokens} tokens (approx)" ) + print(self.config) response = self.client.chat.completions.create( model=self.config["model"], messages=prompt, @@ -208,7 +209,6 @@ class OpenAICompatibleLLM(VannaBase): stop=None, temperature=self.temperature, ) - for choice in response.choices: if "text" in choice: return choice.text