File tree 3 files changed +19
-1
lines changed
src/khoj/processor/conversation
3 files changed +19
-1
lines changed Original file line number Diff line number Diff line change 8
8
from khoj .database .models import Agent , ChatModel , KhojUser
9
9
from khoj .processor .conversation import prompts
10
10
from khoj .processor .conversation .openai .utils import (
11
+ ai_api_supports_json_enforcement ,
11
12
chat_completion_with_backoff ,
12
13
completion_with_backoff ,
13
14
)
14
15
from khoj .processor .conversation .utils import (
16
+ JsonSupport ,
15
17
clean_json ,
16
18
construct_structured_message ,
17
19
generate_chatml_messages_with_context ,
@@ -126,13 +128,14 @@ def send_message_to_model(
126
128
"""
127
129
128
130
# Get Response from GPT
131
+ json_support = ai_api_supports_json_enforcement (model , api_base_url )
129
132
return completion_with_backoff (
130
133
messages = messages ,
131
134
model_name = model ,
132
135
openai_api_key = api_key ,
133
136
temperature = temperature ,
134
137
api_base_url = api_base_url ,
135
- model_kwargs = {"response_format" : {"type" : response_type }},
138
+ model_kwargs = {"response_format" : {"type" : response_type }} if json_support >= JsonSupport . OBJECT else {} ,
136
139
tracer = tracer ,
137
140
)
138
141
Original file line number Diff line number Diff line change 16
16
)
17
17
18
18
from khoj .processor .conversation .utils import (
19
+ JsonSupport ,
19
20
ThreadedGenerator ,
20
21
commit_conversation_trace ,
21
22
)
@@ -245,3 +246,11 @@ def llm_thread(
245
246
logger .error (f"Error in llm_thread: { e } " , exc_info = True )
246
247
finally :
247
248
g .close ()
249
+
250
+
251
+ def ai_api_supports_json_enforcement (model_name : str , api_base_url : str = None ) -> JsonSupport :
252
+ if model_name .startswith ("deepseek-reasoner" ):
253
+ return JsonSupport .NONE
254
+ if ".ai.azure.com" in api_base_url :
255
+ return JsonSupport .OBJECT
256
+ return JsonSupport .SCHEMA
Original file line number Diff line number Diff line change @@ -878,3 +878,9 @@ def safe_serialize(content: Any) -> str:
878
878
return str (content )
879
879
880
880
return "\n " .join ([f"{ json .dumps (safe_serialize (message .content ))[:max_length ]} ..." for message in messages ])
881
+
882
+
883
+ class JsonSupport (int , Enum ):
884
+ NONE = 0
885
+ OBJECT = 1
886
+ SCHEMA = 2
You can’t perform that action at this time.
0 commit comments