dora.cuda
TODO: Add docstring.
1"""TODO: Add docstring.""" 2 3import pyarrow as pa 4 5# Make sure to install torch with cuda 6import torch 7from numba.cuda import to_device 8 9# Make sure to install numba with cuda 10from numba.cuda.cudadrv.devicearray import DeviceNDArray 11from numba.cuda.cudadrv.devices import get_context 12from numba.cuda.cudadrv.driver import IpcHandle 13 14 15import json 16 17from contextlib import contextmanager 18from typing import ContextManager 19 20 21def torch_to_ipc_buffer(tensor: torch.TensorType) -> tuple[pa.array, dict]: 22 """Convert a Pytorch tensor into a pyarrow buffer containing the IPC handle 23 and its metadata. 24 25 Example Use: 26 ```python 27 torch_tensor = torch.tensor(random_data, dtype=torch.int64, device="cuda") 28 ipc_buffer, metadata = torch_to_ipc_buffer(torch_tensor) 29 node.send_output("latency", ipc_buffer, metadata) 30 ``` 31 """ 32 device_arr = to_device(tensor) 33 ipch = get_context().get_ipc_handle(device_arr.gpu_data) 34 _, handle, size, source_info, offset = ipch.__reduce__()[1] 35 metadata = { 36 "shape": device_arr.shape, 37 "strides": device_arr.strides, 38 "dtype": device_arr.dtype.str, 39 "size": size, 40 "offset": offset, 41 "source_info": json.dumps(source_info), 42 } 43 return pa.array(handle, pa.int8()), metadata 44 45 46def ipc_buffer_to_ipc_handle(handle_buffer: pa.array, metadata: dict) -> IpcHandle: 47 """Convert a buffer containing a serialized handler into cuda IPC Handle. 48 49 example use: 50 ```python 51 from dora.cuda import ipc_buffer_to_ipc_handle, open_ipc_handle 52 53 event = node.next() 54 55 ipc_handle = ipc_buffer_to_ipc_handle(event["value"], event["metadata"]) 56 with open_ipc_handle(ipc_handle, event["metadata"]) as tensor: 57 pass 58 ``` 59 """ 60 handle = handle_buffer.to_pylist() 61 return IpcHandle._rebuild( 62 handle, 63 metadata["size"], 64 json.loads(metadata["source_info"]), 65 metadata["offset"], 66 ) 67 68 69@contextmanager 70def open_ipc_handle( 71 ipc_handle: IpcHandle, metadata: dict 72) -> ContextManager[torch.TensorType]: 73 """Open a CUDA IPC handle and return a Pytorch tensor. 74 75 example use: 76 ```python 77 from dora.cuda import ipc_buffer_to_ipc_handle, open_ipc_handle 78 79 event = node.next() 80 81 ipc_handle = ipc_buffer_to_ipc_handle(event["value"], event["metadata"]) 82 with open_ipc_handle(ipc_handle, event["metadata"]) as tensor: 83 pass 84 ``` 85 """ 86 shape = metadata["shape"] 87 strides = metadata["strides"] 88 dtype = metadata["dtype"] 89 try: 90 buffer = ipc_handle.open(get_context()) 91 device_arr = DeviceNDArray(shape, strides, dtype, gpu_data=buffer) 92 yield torch.as_tensor(device_arr, device="cuda") 93 finally: 94 ipc_handle.close()
def
torch_to_ipc_buffer(tensor: torch.TensorType) -> tuple[pyarrow.lib.array, dict]:
22def torch_to_ipc_buffer(tensor: torch.TensorType) -> tuple[pa.array, dict]: 23 """Convert a Pytorch tensor into a pyarrow buffer containing the IPC handle 24 and its metadata. 25 26 Example Use: 27 ```python 28 torch_tensor = torch.tensor(random_data, dtype=torch.int64, device="cuda") 29 ipc_buffer, metadata = torch_to_ipc_buffer(torch_tensor) 30 node.send_output("latency", ipc_buffer, metadata) 31 ``` 32 """ 33 device_arr = to_device(tensor) 34 ipch = get_context().get_ipc_handle(device_arr.gpu_data) 35 _, handle, size, source_info, offset = ipch.__reduce__()[1] 36 metadata = { 37 "shape": device_arr.shape, 38 "strides": device_arr.strides, 39 "dtype": device_arr.dtype.str, 40 "size": size, 41 "offset": offset, 42 "source_info": json.dumps(source_info), 43 } 44 return pa.array(handle, pa.int8()), metadata
Convert a Pytorch tensor into a pyarrow buffer containing the IPC handle and its metadata.
Example Use:
torch_tensor = torch.tensor(random_data, dtype=torch.int64, device="cuda")
ipc_buffer, metadata = torch_to_ipc_buffer(torch_tensor)
node.send_output("latency", ipc_buffer, metadata)
def
ipc_buffer_to_ipc_handle( handle_buffer: <cyfunction array>, metadata: dict) -> numba.cuda.cudadrv.driver.IpcHandle:
47def ipc_buffer_to_ipc_handle(handle_buffer: pa.array, metadata: dict) -> IpcHandle: 48 """Convert a buffer containing a serialized handler into cuda IPC Handle. 49 50 example use: 51 ```python 52 from dora.cuda import ipc_buffer_to_ipc_handle, open_ipc_handle 53 54 event = node.next() 55 56 ipc_handle = ipc_buffer_to_ipc_handle(event["value"], event["metadata"]) 57 with open_ipc_handle(ipc_handle, event["metadata"]) as tensor: 58 pass 59 ``` 60 """ 61 handle = handle_buffer.to_pylist() 62 return IpcHandle._rebuild( 63 handle, 64 metadata["size"], 65 json.loads(metadata["source_info"]), 66 metadata["offset"], 67 )
Convert a buffer containing a serialized handler into cuda IPC Handle.
example use:
from dora.cuda import ipc_buffer_to_ipc_handle, open_ipc_handle
event = node.next()
ipc_handle = ipc_buffer_to_ipc_handle(event["value"], event["metadata"])
with open_ipc_handle(ipc_handle, event["metadata"]) as tensor:
pass
@contextmanager
def
open_ipc_handle( ipc_handle: numba.cuda.cudadrv.driver.IpcHandle, metadata: dict) -> ContextManager[torch.TensorType]:
70@contextmanager 71def open_ipc_handle( 72 ipc_handle: IpcHandle, metadata: dict 73) -> ContextManager[torch.TensorType]: 74 """Open a CUDA IPC handle and return a Pytorch tensor. 75 76 example use: 77 ```python 78 from dora.cuda import ipc_buffer_to_ipc_handle, open_ipc_handle 79 80 event = node.next() 81 82 ipc_handle = ipc_buffer_to_ipc_handle(event["value"], event["metadata"]) 83 with open_ipc_handle(ipc_handle, event["metadata"]) as tensor: 84 pass 85 ``` 86 """ 87 shape = metadata["shape"] 88 strides = metadata["strides"] 89 dtype = metadata["dtype"] 90 try: 91 buffer = ipc_handle.open(get_context()) 92 device_arr = DeviceNDArray(shape, strides, dtype, gpu_data=buffer) 93 yield torch.as_tensor(device_arr, device="cuda") 94 finally: 95 ipc_handle.close()
Open a CUDA IPC handle and return a Pytorch tensor.
example use:
from dora.cuda import ipc_buffer_to_ipc_handle, open_ipc_handle
event = node.next()
ipc_handle = ipc_buffer_to_ipc_handle(event["value"], event["metadata"])
with open_ipc_handle(ipc_handle, event["metadata"]) as tensor:
pass