1+ from collections .abc import Iterable
2+ from dotwiz import DotWiz
3+ from dataclasses import dataclass
4+ from typing import Union
5+ import itertools
6+ import funcy as fc
7+ import exrex
8+ import magicattr
9+ import numpy as np
10+ import copy
11+ import datasets
12+ import time
13+
14+ def get_column_names (dataset ):
15+ cn = dataset .column_names
16+ if type (cn )== dict :
17+ return set (fc .flatten (cn .values ()))
18+ else :
19+ return set (cn )
20+
21+
22+ def sample_dataset (dataset ,n = 10000 , n_eval = 1000 ,seed = 0 ):
23+ for k in dataset :
24+ n_k = (n if k == 'train' else n_eval )
25+ if n_k and len (dataset [k ])> n_k :
26+ dataset [k ]= dataset [k ].train_test_split (train_size = n_k ,seed = seed )['train' ]
27+ return dataset
28+
29+ class Preprocessing (DotWiz ):
30+ default_splits = ('train' ,'validation' ,'test' )
31+ @staticmethod
32+ def __map_to_target (x ,fn = lambda x :None , target = None ):
33+ x [target ]= fn (x )
34+ return x
35+
36+ def load (self ):
37+ return self (datasets .load_dataset (self .dataset_name ,self .config_name ))
38+
39+ def __call__ (self ,dataset , max_rows = None , max_rows_eval = None ,seed = 0 ):
40+ dataset = self .pre_process (dataset )
41+
42+ # manage splits
43+ for k ,v in zip (self .default_splits , self .splits ):
44+ if v and k != v :
45+ dataset [k ]= dataset [v ]
46+ del dataset [v ]
47+ if k in dataset and not v : # obfuscated label
48+ del dataset [k ]
49+ dataset = fix_splits (dataset )
50+
51+ for k in list (dataset .keys ()):
52+ if k not in self .default_splits :
53+ del dataset [k ]
54+ dataset = sample_dataset (dataset , max_rows , max_rows_eval ,seed = seed )
55+
56+ # field annotated with a string
57+ substitutions = {v :k for k ,v in self .to_dict ().items ()
58+ if (k and k not in {'splits' ,'dataset_name' ,'config_name' }
59+ and type (v )== str and k != v )}
60+
61+ dataset = dataset .remove_columns ([c for c in substitutions .values () if c in dataset ['train' ].features and c not in substitutions ])
62+ dataset = dataset .rename_columns (substitutions )
63+
64+ # field annotated with a function
65+ for k in self .to_dict ().keys ():
66+ v = getattr (self , k )
67+ if callable (v ) and k not in {"post_process" ,"pre_process" ,"load" }:
68+ dataset = dataset .map (self .__map_to_target ,
69+ fn_kwargs = {'fn' :v ,'target' :k })
70+
71+ dataset = dataset .remove_columns (
72+ get_column_names (dataset )- set (self .to_dict ().keys ()))
73+ dataset = fix_labels (dataset )
74+ dataset = fix_splits (dataset ) # again: label mapping changed
75+ dataset = self .post_process (dataset )
76+ return dataset
77+
78+
79+ @dataclass
80+ class cat (Preprocessing ):
81+ fields :Union [str ,list ]= None
82+ separator :str = ' '
83+
84+ def __call__ (self , example = None ):
85+ y = [np .char .array (example [f ]) + sep
86+ for f ,sep in zip (self .fields [::- 1 ],itertools .repeat (self .separator ))]
87+ y = list (sum (* y ))
88+ if len (y )== 1 :
89+ y = y [0 ]
90+ return y
91+
92+
93+ def pretty (f ):
94+ class pretty_f (DotWiz ):
95+ def __init__ (self ,* args ):
96+ self .__f_arg = f (* args )
97+ for a in args :
98+ setattr (self ,'value' ,a )
99+
100+ def __call__ (self , * args ,** kwargs ):
101+ return self .__f_arg (* args ,** kwargs )
102+
103+ def __repr__ (self ):
104+ return f"{ self .__f_arg .__qualname__ .split ('.' )[0 ]} ({ self .value } )"
105+ return pretty_f
106+
107+ class dotgetter :
108+ def __init__ (self , path = '' ):
109+ self .path = path
110+
111+ def __bool__ (self ):
112+ return bool (self .path )
113+
114+ def __getattr__ (self , k ):
115+ return self .__class__ (f'{ self .path } .{ k } ' .lstrip ('.' ))
116+
117+ def __getitem__ (self , i ):
118+ return self .__class__ (f'{ self .path } [{ i } ]' )
119+
120+ def __call__ (self , example = None ):
121+ return magicattr .get (DotWiz (example ), self .path )
122+
123+ def __hash__ (self ):
124+ return hash (self .path )
125+
126+
127+ @dataclass
128+ class ClassificationFields (Preprocessing ):
129+ sentence1 :str = 'sentence1'
130+ sentence2 :str = 'sentence2'
131+ labels :str = 'labels'
132+
133+ @dataclass
134+ class Seq2SeqLMFields (Preprocessing ):
135+ prompt :str = 'prompt'
136+ output :str = 'output'
137+
138+ @dataclass
139+ class TokenClassificationFields (Preprocessing ):
140+ tokens :str = 'tokens'
141+ labels :str = 'labels'
142+
143+ @dataclass
144+ class MultipleChoiceFields (Preprocessing ):
145+ inputs :str = 'input'
146+ choices :Iterable = tuple ()
147+ labels :str = 'labels'
148+ choices_list :str = None
149+ def __post_init__ (self ):
150+ for i , c in enumerate (self .choices ):
151+ setattr (self ,f'choice{ i } ' ,c )
152+ delattr (self ,'choices' )
153+ if not self .choices_list :
154+ delattr (self ,'choices_list' )
155+
156+ def __call__ (self ,dataset , * args , ** kwargs ):
157+ dataset = super ().__call__ (dataset , * args , ** kwargs )
158+ if self .choices_list :
159+ dataset = dataset .filter (lambda x : 1 < len (x ['choices_list' ]))
160+ n_options = min ([len (x ) for k in dataset for x in dataset [k ]['choices_list' ]])
161+ n_options = min (5 ,n_options )
162+ dataset = dataset .map (self .flatten , fn_kwargs = {'n_options' :n_options })
163+ return dataset
164+
165+ @staticmethod
166+ def flatten (x , n_options = None ):
167+ n_neg = n_options - 1 if n_options else None
168+ choices = x ['choices_list' ]
169+ label = x ['labels' ]
170+ neg = choices [:label ] + choices [label + 1 :]
171+ pos = choices [label ]
172+ x ['labels' ]= 0
173+ x ['choices_list' ]= [pos ]+ neg [:n_neg ]
174+ for i ,o in enumerate (x ['choices_list' ]):
175+ x [f'choice{ i } ' ]= o
176+ del x ['choices_list' ]
177+ return x
178+
179+ @dataclass
180+ class SharedFields :
181+ splits :list = Preprocessing .default_splits
182+ dataset_name :str = None
183+ config_name :str = None
184+ pre_process : callable = lambda x :x
185+ post_process : callable = lambda x :x
186+ #language:str="en"
187+
188+
189+ @dataclass
190+ class Classification (SharedFields , ClassificationFields ): pass
191+
192+ @dataclass
193+ class MultipleChoice (SharedFields , MultipleChoiceFields ): pass
194+
195+ @dataclass
196+ class TokenClassification (SharedFields , TokenClassificationFields ): pass
197+
198+ @dataclass
199+ class Seq2SeqLM (SharedFields , Seq2SeqLMFields ): pass
200+
201+ get = dotgetter ()
202+ constant = pretty (fc .constantly )
203+ regen = lambda x : list (exrex .generate (x ))
204+
205+ def name (label_name , classes ):
206+ return lambda x :classes [x [label_name ]]
207+
208+ def fix_splits (dataset ):
209+
210+ if len (dataset )== 1 and "train" not in dataset :
211+ k = list (dataset )[0 ]
212+ dataset ['train' ] = copy .deepcopy (dataset [k ])
213+ del dataset [k ]
214+
215+ if 'auxiliary_train' in dataset :
216+ del dataset ['auxiliary_train' ]
217+
218+ if 'test' in dataset : # manage obfuscated labels
219+ if 'labels' in dataset ['test' ].features :
220+ if len (set (fc .flatten (dataset ['test' ].to_dict ()['labels' ])))== 1 :
221+ del dataset ['test' ]
222+
223+ if 'validation' in dataset and 'train' not in dataset :
224+ train_validation = dataset ['validation' ].train_test_split (0.5 , seed = 0 )
225+ dataset ['train' ] = train_validation ['train' ]
226+ dataset ['validation' ]= train_validation ['test' ]
227+
228+ if 'validation' in dataset and 'test' not in dataset :
229+ validation_test = dataset ['validation' ].train_test_split (0.5 , seed = 0 )
230+ dataset ['validation' ] = validation_test ['train' ]
231+ dataset ['test' ]= validation_test ['test' ]
232+
233+ if 'train' in dataset and 'validation' not in dataset :
234+ train_val = dataset ['train' ].train_test_split (train_size = 0.90 , seed = 0 )
235+ dataset ['train' ] = train_val ['train' ]
236+ dataset ['validation' ]= train_val ['test' ]
237+
238+ if 'test' in dataset and 'validation' not in dataset :
239+ validation_test = dataset ['test' ].train_test_split (0.5 , seed = 0 )
240+ dataset ['validation' ] = validation_test ['train' ]
241+ dataset ['test' ]= validation_test ['test' ]
242+
243+ if 'validation' not in dataset and 'test' not in dataset :
244+ train_val_test = dataset ["train" ].train_test_split (train_size = 0.90 , seed = 0 )
245+ val_test = train_val_test ["test" ].train_test_split (0.5 , seed = 0 )
246+ dataset ["train" ] = train_val_test ["train" ]
247+ dataset ["validation" ] = val_test ["train" ]
248+ dataset ["test" ] = val_test ["test" ]
249+
250+ return dataset
251+
252+ def fix_labels (dataset , label_key = 'labels' ):
253+ if type (dataset ['train' ][label_key ][0 ]) in [int ,list ,float ]:
254+ return dataset
255+ labels = set (fc .flatten (dataset [k ][label_key ] for k in {"train" }))
256+ if set (labels )== {'entailment' ,'neutral' ,'contradiction' }:
257+ order = lambda x :dict (fc .flip (enumerate (['entailment' ,'neutral' ,'contradiction' ]))).get (x ,x )
258+ else :
259+ order = str
260+ labels = sorted (labels , key = order )
261+ dataset = dataset .cast_column (label_key , datasets .ClassLabel (names = labels ))
262+ return dataset
263+
264+ def concatenate_dataset_dict (l ):
265+ """Concatenate a list of DatastDict objects sharing same splits and columns."""
266+ keys = l [0 ].keys ()
267+ return datasets .DatasetDict ({k : datasets .concatenate_datasets ([x [k ] for x in l ]) for k in keys })
0 commit comments