Skip to main content
LanceDB provides a seamless integration with PyTorch for training and inference. This allows you to use LanceDB as a backend for your PyTorch models, and to use PyTorch for training and inference. You can use LanceDB to store your data, and PyTorch to train your models.

Quickstart

The Table class in LanceDB implements a contract for a PyTorch Dataset. This means you can simply use a LanceDB table in a PyTorch dataloader directly.
Python
import lancedb
import torch
import pyarrow as pa
from lancedb.util import tbl_to_tensor

mem_db = lancedb.connect("memory://")
table = mem_db.create_table("test_table", pa.table({"a": range(1000)}))

# Any LanceDB table can be used as a PyTorch Dataset
dataloader = torch.utils.data.DataLoader(
    table, batch_size=1024, shuffle=True, collate_fn=tbl_to_tensor
)

for batch in dataloader:
    print(batch)
Although the Table class in LanceDB implements the torch.utils.data.Dataset interface, you may find that using a table Permutation is more flexible.
Python
from lancedb.permutation import Permutation

permutation = Permutation.identity(table)
dataloader = torch.utils.data.DataLoader(permutation)

Output Formats

By default, a Table data loader will emit Arrow data. collate_fn is PyTorch’s batching hook: PyTorch calls it to turn the fetched items into one batch. PyTorch’s default collate function only knows how to combine tensors, NumPy arrays, numbers, dicts, and lists, so it does not accept Arrow data directly. When using a Table directly, pass LanceDB’s lancedb.util.tbl_to_tensor helper as PyTorch’s collate_fn; it converts numeric Arrow columns into a column-major torch.Tensor with shape (columns, rows). Permutation works differently: its default output is a list of Python dicts, which PyTorch’s default collate function can batch into a dict of tensors. This is usually more convenient when you are getting started. However, there is a significant performance penalty converting from Arrow, Lance’s internal representation, to this default format. Use a direct Table with collate_fn when you want Arrow-to-tensor conversion, or a Permutation when you want the default PyTorch dict-of-tensors behavior. To address this, the Permutation class provides a set of builtin transform functions that can be applied to map the Arrow data in different ways. The arrow and polars formats will always avoid data copies. However, numpy, pandas, and torch_col formats will also avoid data copies in most cases. The python, python_col, and torch formats will all require at least one full copy of the data and are the slowest options.

Using the torch_col format with a torch data loader

The torch_col format is the most efficient way to convert from Arrow to a torch.Tensor. It will convert the entire Arrow batch to a column-major torch.Tensor. In other words, given C columns and R rows, the resulting Tensor will have shape (C, R). However, this format generates an error if you are using a torch.utils.data.DataLoader with the default collation function:
Python
TypeError: stack(): argument 'tensors' (position 1) must be tuple of Tensors, not Tensor
This error occurs because the default collation function does not currently expect a single two-dimensional tensor. It expects a list of tensors which it will then stack. This is what is output by the torch format but that format requires a data copy. To avoid this error, and avoid data copies, you will need to provide a custom collation function in addition to specifying the torch_col format.
Python
from lancedb.permutation import Permutation

permutation = Permutation.identity(table).with_format("torch_col")
dataloader = torch.utils.data.DataLoader(permutation, collate_fn=lambda x: x)
This will now output a single two-dimensional tensor for each batch.

Selecting columns

By default, the Table class will return all columns in the table when used as input to PyTorch. If you only need a subset of columns, you can significantly reduce your I/O requirements by selecting only the columns you need. The Permutation class provides a select_columns method that provides this functionality.
Python
from lancedb.permutation import Permutation

permutation = Permutation.identity(table).select_columns(["id", "prompt"])
dataloader = torch.utils.data.DataLoader(
    permutation, batch_size=1024, shuffle=True
)

for batch in dataloader:
    print(batch.schema)

Using multiple DataLoader workers

Set num_workers > 0 to read from LanceDB in multiple PyTorch worker processes. LanceDB tables and Permutation objects are picklable, so each worker reopens the table after it starts. Prefer the spawn start method when using multiple workers; LanceDB uses internal threads. See the performance guide for more multiprocessing guidance.
Python
import torch
from lancedb.permutation import Permutation

permutation = Permutation.identity(table)
dataloader = torch.utils.data.DataLoader(
    permutation,
    batch_size=1024,
    shuffle=True,
    num_workers=4,
    multiprocessing_context="spawn",
    persistent_workers=True,
)

Remote tables in DataLoader workers

Remote LanceDB Enterprise tables (db://...) work the same way: workers reopen the table from the pickled connection state.
Python
import lancedb
import torch
from lancedb.util import tbl_to_tensor

db = lancedb.connect(
    "db://my-database",
    api_key="sk-...",
    region="us-east-1",
)
table = db.open_table("my_table")

dataloader = torch.utils.data.DataLoader(
    table,
    batch_size=512,
    num_workers=4,
    multiprocessing_context="spawn",
    collate_fn=tbl_to_tensor,
)
This sends the connection state, including the API key, to each worker. Use a connection factory if credentials should be loaded inside the worker or your client_config contains a non-serializable header_provider.

Providing a custom connection factory

Permutation.with_connection_factory lets each worker reopen the base table with custom logic. The factory takes the table name, returns a LanceDB table, and must be picklable.
Python
import os
import lancedb
import torch
from lancedb.permutation import Permutation

def open_table(name: str):
    db = lancedb.connect(
        "db://my-database",
        api_key=os.environ["LANCEDB_API_KEY"],
        region="us-east-1",
    )
    return db.open_table(name)

table = open_table("my_table")
permutation = (
    Permutation.identity(table)
    .with_connection_factory(open_table)
)
dataloader = torch.utils.data.DataLoader(
    permutation,
    batch_size=512,
    num_workers=4,
    multiprocessing_context="spawn",
)