diff --git a/opto/optimizers/__init__.py b/opto/optimizers/__init__.py index 482b1b2d..a41b6d34 100644 --- a/opto/optimizers/__init__.py +++ b/opto/optimizers/__init__.py @@ -4,7 +4,9 @@ from opto.optimizers.opro_v2 import OPROv2 from opto.optimizers.textgrad import TextGrad from opto.optimizers.optoprime_v2 import OptoPrimeV2 +from opto.optimizers.optoprime_v3 import OptoPrimeV3 +from opto.optimizers.opro_v3 import OPROv3 OptoPrime = OptoPrimeV1 -__all__ = ["OPRO", "OptoPrime", "OptoPrimeMulti", "TextGrad", "OptoPrimeV2", "OptoPrimeV1", "OPROv2"] \ No newline at end of file +__all__ = ["OPRO", "OptoPrime", "OptoPrimeMulti", "TextGrad", "OptoPrimeV2", "OptoPrimeV1", "OPROv2", "OptoPrimeV3", "OPROv3"] \ No newline at end of file diff --git a/opto/optimizers/opro_v3.py b/opto/optimizers/opro_v3.py new file mode 100644 index 00000000..20575b88 --- /dev/null +++ b/opto/optimizers/opro_v3.py @@ -0,0 +1,543 @@ +""" +Key difference to v2: +1. Use the new backbone conversation history manager +2. Support multimodal node (both trainable and non-trainable) +3. Break from the OptoPrime style template, support more customizable template from user, for brevity and streamlined usage. +""" + +from textwrap import dedent +from dataclasses import dataclass +from typing import Dict, Optional, List, Union +from opto.trace.nodes import ParameterNode + +from opto.optimizers.optoprime_v3 import OptoPrimeV3, OptimizerPromptSymbolSet +from opto.utils.backbone import ( + ContentBase, ImageContent, ContentBlockList, + DEFAULT_IMAGE_PLACEHOLDER +) + +# Not inheriting from optoprime_v2 because this should have a smaller set +class OPROPromptSymbolSet(OptimizerPromptSymbolSet): + """Prompt symbol set for OPRO optimizer. + + This class defines the tags and symbols used in the OPRO optimizer's prompts + and output parsing. It provides a structured way to format problems and parse + responses from the language model. + + Attributes + ---------- + instruction_section_title : str + Title for the instruction section in prompts. + variable_section_title : str + Title for the variable/solution section in prompts. + feedback_section_title : str + Title for the feedback section in prompts. + node_tag : str + Tag used to identify constant nodes in the computation graph. + variable_tag : str + Tag used to identify variable nodes that can be optimized. + value_tag : str + Tag used to wrap the value of a node. + constraint_tag : str + Tag used to wrap constraint expressions for nodes. + reasoning_tag : str + Tag used to wrap reasoning in the output. + improved_variable_tag : str + Tag used to wrap improved variable values in the output. + name_tag : str + Tag used to wrap variable names. + expect_json : bool + Whether to expect JSON output format (default: False). + + Methods + ------- + default_prompt_symbols + Returns default prompt symbols dictionary. + + Notes + ----- + This class inherits from OptimizerPromptSymbolSet but defines a smaller, + more focused set of symbols specifically for OPRO optimization. + """ + + instruction_section_title = "# Instruction" + variables_section_title = "# Solution" + feedback_section_title = "# Feedback" + context_section_title = "# Context" + + node_tag = "node" # nodes that are constants in the graph + variable_tag = "solution" # nodes that can be changed + value_tag = "value" # inside node, we have value tag + constraint_tag = "constraint" # inside node, we have constraint tag + + # output format + # Note: we currently don't support extracting format's like "```code```" because we assume supplied tag is name-only, i.e., + reasoning_tag = "reasoning" + improved_variable_tag = "variable" + name_tag = "name" + + expect_json = False # this will stop `enforce_json` arguments passed to LLM calls + + @property + def default_prompt_symbols(self) -> Dict[str, str]: + return { + "variables": self.variables_section_title, + "feedback": self.feedback_section_title, + "instruction": self.instruction_section_title, + "context": self.context_section_title + } + +@dataclass +class ProblemInstance: + """Represents a problem instance for OPRO optimization. + + This dataclass encapsulates a complete problem instance including the + instruction, current variables/solution, and feedback received. + + Supports multimodal content - variables can contain images. + + Attributes + ---------- + instruction : str + The instruction describing what needs to be done or the question to answer. + variables : Union[str, List[ContentBase]] + The current proposed solution that can be modified. Can contain images. + feedback : str + Feedback about the current solution. + context: str + Optional context information that might be useful to solve the problem. + + optimizer_prompt_symbol_set : OPROPromptSymbolSet + The symbol set used for formatting the problem. + problem_template : str + Template for formatting the problem instance as a string. + + Methods + ------- + __repr__() + Returns a formatted string representation of the problem instance. + to_content_blocks() + Returns a ContentBlockList for multimodal prompts. + has_images() + Returns True if the problem instance contains images. + + Notes + ----- + The problem instance is formatted using the problem_template which + organizes the instruction, variables, and feedback into a structured format. + """ + instruction: str + variables: Union[str, List[ContentBase]] + feedback: str + context: Optional[ContentBlockList] + + optimizer_prompt_symbol_set: OPROPromptSymbolSet + + problem_template = dedent( + """ + # Instruction + {instruction} + + # Solution + {variables} + + # Feedback + {feedback} + """ + ) + + @staticmethod + def _content_to_text(content: Union[str, List[ContentBase]]) -> str: + """Convert content (str or List[ContentBlock]) to text representation. + + Handles both string content and ContentBlockList/List[ContentBlock]. + Uses ContentBlockList.blocks_to_text for list content. + """ + if isinstance(content, str): + return content + # Use the shared utility from ContentBlockList + return ContentBlockList.blocks_to_text(content, DEFAULT_IMAGE_PLACEHOLDER) + + def __repr__(self) -> str: + """Return text-only representation for backward compatibility.""" + optimization_query = self.problem_template.format( + instruction=self.instruction, + variables=self._content_to_text(self.variables), + feedback=self.feedback, + ) + + context_section = dedent(""" + + # Context + {context} + """) + + if self.context is not None and self.context.to_text().strip() != "": + context_section = context_section.format(context=self.context.to_text()) + optimization_query += context_section + + return optimization_query + + def to_content_blocks(self) -> ContentBlockList: + """Convert the problem instance to a list of ContentBlocks. + + Consecutive TextContent blocks are merged into a single block for efficiency. + Images and other non-text blocks are kept separate. + + Returns: + ContentBlockList: A list containing TextContent and ImageContent blocks + that represent the complete problem instance. + """ + blocks = ContentBlockList() + + # Instruction section + blocks.append(f"# Instruction\n{self.instruction}\n\n# Solution\n") + + # Variables/Solution section (may contain images) + blocks.extend(self.variables) + + # Feedback section + blocks.append(f"\n\n# Feedback\n{self.feedback}") + + # Context section (optional) + if self.context is not None and self.context.to_text().strip() != "": + blocks.append(f"\n\n# Context\n") + blocks.extend(self.context) + + return blocks + + def has_images(self) -> bool: + """Check if this problem instance contains any images. + + Returns: + bool: True if variables field contains ImageContent blocks. + """ + if isinstance(self.variables, list): + for block in self.variables: + if isinstance(block, ImageContent): + return True + return False + +class OPROv3(OptoPrimeV3): + """OPRO (Optimization by PROmpting) optimizer version 2. + + OPRO is an optimization algorithm that leverages large language models to + iteratively improve solutions based on feedback. It treats optimization as + a natural language problem where the LLM proposes improvements to variables + based on instruction and feedback. + + Parameters + ---------- + *args + Variable length argument list passed to parent class. + optimizer_prompt_symbol_set : OptimizerPromptSymbolSet, optional + The symbol set for formatting prompts and parsing outputs. + Defaults to OPROPromptSymbolSet(). + include_example : bool, optional + Whether to include examples in the prompt. Default is False as + the default example in OptoPrimeV2 does not work well with OPRO. + memory_size : int, optional + Number of past optimization steps to remember. Default is 5. + **kwargs + Additional keyword arguments passed to parent class. + + Attributes + ---------- + representation_prompt : str + Template for explaining the problem representation to the LLM. + output_format_prompt_template : str + Template for specifying the expected output format. + user_prompt_template : str + Template for presenting the problem instance to the LLM. + final_prompt : str + Template for requesting the final revised solutions. + default_objective : str + Default objective when none is specified. + + Methods + ------- + problem_instance(summary, mask=None) + Creates a ProblemInstance from an optimization summary. + initialize_prompt() + Initializes and formats the prompt templates. + + Notes + ----- + OPRO differs from OptoPrime by focusing on simpler problem representations + and clearer feedback incorporation. It is particularly effective for + problems where the optimization can be expressed in natural language. + + See Also + -------- + OptoPrimeV2 : Parent class providing core optimization functionality. + OPROPromptSymbolSet : Symbol set used for formatting. + + Examples + -------- + >>> optimizer = OPROv3(memory_size=10) + >>> # Use optimizer to improve solutions based on feedback + """ + representation_prompt = dedent( + """ + You're tasked to change the proposed solution according to feedback. + + Specifically, a problem will be composed of the following parts: + - {instruction_section_title}: the instruction which describes the things you need to do or the question you should answer. + - {variables_section_title}: the proposed solution that you can change/tweak (trainable). + - {feedback_section_title}: the feedback about the solution. + - {context_section_title}: the context information that might be useful to solve the problem. + + If `data_type` is `code`, it means `{value_tag}` is the source code of a python code, which may include docstring and definitions. + """ + ) + + output_format_prompt_template = dedent( + """ + Output_format: Your output should be in the following XML/HTML format: + + ``` + {output_format} + ``` + + In <{reasoning_tag}>, explain the problem: 1. what the {instruction_section_title} means 2. what the {feedback_section_title} means to {variables_section_title} considering how {variables_section_title} follow {instruction_section_title}. 3. Reasoning about the suggested changes in {variables_section_title} (if needed) and the expected result. + + If you need to suggest a change in the values of {variables_section_title}, write down the suggested values in <{improved_variable_tag}>. Remember you can change only the values in {variables_section_title}, not others. When `type` of a variable is `code`, you should write the new definition in the format of python code without syntax errors, and you should not change the function name or the function signature. + + If no changes are needed, just output TERMINATE. + """ + ) + + user_prompt_template = dedent( + """ + Now you see problem instance: + + ================================ + {problem_instance} + ================================ + + """ + ) + + context_prompt = dedent( + """ + Here is some additional **context** to solving this problem: + + {context} + """ + ) + + final_prompt = dedent( + """ + What are your revised solutions on {names}? + + Your response: + """ + ) + + # Default Objective becomes instruction for the next block + default_objective = "Propose a new solution that will incorporate the feedback." + + def __init__(self, *args, + optimizer_prompt_symbol_set: OptimizerPromptSymbolSet = None, + include_example=False, # default example in OptoPrimeV2 does not work in OPRO + memory_size=5, + problem_context: Optional[ContentBlockList] = None, + **kwargs): + """Initialize the OPROv2 optimizer. + + Parameters + ---------- + *args + Variable length argument list passed to parent class. + optimizer_prompt_symbol_set : OptimizerPromptSymbolSet, optional + The symbol set for formatting prompts and parsing outputs. + If None, uses OPROPromptSymbolSet(). + include_example : bool, optional + Whether to include examples in the prompt. Default is False. + memory_size : int, optional + Number of past optimization steps to remember. Default is 5. + **kwargs + Additional keyword arguments passed to parent class. + """ + optimizer_prompt_symbol_set = optimizer_prompt_symbol_set or OPROPromptSymbolSet() + super().__init__(*args, optimizer_prompt_symbol_set=optimizer_prompt_symbol_set, + include_example=include_example, memory_size=memory_size, + problem_context=problem_context, + **kwargs) + + def parameter_check(self, parameters: List[ParameterNode]): + """Check if the parameters are valid. + This can be overloaded by subclasses to add more checks. + + Args: + parameters: List[ParameterNode] + The parameters to check. + + Raises: + AssertionError: If more than one parameter contains image data. + + Notes: + OPROv2 supports image parameters, but only one parameter can be + an image at a time since LLMs can only generate one image per inference. + """ + # Count image parameters + image_params = [param for param in parameters if param.is_image] + + if len(image_params) > 1: + param_names = ', '.join([f"'{p.name}'" for p in image_params]) + raise AssertionError( + f"OPROv2 supports at most one image parameter, but found {len(image_params)}: " + f"{param_names}. LLMs can only generate one image at a time." + ) + + def problem_instance(self, summary, mask=None, use_content_blocks=False): + """Create a ProblemInstance from an optimization summary. + + Parameters + ---------- + summary : object + The optimization summary containing variables and feedback. + mask : list, optional + List of sections to mask/hide in the problem instance. + Can include "#Instruction", variable section title, or feedback section title. + use_content_blocks : bool, optional + If True, use content blocks for multimodal support (images). + If False, use text-only representation. + + Returns + ------- + ProblemInstance + A formatted problem instance ready for presentation to the LLM. + + Notes + ----- + The mask parameter allows selective hiding of problem components, + useful for ablation studies or specific optimization strategies. + """ + mask = mask or [] + + if use_content_blocks: + # Use content block representation for multimodal support + variables_content = ( + self.repr_node_value_compact_as_content_blocks( + summary.variables, + node_tag=self.optimizer_prompt_symbol_set.variable_tag, + value_tag=self.optimizer_prompt_symbol_set.value_tag, + constraint_tag=self.optimizer_prompt_symbol_set.constraint_tag + ) + if self.optimizer_prompt_symbol_set.variables_section_title not in mask + else ContentBlockList() + ) + else: + # Use text-only representation (backward compatible) + variables_content = ( + self.repr_node_value_compact( + summary.variables, + node_tag=self.optimizer_prompt_symbol_set.variable_tag, + value_tag=self.optimizer_prompt_symbol_set.value_tag, + constraint_tag=self.optimizer_prompt_symbol_set.constraint_tag + ) + if self.optimizer_prompt_symbol_set.variables_section_title not in mask + else "" + ) + + return ProblemInstance( + instruction=self.objective if "#Instruction" not in mask else "", + variables=variables_content, + feedback=summary.user_feedback if self.optimizer_prompt_symbol_set.feedback_section_title not in mask else "", + context=self.problem_context if hasattr(self, 'problem_context') else None, + optimizer_prompt_symbol_set=self.optimizer_prompt_symbol_set + ) + + def repr_node_value_compact_as_content_blocks(self, node_dict, node_tag="node", + value_tag="value", constraint_tag="constraint") -> ContentBlockList: + """Returns a ContentBlockList with compact representation, including images. + + Consecutive TextContent blocks are merged for efficiency. + Non-image values are truncated. Images break the text flow. + """ + from opto.optimizers.optoprime_v3 import value_to_image_content + + blocks = ContentBlockList() + + for k, v in node_dict.items(): + value_data = v[0] + constraint = v[1] + + if "__code" not in k: + # Check if this is an image + image_content = value_to_image_content(value_data) + + if image_content is not None: + # Image node: output XML structure, then image, then closing + type_name = "image" + constraint_expr = f"<{constraint_tag}>\n{constraint}\n" if constraint is not None and node_tag == self.optimizer_prompt_symbol_set.variable_tag else "" + + xml_text = f"<{node_tag} name=\"{k}\" type=\"{type_name}\">\n<{value_tag}>\n" + blocks.append(xml_text) + blocks.append(image_content) # Image breaks the text flow + + closing_text = f"\n\n{constraint_expr}\n\n" if constraint_expr else f"\n\n\n\n" + blocks.append(closing_text) + else: + # Non-image node: truncated text representation + node_value = self.truncate_expression(value_data, self.initial_var_char_limit) + if constraint is not None and node_tag == self.optimizer_prompt_symbol_set.variable_tag: + constraint_expr = f"<{constraint_tag}>\n{constraint}\n" + blocks.append( + f"<{node_tag} name=\"{k}\" type=\"{type(value_data).__name__}\">\n<{value_tag}>\n{node_value}\n\n{constraint_expr}\n\n\n" + ) + else: + blocks.append( + f"<{node_tag} name=\"{k}\" type=\"{type(value_data).__name__}\">\n<{value_tag}>\n{node_value}\n\n\n\n" + ) + else: + # Code node (never an image) + constraint_expr = f"<{constraint_tag}>\n{constraint}\n" + signature = constraint.replace("The code should start with:\n", "") + func_body = value_data.replace(signature, "") + node_value = self.truncate_expression(func_body, self.initial_var_char_limit) + blocks.append( + f"<{node_tag} name=\"{k}\" type=\"code\">\n<{value_tag}>\n{signature}{node_value}\n\n{constraint_expr}\n\n\n" + ) + + return blocks + + def initialize_prompt(self): + """Initialize and format the prompt templates. + + This method formats the representation_prompt and output_format_prompt + templates with the appropriate symbols from the optimizer_prompt_symbol_set. + It prepares the prompts for use in optimization. + + Notes + ----- + This method should be called during initialization to ensure all + prompt templates are properly formatted with the correct tags and symbols. + """ + self.representation_prompt = self.representation_prompt.format( + variable_expression_format=dedent(f""" + <{self.optimizer_prompt_symbol_set.variable_tag} name="variable_name" type="data_type"> + <{self.optimizer_prompt_symbol_set.value_tag}> + value + + <{self.optimizer_prompt_symbol_set.constraint_tag}> + constraint_expression + + + """), + value_tag=self.optimizer_prompt_symbol_set.value_tag, + variables_section_title=self.optimizer_prompt_symbol_set.variables_section_title.replace(" ", ""), + feedback_section_title=self.optimizer_prompt_symbol_set.feedback_section_title.replace(" ", ""), + instruction_section_title=self.optimizer_prompt_symbol_set.instruction_section_title.replace(" ", ""), + context_section_title=self.optimizer_prompt_symbol_set.context_section_title.replace(" ", "") + ) + self.output_format_prompt = self.output_format_prompt_template.format( + output_format=self.optimizer_prompt_symbol_set.output_format, + reasoning_tag=self.optimizer_prompt_symbol_set.reasoning_tag, + improved_variable_tag=self.optimizer_prompt_symbol_set.improved_variable_tag, + instruction_section_title=self.optimizer_prompt_symbol_set.instruction_section_title.replace(" ", ""), + feedback_section_title=self.optimizer_prompt_symbol_set.feedback_section_title.replace(" ", ""), + variables_section_title=self.optimizer_prompt_symbol_set.variables_section_title.replace(" ", ""), + context_section_title=self.optimizer_prompt_symbol_set.context_section_title.replace(" ", "") + ) diff --git a/opto/optimizers/optoprime_v3.py b/opto/optimizers/optoprime_v3.py new file mode 100644 index 00000000..0bab6bc9 --- /dev/null +++ b/opto/optimizers/optoprime_v3.py @@ -0,0 +1,1282 @@ +""" +Key difference to v2: +1. Use the new backbone conversation history manager +2. Support multimodal node (both trainable and non-trainable) +""" + +import re +import json +from typing import List, Union, Tuple, Optional +from dataclasses import dataclass +from opto.optimizers.optoprime import OptoPrime, node_to_function_feedback +from opto.trace.utils import dedent +from opto.optimizers.utils import truncate_expression, extract_xml_like_data, is_bedrock_model +from opto.trace.nodes import ParameterNode, is_image +from opto.trace.propagators import GraphPropagator +from opto.trace.propagators.propagators import Propagator + +from opto.utils.llm import AbstractModel, LLM +from opto.optimizers.buffers import FIFOBuffer +from opto.utils.backbone import ( + Chat, UserTurn, AssistantTurn, PromptTemplate, + TextContent, ImageContent, ContentBlockList, + DEFAULT_IMAGE_PLACEHOLDER, Content +) +import copy +import pickle +from typing import Dict, Any + + +def value_to_image_content(value: Any) -> Optional[ImageContent]: + """Convert a value to ImageContent if it's an image, otherwise return None. + + Uses is_image() from opto.trace.nodes for validation (stricter than ImageContent.build, + e.g., only accepts URLs with image extensions), then delegates to ImageContent.build(). + + Supports (via is_image detection): + - Base64 data URL strings (data:image/...) + - HTTP/HTTPS URLs pointing to images (pattern-based, must have image extension) + - PIL Image objects + - Raw image bytes + """ + if not is_image(value): + return None + return ImageContent.build(value) + + +class OptimizerPromptSymbolSet: + """ + By inheriting this class and pass into the optimizer. People can change the optimizer documentation + + This divides into three parts: + - Section titles: the title of each section in the prompt + - Node tags: the tags that capture the graph structure (only tag names are allowed to be changed) + - Output format: the format of the output of the optimizer + """ + + # Titles should be written as markdown titles (space between # and title) + # In text, we automatically remove space in the title, so it will become `#Title` + variables_section_title = "# Variables" + inputs_section_title = "# Inputs" + outputs_section_title = "# Outputs" + others_section_title = "# Others" + feedback_section_title = "# Feedback" + instruction_section_title = "# Instruction" + code_section_title = "# Code" + documentation_section_title = "# Documentation" + context_section_title = "# Context" + + node_tag = "node" # nodes that are constants in the graph + variable_tag = "variable" # nodes that can be changed + value_tag = "value" # inside node, we have value tag + constraint_tag = "constraint" # inside node, we have constraint tag + + # output format + # Note: we currently don't support extracting format's like "```code```" because we assume supplied tag is name-only, i.e., + reasoning_tag = "reasoning" + improved_variable_tag = "variable" + name_tag = "name" + + # only used by JSON format + suggestion_tag = "suggestion" + + expect_json = False # this will stop `enforce_json` arguments passed to LLM calls + + # custom output format + # if this is not None, then the user needs to implement the following functions: + # - output_response_extractor + # - example_output + custom_output_format_instruction = None + + @property + def output_format(self) -> str: + """ + This function defines the input to: + ``` + {output_format} + ``` + In the self.output_format_prompt_template in the OptoPrimeV2 + """ + if self.custom_output_format_instruction is None: + # we use a default XML like format + return dedent(f""" + <{self.reasoning_tag}> + reasoning + + <{self.improved_variable_tag}> + <{self.name_tag}>variable_name + <{self.value_tag}> + value + + + """) + else: + return self.custom_output_format_instruction.strip() + + def example_output(self, reasoning, variables): + """ + reasoning: str + variables: format {variable_name, value} + """ + if self.custom_output_format_instruction is not None: + raise NotImplementedError + else: + # Build the output string in the same XML-like format as self.output_format + output = [] + if reasoning != "": + output.append(f"<{self.reasoning_tag}>") + output.append(reasoning) + output.append(f"") + for var_name, value in variables.items(): + output.append(f"<{self.improved_variable_tag}>") + output.append(f"<{self.name_tag}>{var_name}") + output.append(f"<{self.value_tag}>") + output.append(str(value)) + output.append(f"") + output.append(f"") + return "\n".join(output) + + def output_response_extractor(self, response: str) -> Dict[str, Any]: + # the response here should just be plain text + + if self.custom_output_format_instruction is None: + extracted_data = extract_xml_like_data(response, + reasoning_tag=self.reasoning_tag, + improved_variable_tag=self.improved_variable_tag, + name_tag=self.name_tag, + value_tag=self.value_tag) + + # if the suggested value is a code, and the entire code body is empty (i.e., not even function signature is present) + # then we remove such suggestion + keys_to_remove = [] + for key, value in extracted_data['variables'].items(): + if "__code" in key and value.strip() == "": + keys_to_remove.append(key) + + for key in keys_to_remove: + del extracted_data['variables'][key] + + return extracted_data + else: + raise NotImplementedError( + "If you supplied a custom output format prompt template, you need to implement your own response extractor") + + @property + def default_prompt_symbols(self) -> Dict[str, str]: + return { + "variables": self.variables_section_title, + "inputs": self.inputs_section_title, + "outputs": self.outputs_section_title, + "others": self.others_section_title, + "feedback": self.feedback_section_title, + "instruction": self.instruction_section_title, + "code": self.code_section_title, + "documentation": self.documentation_section_title, + "context": self.context_section_title, + "reasoning": self.reasoning_tag, + "suggestion": self.suggestion_tag + } + + +class OptimizerPromptSymbolSetJSON(OptimizerPromptSymbolSet): + """We enforce a JSON output format extraction""" + + expect_json = True + + custom_output_format_instruction = dedent(""" + { + "reasoning": , + "suggestion": { + : , + : , + } + } + """) + + def example_output(self, reasoning, variables): + """ + reasoning: str + variables: format {variable_name, value} + """ + + # Build the output string in the same JSON format as described in custom_output_format_instruction + output = { + "reasoning": reasoning, + "suggestion": {var_name: value for var_name, value in variables.items()} + } + return json.dumps(output, indent=2) + + def output_response_extractor(self, response: str) -> Dict[str, Any]: + """ + Extracts reasoning and suggestion variables from the LLM response using OptoPrime's extraction logic. + """ + # Use the centralized extraction logic from OptoPrime + suggestion_tag = self.default_prompt_symbols.get("suggestion", "suggestion") + reasoning_tag = self.default_prompt_symbols.get("reasoning", "reasoning") + + ignore_extraction_error = True + + reasoning = "(Unable to extract, possibly due to parsing failure)" + + if "```" in response: + # First try to extract from ```json ... ``` blocks + json_match = re.findall(r"```json\s*(.*?)```", response, re.DOTALL) + if len(json_match) > 0: + response = json_match[0].strip() + else: + # Fall back to regular ``` ... ``` blocks + match = re.findall(r"```(.*?)```", response, re.DOTALL) + if len(match) > 0: + # Remove language identifier if present (e.g., "json", "python") + content = match[0].strip() + # Check if first line is a language identifier + lines = content.split('\n', 1) + if len(lines) > 1 and lines[0].strip().isalpha() and len(lines[0].strip()) < 20: + response = lines[1].strip() + else: + response = content + + json_extracted = {} + suggestion = {} + attempt_n = 0 + while attempt_n < 2: + try: + json_extracted = json.loads(response) + if isinstance(json_extracted, dict): # trim all whitespace keys in the json_extracted + json_extracted = {k.strip(): v for k, v in json_extracted.items()} + suggestion = json_extracted.get(suggestion_tag, json_extracted) + reasoning = json_extracted.get(reasoning_tag, "") + break + except json.JSONDecodeError: + response = re.findall(r"{.*}", response, re.DOTALL) + if len(response) > 0: + response = response[0] + attempt_n += 1 + except Exception: + attempt_n += 1 + + if not isinstance(suggestion, dict): + suggestion = json_extracted if isinstance(json_extracted, dict) else {} + + if len(suggestion) == 0: + pattern = rf'"{suggestion_tag}"\s*:\s*\{{(.*?)\}}' + suggestion_match = re.search(pattern, str(response), re.DOTALL) + if suggestion_match: + suggestion = {} + suggestion_content = suggestion_match.group(1) + pair_pattern = r'"([a-zA-Z0-9_]+)"\s*:\s*"(.*)"' + pairs = re.findall(pair_pattern, suggestion_content, re.DOTALL) + for key, value in pairs: + suggestion[key] = value + + if len(suggestion) == 0 and not ignore_extraction_error: + print(f"Cannot extract {suggestion_tag} from LLM's response:\n{response}") + + keys_to_remove = [] + for key, value in suggestion.items(): + if "__code" in key and value.strip() == "": + keys_to_remove.append(key) + for key in keys_to_remove: + del suggestion[key] + + return {"reasoning": reasoning, "variables": suggestion} + + +class OptimizerPromptSymbolSet2(OptimizerPromptSymbolSet): + variables_section_title = "# Variables" + inputs_section_title = "# Inputs" + outputs_section_title = "# Outputs" + others_section_title = "# Others" + feedback_section_title = "# Feedback" + instruction_section_title = "# Instruction" + code_section_title = "# Code" + documentation_section_title = "# Documentation" + context_section_title = "# Context" + + node_tag = "const" # nodes that are constants in the graph + variable_tag = "var" # nodes that can be changed + value_tag = "data" # inside node, we have value tag + constraint_tag = "constraint" # inside node, we have constraint tag + + # output format + reasoning_tag = "reason" + improved_variable_tag = "var" + name_tag = "name" + + +@dataclass +class FunctionFeedback: + """Container for structured feedback from function execution traces. + + Used by OptoPrime to organize execution traces into a format suitable + for LLM-based optimization. + + Attributes + ---------- + graph : list[tuple[int, str]] + Topologically sorted function calls with (depth, representation) pairs. + documentation : dict[str, str] + Mapping of function names to their documentation strings. + others : dict[str, Any] + Intermediate variables with (data, description) tuples. + roots : dict[str, Any] + Input/root variables with (data, description) tuples. + output : dict[str, Any] + Output/leaf variables with (data, description) tuples. + user_feedback : Union[str, ContentBlockList] + User-provided feedback about the execution. May include images. + + Notes + ----- + This structure separates the execution trace into logical components + that can be formatted into prompts for LLM-based optimization. + """ + + graph: List[ + Tuple[int, str] + ] # Each item is is a representation of function call. The items are topologically sorted. + documentation: Dict[str, str] # Function name and its documentationstring + others: Dict[str, Any] # Intermediate variable names and their data + roots: Dict[str, Any] # Root variable name and its data + output: Dict[str, Any] # Leaf variable name and its data + user_feedback: Union[str, ContentBlockList] # User feedback at the leaf of the graph (may include images) + + +@dataclass +class ProblemInstance: + """Problem instance with multimodal content support. + + A composite of multiple ContentBlockLists representing different parts + of a problem. Uses ContentBlockList for variables, inputs, others, and + outputs to support both text and image content in a unified way. + + The class provides: + - __repr__: Returns text-only representation for logging + - to_content_blocks(): Returns ContentBlockList for multimodal prompts + - has_images(): Check if any field contains images + """ + instruction: str + code: str + documentation: str + variables: ContentBlockList + inputs: ContentBlockList + others: ContentBlockList + outputs: ContentBlockList + feedback: ContentBlockList # May contain images mixed with text + context: Optional[ContentBlockList] + + optimizer_prompt_symbol_set: OptimizerPromptSymbolSet + + def __post_init__(self): + # Normalize content fields so callers may pass plain strings (or None). + # ContentBlockList.ensure is idempotent for existing ContentBlockLists. + self.variables = ContentBlockList.ensure(self.variables) + self.inputs = ContentBlockList.ensure(self.inputs) + self.others = ContentBlockList.ensure(self.others) + self.outputs = ContentBlockList.ensure(self.outputs) + self.feedback = ContentBlockList.ensure(self.feedback) + if self.context is not None: + self.context = ContentBlockList.ensure(self.context) + + problem_template = dedent( + """ + # Instruction + {instruction} + + # Code + {code} + + # Documentation + {documentation} + + # Variables + {variables} + + # Inputs + {inputs} + + # Others + {others} + + # Outputs + {outputs} + + # Context + {context} + + # Feedback + {feedback} + """ + ) + + def __repr__(self) -> str: + """Return text-only representation for backward compatibility. + + Uses ContentBlockList.to_text() for fields that may contain images. + """ + optimization_query = self.problem_template.format( + instruction=self.instruction, + code=self.code, + documentation=self.documentation, + variables=self.variables.to_text(), + inputs=self.inputs.to_text(), + outputs=self.outputs.to_text(), + others=self.others.to_text(), + context=self.context.to_text() if self.context is not None else "", + feedback=self.feedback.to_text() + ) + + return optimization_query + + def to_content_blocks(self) -> ContentBlockList: + """Convert the problem instance to a list of ContentBlocks. + + Consecutive TextContent blocks are merged into a single block for efficiency. + Images and other non-text blocks are kept separate. + + Returns: + ContentBlockList: A list containing TextContent and ImageContent blocks + that represent the complete problem instance including any images + from variables, inputs, others, or outputs. + """ + blocks = ContentBlockList() + + # Header sections (always text) + header = dedent(f""" + # Instruction + {self.instruction} + + # Code + {self.code} + + # Documentation + {self.documentation} + + # Variables + """) + blocks.append(header) + + # Variables section (may contain images) + blocks.extend(self.variables) + + # Inputs section + blocks.append("\n\n# Inputs\n") + blocks.extend(self.inputs) + + # Others section + blocks.append("\n\n# Others\n") + blocks.extend(self.others) + + # Outputs section + blocks.append("\n\n# Outputs\n") + blocks.extend(self.outputs) + + # Context section (optional) + if self.context is not None and self.context.to_text().strip() != "": + blocks.append(f"\n\n# Context\n") # section name + blocks.extend(self.context) # extend the blocks + + # Feedback section (may contain images) + blocks.append("\n\n# Feedback\n") + blocks.extend(self.feedback) + + return blocks + + def has_images(self) -> bool: + """Check if this problem instance contains any images. + + Efficiently checks each ContentBlockList field directly + without building full content blocks. + + Returns: + bool: True if any field contains ImageContent blocks. + """ + return any( + field.has_images() + for field in [self.variables, self.inputs, self.others, self.outputs, self.feedback] + ) + + + + + +# we provide two aliases for the Content class for semantic convenience +Context = Content +Feedback = Content + +class OptoPrimeV3(OptoPrime): + # This is generic representation prompt, which just explains how to read the problem. + representation_prompt = dedent( + """You're tasked to solve a coding/algorithm problem. You will see the instruction, the code, the documentation of each function used in the code, and the feedback about the execution result. + + Specifically, a problem will be composed of the following parts: + - {instruction_section_title}: the instruction which describes the things you need to do or the question you should answer. + - {code_section_title}: the code defined in the problem. + - {documentation_section_title}: the documentation of each function used in #Code. The explanation might be incomplete and just contain high-level description. You can use the values in #Others to help infer how those functions work. + - {variables_section_title}: the input variables that you can change/tweak (trainable). + - {inputs_section_title}: the values of fixed inputs to the code, which CANNOT be changed (fixed). + - {others_section_title}: the intermediate values created through the code execution. + - {outputs_section_title}: the result of the code output. + - {feedback_section_title}: the feedback about the code's execution result. + - {context_section_title}: the context information that might be useful to solve the problem. + + In `{variables_section_title}`, `{inputs_section_title}`, `{outputs_section_title}`, and `{others_section_title}`, the format is: + + For variables we express as this: + {variable_expression_format} + + If `data_type` is `code`, it means `{value_tag}` is the source code of a python code, which may include docstring and definitions.""" + ) + + # Optimization + default_objective = "You need to change the `{value_tag}` of the variables in {variables_section_title} to improve the output in accordance to {feedback_section_title}." + + output_format_prompt_template = dedent( + """ + Output_format: Your output should be in the following XML or JSON format: + + {output_format} + + In <{reasoning_tag}>, explain the problem: 1. what the {instruction_section_title} means 2. what the {feedback_section_title} on {outputs_section_title} means to {variables_section_title} considering how {variables_section_title} are used in {code_section_title} and other values in {documentation_section_title}, {inputs_section_title}, {others_section_title}. 3. Reasoning about the suggested changes in {variables_section_title} (if needed) and the expected result. + + If you need to suggest a change in the values of {variables_section_title}, write down the suggested values in <{improved_variable_tag}>. Remember you can change only the values in {variables_section_title}, not others. When `type` of a variable is `code`, you should write the new definition in the format of python code without syntax errors, and you should not change the function name or the function signature. + + If no changes are needed, just output TERMINATE. + """ + ) + + example_problem_template = PromptTemplate(dedent( + """ + Here is an example of problem instance and response: + + ================================ + {example_problem} + ================================ + + Your response: + {example_response} + """ + )) + + user_prompt_template = PromptTemplate(dedent( + """ + Now you see problem instance: + + ================================ + {problem_instance} + ================================ + + """ + )) + + final_prompt = dedent( + """ + What are your suggestions on variables {names}? + + Your response: + """ + ) + + def __init__( + self, + parameters: List[ParameterNode], + llm: AbstractModel = None, + *args, + image_llm: AbstractModel = None, + propagator: Propagator = None, + objective: Union[None, str] = None, + ignore_extraction_error: bool = True, + # ignore the type conversion error when extracting updated values from LLM's suggestion + include_example=False, + memory_size=0, # Memory size to store the past feedback + max_tokens=8192, + log=True, + initial_var_char_limit=2000, + optimizer_prompt_symbol_set: OptimizerPromptSymbolSet = OptimizerPromptSymbolSet(), + use_json_object_format=True, # whether to use json object format for the response when calling LLM + truncate_expression=truncate_expression, + problem_context: Optional[ContentBlockList] = None, + **kwargs, + ): + super().__init__(parameters, *args, propagator=propagator, **kwargs) + + self.truncate_expression = truncate_expression + self.problem_context: Optional[ContentBlockList] = problem_context + self.output_contains_image = False + + self.use_json_object_format = use_json_object_format if optimizer_prompt_symbol_set.expect_json and use_json_object_format else False + self.ignore_extraction_error = ignore_extraction_error + self.llm = llm or LLM(mm_beta=True) + self.image_llm = image_llm + + assert self.llm.mm_beta, "OptoPrimeV3 enables multi-modal LLM backbone by default. Please use LLM(model='...', mm_beta=True)." + + self.objective = objective or self.default_objective.format(value_tag=optimizer_prompt_symbol_set.value_tag, + variables_section_title=optimizer_prompt_symbol_set.variables_section_title, + feedback_section_title=optimizer_prompt_symbol_set.feedback_section_title) + self.initial_var_char_limit = initial_var_char_limit + self.optimizer_prompt_symbol_set = optimizer_prompt_symbol_set + + self.example_problem_summary = FunctionFeedback(graph=[(1, 'y = add(x=a,y=b)'), (2, "z = subtract(x=y, y=c)")], + documentation={'add': 'This is an add operator of x and y.', + 'subtract': "subtract y from x"}, + others={'y': (6, None)}, + roots={'a': (5, "a > 0"), + 'b': (1, None), + 'c': (5, None)}, + output={'z': (1, None)}, + user_feedback='The result of the code is not as expected. The result should be 10, but the code returns 1' + ) + self.example_problem_summary.variables = {'a': (5, "a > 0")} + self.example_problem_summary.inputs = {'b': (1, None), 'c': (5, None)} + + self.example_problem = self.problem_instance(self.example_problem_summary) + self.example_response = self.optimizer_prompt_symbol_set.example_output( + reasoning="In this case, the desired response would be to change the value of input a to 14, as that would make the code return 10.", + variables={ + 'a': 10, + } + ) + + self.include_example = include_example + self.max_tokens = max_tokens + self.log = [] if log else None + self.summary_log = [] if log else None + self.memory = FIFOBuffer(memory_size) + self.conversation_history = Chat() + self.conversation_length = memory_size # Number of conversation turns to keep + + self.default_prompt_symbols = self.optimizer_prompt_symbol_set.default_prompt_symbols + + self.prompt_symbols = copy.deepcopy(self.default_prompt_symbols) + self.initialize_instruct_prompt() + + def parameter_check(self, parameters: List[ParameterNode]): + """Check if the parameters are valid. + This can be overloaded by subclasses to add more checks. + + Args: + parameters: List[ParameterNode] + The parameters to check. + + Raises: + AssertionError: If more than one parameter contains image data. + + Notes: + OptoPrimeV3 supports image parameters, but only one parameter can be + an image at a time since LLMs can only generate one image per inference. + """ + # Count image parameters + image_params = [param for param in parameters if param.is_image] + + if len(image_params) > 1: + param_names = ', '.join([f"'{p.name}'" for p in image_params]) + raise AssertionError( + f"OptoPrimeV3 supports at most one image parameter, but found {len(image_params)}: " + f"{param_names}. LLMs can only generate one image at a time." + ) + if len(image_params) == 1: + self.output_contains_image = True + + def add_context(self, *args, images: Optional[List[Any]] = None, format: str = "PNG"): + """Add context to the optimizer, supporting both text and images. + + Two usage patterns are supported: + + **Usage 1: Variadic arguments (alternating text and images)** + + optimizer.add_context("text part 1", image_link, "text part 2", image_file) + + Each argument is either a string (text) or an image source. + + **Usage 2: Template with placeholders** + + optimizer.add_context( + "text part 1 [IMAGE] text part 2 [IMAGE]", + images=[image_link, image_file] + ) + + The text contains `[IMAGE]` placeholders that are replaced by images + from the `images` list in order. The number of placeholders must match + the number of images. + + Args: + *args: Variable arguments. In Usage 1, alternating text and images. + In Usage 2, a single template string with placeholders. + images: Optional list of image sources for Usage 2. Each can be: + - URL string (http/https) + - Local file path + - PIL Image object + - Numpy array + format: Image format for numpy arrays (PNG, JPEG, etc.). Default: PNG + + Raises: + ValueError: If using Usage 2 and the number of placeholders doesn't + match the number of images. + + Examples: + # Usage 1: Alternating text and images + optimizer.add_context("Here's the diagram:", "diagram.png", "And here's another:", "other.png") + + # Usage 2: Template with placeholders + optimizer.add_context("See [IMAGE] and compare with [IMAGE]", images=["a.png", "b.png"]) + + # Text-only context + optimizer.add_context("Important background information") + """ + ctx = Content(*args, images=images, format=format) + + # Store the context + if self.problem_context is None: + self.problem_context = ctx + else: + # Append to existing context with a newline separator + self.problem_context.append("\n\n") + self.problem_context.extend(ctx.to_content_blocks()) + + def initialize_instruct_prompt(self): + self.representation_prompt = self.representation_prompt.format( + variable_expression_format=dedent(f""" + <{self.optimizer_prompt_symbol_set.variable_tag} name="variable_name" type="data_type"> + <{self.optimizer_prompt_symbol_set.value_tag}> + value + + <{self.optimizer_prompt_symbol_set.constraint_tag}> + constraint_expression + + + """), + value_tag=self.optimizer_prompt_symbol_set.value_tag, + variables_section_title=self.optimizer_prompt_symbol_set.variables_section_title.replace(" ", ""), + inputs_section_title=self.optimizer_prompt_symbol_set.inputs_section_title.replace(" ", ""), + outputs_section_title=self.optimizer_prompt_symbol_set.outputs_section_title.replace(" ", ""), + feedback_section_title=self.optimizer_prompt_symbol_set.feedback_section_title.replace(" ", ""), + instruction_section_title=self.optimizer_prompt_symbol_set.instruction_section_title.replace(" ", ""), + code_section_title=self.optimizer_prompt_symbol_set.code_section_title.replace(" ", ""), + documentation_section_title=self.optimizer_prompt_symbol_set.documentation_section_title.replace(" ", ""), + others_section_title=self.optimizer_prompt_symbol_set.others_section_title.replace(" ", ""), + context_section_title=self.optimizer_prompt_symbol_set.context_section_title.replace(" ", "") + ) + self.output_format_prompt = self.output_format_prompt_template.format( + output_format=self.optimizer_prompt_symbol_set.output_format, + reasoning_tag=self.optimizer_prompt_symbol_set.reasoning_tag, + improved_variable_tag=self.optimizer_prompt_symbol_set.improved_variable_tag, + instruction_section_title=self.optimizer_prompt_symbol_set.instruction_section_title.replace(" ", ""), + feedback_section_title=self.optimizer_prompt_symbol_set.feedback_section_title.replace(" ", ""), + outputs_section_title=self.optimizer_prompt_symbol_set.outputs_section_title.replace(" ", ""), + code_section_title=self.optimizer_prompt_symbol_set.code_section_title.replace(" ", ""), + documentation_section_title=self.optimizer_prompt_symbol_set.documentation_section_title.replace(" ", ""), + variables_section_title=self.optimizer_prompt_symbol_set.variables_section_title.replace(" ", ""), + inputs_section_title=self.optimizer_prompt_symbol_set.inputs_section_title.replace(" ", ""), + others_section_title=self.optimizer_prompt_symbol_set.others_section_title.replace(" ", ""), + ) + + def repr_node_value(self, node_dict, node_tag="node", + value_tag="value", constraint_tag="constraint") -> str: + """Returns text-only representation of node values (backward compatible).""" + temp_list = [] + for k, v in node_dict.items(): + if "__code" not in k: + # For images, use placeholder text + value_repr = "[IMAGE]" if is_image(v[0]) else str(v[0]) + if v[1] is not None and node_tag == self.optimizer_prompt_symbol_set.variable_tag: + constraint_expr = f"<{constraint_tag}>\n{v[1]}\n" + temp_list.append( + f"<{node_tag} name=\"{k}\" type=\"{type(v[0]).__name__}\">\n<{value_tag}>\n{value_repr}\n\n{constraint_expr}\n\n") + else: + temp_list.append( + f"<{node_tag} name=\"{k}\" type=\"{type(v[0]).__name__}\">\n<{value_tag}>\n{value_repr}\n\n\n") + else: + constraint_expr = f"\n{v[1]}\n" + signature = v[1].replace("The code should start with:\n", "") + func_body = v[0].replace(signature, "") + temp_list.append( + f"<{node_tag} name=\"{k}\" type=\"code\">\n<{value_tag}>\n{signature}{func_body}\n\n{constraint_expr}\n\n") + return "\n".join(temp_list) + + def repr_node_value_compact(self, node_dict, node_tag="node", + value_tag="value", constraint_tag="constraint") -> str: + """Returns text-only compact representation of node values (backward compatible).""" + temp_list = [] + for k, v in node_dict.items(): + if "__code" not in k: + # For images, use placeholder text + if is_image(v[0]): + node_value = "[IMAGE]" + else: + node_value = self.truncate_expression(v[0], self.initial_var_char_limit) + if v[1] is not None and node_tag == self.optimizer_prompt_symbol_set.variable_tag: + constraint_expr = f"<{constraint_tag}>\n{v[1]}\n" + temp_list.append( + f"<{node_tag} name=\"{k}\" type=\"{type(v[0]).__name__}\">\n<{value_tag}>\n{node_value}\n\n{constraint_expr}\n\n") + else: + temp_list.append( + f"<{node_tag} name=\"{k}\" type=\"{type(v[0]).__name__}\">\n<{value_tag}>\n{node_value}\n\n\n") + else: + constraint_expr = f"<{constraint_tag}>\n{v[1]}\n" + # we only truncate the function body + signature = v[1].replace("The code should start with:\n", "") + func_body = v[0].replace(signature, "") + node_value = self.truncate_expression(func_body, self.initial_var_char_limit) + temp_list.append( + f"<{node_tag} name=\"{k}\" type=\"code\">\n<{value_tag}>\n{signature}{node_value}\n\n{constraint_expr}\n\n") + return "\n".join(temp_list) + + def repr_node_value_as_content_blocks(self, node_dict, node_tag="node", + value_tag="value", constraint_tag="constraint") -> ContentBlockList: + """Returns a ContentBlockList representing node values, including images. + + Consecutive TextContent blocks are merged for efficiency. + For image values, the text before and after the image are separate blocks. + """ + blocks = ContentBlockList() + + for k, v in node_dict.items(): + value_data = v[0] + constraint = v[1] + + if "__code" not in k: + # Check if this is an image + image_content = value_to_image_content(value_data) + + if image_content is not None: + # Image node: output XML structure, then image, then closing + type_name = "image" + constraint_expr = f"<{constraint_tag}>\n{constraint}\n" if constraint is not None and node_tag == self.optimizer_prompt_symbol_set.variable_tag else "" + + xml_text = f"<{node_tag} name=\"{k}\" type=\"{type_name}\">\n<{value_tag}>\n" + blocks.append(xml_text) + blocks.append(image_content) # Image breaks the text flow + + closing_text = f"\n\n{constraint_expr}\n\n" if constraint_expr else f"\n\n\n\n" + blocks.append(closing_text) + else: + # Non-image node: text representation + if constraint is not None and node_tag == self.optimizer_prompt_symbol_set.variable_tag: + constraint_expr = f"<{constraint_tag}>\n{constraint}\n" + blocks.append( + f"<{node_tag} name=\"{k}\" type=\"{type(value_data).__name__}\">\n<{value_tag}>\n{value_data}\n\n{constraint_expr}\n\n\n" + ) + else: + blocks.append( + f"<{node_tag} name=\"{k}\" type=\"{type(value_data).__name__}\">\n<{value_tag}>\n{value_data}\n\n\n\n" + ) + else: + # Code node (never an image) + constraint_expr = f"<{constraint_tag}>\n{constraint}\n" + signature = constraint.replace("The code should start with:\n", "") + func_body = value_data.replace(signature, "") + blocks.append( + f"<{node_tag} name=\"{k}\" type=\"code\">\n<{value_tag}>\n{signature}{func_body}\n\n{constraint_expr}\n\n\n" + ) + + return blocks + + def repr_node_value_compact_as_content_blocks(self, node_dict, node_tag="node", + value_tag="value", constraint_tag="constraint") -> ContentBlockList: + """Returns a ContentBlockList with compact representation, including images. + + Consecutive TextContent blocks are merged for efficiency. + Non-image values are truncated. Images break the text flow. + """ + blocks = ContentBlockList() + + for k, v in node_dict.items(): + value_data = v[0] + constraint = v[1] + + if "__code" not in k: + # Check if this is an image + image_content = value_to_image_content(value_data) + + if image_content is not None: + # Image node: output XML structure, then image, then closing + type_name = "image" + constraint_expr = f"<{constraint_tag}>\n{constraint}\n" if constraint is not None and node_tag == self.optimizer_prompt_symbol_set.variable_tag else "" + + xml_text = f"<{node_tag} name=\"{k}\" type=\"{type_name}\">\n<{value_tag}>\n" + blocks.append(xml_text) + blocks.append(image_content) # Image breaks the text flow + + closing_text = f"\n\n{constraint_expr}\n\n" if constraint_expr else f"\n\n\n\n" + blocks.append(closing_text) + else: + # Non-image node: truncated text representation + node_value = self.truncate_expression(value_data, self.initial_var_char_limit) + if constraint is not None and node_tag == self.optimizer_prompt_symbol_set.variable_tag: + constraint_expr = f"<{constraint_tag}>\n{constraint}\n" + blocks.append( + f"<{node_tag} name=\"{k}\" type=\"{type(value_data).__name__}\">\n<{value_tag}>\n{node_value}\n\n{constraint_expr}\n\n\n" + ) + else: + blocks.append( + f"<{node_tag} name=\"{k}\" type=\"{type(value_data).__name__}\">\n<{value_tag}>\n{node_value}\n\n\n\n" + ) + else: + # Code node (never an image) + constraint_expr = f"<{constraint_tag}>\n{constraint}\n" + signature = constraint.replace("The code should start with:\n", "") + func_body = value_data.replace(signature, "") + node_value = self.truncate_expression(func_body, self.initial_var_char_limit) + blocks.append( + f"<{node_tag} name=\"{k}\" type=\"code\">\n<{value_tag}>\n{signature}{node_value}\n\n{constraint_expr}\n\n\n" + ) + + return blocks + + def summarize(self): + """Aggregate feedback from parameters into a structured summary. + + Collects and organizes feedback from all trainable parameters into + a FunctionFeedback structure suitable for problem representation. + + Returns + ------- + FunctionFeedback + Structured feedback containing: + - variables: Trainable parameters with values and descriptions + - inputs: Non-trainable root nodes + - graph: Topologically sorted function calls + - others: Intermediate computation values + - output: Final output values + - documentation: Function documentation strings + - user_feedback: Aggregated user feedback + + Notes + ----- + The method performs several transformations: + 1. Aggregates feedback from all trainable parameters + 2. Converts the trace graph to FunctionFeedback structure + 3. Separates root nodes into variables (trainable) and inputs (non-trainable) + 4. Preserves the computation graph and intermediate values + + Parameters without feedback (disconnected from output) are still + included in the summary but may not receive updates. + """ + # Aggregate feedback from all the parameters + feedbacks = [ + self.propagator.aggregate(node.feedback) + for node in self.parameters + if node.trainable + ] + summary = sum(feedbacks) # TraceGraph + # Construct variables and update others + # Some trainable nodes might not receive feedback, because they might not be connected to the output + summary = node_to_function_feedback(summary) + # Classify the root nodes into variables and others + # summary.variables = {p.py_name: p.data for p in self.parameters if p.trainable and p.py_name in summary.roots} + + trainable_param_dict = {p.py_name: p for p in self.parameters if p.trainable} + summary.variables = { + py_name: data + for py_name, data in summary.roots.items() + if py_name in trainable_param_dict + } + summary.inputs = { + py_name: data + for py_name, data in summary.roots.items() + if py_name not in trainable_param_dict + } # non-variable roots + + return summary + + def construct_prompt(self, summary, mask=None, *args, **kwargs): + """Construct the system and user prompt. + + The prompt for the optimizer agent is rather complex. + There are prompts that are automatically constructed through the Trace frontend (aka the bundle/node API). + However, we also allow the user to provide additional context to the optimizer agent. + + We handle multimodal (MM) conversion implicitly for the automatic part (TraceGraph), + but we handle the user-provided context explicitly. + + Args: + summary: The FunctionFeedback summary containing graph information. + mask: List of section titles to exclude from the problem instance. + + Returns: + Tuple of (system_prompt: str, user_prompt: ContentBlockList) + - system_prompt is always a string + - user_prompt is a ContentBlockList for multimodal support + """ + system_prompt = ( + self.representation_prompt + self.output_format_prompt + ) # generic representation + output rule + + problem_inst = self.problem_instance(summary, mask=mask) + + # Build user prompt as ContentBlockList (auto-merges consecutive text) + user_content_blocks = ContentBlockList() + + # Add example if included + if self.include_example: + example_text = self.example_problem_template.format( + example_problem=str(self.example_problem), # Example is always text + example_response=self.example_response, + ) + user_content_blocks.append(example_text) + + # Add problem instance template + # context is part of the problem instance + user_content_blocks.append(self.user_prompt_template.format( + problem_instance=problem_inst.to_content_blocks(), + )) + + # Add final prompt + var_names = ", ".join(k for k in summary.variables.keys()) + user_content_blocks.append(self.final_prompt.format( + names=var_names, + )) + + return system_prompt, user_content_blocks + + def problem_instance(self, summary: FunctionFeedback, mask=None): + """Create a ProblemInstance from the summary. + + Args: + summary: The FunctionFeedback summary containing graph information. + mask: List of section titles to exclude from the problem instance. + + Returns: + ProblemInstance with content block fields for multimodal support. + """ + mask = mask or [] + + # Use content block representations for multimodal support + variables_content = ( + self.repr_node_value_as_content_blocks( + summary.variables, + node_tag=self.optimizer_prompt_symbol_set.variable_tag, + value_tag=self.optimizer_prompt_symbol_set.value_tag, + constraint_tag=self.optimizer_prompt_symbol_set.constraint_tag + ) + if self.optimizer_prompt_symbol_set.variables_section_title not in mask + else ContentBlockList() + ) + + # we add a temporary check here to ensure no more than 1 parameter is an image + variable_stats = variables_content.count_blocks() + if 'ImageContent' in variable_stats: + assert variable_stats['ImageContent'] <= 1, "Currently we do not support generating multiple images (more than 1 parameter is an image)" + self.output_contains_image = True + + inputs_content = ( + self.repr_node_value_compact_as_content_blocks( + summary.inputs, + node_tag=self.optimizer_prompt_symbol_set.node_tag, + value_tag=self.optimizer_prompt_symbol_set.value_tag, + constraint_tag=self.optimizer_prompt_symbol_set.constraint_tag + ) + if self.optimizer_prompt_symbol_set.inputs_section_title not in mask + else ContentBlockList() + ) + outputs_content = ( + self.repr_node_value_compact_as_content_blocks( + summary.output, + node_tag=self.optimizer_prompt_symbol_set.node_tag, + value_tag=self.optimizer_prompt_symbol_set.value_tag, + constraint_tag=self.optimizer_prompt_symbol_set.constraint_tag + ) + if self.optimizer_prompt_symbol_set.outputs_section_title not in mask + else ContentBlockList() + ) + others_content = ( + self.repr_node_value_compact_as_content_blocks( + summary.others, + node_tag=self.optimizer_prompt_symbol_set.node_tag, + value_tag=self.optimizer_prompt_symbol_set.value_tag, + constraint_tag=self.optimizer_prompt_symbol_set.constraint_tag + ) + if self.optimizer_prompt_symbol_set.others_section_title not in mask + else ContentBlockList() + ) + + return ProblemInstance( + instruction=self.objective if "#Instruction" not in mask else "", + code=( + "\n".join([v for k, v in sorted(summary.graph)]) + if self.optimizer_prompt_symbol_set.inputs_section_title not in mask + else "" + ), + documentation=( + "\n".join([f"[{k}] {v}" for k, v in summary.documentation.items()]) + if self.optimizer_prompt_symbol_set.documentation_section_title not in mask + else "" + ), + variables=variables_content, + inputs=inputs_content, + outputs=outputs_content, + others=others_content, + feedback=Content(summary.user_feedback) if self.optimizer_prompt_symbol_set.feedback_section_title not in mask else Content(""), + context=self.problem_context, + optimizer_prompt_symbol_set=self.optimizer_prompt_symbol_set + ) + + def _step( + self, verbose=False, mask=None, *args, **kwargs + ) -> Dict[ParameterNode, Any]: + """Execute one optimization step. + + Args: + verbose: If True, print prompts and responses. + mask: List of section titles to exclude from the problem instance. + + Returns: + Dictionary mapping parameters to their updated values. + """ + assert isinstance(self.propagator, GraphPropagator) + summary = self.summarize() + + system_prompt, user_content_blocks = self.construct_prompt(summary, mask=mask) + + response = self.call_llm( + system_prompt=system_prompt, + user_prompt=user_content_blocks, + verbose=verbose, + max_tokens=self.max_tokens, + ) + + if "TERMINATE" in response.to_text(): + return {} + + suggestion = self.extract_llm_suggestion(response.to_text()) + update_dict = self.construct_update_dict(suggestion['variables']) + # suggestion has two keys: reasoning, and variables + + # for update_dict, we manually update the image according to the variable name + if response.get_images().has_images(): + images = response.get_images() + assert len(images) == 1, "Currently we only allow at most one image parameter" + # find the variable name + image_param = [param for param in self.parameters if param.is_image][0] + update_dict[image_param] = images[0].as_image() # parameter as PIL Image + + if self.log is not None: + # For logging, use text representation + log_user_prompt = str(self.problem_instance(summary)) + self.log.append( + { + "system_prompt": system_prompt, + "user_prompt": log_user_prompt, + "response": response, + } + ) + self.summary_log.append( + {"problem_instance": self.problem_instance(summary), "summary": summary} + ) + + return update_dict + + def extract_llm_suggestion(self, response: str): + """Extract the suggestion from the response.""" + + suggestion = self.optimizer_prompt_symbol_set.output_response_extractor(response) + + if len(suggestion) == 0: + if not self.ignore_extraction_error: + print("Cannot extract suggestion from LLM's response:") + print(response) + + return suggestion + + def call_llm( + self, + system_prompt: str, + user_prompt: ContentBlockList, + verbose: Union[bool, str] = False, + max_tokens: int = 4096, + ) -> AssistantTurn: + """Call the LLM with a prompt and return the response. + + Args: + system_prompt: The system prompt (always a string). + user_prompt: The user prompt as ContentBlockList for multimodal content. + verbose: If True, print the prompt and response. If "output", only print response. + max_tokens: Maximum tokens in the response. + + Returns: + assistant_turn: AssistantTurn object + """ + if verbose not in (False, "output"): + # Print text portions, indicate if images present + text_parts = [block.text for block in user_prompt if isinstance(block, TextContent)] + has_images = any(isinstance(block, ImageContent) for block in user_prompt) + suffix = f" [+ {DEFAULT_IMAGE_PLACEHOLDER}]" if has_images else "" + print("Prompt\n", system_prompt + "".join(text_parts) + suffix) + + # Update system prompt in conversation history + self.conversation_history.system_prompt = system_prompt + + # Create user turn with content + user_turn = UserTurn(user_prompt) + self.conversation_history.add_user_turn(user_turn) + + # Get messages with conversation length control (truncate from start) + # conversation_length = n historical rounds (user+assistant pairs) to keep + # The current user turn is automatically included by to_messages() + messages = self.conversation_history.to_messages( + n=self.conversation_length if self.conversation_length > 0 else -1, + truncate_strategy="from_start", + model_name=self.llm.model_name + ) + + # Bedrock doesn't support response_format natively - LiteLLM adds tools which breaks the response + _is_bedrock = hasattr(self.llm, 'model_name') and is_bedrock_model(self.llm.model_name) + response_format = {"type": "json_object"} if (self.use_json_object_format and not _is_bedrock) else None + + # Prepare common arguments + llm_kwargs = {"messages": messages, "max_tokens": max_tokens, "response_format": response_format} + + # Add image generation tool only for non-Gemini models when output contains image + if self.output_contains_image and 'gemini' not in self.llm.model_name: + llm_kwargs["tools"] = [{"type": "image_generation"}] + + assistant_turn = self.llm(**llm_kwargs) + + if verbose: + print("LLM response:\n", assistant_turn) + + self.conversation_history.add_assistant_turn(assistant_turn) + + return assistant_turn + + def save(self, path: str): + """Save the optimizer state to a file.""" + with open(path, 'wb') as f: + pickle.dump({ + "truncate_expression": self.truncate_expression, + "use_json_object_format": self.use_json_object_format, + "ignore_extraction_error": self.ignore_extraction_error, + "objective": self.objective, + "initial_var_char_limit": self.initial_var_char_limit, + "optimizer_prompt_symbol_set": self.optimizer_prompt_symbol_set, + "include_example": self.include_example, + "max_tokens": self.max_tokens, + "memory": self.memory, + "conversation_history": self.conversation_history, + "conversation_length": self.conversation_length, + "default_prompt_symbols": self.default_prompt_symbols, + "prompt_symbols": self.prompt_symbols, + "representation_prompt": self.representation_prompt, + "output_format_prompt": self.output_format_prompt, + }, f) + + def load(self, path: str): + """Load the optimizer state from a file.""" + with open(path, 'rb') as f: + state = pickle.load(f) + self.truncate_expression = state["truncate_expression"] + self.use_json_object_format = state["use_json_object_format"] + self.ignore_extraction_error = state["ignore_extraction_error"] + self.objective = state["objective"] + self.initial_var_char_limit = state["initial_var_char_limit"] + self.optimizer_prompt_symbol_set = state["optimizer_prompt_symbol_set"] + self.include_example = state["include_example"] + self.max_tokens = state["max_tokens"] + self.memory = state["memory"] + self.conversation_history = state.get("conversation_history", Chat()) + self.conversation_length = state.get("conversation_length", 0) + self.default_prompt_symbols = state["default_prompt_symbols"] + self.prompt_symbols = state["prompt_symbols"] + self.representation_prompt = state["representation_prompt"] + self.output_format_prompt = state["output_format_prompt"] diff --git a/opto/optimizers/utils.py b/opto/optimizers/utils.py index 13a5ad01..ca967402 100644 --- a/opto/optimizers/utils.py +++ b/opto/optimizers/utils.py @@ -1,5 +1,26 @@ from typing import Dict, Any + +def is_bedrock_model(model_name: str) -> bool: + """Check whether a model name refers to an AWS Bedrock model. + + Bedrock models in LiteLLM look like ``bedrock/us.anthropic.claude-...`` or + carry a region prefix such as ``us.``/``eu.``/``ap.``. + + Args: + model_name: The model name string to check (may be None). + + Returns: + True if the model is a Bedrock model, False otherwise. + """ + if model_name is None: + return False + if model_name.startswith('bedrock/'): + return True + # AWS region prefixes (us-east-1, eu-west-1, ap-northeast-1, ...) + return any(model_name.startswith(f'{region}.') for region in ('us', 'eu', 'ap')) + + def print_color(message, color=None, logger=None): colors = { "red": "\033[91m", diff --git a/opto/trace/nodes.py b/opto/trace/nodes.py index ad935015..8775deb9 100644 --- a/opto/trace/nodes.py +++ b/opto/trace/nodes.py @@ -285,6 +285,102 @@ def __len__(self): T = TypeVar("T") +def verify_data_is_image_url(url: str, timeout: float = 1.0) -> bool: + """Verify that a URL points to an image via a HEAD request (Content-Type). + + Use this when you need definitive verification beyond the pattern-based + :func:`is_image` check (e.g. right before converting an image to base64). + + Args: + url: The URL to check. + timeout: Maximum time in seconds to wait for the request. Default 1.0. + + Returns: + bool: True if the URL returns an ``image/*`` Content-Type, else False. + Returns False for non-URL data or if the request/library is unavailable. + """ + if not isinstance(url, str): + return False + try: + from urllib.parse import urlparse + parsed = urlparse(url) + if parsed.scheme not in ('http', 'https'): + return False + try: + import requests + response = requests.head(url, timeout=timeout, allow_redirects=True) + content_type = response.headers.get('content-type', '').lower() + return content_type.startswith('image/') + except ImportError: + warnings.warn( + "requests library not available. Install with: pip install requests", + ImportWarning, + ) + return False + except Exception: + # Network errors, timeouts, invalid URLs, etc. + return False + except (ValueError, AttributeError): + return False + + +def is_image(data) -> bool: + """Pattern-based check for whether ``data`` represents an image. + + Supports: base64 data-URL strings (``data:image/...``), PIL Image objects, + raw image bytes, image URLs (by extension; no network request), and + ``ImageContent`` containers (checked by class name to avoid an import cycle). + + For network verification of URLs, use :func:`verify_data_is_image_url`. + Convert numpy arrays to PIL Images first. + """ + # Base64 data URL string + if isinstance(data, str) and data.startswith('data:image/'): + return True + + # PIL Image object + try: + from PIL import Image + if isinstance(data, Image.Image): + return True + except ImportError: + pass + + # Raw image bytes + if isinstance(data, bytes): + try: + from PIL import Image + from io import BytesIO + Image.open(BytesIO(data)) + return True + except Exception: + pass + + # Image URL (pattern-based, no network request) + if isinstance(data, str): + try: + from urllib.parse import urlparse + parsed = urlparse(data) + if parsed.scheme in ('http', 'https'): + path = parsed.path.lower() + image_extensions = ('.jpg', '.jpeg', '.png', '.gif', '.webp', '.bmp', + '.svg', '.ico', '.tiff', '.tif', '.heic', '.heif') + if any(path.endswith(ext) for ext in image_extensions): + return True + except (ValueError, AttributeError): + pass + + # Specialized container class (e.g. ImageContent) checked by name to keep + # nodes.py free of external (opto.utils) dependencies. + try: + if 'ImageContent' in data.__class__.__name__: + return True + except AttributeError: + pass + + return False + + class AbstractNode(Generic[T]): """AbstractNode represents an abstract data node in a directed graph. @@ -362,6 +458,11 @@ def data(self): current_used_nodes[-1].add(self) return self.__getattribute__("_data") + @property + def is_image(self) -> bool: + """Whether this node's data represents an image (see :func:`is_image`).""" + return is_image(self._data) + @property def parents(self): """Get the parents of a node. diff --git a/opto/utils/backbone/__init__.py b/opto/utils/backbone/__init__.py new file mode 100644 index 00000000..04151661 --- /dev/null +++ b/opto/utils/backbone/__init__.py @@ -0,0 +1,33 @@ +"""Multimodal conversation primitives for Trace optimizers. + +This package replaces the former single-file ``backbone.py``. The public API is +re-exported here so existing imports (``from opto.utils.backbone import X``) +keep working. +""" +from .content import ( + DEFAULT_IMAGE_PLACEHOLDER, + ContentBase, + ContentBlockList, + Content, + TextContent, + ImageContent, + ContentBlock, +) +from .template import PromptTemplate +from .turns import Turn, UserTurn, AssistantTurn +from .chat import Chat + +__all__ = [ + "DEFAULT_IMAGE_PLACEHOLDER", + "ContentBase", + "ContentBlockList", + "Content", + "ContentBlock", + "TextContent", + "ImageContent", + "PromptTemplate", + "Turn", + "UserTurn", + "AssistantTurn", + "Chat", +] diff --git a/opto/utils/backbone/chat.py b/opto/utils/backbone/chat.py new file mode 100644 index 00000000..9f0db949 --- /dev/null +++ b/opto/utils/backbone/chat.py @@ -0,0 +1,361 @@ +"""Chat: multi-turn conversation manager that renders to provider formats.""" +from typing import List, Dict, Any, Optional, Literal, Union +from dataclasses import dataclass, field +import json + +from .content import ContentBlockList, TextContent, ImageContent, Content +from .turns import UserTurn, AssistantTurn + + +@dataclass +class Chat: + """Manages conversation history across multiple turns using LiteLLM unified format""" + turns: List[Union[UserTurn, AssistantTurn]] = field(default_factory=list) + system_prompt: Optional[str] = None + protected_rounds: int = 0 # Initial rounds to never truncate (task definition) + + def add_user_turn(self, turn: Union[str, ContentBlockList, 'TextContent', 'ImageContent', 'Content', UserTurn]) -> 'Chat': + """Add a user turn + + Args: + turn: Can be: + - str: Plain text message + - ContentBlockList: List of content blocks + - TextContent: Single text content block + - ImageContent: Single image content block + - Content: Multi-modal content wrapper + - UserTurn: Complete user turn object + + Returns: + Chat: Self for method chaining + + Raises: + TypeError: If turn is not one of the accepted types + """ + # Accept UserTurn directly + if isinstance(turn, UserTurn): + self.turns.append(turn) + return self + + assert isinstance( + turn, (str, ContentBlockList, TextContent, ImageContent, Content) + ), "turn must be a string, ContentBlockList, TextContent, ImageContent, or Content object" + user_turn = UserTurn(content=turn) + self.turns.append(user_turn) + return self + + def add_assistant_turn(self, turn: AssistantTurn) -> 'Chat': + """Add an assistant turn. AssistantTurn parses the response from the LLM.""" + assert isinstance(turn, AssistantTurn), "turn must be an AssistantTurn object" + self.turns.append(turn) + return self + + def get_last_user_turn(self) -> Optional[UserTurn]: + """Get the most recent user turn""" + for turn in reversed(self.turns): + if isinstance(turn, UserTurn): + return turn + return None + + def get_last_assistant_turn(self) -> Optional[AssistantTurn]: + """Get the most recent assistant turn""" + for turn in reversed(self.turns): + if isinstance(turn, AssistantTurn): + return turn + return None + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary format""" + return { + "system_prompt": self.system_prompt, + "protected_rounds": self.protected_rounds, + "turns": [turn.to_dict() for turn in self.turns] + } + + def to_litellm_format( + self, + n: int = -1, + truncate_strategy: Literal["from_start", "from_end"] = "from_start", + protected_rounds: Optional[int] = None + ) -> List[Dict[str, Any]]: + """ + Convert to LiteLLM messages format (OpenAI-compatible, works with all providers) + + Args: + n: Number of historical rounds (user+assistant pairs) to include. + -1 means all history (default: -1). + The current (potentially incomplete) round is always included. + truncate_strategy: How to truncate when n is specified: + - "from_start": Remove oldest rounds, keep the most recent n rounds (default) + - "from_end": Remove newest rounds, keep the oldest n rounds + protected_rounds: Number of initial rounds to never truncate (task definition). + If None, uses self.protected_rounds. These rounds count towards n, so + if n=5 and protected_rounds=1, you get 1 protected + 4 truncatable rounds. + + Returns: + List of message dictionaries in LiteLLM format + """ + # Determine protected rounds + n_protected = protected_rounds if protected_rounds is not None else self.protected_rounds + protected_turns = n_protected * 2 # Each round = user + assistant + + # Apply truncation to turns + if n == -1: + selected_turns = self.turns + else: + # Protected rounds count towards N + # So if N=5 and protected_rounds=1, we keep 1 protected + 4 from truncatable + remaining_rounds = max(0, n - n_protected) + + # Split into protected and truncatable turns + protected_part = self.turns[:protected_turns] + truncatable_part = self.turns[protected_turns:] + + # remaining_rounds = number of rounds (pairs) from the truncatable part + # Each round = 2 turns (user + assistant) + # Plus include current incomplete round (if last turn is user, +1) + has_incomplete_round = len(truncatable_part) > 0 and isinstance(truncatable_part[-1], UserTurn) + n_turns = remaining_rounds * 2 + (1 if has_incomplete_round else 0) + + if truncate_strategy == "from_start": + # Keep last n_turns from truncatable part (remove from start) + truncated_part = truncatable_part[-n_turns:] if n_turns > 0 else [] + elif truncate_strategy == "from_end": + # Keep first n_turns from truncatable part (remove from end) + truncated_part = truncatable_part[:n_turns] if n_turns > 0 else [] + else: + raise ValueError(f"Unknown truncate_strategy: {truncate_strategy}. Use 'from_start' or 'from_end'") + + # Combine protected + truncated + selected_turns = protected_part + truncated_part + + messages = [] + + if self.system_prompt: + messages.append({"role": "system", "content": self.system_prompt}) + + for turn in selected_turns: + messages.append(turn.to_litellm_format()) + + return messages + + def to_messages( + self, + n: int = -1, + truncate_strategy: Literal["from_start", "from_end"] = "from_start", + protected_rounds: Optional[int] = None, + model_name: Optional[str] = None + ) -> List[Dict[str, Any]]: + """ + Smart message format conversion that auto-detects the appropriate format. + + This method automatically chooses between Gemini format and LiteLLM format based on + the model name. Detection priority: + 1. If model_name argument is provided and contains "gemini", uses Gemini format + 2. Otherwise, checks if any AssistantTurn has a model name containing "gemini" + 3. If no Gemini model detected, uses LiteLLM format (default) + + Note: This detection may not work for custom LLM backends with Gemini model names. + In such cases, call to_gemini_format() or to_litellm_format() explicitly. + + Args: + n: Number of historical rounds (user+assistant pairs) to include. + -1 means all history (default: -1). + The current (potentially incomplete) round is always included. + truncate_strategy: How to truncate when n is specified: + - "from_start": Remove oldest rounds, keep the most recent n rounds (default) + - "from_end": Remove newest rounds, keep the oldest n rounds + protected_rounds: Number of initial rounds to never truncate (task definition). + If None, uses self.protected_rounds. Counts towards n. + model_name: Optional model name to use for format detection. If provided and + contains "gemini" (case-insensitive), forces Gemini format. + + Returns: + List of message dictionaries in the appropriate format + + Example: + # Automatically uses Gemini format if model is Gemini + history = Chat() + history.system_prompt = "You are helpful." + history.add_user_turn(UserTurn().add_text("Hello")) + + # Force Gemini format by providing model name + messages = history.to_messages(model_name="gemini-2.5-flash") + + # Or be explicit: + messages = history.to_gemini_format() # Force Gemini format + messages = history.to_litellm_format() # Force LiteLLM format + """ + # Check if model_name argument indicates Gemini (highest priority) + use_gemini_format = False + if model_name and 'gemini' in model_name.lower(): + use_gemini_format = True + else: + # Check if any AssistantTurn has a Gemini model + for turn in self.turns: + if isinstance(turn, AssistantTurn) and turn.model: + if 'gemini' in turn.model.lower(): + use_gemini_format = True + break + + # Use the appropriate format + if use_gemini_format: + return self.to_gemini_format( + n=n, + truncate_strategy=truncate_strategy, + protected_rounds=protected_rounds + ) + else: + return self.to_litellm_format( + n=n, + truncate_strategy=truncate_strategy, + protected_rounds=protected_rounds + ) + + def to_gemini_format( + self, + n: int = -1, + truncate_strategy: Literal["from_start", "from_end"] = "from_start", + protected_rounds: Optional[int] = None + ) -> List[Dict[str, Any]]: + """ + Convert to Google Gemini format (messages with 'model' role instead of 'assistant') + + This method converts the conversation history to a format compatible with Google's + Gemini API. The main differences from LiteLLM format are: + - Uses 'model' instead of 'assistant' for role names + - Content is structured as 'parts' (list of text/image parts) + - System message (if present) remains as first message with role='system' + + The GoogleGenAILLM class will extract the system message and convert it to + system_instruction when making the API call. + + Args: + n: Number of historical rounds (user+assistant pairs) to include. + -1 means all history (default: -1). + The current (potentially incomplete) round is always included. + truncate_strategy: How to truncate when n is specified: + - "from_start": Remove oldest rounds, keep the most recent n rounds (default) + - "from_end": Remove newest rounds, keep the oldest n rounds + protected_rounds: Number of initial rounds to never truncate (task definition). + If None, uses self.protected_rounds. These rounds count towards n. + + Returns: + List of message dictionaries in Gemini format with 'role' and 'parts'. + System message (if present) is included as first message with role='system'. + + Example: + from opto.utils.llm import LLM + from opto.utils.backbone import Chat, UserTurn + + # Create conversation + history = Chat() + history.system_prompt = "You are a helpful assistant." + history.add_user_turn(UserTurn().add_text("Hello!")) + + # Convert to Gemini format + messages = history.to_gemini_format() + + # Use with GoogleGenAILLM + llm = LLM(model="gemini-2.5-flash") + response = llm(messages=messages) + """ + # Get the LiteLLM format messages first (handles truncation logic) + litellm_messages = self.to_litellm_format( + n=n, + truncate_strategy=truncate_strategy, + protected_rounds=protected_rounds + ) + + # Convert messages to Google GenAI format + gemini_messages = [] + + for msg in litellm_messages: + role = msg.get('role') + content = msg.get('content') + + # Keep system messages as-is (will be extracted by GoogleGenAILLM) + if role == 'system': + gemini_messages.append({'role': 'system', 'content': content}) + continue + + # Map roles: user -> user, assistant -> model + if role == 'assistant': + role = 'model' + elif role == 'tool': + # Skip tool messages for now - Gemini handles these differently + # TODO: Handle tool results properly if needed + continue + + # Handle content (can be string or list of content blocks) + if isinstance(content, str): + gemini_messages.append({'role': role, 'parts': [{'text': content}]}) + elif isinstance(content, list): + # Convert content blocks to parts + parts = [] + for block in content: + if block.get('type') == 'text': + parts.append({'text': block.get('text', '')}) + elif block.get('type') == 'image': + # Handle image URLs + image_url = block.get('image_url', '') + if image_url.startswith('data:'): + # Extract base64 data + import re + match = re.match(r'data:([^;]+);base64,(.+)', image_url) + if match: + mime_type, data = match.groups() + parts.append({'inline_data': {'mime_type': mime_type, 'data': data}}) + else: + # External URL + parts.append({'file_data': {'file_uri': image_url}}) + if parts: + gemini_messages.append({'role': role, 'parts': parts}) + + return gemini_messages + + def save_to_file(self, filepath: str): + """Save conversation history to JSON file""" + with open(filepath, 'w') as f: + json.dump(self.to_dict(), f, indent=2) + + @classmethod + def load_from_file(cls, filepath: str) -> 'Chat': + """Load conversation history from JSON file""" + with open(filepath, 'r') as f: + data = json.load(f) + + # This is a simplified loader - you'd want more robust deserialization + history = cls( + system_prompt=data.get('system_prompt'), + protected_rounds=data.get('protected_rounds', 0) + ) + + # Note: Full deserialization would require reconstructing objects from dicts + # This is left as an exercise since it depends on your exact needs + + return history + + def clear(self): + """Clear all turns from history""" + self.turns.clear() + + def get_token_count_estimate(self) -> int: + """Rough estimate of token count (actual count requires tokenizer)""" + total = 0 + for turn in self.turns: + if isinstance(turn, (UserTurn, AssistantTurn)): + for block in turn.content: + if isinstance(block, TextContent): + # Very rough estimate: ~4 chars per token + total += len(block.text) // 4 + return total + + def _repr_html_(self) -> str: + """Rich HTML representation for Jupyter notebooks with glassmorphism design.""" + try: + from opto.utils.display.jupyter import render_chat + return render_chat(self) + except ImportError: + # Fallback to text representation if display module unavailable + return None diff --git a/opto/utils/backbone/content.py b/opto/utils/backbone/content.py new file mode 100644 index 00000000..f7a2543c --- /dev/null +++ b/opto/utils/backbone/content.py @@ -0,0 +1,1125 @@ +"""Multimodal content blocks (text + image) for LLM conversations. + +Every class here is a small data class that is picklable / JSON-able and offers +an ``autocast``/``build`` helper to construct itself from loosely typed input. +""" +from typing import List, Dict, Any, Optional, Literal, Union, Iterable +from dataclasses import dataclass, field +import base64 +from pathlib import Path +import warnings + +from PIL import Image +import io + + +# Default placeholder for images that cannot be rendered as text +DEFAULT_IMAGE_PLACEHOLDER = "\n[IMAGE]\n" + + +@dataclass +class ContentBase: + """Abstract base class for all content blocks.""" + + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + def to_dict(self) -> Dict[str, Any]: + """Convert the content block to a dictionary representation. + + Returns: + Dict[str, Any]: Dictionary representation of the content block + """ + raise NotImplementedError("Subclasses must implement this method") + + @classmethod + def build(cls, value: Any, **kwargs) -> 'ContentBase': + """Build a content block from a value with auto-detection. + + Args: + value: The value to build from (type depends on subclass) + **kwargs: Additional keyword arguments for building + + Returns: + ContentBase: The built content block + """ + raise NotImplementedError("Subclasses must implement this method") + + def is_empty(self) -> bool: + """Check if the content block is empty (has no meaningful content). + + Returns: + bool: True if the block is empty, False otherwise + """ + raise NotImplementedError("Subclasses must implement this method") + +class ContentBlockList(list): + """List of content blocks with automatic type conversion. + + Supports automatic conversion from: + - str -> [TextContent(text=str)] + - TextContent -> [TextContent] + - ImageContent -> [ImageContent] + - List[ContentBlock] -> ContentBlockList + - None/empty -> [] + + Note: This list can contain mixed types of ContentBlocks (text, images, PDFs, etc.). + Type annotations like ContentBlockList[TextContent] are used for documentation + purposes in specialized methods but don't restrict the actual content. + """ + + def __init__(self, content: Union[str, 'ContentBase', List['ContentBase'], None] = None): + """Initialize ContentBlockList with automatic type conversion. + + Args: + content: Can be a string (converted to TextContent), a single ContentBlock, + a list of ContentBlocks, or None (empty list). + """ + super().__init__() + if content is not None: + self.extend(self._normalize(content)) + + @staticmethod + def _normalize(content: Union[str, 'ContentBase', List['ContentBase'], None]) -> List['ContentBase']: + """Normalize content to a list of ContentBlocks.""" + if content is None: + return [] + if isinstance(content, str): + return [TextContent(text=content)] if content else [] + if isinstance(content, list): + return content + # Single ContentBlock + return [content] + + @classmethod + def ensure(cls, content: Union[str, 'ContentBase', List['ContentBase'], None]) -> 'ContentBlockList': + """Ensure content is a ContentBlockList with automatic conversion. + + Args: + content: String, ContentBlock, list of ContentBlocks, or None + + Returns: + ContentBlockList with the content + """ + if isinstance(content, cls): + return content + return cls(content) + + def __getitem__(self, key: Union[int, slice]) -> Union['ContentBase', 'ContentBlockList']: + """Support indexing and slicing. + + Args: + key: Integer index or slice object + + Returns: + ContentBlock for single index, ContentBlockList for slices + """ + if isinstance(key, slice): + # Return a new ContentBlockList with the sliced items + return ContentBlockList(list.__getitem__(self, key)) + else: + # Return the single item for integer index + return list.__getitem__(self, key) + + def to_dict(self) -> Dict[str, Any]: + return {"type": "list", "blocks": [b.to_dict() for b in self]} + + def append(self, item: Union[str, 'ContentBase', 'ContentBlockList']) -> 'ContentBlockList': + """Append a string or ContentBlock, merging consecutive text. + + Args: + item: String (auto-converted to TextContent) or ContentBlock. + If the last item is TextContent and item is also text, + they are merged into a single TextContent. + """ + if isinstance(item, str): + # String: merge with last TextContent or create new one (with a separation mark " ") + if self and isinstance(self[-1], TextContent): + self[-1] = TextContent(text=self[-1].text + " " + item) + else: + super().append(TextContent(text=item)) + elif isinstance(item, TextContent): + # TextContent: merge with last TextContent or add (with a separation mark " ") + if self and isinstance(self[-1], TextContent): + self[-1] = TextContent(text=self[-1].text + " " + item.text) + else: + super().append(item) + elif isinstance(item, ContentBlockList): + # we silently call extend here + super().extend(item) + else: + # Other ContentBlock types (ImageContent, etc.): just add + super().append(item) + return self + + def extend(self, blocks: Union[str, 'ContentBase', List[ + 'ContentBase'], 'ContentBlockList', None]) -> 'ContentBlockList': + """Extend with blocks, merging consecutive TextContent. + + Args: + blocks: String, ContentBlock, list of ContentBlocks, or None. + Strings are auto-converted. Consecutive text is merged. + """ + normalized = self._normalize(blocks) + for block in normalized: + self.append(block) + return self + + def __add__(self, other) -> 'ContentBlockList': + """Concatenate content block lists with other content block lists or strings. + + Args: + other: ContentBlockList, List[ContentBlock], or string to concatenate + """ + if isinstance(other, (ContentBlockList, list)): + result = ContentBlockList(list(self)) + result.extend(other) + return result + elif isinstance(other, str): + result = ContentBlockList(list(self)) + result.append(TextContent(text=other)) + return result + else: + return NotImplemented + + def __radd__(self, other) -> 'ContentBlockList': + """Right-side concatenation (when string is on the left). + """ + if isinstance(other, str): + result = ContentBlockList([TextContent(text=other)]) + result.extend(self) + return result + else: + return NotImplemented + + def is_empty(self) -> bool: + """Check if the content block list is empty.""" + if len(self) == 0: + return True + return all(block.is_empty() for block in self) + + def has_images(self) -> bool: + """Check if the content block list contains any images.""" + return any(isinstance(block, ImageContent) for block in self) + + def has_text(self) -> bool: + """Check if the content block list contains any text.""" + return any(isinstance(block, TextContent) for block in self) + + # --- Multimodal utilities --- + @staticmethod + def blocks_to_text(blocks: Iterable['ContentBase'], + image_placeholder: str = DEFAULT_IMAGE_PLACEHOLDER) -> str: + """Convert any iterable of ContentBlocks to text representation. + + This is a utility that can be used by composite classes containing + multiple ContentBlockLists. Handles nested ContentBlockLists recursively. + + Args: + blocks: Iterable of ContentBlock objects (may include nested ContentBlockLists) + image_placeholder: Placeholder string for images (default: "[IMAGE]") + + Returns: + str: Text representation where images are replaced with placeholder. + """ + text_parts = [] + for block in blocks: + if isinstance(block, TextContent): + text_parts.append(block.text) + elif isinstance(block, ImageContent): + text_parts.append(image_placeholder) + elif isinstance(block, ContentBlockList): + # Recursively handle nested ContentBlockList + nested_text = ContentBlockList.blocks_to_text(block, image_placeholder) + if nested_text: + text_parts.append(nested_text) + return " ".join(text_parts) + + def to_text(self, image_placeholder: str = DEFAULT_IMAGE_PLACEHOLDER) -> str: + """Convert this list to text representation. + + Args: + image_placeholder: Placeholder string for images (default: "[IMAGE]") + + Returns: + str: Text representation where images are replaced with placeholder. + """ + return self.blocks_to_text(self, image_placeholder) + + def __bool__(self) -> bool: + """Check if there's any actual content (not just empty text). + + Returns: + bool: True if content is non-empty (has images or non-whitespace text). + """ + for block in self: + if isinstance(block, ImageContent): + return True + if isinstance(block, TextContent) and block.text.strip(): + return True + return False + + def __repr__(self) -> str: + """Return text-only representation for logging. + + Images are represented as "[IMAGE]" placeholder. + + Returns: + str: Text representation of the content. + """ + return self.to_text() + + def _repr_html_(self) -> str: + """Rich HTML representation for Jupyter notebooks.""" + try: + from opto.utils.display.jupyter import render_content_block_list + return render_content_block_list(self) + except ImportError: + # Fallback to text representation if display module unavailable + return None + + def to_content_blocks(self) -> 'ContentBlockList': + """Return self (for interface compatibility with composites). + + This allows ContentBlockList and classes that inherit from it + to be used interchangeably with composite classes that have + a to_content_blocks() method. + + Returns: + ContentBlockList: Self reference. + """ + return self + + def count_blocks(self) -> Dict[str, int]: + """Count blocks by type, including nested structures. + + Recursively traverses the content block structure and counts + each block type by its class name. + + Returns: + Dict[str, int]: Dictionary mapping block class names to counts. + Example: {"TextContent": 3, "ImageContent": 1} + """ + counts: Dict[str, int] = {} + + def _count_recursive(item: Any) -> None: + """Recursively count blocks in nested structures.""" + if isinstance(item, ContentBase): + # Count this block + class_name = item.__class__.__name__ + counts[class_name] = counts.get(class_name, 0) + 1 + + # Check if this block has any attributes that might contain nested blocks + if hasattr(item, '__dict__'): + for attr_value in item.__dict__.values(): + if isinstance(attr_value, (ContentBlockList, list)): + for nested_item in attr_value: + _count_recursive(nested_item) + elif isinstance(attr_value, ContentBase): + _count_recursive(attr_value) + elif isinstance(item, (ContentBlockList, list)): + # Recursively count items in lists + for nested_item in item: + _count_recursive(nested_item) + + # Count all blocks in this list + for block in self: + _count_recursive(block) + + return counts + + def to_litellm_format(self, role: Optional[str] = None) -> List[Dict[str, Any]]: + """Convert content blocks to LiteLLM Response API format. + + Args: + role: Optional role context ("user" or "assistant") to determine the correct type. + If not provided, defaults to "user" for backward compatibility. + + Returns: + List[Dict[str, Any]]: List of content block dictionaries in Response API format + """ + if role is None: + role = "user" + + content = [] + for block in self: + # Skip empty content blocks + if block.is_empty(): + continue + + # Handle different content block types + if isinstance(block, TextContent): + # Pass role context to TextContent for proper type selection + content.append(block.to_litellm_format(role=role)) + elif isinstance(block, ImageContent): + # ImageContent always uses input_image for user messages + content.append(block.to_litellm_format()) + elif hasattr(block, 'to_litellm_format'): + # Fallback: use block's own to_litellm_format method + content.append(block.to_litellm_format()) + else: + # Last resort: use to_dict() + content.append(block.to_dict()) + + return content + + +class Content(ContentBlockList): + """Semantic wrapper providing multi-modal content for the optimizer agent. + + The goal is to provide a flexible interface for user to add mixed text and image content to the optimizer agent. + + Inherits all ContentBlockList functionality (append, extend, has_images, + to_text, __bool__, __repr__, etc.) with a flexible constructor that + supports multiple input patterns. + + Primary use cases: + - Building problem context for the optimizer agent + - Providing user feedback + + Creation patterns: + - Variadic: Content("text", image, "more text") + - Template: Content("See [IMAGE] here", images=[img]) + - Empty: Content() + + Examples: + # Text-only content + ctx = Content("Important background information") + + # Image content + ctx = Content(ImageContent.build("diagram.png")) + + # Mixed content (variadic mode) + ctx = Content( + "Here's the diagram:", + "diagram.png", # auto-detected as image file + "And the analysis." + ) + + # Template mode with placeholders + ctx = Content( + "Compare [IMAGE] with [IMAGE]:", + images=[img1, img2] + ) + + # Manual building + ctx = Content() + ctx.append("Here's the relevant diagram:") + ctx.append(ImageContent.build("diagram.png")) + """ + + def __init__( + self, + *args, + images: Optional[List[Any]] = None, + format: str = "PNG" + ): + """Initialize a Content from various input patterns. + + Supports two usage modes: + + **Mode 1: Variadic (images=None)** + Pass any mix of text and image sources as arguments. + Strings are auto-detected as text or image paths/URLs. + + Content("Hello", some_image, "World") + Content("Check this:", "path/to/image.png") + + **Mode 2: Template (images provided)** + Pass a template string with [IMAGE] placeholders and a list of images. + + Content( + "Compare [IMAGE] with [IMAGE]", + images=[img1, img2] + ) + + Args: + *args: Variable arguments - text strings and/or image sources (Mode 1), + or a single template string (Mode 2) + images: Optional list of images for template mode. When provided, + expects exactly one template string in args. + format: Image format for numpy arrays (PNG, JPEG, etc.). Default: PNG + + Raises: + ValueError: In template mode, if placeholder count doesn't match image count, + or if args is not a single template string. + """ + # Initialize empty list first + super().__init__() + + # Build content based on mode + if images is not None: + if len(args) != 1 or not isinstance(args[0], str): + raise ValueError( + "Template mode requires exactly one template string as the first argument. " + f"Got {len(args)} arguments." + ) + self._build_from_template(args[0], images=images, format=format) + elif args: + self._build_from_variadic(*args) + + def _build_from_variadic(self, *args) -> None: + """Populate self from variadic arguments. + + Each argument is either text (str) or an image source. + Strings are auto-detected: if they look like image paths/URLs, + they're converted to ImageContent; otherwise treated as text. + + Args: + *args: Alternating text and image sources + format: Image format for numpy arrays + """ + for arg in args: + # for Future expansion, we can check if the string is any special content type + # by is_empty() on special ContentBlock subclasses + image_content = ImageContent.build(arg) + if not image_content.is_empty(): + self.append(image_content) + else: + self.append(arg) + + def _build_from_template( + self, + template: str, + images: List[Any], + format: str = "PNG" + ) -> None: + """Populate self from template with [IMAGE] placeholders. + + The template string contains [IMAGE] placeholders that are replaced + by images from the images list in order. + + Args: + template: Template string containing [IMAGE] placeholders + images: List of image sources to insert at placeholders + format: Image format for numpy arrays + + Raises: + ValueError: If placeholder count doesn't match the number of images. + """ + placeholder = DEFAULT_IMAGE_PLACEHOLDER + + # Count placeholders + placeholder_count = template.count(placeholder) + if placeholder_count != len(images): + raise ValueError( + f"Number of {placeholder} placeholders ({placeholder_count}) " + f"does not match number of images ({len(images)})" + ) + + # Split template by placeholder and interleave with images + parts = template.split(placeholder) + + for i, part in enumerate(parts): + if part: # Add text part if non-empty + self.append(part) + + # Add image after each part except the last + if i < len(images): + image_content = ImageContent.build(images[i], format=format) + if image_content is None: + raise ValueError( + f"Could not convert image at index {i} to ImageContent: {type(images[i])}" + ) + self.append(image_content) + + +@dataclass +class TextContent(ContentBase): + """Text content block""" + type: Literal["text"] = "text" + text: str = "" + + def __init__(self, text: str = ""): + super().__init__(text=text) + + def is_empty(self) -> bool: + """Check if the text content is empty.""" + return not self.text + + @classmethod + def build(cls, value: Any = "", **kwargs) -> 'TextContent': + """Build a text content block from a value. + + Args: + value: String or any value to convert to text + **kwargs: Unused, for compatibility with base class + + Returns: + TextContent: Text content block with the value as text + """ + if isinstance(value, str): + return cls(text=value) + return cls(text=str(value)) + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for serialization.""" + return {"type": self.type, "text": self.text} + + def to_litellm_format(self, role: str = "user") -> Dict[str, Any]: + """Convert to LiteLLM/OpenAI Response API compatible format. + + Args: + role: The role context ("user" or "assistant") to determine the correct type + + Returns dict in format: + - {"type": "input_text", "text": "..."} for user messages + - {"type": "output_text", "text": "..."} for assistant messages + """ + text_type = "input_text" if role == "user" else "output_text" + return {"type": text_type, "text": self.text} + + def __add__(self, other) -> 'TextContent': + """Concatenate text content with strings or other TextContent objects. + + Args: + other: String or TextContent to concatenate + + Returns: + TextContent: New TextContent with concatenated text + """ + if isinstance(other, str): + return TextContent(text=self.text + " " + other) + elif isinstance(other, TextContent): + return TextContent(text=self.text + " " + other.text) + else: + return NotImplemented + + def __radd__(self, other) -> 'TextContent': + """Right-side concatenation (when string is on the left). + + Args: + other: String to concatenate + + Returns: + TextContent: New TextContent with concatenated text + """ + if isinstance(other, str): + return TextContent(text=other + " " + self.text) + else: + return NotImplemented + + +@dataclass +class ImageContent(ContentBase): + """Image content block - supports URLs, base64, file paths, and numpy arrays. + + OpenAI uses base64 encoded images in the image_data field and recombine it into a base64 string of the format `"image_url": f"data:image/jpeg;base64,{base64_image}"` when sending to the API. + Gemini uses raw bytes in the image_bytes field: + ``` + types.Part.from_bytes( + data=image_bytes, + mime_type='image/jpeg', + ) + ``` + + Supports multiple ways to create an ImageContent: + 1. Direct instantiation with image_url or image_data + 2. from_file/from_path: Load from local file path + 3. from_url: Create from HTTP/HTTPS URL + 4. from_array: Create from numpy array or array-like RGB image + 5. from_value: Auto-detect and create from various formats + """ + type: Literal["image"] = "image" + image_url: Optional[str] = None + image_data: Optional[str] = None # base64 encoded + image_bytes: Optional[bytes] = None + media_type: str = "image/jpeg" # image/jpeg, image/png, image/gif, image/webp + detail: Optional[str] = None # OpenAI: "auto", "low", "high" + + def __init__(self, value: Any = None, format: str = "PNG", **kwargs): + """Initialize ImageContentBlock with auto-detection of input type. + + Args: + value: Can be: + - URL string (starting with 'http://' or 'https://') + - Data URL string (starting with 'data:image/') + - Local file path (string) + - Numpy array or array-like RGB image + - PIL Image object + - Raw bytes + - None (empty image) + format: Image format for numpy arrays (PNG, JPEG, etc.). Default: PNG + **kwargs: Direct field values (image_url, image_data, media_type, detail) + """ + # If explicit field values are provided, use them directly + if kwargs: + kwargs.setdefault('type', 'image') + kwargs.setdefault('media_type', 'image/jpeg') + super().__init__(**kwargs) + else: + # Use autocast to detect and convert the value + value_dict = self.autocast(value, format=format) + super().__init__(**value_dict) + + def __str__(self) -> str: + # Truncate image_data and image_bytes for readability + image_data_str = f"{self.image_data[:10]}..." if self.image_data and len(self.image_data) > 10 else self.image_data + image_bytes_str = f"{str(self.image_bytes[:10])}..." if self.image_bytes and len(self.image_bytes) > 10 else self.image_bytes + return f"ImageContent(image_url={self.image_url}, image_data={image_data_str}, image_bytes={image_bytes_str}, media_type={self.media_type})" + + def __repr__(self) -> str: + # Truncate image_data and image_bytes for readability + image_data_str = f"{self.image_data[:10]}..." if self.image_data and len(self.image_data) > 10 else self.image_data + image_bytes_str = f"{str(self.image_bytes[:10])}..." if self.image_bytes and len(self.image_bytes) > 10 else self.image_bytes + return f"ImageContent(image_url={self.image_url}, image_data={image_data_str}, image_bytes={image_bytes_str}, media_type={self.media_type})" + + def is_empty(self) -> bool: + """Check if the image content is empty (no URL or data).""" + return not self.image_url and not self.image_data and not self.image_bytes + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for serialization (not LiteLLM format). + + For LiteLLM format, use to_litellm_format() instead. + """ + result = { + "type": self.type, + "media_type": self.media_type + } + if self.image_url: + result["image_url"] = self.image_url + if self.image_data: + result["image_data"] = self.image_data + if self.image_bytes: + result["image_bytes"] = self.image_bytes + if self.detail: + result["detail"] = self.detail + return result + + def to_litellm_format(self) -> Dict[str, Any]: + """Convert to LiteLLM Response API compatible format. + + Returns dict in format: + {"type": "input_image", "image_url": {"url": "..."}} + """ + # Determine the URL to use + if self.image_url: + url = self.image_url + elif self.image_data: + # Convert base64 data to data URL + url = f"data:{self.media_type};base64,{self.image_data}" + elif self.image_bytes: + # Convert bytes to base64 and then to data URL + import base64 + b64_data = base64.b64encode(self.image_bytes).decode('utf-8') + url = f"data:{self.media_type};base64,{b64_data}" + else: + # Empty image + return {"type": "input_image", "image_url": ""} + + # Build the result in Response API format + result = { + "type": "input_image", + "image_url": url + } + + # Add detail if specified (OpenAI-specific) + if self.detail: + result["detail"] = self.detail + + return result + + @classmethod + def from_file(cls, filepath: str, media_type: Optional[str] = None): + """Load image from file path.""" + path = Path(filepath) + if not media_type: + ext_to_type = { + '.jpg': 'image/jpeg', + '.jpeg': 'image/jpeg', + '.png': 'image/png', + '.gif': 'image/gif', + '.webp': 'image/webp' + } + media_type = ext_to_type.get(path.suffix.lower(), 'image/jpeg') + + with open(filepath, 'rb') as f: + image_data = base64.b64encode(f.read()).decode('utf-8') + + return cls(image_data=image_data, media_type=media_type) + + @classmethod + def from_path(cls, filepath: str, media_type: Optional[str] = None): + """Load image from file path. Alias for from_file.""" + return cls.from_file(filepath, media_type) + + @classmethod + def from_url(cls, url: str, media_type: str = "image/jpeg"): + """Create ImageContent from an HTTP/HTTPS URL. + + Args: + url: HTTP or HTTPS URL pointing to an image + media_type: MIME type of the image (default: image/jpeg) + """ + return cls(image_url=url, media_type=media_type) + + @classmethod + def from_array(cls, array: Any, format: str = "PNG"): + """Create ImageContent from a numpy array or array-like RGB image. + + Args: + array: numpy array representing an image (H, W, C) with values in [0, 255] or [0, 1] + format: Image format (PNG, JPEG, etc.). Default: PNG + + Returns: + ImageContent with base64-encoded image data + """ + try: + import numpy as np + except ImportError: + raise ImportError("numpy is required for from_array. Install with: pip install numpy") + + try: + from PIL import Image + except ImportError: + raise ImportError("Pillow is required for from_array. Install with: pip install Pillow") + + import io + + # Convert to numpy array if not already + if not isinstance(array, np.ndarray): + array = np.array(array) + + # Normalize to [0, 255] if needed + if array.dtype == np.float32 or array.dtype == np.float64: + if array.max() <= 1.0: + array = (array * 255).astype(np.uint8) + else: + array = array.astype(np.uint8) + elif array.dtype != np.uint8: + array = array.astype(np.uint8) + + # Convert to PIL Image and encode + image = Image.fromarray(array) + buffer = io.BytesIO() + image.save(buffer, format=format.upper()) + buffer.seek(0) + + image_data = base64.b64encode(buffer.getvalue()).decode('utf-8') + media_type = f"image/{format.lower()}" + + return cls(image_data=image_data, media_type=media_type) + + @classmethod + def from_pil(cls, image: Any, format: str = "PNG"): + """Create ImageContent from a PIL Image. + + Args: + image: PIL Image object + format: Image format (PNG, JPEG, etc.). Default: PNG + + Returns: + ImageContent with base64-encoded image data + """ + import io + + buffer = io.BytesIO() + img_format = image.format or format.upper() + image.save(buffer, format=img_format) + buffer.seek(0) + + image_data = base64.b64encode(buffer.getvalue()).decode('utf-8') + media_type = f"image/{img_format.lower()}" + + return cls(image_data=image_data, media_type=media_type) + + @classmethod + def from_bytes(cls, data: bytes, media_type: str = "image/jpeg"): + """Create ImageContent from raw image bytes. + + Args: + data: Raw image bytes + media_type: MIME type of the image (default: image/jpeg) + + Returns: + ImageContent with base64-encoded data + """ + image_data = base64.b64encode(data).decode('utf-8') + return cls(image_data=image_data, media_type=media_type) + + @classmethod + def from_base64(cls, b64_data: str, media_type: str = "image/jpeg"): + """Create ImageContent from base64-encoded string. + + Args: + b64_data: Base64-encoded image data (without data URL prefix) + media_type: MIME type of the image (default: image/jpeg) + + Returns: + ImageContent with the provided base64 data + """ + return cls(image_data=b64_data, media_type=media_type) + + @classmethod + def from_data_url(cls, data_url: str): + """Create ImageContent from a data URL (data:image/...;base64,...). + + Args: + data_url: Data URL string in format data:image/;base64, + + Returns: + ImageContent with extracted base64 data and media type + """ + try: + header, b64_data = data_url.split(',', 1) + media_type = header.split(':')[1].split(';')[0] # e.g., "image/png" + return cls(image_data=b64_data, media_type=media_type) + except (ValueError, IndexError): + # Fallback: assume the whole thing is base64 data + return cls(image_data=data_url.split(',')[-1], media_type="image/jpeg") + + @staticmethod + def autocast(value: Any, format: str = "PNG") -> Dict[str, Any]: + """Auto-detect value type and return image field values. + + Args: + value: Can be: + - URL string (starting with 'http://' or 'https://') + - Data URL string (starting with 'data:image/') + - Local file path (string) + - Numpy array or array-like RGB image + - PIL Image object + - Raw bytes + - None (empty image) + format: Image format for numpy arrays (PNG, JPEG, etc.). Default: PNG + + Returns: + Dictionary with keys: image_url, image_data, image_bytes, media_type + """ + # Handle None or empty + if value is None: + return {"image_url": None, "image_data": None, "image_bytes": None, "media_type": "image/jpeg"} + + # Handle ImageContentBlock instance + if isinstance(value, ImageContent): + return { + "image_url": value.image_url, + "image_data": value.image_data, + "image_bytes": value.image_bytes, + "media_type": value.media_type + } + + # Handle string inputs + if isinstance(value, str): + if not value.strip(): + return {"image_url": None, "image_data": None, "image_bytes": None, "media_type": "image/jpeg"} + + # Data URL + if value.startswith('data:image/'): + try: + header, b64_data = value.split(',', 1) + media_type = header.split(':')[1].split(';')[0] + return {"image_url": None, "image_data": b64_data, "image_bytes": None, "media_type": media_type} + except (ValueError, IndexError): + return {"image_url": None, "image_data": value.split(',')[-1], "image_bytes": None, "media_type": "image/jpeg"} + + # HTTP/HTTPS URL + if value.startswith('http://') or value.startswith('https://'): + return {"image_url": value, "image_data": None, "image_bytes": None, "media_type": "image/jpeg"} + + # File path - only check if string is reasonable length (< 4096 chars) + # Long strings are clearly not file paths and would cause OS errors + if len(value) < 4096: + path = Path(value) + try: + if path.exists(): + ext_to_type = { + '.jpg': 'image/jpeg', + '.jpeg': 'image/jpeg', + '.png': 'image/png', + '.gif': 'image/gif', + '.webp': 'image/webp' + } + media_type = ext_to_type.get(path.suffix.lower(), 'image/jpeg') + with open(value, 'rb') as f: + image_data = base64.b64encode(f.read()).decode('utf-8') + return {"image_url": None, "image_data": image_data, "image_bytes": None, "media_type": media_type} + except (OSError, IOError): + # Not a valid file path, continue to other checks + pass + + # Handle bytes - store as base64 for portability + if isinstance(value, bytes): + image_data = base64.b64encode(value).decode('utf-8') + return {"image_url": None, "image_data": image_data, "image_bytes": None, "media_type": "image/jpeg"} + + # Handle PIL Image + try: + from PIL import Image + if isinstance(value, Image.Image): + import io + buffer = io.BytesIO() + img_format = value.format or format.upper() + value.save(buffer, format=img_format) + buffer.seek(0) + image_data = base64.b64encode(buffer.getvalue()).decode('utf-8') + media_type = f"image/{img_format.lower()}" + return {"image_url": None, "image_data": image_data, "image_bytes": None, "media_type": media_type} + except ImportError: + pass + + # Handle numpy array or array-like + try: + import numpy as np + if isinstance(value, np.ndarray) or hasattr(value, '__array__'): + try: + from PIL import Image + except ImportError: + raise ImportError("Pillow is required for array conversion. Install with: pip install Pillow") + + import io + + if not isinstance(value, np.ndarray): + value = np.array(value) + + # Normalize to [0, 255] if needed + if value.dtype == np.float32 or value.dtype == np.float64: + if value.max() <= 1.0: + value = (value * 255).astype(np.uint8) + else: + value = value.astype(np.uint8) + elif value.dtype != np.uint8: + value = value.astype(np.uint8) + + image = Image.fromarray(value) + buffer = io.BytesIO() + image.save(buffer, format=format.upper()) + buffer.seek(0) + + image_data = base64.b64encode(buffer.getvalue()).decode('utf-8') + media_type = f"image/{format.lower()}" + return {"image_url": None, "image_data": image_data, "image_bytes": None, "media_type": media_type} + except ImportError: + pass + + return {"image_url": None, "image_data": None, "image_bytes": None, "media_type": "image/jpeg"} + + @classmethod + def build(cls, value: Any, format: str = "PNG") -> 'ImageContent': + """Auto-detect format and create ImageContent from various input types. + + Args: + value: Can be: + - URL string (starting with 'http://' or 'https://') + - Data URL string (starting with 'data:image/') + - Local file path (string) + - Numpy array or array-like RGB image + - PIL Image object + - Raw bytes + format: Image format for numpy arrays (PNG, JPEG, etc.). Default: PNG + + Returns: + ImageContent or None if the value cannot be converted + """ + # Handle ImageContentBlock instance directly + if isinstance(value, cls): + return value + + value_dict = cls.autocast(value, format=format) + return cls(**value_dict) + + def set_image(self, image: Any, format: str = "PNG") -> None: + """Set the image from various input formats (mutates self). + + Args: + image: Can be: + - URL string (starting with 'http://' or 'https://') + - Data URL string (starting with 'data:image/') + - Local file path (string) + - Numpy array or array-like RGB image + - PIL Image object + - Raw bytes + format: Image format for numpy arrays (PNG, JPEG, etc.). Default: PNG + """ + result = ImageContent.build(image, format=format) + if result: + self.image_url = result.image_url + self.image_data = result.image_data + # Only copy image_bytes if it was explicitly set (e.g., from Google API) + if result.image_bytes: + self.image_bytes = result.image_bytes + self.media_type = result.media_type + + def as_image(self) -> Image.Image: + """Convert the image to a PIL Image. + + Fetches the image from URL if necessary (including HTTP/HTTPS URLs). + + Returns: + PIL Image object + + Raises: + ValueError: If no image data is available + requests.RequestException: If fetching from URL fails + """ + # Try to get image bytes from any available source + image_bytes = self.get_bytes() + + if image_bytes: + return Image.open(io.BytesIO(image_bytes)) + elif self.image_url: + if self.image_url.startswith(('http://', 'https://')): + # Fetch image from URL + try: + import requests + response = requests.get(self.image_url, timeout=30) + response.raise_for_status() + return Image.open(io.BytesIO(response.content)) + except ImportError: + # Fallback to urllib if requests is not available + from urllib.request import urlopen + with urlopen(self.image_url, timeout=30) as response: + return Image.open(io.BytesIO(response.read())) + else: + # If it's a local file path + return Image.open(self.image_url) + else: + raise ValueError("No image data available to convert to PIL Image") + + def show(self) -> Image.Image: + """A convenience alias for as_image()""" + return self.as_image() + + def get_bytes(self) -> Optional[bytes]: + """Get raw image bytes. + + Returns image_bytes if available, otherwise decodes image_data from base64. + + Returns: + Raw image bytes or None if no image data available + """ + if self.image_bytes: + return self.image_bytes + elif self.image_data: + return base64.b64decode(self.image_data) + return None + + def get_base64(self) -> Optional[str]: + """Get base64-encoded image data. + + Returns image_data if available, otherwise encodes image_bytes to base64. + + Returns: + Base64-encoded string or None if no image data available + """ + if self.image_data: + return self.image_data + elif self.image_bytes: + return base64.b64encode(self.image_bytes).decode('utf-8') + return None + + def ensure_bytes(self) -> None: + """Ensure image_bytes is populated (converts from image_data if needed).""" + if not self.image_bytes and self.image_data: + self.image_bytes = base64.b64decode(self.image_data) + + def ensure_base64(self) -> None: + """Ensure image_data is populated (converts from image_bytes if needed).""" + if not self.image_data and self.image_bytes: + self.image_data = base64.b64encode(self.image_bytes).decode('utf-8') + + +# Union type alias for the supported content types (for type hints). +ContentBlock = Union[TextContent, ImageContent] diff --git a/opto/utils/backbone/template.py b/opto/utils/backbone/template.py new file mode 100644 index 00000000..ae0b659d --- /dev/null +++ b/opto/utils/backbone/template.py @@ -0,0 +1,194 @@ +"""PromptTemplate: ``str.format``-like templating that also supports +multimodal :class:`ContentBlockList` values. +""" +from typing import Union + +from .content import ContentBlockList + + +class PromptTemplate: + """Template for building ContentBlockLists with {placeholder} support. + + Similar to str.format(), but supports multimodal content (ContentBlockList). + + Return type depends on values: + - All strings → returns str (backward compatible) + - Any multimodal content → returns ContentBlockList + + Features: + - Multiple placeholders: {a}, {b}, {c} + - Escaping: {{ and }} for literal braces + - Missing placeholders: left as-is in text + - Extra kwargs: silently ignored (no error) + - Nested templates: if value is PromptTemplate, formats it first + - Mixed values: str, ContentBlockList, or objects with to_content_blocks() + + Examples: + # Define template (can be class attribute) + user_prompt_template = PromptTemplate(''' + Now you see problem instance: + + ================================ + {problem_instance} + ================================ + ''') + + # Format with ContentBlockList (may contain images) + content = user_prompt_template.format( + problem_instance=problem.to_content_blocks() + ) + # Returns ContentBlockList: [TextContent("Now you see..."), *problem_blocks, TextContent("===...")] + + # Multiple placeholders + template = PromptTemplate("User: {user}\\nAssistant: {assistant}") + result = template.format(user=user_blocks, assistant=assistant_blocks) + + # Nested templates + outer = PromptTemplate("Header\\n{body}\\nFooter") + inner = PromptTemplate("Content: {data}") + result = outer.format(body=inner, data="some data") # inner gets same kwargs + + # Escaping braces + template = PromptTemplate('JSON example: {{"key": "{value}"}}') + result = template.format(value="hello") # {"key": "hello"} + + # Extra kwargs are ignored (no error) + result = template.format(value="hello", unused_key="ignored") + + # Missing placeholders left as-is + template = PromptTemplate("Hello {name}, score: {score}") + result = template.format(name="Alice") # "Hello Alice, score: {score}" + """ + + # Regex to find {placeholder} but not {{ or }} + _PLACEHOLDER_PATTERN = None # Lazy compiled + + def __init__(self, template: str): + """Initialize with a template string. + + Args: + template: Template string with {placeholder} syntax. + """ + self.template = template + + @classmethod + def _get_pattern(cls): + """Lazily compile the placeholder regex pattern.""" + if cls._PLACEHOLDER_PATTERN is None: + import re + # Match {name} but not {{ or }} + # Captures the placeholder name + cls._PLACEHOLDER_PATTERN = re.compile(r'\{(\w+)\}') + return cls._PLACEHOLDER_PATTERN + + def format(self, **kwargs) -> Union[str, 'ContentBlockList']: + """Format the template with the given values. + + Similar to str.format(), but supports multimodal content. + Extra kwargs are silently ignored. + + If all values are strings, returns a str (backward compatible). + If any value is a ContentBlockList or multimodal, returns ContentBlockList. + + Args: + **kwargs: Placeholder values. Each value can be: + - str: inserted as text + - ContentBlockList: blocks spliced in at that position + - PromptTemplate: formatted first, then spliced in + - Object with to_content_blocks(): method called, result spliced + - Other: converted to str + + Returns: + str: If all values are strings (backward compatible behavior). + ContentBlockList: If any value is multimodal content. + """ + # Check if all values are simple strings - if so, use simple string formatting + pattern = self._get_pattern() + placeholder_names = set(pattern.findall(self.template)) + + # Only check values for placeholders that exist in the template + relevant_values = {k: v for k, v in kwargs.items() if k in placeholder_names} + + if all(isinstance(v, str) for v in relevant_values.values()): + # All strings: use simple string replacement, return str + # Handle escaping and missing placeholders + result = self.template.replace("{{", "\x00LBRACE\x00").replace("}}", "\x00RBRACE\x00") + + for name in placeholder_names: + placeholder = "{" + name + "}" + if name in kwargs: + result = result.replace(placeholder, kwargs[name]) + # Missing placeholders left as-is + + result = result.replace("\x00LBRACE\x00", "{").replace("\x00RBRACE\x00", "}") + return result + + # Multimodal content: build ContentBlockList + result = ContentBlockList() + + # Handle escaping: replace {{ with a sentinel, }} with another + LBRACE_SENTINEL = "\x00LBRACE\x00" + RBRACE_SENTINEL = "\x00RBRACE\x00" + + text = self.template.replace("{{", LBRACE_SENTINEL).replace("}}", RBRACE_SENTINEL) + + last_end = 0 + + for match in pattern.finditer(text): + # Add text before this placeholder + prefix = text[last_end:match.start()] + if prefix: + # Restore escaped braces in prefix + prefix = prefix.replace(LBRACE_SENTINEL, "{").replace(RBRACE_SENTINEL, "}") + result.append(prefix) + + # Get placeholder name and value + placeholder_name = match.group(1) + + if placeholder_name in kwargs: + value = kwargs[placeholder_name] + # Convert value to ContentBlockList and splice in + content = self._value_to_content(value, **kwargs) + result.extend(content) + else: + # Missing placeholder: leave as-is (restore original {name}) + result.append("{" + placeholder_name + "}") + + last_end = match.end() + + # Add remaining text after last placeholder + suffix = text[last_end:] + if suffix: + suffix = suffix.replace(LBRACE_SENTINEL, "{").replace(RBRACE_SENTINEL, "}") + result.append(suffix) + + return result + + def _value_to_content(self, value, **kwargs) -> 'ContentBlockList': + """Convert a value to ContentBlockList. + + Args: + value: The value to convert + **kwargs: Passed to nested PromptTemplate.render() + + Returns: + ContentBlockList: The value as content blocks. + """ + if isinstance(value, ContentBlockList): + return value + elif isinstance(value, PromptTemplate): + # Nested template: format it with the same kwargs + return value.format(**kwargs) + elif hasattr(value, 'to_content_blocks'): + # Object with to_content_blocks method (e.g., ProblemInstance) + return value.to_content_blocks() + elif isinstance(value, str): + return ContentBlockList(value) + else: + # Fallback: convert to string + return ContentBlockList(str(value)) + + def __repr__(self) -> str: + """Return a preview of the template.""" + preview = self.template[:50] + "..." if len(self.template) > 50 else self.template + return f"PromptTemplate({preview!r})" diff --git a/opto/utils/backbone/turns.py b/opto/utils/backbone/turns.py new file mode 100644 index 00000000..73953b71 --- /dev/null +++ b/opto/utils/backbone/turns.py @@ -0,0 +1,748 @@ +"""Conversation turns: :class:`UserTurn` and :class:`AssistantTurn`. + +``AssistantTurn.autocast`` parses raw responses from LiteLLM/OpenAI (Responses +and Completion APIs), Bedrock Converse, and Google GenAI into a uniform shape. +""" +from typing import List, Dict, Any, Optional +from dataclasses import dataclass, field + +from .content import ContentBlockList, TextContent, ImageContent + + +@dataclass +class UserTurn: + """Represents a user message turn in the conversation""" + role: str = "user" + + content: ContentBlockList = field(default_factory=ContentBlockList) + + # Provider-specific settings + temperature: Optional[float] = None + max_tokens: Optional[int] = None + top_p: Optional[float] = None + + # Metadata + timestamp: Optional[str] = None + metadata: Dict[str, Any] = field(default_factory=dict) + + def __init__(self, content=None, tools=None, **kwargs): + """ + Initialize UserTurn with content and tools. + + Four ways to initialize: + 1. Empty: UserTurn() - creates empty turn with defaults + 2. Copy: UserTurn(existing_turn) - creates a copy of an existing UserTurn + 3. Positional args: UserTurn(content, tools) - pass content and/or tools + 4. Keyword args: UserTurn(content=..., tools=..., temperature=...) - explicit fields + + Args: + content: ContentBlockList, list of content blocks, UserTurn (for copying), or None + tools: List of ToolDefinition or None + **kwargs: Additional fields (temperature, max_tokens, top_p, timestamp, metadata) + """ + self.output_contains_image = False + + # Handle copy constructor: UserTurn(existing_turn) + if isinstance(content, UserTurn): + source = content + self.role = source.role + self.content = ContentBlockList(source.content) # Deep copy the content list + self.temperature = source.temperature + self.max_tokens = source.max_tokens + self.top_p = source.top_p + self.timestamp = source.timestamp + self.metadata = dict(source.metadata) # Copy the metadata dict + return + + # Handle content + if content is None: + content = ContentBlockList() + elif not isinstance(content, ContentBlockList): + # If it's a list, wrap it in ContentBlockList + content = ContentBlockList(content) if isinstance(content, list) else ContentBlockList([content]) + + + # Set all fields + self.role = kwargs.get('role', "user") + self.content = content + self.temperature = kwargs.get('temperature', None) + self.max_tokens = kwargs.get('max_tokens', None) + self.top_p = kwargs.get('top_p', None) + self.timestamp = kwargs.get('timestamp', None) + self.metadata = kwargs.get('metadata', {}) + + def add_text(self, text: str) -> 'UserTurn': + """Add text content""" + self.content.append(TextContent(text=text)) + return self + + def add_image(self, url: Optional[str] = None, data: Optional[str] = None, + media_type: str = "image/jpeg") -> 'UserTurn': + """Add image content""" + self.content.append(ImageContent( + image_url=url, + image_data=data, + media_type=media_type + )) + return self + + def add_image_file(self, filepath: str) -> 'UserTurn': + """Add image from file""" + self.content.append(ImageContent.from_file(filepath)) + return self + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary format""" + return { + "role": "user", + "content": [c.to_dict() for c in self.content], + "temperature": self.temperature, + "max_tokens": self.max_tokens, + "top_p": self.top_p, + "metadata": self.metadata + } + + def enable_image_generation(self): + self.output_contains_image = True + + def __repr__(self) -> str: + """Safe string representation that handles missing attributes.""" + content_preview = str(self.content)[:50] + "..." if len(str(self.content)) > 50 else str(self.content) + parts = [f"UserTurn(content={content_preview!r}"] + + # Safely add optional fields if they exist + temperature = getattr(self, 'temperature', None) + if temperature is not None: + parts.append(f", temperature={temperature}") + + parts.append(")") + return "".join(parts) + + def to_litellm_format(self) -> Dict[str, Any]: + """Convert to LiteLLM Response API format (OpenAI Response API compatible)""" + return { + "role": "user", + "content": self.content.to_litellm_format(role="user") + } + + def _repr_html_(self) -> str: + """Rich HTML representation for Jupyter notebooks with glassmorphism design.""" + try: + from opto.utils.display.jupyter import render_user_turn + return render_user_turn(self) + except ImportError: + # Fallback to text representation if display module unavailable + return None + + +@dataclass +class Turn: + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + +@dataclass +class AssistantTurn(Turn): + """Represents an assistant message turn in the conversation""" + role: str = "assistant" + content: ContentBlockList = field(default_factory=ContentBlockList) + + # Provider-specific features + reasoning: Optional[str] = None # OpenAI reasoning/thinking + finish_reason: Optional[str] = None # "stop", "length", "tool_calls", etc. + + # Token usage + prompt_tokens: Optional[int] = None + completion_tokens: Optional[int] = None + + # Metadata + model: Optional[str] = None + timestamp: Optional[str] = None + metadata: Dict[str, Any] = field(default_factory=dict) + + def __init__(self, *args, **kwargs): + """ + Initialize AssistantTurn from a raw response or with explicit fields. + + Three ways to initialize: + 1. Empty: AssistantTurn() - creates empty turn with defaults + 2. From raw response: AssistantTurn(response) - autocasts the response + 3. With fields: AssistantTurn(role="assistant", content=[...]) - explicit fields + """ + if len(args) == 1 and isinstance(args[0], AssistantTurn): + # Case: Copy constructor - create a copy of another AssistantTurn + other = args[0] + super().__init__( + role=other.role, + content=ContentBlockList(other.content), + reasoning=other.reasoning, + finish_reason=other.finish_reason, + prompt_tokens=other.prompt_tokens, + completion_tokens=other.completion_tokens, + model=other.model, + timestamp=other.timestamp, + metadata=dict(other.metadata) + ) + return + + if len(args) > 0 and len(kwargs) == 0: + # Case 2: Single positional arg - autocast from raw response + value_dict = self.autocast(args[0]) + super().__init__(**value_dict) + elif len(kwargs) > 0: + # Case 3: Keyword arguments - use them directly + super().__init__(**kwargs) + else: + # Case 1: No arguments - initialize with defaults + super().__init__( + role="assistant", + content=ContentBlockList(), + reasoning=None, + finish_reason=None, + prompt_tokens=None, + completion_tokens=None, + model=None, + timestamp=None, + metadata={} + ) + + @staticmethod + def from_google_genai(value: Any) -> Dict[str, Any]: + """Parse a Google GenAI response into a dictionary of AssistantTurn fields. + + Supports both the legacy generate_content API and the new Interactions API. + + Args: + value: Raw response from Google GenAI API + + Returns: + Dict[str, Any]: Dictionary with keys corresponding to AssistantTurn fields + """ + # Initialize the result dictionary with default values + result = { + "role": "assistant", + "content": ContentBlockList(), + "reasoning": None, + "finish_reason": None, + "prompt_tokens": None, + "completion_tokens": None, + "model": None, + "timestamp": None, + "metadata": {} + } + + # Check if this is a normalized response (from our GoogleGenAILLM) + if hasattr(value, 'raw_response'): + raw_response = value.raw_response + else: + raw_response = value + + # Handle Interactions API format (new) + if hasattr(raw_response, 'outputs'): + # This is an Interaction object + interaction = raw_response + + # Extract text from outputs + if interaction.outputs and len(interaction.outputs) > 0: + for output in interaction.outputs: + if hasattr(output, 'text') and output.text: + result["content"].append(TextContent(text=output.text)) + # Handle other output types if they exist + elif hasattr(output, 'content'): + # Content could be a list of parts + if isinstance(output.content, list): + for part in output.content: + if hasattr(part, 'text') and part.text: + result["content"].append(TextContent(text=part.text)) + else: + result["content"].append(TextContent(text=str(output.content))) + + # Extract model info + if hasattr(interaction, 'model'): + result["model"] = interaction.model + + # Extract status as finish_reason + if hasattr(interaction, 'status'): + result["finish_reason"] = interaction.status + + # Extract token usage from Interactions API + if hasattr(interaction, 'usage'): + usage = interaction.usage + if hasattr(usage, 'input_tokens'): + result["prompt_tokens"] = usage.input_tokens + elif hasattr(usage, 'prompt_token_count'): + result["prompt_tokens"] = usage.prompt_token_count + + if hasattr(usage, 'output_tokens'): + result["completion_tokens"] = usage.output_tokens + elif hasattr(usage, 'candidates_token_count'): + result["completion_tokens"] = usage.candidates_token_count + + # Extract interaction ID as metadata + if hasattr(interaction, 'id'): + result["metadata"]['interaction_id'] = interaction.id + + # Handle legacy generate_content API format + else: + # Extract thinking/reasoning (for Gemini 2.5+ models) + if hasattr(raw_response, 'thoughts') and raw_response.thoughts: + # Gemini's thinking budget feature + result["reasoning"] = str(raw_response.thoughts) + + # Extract model info + if hasattr(raw_response, 'model_version'): + result["model"] = raw_response.model_version + + # Extract token usage (if available) + if hasattr(raw_response, 'usage_metadata'): + usage = raw_response.usage_metadata + if hasattr(usage, 'prompt_token_count'): + result["prompt_tokens"] = usage.prompt_token_count + if hasattr(usage, 'candidates_token_count'): + result["completion_tokens"] = usage.candidates_token_count + + # Handle multimodal content from Gemini (candidates with parts) + content_extracted = False + if hasattr(raw_response, 'candidates') and raw_response.candidates: + candidate = raw_response.candidates[0] + + # Extract from parts (supports multimodal responses with text and images) + if hasattr(candidate, 'content') and hasattr(candidate.content, 'parts'): + for part in candidate.content.parts: + # Handle text parts + if hasattr(part, 'text') and part.text: + result["content"].append(TextContent(text=part.text)) + content_extracted = True + # Handle inline data (images, generated images, etc.) + elif hasattr(part, 'inline_data'): + # Try to extract image data, preferring direct inline_data access + inline = part.inline_data + image_bytes = None + image_data = None + media_type = 'image/jpeg' + + + # Extract from inline_data Blob (most reliable method) + # Google's Blob.data should be raw bytes + if hasattr(inline, 'data'): + data = inline.data + # Check if it's bytes or string + if isinstance(data, bytes): + # Store raw bytes for Gemini compatibility + # (Gemini prefers raw bytes when sending images) + image_bytes = data + elif isinstance(data, str): + # Already base64-encoded string + image_data = data + # Don't decode to bytes - keep as base64 for portability + + if hasattr(inline, 'mime_type'): + media_type = inline.mime_type + + # If we got the data, create ImageContent + # Store image_bytes only if we got raw bytes from Google + if image_data or image_bytes: + result["content"].append(ImageContent( + image_data=image_data, + image_bytes=image_bytes if isinstance(data, bytes) else None, + media_type=media_type + )) + content_extracted = True + + # Extract finish reason + if hasattr(candidate, 'finish_reason'): + result["finish_reason"] = str(candidate.finish_reason) + + # Fallback: Extract simple text content if no candidates/parts were found + if not content_extracted: + if hasattr(raw_response, 'text'): + result["content"].append(TextContent(text=raw_response.text)) + elif hasattr(value, 'choices'): + # Fallback to normalized format + result["content"].append(TextContent(text=value.choices[0].message.content)) + + return result + + @staticmethod + def from_litellm_openai_response_api(value: Any) -> Dict[str, Any]: + """Parse a LiteLLM/OpenAI-style response into a dictionary of AssistantTurn fields. + + Handles both formats: + - New Responses API: Has 'output' field with ResponseOutputMessage objects + - Legacy Completion API: Has 'choices' field with message objects + + Args: + value: Response from LiteLLM/OpenAI API (Responses API or Completion API) + + Returns: + Dict[str, Any]: Dictionary with keys corresponding to AssistantTurn fields + """ + # Initialize the result dictionary with default values + result = { + "role": "assistant", + "content": ContentBlockList(), + "reasoning": None, + "finish_reason": None, + "prompt_tokens": None, + "completion_tokens": None, + "model": None, + "timestamp": None, + "metadata": {} + } + + # Handle Bedrock Converse API format (has 'output' field with 'message') + # Check both attribute-based and dict-based access for robustness + is_bedrock = False + bedrock_output = None + bedrock_value = value # Keep reference to the original value for later access + + # Try attribute-based access first + if hasattr(value, 'output'): + output_val = value.output + + if hasattr(output_val, 'message'): + is_bedrock = True + bedrock_output = output_val + # Also check dict-based access on the output attribute + elif isinstance(output_val, dict) and 'message' in output_val: + is_bedrock = True + bedrock_output = output_val + + # If not found, try dict-based access on value itself + if not is_bedrock and isinstance(value, dict) and 'output' in value: + output_val = value['output'] + if isinstance(output_val, dict) and 'message' in output_val: + is_bedrock = True + bedrock_output = output_val + bedrock_value = value # Use the dict directly + + if is_bedrock and bedrock_output is not None: + # Bedrock Converse API format detected + # Get message with dict or attr access + message = bedrock_output.get('message') if isinstance(bedrock_output, dict) else (bedrock_output.message if hasattr(bedrock_output, 'message') else None) + + if message: + # Extract role + if isinstance(message, dict): + result["role"] = message.get('role', 'assistant') + elif hasattr(message, 'role'): + result["role"] = message.role + + # Extract content + content_list = message.get('content') if isinstance(message, dict) else (message.content if hasattr(message, 'content') else None) + + if content_list: + for content_item in content_list: + # Handle text content (dict or attr) + text_val = None + if isinstance(content_item, dict): + text_val = content_item.get('text') + elif hasattr(content_item, 'text'): + text_val = content_item.text + + if text_val: + result["content"].append(TextContent(text=text_val)) + + # Extract finish reason from stopReason (check both value and bedrock_value) + stop_reason = None + if isinstance(bedrock_value, dict): + stop_reason = bedrock_value.get('stopReason') + elif hasattr(bedrock_value, 'stopReason'): + stop_reason = bedrock_value.stopReason + if stop_reason: + result["finish_reason"] = stop_reason + + # Extract token usage (check both value and bedrock_value) + usage = None + if isinstance(bedrock_value, dict): + usage = bedrock_value.get('usage') + elif hasattr(bedrock_value, 'usage'): + usage = bedrock_value.usage + + if usage: + if isinstance(usage, dict): + result["prompt_tokens"] = usage.get('inputTokens') + result["completion_tokens"] = usage.get('outputTokens') + else: + if hasattr(usage, 'inputTokens'): + result["prompt_tokens"] = usage.inputTokens + if hasattr(usage, 'outputTokens'): + result["completion_tokens"] = usage.outputTokens + + # Handle Responses API format (new format with 'output' field) + # The output field is a list of output items (messages, image generation calls, etc.) + # NOTE: LiteLLM may set value.object to 'chat.completion' or 'response' depending on the provider + elif hasattr(value, 'output') and hasattr(value, 'object'): + # Extract metadata + if hasattr(value, 'id'): + result["metadata"]['response_id'] = value.id + if hasattr(value, 'created_at'): + result["timestamp"] = str(value.created_at) + + # Extract model info + if hasattr(value, 'model'): + result["model"] = value.model + + # Extract status as finish_reason + if hasattr(value, 'status'): + result["finish_reason"] = value.status + + # Extract content from output (list of output items) + if value.output and len(value.output) > 0: + for output_item in value.output: + # Handle ImageGenerationCall + if hasattr(output_item, 'type') and output_item.type == 'image_generation_call': + # Extract generated image + if hasattr(output_item, 'result') and output_item.result: + # Determine media type from output_format + media_type = 'image/jpeg' # default + if hasattr(output_item, 'output_format'): + format_map = { + 'png': 'image/png', + 'jpeg': 'image/jpeg', + 'jpg': 'image/jpeg', + 'webp': 'image/webp', + 'gif': 'image/gif' + } + media_type = format_map.get(output_item.output_format.lower(), 'image/jpeg') + + # Add image to content + result["content"].append(ImageContent( + image_data=output_item.result, + media_type=media_type + )) + + # Store additional metadata about the image generation + if hasattr(output_item, 'revised_prompt') and output_item.revised_prompt: + if 'image_generation' not in result["metadata"]: + result["metadata"]['image_generation'] = [] + result["metadata"]['image_generation'].append({ + 'id': output_item.id if hasattr(output_item, 'id') else None, + 'revised_prompt': output_item.revised_prompt, + 'size': output_item.size if hasattr(output_item, 'size') else None, + 'quality': output_item.quality if hasattr(output_item, 'quality') else None, + 'status': output_item.status if hasattr(output_item, 'status') else None + }) + + # Handle ResponseOutputMessage + elif hasattr(output_item, 'type') and output_item.type == 'message': + # Extract role + if hasattr(output_item, 'role'): + result["role"] = output_item.role + + # Extract status for this message + if hasattr(output_item, 'status') and not result["finish_reason"]: + result["finish_reason"] = output_item.status + + # Extract content items + if hasattr(output_item, 'content') and output_item.content: + for content_item in output_item.content: + # Handle text content + if hasattr(content_item, 'type') and content_item.type == 'output_text': + if hasattr(content_item, 'text') and content_item.text: + result["content"].append(TextContent(text=content_item.text)) + # Handle other content types as they become available + elif hasattr(content_item, 'text') and content_item.text: + result["content"].append(TextContent(text=str(content_item.text))) + + # Extract reasoning (for models with reasoning capabilities) + if hasattr(value, 'reasoning'): + reasoning_parts = [] + if isinstance(value.reasoning, dict): + if value.reasoning.get('summary'): + reasoning_parts.append(f"Summary: {value.reasoning['summary']}") + if value.reasoning.get('effort'): + reasoning_parts.append(f"Effort: {value.reasoning['effort']}") + if reasoning_parts: + result["reasoning"] = "\n".join(reasoning_parts) + elif value.reasoning: + result["reasoning"] = str(value.reasoning) + + # Extract token usage (Responses API format) + if hasattr(value, 'usage'): + if hasattr(value.usage, 'input_tokens'): + result["prompt_tokens"] = value.usage.input_tokens + if hasattr(value.usage, 'output_tokens'): + result["completion_tokens"] = value.usage.output_tokens + + # Handle legacy Completion API format (has 'choices' field) + elif hasattr(value, 'choices') and len(value.choices) > 0: + choice = value.choices[0] + message = choice.message if hasattr(choice, 'message') else choice + + # Extract text content + if hasattr(message, 'content') and message.content: + result["content"].append(TextContent(text=str(message.content))) + + + # Extract finish reason + if hasattr(choice, 'finish_reason'): + result["finish_reason"] = choice.finish_reason + + # Extract reasoning/thinking (for OpenAI o1/o3 models) + if hasattr(message, 'reasoning') and message.reasoning: + result["reasoning"] = message.reasoning + + # Extract token usage (Completion API format) + if hasattr(value, 'usage'): + if hasattr(value.usage, 'prompt_tokens'): + result["prompt_tokens"] = value.usage.prompt_tokens + if hasattr(value.usage, 'completion_tokens'): + result["completion_tokens"] = value.usage.completion_tokens + + # Extract model info + if hasattr(value, 'model'): + result["model"] = value.model + + return result + + @staticmethod + def autocast(value: Any) -> Dict[str, Any]: + """Automatically parse a response from any API into a dictionary of AssistantTurn fields. + + Automatically detects the response format and uses the appropriate parser: + - Google GenAI (generate_content or Interactions API) + - LiteLLM/OpenAI Responses API (new format with 'output' field) + - LiteLLM/OpenAI Completion API (legacy format with 'choices' field) + + Args: + value: Raw response from any supported API + + Returns: + Dict[str, Any]: Dictionary with keys corresponding to AssistantTurn fields + """ + + # Check if this is a normalized response (from our GoogleGenAILLM) + raw_response = value.raw_response if hasattr(value, 'raw_response') else value + + # Detect Google GenAI format (Interactions API or generate_content) + # Google GenAI has 'outputs' (Google Interactions API) or 'candidates' (generate_content) + # Note: 'outputs' is for Google's Interactions API, 'output' is for LiteLLM Responses API + if hasattr(raw_response, 'outputs') or \ + (hasattr(raw_response, 'candidates') and not hasattr(value, 'choices')) or \ + hasattr(raw_response, 'usage_metadata'): + return AssistantTurn.from_google_genai(value) + + # Detect LiteLLM/OpenAI/Bedrock format (Responses API, Completion API, or Bedrock Converse) + # Responses API has 'output' field and object='response' + # Completion API has 'choices' field + # Bedrock Converse API has 'output' field with nested 'message' + # Check both attribute and dict-based access + has_output = hasattr(value, 'output') or (isinstance(value, dict) and 'output' in value) + has_choices = hasattr(value, 'choices') or (isinstance(value, dict) and 'choices' in value) + + if has_output or has_choices: + return AssistantTurn.from_litellm_openai_response_api(value) + + # Fallback: if has 'text' attribute, might be a simple Google response + elif hasattr(raw_response, 'text'): + return AssistantTurn.from_google_genai(value) + + # Default to empty result if format is not recognized + else: + return { + "role": "assistant", + "content": ContentBlockList(), + "tool_calls": [], + "unparsed_tool_calls": [], + "tool_results": [], + "reasoning": None, + "finish_reason": None, + "prompt_tokens": None, + "completion_tokens": None, + "model": None, + "timestamp": None, + "metadata": {} + } + + def add_text(self, text: str) -> 'AssistantTurn': + """Add text content""" + self.content.append(text) + return self + + def add_image(self, url: Optional[str] = None, data: Optional[str] = None, + media_type: str = "image/jpeg") -> 'AssistantTurn': + """Add image content (some models can generate images)""" + self.content.append(ImageContent( + image_url=url, + image_data=data, + media_type=media_type + )) + return self + + def to_text(self) -> str: + """Get all text content concatenated. Images will be presented as placeholder text.""" + return self.content.to_text() + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary format""" + return { + "role": self.role, + "content": [c.to_dict() for c in self.content], + "reasoning": self.reasoning, + "finish_reason": self.finish_reason, + "prompt_tokens": self.prompt_tokens, + "completion_tokens": self.completion_tokens, + "model": self.model, + "metadata": self.metadata + } + + def get_text(self) -> ContentBlockList: + """Get all text content blocks. + + Returns: + ContentBlockList: List containing only TextContent blocks + """ + text_blocks = ContentBlockList() + for block in self.content: + if isinstance(block, TextContent): + text_blocks.append(block) + return text_blocks + + def get_images(self) -> ContentBlockList: + """Get all image content blocks. + + Returns: + ContentBlockList: List containing only ImageContent blocks + """ + image_blocks = ContentBlockList() + for block in self.content: + if isinstance(block, ImageContent): + image_blocks.append(block) + return image_blocks + + def __repr__(self) -> str: + """Safe string representation that handles missing attributes.""" + content_preview = str(self.content)[:50] + "..." if len(str(self.content)) > 50 else str(self.content) + parts = [f"AssistantTurn(content={content_preview!r}"] + + # Safely add optional fields if they exist + if hasattr(self, 'model') and self.model: + parts.append(f", model={self.model!r}") + if hasattr(self, 'prompt_tokens') and self.prompt_tokens: + parts.append(f", prompt_tokens={self.prompt_tokens}") + if hasattr(self, 'completion_tokens') and self.completion_tokens: + parts.append(f", completion_tokens={self.completion_tokens}") + + parts.append(")") + return "".join(parts) + + def to_litellm_format(self) -> Dict[str, Any]: + """Convert to LiteLLM Response API format (OpenAI Response API compatible)""" + result = {"role": self.role} + + # Handle content blocks (text, images, etc.) - delegate to ContentBlockList + result["content"] = self.content.to_litellm_format(role=self.role) + + return result + + + def _repr_html_(self) -> str: + """Rich HTML representation for Jupyter notebooks with glassmorphism design.""" + try: + from opto.utils.display.jupyter import render_assistant_turn + return render_assistant_turn(self) + except ImportError: + # Fallback to text representation if display module unavailable + return None diff --git a/opto/utils/display/README.md b/opto/utils/display/README.md new file mode 100644 index 00000000..5269568a --- /dev/null +++ b/opto/utils/display/README.md @@ -0,0 +1,129 @@ +# Opto Display Module + +Rendering utilities for visualizing Opto objects in various formats. + +## Architecture + +The display module separates **visualization logic** from **data handling logic**: + +``` +opto/utils/ +ā”œā”€ā”€ backbone.py # Core data classes (UserTurn, AssistantTurn, Chat) +└── display/ + ā”œā”€ā”€ __init__.py # Public API + ā”œā”€ā”€ jupyter.py # Jupyter notebook HTML rendering + ā”œā”€ā”€ themes.py # Color schemes and styling + └── README.md # This file +``` + +## Usage + +### Automatic Display in Jupyter + +Objects automatically render with glassmorphism styling: + +```python +from opto.utils.backbone import UserTurn, AssistantTurn, Chat + +user_turn = UserTurn("Hello!") +user_turn # Beautiful display! ✨ +``` + +### Direct Rendering + +Use the display module directly for more control: + +```python +from opto.utils.display import render_user_turn, render_chat +from IPython.display import HTML + +user_turn = UserTurn("Hello!") +html_output = render_user_turn(user_turn) +HTML(html_output) +``` + +## Features + +- Glassmorphism Design +- Multi-modal content (images, files interleaved with text), all rendered as HTML. + +### Custom Themes + +You can customize the color scheme: + +```python +from opto.utils.display.themes import set_theme + +custom_theme = { + 'user': { + 'background': 'rgba(255, 240, 245, 0.85)', + 'border': 'rgba(233, 30, 99, 0.3)', + 'text_color': '#C2185B', + 'icon': 'šŸ‘¤', + }, + 'assistant': { + 'background': 'rgba(232, 245, 233, 0.85)', + 'border': 'rgba(76, 175, 80, 0.3)', + 'text_color': '#388E3C', + 'icon': 'šŸ¤–', + }, + # ... other theme properties +} + +set_theme(custom_theme) +``` + +### Custom Renderers + +You can create your own rendering functions: + +```python +def my_custom_renderer(user_turn): + """Custom HTML renderer for UserTurn""" + return f"
{user_turn.content.to_text()}
" + +# Use it directly +from IPython.display import HTML +HTML(my_custom_renderer(user_turn)) +``` + +### Fallback Behavior + +If the display module is unavailable (e.g., import error), classes fall back to their `__repr__()` text representation: + +```python +class UserTurn: + def _repr_html_(self): + try: + from opto.utils.display.jupyter import render_user_turn + return render_user_turn(self) + except ImportError: + return None # Falls back to __repr__ +``` + +## Future Extension + +Potential additions to the display module: + +- **Terminal renderer**: ANSI color codes for CLI display +- **Markdown renderer**: Export conversations as markdown +- **LaTeX renderer**: For academic papers +- **HTML export**: Static HTML page generation + +To add a new rendering format: + +1. Create `opto/utils/display/my_format.py` +2. Implement `render_*` functions for each class +3. Export in `__init__.py` +4. Update documentation + +Example: + +```python +# opto/utils/display/terminal.py +def render_user_turn(turn): + """Render UserTurn with ANSI colors for terminal""" + from colorama import Fore, Style + return f"{Fore.CYAN}šŸ‘¤ User:{Style.RESET_ALL} {turn.content.to_text()}" +``` + diff --git a/opto/utils/display/__init__.py b/opto/utils/display/__init__.py new file mode 100644 index 00000000..ce4ef555 --- /dev/null +++ b/opto/utils/display/__init__.py @@ -0,0 +1,27 @@ +""" +Display utilities for rendering Opto objects in various formats. + +This module provides rendering functions for different output formats: +- Jupyter notebooks (HTML) +- Terminal/CLI +- Markdown export + +Usage: + from opto.utils.display import render_for_jupyter + html = render_for_jupyter(user_turn) +""" + +from .jupyter import ( + render_user_turn, + render_assistant_turn, + render_chat, + render_content_block_list, +) + +__all__ = [ + 'render_user_turn', + 'render_assistant_turn', + 'render_chat', + 'render_content_block_list', +] + diff --git a/opto/utils/display/jupyter.py b/opto/utils/display/jupyter.py new file mode 100644 index 00000000..393e7fa1 --- /dev/null +++ b/opto/utils/display/jupyter.py @@ -0,0 +1,543 @@ +""" +Jupyter notebook HTML rendering for Opto objects. + +This module handles all HTML generation for displaying Opto objects +in Jupyter notebooks with glassmorphism styling. +""" + +import html as html_module +import uuid +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from opto.utils.backbone import UserTurn, AssistantTurn, Chat, ContentBlockList + +from .themes import get_theme + + +def _escape(text: str) -> str: + """HTML escape helper.""" + return html_module.escape(str(text)) + + +def _escape_with_linebreaks(text: str) -> str: + """HTML escape with proper newline and tab handling.""" + # First escape HTML + escaped = html_module.escape(str(text)) + # Convert newlines to
tags + escaped = escaped.replace('\n', '
') + # Convert tabs to 4 spaces (visible) + escaped = escaped.replace('\t', '    ') + return escaped + + +def _render_image_block(block) -> str: + """Render an image content block.""" + parts = [] + parts.append('
') + + # Get image source - prioritize URL, then base64 data + img_src = None + if hasattr(block, 'image_url') and block.image_url: + img_src = block.image_url + elif hasattr(block, 'image_data') and block.image_data: + # Base64 encoded image + media_type = getattr(block, 'media_type', 'image/jpeg') + img_src = f"data:{media_type};base64,{block.image_data}" + + theme = get_theme() + max_height = theme['common']['image_max_height'] + + if img_src: + parts.append(f'') + else: + # Fallback to placeholder if no image source + media_type = getattr(block, 'media_type', 'image/jpeg') + parts.append('
šŸ–¼ļø') + parts.append(f'Image ({media_type})
') + + parts.append('
') + return ''.join(parts) + + +def _render_content_blocks(blocks, block_id: str) -> tuple: + """ + Render content blocks inline and generate expandable detail view. + + Returns: + tuple: (inline_html, detail_html, num_blocks, num_images) + """ + from opto.utils.backbone import TextContent, ImageContent + + inline_parts = [] + detail_parts = [] + num_images = 0 + + for block in blocks: + if isinstance(block, TextContent): + inline_parts.append(_escape_with_linebreaks(block.text)) + block_text = str(block) + elif isinstance(block, ImageContent): + inline_parts.append(_render_image_block(block)) + num_images += 1 + block_text = str(block) # Show full repr in detail view + else: + # Unknown block type + block_text = str(block) + inline_parts.append(_escape(block_text)) + + # Add to detail view + block_type = type(block).__name__ + detail_parts.append(f''' +
+ + {block_type}: + + + {_escape(block_text)} + +
+ ''') + + return ''.join(inline_parts), ''.join(detail_parts), len(blocks), num_images + + +def render_user_turn(turn: 'UserTurn') -> str: + """Render a UserTurn as Jupyter HTML.""" + theme = get_theme() + user_theme = theme['user'] + common = theme['common'] + + # Generate unique ID + turn_id = str(uuid.uuid4())[:8] + + parts = [] + + # Main turn container + parts.append(f''' +
+ ''') + + # Role header + parts.append(f''' +
+ {user_theme['icon']} + User +
+ ''') + + # Content + parts.append('
') + + # Render content blocks + inline_html, detail_html, num_blocks, num_images = _render_content_blocks(turn.content, turn_id) + parts.append(inline_html) + parts.append('
') + + # Metadata badges + parts.append('
') + + # Content blocks badge + block_label = f"{num_blocks} content block{'s' if num_blocks != 1 else ''}" + if num_images > 0: + block_label += f" ({num_images} image{'s' if num_images != 1 else ''})" + + parts.append(f''' + + {block_label} ā–¼ + + ''') + + # Tools badge + tools = getattr(turn, 'tools', []) + if tools: + parts.append(f''' + + šŸ”§ {len(tools)} tool{'s' if len(tools) != 1 else ''} available + + ''') + + # Settings badges + temperature = getattr(turn, 'temperature', None) + if temperature is not None: + parts.append(f''' + + temperature: {temperature} + + ''') + + parts.append('
') + + # Content blocks detail (expandable) + parts.append(f''' + ') + + # Close main container + parts.append('
') + + return ''.join(parts) + + +def render_assistant_turn(turn: 'AssistantTurn') -> str: + """Render an AssistantTurn as Jupyter HTML.""" + theme = get_theme() + assistant_theme = theme['assistant'] + common = theme['common'] + + # Generate unique ID + turn_id = str(uuid.uuid4())[:8] + + parts = [] + + # Main turn container + parts.append(f''' +
+ ''') + + # Role header with token badge + parts.append(f''' +
+
+ {assistant_theme['icon']} + Assistant +
+ ''') + + # Token count badge + prompt_tokens = getattr(turn, 'prompt_tokens', None) + completion_tokens = getattr(turn, 'completion_tokens', None) + if prompt_tokens or completion_tokens: + total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) + parts.append(f''' + + šŸ’° {total_tokens} tokens + + ''') + + parts.append('
') + + # Reasoning section (if present) + reasoning = getattr(turn, 'reasoning', None) + if reasoning: + reasoning_theme = theme['reasoning'] + parts.append(f''' +
+
+ {reasoning_theme['icon']} Reasoning: +
+ {_escape(reasoning)} +
+ ''') + + # Content + parts.append('
') + + # Render content blocks + inline_html, _, _, _ = _render_content_blocks(turn.content, turn_id) + parts.append(inline_html) + parts.append('
') + + # Tool calls section (if present) + if hasattr(turn, 'tool_calls') and turn.tool_calls: + import json + tool_theme = theme['tool_calls'] + + parts.append(f''' +
+
+ {tool_theme['icon']} Tool Calls: +
+ ''') + + for tc in turn.tool_calls: + func_name = tc.function.name if tc.function else "unknown" + args_str = tc.function.arguments if tc.function else "{}" + parts.append(f''' +
+ {_escape(func_name)} + ({_escape(args_str)}) +
+ ''') + + # Show results if available + tool_results = getattr(turn, 'tool_results', []) + matching_results = [tr for tr in tool_results if tr.tool_call_id == tc.id] + if matching_results: + for tr in matching_results: + result_content = str(tr.content)[:200] + if len(str(tr.content)) > 200: + result_content += '...' + parts.append(f''' +
+ Result: {_escape(result_content)} +
+ ''') + + parts.append('
') + + # Metadata badges + parts.append('
') + + model = getattr(turn, 'model', None) + if model: + parts.append(f''' + + model: {_escape(model)} + + ''') + + finish_reason = getattr(turn, 'finish_reason', None) + if finish_reason: + parts.append(f''' + + finish: {_escape(finish_reason)} + + ''') + + parts.append('
') + + # Close main container + parts.append('
') + + return ''.join(parts) + + +def render_content_block_list(blocks: 'ContentBlockList') -> str: + """Render a ContentBlockList as Jupyter HTML.""" + theme = get_theme() + content_theme = theme['content_blocks'] + common = theme['common'] + + # Generate unique ID + list_id = str(uuid.uuid4())[:8] + + parts = [] + + # Main container + parts.append(f''' +
+ ''') + + # Header + parts.append(f''' +
+ {content_theme['icon']} + Content Blocks +
+ ''') + + # Content + parts.append('
') + + # Render content blocks + inline_html, detail_html, num_blocks, num_images = _render_content_blocks(blocks, list_id) + parts.append(inline_html) + parts.append('
') + + # Metadata badges + parts.append('
') + + block_label = f"{num_blocks} block{'s' if num_blocks != 1 else ''}" + if num_images > 0: + block_label += f" ({num_images} image{'s' if num_images != 1 else ''})" + + parts.append(f''' + + {block_label} ā–¼ + + ''') + + parts.append('
') + + # Content blocks detail (expandable) + parts.append(f''' + ') + + # Close main container + parts.append('
') + + return ''.join(parts) + + +def render_chat(chat: 'Chat') -> str: + """Render a Chat as Jupyter HTML.""" + theme = get_theme() + chat_theme = theme['chat'] + system_theme = theme['system_prompt'] + common = theme['common'] + + # Generate unique ID + chat_id = str(uuid.uuid4())[:8] + + parts = [] + + # Add inline styles for collapsed turns functionality + parts.append(f''' + + ''') + + # Main chat container + parts.append(f''' +
+ ''') + + # Chat header + parts.append(f''' +
+
šŸ’¬ Chat History
+
+ ''') + + # Stats badges + num_turns = len(chat.turns) + parts.append(f''' + + {num_turns} turn{'s' if num_turns != 1 else ''} + + ''') + + total_tokens = chat.get_token_count_estimate() + if total_tokens > 0: + parts.append(f''' + + Total: ~{total_tokens} tokens + + ''') + + parts.append('
') + + # System prompt (if present) + if chat.system_prompt: + parts.append(f''' +
+
+ {system_theme['icon']} System Prompt: +
+ {_escape(chat.system_prompt)} +
+ ''') + + # Render turns with middle collapsing for long conversations + COLLAPSE_THRESHOLD = 20 + SHOW_FIRST = 2 + SHOW_LAST = 2 + + if len(chat.turns) > COLLAPSE_THRESHOLD: + # Show first few turns + for turn in chat.turns[:SHOW_FIRST]: + if hasattr(turn, '_repr_html_'): + parts.append(turn._repr_html_()) + + # Collapsed middle section + num_hidden = len(chat.turns) - SHOW_FIRST - SHOW_LAST + parts.append(f''' +
+ ... {num_hidden} more turn{'s' if num_hidden != 1 else ''} ... (click to expand) +
+ ''') + + # Hidden middle turns + parts.append(f'
') + for turn in chat.turns[SHOW_FIRST:-SHOW_LAST]: + if hasattr(turn, '_repr_html_'): + parts.append(turn._repr_html_()) + parts.append('
') + + # Show last few turns + for turn in chat.turns[-SHOW_LAST:]: + if hasattr(turn, '_repr_html_'): + parts.append(turn._repr_html_()) + else: + # Show all turns + for turn in chat.turns: + if hasattr(turn, '_repr_html_'): + parts.append(turn._repr_html_()) + + # Close main container + parts.append('
') + + return ''.join(parts) + diff --git a/opto/utils/display/themes.py b/opto/utils/display/themes.py new file mode 100644 index 00000000..25072751 --- /dev/null +++ b/opto/utils/display/themes.py @@ -0,0 +1,73 @@ +""" +Theme configurations for display rendering. + +Themes define colors, styles, and visual elements for different rendering styles. +""" + +# Glassmorphism theme (default) - inspired by Trace2 branding +GLASSMORPHISM_THEME = { + 'user': { + 'background': 'rgba(236, 254, 255, 0.85)', + 'border': 'rgba(6, 182, 212, 0.3)', + 'border_color': '#06b6d4', + 'text_color': '#0891b2', + 'icon': 'šŸ‘¤', + }, + 'assistant': { + 'background': 'rgba(238, 242, 255, 0.85)', + 'border': 'rgba(99, 102, 241, 0.3)', + 'border_color': '#6366f1', + 'text_color': '#4f46e5', + 'icon': 'šŸ¤–', + }, + 'content_blocks': { + 'background': 'rgba(255, 255, 255, 0.85)', + 'border': 'rgba(158, 158, 158, 0.3)', + 'text_color': '#666', + 'icon': 'šŸ“', + }, + 'system_prompt': { + 'background': 'rgba(250, 250, 250, 0.7)', + 'border': 'rgba(158, 158, 158, 0.2)', + 'text_color': '#757575', + 'icon': 'āš™ļø', + }, + 'chat': { + 'background': 'rgba(255, 255, 255, 0.7)', + 'border': 'rgba(158, 158, 158, 0.2)', + 'title_color': '#424242', + }, + 'reasoning': { + 'background': '#F5F5F5', + 'border': '#E0E0E0', + 'text_color': '#555', + 'icon': 'šŸ’­', + }, + 'tool_calls': { + 'background': '#F5F5F5', + 'border': '#E0E0E0', + 'text_color': '#555', + 'icon': 'šŸ”§', + }, + 'common': { + 'border_radius': '12px', + 'box_shadow': '0 4px 16px rgba(0,0,0,0.06)', + 'backdrop_filter': 'blur(10px)', + 'image_max_height': '400px', + } +} + +# Active theme (can be changed) +ACTIVE_THEME = GLASSMORPHISM_THEME + + +def set_theme(theme_dict): + """Set a custom theme for display rendering.""" + global ACTIVE_THEME + ACTIVE_THEME = theme_dict + + +def get_theme(): + """Get the currently active theme.""" + return ACTIVE_THEME + diff --git a/opto/utils/llm.py b/opto/utils/llm.py index 97377c16..4b8b6444 100644 --- a/opto/utils/llm.py +++ b/opto/utils/llm.py @@ -1,4 +1,12 @@ -from typing import List, Tuple, Dict, Any, Callable, Union +""" +When MM (multimodal) is enabled, we primarily either use: +1. LiteLLM's response API +2. Google's Interaction API design (not supported by LiteLLM response API at all) +When MM is disabled, for backward compatibility, we use: +1. LiteLLM's completion API +""" + +from typing import List, Tuple, Dict, Any, Callable, Union, Optional import os import time import json @@ -6,11 +14,35 @@ import warnings from .auto_retry import retry_with_exponential_backoff +# Import content/turn types for mm_beta mode. +# Heavy/optional SDKs (openai, google-genai) are imported lazily inside the +# backends that need them, so importing this module never requires them. +from .backbone import AssistantTurn, TextContent, ImageContent + try: import autogen # We import autogen here to avoid the need of installing autogen except ImportError: pass + +def _is_image_generation_model(model_name: str) -> bool: + """Detect if a model is for image generation based on its name. + + Detects: + - OpenAI: gpt-image-1, gpt-image-1.5, gpt-image-1-mini, dall-e-2, dall-e-3 + - Gemini: gemini-2.5-flash-image, gemini-2.5-pro-image, etc. + + Args: + model_name: The name of the model to check + + Returns: + bool: True if the model is an image generation model, False otherwise + """ + if model_name is None: + return False + model_lower = model_name.lower() + return 'image' in model_lower or 'dall-e' in model_lower + class AbstractModel: """Abstract base class for LLM model wrappers with automatic refreshing. @@ -24,6 +56,12 @@ class AbstractModel: reset_freq : int or None, optional Number of seconds after which to refresh the model. If None, the model is never refreshed. + mm_beta : bool, optional + If True, returns AssistantTurn objects with rich multimodal content. + If False (default), returns raw API responses in legacy format. + model_name : str or None, optional + The name of the model being used (e.g., "gpt-4o", "claude-3-5-sonnet-latest"). + If None, no model name is stored. Attributes ---------- @@ -31,13 +69,21 @@ class AbstractModel: The factory function for creating model instances. reset_freq : int or None Refresh frequency in seconds. + mm_beta : bool + Whether to use multimodal beta mode. + model_name : str or None + The name of the model being used. + is_image_model : bool + Whether the model is for image generation (auto-detected from model name). + model : Any Property that returns the current model instance. Methods ------- __call__(*args, **kwargs) - Execute the model, refreshing if needed. + Execute the model, refreshing if needed. Returns AssistantTurn if mm_beta=True, + otherwise returns raw API response. Notes ----- @@ -45,8 +91,9 @@ class AbstractModel: 1. **Automatic Refreshing**: Recreates the model instance periodically to prevent issues with long-running connections. 2. **Serialization**: Supports pickling by recreating the model on load. - 3. **Consistent Interface**: Ensures responses are available at - `response['choices'][0]['message']['content']`. + 3. **Response Formats**: + - Legacy (mm_beta=False): `response['choices'][0]['message']['content']` + - Multimodal (mm_beta=True): AssistantTurn object with .content, .tool_calls, etc. Subclasses should override the `model` property to customize behavior. @@ -56,32 +103,58 @@ class AbstractModel: LiteLLM : Concrete implementation using LiteLLM """ - def __init__(self, factory: Callable, reset_freq: Union[int, None] = None) -> None: + def __init__(self, factory: Callable, reset_freq: Union[int, None] = None, + mm_beta: bool = False, model_name: Union[str, None] = None) -> None: """ Args: factory: A function that takes no arguments and returns a model that is callable. reset_freq: The number of seconds after which the model should be refreshed. If None, the model is never refreshed. + mm_beta: If True, returns AssistantTurn objects with rich multimodal content. + If False (default), returns raw API responses in legacy format. + model_name: The name of the model being used (e.g., "gpt-4o", "claude-3-5-sonnet-latest"). + If None, no model name is stored. """ self.factory = factory self._model = self.factory() self.reset_freq = reset_freq self._init_time = time.time() + self.mm_beta = mm_beta + self.model_name = model_name # Overwrite this `model` property when subclassing. @property def model(self): """When self.model is called, text responses should always be available at `response['choices'][0]['message']['content']`""" return self._model + + @property + def is_image_model(self) -> bool: + """Check if this model is for image generation based on model name. + + Returns True if the model name contains 'image' or 'dall-e', False otherwise. + """ + return _is_image_generation_model(self.model_name) # This is the main API def __call__(self, *args, **kwargs) -> Any: """ The call function handles refreshing the model if needed. + + Returns: + If mm_beta=False: Raw completion API response (backward compatible) + If mm_beta=True: AssistantTurn object with parsed multimodal content """ if self.reset_freq is not None and time.time() - self._init_time > self.reset_freq: self._model = self.factory() self._init_time = time.time() - return self.model(*args, **kwargs) + + response = self.model(*args, **kwargs) + + # Parse to AssistantTurn if mm_beta mode is enabled + if self.mm_beta: + return AssistantTurn(response) + + return response def __getstate__(self): state = self.__dict__.copy() @@ -151,7 +224,8 @@ class AutoGenLLM(AbstractModel): >>> response = llm(messages=[{"role": "user", "content": "Hello"}]) """ - def __init__(self, config_list: List = None, filter_dict: Dict = None, reset_freq: Union[int, None] = None) -> None: + def __init__(self, config_list: List = None, filter_dict: Dict = None, + reset_freq: Union[int, None] = None, mm_beta: bool = False) -> None: if config_list is None: try: config_list = autogen.config_list_from_json("OAI_CONFIG_LIST") @@ -163,8 +237,11 @@ def __init__(self, config_list: List = None, filter_dict: Dict = None, reset_fre if filter_dict is not None: config_list = autogen.filter_config(config_list, filter_dict) + # Extract model name from config_list if available + model_name = config_list[0].get('model') if config_list and len(config_list) > 0 else None + factory = lambda *args, **kwargs: self._factory(config_list) - super().__init__(factory, reset_freq) + super().__init__(factory, reset_freq, mm_beta=mm_beta, model_name=model_name) @classmethod def _factory(cls, config_list): @@ -270,16 +347,55 @@ class LiteLLM(AbstractModel): This is an LLM backend supported by LiteLLM library. https://docs.litellm.ai/docs/completion/input + https://docs.litellm.ai/docs/response_api + https://docs.litellm.ai/docs/image_generation To use this, set the credentials through the environment variable as instructed in the LiteLLM documentation. For convenience, you can set the default model name through the environment variable TRACE_LITELLM_MODEL. - When using Azure models via token provider, you can set the Azure token - provider scope through the environment variable AZURE_TOKEN_PROVIDER_SCOPE. + + Azure OpenAI Authentication: + Two authentication methods are supported for Azure OpenAI: + + 1. API Key Authentication (Recommended for most users): + Set these environment variables: + - AZURE_API_KEY: Your Azure OpenAI API key + - AZURE_API_BASE: Your Azure OpenAI endpoint (e.g., https://your-resource.openai.azure.com) + - AZURE_API_VERSION: API version (e.g., 2024-08-01-preview) + - TRACE_LITELLM_MODEL: Model name with azure/ prefix (e.g., azure/o4-mini) + + Do NOT set AZURE_TOKEN_PROVIDER_SCOPE for this method. + + 2. Azure AD Credential Authentication (For enterprise users): + Set AZURE_TOKEN_PROVIDER_SCOPE (e.g., https://cognitiveservices.azure.com/.default) + to use Azure Identity credential-based authentication. + This method does NOT use AZURE_API_KEY. + + This class now supports storing default completion parameters (like temperature, + top_p, max_tokens, etc.) that will be used for all calls unless overridden. + + Text Generation: + When mm_beta=True, the Responses API is used for rich multimodal content. + When mm_beta=False (default), the Completion API is used for backward compatibility. + + See: https://docs.litellm.ai/docs/response_api + + Image Generation: + Automatically detects image generation models (containing 'image' or 'dall-e' in name). + Uses litellm.image_generation() API for models like: + - gpt-image-1, gpt-image-1.5, gpt-image-1-mini + - dall-e-2, dall-e-3 + + Image models require a single string prompt: + llm = LLM(model="gpt-image-1.5") + result = llm(prompt="A serene mountain landscape") + + Check llm.is_image_model to determine if a model is for image generation. """ def __init__(self, model: Union[str, None] = None, reset_freq: Union[int, None] = None, - cache=True, max_retries=10, base_delay=1.0) -> None: + cache=True, max_retries=1, base_delay=1.0, + mm_beta: bool = False, **default_params) -> None: if model is None: model = os.environ.get('TRACE_LITELLM_MODEL') if model is None: @@ -288,38 +404,164 @@ def __init__(self, model: Union[str, None] = None, reset_freq: Union[int, None] self.model_name = model self.cache = cache - factory = lambda: self._factory(self.model_name, max_retries=max_retries, base_delay=base_delay) # an LLM instance uses a fixed model - super().__init__(factory, reset_freq) + self.default_params = default_params # Store default completion parameters + + factory = lambda: self._factory( + self.model_name, + self.default_params, + mm_beta, + max_retries=max_retries, + base_delay=base_delay + ) + super().__init__(factory, reset_freq, mm_beta=mm_beta, model_name=model) @classmethod - def _factory(cls, model_name: str, max_retries=10, base_delay=1.0): + def _factory(cls, model_name: str, default_params: dict, mm_beta: bool, + max_retries=1, base_delay=1.0): import litellm + + # For Azure models, set global litellm variables as a fallback + # (workaround for potential litellm.responses API issues) + if model_name.startswith('azure/'): + if os.environ.get('AZURE_API_VERSION') and not hasattr(litellm, '_azure_api_version_set'): + try: + litellm.api_version = os.environ.get('AZURE_API_VERSION') + litellm._azure_api_version_set = True # Mark to avoid setting multiple times + except: + pass # Ignore if litellm doesn't support this + + # Check if this is an image generation model + is_image_model = _is_image_generation_model(model_name) + + if is_image_model: + # Image generation API + api_func = litellm.image_generation + operation_name = "LiteLLM_image_generation" + + # Standard image generation wrapper + def image_wrapper(prompt, **kwargs): + assert isinstance(prompt, str), ( + f"Image generation requires a single string prompt. " + f"Got {type(prompt).__name__}. " + f"Usage: llm(prompt='your prompt here')" + ) + return retry_with_exponential_backoff( + lambda: api_func(model=model_name, prompt=prompt, **{**default_params, **kwargs}), + max_retries=max_retries, + base_delay=base_delay, + operation_name=operation_name + ) + return image_wrapper + + # Use Responses API when mm_beta=True, otherwise use Completion API + api_func = litellm.responses if mm_beta else litellm.completion + operation_name = "LiteLLM_responses" if mm_beta else "LiteLLM_completion" + if model_name.startswith('azure/'): # azure model azure_token_provider_scope = os.environ.get('AZURE_TOKEN_PROVIDER_SCOPE', None) if azure_token_provider_scope is not None: + # Azure AD credential-based authentication from azure.identity import DefaultAzureCredential, get_bearer_token_provider credential = get_bearer_token_provider(DefaultAzureCredential(), azure_token_provider_scope) - return lambda *args, **kwargs: retry_with_exponential_backoff( - lambda: litellm.completion(model_name, *args, - azure_ad_token_provider=credential, **kwargs), + if mm_beta: + # Responses API: model as keyword argument, convert messages to input + def azure_responses_wrapper(*args, **kwargs): + # Convert 'messages' to 'input' for Responses API + if 'messages' in kwargs and 'input' not in kwargs: + kwargs['input'] = kwargs.pop('messages') + return retry_with_exponential_backoff( + lambda: api_func(model=model_name, + azure_ad_token_provider=credential, **{**default_params, **kwargs}), + max_retries=max_retries, + base_delay=base_delay, + operation_name=operation_name + ) + return azure_responses_wrapper + else: + # Completion API: model as positional argument + return lambda *args, **kwargs: retry_with_exponential_backoff( + lambda: api_func(model_name, *args, + azure_ad_token_provider=credential, **{**default_params, **kwargs}), + max_retries=max_retries, + base_delay=base_delay, + operation_name=operation_name + ) + else: + # Azure API key authentication - explicitly pass Azure env vars if available + azure_params = {} + if 'api_key' not in default_params: + azure_api_key = os.environ.get('AZURE_API_KEY') + if azure_api_key: + azure_params['api_key'] = azure_api_key + if 'api_base' not in default_params: + azure_api_base = os.environ.get('AZURE_API_BASE') + if azure_api_base: + azure_params['api_base'] = azure_api_base + if 'api_version' not in default_params: + azure_api_version = os.environ.get('AZURE_API_VERSION') + if azure_api_version: + azure_params['api_version'] = azure_api_version + + if mm_beta: + # Responses API: model as keyword argument, convert messages to input + def azure_key_responses_wrapper(*args, **kwargs): + # Convert 'messages' to 'input' for Responses API + if 'messages' in kwargs and 'input' not in kwargs: + kwargs['input'] = kwargs.pop('messages') + return retry_with_exponential_backoff( + lambda: api_func(model=model_name, **{**azure_params, **default_params, **kwargs}), + max_retries=max_retries, + base_delay=base_delay, + operation_name=operation_name + ) + return azure_key_responses_wrapper + else: + # Completion API: model as positional argument + return lambda *args, **kwargs: retry_with_exponential_backoff( + lambda: api_func(model_name, *args, **{**azure_params, **default_params, **kwargs}), + max_retries=max_retries, + base_delay=base_delay, + operation_name=operation_name + ) + + if mm_beta: + # Responses API: model as keyword argument, convert messages to input + def responses_wrapper(*args, **kwargs): + # Convert 'messages' to 'input' for Responses API + if 'messages' in kwargs and 'input' not in kwargs: + kwargs['input'] = kwargs.pop('messages') + return retry_with_exponential_backoff( + lambda: api_func(model=model_name, **{**default_params, **kwargs}), max_retries=max_retries, base_delay=base_delay, - operation_name="LiteLLM_completion" + operation_name=operation_name ) - return lambda *args, **kwargs: retry_with_exponential_backoff( - lambda: litellm.completion(model_name, *args, **kwargs), - max_retries=max_retries, - base_delay=base_delay, - operation_name="LiteLLM_completion" - ) + return responses_wrapper + else: + # Completion API: model as positional argument + return lambda *args, **kwargs: retry_with_exponential_backoff( + lambda: api_func(model_name, *args, **{**default_params, **kwargs}), + max_retries=max_retries, + base_delay=base_delay, + operation_name=operation_name + ) @property def model(self): """ - response = litellm.completion( - model=self.model, - messages=[{"content": message, "role": "user"}] - ) + Calls either litellm.completion() or litellm.responses() depending on mm_beta. + + For completion API (mm_beta=False): + response = litellm.completion( + model=self.model, + messages=[{"content": message, "role": "user"}] + ) + + For responses API (mm_beta=True): + response = litellm.responses( + model=self.model, + input="Your input text" + ) """ return lambda *args, **kwargs: self._model(*args, **kwargs) @@ -331,7 +573,7 @@ class CustomLLM(AbstractModel): """ def __init__(self, model: Union[str, None] = None, reset_freq: Union[int, None] = None, - cache=True) -> None: + cache=True, mm_beta: bool = False) -> None: if model is None: model = os.environ.get('TRACE_CUSTOMLLM_MODEL', 'gpt-4o') base_url = os.environ.get('TRACE_CUSTOMLLM_URL', 'http://xx.xx.xxx.xx:4000/') @@ -342,7 +584,7 @@ def __init__(self, model: Union[str, None] = None, reset_freq: Union[int, None] self.model_name = model self.cache = cache factory = lambda: self._factory(base_url, server_api_key) # an LLM instance uses a fixed model - super().__init__(factory, reset_freq) + super().__init__(factory, reset_freq, mm_beta=mm_beta, model_name=model) @classmethod def _factory(cls, base_url: str, server_api_key: str): @@ -359,83 +601,498 @@ def create(self, **config: Any): config['model'] = self.model_name return self._model.chat.completions.create(**config) -# Registry of available backends -_LLM_REGISTRY = { - "LiteLLM": LiteLLM, - "AutoGen": AutoGenLLM, - "CustomLLM": CustomLLM, -} +class GoogleGenAILLM(AbstractModel): + """ + This is an LLM backend using Google's GenAI SDK with the Interactions API. + + https://ai.google.dev/gemini-api/docs/text-generation + https://ai.google.dev/gemini-api/docs/image-generation + + The Interactions API is a unified interface for interacting with Gemini models, + similar to OpenAI's Response API. It provides better state management, tool + orchestration, and support for long-running tasks. + + To use this, set the GEMINI_API_KEY environment variable with your API key. + For convenience, you can set the default model name through the environment + variable TRACE_GOOGLE_GENAI_MODEL. + + Supported models: + - Text: gemini-2.5-flash, gemini-2.5-pro, gemini-2.5-flash-lite + - Image: gemini-2.5-flash-image, gemini-2.5-pro-image + + This class supports storing default generation parameters (like temperature, + max_output_tokens, etc.) that will be used for all calls unless overridden. + + Text Generation: + Use ConversationHistory.to_gemini_format() to convert conversation history + to the format expected by Google GenAI. + + Example: + from opto.utils.llm import LLM + from opto.utils.backbone import ConversationHistory, UserTurn, AssistantTurn + + # Initialize LLM + llm = LLM(model="gemini-2.5-flash") + + # Create conversation history + history = ConversationHistory() + history.system_prompt = "You are a helpful assistant." + history.add_user_turn(UserTurn().add_text("What is AI?")) + + # Convert to Gemini format and call LLM + messages = history.to_gemini_format() + response = llm(messages=messages, max_tokens=100) + + # Parse response + at = AssistantTurn(response) + print(at.get_text()) + + Image Generation: + Automatically detects image generation models (containing 'image' in name). + Uses client.models.generate_images() API for models like gemini-2.5-flash-image. + + Image models require a single string prompt: + llm = LLM(model="gemini-2.5-flash-image") + result = llm(prompt="A serene mountain landscape", number_of_images=2) + + Check llm.is_image_model to determine if a model is for image generation. + """ -class LLMFactory: - """Factory for creating LLM instances with predefined profiles. + def __init__(self, model: Union[str, None] = None, reset_freq: Union[int, None] = None, + cache=True, mm_beta: bool = False, max_retries: int = 1, + base_delay: float = 1.0, **default_params) -> None: + if model is None: + model = os.environ.get('TRACE_GOOGLE_GENAI_MODEL', 'gemini-2.5-flash') + + self.model_name = model + self.cache = cache + self.default_params = default_params # Store default generation parameters + factory = lambda: self._factory(self.model_name, self.default_params, + max_retries=max_retries, base_delay=base_delay) + super().__init__(factory, reset_freq, mm_beta=mm_beta, model_name=model) - The code comes with these built-in profiles: + @classmethod + def _factory(cls, model_name: str, default_params: dict, + max_retries: int = 1, base_delay: float = 1.0): + """Create a Google GenAI client wrapper using the Interactions API.""" + from google import genai + from google.genai import types + # Get API key from environment variable + api_key = os.environ.get('GEMINI_API_KEY') + if api_key: + client = genai.Client(api_key=api_key) + else: + # Try without API key (will use default credentials or fail gracefully) + client = genai.Client() + + # Check if this is an image generation model + is_image_model = _is_image_generation_model(model_name) + + if is_image_model: + # Image generation for Gemini + def image_api_func(prompt, **kwargs): + assert isinstance(prompt, str), ( + f"Image generation requires a single string prompt. " + f"Got {type(prompt).__name__}. " + f"Usage: llm(prompt='your prompt here')" + ) + + # Gemini image generation API + # https://ai.google.dev/gemini-api/docs/image-generation + # Filter kwargs to only valid parameters for generate_images + valid_params = { + k: v for k, v in kwargs.items() + if k in ['number_of_images', 'aspect_ratio', 'safety_filter_level'] + } + response = client.models.generate_images( + model=model_name, + prompt=prompt, + **valid_params + ) + return response + + return lambda *args, **kwargs: retry_with_exponential_backoff( + lambda: image_api_func(*args, **{**default_params, **kwargs}), + max_retries=max_retries, + base_delay=base_delay, + operation_name=f"{model_name}_image_gen" + ) + + # Build config if there are generation parameters + config_params = {} + + # Handle thinking config for Gemini 2.5+ models + if 'thinking_budget' in default_params: + thinking_budget = default_params.pop('thinking_budget') + config_params['thinking_config'] = types.ThinkingConfig( + thinking_budget=thinking_budget + ) + + def api_func(model_name, *args, **kwargs): + # Extract system_instruction if present (needs to be at config level, not in kwargs) + system_instruction = kwargs.pop('system_instruction', None) + + # Handle messages parameter (from history.to_gemini_format()) + messages = kwargs.pop('messages', None) + contents = kwargs.pop('contents', None) + + if messages: + # Detect format: OpenAI-style has 'content' key, Gemini-native has 'parts' key. + # If OpenAI-style, convert via _gemini_messages_to_contents so the SDK accepts it. + first_non_system = next( + (m for m in messages if m.get('role') != 'system'), None + ) + is_openai_format = ( + first_non_system is not None and + 'content' in first_non_system and + 'parts' not in first_non_system + ) + if is_openai_format: + converted, extracted_sys = _gemini_messages_to_contents(messages) + if system_instruction is None: + system_instruction = extracted_sys + contents = converted + else: + # Already in Gemini native format (from history.to_gemini_format()) + if messages[0].get('role') == 'system': + if system_instruction is None: + system_instruction = messages[0].get('content') + contents = messages[1:] + else: + contents = messages + + # Use contents if provided, otherwise use positional args. + # If a positional arg is a list of OpenAI-format dicts, convert it. + if contents is None and args: + raw = args[0] + if ( + isinstance(raw, list) and raw and + isinstance(raw[0], dict) and + 'content' in raw[0] and 'parts' not in raw[0] + ): + converted, extracted_sys = _gemini_messages_to_contents(raw) + if system_instruction is None: + system_instruction = extracted_sys + contents = converted + else: + contents = raw + contents_to_use = contents + + # Map max_tokens to max_output_tokens for Google GenAI + if 'max_tokens' in kwargs: + kwargs['max_output_tokens'] = kwargs.pop('max_tokens') + + # Remove any other parameters that shouldn't go to GenerateContentConfig + # Keep only valid config parameters + valid_config_params = { + 'temperature', 'max_output_tokens', 'top_p', 'top_k', + 'stop_sequences', 'candidate_count', 'presence_penalty', + 'frequency_penalty', 'response_mime_type', 'response_schema' + } + config_kwargs = {k: v for k, v in kwargs.items() if k in valid_config_params} + + if system_instruction: + config_params_with_system = {**config_params, 'system_instruction': system_instruction} + else: + config_params_with_system = config_params + + response = client.models.generate_content( + model=model_name, + contents=contents_to_use, + config=types.GenerateContentConfig(**{**config_params_with_system, **config_kwargs}) + ) + + return response + + return lambda *args, **kwargs: retry_with_exponential_backoff( + lambda: api_func(model_name, *args, **{**default_params, **kwargs}), + max_retries=max_retries, + base_delay=base_delay, + operation_name=f"{model_name}" + ) - llm_default = LLM(profile="default") # gpt-4o-mini - llm_premium = LLM(profile="premium") # gpt-4 - llm_cheap = LLM(profile="cheap") # gpt-4o-mini - llm_fast = LLM(profile="fast") # gpt-3.5-turbo-mini - llm_reasoning = LLM(profile="reasoning") # o1-mini + @property + def model(self): + """ + Wrapper that injects the model name into calls. + + Example: + response = llm(contents="How does AI work?") + """ + return lambda *args, **kwargs: self._model(model=self.model_name, *args, **kwargs) - You can override those built-in profiles: +# --------------------------------------------------------------------------- +# Helper to convert OpenAI-style messages into Gemini 'contents' format. +# --------------------------------------------------------------------------- - LLMFactory.register_profile("default", "LiteLLM", model="gpt-4o", temperature=0.5) - LLMFactory.register_profile("premium", "LiteLLM", model="o1-preview", max_tokens=8000) - LLMFactory.register_profile("cheap", "LiteLLM", model="gpt-3.5-turbo", temperature=0.9) - LLMFactory.register_profile("fast", "LiteLLM", model="gpt-3.5-turbo", max_tokens=500) - LLMFactory.register_profile("reasoning", "LiteLLM", model="o1-preview") +def _gemini_messages_to_contents(messages): + """Convert a standard messages list to Gemini REST API ``contents`` format. - An Example of using Different Backends + Returns (contents, system_instruction) where system_instruction is a str + extracted from any ``role="system"`` message, or None. + """ + contents = [] + system_instruction = None + + for msg in messages: + role = msg.get('role', 'user') + content = msg.get('content', '') + + if role == 'system': + if isinstance(content, list): + texts = [ + item['text'] for item in content + if isinstance(item, dict) and item.get('type') == 'text' + ] + system_instruction = '\n'.join(texts) + else: + system_instruction = str(content) + continue + + gemini_role = 'model' if role == 'assistant' else 'user' + + if isinstance(content, str): + parts = [{'text': content}] + elif isinstance(content, list): + parts = [ + {'text': item['text']} + for item in content + if isinstance(item, dict) and item.get('type') == 'text' + ] + if not parts: + parts = [{'text': str(content)}] + else: + parts = [{'text': str(content)}] + + contents.append({'role': gemini_role, 'parts': parts}) + + return contents, system_instruction - # Register custom profiles for different use cases - LLMFactory.register_profile("advanced_reasoning", "LiteLLM", model="o1-preview", max_tokens=4000) - LLMFactory.register_profile("claude_sonnet", "LiteLLM", model="claude-3-5-sonnet-latest", temperature=0.3) - LLMFactory.register_profile("custom_server", "CustomLLM", model="llama-3.1-8b") - # Use in different contexts - reasoning_llm = LLM(profile="advanced_reasoning") # For complex reasoning - claude_llm = LLM(profile="claude_sonnet") # For Claude responses - local_llm = LLM(profile="custom_server") # For local deployment +# Registry of available backends +_LLM_REGISTRY = { + "LiteLLM": LiteLLM, + "AutoGen": AutoGenLLM, + "CustomLLM": CustomLLM, + "GoogleGenAI": GoogleGenAILLM, +} - # Single LLM optimizer with custom profile - optimizer1 = OptoPrime(parameters, llm=LLM(profile="advanced_reasoning")) +class LLMFactory: + """Factory for creating LLM instances with named profiles. + + Profiles allow you to save and reuse LLM configurations with specific settings. + Each profile can include any LiteLLM-supported parameters like model, temperature, + top_p, max_tokens, etc. + + The default profile uses 'gpt-4o-mini' with standard settings. + + Basic Usage: + # Use default model (gpt-4o-mini) + llm = LLM() + + # Specify a model directly + llm = LLM(model="gpt-4o") + + # Use a named profile + llm = LLM(profile="my_profile") + + Creating Custom Profiles: + # Register a profile with full LiteLLM configuration + LLMFactory.create_profile( + "creative_writer", + backend="LiteLLM", + model="gpt-4o", + temperature=0.9, + top_p=0.95, + max_tokens=2000, + presence_penalty=0.6 + ) + + # Register a reasoning profile + LLMFactory.create_profile( + "deep_thinker", + backend="LiteLLM", + model="o1-preview", + max_completion_tokens=8000 + ) + + # Register a profile with specific formatting + LLMFactory.create_profile( + "json_responder", + backend="LiteLLM", + model="gpt-4o-mini", + temperature=0.3, + response_format={"type": "json_object"} + ) - # Multi-LLM optimizer with multiple profiles - optimizer2 = OptoPrimeMulti(parameters, llm_profiles=["cheap", "premium", "claude_sonnet"], generation_technique="multi_llm") + Using Profiles: + # Use your custom profile + llm = LLM(profile="creative_writer") + + # In optimizers + optimizer = OptoPrime(parameters, llm=LLM(profile="deep_thinker")) + + Profile Management: + # List all available profiles + profiles = LLMFactory.list_profiles() + + # Get profile configuration + config = LLMFactory.get_profile_info("creative_writer") + + # Override existing profile + LLMFactory.create_profile("default", "LiteLLM", model="gpt-4o", temperature=0.5) + + Supported LiteLLM Parameters: + See https://docs.litellm.ai/docs/completion/input for full list: + - model: Model name (required) + - temperature: Sampling temperature (0-2) + - top_p: Nucleus sampling parameter + - max_tokens: Maximum tokens to generate + - max_completion_tokens: Upper bound for completion tokens + - presence_penalty: Penalize new tokens based on presence + - frequency_penalty: Penalize new tokens based on frequency + - stop: Stop sequences (string or list) + - stream: Enable streaming responses + - response_format: Output format specification + - seed: Deterministic sampling seed + - tools: Function calling tools + - tool_choice: Control function calling behavior + - logprobs: Return log probabilities + - top_logprobs: Number of most likely tokens to return + - n: Number of completions to generate + - and many more... """ - # Default profiles for different use cases + # Default profile - just gpt-4o-mini with no opinionated settings _profiles = { - 'default': {'backend': 'LiteLLM', 'params': {'model': 'gpt-4o-mini'}}, - 'premium': {'backend': 'LiteLLM', 'params': {'model': 'gpt-4'}}, - 'cheap': {'backend': 'LiteLLM', 'params': {'model': 'gpt-4o-mini'}}, - 'fast': {'backend': 'LiteLLM', 'params': {'model': 'gpt-3.5-turbo-mini'}}, - 'reasoning': {'backend': 'LiteLLM', 'params': {'model': 'o1-mini'}}, + 'default': {'backend': 'LiteLLM', 'params': {'model': 'gpt-4o'}}, } - @classmethod - def get_llm(cls, profile: str = 'default') -> AbstractModel: - """Get an LLM instance for the specified profile.""" + def get_llm(cls, profile: str = 'default', model: str = None, mm_beta: bool = False, **kwargs) -> AbstractModel: + """Get an LLM instance for the specified profile or model. + + Args: + profile: Name of the profile to use. Defaults to 'default'. + model: Model name to use directly. If provided, overrides profile. + mm_beta: If True, returns AssistantTurn objects with rich multimodal content. + If False (default), returns raw API responses in legacy format. + **kwargs: Additional parameters to pass to the backend (e.g., temperature, top_p). + These override profile settings if both are specified. + + Returns: + An LLM instance configured according to the profile/model and parameters. + + Examples: + # Use default profile + llm = LLMFactory.get_llm() + + # Use specific model + llm = LLMFactory.get_llm(model="gpt-4o") + + # Use named profile + llm = LLMFactory.get_llm(profile="creative_writer") + + # Use model with custom parameters + llm = LLMFactory.get_llm(model="gpt-4o", temperature=0.7, max_tokens=1000) + + # Override profile settings + llm = LLMFactory.get_llm(profile="creative_writer", temperature=0.5) + + # Use mm_beta mode for multimodal responses + llm = LLMFactory.get_llm(model="gpt-4o", mm_beta=True) + """ + # If model is specified directly, create a simple config + if model is not None: + backend = kwargs.pop('backend', None) + + # Determine backend with priority: + # 1. Explicit backend kwarg (always wins) + # 2. Gemini model name -> GoogleGenAI + # 3. Default -> LiteLLM + if backend is not None: + backend_cls = _LLM_REGISTRY[backend] + elif model.startswith('gemini'): + # Gemini models default to GoogleGenAILLM backend + backend_cls = _LLM_REGISTRY['GoogleGenAI'] + # Strip 'gemini/' prefix if present (LiteLLM format: gemini/gemini-pro) + if model.startswith('gemini/'): + model = model[len('gemini/'):] + else: + # Default to LiteLLM for other models + backend_cls = _LLM_REGISTRY['LiteLLM'] + + params = {'model': model, 'mm_beta': mm_beta, **kwargs} + return backend_cls(**params) + # Otherwise use profile if profile not in cls._profiles: - raise ValueError(f"Unknown profile '{profile}'. Available profiles: {list(cls._profiles.keys())}") + raise ValueError( + f"Unknown profile '{profile}'. Available profiles: {list(cls._profiles.keys())}. " + f"Use LLMFactory.create_profile() to create custom profiles, or pass model= directly." + ) - config = cls._profiles[profile] + config = cls._profiles[profile].copy() backend_cls = _LLM_REGISTRY[config['backend']] - return backend_cls(**config['params']) + + # Merge profile params with any override kwargs + params = config['params'].copy() + params['mm_beta'] = mm_beta + params.update(kwargs) + + return backend_cls(**params) @classmethod - def register_profile(cls, name: str, backend: str, **params): - """Register a new LLM profile.""" + def create_profile(cls, name: str, backend: str = 'LiteLLM', **params): + """Register a new LLM profile with custom configuration. + + Args: + name: Profile name to register. + backend: Backend to use ('LiteLLM', 'AutoGen', or 'CustomLLM'). Defaults to 'LiteLLM'. + **params: Configuration parameters for the backend. For LiteLLM, this can include + any parameters from https://docs.litellm.ai/docs/completion/input + + Examples: + # Simple profile with just a model + LLMFactory.create_profile("gpt4", model="gpt-4o") + + # Profile with temperature and token settings + LLMFactory.create_profile( + "creative", + model="gpt-4o", + temperature=0.9, + max_tokens=2000 + ) + + # Profile with advanced settings + LLMFactory.create_profile( + "structured_json", + model="gpt-4o-mini", + temperature=0.3, + response_format={"type": "json_object"}, + max_tokens=1500, + top_p=0.9 + ) + """ + if backend not in _LLM_REGISTRY: + raise ValueError( + f"Unknown backend '{backend}'. Valid options: {list(_LLM_REGISTRY.keys())}" + ) cls._profiles[name] = {'backend': backend, 'params': params} @classmethod def list_profiles(cls): - """List all available profiles.""" + """List all available profile names.""" return list(cls._profiles.keys()) @classmethod def get_profile_info(cls, profile: str = None): - """Get information about a profile or all profiles.""" + """Get configuration information about one or all profiles. + + Args: + profile: Profile name to get info for. If None, returns all profiles. + + Returns: + Dictionary with profile configuration(s). + """ if profile: return cls._profiles.get(profile) return cls._profiles @@ -446,10 +1103,12 @@ class DummyLLM(AbstractModel): def __init__(self, callable, - reset_freq: Union[int, None] = None) -> None: + reset_freq: Union[int, None] = None, + mm_beta: bool = False, + model_name: Union[str, None] = None) -> None: # self.message = message self.callable = callable - super().__init__(self._factory, reset_freq) + super().__init__(self._factory, reset_freq, mm_beta=mm_beta, model_name=model_name) def _factory(self): @@ -465,29 +1124,220 @@ def __init__(self, content): class Response: def __init__(self, content): self.choices = [Choice(content)] + self.content = content # for the AssistantTurn API return lambda *args, **kwargs: Response(self.callable(*args, **kwargs)) - class LLM: """ A unified entry point for all supported LLM backends. - Usage: - # pick by env var (default: LiteLLM) - llm = LLM() - # or override explicitly - llm = LLM(backend="AutoGen", config_list=my_configs) - # or use predefined profiles - llm = LLM(profile="premium") # Use premium model - llm = LLM(profile="cheap") # Use cheaper model - llm = LLM(profile="reasoning") # Use reasoning/thinking model + The LLM class provides a simple interface for creating language model instances. + By default, it uses gpt-4o through LiteLLM unless TRACE_LITELLM_MODEL is set. + + Basic Usage: + # Use default model (gpt-4o, unless TRACE_LITELLM_MODEL is set) + llm = LLM() + + # Specify a model directly (highest priority) + llm = LLM(model="gpt-4o") + llm = LLM(model="claude-3-5-sonnet-latest") + llm = LLM(model="o1-preview") + + # Use Azure OpenAI via environment variable + os.environ["TRACE_LITELLM_MODEL"] = "azure/o4-mini" + llm = LLM() # Automatically uses Azure with proper auth + + # Add LiteLLM parameters + llm = LLM(model="gpt-4o", temperature=0.7, max_tokens=2000) + llm = LLM(model="gpt-4o-mini", temperature=0.3, top_p=0.9) + + Image Generation: + # OpenAI image models (auto-detected by 'image' or 'dall-e' in name) + img_llm = LLM(model="gpt-image-1.5") + print(img_llm.is_image_model) # True + result = img_llm(prompt="A serene mountain landscape at sunset") + + # With additional parameters + img_llm = LLM(model="gpt-image-1", size="1024x1024", quality="hd") + result = img_llm(prompt="A futuristic cityscape") + + # DALL-E models + dalle = LLM(model="dall-e-3") + result = dalle(prompt="A cat astronaut in space", size="1024x1792") + + # Gemini image models + gemini_img = LLM(model="gemini-2.5-flash-image") + result = gemini_img(prompt="Abstract art", number_of_images=2) + + # Check if model generates images + if llm.is_image_model: + result = llm(prompt="Your prompt here") + else: + result = llm(messages=[{"role": "user", "content": "Your message"}]) + + Using Multimodal Beta Mode: + # Enable mm_beta for rich AssistantTurn responses + llm = LLM(model="gpt-4o", mm_beta=True) + response = llm(messages=[{"role": "user", "content": "Hello"}]) + # response is now an AssistantTurn object with .content, .tool_calls, etc. + + # Legacy mode (default, mm_beta=False) + llm = LLM(model="gpt-4o") + response = llm(messages=[{"role": "user", "content": "Hello"}]) + # response is raw API response: response.choices[0].message.content + + Using System Messages: + + # LiteLLM (OpenAI, Anthropic, etc.) - Use messages array with role="system" + llm = LLM(model="gpt-4o-mini", mm_beta=True) + response = llm(messages=[ + {"role": "system", "content": "You are a helpful math tutor."}, + {"role": "user", "content": "What is 2+2?"} + ]) + print(response.get_text()) # AssistantTurn object + + # LiteLLM Legacy mode (mm_beta=False) + llm = LLM(model="gpt-4o-mini") + response = llm(messages=[ + {"role": "system", "content": "You are a pirate assistant."}, + {"role": "user", "content": "Hello!"} + ]) + print(response.choices[0].message.content) # Raw API response + + # Google Gemini - Use system_instruction parameter (not in messages array) + llm = LLM(backend="GoogleGenAI", model="gemini-2.5-flash-image", mm_beta=True) + response = llm( + "Hello there", + system_instruction="You are a helpful assistant." + ) + print(response.get_text()) # AssistantTurn object + + # Gemini with messages format (system_instruction separate from messages) + llm = LLM(backend="GoogleGenAI", model="gemini-2.5-flash-image", mm_beta=True) + response = llm( + messages=[ + {"role": "user", "content": "What is your purpose?"} + ], + system_instruction="You are a creative writing instructor." + ) + + # Our Gemini wrapper also automatically extracts system instruction from messages array if not passed explicitly + messages = [ + {"role": "system", "content": "You are a Shakespearean poet."}, + {"role": "user", "content": "Tell me about the sun."} + ] + response1 = llm(messages=messages) + messages.append({"role": "assistant", "content": response1.get_text()}) + messages.append({"role": "user", "content": "And the moon?"}) + response2 = llm(messages=messages) # System message still applies + + Using Named Profiles: + # Use a saved profile + llm = LLM(profile="my_custom_profile") + + # Create profiles with LLMFactory + LLMFactory.create_profile("creative", model="gpt-4o", temperature=0.9) + llm = LLM(profile="creative") + + Using Different Backends: + # Explicitly specify backend (default: LiteLLM) + llm = LLM(backend="AutoGen", config_list=my_configs) + llm = LLM(backend="CustomLLM", model="llama-3.1-8b") + llm = LLM(backend="GoogleGenAI", model="gemini-2.5-flash-image") + + # Or set via environment variable + # export TRACE_DEFAULT_LLM_BACKEND=AutoGen + llm = LLM() + + Model Selection Priority (when no model= argument is provided): + The LLM class follows this priority order for determining which model to use: + + 1. Explicit model argument: LLM(model="gpt-4o") + 2. TRACE_LITELLM_MODEL environment variable: + # export TRACE_LITELLM_MODEL="azure/o4-mini" + llm = LLM() # Uses azure/o4-mini + 3. Named profile (default='default'): LLM(profile="my_profile") + 4. Backend-specific defaults + + This means setting TRACE_LITELLM_MODEL will be honored even when using + default initialization, making it ideal for Azure and custom endpoints: + + # Azure OpenAI setup + os.environ["TRACE_LITELLM_MODEL"] = "azure/o4-mini" + os.environ["AZURE_API_KEY"] = "your-key" + os.environ["AZURE_API_BASE"] = "https://your-resource.openai.azure.com" + os.environ["AZURE_API_VERSION"] = "2024-08-01-preview" + + # Now LLM() automatically uses Azure + llm = LLM(mm_beta=True) # Uses azure/o4-mini with proper Azure auth + + Examples with LiteLLM Parameters: + # Structured output + llm = LLM( + model="gpt-4o-mini", + response_format={"type": "json_object"}, + temperature=0.3 + ) + + # High creativity + llm = LLM( + model="gpt-4o", + temperature=0.9, + top_p=0.95, + presence_penalty=0.6 + ) + + # Deterministic responses + llm = LLM( + model="gpt-4o-mini", + temperature=0, + seed=42 + ) + + Key Differences Between Backends: + LiteLLM (OpenAI, Anthropic, etc.): + - System message: Include in messages array with role="system" + - Format: messages=[{"role": "system", "content": "..."}] + - Works with: OpenAI, Anthropic, Cohere, etc. + + Google Gemini: + - System instruction: Pass as system_instruction parameter + - Format: system_instruction="You are a helpful assistant." + - Separate from messages array + - Works with: gemini-2.5-flash, gemini-2.5-pro, etc. + + See Also: + - LLMFactory: For managing named profiles + - AssistantTurn: Returned when mm_beta=True + - https://docs.litellm.ai/docs/completion/input: Full list of LiteLLM parameters + - https://ai.google.dev/gemini-api/docs/system-instructions: Gemini system instructions """ - def __new__(cls, *args, profile: str = None, backend: str = None, **kwargs): - # New: if profile is specified, use LLMFactory + def __new__(cls, model: str = None, profile: str = 'default', backend: str = None, + mm_beta: bool = False, **kwargs): + + if _is_image_generation_model(model): + mm_beta = True + + # Priority 1: If model is specified, use LLMFactory with model + if model: + if backend is not None: + kwargs['backend'] = backend + return LLMFactory.get_llm(model=model, mm_beta=mm_beta, **kwargs) + + # Priority 2: Check if TRACE_LITELLM_MODEL is set (honor user's explicit env config) + env_model = os.environ.get('TRACE_LITELLM_MODEL') + if env_model is not None: + if backend is not None: + kwargs['backend'] = backend + return LLMFactory.get_llm(model=env_model, mm_beta=mm_beta, **kwargs) + + # Priority 3: If profile is specified, use LLMFactory if profile: - return LLMFactory.get_llm(profile) - # Decide which backend to use + return LLMFactory.get_llm(profile=profile, mm_beta=mm_beta, **kwargs) + + # Priority 4: Use backend-specific instantiation (for AutoGen, CustomLLM, etc.) + # This path is for when neither profile nor model is specified name = backend or os.getenv("TRACE_DEFAULT_LLM_BACKEND", "LiteLLM") try: backend_cls = _LLM_REGISTRY[name] @@ -495,4 +1345,5 @@ def __new__(cls, *args, profile: str = None, backend: str = None, **kwargs): raise ValueError(f"Unknown LLM backend: {name}. " f"Valid options are: {list(_LLM_REGISTRY)}") # Instantiate and return the chosen subclass - return backend_cls(*args, **kwargs) \ No newline at end of file + kwargs['mm_beta'] = mm_beta + return backend_cls(**kwargs) \ No newline at end of file diff --git a/setup.py b/setup.py index dbd60be5..394d4046 100644 --- a/setup.py +++ b/setup.py @@ -11,9 +11,11 @@ install_requires = [ "graphviz>=0.20.1", "pytest", - "litellm==1.75.0", + "litellm==1.80.8", + "google-genai", "black", "scikit-learn", + "pillow", "tensorboardX", "tensorboard" ] diff --git a/tests/llm_optimizers_tests/test_optoprime_v3.py b/tests/llm_optimizers_tests/test_optoprime_v3.py new file mode 100644 index 00000000..f124c5ec --- /dev/null +++ b/tests/llm_optimizers_tests/test_optoprime_v3.py @@ -0,0 +1,510 @@ +import os +import pytest +from opto.trace import GRAPH +from opto.utils.llm import LLM + +from opto.trace import node, bundle +from opto.optimizers.optoprime_v3 import ( + OptoPrimeV3, OptimizerPromptSymbolSet2, ProblemInstance, + OptimizerPromptSymbolSet, value_to_image_content +) +from opto.utils.backbone import TextContent, ImageContent + +# You can override for temporarly testing a specific optimizer ALL_OPTIMIZERS = [TextGrad] # [OptoPrimeMulti] ALL_OPTIMIZERS = [OptoPrime] + +# Tests that issue real LLM calls are opt-in: set RUN_LIVE_LLM_TESTS=1 to run +# them. CI runs against a text-only stub that cannot satisfy the multimodal +# optimizer steps, so they are skipped there. +SKIP_REASON = "Live LLM test; set RUN_LIVE_LLM_TESTS=1 to run" +HAS_CREDENTIALS = os.environ.get("RUN_LIVE_LLM_TESTS") == "1" +llm = LLM() + + +@pytest.fixture(autouse=True) +def clear_graph(): + """Reset the graph before each test""" + GRAPH.clear() + yield + GRAPH.clear() + + +@pytest.mark.skipif(not HAS_CREDENTIALS, reason=SKIP_REASON) +def test_response_extraction(): + pass + + +def test_tag_template_change(): + num_1 = node(1, trainable=True) + num_2 = node(2, trainable=True, description="<=5") + result = num_1 + num_2 + optimizer = OptoPrimeV3([num_1, num_2], use_json_object_format=False, + ignore_extraction_error=False, + include_example=True, + optimizer_prompt_symbol_set=OptimizerPromptSymbolSet2()) + + optimizer.zero_feedback() + optimizer.backward(result, 'make this number bigger') + + summary = optimizer.summarize() + system_prompt, user_prompt = optimizer.construct_prompt(summary) + + # system_prompt is a string, user_prompt is a ContentBlockList + system_prompt = optimizer.replace_symbols(system_prompt, optimizer.prompt_symbols) + + # Convert ContentBlockList to text for symbol replacement + user_prompt_text = "".join(block.text for block in user_prompt if isinstance(block, TextContent)) + user_prompt_text = optimizer.replace_symbols(user_prompt_text, optimizer.prompt_symbols) + + assert """""" in system_prompt, "Expected tag to be present in system_prompt" + assert """""" in user_prompt_text, "Expected tag to be present in user_prompt" + + print(system_prompt) + print(user_prompt_text) + + +@bundle() +def transform(num): + """Add number""" + return num + 1 + + +@bundle(trainable=True) +def multiply(num): + return num * 5 + + +def test_function_repr(): + num_1 = node(1, trainable=False) + + result = multiply(transform(num_1)) + optimizer = OptoPrimeV3([multiply.parameter], use_json_object_format=False, + ignore_extraction_error=False, + include_example=True) + + optimizer.zero_feedback() + optimizer.backward(result, 'make this number bigger') + + summary = optimizer.summarize() + system_prompt, user_prompt = optimizer.construct_prompt(summary) + + system_prompt = optimizer.replace_symbols(system_prompt, optimizer.prompt_symbols) + # Convert ContentBlockList to text for symbol replacement + user_prompt_text = "".join(block.text for block in user_prompt if isinstance(block, TextContent)) + user_prompt_text = optimizer.replace_symbols(user_prompt_text, optimizer.prompt_symbols) + + function_repr = """ + +def multiply(num): + return num * 5 + + +The code should start with: +def multiply(num): + +""" + + assert function_repr in user_prompt_text, "Expected function representation to be present in user_prompt" + +def test_big_data_truncation(): + num_1 = node("**2", trainable=True) + + list_1 = node("12345691912338" * 10, trainable=False) + + result = list_1 + num_1 + + optimizer = OptoPrimeV3([num_1], use_json_object_format=False, + ignore_extraction_error=False, + include_example=True, initial_var_char_limit=10) + + optimizer.zero_feedback() + optimizer.backward(result, 'compute the expression') + + summary = optimizer.summarize() + system_prompt, user_prompt = optimizer.construct_prompt(summary) + + system_prompt = optimizer.replace_symbols(system_prompt, optimizer.prompt_symbols) + # Convert ContentBlockList to text for symbol replacement + user_prompt_text = "".join(block.text for block in user_prompt if isinstance(block, TextContent)) + user_prompt_text = optimizer.replace_symbols(user_prompt_text, optimizer.prompt_symbols) + + truncated_repr = """1234569191...(skipped due to length limit)""" + + assert truncated_repr in user_prompt_text, "Expected truncated list representation to be present in user_prompt" + +def test_extraction_pipeline(): + num_1 = node(1, trainable=True) + num_2 = node(2, trainable=True, description="<=5") + result = num_1 + num_2 + optimizer = OptoPrimeV3([num_1, num_2], use_json_object_format=False, + ignore_extraction_error=False, + include_example=True, + optimizer_prompt_symbol_set=OptimizerPromptSymbolSet2()) + + optimizer.zero_feedback() + optimizer.backward(result, 'make this number bigger') + + summary = optimizer.summarize() + system_prompt, user_prompt = optimizer.construct_prompt(summary) + + # Verify construct_prompt returns expected types + assert isinstance(system_prompt, str) + assert isinstance(user_prompt, list) + + # Test extraction from a mock response + response = """ +The instruction suggests that the output, `add0`, needs to be made bigger than it currently is (3). The code performs an addition of `int0` and `int1` to produce `add0`. To increase `add0`, we can increase the values of `int0` or `int1`, or both. Given that `int1` has a constraint of being less than or equal to 5, we can set `int0` to a higher value, since it has no explicit constraint. By adjusting `int0` to a higher value, the output can be made larger in accordance with the feedback. + + + +int0 + +5 + + + + +int1 + +5 + +""" + suggestion = optimizer.extract_llm_suggestion(response) + + assert 'reasoning' in suggestion, "Expected 'reasoning' in suggestion" + assert 'variables' in suggestion, "Expected 'variables' in suggestion" + assert 'int0' in suggestion['variables'], "Expected 'int0' variable in suggestion" + assert 'int1' in suggestion['variables'], "Expected 'int1' variable in suggestion" + assert suggestion['variables']['int0'] == '5', "Expected int0 to be incremented to 5" + assert suggestion['variables']['int1'] == '5', "Expected int1 to be incremented to 5" + + +# ==================== Multimodal / Content Block Tests ==================== + +def test_problem_instance_text_only(): + """Test that ProblemInstance with text-only content works correctly.""" + from opto.utils.backbone import ContentBlockList + symbol_set = OptimizerPromptSymbolSet() + + instance = ProblemInstance( + instruction="Test instruction", + code="y = add(x=a, y=b)", + documentation="[add] Adds two numbers", + variables=ContentBlockList("5"), + inputs=ContentBlockList("3"), + others=ContentBlockList(), + outputs=ContentBlockList("8"), + feedback="Result should be 10", + context="Some context", + optimizer_prompt_symbol_set=symbol_set + ) + + # Test __repr__ returns string + text_repr = str(instance) + assert "Test instruction" in text_repr + assert "y = add(x=a, y=b)" in text_repr + assert "Result should be 10" in text_repr + assert "Some context" in text_repr + + # Test to_content_blocks returns list + blocks = instance.to_content_blocks() + assert isinstance(blocks, list) + assert len(blocks) > 0 + assert all(isinstance(b, (TextContent, ImageContent)) for b in blocks) + + # Test has_images returns False for text-only + assert not instance.has_images() + + +def test_problem_instance_with_content_blocks(): + """Test ProblemInstance with ContentBlockList fields containing images.""" + from opto.utils.backbone import ContentBlockList + symbol_set = OptimizerPromptSymbolSet() + + # Create content blocks with an image + variables_blocks = ContentBlockList([ + TextContent(text=""), + ImageContent(image_url="https://example.com/test.jpg"), + TextContent(text="") + ]) + + instance = ProblemInstance( + instruction="Analyze the image", + code="result = analyze(img)", + documentation="[analyze] Analyzes an image", + variables=variables_blocks, + inputs=ContentBlockList(), + others=ContentBlockList(), + outputs=ContentBlockList("cat"), + feedback="Result should be 'dog'", + context=None, + optimizer_prompt_symbol_set=symbol_set + ) + + # Test __repr__ handles content blocks (should show [IMAGE] placeholder) + text_repr = str(instance) + assert "Analyze the image" in text_repr + assert "[IMAGE]" in text_repr + + # Test to_content_blocks includes the image + blocks = instance.to_content_blocks() + assert isinstance(blocks, list) + + # Find the ImageContent block + image_blocks = [b for b in blocks if isinstance(b, ImageContent)] + assert len(image_blocks) == 1 + assert image_blocks[0].image_url == "https://example.com/test.jpg" + + # Test has_images returns True + assert instance.has_images() + + +def test_problem_instance_mixed_content(): + """Test ProblemInstance with mixed text and image content in multiple fields.""" + from opto.utils.backbone import ContentBlockList + symbol_set = OptimizerPromptSymbolSet() + + # Variables with image + variables_blocks = ContentBlockList([ + TextContent(text="Hello\n"), + TextContent(text=""), + ImageContent(image_data="base64data", media_type="image/png"), + TextContent(text="") + ]) + + # Inputs with image + inputs_blocks = ContentBlockList([ + TextContent(text=""), + ImageContent(image_url="https://example.com/ref.png"), + TextContent(text="") + ]) + + instance = ProblemInstance( + instruction="Compare images", + code="result = compare(img, reference)", + documentation="[compare] Compares two images", + variables=variables_blocks, + inputs=inputs_blocks, + others=ContentBlockList(), + outputs=ContentBlockList("0.8"), + feedback="Similarity should be higher", + context="Context text", + optimizer_prompt_symbol_set=symbol_set + ) + + # Test has_images + assert instance.has_images() + + # Test to_content_blocks + blocks = instance.to_content_blocks() + image_blocks = [b for b in blocks if isinstance(b, ImageContent)] + assert len(image_blocks) == 2 # One from variables, one from inputs + + +def test_value_to_image_content_url(): + """Test value_to_image_content with URL strings.""" + # Valid image URL + result = value_to_image_content("https://example.com/image.jpg") + assert result is not None + assert isinstance(result, ImageContent) + assert result.image_url == "https://example.com/image.jpg" + + # Non-image URL (no image extension) - is_image returns False for pattern check + result = value_to_image_content("https://example.com/page.html") + assert result is None + + # Non-URL string + result = value_to_image_content("just a regular string") + assert result is None + + +def test_value_to_image_content_base64(): + """Test value_to_image_content with base64 data URLs.""" + # Valid base64 data URL + data_url = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUg==" + result = value_to_image_content(data_url) + assert result is not None + assert isinstance(result, ImageContent) + assert result.image_data == "iVBORw0KGgoAAAANSUhEUg==" + assert result.media_type == "image/png" + + +def test_value_to_image_content_non_image(): + """Test value_to_image_content with non-image values.""" + # Integer + assert value_to_image_content(42) is None + + # List + assert value_to_image_content([1, 2, 3]) is None + + # Dict + assert value_to_image_content({"key": "value"}) is None + + # Regular string + assert value_to_image_content("hello world") is None + + +def test_construct_prompt(): + """Test construct_prompt returns ContentBlockList for multimodal support.""" + num_1 = node(1, trainable=True) + num_2 = node(2, trainable=True) + result = num_1 + num_2 + + optimizer = OptoPrimeV3([num_1, num_2], use_json_object_format=False) + optimizer.zero_feedback() + optimizer.backward(result, 'make this number bigger') + + summary = optimizer.summarize() + system_prompt, user_prompt = optimizer.construct_prompt(summary) + + # system_prompt should be string, user_prompt should be ContentBlockList + assert isinstance(system_prompt, str) + assert isinstance(user_prompt, list) + assert all(isinstance(b, (TextContent, ImageContent)) for b in user_prompt) + + # Check that text content contains expected info + text_parts = [b.text for b in user_prompt if isinstance(b, TextContent)] + full_text = "".join(text_parts) + assert "int0" in full_text or "int1" in full_text + + +def test_repr_node_value_as_content_blocks(): + """Test repr_node_value_as_content_blocks method.""" + num_1 = node(1, trainable=True) + result = num_1 + 1 + + optimizer = OptoPrimeV3([num_1], use_json_object_format=False) + optimizer.zero_feedback() + optimizer.backward(result, 'test') + + # Test with non-image nodes + summary = optimizer.summarize() + blocks = optimizer.repr_node_value_as_content_blocks( + summary.variables, + node_tag=optimizer.optimizer_prompt_symbol_set.variable_tag, + value_tag=optimizer.optimizer_prompt_symbol_set.value_tag, + constraint_tag=optimizer.optimizer_prompt_symbol_set.constraint_tag + ) + + assert isinstance(blocks, list) + assert len(blocks) > 0 + assert all(isinstance(b, TextContent) for b in blocks) # No images in this case + + +def test_repr_node_value_compact_as_content_blocks(): + """Test repr_node_value_compact_as_content_blocks method.""" + long_string = "x" * 5000 # Long string that will be truncated + str_node = node(long_string, trainable=True) + result = str_node + "!" + + optimizer = OptoPrimeV3([str_node], use_json_object_format=False, initial_var_char_limit=100) + optimizer.zero_feedback() + optimizer.backward(result, 'test') + + summary = optimizer.summarize() + blocks = optimizer.repr_node_value_compact_as_content_blocks( + summary.inputs, + node_tag=optimizer.optimizer_prompt_symbol_set.node_tag, + value_tag=optimizer.optimizer_prompt_symbol_set.value_tag, + constraint_tag=optimizer.optimizer_prompt_symbol_set.constraint_tag + ) + + # Should be truncated + text_parts = [b.text for b in blocks if isinstance(b, TextContent)] + full_text = "".join(text_parts) + assert "skipped due to length limit" in full_text or len(full_text) < len(long_string) + + +# ==================== Real LLM Call Tests ==================== + +@pytest.mark.skipif(not HAS_CREDENTIALS, reason=SKIP_REASON) +def test_optimizer_step_real_llm_call(): + """Test a real optimization step with LLM call.""" + # Create a simple optimization problem + greeting = node("Hello", trainable=True, description="A greeting message") + + @bundle() + def make_sentence(word): + """Create a sentence from a word.""" + return f"{word}, how are you today?" + + result = make_sentence(greeting) + + # Create optimizer + optimizer = OptoPrimeV3( + [greeting], + use_json_object_format=False, + ignore_extraction_error=True, + include_example=False, + ) + + # Setup feedback + optimizer.zero_feedback() + optimizer.backward(result, "The greeting should be more formal and professional") + + # Execute optimization step - this makes a real LLM call + update_dict = optimizer.step(verbose=True) + + # Verify the optimizer produced a suggestion + print(f"Update dict: {update_dict}") + + # The LLM should have suggested a new value + # We don't assert specific content since LLM output varies + # but we verify the step completed without error + assert optimizer.log is not None + assert len(optimizer.log) > 0 + + # Check that the log contains the expected structure + last_log = optimizer.log[-1] + assert "system_prompt" in last_log + assert "user_prompt" in last_log + assert "response" in last_log + + print(f"LLM Response: {last_log['response'][:500]}...") + + +@pytest.mark.skipif(not HAS_CREDENTIALS, reason=SKIP_REASON) +def test_optimizer_step_with_content_blocks(): + """Test optimization step using content blocks (multimodal mode).""" + # Create trainable parameters + num_1 = node(5, trainable=True, description="A number to optimize") + num_2 = node(3, trainable=True, description="Another number") + + result = num_1 + num_2 + + # Create optimizer + optimizer = OptoPrimeV3( + [num_1, num_2], + use_json_object_format=False, + ignore_extraction_error=True, + include_example=False, + ) + + # Setup feedback + optimizer.zero_feedback() + optimizer.backward(result, "The sum should be exactly 100") + + # Test that construct_prompt returns ContentBlockList + summary = optimizer.summarize() + system_prompt, user_prompt = optimizer.construct_prompt(summary) + + # Verify content blocks structure + from opto.utils.backbone import ContentBlockList + assert isinstance(user_prompt, ContentBlockList) + assert len(user_prompt) > 0 + + # Verify text is merged (should be fewer blocks than if not merged) + text_blocks = [b for b in user_prompt if isinstance(b, TextContent)] + print(f"Number of text blocks after merging: {len(text_blocks)}") + + # Execute the step (this makes a real LLM call) + update_dict = optimizer.step(verbose=True) + + print(f"Update dict: {update_dict}") + + # Verify the step completed + assert optimizer.log is not None + assert len(optimizer.log) > 0 + +@pytest.mark.skipif(not HAS_CREDENTIALS, reason=SKIP_REASON) +def test_optimizer_multimodal_parameter_update(): + pass \ No newline at end of file diff --git a/tests/unit_tests/test_backbone.py b/tests/unit_tests/test_backbone.py new file mode 100644 index 00000000..ace28cbb --- /dev/null +++ b/tests/unit_tests/test_backbone.py @@ -0,0 +1,702 @@ +""" +Comprehensive tests for optimizer backbone components (Chat, UserTurn, AssistantTurn) +Tests include: truncation strategies, multimodal content, and conversation management + +We need to test a few things: +1. Various use cases of ContentBlock and specialized ones +2. UserTurn, AssistantTurn and conversation manager +3. Multi-modal use of conversation manager, including multi-turn and image as output +""" +import os +import base64 +import pytest +from opto.utils.backbone import ( + Chat, + UserTurn, + AssistantTurn +) + +# These tests make real LLM calls (some with image inputs) against specific +# models. They are opt-in: set RUN_LIVE_LLM_TESTS=1 to run them. CI runs against +# a text-only stub that cannot satisfy them, so they are skipped there. +SKIP_REASON = "Live LLM test; set RUN_LIVE_LLM_TESTS=1 to run" +HAS_CREDENTIALS = os.environ.get("RUN_LIVE_LLM_TESTS") == "1" + + +# ============================================================================ +# Test Fixtures +# ============================================================================ + +def create_sample_conversation(): + """Create a sample conversation with multiple rounds""" + history = Chat(system_prompt="You are a helpful assistant.") + + # Round 1 + user1 = UserTurn().add_text("Hello, what's the weather?") + assistant1 = AssistantTurn().add_text("The weather is sunny today.") + history.add_user_turn(user1).add_assistant_turn(assistant1) + + # Round 2 + user2 = UserTurn().add_text("What about tomorrow?") + assistant2 = AssistantTurn().add_text("Tomorrow will be rainy.") + history.add_user_turn(user2).add_assistant_turn(assistant2) + + # Round 3 + user3 = UserTurn().add_text("Should I bring an umbrella?") + assistant3 = AssistantTurn().add_text("Yes, definitely bring an umbrella.") + history.add_user_turn(user3).add_assistant_turn(assistant3) + + # Round 4 + user4 = UserTurn().add_text("Thanks for the advice!") + assistant4 = AssistantTurn().add_text("You're welcome! Stay dry!") + history.add_user_turn(user4).add_assistant_turn(assistant4) + + return history + + +# ============================================================================ +# Truncation Tests +# ============================================================================ + +def test_default_all_history(): + """Test default behavior (n=-1) returns all history""" + history = create_sample_conversation() + + messages = history.to_messages() + + # Should have: system + 8 turns (4 user + 4 assistant) + assert len(messages) == 9 # 1 system + 8 messages + assert messages[0]["role"] == "system" + assert messages[0]["content"] == "You are a helpful assistant." + assert messages[-1]["role"] == "assistant" + + +def test_truncate_from_start(): + """Test truncate_from_start strategy - keeps last N rounds""" + history = create_sample_conversation() + + # Keep last 2 rounds (4 turns) + messages = history.to_messages(n=2, truncate_strategy="from_start") + + # Should have: system + 2 rounds (4 turns) + assert len(messages) == 5 # 1 system + 4 messages + assert messages[0]["role"] == "system" + + # Should have the last 2 rounds (round 3 and round 4) + # Round 3: user3 (umbrella question), assistant3 (umbrella answer) + # Round 4: user4 (thanks), assistant4 (welcome) + assert messages[1]["role"] == "user" + assert "umbrella" in messages[1]["content"][0]["text"] + assert messages[2]["role"] == "assistant" + # Content is now a list of dicts with type and text fields + assert any("umbrella" in item.get("text", "") for item in messages[2]["content"]) + assert messages[3]["role"] == "user" + assert "Thanks" in messages[3]["content"][0]["text"] + assert messages[4]["role"] == "assistant" + # Content is now a list of dicts with type and text fields + assert any("welcome" in item.get("text", "") for item in messages[4]["content"]) + + +def test_truncate_from_end(): + """Test truncate_from_end strategy - keeps first N rounds""" + history = create_sample_conversation() + + # Keep first 2 rounds (4 turns) + messages = history.to_messages(n=2, truncate_strategy="from_end") + + # Should have: system + 2 rounds (4 turns) + assert len(messages) == 5 # 1 system + 4 messages + assert messages[0]["role"] == "system" + + # Should have the first 2 rounds (round 1 and round 2) + # Round 1: user1 (weather), assistant1 (sunny) + # Round 2: user2 (tomorrow), assistant2 (rainy) + assert messages[1]["role"] == "user" + assert "Hello" in messages[1]["content"][0]["text"] + assert messages[2]["role"] == "assistant" + # Content is now a list of dicts with type and text fields + assert any("sunny" in item.get("text", "") for item in messages[2]["content"]) + assert messages[3]["role"] == "user" + assert "tomorrow" in messages[3]["content"][0]["text"] + assert messages[4]["role"] == "assistant" + # Content is now a list of dicts with type and text fields + assert any("rainy" in item.get("text", "") for item in messages[4]["content"]) + + +def test_truncate_zero_turns(): + """Test truncating to 0 turns""" + history = create_sample_conversation() + + messages = history.to_messages(n=0, truncate_strategy="from_start") + + # Should only have system message + assert len(messages) == 1 + assert messages[0]["role"] == "system" + + +def test_truncate_more_than_available(): + """Test requesting more turns than available""" + history = create_sample_conversation() + + # Request 100 turns but only have 8 + messages = history.to_messages(n=100, truncate_strategy="from_start") + + # Should return all available + assert len(messages) == 9 # 1 system + 8 messages + + +def test_empty_conversation(): + """Test truncation on empty conversation""" + history = Chat(system_prompt="Test") + + messages = history.to_messages(n=5) + + assert len(messages) == 1 # Just system + assert messages[0]["role"] == "system" + + +def test_to_litellm_format_with_truncation(): + """Test to_litellm_format() also supports truncation""" + history = create_sample_conversation() + + # n=2 means 2 rounds (4 turns), from_end keeps first 2 rounds + messages = history.to_litellm_format(n=2, truncate_strategy="from_end") + + # Should have: system + 2 rounds (4 turns) + assert len(messages) == 5 + assert messages[0]["role"] == "system" + assert messages[1]["role"] == "user" + assert messages[2]["role"] == "assistant" + assert messages[3]["role"] == "user" + assert messages[4]["role"] == "assistant" + + +def test_invalid_strategy(): + """Test that invalid strategy raises error""" + history = create_sample_conversation() + + with pytest.raises(ValueError, match="Unknown truncate_strategy"): + history.to_messages(n=2, truncate_strategy="invalid_strategy") + + +def test_negative_n_values(): + """Test that n=-1 returns all history""" + history = create_sample_conversation() + + # n=-1 should return all + messages_all = history.to_messages(n=-1) + assert len(messages_all) == 9 + + # Verify it's the same as not passing n at all + messages_default = history.to_messages() + assert len(messages_all) == len(messages_default) + + +# ============================================================================ +# Multimodal / Multi-Image Tests +# ============================================================================ + +def test_user_turn_multiple_images(): + """Test that a user turn can have multiple images""" + history = Chat() + + # Create a user turn with text and multiple images (like the OpenAI example) + user_turn = (UserTurn() + .add_text("What are in these images? Is there any difference between them?") + .add_image(url="https://images.pexels.com/photos/736230/pexels-photo-736230.jpeg") + .add_image(url="https://images.contentstack.io/v3/assets/bltcedd8dbd5891265b/blt134818d279038650/6668df6434f6fb5cd48aac34/beautiful-flowers-rose.jpeg")) + + history.add_user_turn(user_turn) + + # Convert to LiteLLM format + messages = history.to_litellm_format() + + # Should have 1 message + assert len(messages) == 1 + + user_msg = messages[0] + assert user_msg["role"] == "user" + + # Content should be a list with 3 items: 1 text + 2 images + assert len(user_msg["content"]) == 3 + + # Check first item is text + assert user_msg["content"][0]["type"] == "input_text" + assert user_msg["content"][0]["text"] == "What are in these images? Is there any difference between them?" + + # Check second item is first image + assert user_msg["content"][1]["type"] == "input_image" + assert user_msg["content"][1]["image_url"] == "https://images.pexels.com/photos/736230/pexels-photo-736230.jpeg" + + # Check third item is second image + assert user_msg["content"][2]["type"] == "input_image" + assert user_msg["content"][2]["image_url"] == "https://images.contentstack.io/v3/assets/bltcedd8dbd5891265b/blt134818d279038650/6668df6434f6fb5cd48aac34/beautiful-flowers-rose.jpeg" + + +def test_assistant_turn_multiple_images(): + """Test that an assistant turn can also have multiple images (for models that generate images)""" + history = Chat() + + # Assistant turn with text and multiple images + assistant_turn = (AssistantTurn() + .add_text("Here are two generated images based on your request:") + .add_image(url="https://example.com/generated1.png") + .add_image(url="https://example.com/generated2.png")) + + history.add_assistant_turn(assistant_turn) + + # Convert to LiteLLM format + messages = history.to_litellm_format() + + assert len(messages) == 1 + assert messages[0]["role"] == "assistant" + + # Assistant should have text content (now in list format) + assert any("Here are two generated images" in item.get("text", "") for item in messages[0]["content"]) + + +def test_mixed_content_types_in_turn(): + """Test mixing text, images, and other content types in a single turn""" + history = Chat() + + # Create a complex turn with multiple content types + user_turn = (UserTurn() + .add_text("Please analyze these images and this document:") + .add_image(url="https://example.com/chart1.png") + .add_image(url="https://example.com/chart2.png") + .add_text("What patterns do you see?")) + + history.add_user_turn(user_turn) + + messages = history.to_litellm_format() + + assert len(messages) == 1 + user_msg = messages[0] + + # Should have 4 content blocks: text, image, image, text + assert len(user_msg["content"]) == 4 + assert user_msg["content"][0]["type"] == "input_text" + assert user_msg["content"][1]["type"] == "input_image" + assert user_msg["content"][2]["type"] == "input_image" + assert user_msg["content"][3]["type"] == "input_text" + + +def test_multiple_images_with_base64(): + """Test multiple images using base64 encoding""" + history = Chat() + + # Create fake base64 image data + fake_image_data1 = base64.b64encode(b"fake image 1").decode('utf-8') + fake_image_data2 = base64.b64encode(b"fake image 2").decode('utf-8') + + user_turn = (UserTurn() + .add_text("Compare these two images:") + .add_image(data=fake_image_data1, media_type="image/png") + .add_image(data=fake_image_data2, media_type="image/jpeg")) + + history.add_user_turn(user_turn) + + messages = history.to_litellm_format() + + assert len(messages) == 1 + user_msg = messages[0] + + # Should have 3 content blocks + assert len(user_msg["content"]) == 3 + + # Check base64 data URLs are properly formatted + assert user_msg["content"][1]["type"] == "input_image" + assert user_msg["content"][1]["image_url"].startswith("data:image/png;base64,") + + assert user_msg["content"][2]["type"] == "input_image" + assert user_msg["content"][2]["image_url"].startswith("data:image/jpeg;base64,") + + +def test_conversation_with_multiple_multi_image_turns(): + """Test a full conversation where multiple turns each have multiple images""" + history = Chat(system_prompt="You are a helpful image analysis assistant.") + + # User turn 1: Multiple images + user1 = (UserTurn() + .add_text("What's the difference between these flowers?") + .add_image(url="https://example.com/rose.jpg") + .add_image(url="https://example.com/tulip.jpg")) + history.add_user_turn(user1) + + # Assistant response + assistant1 = AssistantTurn().add_text("The first is a rose with layered petals, the second is a tulip with a cup shape.") + history.add_assistant_turn(assistant1) + + # User turn 2: More images + user2 = (UserTurn() + .add_text("Now compare these landscapes:") + .add_image(url="https://example.com/mountain.jpg") + .add_image(url="https://example.com/beach.jpg") + .add_image(url="https://example.com/forest.jpg")) + history.add_user_turn(user2) + + messages = history.to_litellm_format() + + # Should have: system + user1 + assistant1 + user2 + assert len(messages) == 4 + + # Check user1 has 3 content blocks (1 text + 2 images) + assert len(messages[1]["content"]) == 3 + + # Check user2 has 4 content blocks (1 text + 3 images) + assert len(messages[3]["content"]) == 4 + + +# ============================================================================ +# Integration Tests - Truncation + Multimodal +# ============================================================================ + +def test_truncate_multimodal_conversation(): + """Test truncation works correctly with multimodal content""" + history = Chat(system_prompt="You are a vision assistant.") + + # Add several turns with images (5 rounds = 10 turns) + for i in range(5): + user = (UserTurn() + .add_text(f"Analyze image {i}") + .add_image(url=f"https://example.com/image{i}.jpg")) + assistant = AssistantTurn().add_text(f"Analysis of image {i}") + history.add_user_turn(user).add_assistant_turn(assistant) + + # Truncate to last 2 rounds (4 turns) + messages = history.to_messages(n=2, truncate_strategy="from_start") + + # Should have system + 2 rounds (4 turns) + assert len(messages) == 5 + + # Check that multimodal content is preserved + assert len(messages[1]["content"]) == 2 # text + image + assert messages[1]["content"][1]["type"] == "input_image" + +# ============================================================================ +# Real LLM Call Tests with Images +# ============================================================================ + +@pytest.mark.skipif(not HAS_CREDENTIALS, reason=SKIP_REASON) +def test_real_llm_call_with_multiple_images(): + """Test sending real images to GPT and getting a response. + + This test sends two flower images to GPT-4 Vision and asks it to compare them. + """ + from opto.utils.llm import LLM + + # Create conversation with images + history = Chat(system_prompt="You are a helpful assistant that can analyze images.") + + # Create a user turn with text and two real flower images + user_turn = (UserTurn() + .add_text("What are in these images? Is there any difference between them? Please describe each image briefly.") + .add_image(url="https://images.pexels.com/photos/736230/pexels-photo-736230.jpeg") + .add_image(url="https://images.contentstack.io/v3/assets/bltcedd8dbd5891265b/blt134818d279038650/6668df6434f6fb5cd48aac34/beautiful-flowers-rose.jpeg")) + + history.add_user_turn(user_turn) + + # Get messages in LiteLLM format + messages = history.to_litellm_format() + + print("\n" + "="*80) + print("REAL LLM CALL WITH MULTIPLE IMAGES") + print("="*80) + print(f"\nSending {len(user_turn.content)} content blocks (1 text + 2 images)...") + + # Make the LLM call with mm_beta=True for Response API format + llm = LLM(mm_beta=True) + response = llm(messages=messages, max_tokens=500) + + # response is now an AssistantTurn object + response_content = response.to_text() + + print("\nšŸ“· User Query:") + print(" What are in these images? Is there any difference between them?") + print("\nšŸ¤– GPT Response:") + print("-" * 40) + print(response_content) + print("-" * 40) + + # Store assistant response in history + history.add_assistant_turn(response) + + # Verify we got a meaningful response + assert response_content is not None + assert len(response_content) > 50 # Should have some substantial content + + # The response should mention something about flowers/images + response_lower = response_content.lower() + assert any(word in response_lower for word in ["flower", "image", "picture", "rose", "pink", "red", "petal"]), \ + f"Response doesn't seem to describe the flower images: {response_content[:200]}..." + + print("\nāœ… Successfully received and validated GPT response about the images!") + + +@pytest.mark.skipif(not HAS_CREDENTIALS, reason=SKIP_REASON) +def test_real_llm_multi_turn_with_images(): + """Test a multi-turn conversation with images. + + First turn: Ask about images + Second turn: Follow-up question about the same images + """ + from opto.utils.llm import LLM + + history = Chat(system_prompt="You are a helpful assistant that can analyze images.") + llm = LLM(mm_beta=True) + + print("\n" + "="*80) + print("MULTI-TURN CONVERSATION WITH IMAGES") + print("="*80) + + # Turn 1: Send images and ask about them + user_turn1 = (UserTurn() + .add_text("What type of flowers are shown in these images?") + .add_image(url="https://images.pexels.com/photos/736230/pexels-photo-736230.jpeg") + .add_image(url="https://images.contentstack.io/v3/assets/bltcedd8dbd5891265b/blt134818d279038650/6668df6434f6fb5cd48aac34/beautiful-flowers-rose.jpeg")) + + history.add_user_turn(user_turn1) + messages = history.to_litellm_format() + + print("\nšŸ“· Turn 1 - User:") + print(" What type of flowers are shown in these images? [+ 2 images]") + + response1 = llm(messages=messages, max_tokens=300) + response1_content = response1.to_text() + + print("\nšŸ¤– Turn 1 - Assistant:") + print(f" {response1_content[:200]}...") + + history.add_assistant_turn(response1) + + # Turn 2: Follow-up question (no new images, but context from previous turn) + user_turn2 = UserTurn().add_text("Which of these flowers would be better for a romantic gift and why?") + history.add_user_turn(user_turn2) + + messages = history.to_litellm_format() + + print("\nšŸ“· Turn 2 - User:") + print(" Which of these flowers would be better for a romantic gift and why?") + + response2 = llm(messages=messages, max_tokens=300) + response2_content = response2.to_text() + + print("\nšŸ¤– Turn 2 - Assistant:") + print(f" {response2_content[:200]}...") + + # Verify responses + assert response1_content is not None and len(response1_content) > 20 + assert response2_content is not None and len(response2_content) > 20 + + # Turn 2 should reference the context from turn 1 + response2_lower = response2_content.lower() + assert any(word in response2_lower for word in ["flower", "rose", "romantic", "gift", "love"]), \ + "Turn 2 response doesn't seem to reference the flower context" + + print("\nāœ… Multi-turn conversation with images completed successfully!") + + +@pytest.mark.skipif(not HAS_CREDENTIALS, reason=SKIP_REASON) +def test_real_llm_multi_turn_with_images_updated_assistant_turn(): + """Test a multi-turn conversation with images. + + First turn: Ask about images + Second turn: Follow-up question about the same images + """ + from opto.utils.llm import LLM + + history = Chat(system_prompt="You are a helpful assistant that can analyze images.") + llm = LLM(mm_beta=True) + + print("\n" + "="*80) + print("MULTI-TURN CONVERSATION WITH IMAGES") + print("="*80) + + # Turn 1: Send images and ask about them + user_turn1 = (UserTurn() + .add_text("What type of flowers are shown in these images?") + .add_image(url="https://images.pexels.com/photos/736230/pexels-photo-736230.jpeg") + .add_image(url="https://images.contentstack.io/v3/assets/bltcedd8dbd5891265b/blt134818d279038650/6668df6434f6fb5cd48aac34/beautiful-flowers-rose.jpeg")) + + history.add_user_turn(user_turn1) + messages = history.to_litellm_format() + + print("\nšŸ“· Turn 1 - User:") + print(" What type of flowers are shown in these images? [+ 2 images]") + + at = llm(messages=messages, max_tokens=300) + + print("\nšŸ¤– Turn 1 - Assistant:") + print(f" {at.to_text()[:200]}...") + + history.add_assistant_turn(at) + + # Turn 2: Follow-up question (no new images, but context from previous turn) + user_turn2 = UserTurn().add_text("Which of these flowers would be better for a romantic gift and why?") + history.add_user_turn(user_turn2) + + messages = history.to_litellm_format() + + print("\nšŸ“· Turn 2 - User:") + print(" Which of these flowers would be better for a romantic gift and why?") + + response2 = llm(messages=messages, max_tokens=300) + response2_content = response2.to_text() + + print("\nšŸ¤– Turn 2 - Assistant:") + print(f" {response2_content[:200]}...") + + # Verify responses + assert at.to_text() is not None and len(at.to_text()) > 20 + assert response2_content is not None and len(response2_content) > 20 + + # Turn 2 should reference the context from turn 1 + response2_lower = response2_content.lower() + assert any(word in response2_lower for word in ["flower", "rose", "romantic", "gift", "love"]), \ + "Turn 2 response doesn't seem to reference the flower context" + + print("\nāœ… Multi-turn conversation with images completed successfully!") + +@pytest.mark.skipif(not os.environ.get("GEMINI_API_KEY"), reason="No GEMINI_API_KEY found") +def test_real_google_genai_multi_turn_with_images_updated(): + """Test multi-turn conversation with images using Google Gemini image generation model""" + from opto.utils.llm import LLM + + print("\n" + "="*80) + print("Testing Multi-turn Conversation with Gemini Image Generation") + print("="*80) + + # Initialize conversation history + history = Chat() + history.system_prompt = "You are a helpful assistant that can generate and discuss images." + + # Use a Gemini model that supports image generation + model = "gemini-2.5-flash-image" + llm = LLM(model=model, mm_beta=True) + + print("="*80) + + # Turn 1: Ask to generate an image + user_turn1 = UserTurn().add_text("Generate an image of a serene mountain landscape at sunrise with a lake in the foreground.") + + history.add_user_turn(user_turn1) + + print("\nšŸ“· Turn 1 - User:") + print(" Generate an image of a serene mountain landscape at sunrise with a lake in the foreground.") + + # For image generation models, pass the prompt directly instead of messages + prompt = user_turn1.content.to_text() + response1 = llm(prompt=prompt, max_tokens=300) + at = AssistantTurn(response1) + + print("\nšŸ¤– Turn 1 - Assistant:") + print(f" {at.to_text()[:200] if at.to_text() else '[Image generated]'}...") + + history.add_assistant_turn(at) + + # Turn 2: Follow-up question about the generated image + user_turn2 = UserTurn().add_text("Can you describe the colors and mood of the image you just generated?") + history.add_user_turn(user_turn2) + + messages = history.to_gemini_format() + + print("\nšŸ“· Turn 2 - User:") + print(" Can you describe the colors and mood of the image you just generated?") + + response2 = llm(messages=messages, max_tokens=300) + at2 = AssistantTurn(response2) + response2_content = at2.to_text() + + print("\nšŸ¤– Turn 2 - Assistant:") + print(f" {response2_content[:200]}...") + + # Verify responses + assert at.content is not None and len(at.content) > 0 + assert response2_content is not None and len(response2_content) > 20 + + # Turn 2 should reference the context from turn 1 + response2_lower = response2_content.lower() + assert any(word in response2_lower for word in ["mountain", "sunrise", "lake", "color", "mood", "landscape"]), \ + "Turn 2 response doesn't seem to reference the image generation context" + + print("\nāœ… Multi-turn conversation with Gemini image generation completed successfully!") + +# ==== Testing the Automatic Raw Response Parsing into AssistantTurn === +@pytest.mark.skipif(not HAS_CREDENTIALS, reason=SKIP_REASON) +def test_automatic_openai_raw_response_parsing_into_assistant_turn(): + import litellm + import base64 + + # Simple OpenAI text generation + response = litellm.responses( + model="openai/gpt-4o", + input="Hello, how are you?" + ) + assistant_turn = AssistantTurn(response) + assert "Hello" in assistant_turn.content[0].text + + print(assistant_turn) + +@pytest.mark.skipif(not HAS_CREDENTIALS, reason=SKIP_REASON) +def test_automatic_openai_multimodal_raw_response_parsing_into_assistant_turn(): + import litellm + import base64 + + # OpenAI models require tools parameter for image generation + response = litellm.responses( + model="openai/gpt-4o", + input="Generate a futuristic city at sunset and describe it in a sentence.", + tools=[{"type": "image_generation"}] + ) + + assistant_turn = AssistantTurn(response) + print(assistant_turn) + + +@pytest.mark.skipif(not os.environ.get("GEMINI_API_KEY"), reason="No GEMINI_API_KEY found") +def test_automatic_google_generate_content_raw_response_parsing_into_assistant_turn(): + from google import genai + from google.genai import types + + client = genai.Client(api_key=os.environ.get("GEMINI_API_KEY")) + + response = client.models.generate_content( + model="gemini-2.5-flash-image", + contents="A kawaii-style sticker of a happy red panda wearing a tiny bamboo hat. It's munching on a green bamboo leaf. The design features bold, clean outlines, simple cel-shading, and a vibrant color palette. The background must be white.", + ) + + assistant_turn = AssistantTurn(response) + print(assistant_turn) + + assert not assistant_turn.content[1].is_empty() + + + +if __name__ == '__main__': + import litellm + import base64 + + # Gemini image generation models don't require tools parameter + response = litellm.responses( + model="gemini/gemini-2.5-flash-image", + input="Generate a cute cat playing with yarn" + ) + + # Access generated images from output + for item in response.output: + if item.type == "image_generation_call": + # item.result contains pure base64 (no data: prefix) + image_bytes = base64.b64decode(item.result) + + # Save the image + with open(f"generated_{item.id}.png", "wb") as f: + f.write(image_bytes) + + print(f"Image saved: generated_{response.output[0].id}.png") + + from google import genai + + client = genai.Client() + chat = client.chats.create(model="gemini-2.5-flash") + + diff --git a/tests/unit_tests/test_llm.py b/tests/unit_tests/test_llm.py index 9435bf33..70a0ed80 100644 --- a/tests/unit_tests/test_llm.py +++ b/tests/unit_tests/test_llm.py @@ -1,8 +1,24 @@ -from opto.utils.llm import LLM +from opto.utils.llm import LLM, LLMFactory from opto.optimizers.utils import print_color import os +import pytest +from opto.utils.backbone import ( + Chat, + UserTurn, + AssistantTurn +) + +# These tests hit a real LLM provider with specific models (e.g. gpt-4o-mini) +# and multimodal inputs. They are opt-in: set RUN_LIVE_LLM_TESTS=1 to run them. +# CI runs against a text-only stub that cannot satisfy these requirements, so by +# default they are skipped there. +SKIP_REASON = "Live LLM test; set RUN_LIVE_LLM_TESTS=1 to run" +HAS_CREDENTIALS = os.environ.get("RUN_LIVE_LLM_TESTS") == "1" + + def test_llm_init(): + """Test basic LLM initialization with legacy mode (mm_beta=False)""" if os.path.exists("OAI_CONFIG_LIST") or os.environ.get("TRACE_LITELLM_MODEL") or os.environ.get("OPENAI_API_KEY"): llm = LLM() system_prompt = 'You are a helpful assistant.' @@ -22,3 +38,432 @@ def test_llm_init(): print_color(f'System: {system_prompt}', 'red') print_color(f'User: {user_prompt}', 'blue') print_color(f'LLM: {response}', 'green') + + +@pytest.mark.skipif(not HAS_CREDENTIALS, reason=SKIP_REASON) +class TestLLMMMBetaMode: + """Test suite for LLM class with mm_beta=True and mm_beta=False modes""" + + def test_mm_beta_false_legacy_response_format(self): + """Test that mm_beta=False returns raw API response (legacy format)""" + llm = LLM(mm_beta=False) + messages = [{"role": "user", "content": "Say 'test' and nothing else."}] + + response = llm(messages=messages) + + # Legacy mode should return raw API response with .choices attribute + assert hasattr(response, 'choices'), "Legacy mode should return raw API response" + assert hasattr(response.choices[0], 'message'), "Response should have message attribute" + assert hasattr(response.choices[0].message, 'content'), "Message should have content attribute" + + # Should NOT be an AssistantTurn object + assert not isinstance(response, AssistantTurn), "Legacy mode should not return AssistantTurn" + + content = response.choices[0].message.content + assert isinstance(content, str), "Content should be a string" + assert len(content) > 0, "Content should not be empty" + + print_color(f"āœ“ Legacy mode (mm_beta=False) returns raw API response", 'green') + + def test_mm_beta_true_assistant_turn_response(self): + """Test that mm_beta=True returns AssistantTurn object""" + llm = LLM(mm_beta=True) + messages = [{"role": "user", "content": "Say 'test' and nothing else."}] + + response = llm(messages=messages) + + # mm_beta mode should return AssistantTurn object + assert isinstance(response, AssistantTurn), "mm_beta mode should return AssistantTurn object" + + # Check AssistantTurn attributes + assert hasattr(response, 'content'), "AssistantTurn should have content attribute" + assert hasattr(response, 'role'), "AssistantTurn should have role attribute" + assert response.role == "assistant", "Role should be 'assistant'" + + # Content should be accessible + assert response.content is not None, "Content should not be None" + + print_color(f"āœ“ Multimodal mode (mm_beta=True) returns AssistantTurn object", 'green') + + def test_mm_beta_with_explicit_model(self): + """Test mm_beta parameter works with explicit model specification""" + # Test with mm_beta=False + llm_legacy = LLM(model="gpt-4o-mini", mm_beta=False) + messages = [{"role": "user", "content": "Hi"}] + + response_legacy = llm_legacy(messages=messages) + assert hasattr(response_legacy, 'choices'), "Should return raw API response" + assert not isinstance(response_legacy, AssistantTurn), "Should not be AssistantTurn" + + # Test with mm_beta=True + llm_mm = LLM(model="gpt-4o-mini", mm_beta=True) + response_mm = llm_mm(messages=messages) + assert isinstance(response_mm, AssistantTurn), "Should return AssistantTurn" + + print_color(f"āœ“ mm_beta parameter works correctly with explicit model", 'green') + + def test_mm_beta_with_profile(self): + """Test mm_beta parameter works with profile-based instantiation""" + # Create a test profile + LLMFactory.create_profile("test_profile", backend="LiteLLM", model="gpt-4o-mini", temperature=0.7) + + # Test with mm_beta=False + llm_legacy = LLM(profile="test_profile", mm_beta=False) + messages = [{"role": "user", "content": "Hi"}] + + response_legacy = llm_legacy(messages=messages) + assert hasattr(response_legacy, 'choices'), "Profile with mm_beta=False should return raw API response" + + # Test with mm_beta=True + llm_mm = LLM(profile="test_profile", mm_beta=True) + response_mm = llm_mm(messages=messages) + assert isinstance(response_mm, AssistantTurn), "Profile with mm_beta=True should return AssistantTurn" + + print_color(f"āœ“ mm_beta parameter works correctly with profiles", 'green') + + def test_mm_beta_with_litellm_parameters(self): + """Test mm_beta works with various LiteLLM parameters""" + # Test with temperature and max_tokens + llm = LLM( + model="gpt-4o-mini", + mm_beta=True, + temperature=0.3, + max_tokens=100 + ) + + messages = [{"role": "user", "content": "Say hello"}] + response = llm(messages=messages) + + assert isinstance(response, AssistantTurn), "Should return AssistantTurn with LiteLLM params" + assert response.content is not None, "Should have content" + + print_color(f"āœ“ mm_beta works with LiteLLM parameters", 'green') + + def test_mm_beta_default_is_false(self): + """Test that mm_beta defaults to False for backward compatibility""" + llm = LLM() # No mm_beta specified + messages = [{"role": "user", "content": "Hi"}] + + response = llm(messages=messages) + + # Default should be legacy mode (mm_beta=False) + assert hasattr(response, 'choices'), "Default should be legacy mode" + assert not isinstance(response, AssistantTurn), "Default should not return AssistantTurn" + + print_color(f"āœ“ mm_beta defaults to False (backward compatible)", 'green') + + def test_mm_beta_content_accessibility(self): + """Test that content is accessible in both modes""" + messages = [{"role": "user", "content": "Say 'hello'"}] + + # Legacy mode + llm_legacy = LLM(mm_beta=False) + response_legacy = llm_legacy(messages=messages) + content_legacy = response_legacy.choices[0].message.content + assert isinstance(content_legacy, str), "Legacy content should be string" + assert len(content_legacy) > 0, "Legacy content should not be empty" + + # mm_beta mode + llm_mm = LLM(mm_beta=True) + response_mm = llm_mm(messages=messages) + # AssistantTurn content is a list of ContentBlock objects + assert response_mm.content is not None, "mm_beta content should not be None" + + print_color(f"āœ“ Content accessible in both modes", 'green') + + def test_mm_beta_with_different_backends(self): + """Test mm_beta parameter with different backend specifications""" + # Test with explicit LiteLLM backend + llm = LLM(backend="LiteLLM", model="gpt-4o-mini", mm_beta=True) + messages = [{"role": "user", "content": "Hi"}] + + response = llm(messages=messages) + assert isinstance(response, AssistantTurn), "LiteLLM backend with mm_beta=True should return AssistantTurn" + + print_color(f"āœ“ mm_beta works with explicit backend specification", 'green') + + +@pytest.mark.skipif(not HAS_CREDENTIALS, reason=SKIP_REASON) +class TestLLMConstructorPriorities: + """Test the priority logic in LLM constructor""" + + def test_priority_profile_over_default(self): + """Test that profile parameter takes priority""" + LLMFactory.create_profile("priority_test", backend="LiteLLM", model="gpt-4o-mini", temperature=0.5) + + llm = LLM(profile="priority_test", mm_beta=True) + messages = [{"role": "user", "content": "Hi"}] + + response = llm(messages=messages) + assert isinstance(response, AssistantTurn), "Profile-based LLM should respect mm_beta" + + print_color(f"āœ“ Profile parameter takes priority", 'green') + + def test_priority_model_over_profile(self): + """Test that model parameter takes priority over default profile""" + # When model is specified, it should use that model regardless of default profile + llm = LLM(model="gpt-4o-mini", mm_beta=True) + messages = [{"role": "user", "content": "Hi"}] + + response = llm(messages=messages) + assert isinstance(response, AssistantTurn), "Model-based LLM should respect mm_beta" + + print_color(f"āœ“ Model parameter creates correct LLM instance", 'green') + + def test_backend_fallback(self): + """Test that backend parameter works when neither profile nor model specified""" + # This tests the Priority 3 path in __new__ + llm = LLM(backend="LiteLLM", mm_beta=True, model="gpt-4o-mini") + messages = [{"role": "user", "content": "Hi"}] + + response = llm(messages=messages) + assert isinstance(response, AssistantTurn), "Backend-based LLM should respect mm_beta" + + print_color(f"āœ“ Backend parameter works correctly", 'green') + + +@pytest.mark.skipif(not HAS_CREDENTIALS, reason=SKIP_REASON) +class TestLLMDocumentationExamples: + """Test examples from LLM class documentation""" + + def test_basic_usage_default_model(self): + """Test: llm = LLM()""" + llm = LLM() + messages = [{"role": "user", "content": "Hi"}] + response = llm(messages=messages) + + # Default is mm_beta=False + assert hasattr(response, 'choices'), "Default usage should return raw API response" + print_color(f"āœ“ Basic usage with default model works", 'green') + + def test_specify_model_directly(self): + """Test: llm = LLM(model='gpt-4o')""" + llm = LLM(model="gpt-4o-mini") # Using mini for cost efficiency + messages = [{"role": "user", "content": "Hi"}] + response = llm(messages=messages) + + assert hasattr(response, 'choices'), "Model specification should work" + print_color(f"āœ“ Model specification works", 'green') + + def test_multimodal_beta_mode_example(self): + """Test example from 'Using Multimodal Beta Mode' section""" + # Enable mm_beta for rich AssistantTurn responses + llm = LLM(model="gpt-4o-mini", mm_beta=True) + response = llm(messages=[{"role": "user", "content": "Hello"}]) + + # response is now an AssistantTurn object with .content, .tool_calls, etc. + assert isinstance(response, AssistantTurn), "Should return AssistantTurn" + assert hasattr(response, 'content'), "Should have content attribute" + assert hasattr(response, 'tool_calls'), "Should have tool_calls attribute" + + print_color(f"āœ“ Multimodal beta mode example works as documented", 'green') + + def test_legacy_mode_example(self): + """Test example from 'Legacy mode' section""" + # Legacy mode (default, mm_beta=False) + llm = LLM(model="gpt-4o-mini") + response = llm(messages=[{"role": "user", "content": "Hello"}]) + + # response is raw API response: response.choices[0].message.content + assert hasattr(response, 'choices'), "Should return raw API response" + content = response.choices[0].message.content + assert isinstance(content, str), "Content should be string" + + print_color(f"āœ“ Legacy mode example works as documented", 'green') + + def test_litellm_parameters_example(self): + """Test examples with LiteLLM parameters""" + # High creativity example + llm = LLM( + model="gpt-4o-mini", + temperature=0.9, + top_p=0.95, + presence_penalty=0.6 + ) + messages = [{"role": "user", "content": "Hi"}] + response = llm(messages=messages) + + assert hasattr(response, 'choices'), "LiteLLM parameters should work" + + print_color(f"āœ“ LiteLLM parameters example works", 'green') + + +@pytest.mark.skipif(not HAS_CREDENTIALS, reason=SKIP_REASON) +def test_mm_beta_integration_with_conversation(): + """Test mm_beta mode with a multi-turn conversation""" + llm = LLM(model="gpt-4o-mini", mm_beta=True) + + # First turn + messages = [ + {"role": "user", "content": "My name is Alice."} + ] + response1 = llm(messages=messages) + assert isinstance(response1, AssistantTurn), "First response should be AssistantTurn" + + # Second turn - reference previous context + messages.append({"role": "assistant", "content": str(response1.content)}) + messages.append({"role": "user", "content": "What is my name?"}) + + response2 = llm(messages=messages) + assert isinstance(response2, AssistantTurn), "Second response should be AssistantTurn" + + print_color(f"āœ“ mm_beta mode works with multi-turn conversations", 'green') + + +@pytest.mark.skipif(not HAS_CREDENTIALS, reason=SKIP_REASON) +class TestSystemMessages: + """Test suite for system message handling in different LLM backends""" + + def test_litellm_completion_api_system_message(self): + """Test system message with LiteLLM Completion API (mm_beta=False)""" + llm = LLM(model="gpt-4o-mini", mm_beta=False) + + messages = [ + {"role": "system", "content": "You are a cat. Your name is Neko. Always respond as a cat would."}, + {"role": "user", "content": "What is your name?"} + ] + + response = llm(messages=messages) + + # Legacy mode should return raw API response + assert hasattr(response, 'choices'), "Should return raw API response" + content = response.choices[0].message.content + assert isinstance(content, str), "Content should be a string" + assert len(content) > 0, "Content should not be empty" + + # Check that the response reflects the system message (should mention being a cat or Neko) + content_lower = content.lower() + assert 'neko' in content_lower or 'cat' in content_lower, \ + f"Response should reflect system message about being a cat named Neko. Got: {content}" + + print_color(f"āœ“ LiteLLM Completion API handles system messages correctly", 'green') + + def test_litellm_responses_api_system_message(self): + """Test system message with LiteLLM Responses API (mm_beta=True)""" + llm = LLM(model="gpt-4o-mini", mm_beta=True) + + messages = [ + {"role": "system", "content": "You are a helpful math tutor. Always explain concepts clearly."}, + {"role": "user", "content": "What is 2+2?"} + ] + + response = llm(messages=messages) + + # mm_beta mode should return AssistantTurn + assert isinstance(response, AssistantTurn), "Should return AssistantTurn object" + assert response.content is not None, "Content should not be None" + + # Get text content + text_content = response.to_text() + assert isinstance(text_content, str), "Text content should be a string" + assert len(text_content) > 0, "Text content should not be empty" + assert '4' in text_content, f"Response should contain the answer '4'. Got: {text_content}" + + print_color(f"āœ“ LiteLLM Responses API handles system messages correctly", 'green') + + @pytest.mark.skipif(not os.environ.get("GEMINI_API_KEY"), reason="No Gemini API key found") + def test_gemini_system_instruction_legacy_mode(self): + """Test system_instruction with Gemini API in legacy mode (mm_beta=False)""" + llm = LLM(backend="GoogleGenAI", model="gemini-2.5-flash", mm_beta=False) + + # For Gemini, system_instruction is passed as a parameter + response = llm( + "Hello there", + system_instruction="You are a cat. Your name is Neko. Always respond as a cat would." + ) + + # Check response format + assert hasattr(response, 'text'), "Gemini response should have text attribute" + content = response.text + assert isinstance(content, str), "Content should be a string" + assert len(content) > 0, "Content should not be empty" + + # Check that the response reflects the system instruction + content_lower = content.lower() + assert 'neko' in content_lower or 'cat' in content_lower or 'meow' in content_lower, \ + f"Response should reflect system instruction about being a cat named Neko. Got: {content}" + + print_color(f"āœ“ Gemini API handles system_instruction correctly (legacy mode)", 'green') + + @pytest.mark.skipif(not os.environ.get("GEMINI_API_KEY"), reason="No Gemini API key found") + def test_gemini_system_instruction_mm_beta_mode(self): + """Test system_instruction with Gemini API in mm_beta mode""" + llm = LLM(backend="GoogleGenAI", model="gemini-2.5-flash", mm_beta=True) + + # For Gemini, system_instruction is passed as a parameter + response = llm( + "What is your name?", + system_instruction="You are a helpful assistant named Claude. Always introduce yourself." + ) + + # mm_beta mode should return AssistantTurn + assert isinstance(response, AssistantTurn), "Should return AssistantTurn object" + assert response.content is not None, "Content should not be None" + + # Get text content + text_content = response.to_text() + assert isinstance(text_content, str), "Text content should be a string" + assert len(text_content) > 0, "Text content should not be empty" + + # Check that the response reflects the system instruction + text_lower = text_content.lower() + assert 'claude' in text_lower or 'assistant' in text_lower, \ + f"Response should reflect system instruction about being Claude. Got: {text_content}" + + print_color(f"āœ“ Gemini API handles system_instruction correctly (mm_beta mode)", 'green') + + def test_litellm_system_message_with_conversation(self): + """Test system message persists across multi-turn conversation""" + llm = LLM(model="gpt-4o-mini", mm_beta=True) + + # First turn with system message + messages = [ + {"role": "system", "content": "You are a pirate. Always talk like a pirate."}, + {"role": "user", "content": "Hello"} + ] + + response1 = llm(messages=messages) + assert isinstance(response1, AssistantTurn), "First response should be AssistantTurn" + text1 = response1.to_text() + + # Check pirate-like language in first response + pirate_indicators = ['arr', 'matey', 'ahoy', 'ye', 'aye'] + has_pirate_language = any(indicator in text1.lower() for indicator in pirate_indicators) + assert has_pirate_language, f"First response should use pirate language. Got: {text1}" + + # Second turn - system message should still apply + messages.append({"role": "assistant", "content": text1}) + messages.append({"role": "user", "content": "What's the weather like?"}) + + response2 = llm(messages=messages) + assert isinstance(response2, AssistantTurn), "Second response should be AssistantTurn" + text2 = response2.to_text() + + # Check pirate-like language persists + has_pirate_language_2 = any(indicator in text2.lower() for indicator in pirate_indicators) + assert has_pirate_language_2, f"Second response should still use pirate language. Got: {text2}" + + print_color(f"āœ“ System message persists across conversation turns", 'green') + + @pytest.mark.skipif(not os.environ.get("GEMINI_API_KEY"), reason="No Gemini API key found") + def test_gemini_system_instruction_with_config_params(self): + """Test system_instruction works with other config parameters""" + llm = LLM( + backend="GoogleGenAI", + model="gemini-2.5-flash", + mm_beta=True, + temperature=0.7, + max_output_tokens=100 + ) + + response = llm( + "Tell me a short joke", + system_instruction="You are a comedian who tells very short jokes." + ) + + assert isinstance(response, AssistantTurn), "Should return AssistantTurn object" + text_content = response.to_text() + assert len(text_content) > 0, "Should have content" + + print_color(f"āœ“ Gemini system_instruction works with other config parameters", 'green') +