So... Whats your type đ?
Learning how to type properly
Prime says it like it is with this double entendre - fast on the keyboard, expressive with types.
as a programmer you should learn how to type properly
â ThePrimeagen (@ThePrimeagen) April 21, 2023
Iâve only been learning rust for a few months and it is already my most favourite language. Before rust, I never bothered with type annotations. I frankly thought they were obnoxious especially since it doesnât not improve performance. But Rustâs rich type system has completely converted me to a âtype-istâ.
Pythonâs typing module is no where near Rustâs type system, but that doesnât mean you shouldnât use it. By embracing typing, we can catch errors early, enhance code maintainability, and foster seamless collaboration among team members.
Typing is documentation
Using type hints makes your code clear and expressive. Just reading the function type signature should give you a pretty good idea of what the function does and the flow from the input types to the output types. There is also no need to write long docstring to explain what each input type is and what the return type is.
Typing is speed
Diligently using type signatures for bigger projects can help you stay focused. By leveraging LSPâs capabilities and static type checkers, you get immediate feedback if youâre not using the right types and a simple hover over the function name can remind you how to use the function.
def mean_of_arr(np_arr: np.ndarray) -> np.floating:
return np.mean(np_arr)
arr = [1,2,3,4,5]
mean = mean_of_numpy_arr(arr)
Even before you run the code, the inlay hints tell you what is wrong: ```python Pyright: Argument of type "list[int]" cannot be assigned to parameter "np_arr" of type "ndarray[Unknown, Unknown]" in function "mean_of_numpy_arr" ```
Why I donât like pythonâs Union type
It is good but not good enough. There is ambiguity around what it actually means. Lets say your function return a list of a mixture of strings and integers. Then list[int | str]]
is great because it captures the return type accurately.
But lets consider the case where a function either returns a list of strings or a list of integers. How do we write the return type here?
-
list[int | str]]
- This is the same type as the previous example. While this works, you probably donât want to write it like this because it does not convey what your function does.
-
list[int] | list[str]
- This is more accurate because its saying that the function could either return a list full of ints or a list full of str
While we know the possible return types of a function, it still doesnât tell us which one exactly we should be expecting. Many-a-times, the input types will determine the output type (counter example: randomly choosing the return type in the function body). So how do we encode this knowledge into the type checker?
Type narrowing
Type narrowing helps to guide the type checker to understand what the return type is given some conditions. One way to do this is to make use of conditionals and isinstance()
. If x is of type str | int
and we want to narrow it down, we can check if isinstance(x, int)
then do something, likewise for the str
variant. But this can look quite ugly as we would have to write these checks to narrow down the type before proceeding to use this value for future steps.
There is a much better way to achieve this task using the @overload
decorator from the typing module.
from typing import Literal, overload
from typing_extentions import reveal_type
import numpy as np
@overload
def recursive_add(items: list[str]) -> str: ...
@overload
def recursive_add(items: list[int]) -> int: ...
@overload
def recursive_add(items: list[float]) -> float: ...
def recursive_add(items: list[str] | list[int] | list[float]) -> str | int | float:
if len(items) == 1:
return items[0]
return items[0] + recursive_add(items[1:])
Starts to look a lot like java doesnât it? However, here the overloads have no function bodies, they are just alternative type signatures for the same function.
There are some rules to follow to use @overload
:
- There must be at least 2 @overload function signatures for a function.
- The original function must come after all the overloaded variants
- The overloaded functions have no function bodies. It uses the Ellipses type
âŚ
Be more expressive with Literals
Sometimes we write functions that take in âmodesâ and performs some operations depending on the mode. For example this function can either do max or mean pooling and return the result as an numpy array or an array of integers if the mode is max pooling, or an array of floats if the mode is mean pooling.
def pooling_strategy(vecs: list[list[int]], mode: str, return_as_numpy: bool) -> list[float] | np.ndarray | list[int]:
if mode == 'max':
# do the max pooling over the columns and keep the result as list[int]
pooled = max_pool(vecs)
elif mode == 'mean':
# do the mean pooling over the columns and keep the result as list[float]
pooled = mean_pool(vecs)
else:
raise Exception("invalid mode")
if return_as_numpy:
return np.array(pooled)
return pooled
Trying to write overloads to account for the various input combinations and output types will not work:
@overload
def pooling_strategy(vecs: list[list[int]], mode: str = 'max', return_as_numpy: bool = True) -> np.ndarray: ...
@overload
def pooling_strategy(vecs: list[list[int]], mode: str = 'mean', return_as_numpy: bool = True) -> np.ndarray: ...
@overload
def pooling_strategy(vecs: list[list[int]], mode: str = 'max', return_as_numpy: bool = False) -> list[int]: ...
@overload
def pooling_strategy(vecs: list[list[int]], mode: str = 'mean', return_as_numpy: bool = False) -> list[float]: ...
mypy is immediately unhappy about this and gives error messages like this:
mypy: Overloaded function signature 3 will never be matched: signature 1's parameter type(s) are the same or broader [misc]
mypy: Overloaded function signature 3 will never be matched: signature 2's parameter type(s) are the same or broader [misc]
The error message tells us that the second, third and fourth overload variants will never be reached because the first one will always be selected. It will always be a numpy array type. This is because the @overload operator works by looking for differences in the input types, not the input values.
So we need some way of turning these âmodesâ into different types, and thats exactly what Literal
does. Literal
are useful in encapsulating Enum-like behaviour. There are only so many âmodesâ. Leaving mode as a str
is not clear enough because unless there is more documentation on how to use it. Literal
also ensures type safety because mypy will guide you on how to use the function with autocomplete and suggestions. Even if youâre not a fan of the @overload
operator, I hope you see the power of Literal
.
Lets rewrite this function with Literals and see how the @overloads
should be written.
@overload
def pooling_strategy(vecs: list[list[int]], mode: Literal['max'], return_as_numpy: Literal[True]) -> np.ndarray: ...
@overload
def pooling_strategy(vecs: list[list[int]], mode: Literal['mean'], return_as_numpy: Literal[True]) -> np.ndarray: ...
@overload
def pooling_strategy(vecs: list[list[int]], mode: Literal['max'], return_as_numpy: Literal[False]) -> list[int]: ...
@overload
def pooling_strategy(vecs: list[list[int]], mode: Literal['mean'], return_as_numpy: Literal[False]) -> list[float]: ...
def pooling_strategy(vecs: list[list[int]], mode: Literal['max', 'mean'], return_as_numpy: bool) -> list[float] | np.ndarray | list[int]:
match mode:
case 'max':
# do the max pooling over the columns and keep the result as list[int]
pooled = max_pool(vecs)
case 'mean'
# do the mean pooling over the columns and keep the result as list[float]
pooled = mean_pool(vecs)
case _:
raise Exception("invalid mode")
if return_as_numpy:
return np.array(pooled)
return pooled
Now the type checker knows what is going on and the code is a lot more readable as well. Mypy will also remind you to handle all the cases for matching if youâre trying to match against an array of Literal
types, much like rust. _Notice that Literal
can be used for bool
variants as well!
The need for stubs
I went down this whole rabbit hole of proper type hinting because of SentenceTransformers. Lets take a look at their encode
function. (I have redacted certain sections, with ellipses, that are not important for this example)
def encode(self, sentences: Union[str, List[str]],
batch_size: int = 32,
show_progress_bar: bool = None,
output_value: str = 'sentence_embedding',
convert_to_numpy: bool = True,
convert_to_tensor: bool = False,
device: str = None,
normalize_embeddings: bool = False) -> Union[List[Tensor], ndarray, Tensor]:
"""
Computes sentence embeddings
:param sentences: the sentences to embed
:param batch_size: the batch size used for the computation
:param show_progress_bar: Output a progress bar when encode sentences
:param output_value: Default sentence_embedding, to get sentence embeddings. Can be set to token_embeddings to get wordpiece token embeddings. Set to None, to get all output values
:param convert_to_numpy: If true, the output is a list of numpy vectors. Else, it is a list of pytorch tensors.
:param convert_to_tensor: If true, you get one large tensor as return. Overwrites any setting from convert_to_numpy
:param device: Which torch.device to use for the computation
:param normalize_embeddings: If set to true, returned vectors will have length 1. In that case, the faster dot-product (util.dot_score) instead of cosine similarity can be used.
:return:
By default, a list of tensors is returned. If convert_to_tensor, a stacked tensor is returned. If convert_to_numpy, a numpy matrix is returned.
"""
...
if convert_to_tensor:
convert_to_numpy = False
if output_value != 'sentence_embedding':
convert_to_tensor = False
convert_to_numpy = False
input_was_string = False
if isinstance(sentences, str) or not hasattr(sentences, '__len__'): #Cast an individual sentence to a list with length 1
sentences = [sentences]
input_was_string = True
...
if convert_to_tensor:
all_embeddings = torch.stack(all_embeddings)
elif convert_to_numpy:
all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
if input_was_string:
all_embeddings = all_embeddings[0]
return all_embeddings
There are a bunch of issues here:
- we see from the inputs that
convert_to_numpy: bool = True
but look at the return type from the docstring:By default, a list of tensors is returned. If convert_to_tensor, a stacked tensor is returned. If convert_to_numpy, a numpy matrix is returned.
đ¤ˇđ˝ââď¸đ¤ˇđ˝ââď¸đ¤ˇđ˝ââď¸ - The function inputs
output_value
,convert_to_tensor
andconvert_to_numpy
are stepping on each otherâs toes. Consolidating it to areturn_type
Literal
might simplify things⌠- We know that this actually returns numpy arrays by default. Thats fine. But the return type of this function is a
Union
. Despite having some logic in the function body to check ifconvert_to_numpy
isTrue
, it still gives us aUnion
that we have to narrow down to a numpy array before passing it to other functions that expect a numpy type.
This is a perfect example to refactor and use overloads
@overload
def encode(self, sentences: str | list[str],
batch_size: int,
show_progress_bar: bool,
output_value: str,
convert_to_numpy: Literal[True],
convert_to_tensor: Literal[False], device:str, normalize_embeddings: bool) -> np.ndarray: ...
@overload
def encode(self, sentences: str | list[str],
batch_size: int,
show_progress_bar: bool,
output_value: str,
convert_to_numpy: Literal[False],
convert_to_tensor: Literal[True], device:str, normalize_embeddings: bool) -> Tensor: ...
@overload
def encode(self, sentences: str | list[str],
batch_size: int,
show_progress_bar: bool,
output_value: str,
convert_to_numpy: Literal[False],
convert_to_tensor: Literal[False], device:str, normalize_embeddings: bool) -> list[Tensor]: ...
Now we handle all the cases of the return type, aligning the return type logic with the code logic.
The good people of the open source community help to write stubs so that we can enjoy a smooth developer experience. Numpy for example has really nice stubs and a typing API to help with writing nicely typed code. One feature that will be useful is Shape Typing - to annotate with shapes of tensors/arrays. This will help us move towards knowing what the shape will be while writing code as opposed to running code and finding out. Shape Typing also serves as concise documentation
Why bother with all these type stuff? I can just run the code and find out.
Yes, in most cases this is not a bad way to go. But what if your code does some serious processing and takes a few seconds to a few minutes to run? That would really slow you down a lotâŚ
Use pyright and mypy to show you the type!
Thanks for reading!
References:
https://mypy.readthedocs.io/en/stable/literal_types.html#exhaustiveness-checking
https://mypy.readthedocs.io/en/stable/more_types.html#type-checking-calls-to-overloads
Enjoy Reading This Article?
Here are some more articles you might like to read next: