Skip to content

Commit

Permalink
add debug output
Browse files Browse the repository at this point in the history
  • Loading branch information
djcopley committed Jan 31, 2025
1 parent a7d6dd6 commit 629f41d
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions src/shelloracle/providers/openai.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from collections.abc import AsyncIterator

from openai import APIError, AsyncOpenAI
Expand All @@ -9,13 +10,14 @@ class OpenAI(Provider):
name = "OpenAI"

api_key = Setting(default="")
model = Setting(default="gpt-3.5-turbo")
base_url = Setting(default="https://api.openai.com/v1")
model = Setting(default="gpt-4o")

def __init__(self):
if not self.api_key:
msg = "No API key provided"
raise ProviderError(msg)
self.client = AsyncOpenAI(api_key=self.api_key)
self.client = AsyncOpenAI(base_url=self.base_url, api_key=self.api_key)

async def generate(self, prompt: str) -> AsyncIterator[str]:
try:
Expand All @@ -25,8 +27,12 @@ async def generate(self, prompt: str) -> AsyncIterator[str]:
stream=True,
)
async for chunk in stream:
if chunk.choices[0].delta.content is not None:
yield chunk.choices[0].delta.content
logging.getLogger(__name__).info(chunk)
try:
if chunk.choices[0].delta.content is not None:
yield chunk.choices[0].delta.content
except:
...
except APIError as e:
msg = f"Something went wrong while querying OpenAI: {e}"
raise ProviderError(msg) from e

0 comments on commit 629f41d

Please sign in to comment.