Learning how to type properly

Prime says it like it is with this double entendre - fast on the keyboard, expressive with types.

I’ve only been learning rust for a few months and it is already my most favourite language. Before rust, I never bothered with type annotations. I frankly thought they were obnoxious especially since it doesn’t not improve performance. But Rust’s rich type system has completely converted me to a “type-ist”.

Python’s typing module is no where near Rust’s type system, but that doesn’t mean you shouldn’t use it. By embracing typing, we can catch errors early, enhance code maintainability, and foster seamless collaboration among team members.

Typing is documentation

Using type hints makes your code clear and expressive. Just reading the function type signature should give you a pretty good idea of what the function does and the flow from the input types to the output types. There is also no need to write long docstring to explain what each input type is and what the return type is.

Typing is speed

Diligently using type signatures for bigger projects can help you stay focused. By leveraging LSP’s capabilities and static type checkers, you get immediate feedback if you’re not using the right types and a simple hover over the function name can remind you how to use the function.

def mean_of_arr(np_arr: np.ndarray) -> np.floating:
    return np.mean(np_arr)

arr = [1,2,3,4,5]
mean = mean_of_numpy_arr(arr) 
Even before you run the code, the inlay hints tell you what is wrong: ```python Pyright: Argument of type "list[int]" cannot be assigned to parameter "np_arr" of type "ndarray[Unknown, Unknown]" in function "mean_of_numpy_arr" ```

Why I don’t like python’s Union type

It is good but not good enough. There is ambiguity around what it actually means. Lets say your function return a list of a mixture of strings and integers. Then list[int | str]] is great because it captures the return type accurately.

But lets consider the case where a function either returns a list of strings or a list of integers. How do we write the return type here?

  1. list[int | str]]
    • This is the same type as the previous example. While this works, you probably don’t want to write it like this because it does not convey what your function does.
  2. list[int] | list[str]
    • This is more accurate because its saying that the function could either return a list full of ints or a list full of str

While we know the possible return types of a function, it still doesn’t tell us which one exactly we should be expecting. Many-a-times, the input types will determine the output type (counter example: randomly choosing the return type in the function body). So how do we encode this knowledge into the type checker?

Type narrowing

Type narrowing helps to guide the type checker to understand what the return type is given some conditions. One way to do this is to make use of conditionals and isinstance() . If x is of type str | int and we want to narrow it down, we can check if isinstance(x, int) then do something, likewise for the str variant. But this can look quite ugly as we would have to write these checks to narrow down the type before proceeding to use this value for future steps.

There is a much better way to achieve this task using the @overload decorator from the typing module.

from typing import Literal, overload
from typing_extentions import reveal_type
import numpy as np

@overload
def recursive_add(items: list[str]) -> str: ...
@overload
def recursive_add(items: list[int]) -> int: ...
@overload
def recursive_add(items: list[float]) -> float: ...

def recursive_add(items: list[str] | list[int] | list[float]) -> str | int | float:
    if len(items) == 1:
	    return items[0]
	return items[0] + recursive_add(items[1:])

Starts to look a lot like java doesn’t it? However, here the overloads have no function bodies, they are just alternative type signatures for the same function.

There are some rules to follow to use @overload:

  1. There must be at least 2 @overload function signatures for a function.
  2. The original function must come after all the overloaded variants
  3. The overloaded functions have no function bodies. It uses the Ellipses type …

Be more expressive with Literals

Sometimes we write functions that take in “modes” and performs some operations depending on the mode. For example this function can either do max or mean pooling and return the result as an numpy array or an array of integers if the mode is max pooling, or an array of floats if the mode is mean pooling.

def pooling_strategy(vecs: list[list[int]], mode: str, return_as_numpy: bool) -> list[float] | np.ndarray | list[int]:																		
	if mode == 'max':
		# do the max pooling over the columns and keep the result as list[int]
		pooled = max_pool(vecs)
	elif mode == 'mean':
		# do the mean pooling over the columns and keep the result as list[float]
		pooled = mean_pool(vecs)
	else:
		raise Exception("invalid mode")

	if return_as_numpy:
		return np.array(pooled)
	return pooled

Trying to write overloads to account for the various input combinations and output types will not work:

@overload
def pooling_strategy(vecs: list[list[int]], mode: str = 'max', return_as_numpy: bool = True) -> np.ndarray: ...

@overload
def pooling_strategy(vecs: list[list[int]], mode: str = 'mean', return_as_numpy: bool = True) -> np.ndarray: ...

@overload
def pooling_strategy(vecs: list[list[int]], mode: str = 'max', return_as_numpy: bool = False) -> list[int]: ...

@overload
def pooling_strategy(vecs: list[list[int]], mode: str = 'mean', return_as_numpy: bool = False) -> list[float]: ...

mypy is immediately unhappy about this and gives error messages like this:

mypy: Overloaded function signature 3 will never be matched: signature 1's parameter type(s) are the same or broader [misc]
mypy: Overloaded function signature 3 will never be matched: signature 2's parameter type(s) are the same or broader [misc]

The error message tells us that the second, third and fourth overload variants will never be reached because the first one will always be selected. It will always be a numpy array type. This is because the @overload operator works by looking for differences in the input types, not the input values.

So we need some way of turning these “modes” into different types, and thats exactly what Literal does. Literal are useful in encapsulating Enum-like behaviour. There are only so many “modes”. Leaving mode as a str is not clear enough because unless there is more documentation on how to use it. Literal also ensures type safety because mypy will guide you on how to use the function with autocomplete and suggestions. Even if you’re not a fan of the @overload operator, I hope you see the power of Literal.

Lets rewrite this function with Literals and see how the @overloads should be written.

@overload
def pooling_strategy(vecs: list[list[int]], mode: Literal['max'], return_as_numpy: Literal[True]) -> np.ndarray: ...

@overload
def pooling_strategy(vecs: list[list[int]], mode: Literal['mean'], return_as_numpy: Literal[True]) -> np.ndarray: ...

@overload
def pooling_strategy(vecs: list[list[int]], mode: Literal['max'], return_as_numpy: Literal[False]) -> list[int]: ...

@overload
def pooling_strategy(vecs: list[list[int]], mode: Literal['mean'], return_as_numpy: Literal[False]) -> list[float]: ...

def pooling_strategy(vecs: list[list[int]], mode: Literal['max', 'mean'], return_as_numpy: bool) -> list[float] | np.ndarray | list[int]:																		
	match mode:
		case 'max':
			# do the max pooling over the columns and keep the result as list[int]
			pooled = max_pool(vecs)
	case 'mean'
			# do the mean pooling over the columns and keep the result as list[float]
			pooled = mean_pool(vecs)
	case _:
		raise Exception("invalid mode")

	if return_as_numpy:
		return np.array(pooled)
	return pooled

Now the type checker knows what is going on and the code is a lot more readable as well. Mypy will also remind you to handle all the cases for matching if you’re trying to match against an array of Literal types, much like rust. _Notice that Literal can be used for bool variants as well!

The need for stubs

I went down this whole rabbit hole of proper type hinting because of SentenceTransformers. Lets take a look at their encode function. (I have redacted certain sections, with ellipses, that are not important for this example)

def encode(self, sentences: Union[str, List[str]],
               batch_size: int = 32,
               show_progress_bar: bool = None,
               output_value: str = 'sentence_embedding',
               convert_to_numpy: bool = True,
               convert_to_tensor: bool = False,
               device: str = None,
               normalize_embeddings: bool = False) -> Union[List[Tensor], ndarray, Tensor]:
        """
        Computes sentence embeddings
        :param sentences: the sentences to embed
        :param batch_size: the batch size used for the computation
        :param show_progress_bar: Output a progress bar when encode sentences
        :param output_value:  Default sentence_embedding, to get sentence embeddings. Can be set to token_embeddings to get wordpiece token embeddings. Set to None, to get all output values
        :param convert_to_numpy: If true, the output is a list of numpy vectors. Else, it is a list of pytorch tensors.
        :param convert_to_tensor: If true, you get one large tensor as return. Overwrites any setting from convert_to_numpy
        :param device: Which torch.device to use for the computation
        :param normalize_embeddings: If set to true, returned vectors will have length 1. In that case, the faster dot-product (util.dot_score) instead of cosine similarity can be used.
        :return:
           By default, a list of tensors is returned. If convert_to_tensor, a stacked tensor is returned. If convert_to_numpy, a numpy matrix is returned.
        """
        ...

        if convert_to_tensor:
            convert_to_numpy = False

        if output_value != 'sentence_embedding':
            convert_to_tensor = False
            convert_to_numpy = False

        input_was_string = False
        if isinstance(sentences, str) or not hasattr(sentences, '__len__'): #Cast an individual sentence to a list with length 1
            sentences = [sentences]
            input_was_string = True
            
        ...

        if convert_to_tensor:
            all_embeddings = torch.stack(all_embeddings)
        elif convert_to_numpy:
            all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])

        if input_was_string:
            all_embeddings = all_embeddings[0]

        return all_embeddings

There are a bunch of issues here:

  1. we see from the inputs that convert_to_numpy: bool = True but look at the return type from the docstring:By default, a list of tensors is returned. If convert_to_tensor, a stacked tensor is returned. If convert_to_numpy, a numpy matrix is returned. 🤷🏽‍♂️🤷🏽‍♂️🤷🏽‍♂️
  2. The function inputs output_value, convert_to_tensor and convert_to_numpy are stepping on each other’s toes. Consolidating it to a return_type Literal might simplify things…
  3. We know that this actually returns numpy arrays by default. Thats fine. But the return type of this function is a Union. Despite having some logic in the function body to check if convert_to_numpy is True, it still gives us a Union that we have to narrow down to a numpy array before passing it to other functions that expect a numpy type.

This is a perfect example to refactor and use overloads

@overload
def encode(self, sentences: str | list[str],
           batch_size: int,
           show_progress_bar: bool,
           output_value: str,
           convert_to_numpy: Literal[True],
           convert_to_tensor: Literal[False], device:str, normalize_embeddings: bool) -> np.ndarray: ...

@overload
def encode(self, sentences: str | list[str],
           batch_size: int,
           show_progress_bar: bool,
           output_value: str,
           convert_to_numpy: Literal[False],
           convert_to_tensor: Literal[True], device:str, normalize_embeddings: bool) -> Tensor: ...

@overload
def encode(self, sentences: str | list[str],
           batch_size: int,
           show_progress_bar: bool,
           output_value: str,
           convert_to_numpy: Literal[False],
           convert_to_tensor: Literal[False], device:str, normalize_embeddings: bool) -> list[Tensor]: ...

Now we handle all the cases of the return type, aligning the return type logic with the code logic.


The good people of the open source community help to write stubs so that we can enjoy a smooth developer experience. Numpy for example has really nice stubs and a typing API to help with writing nicely typed code. One feature that will be useful is Shape Typing - to annotate with shapes of tensors/arrays. This will help us move towards knowing what the shape will be while writing code as opposed to running code and finding out. Shape Typing also serves as concise documentation

Why bother with all these type stuff? I can just run the code and find out.

Yes, in most cases this is not a bad way to go. But what if your code does some serious processing and takes a few seconds to a few minutes to run? That would really slow you down a lot…

Use pyright and mypy to show you the type!

Thanks for reading!

References:

https://mypy.readthedocs.io/en/stable/literal_types.html#exhaustiveness-checking

https://mypy.readthedocs.io/en/stable/more_types.html#type-checking-calls-to-overloads

https://adamj.eu/tech/2021/05/17/python-type-hints-how-to-narrow-types-with-isinstance-assert-literal/