> ## Documentation Index
> Fetch the complete documentation index at: https://lancedb-bcbb4faf-docs-hermes-agent-memory-integration.mintlify.site/llms.txt
> Use this file to discover all available pages before exploring further.

# PyTorch Integration

> Learn how to use LanceDB with PyTorch for training and inference.

LanceDB provides a seamless integration with PyTorch for training and inference. This allows you to use LanceDB as a backend for your PyTorch models, and to use PyTorch for training and inference. You can use LanceDB to store your data, and PyTorch to train your models.

## Quickstart

The `Table` class in LanceDB implements a contract for a PyTorch
[Dataset](https://docs.pytorch.org/docs/stable/data.html#torch.utils.data.Dataset).
This means you can simply use a LanceDB table in a PyTorch dataloader directly.

```py Python icon=Python  theme={"theme":{"light":"vitesse-light","dark":"catppuccin-mocha"}}
import lancedb
import torch
import pyarrow as pa
from lancedb.util import tbl_to_tensor

mem_db = lancedb.connect("memory://")
table = mem_db.create_table("test_table", pa.table({"a": range(1000)}))

# Any LanceDB table can be used as a PyTorch Dataset
dataloader = torch.utils.data.DataLoader(
    table, batch_size=1024, shuffle=True, collate_fn=tbl_to_tensor
)

for batch in dataloader:
    print(batch)
```

Although the `Table` class in LanceDB implements the `torch.utils.data.Dataset` interface, you may find that using
a table [Permutation](/training/) is more flexible.

```py Python icon=Python  theme={"theme":{"light":"vitesse-light","dark":"catppuccin-mocha"}}
from lancedb.permutation import Permutation

permutation = Permutation.identity(table)
dataloader = torch.utils.data.DataLoader(permutation)
```

## Output Formats

By default, a `Table` data loader will emit Arrow data. `collate_fn` is PyTorch's batching hook: PyTorch calls it to
turn the fetched items into one batch. PyTorch's default collate function only knows how to combine tensors, NumPy
arrays, numbers, dicts, and lists, so it does not accept Arrow data directly. When using a `Table` directly, pass
LanceDB's `lancedb.util.tbl_to_tensor` helper as PyTorch's `collate_fn`; it converts numeric Arrow columns into a
column-major `torch.Tensor` with shape `(columns, rows)`.

`Permutation` works differently: its default output is a list of Python dicts, which PyTorch's default collate function
can batch into a dict of tensors. This is usually more convenient when you are getting started. However, there is a
significant performance penalty converting from Arrow, Lance's internal representation, to this default format. Use a
direct `Table` with `collate_fn` when you want Arrow-to-tensor conversion, or a `Permutation` when you want the default
PyTorch dict-of-tensors behavior.

To address this, the `Permutation` class provides a set of builtin transform functions that can be applied to map
the Arrow data in different ways.  The `arrow` and `polars` formats will always avoid data copies.  However, `numpy`,
`pandas`, and `torch_col` formats will also avoid data copies in most cases.  The `python`, `python_col`, and
`torch` formats will all require at least one full copy of the data and are the slowest options.

### Using the torch\_col format with a torch data loader

The `torch_col` format is the most efficient way to convert from Arrow to a `torch.Tensor`.  It will convert the
entire Arrow batch to a *column-major* `torch.Tensor`.  In other words, given C columns and R rows, the resulting
Tensor will have shape `(C, R)`.  However, this format generates an error if you are using a
`torch.utils.data.DataLoader` with the default collation function:

```py Python icon=Python  theme={"theme":{"light":"vitesse-light","dark":"catppuccin-mocha"}}
TypeError: stack(): argument 'tensors' (position 1) must be tuple of Tensors, not Tensor
```

This error occurs because the default collation function does not currently expect a single two-dimensional tensor.
It expects a list of tensors which it will then stack.  This is what is output by the `torch` format but that format
requires a data copy.  To avoid this error, and avoid data copies, you will need to provide a custom collation function
in addition to specifying the `torch_col` format.

```py Python icon=Python  theme={"theme":{"light":"vitesse-light","dark":"catppuccin-mocha"}}
from lancedb.permutation import Permutation

permutation = Permutation.identity(table).with_format("torch_col")
dataloader = torch.utils.data.DataLoader(permutation, collate_fn=lambda x: x)
```

This will now output a single two-dimensional tensor for each batch.

## Selecting columns

By default, the `Table` class will return all columns in the table when used as input to PyTorch. If you only need
a subset of columns, you can significantly reduce your I/O requirements by selecting only the columns you need.  The
`Permutation` class provides a `select_columns` method that provides this functionality.

```py Python icon=Python  theme={"theme":{"light":"vitesse-light","dark":"catppuccin-mocha"}}
from lancedb.permutation import Permutation

permutation = Permutation.identity(table).select_columns(["id", "prompt"])
dataloader = torch.utils.data.DataLoader(
    permutation, batch_size=1024, shuffle=True
)

for batch in dataloader:
    print(batch.schema)
```

## Using multiple DataLoader workers

Set `num_workers > 0` to read from LanceDB in multiple PyTorch worker processes. LanceDB tables and `Permutation` objects are picklable, so each worker reopens the table after it starts.

Prefer the `spawn` start method when using multiple workers; LanceDB uses internal threads. See [the performance guide](/performance) for more multiprocessing guidance.

```py Python icon=Python  theme={"theme":{"light":"vitesse-light","dark":"catppuccin-mocha"}}
import torch
from lancedb.permutation import Permutation

permutation = Permutation.identity(table)
dataloader = torch.utils.data.DataLoader(
    permutation,
    batch_size=1024,
    shuffle=True,
    num_workers=4,
    multiprocessing_context="spawn",
    persistent_workers=True,
)
```

### Remote tables in DataLoader workers

Remote LanceDB Enterprise tables (`db://...`) work the same way: workers reopen the table from the pickled connection state.

```py Python icon=Python  theme={"theme":{"light":"vitesse-light","dark":"catppuccin-mocha"}}
import lancedb
import torch
from lancedb.util import tbl_to_tensor

db = lancedb.connect(
    "db://my-database",
    api_key="sk-...",
    region="us-east-1",
)
table = db.open_table("my_table")

dataloader = torch.utils.data.DataLoader(
    table,
    batch_size=512,
    num_workers=4,
    multiprocessing_context="spawn",
    collate_fn=tbl_to_tensor,
)
```

<Note>
  This sends the connection state, including the API key, to each worker. Use a connection factory if credentials should be loaded inside the worker or your `client_config` contains a non-serializable `header_provider`.
</Note>

### Providing a custom connection factory

`Permutation.with_connection_factory` lets each worker reopen the base table with custom logic. The factory takes the table name, returns a LanceDB table, and must be picklable.

```py Python icon=Python  theme={"theme":{"light":"vitesse-light","dark":"catppuccin-mocha"}}
import os
import lancedb
import torch
from lancedb.permutation import Permutation

def open_table(name: str):
    db = lancedb.connect(
        "db://my-database",
        api_key=os.environ["LANCEDB_API_KEY"],
        region="us-east-1",
    )
    return db.open_table(name)

table = open_table("my_table")
permutation = (
    Permutation.identity(table)
    .with_connection_factory(open_table)
)
dataloader = torch.utils.data.DataLoader(
    permutation,
    batch_size=512,
    num_workers=4,
    multiprocessing_context="spawn",
)
```
