|
- import dataclasses
- import logging
- import math
- import os
- import io
- import sys
- import time
- import json
- from typing import Optional, Sequence, Union
-
- import openai
- import tqdm
- from openai import openai_object
- import copy
-
- StrOrOpenAIObject = Union[str, openai_object.OpenAIObject]
-
- openai_org = os.getenv("OPENAI_ORG")
- if openai_org is not None:
- openai.organization = openai_org
- logging.warning(f"Switching to organization: {openai_org} for OAI API key.")
-
-
- @dataclasses.dataclass
- class OpenAIDecodingArguments(object):
- max_tokens: int = 1800
- temperature: float = 0.2
- top_p: float = 1.0
- n: int = 1
- stream: bool = False
- stop: Optional[Sequence[str]] = None
- presence_penalty: float = 0.0
- frequency_penalty: float = 0.0
- suffix: Optional[str] = None
- logprobs: Optional[int] = None
- echo: bool = False
-
-
- def openai_completion(
- prompts: Union[str, Sequence[str], Sequence[dict[str, str]], dict[str, str]],
- decoding_args: OpenAIDecodingArguments,
- model_name="text-davinci-003",
- sleep_time=2,
- batch_size=1,
- max_instances=sys.maxsize,
- max_batches=sys.maxsize,
- return_text=False,
- **decoding_kwargs,
- ) -> Union[Union[StrOrOpenAIObject], Sequence[StrOrOpenAIObject], Sequence[Sequence[StrOrOpenAIObject]],]:
- """Decode with OpenAI API.
-
- Args:
- prompts: A string or a list of strings to complete. If it is a chat model the strings should be formatted
- as explained here: https://github.com/openai/openai-python/blob/main/chatml.md. If it is a chat model
- it can also be a dictionary (or list thereof) as explained here:
- https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb
- decoding_args: Decoding arguments.
- model_name: Model name. Can be either in the format of "org/model" or just "model".
- sleep_time: Time to sleep once the rate-limit is hit.
- batch_size: Number of prompts to send in a single request. Only for non chat model.
- max_instances: Maximum number of prompts to decode.
- max_batches: Maximum number of batches to decode. This argument will be deprecated in the future.
- return_text: If True, return text instead of full completion object (which contains things like logprob).
- decoding_kwargs: Additional decoding arguments. Pass in `best_of` and `logit_bias` if you need them.
-
- Returns:
- A completion or a list of completions.
- Depending on return_text, return_openai_object, and decoding_args.n, the completion type can be one of
- - a string (if return_text is True)
- - an openai_object.OpenAIObject object (if return_text is False)
- - a list of objects of the above types (if decoding_args.n > 1)
- """
- is_single_prompt = isinstance(prompts, (str, dict))
- if is_single_prompt:
- prompts = [prompts]
-
- if max_batches < sys.maxsize:
- logging.warning(
- "`max_batches` will be deprecated in the future, please use `max_instances` instead."
- "Setting `max_instances` to `max_batches * batch_size` for now."
- )
- max_instances = max_batches * batch_size
-
- prompts = prompts[:max_instances]
- num_prompts = len(prompts)
- prompt_batches = [
- prompts[batch_id * batch_size : (batch_id + 1) * batch_size]
- for batch_id in range(int(math.ceil(num_prompts / batch_size)))
- ]
-
- completions = []
- for batch_id, prompt_batch in tqdm.tqdm(
- enumerate(prompt_batches),
- desc="prompt_batches",
- total=len(prompt_batches),
- ):
- batch_decoding_args = copy.deepcopy(decoding_args) # cloning the decoding_args
-
- while True:
- try:
- shared_kwargs = dict(
- model=model_name,
- **batch_decoding_args.__dict__,
- **decoding_kwargs,
- )
- completion_batch = openai.Completion.create(prompt=prompt_batch, **shared_kwargs)
- choices = completion_batch.choices
-
- for choice in choices:
- choice["total_tokens"] = completion_batch.usage.total_tokens
- completions.extend(choices)
- break
- except openai.error.OpenAIError as e:
- logging.warning(f"OpenAIError: {e}.")
- if "Please reduce your prompt" in str(e):
- batch_decoding_args.max_tokens = int(batch_decoding_args.max_tokens * 0.8)
- logging.warning(f"Reducing target length to {batch_decoding_args.max_tokens}, Retrying...")
- else:
- logging.warning("Hit request rate limit; retrying...")
- time.sleep(sleep_time) # Annoying rate limit on requests.
-
- if return_text:
- completions = [completion.text for completion in completions]
- if decoding_args.n > 1:
- # make completions a nested list, where each entry is a consecutive decoding_args.n of original entries.
- completions = [completions[i : i + decoding_args.n] for i in range(0, len(completions), decoding_args.n)]
- if is_single_prompt:
- # Return non-tuple if only 1 input and 1 generation.
- (completions,) = completions
- return completions
-
-
- def _make_w_io_base(f, mode: str):
- if not isinstance(f, io.IOBase):
- f_dirname = os.path.dirname(f)
- if f_dirname != "":
- os.makedirs(f_dirname, exist_ok=True)
- f = open(f, mode=mode)
- return f
-
-
- def _make_r_io_base(f, mode: str):
- if not isinstance(f, io.IOBase):
- f = open(f, mode=mode)
- return f
-
-
- def jdump(obj, f, mode="w", indent=4, default=str):
- """Dump a str or dictionary to a file in json format.
-
- Args:
- obj: An object to be written.
- f: A string path to the location on disk.
- mode: Mode for opening the file.
- indent: Indent for storing json dictionaries.
- default: A function to handle non-serializable entries; defaults to `str`.
- """
- f = _make_w_io_base(f, mode)
- if isinstance(obj, (dict, list)):
- json.dump(obj, f, indent=indent, default=default)
- elif isinstance(obj, str):
- f.write(obj)
- else:
- raise ValueError(f"Unexpected type: {type(obj)}")
- f.close()
-
-
- def jload(f, mode="r"):
- """Load a .json file into a dictionary."""
- f = _make_r_io_base(f, mode)
- jdict = json.load(f)
- f.close()
- return jdict
|