Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

⚡️ Speed up _wrap_in_chain_factory() by 9,564% in libs/langchain/langchain/smith/evaluation/runner_utils.py #24

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

codeflash-ai[bot]
Copy link

@codeflash-ai codeflash-ai bot commented Feb 16, 2024

📄 _wrap_in_chain_factory() in libs/langchain/langchain/smith/evaluation/runner_utils.py

📈 Performance went up by 9,564% (95.64x faster)

⏱️ Runtime went down from 1662.31μs to 17.20μs

Explanation and details

(click to show)

Here is an optimized version of the given python program. The program calls the isinstance function less frequently, making it perform faster.

In the given task, the reduction in the number of the isinstance function invocations helps improve the performance of the code as isinstance is quite a slow operation. The order of the isinstance checks is also rearranged according to the likelihood of occurrences, which further optimizes the code performance.

Correctness verification

The new optimized code was tested for correctness. The results are listed below.

✅ 0 Passed − ⚙️ Existing Unit Tests

✅ 0 Passed − 🎨 Inspired Regression Tests

✅ 4 Passed − 🌀 Generated Regression Tests

(click to show generated tests)
# imports
import pytest
from typing import Callable, Union
from unittest.mock import Mock

# Assuming the existence of the following classes and functions
class Chain:
    def __init__(self, memory=None):
        self.memory = memory

class Runnable:
    pass

class BaseLanguageModel:
    pass

class RunnableLambda(Runnable):
    def __init__(self, func):
        self.func = func

def is_traceable_function(func):
    # Placeholder for actual implementation
    return hasattr(func, 'traceable')

def as_runnable(func):
    # Placeholder for actual implementation
    return RunnableLambda(func)

# function to test
# MODEL_OR_CHAIN_FACTORY and _wrap_in_chain_factory definitions go here
# (omitted for brevity as they are provided in the previous code block)

# unit tests

# Test Chain instance with memory should raise ValueError
def test_chain_with_memory_raises_error():
    chain_with_memory = Chain(memory="stateful_memory")
    with pytest.raises(ValueError) as excinfo:
        _wrap_in_chain_factory(chain_with_memory)
    assert "Cannot directly evaluate a chain with stateful memory" in str(excinfo.value)

# Test Chain instance without memory should return a lambda that returns the chain
def test_chain_without_memory_returns_lambda():
    chain_without_memory = Chain(memory=None)
    result = _wrap_in_chain_factory(chain_without_memory)
    assert callable(result)
    assert result() is chain_without_memory

# Test BaseLanguageModel instance should be returned as is
def test_base_language_model_returns_self():
    model = BaseLanguageModel()
    result = _wrap_in_chain_factory(model)
    assert result is model

# Test Runnable instance should return a lambda that returns the Runnable instance
def test_runnable_returns_lambda():
    runnable_instance = Runnable()
    result = _wrap_in_chain_factory(runnable_instance)
    assert callable(result)
    assert result() is runnable_instance

# Test callable returning Chain or Runnable should be invoked and result handled
def test_callable_returning_chain_or_runnable():
    chain = Chain(memory=None)
    chain_factory = Mock(return_value=chain)
    result = _wrap_in_chain_factory(chain_factory)
    assert callable(result)
    assert isinstance(result(), Chain)

# Test callable returning BaseLanguageModel should return instance directly
def test_callable_returning_base_language_model():
    model = BaseLanguageModel()
    model_factory = Mock(return_value=model)
    result = _wrap_in_chain_factory(model_factory)
    assert result is model

# Test callable returning traceable function should be wrapped in Runnable
def test_callable_returning_traceable_function():
    traceable_func = Mock()
    traceable_func.traceable = True
    result = _wrap_in_chain_factory(traceable_func)
    assert callable(result)
    assert isinstance(result(), RunnableLambda)

# Test callable with TypeError should be wrapped in RunnableLambda
def test_callable_with_type_error_wrapped_in_runnable_lambda():
    def user_func_with_args(arg1, arg2):
        pass
    result = _wrap_in_chain_factory(user_func_with_args)
    assert callable(result)
    wrapped = result()
    assert isinstance(wrapped, RunnableLambda)
    assert wrapped.func == user_func_with_args

# Test non-callable input should be returned as is
def test_non_callable_input_returns_self():
    non_callable = "non_callable_object"
    result = _wrap_in_chain_factory(non_callable)
    assert result == non_callable

# Additional edge case tests

# Test Chain subclass with overridden memory behavior
class CustomChain(Chain):
    @property
    def memory(self):
        return "custom_memory"

def test_custom_chain_raises_error():
    custom_chain = CustomChain()
    with pytest.raises(ValueError) as excinfo:
        _wrap_in_chain_factory(custom_chain)
    assert "Cannot directly evaluate a chain with stateful memory" in str(excinfo.value)

# Test callable that returns different types on subsequent calls
def test_callable_returning_different_types():
    def unpredictable_factory():
        if hasattr(unpredictable_factory, 'counter'):
            unpredictable_factory.counter += 1
        else:
            unpredictable_factory.counter = 1
        return Chain() if unpredictable_factory.counter % 2 == 0 else BaseLanguageModel()
    
    first_result = _wrap_in_chain_factory(unpredictable_factory)
    second_result = _wrap_in_chain_factory(unpredictable_factory)
    
    assert callable(first_result)
    assert isinstance(first_result(), BaseLanguageModel)
    assert callable(second_result)
    assert isinstance(second_result(), Chain)

# Test callable that is also an instance of a handled type
class CallableModel(BaseLanguageModel, Callable):
    def __call__(self):
        return self

def test_callable_model():
    callable_model = CallableModel()
    result = _wrap_in_chain_factory(callable_model)
    assert result is callable_model

@codeflash-ai codeflash-ai bot added the ⚡️ codeflash Optimization PR opened by CodeFlash AI label Feb 16, 2024
@codeflash-ai codeflash-ai bot requested a review from aphexcx February 16, 2024 09:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
⚡️ codeflash Optimization PR opened by CodeFlash AI
Projects
None yet
Development

Successfully merging this pull request may close these issues.

0 participants