gauthamnairy commited on
Commit
ced14f2
·
verified ·
1 Parent(s): 385769a

Update llm_config.py

Browse files
Files changed (1) hide show
  1. llm_config.py +41 -39
llm_config.py CHANGED
@@ -1,39 +1,41 @@
1
- import os
2
- from openai import OpenAI
3
-
4
- def get_llm_client(provider="nvidia"):
5
- """
6
- Returns an OpenAI client configured for the specified provider.
7
-
8
- Args:
9
- provider (str): "nvidia" or "mistral"
10
-
11
- Returns:
12
- OpenAI: The configured client
13
- """
14
- if provider == "nvidia":
15
- # Llama 4 Maverick via NVIDIA NIM
16
- api_key = os.getenv("NVIDIA_API_KEY")
17
- if not api_key:
18
- print("Warning: NVIDIA_API_KEY not found in environment variables.")
19
-
20
- return OpenAI(
21
- base_url="https://integrate.api.nvidia.com/v1",
22
- api_key=api_key
23
- )
24
- else:
25
- # Mistral Large 3 via Mistral API
26
- api_key = os.getenv("MISTRAL_API_KEY")
27
- if not api_key:
28
- print("Warning: MISTRAL_API_KEY not found in environment variables.")
29
-
30
- return OpenAI(
31
- base_url="https://api.mistral.ai/v1",
32
- api_key=api_key
33
- )
34
-
35
- def get_model_name(provider="nvidia"):
36
- if provider == "nvidia":
37
- return "meta/llama-4-maverick"
38
- else:
39
- return "mistral-large-latest"
 
 
 
1
+ import os
2
+ from openai import OpenAI
3
+
4
+ def get_llm_client(provider="nvidia"):
5
+ """
6
+ Returns an OpenAI client configured for the specified provider.
7
+
8
+ Args:
9
+ provider (str): "nvidia" or "mistral"
10
+
11
+ Returns:
12
+ OpenAI: The configured client
13
+ """
14
+ if provider == "nvidia":
15
+ # Llama 4 Maverick via NVIDIA NIM
16
+ api_key = os.getenv("NVIDIA_API_KEY")
17
+ if not api_key:
18
+ print("Warning: NVIDIA_API_KEY not found in environment variables.")
19
+
20
+ return OpenAI(
21
+ base_url="https://integrate.api.nvidia.com/v1",
22
+ api_key=api_key
23
+ )
24
+ else:
25
+ # Mistral Large 3 via Mistral API
26
+ api_key = os.getenv("MISTRAL_API_KEY")
27
+ if not api_key:
28
+ print("Warning: MISTRAL_API_KEY not found in environment variables.")
29
+
30
+ return OpenAI(
31
+ base_url="https://api.mistral.ai/v1",
32
+ api_key=api_key
33
+ )
34
+
35
+ def get_model_name(provider="nvidia"):
36
+ if provider == "nvidia":
37
+ # Meta Llama 3.1 70B is stable and common on NVIDIA NIM
38
+ return "meta/llama-3.1-70b-instruct"
39
+ else:
40
+ # Mistral Large latest ID
41
+ return "mistral-large-latest"