Last active
April 30, 2021 09:56
-
-
Save fjetter/25fae963c70c9b756b591213244af96a to your computer and use it in GitHub Desktop.
A draft for a possible new multi-layered spilling interface for dask.distributed
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import io | |
from collections import defaultdict | |
from typing import Callable, Dict, List, MutableMapping, Set | |
from dask.sizeof import sizeof | |
class Data: | |
async def open(self) -> io.IOBase: | |
return io.StringIO() | |
async def close(self): | |
pass | |
class DataProcessor: | |
async def store(self, data: Data) -> Data: | |
# Compress, spill, remote store, etc. | |
return data | |
async def retrieve(self, data: Data) -> Data: | |
# Restore original | |
return data | |
class CompressionProcessor(DataProcessor): | |
pass | |
class FileProcessor(DataProcessor): | |
pass | |
class RemoteStorageProcessor(DataProcessor): | |
pass | |
class DataProxy: | |
def __init__(self, data: Data, processors: List[DataProcessor]): | |
# Indicate how many levels this data can be pushed down | |
self.max_level = len(processors) | |
# This is the size *in-memory*. For in-memory data this would be the raw | |
# amount, for remote storage this may be just the size of the key, etc. | |
# We probably want to store the raw size as well for tracking but that's | |
# not the point of this draft | |
self.size = sizeof(data) | |
self.processors = processors | |
# We keep a refcount to avoid duplicates and avoid spilling currently | |
# used data | |
self.refcount = 0 | |
self.current_level = 0 | |
self.raw = data | |
async def archive(self): | |
# Only allow archivation (spill, compression, etc.) if data is not | |
# currently in use | |
if self.refcount == 0: | |
processor = self.processors[self.current_level] | |
self.raw = await processor.store(self.raw) | |
self.size = sizeof(self.raw) | |
self.current_level += 1 | |
return self.current_level | |
async def unarchive(self): | |
if self.current_level: | |
processor = self.processors[self.current_level] | |
self.raw = await processor.retrieve(self.raw) | |
self.size = sizeof(self.raw) | |
self.current_level -= 1 | |
return self.current_level | |
async def open(self) -> io.IOBase: | |
self.refcount += 1 | |
while not self.current_level is 0: | |
await self.unarchive() | |
return await self.raw.open() | |
async def close(self): | |
self.refcount -= 1 | |
await self.raw.close() | |
def proxy_factory(data: Data) -> DataProxy: | |
# Maybe we want to have different processors per object type | |
if isinstance(data, object): | |
return DataProxy( | |
data, | |
processors=[ | |
CompressionProcessor(), | |
FileProcessor(), | |
], | |
) | |
else: | |
raise RuntimeError() | |
Level = int | |
class SmartMutableBuffer(MutableMapping): | |
def __init__( | |
self, | |
data: Dict[str, Data], | |
target: int | Dict[Level, int], | |
factory: Callable[[Data], DataProxy], | |
): | |
self._data = {} | |
self.factory = factory | |
# We may want to limit targets by level, e.g. if we can move data | |
# through different layers of memory. That's a use case for GPUs, I | |
# believe. Standard users would probably simply define an integer | |
if isinstance(target, int): | |
# This is not accurate. If we have multiple processors which are | |
# using actual memory, we might need to refine this target a bit. | |
# But this should give the idea | |
self.target = {0: target} | |
else: | |
self.target = target | |
self.size_by_level: Dict[Level, int] = defaultdict(lambda: 0) | |
self.keys_by_level: Dict[Level, Set[str]] = defaultdict(set) | |
for k, value in data.items(): | |
self[k] = value | |
def __getitem__(self, k: str) -> DataProxy: | |
return self._data[k] | |
def __setitem__(self, k: str, v: Data) -> None: | |
proxy = self.factory(v) | |
self._data[k] = proxy | |
self.size_by_level[proxy.current_level] += proxy.size | |
self.keys_by_level[proxy.current_level].add(k) | |
async def balance(self): | |
"""This balanaces all layers such that they fall below their given targets, if possible.""" | |
# It may also be necessary to refactor this to allow for explicit | |
# evict calls. c.f. memory.target vs memory.spill | |
# This is just an example and it does not have a valid exit condition | |
# and may therefore never stop but should give an idea. This could be | |
# further optimized with various LRU / size / whatever policies | |
for level, target in self.target.items(): | |
# If we are breaching the desired target for a given level, start | |
# archiving and move data to lower levels | |
while self.size_by_level[level] > target: | |
if not self.keys_by_level[level]: | |
break | |
k = self.keys_by_level[level].pop() | |
proxy = self._data[k] | |
self.size_by_level[level] -= proxy.size | |
new_level = await proxy.archive() | |
self.size_by_level[new_level] += proxy.size | |
self.keys_by_level[new_level].add(k) | |
# If we are otherwise way below the target, we may unarchive stuff | |
# to keep data in hot storage since eventually we'll need it again | |
while self.size_by_level[level] < target * 0.7: | |
pass | |
# do the same thing with archive to keep the data in mostly hot | |
# storage | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment