Skip to content

Commit fc62b41

Browse files
committed
Chnages: Torch > 1.12.0 compat for __torch_function__ as a classmethod
Adds: pip_memory_device as a field for FakeLoader/DataLoader Adds: foreach param unit tests for optimizers for torch version>=1.12 Adds: nightly torch comment option for docker-compose.yml
1 parent 26f1232 commit fc62b41

9 files changed

+147
-77
lines changed

docker-compose.yml

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ services:
1313
volumes:
1414
- .:/data/
1515
environment:
16+
# - LIB_INSTALL_TYPE=.[dev] && pip install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu --upgrade #optionally change this locally to .[dev] and install nighty torch
1617
- LIB_INSTALL_TYPE=. #optionally change this locally to .[dev] to install dev packages as well
1718

1819
notebook:

fastai/data/load.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@ def _fn_noops(self, x=None, *args, **kwargs): return x
2828
_index_sampler,generator,prefetch_factor = Inf.count,None,2
2929
dataset_kind = _dataset_kind = _DatasetKind.Iterable
3030

31-
def __init__(self, d, pin_memory, num_workers, timeout, persistent_workers):
32-
self.dataset,self.default,self.worker_init_fn = self,d,_wif
33-
store_attr('d,pin_memory,num_workers,timeout,persistent_workers')
31+
def __init__(self, d, pin_memory, num_workers, timeout, persistent_workers,pin_memory_device):
32+
self.dataset,self.default,self.worker_init_fn,self.pin_memory_device = self,d,_wif,pin_memory_device
33+
store_attr('d,pin_memory,num_workers,timeout,persistent_workers,pin_memory_device')
3434

3535
def __iter__(self): return iter(self.d.create_batches(self.d.sample()))
3636

@@ -92,7 +92,8 @@ class DataLoader(GetAttr):
9292
get_idxs sample shuffle_fn do_batch create_batch'.split()
9393
_default = 'dataset'
9494
def __init__(self, dataset=None, bs=None, num_workers=0, pin_memory=False, timeout=0, batch_size=None,
95-
shuffle=False, drop_last=False, indexed=None, n=None, device=None, persistent_workers=False, **kwargs):
95+
shuffle=False, drop_last=False, indexed=None, n=None, device=None, persistent_workers=False,
96+
pin_memory_device='', **kwargs):
9697
if batch_size is not None: bs = batch_size # PyTorch compatibility
9798
assert not (bs is None and drop_last)
9899
if indexed is None: indexed = (hasattr(dataset,'__getitem__')
@@ -107,7 +108,8 @@ def __init__(self, dataset=None, bs=None, num_workers=0, pin_memory=False, timeo
107108
print("Due to IPython and Windows limitation, python multiprocessing isn't available now.")
108109
print("So `number_workers` is changed to 0 to avoid getting stuck")
109110
num_workers = 0
110-
self.fake_l = _FakeLoader(self, pin_memory, num_workers, timeout, persistent_workers=persistent_workers)
111+
self.fake_l = _FakeLoader(self, pin_memory, num_workers, timeout, persistent_workers=persistent_workers,
112+
pin_memory_device=pin_memory_device)
111113

112114
def __len__(self):
113115
if self.n is None: raise TypeError

fastai/distributed.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,9 @@ def __init__(self,dl,rank=None,world_size=None):
8787
pin_memory=dl.pin_memory, timeout=dl.timeout, shuffle=shuffle, drop_last=dl.drop_last, persistent_workers=dl.persistent_workers)
8888
self.bs,self.device,self.drop_last,self.dataset,fake,self.num_workers,self.offs,self.pin_memory = \
8989
attrgetter('bs','device','drop_last','dataset','fake_l','num_workers','offs','pin_memory')(self.dl)
90-
self.fake_l = _FakeLoader(self, fake.pin_memory, fake.num_workers, fake.timeout, persistent_workers=fake.persistent_workers)
90+
self.fake_l = _FakeLoader(self, fake.pin_memory, fake.num_workers, fake.timeout,
91+
persistent_workers=fake.persistent_workers,
92+
pin_memory_device=fake.pin_memory_device)
9193

9294
def _broadcast(self,t,rank):
9395
"Broadcasts t from rank `rank` to all other ranks. Returns t so t is same for all ranks after call."

fastai/optimizer.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# Cell
1313
#nbdev_comment from __future__ import annotations
1414
from .torch_basics import *
15+
from packaging import version
1516

1617
# Cell
1718
class _BaseOptimizer():
@@ -378,7 +379,11 @@ def set_item_pg(pg, k, v):
378379
return pg
379380

380381
# Cell
381-
pytorch_hp_map = {'momentum': 'mom', 'weight_decay': 'wd', 'alpha': 'sqr_mom', 'betas__0': 'mom', 'betas__1': 'sqr_mom'}
382+
pytorch_hp_map = {'momentum': 'mom', 'weight_decay': 'wd', 'alpha': 'sqr_mom', 'betas__0': 'mom',
383+
'betas__1': 'sqr_mom'}
384+
if version.parse(torch.version.__version__)>version.parse('1.12.0'):
385+
# Torch>=1.12 has a foreach param
386+
pytorch_hp_map = merge(*(pytorch_hp_map,{'foreach': 'foreach'}))
382387

383388
# Cell
384389
def _convert_params(o:list) -> list:

fastai/torch_core.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -353,13 +353,14 @@ def __reduce_ex__(self,proto):
353353
@classmethod
354354
def register_func(cls, func, *oks): cls._opt[func].append(oks)
355355

356-
def __torch_function__(self, func, types, args=(), kwargs=None):
357-
if self.debug and func.__name__ not in ('__str__','__repr__'): print(func, types, args, kwargs)
358-
convert=False
359-
if _torch_handled(args, self._opt, func): convert,types = type(self),(torch.Tensor,)
360-
res = super().__torch_function__(func, types, args=args, kwargs=kwargs)
361-
if convert: res = convert(res)
362-
if isinstance(res, TensorBase): res.set_meta(self, as_copy=True)
356+
@classmethod
357+
def __torch_function__(cls, func, types, args=(), kwargs=None):
358+
if cls.debug and func.__name__ not in ('__str__','__repr__'): print(func, types, args, kwargs)
359+
if is_listy(args[0]) and args[0]: dict_objs = [a for a in args[0] if hasattr(a,'__dict__')]
360+
else: dict_objs = [a for a in args if hasattr(a,'__dict__')]
361+
if _torch_handled(args, cls._opt, func): types = (torch.Tensor,)
362+
res = super().__torch_function__(func, types, args, ifnone(kwargs, {}))
363+
if issubclass(type(res),TensorBase) and dict_objs: res.set_meta(dict_objs[0],as_copy=True)
363364
return res
364365

365366
def new_tensor(self, size, dtype=None, device=None, requires_grad=False):

nbs/00_torch_core.ipynb

+63-29
Large diffs are not rendered by default.

nbs/02_data.load.ipynb

+26-23
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,9 @@
9797
" _index_sampler,generator,prefetch_factor = Inf.count,None,2\n",
9898
" dataset_kind = _dataset_kind = _DatasetKind.Iterable\n",
9999
" \n",
100-
" def __init__(self, d, pin_memory, num_workers, timeout, persistent_workers):\n",
101-
" self.dataset,self.default,self.worker_init_fn = self,d,_wif\n",
102-
" store_attr('d,pin_memory,num_workers,timeout,persistent_workers')\n",
100+
" def __init__(self, d, pin_memory, num_workers, timeout, persistent_workers,pin_memory_device):\n",
101+
" self.dataset,self.default,self.worker_init_fn,self.pin_memory_device = self,d,_wif,pin_memory_device\n",
102+
" store_attr('d,pin_memory,num_workers,timeout,persistent_workers,pin_memory_device')\n",
103103
"\n",
104104
" def __iter__(self): return iter(self.d.create_batches(self.d.sample()))\n",
105105
"\n",
@@ -274,7 +274,8 @@
274274
" get_idxs sample shuffle_fn do_batch create_batch'.split()\n",
275275
" _default = 'dataset'\n",
276276
" def __init__(self, dataset=None, bs=None, num_workers=0, pin_memory=False, timeout=0, batch_size=None,\n",
277-
" shuffle=False, drop_last=False, indexed=None, n=None, device=None, persistent_workers=False, **kwargs):\n",
277+
" shuffle=False, drop_last=False, indexed=None, n=None, device=None, persistent_workers=False,\n",
278+
" pin_memory_device='', **kwargs):\n",
278279
" if batch_size is not None: bs = batch_size # PyTorch compatibility\n",
279280
" assert not (bs is None and drop_last)\n",
280281
" if indexed is None: indexed = (hasattr(dataset,'__getitem__')\n",
@@ -289,7 +290,8 @@
289290
" print(\"Due to IPython and Windows limitation, python multiprocessing isn't available now.\")\n",
290291
" print(\"So `number_workers` is changed to 0 to avoid getting stuck\")\n",
291292
" num_workers = 0 \n",
292-
" self.fake_l = _FakeLoader(self, pin_memory, num_workers, timeout, persistent_workers=persistent_workers)\n",
293+
" self.fake_l = _FakeLoader(self, pin_memory, num_workers, timeout, persistent_workers=persistent_workers,\n",
294+
" pin_memory_device=pin_memory_device)\n",
293295
"\n",
294296
" def __len__(self):\n",
295297
" if self.n is None: raise TypeError\n",
@@ -423,7 +425,7 @@
423425
{
424426
"data": {
425427
"text/plain": [
426-
"(#40) [0.6220516703202649,0.38347354268972134,0.36273911288359706,0.4314958642862322,0.48170868503127295,0.1755075234373844,0.26036103499878493,0.16037428323147251,0.8350911770957413,0.4347179239514216...]"
428+
"(#0) []"
427429
]
428430
},
429431
"execution_count": null,
@@ -448,7 +450,7 @@
448450
{
449451
"data": {
450452
"text/plain": [
451-
"(#1) [4]"
453+
"(#11) [4,4,4,4,4,4,4,4,4,4...]"
452454
]
453455
},
454456
"execution_count": null,
@@ -468,7 +470,7 @@
468470
{
469471
"data": {
470472
"text/plain": [
471-
"(#21) [4,4,4,4,4,4,4,4,4,4...]"
473+
"(#10) [4,4,4,4,4,4,4,4,4,4]"
472474
]
473475
},
474476
"execution_count": null,
@@ -503,7 +505,7 @@
503505
{
504506
"data": {
505507
"text/plain": [
506-
"(#2) [0.6192763059885179,0.33021254121031707]"
508+
"(#7) [0.41917548410987093,0.5197100010284023,0.7706771870574884,0.6479314353871329,0.43661079462005437,0.6094953292136542,0.4985993416362957]"
507509
]
508510
},
509511
"execution_count": null,
@@ -631,18 +633,18 @@
631633
"name": "stdout",
632634
"output_type": "stream",
633635
"text": [
634-
"CPU times: user 4.27 ms, sys: 1.05 ms, total: 5.32 ms\n",
635-
"Wall time: 316 ms\n",
636-
"CPU times: user 12.7 ms, sys: 11.9 ms, total: 24.5 ms\n",
637-
"Wall time: 197 ms\n",
638-
"CPU times: user 14.5 ms, sys: 16.2 ms, total: 30.7 ms\n",
639-
"Wall time: 127 ms\n"
636+
"CPU times: user 6.97 ms, sys: 0 ns, total: 6.97 ms\n",
637+
"Wall time: 309 ms\n",
638+
"CPU times: user 12.2 ms, sys: 12.8 ms, total: 25 ms\n",
639+
"Wall time: 277 ms\n",
640+
"CPU times: user 21.9 ms, sys: 23.9 ms, total: 45.7 ms\n",
641+
"Wall time: 325 ms\n"
640642
]
641643
},
642644
{
643645
"data": {
644646
"text/plain": [
645-
"(#26) ['r','c','q','n','j','s','l','p','b','y'...]"
647+
"(#26) ['i','x','t','y','p','u','j','n','f','k'...]"
646648
]
647649
},
648650
"execution_count": null,
@@ -677,8 +679,8 @@
677679
"name": "stdout",
678680
"output_type": "stream",
679681
"text": [
680-
"CPU times: user 12 ms, sys: 22.3 ms, total: 34.3 ms\n",
681-
"Wall time: 130 ms\n"
682+
"CPU times: user 19 ms, sys: 22.1 ms, total: 41 ms\n",
683+
"Wall time: 295 ms\n"
682684
]
683685
}
684686
],
@@ -728,9 +730,9 @@
728730
{
729731
"data": {
730732
"text/plain": [
731-
"[tensor([16, 14, 5, 1, 39, 49, 10, 40, 7, 36, 28, 42, 32, 24, 43, 46, 4, 3,\n",
732-
" 11, 48, 26, 35, 15, 25, 23, 8, 44, 47, 0, 34, 21, 17]),\n",
733-
" tensor([45, 41, 6, 20, 38, 19, 29, 37, 13, 18, 2, 27, 30, 12, 33, 22, 9, 31])]"
733+
"[tensor([29, 10, 19, 23, 36, 5, 31, 1, 40, 22, 24, 47, 34, 9, 2, 33, 39, 30,\n",
734+
" 42, 49, 14, 17, 18, 35, 15, 27, 13, 48, 3, 32, 4, 8]),\n",
735+
" tensor([11, 25, 45, 28, 38, 7, 6, 37, 44, 0, 26, 12, 41, 43, 21, 16, 20, 46])]"
734736
]
735737
},
736738
"execution_count": null,
@@ -831,6 +833,7 @@
831833
"Converted 21_vision.learner.ipynb.\n",
832834
"Converted 22_tutorial.imagenette.ipynb.\n",
833835
"Converted 23_tutorial.vision.ipynb.\n",
836+
"Converted 24_tutorial.image_sequence.ipynb.\n",
834837
"Converted 24_tutorial.siamese.ipynb.\n",
835838
"Converted 24_vision.gan.ipynb.\n",
836839
"Converted 30_text.core.ipynb.\n",
@@ -839,7 +842,6 @@
839842
"Converted 33_text.models.core.ipynb.\n",
840843
"Converted 34_callback.rnn.ipynb.\n",
841844
"Converted 35_tutorial.wikitext.ipynb.\n",
842-
"Converted 36_text.models.qrnn.ipynb.\n",
843845
"Converted 37_text.learner.ipynb.\n",
844846
"Converted 38_tutorial.text.ipynb.\n",
845847
"Converted 39_tutorial.transformers.ipynb.\n",
@@ -858,7 +860,7 @@
858860
"Converted 71_callback.tensorboard.ipynb.\n",
859861
"Converted 72_callback.neptune.ipynb.\n",
860862
"Converted 73_callback.captum.ipynb.\n",
861-
"Converted 74_callback.azureml.ipynb.\n",
863+
"Converted 74_huggingface.ipynb.\n",
862864
"Converted 97_test_utils.ipynb.\n",
863865
"Converted 99_pytorch_doc.ipynb.\n",
864866
"Converted dev-setup.ipynb.\n",
@@ -868,6 +870,7 @@
868870
"Converted migrating_ignite.ipynb.\n",
869871
"Converted migrating_lightning.ipynb.\n",
870872
"Converted migrating_pytorch.ipynb.\n",
873+
"Converted migrating_pytorch_verbose.ipynb.\n",
871874
"Converted ulmfit.ipynb.\n",
872875
"Converted index.ipynb.\n",
873876
"Converted quick_start.ipynb.\n",

nbs/12_optimizer.ipynb

+30-10
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@
2828
"source": [
2929
"#|export\n",
3030
"from __future__ import annotations\n",
31-
"from fastai.torch_basics import *"
31+
"from fastai.torch_basics import *\n",
32+
"from packaging import version"
3233
]
3334
},
3435
{
@@ -456,7 +457,7 @@
456457
"text/markdown": [
457458
"<h4 id=\"Optimizer.step\" class=\"doc_header\"><code>Optimizer.step</code><a href=\"__main__.py#L24\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
458459
"\n",
459-
"> <code>Optimizer.step</code>()\n",
460+
"> <code>Optimizer.step</code>(**`closure`**=*`None`*)\n",
460461
"\n",
461462
"Standard PyTorch API: Update the stats and execute the steppers in on all parameters that have a grad"
462463
],
@@ -865,7 +866,7 @@
865866
{
866867
"data": {
867868
"text/markdown": [
868-
"<h4 id=\"Optimizer.state_dict\" class=\"doc_header\"><code>Optimizer.state_dict</code><a href=\"__main__.py#L33\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
869+
"<h4 id=\"Optimizer.state_dict\" class=\"doc_header\"><code>Optimizer.state_dict</code><a href=\"__main__.py#L34\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
869870
"\n",
870871
"> <code>Optimizer.state_dict</code>()\n",
871872
"\n",
@@ -891,7 +892,7 @@
891892
{
892893
"data": {
893894
"text/markdown": [
894-
"<h4 id=\"Optimizer.load_state_dict\" class=\"doc_header\"><code>Optimizer.load_state_dict</code><a href=\"__main__.py#L37\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
895+
"<h4 id=\"Optimizer.load_state_dict\" class=\"doc_header\"><code>Optimizer.load_state_dict</code><a href=\"__main__.py#L38\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
895896
"\n",
896897
"> <code>Optimizer.load_state_dict</code>(**`sd`**:`dict`)\n",
897898
"\n",
@@ -943,7 +944,7 @@
943944
{
944945
"data": {
945946
"text/markdown": [
946-
"<h4 id=\"Optimizer.clear_state\" class=\"doc_header\"><code>Optimizer.clear_state</code><a href=\"__main__.py#L29\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
947+
"<h4 id=\"Optimizer.clear_state\" class=\"doc_header\"><code>Optimizer.clear_state</code><a href=\"__main__.py#L30\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
947948
"\n",
948949
"> <code>Optimizer.clear_state</code>()\n",
949950
"\n",
@@ -1768,7 +1769,11 @@
17681769
"outputs": [],
17691770
"source": [
17701771
"#|export\n",
1771-
"pytorch_hp_map = {'momentum': 'mom', 'weight_decay': 'wd', 'alpha': 'sqr_mom', 'betas__0': 'mom', 'betas__1': 'sqr_mom'}"
1772+
"pytorch_hp_map = {'momentum': 'mom', 'weight_decay': 'wd', 'alpha': 'sqr_mom', 'betas__0': 'mom',\n",
1773+
" 'betas__1': 'sqr_mom'}\n",
1774+
"if version.parse(torch.version.__version__)>version.parse('1.12.0'):\n",
1775+
" # Torch>=1.12 has a foreach param\n",
1776+
" pytorch_hp_map = merge(*(pytorch_hp_map,{'foreach': 'foreach'}))"
17721777
]
17731778
},
17741779
{
@@ -1847,6 +1852,10 @@
18471852
"test_eq(tst_sgd.opt.param_groups[0]['params'], [tensor(4,5,6)])\n",
18481853
"#Access to hypers\n",
18491854
"_xtra_hypers = dict(dampening=0., nesterov=False, maximize=False)\n",
1855+
"\n",
1856+
"if version.parse(torch.version.__version__)>version.parse('1.12.0'):\n",
1857+
" _xtra_hypers = merge(*(_xtra_hypers,dict(foreach=None)))\n",
1858+
" \n",
18501859
"test_eq(tst_sgd.hypers, [{**sgd.hypers[0], **_xtra_hypers}])\n",
18511860
"#Set hypers\n",
18521861
"tst_sgd.set_hyper('mom', 0.95)\n",
@@ -1912,17 +1921,28 @@
19121921
"#|hide\n",
19131922
"#check it works with tuply hp names like in Adam\n",
19141923
"tst_adam = OptimWrapper([tensor([1,2,3])], torch.optim.Adam, lr=1e-2, betas=(0.9, 0.99))\n",
1915-
"test_eq(tst_adam.hypers, [{\n",
1916-
" 'lr': 0.01, 'mom': 0.9, 'sqr_mom': 0.99, 'eps': 1e-08, 'wd': 0, 'amsgrad': False, 'maximize':False}])\n",
1924+
"\n",
1925+
"tst_hypers = {'lr': 0.01, 'mom': 0.9, 'sqr_mom': 0.99, 'eps': 1e-08, 'wd': 0, \n",
1926+
" 'amsgrad': False, 'maximize':False}\n",
1927+
"if version.parse(torch.version.__version__)>version.parse('1.12.0'):\n",
1928+
" tst_hypers = merge(*(tst_hypers,dict(foreach=None)))\n",
1929+
"\n",
1930+
"test_eq(tst_adam.hypers, [tst_hypers])\n",
19171931
"tst_adam.set_hyper('mom', 0.95)\n",
19181932
"test_eq(tst_adam.opt.param_groups[0]['betas'], (0.95, 0.99))\n",
19191933
"tst_adam.set_hyper('sqr_mom', 0.9)\n",
19201934
"test_eq(tst_adam.opt.param_groups[0]['betas'], (0.95, 0.9))\n",
19211935
"\n",
19221936
"tst_adam = torch.optim.Adam([tensor([1,2,3])], lr=1e-2, betas=(0.9, 0.99))\n",
19231937
"tst_adam = OptimWrapper(opt=tst_adam)\n",
1924-
"test_eq(tst_adam.hypers, [{\n",
1925-
" 'lr': 0.01, 'mom': 0.9, 'sqr_mom': 0.99, 'eps': 1e-08, 'wd': 0, 'amsgrad': False, 'maximize':False}])\n",
1938+
"\n",
1939+
"tst_hypers = {'lr': 0.01, 'mom': 0.9, 'sqr_mom': 0.99, 'eps': 1e-08, 'wd': 0, 'amsgrad': False, \n",
1940+
" 'maximize':False}\n",
1941+
"\n",
1942+
"if version.parse(torch.version.__version__)>version.parse('1.12.0'):\n",
1943+
" tst_hypers = merge(*(tst_hypers,dict(foreach=None)))\n",
1944+
"\n",
1945+
"test_eq(tst_adam.hypers, [tst_hypers])\n",
19261946
"tst_adam.set_hyper('mom', 0.95)\n",
19271947
"test_eq(tst_adam.opt.param_groups[0]['betas'], (0.95, 0.99))\n",
19281948
"tst_adam.set_hyper('sqr_mom', 0.9)\n",

nbs/20a_distributed.ipynb

+3-1
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,9 @@
240240
" pin_memory=dl.pin_memory, timeout=dl.timeout, shuffle=shuffle, drop_last=dl.drop_last, persistent_workers=dl.persistent_workers)\n",
241241
" self.bs,self.device,self.drop_last,self.dataset,fake,self.num_workers,self.offs,self.pin_memory = \\\n",
242242
" attrgetter('bs','device','drop_last','dataset','fake_l','num_workers','offs','pin_memory')(self.dl)\n",
243-
" self.fake_l = _FakeLoader(self, fake.pin_memory, fake.num_workers, fake.timeout, persistent_workers=fake.persistent_workers)\n",
243+
" self.fake_l = _FakeLoader(self, fake.pin_memory, fake.num_workers, fake.timeout, \n",
244+
" persistent_workers=fake.persistent_workers, \n",
245+
" pin_memory_device=fake.pin_memory_device)\n",
244246
" \n",
245247
" def _broadcast(self,t,rank):\n",
246248
" \"Broadcasts t from rank `rank` to all other ranks. Returns t so t is same for all ranks after call.\"\n",

0 commit comments

Comments
 (0)