@@ -35,28 +35,6 @@ def from_provolatile(cls, provolatile: str) -> FunctionVolatility:
3535 raise ValueError (f"Invalid volatility: { provolatile } " )
3636
3737
38- # def normalize_arg(arg: str) -> str:
39- # parts = arg.strip().split(maxsplit=1)
40- # if len(parts) == 2:
41- # name, type_str = parts
42- # norm_type = type_map.get(type_str.lower(), type_str.lower())
43- # # Handle array types
44- # if norm_type.endswith("[]"):
45- # base_type = norm_type[:-2]
46- # norm_base_type = type_map.get(base_type, base_type)
47- # norm_type = f"{norm_base_type}[]"
48- #
49- # return f"{name} {norm_type}"
50- # # Handle case where it might just be the type (e.g., from DROP FUNCTION)
51- # type_str = arg.strip()
52- # norm_type = type_map.get(type_str.lower(), type_str.lower())
53- # if norm_type.endswith("[]"):
54- # base_type = norm_type[:-2]
55- # norm_base_type = type_map.get(base_type, base_type)
56- # norm_type = f"{norm_base_type}[]"
57- # return norm_type
58-
59-
6038@dataclass
6139class Function (base .Function ):
6240 """Describes a PostgreSQL function.
@@ -143,7 +121,7 @@ def normalize(self) -> Function:
143121
144122@dataclass
145123class FunctionParam :
146- name : str
124+ name : str | None
147125 type : str
148126 default : Any | None = None
149127 mode : Literal ["i" , "o" , "b" , "v" , "t" ] | None = None
@@ -185,31 +163,40 @@ def from_unknown(
185163 if isinstance (source_param , tuple ):
186164 return cls (* source_param )
187165
188- name , type = source_param .strip ().split (maxsplit = 1 )
166+ try :
167+ name , type = source_param .strip ().split (maxsplit = 1 )
168+ except ValueError :
169+ name = None
170+ type = source_param .strip ()
171+
189172 return cls (name , type )
190173
191174 def normalize (self ) -> FunctionParam :
192175 type = self .type .lower ()
193176 return replace (
194177 self ,
195- name = self .name .lower (),
178+ name = self .name .lower () if self . name is not None else None ,
196179 mode = self .mode or "i" ,
197180 type = type_map .get (type , type ),
198181 default = str (self .default ) if self .default is not None else None ,
199182 )
200183
201184 def to_sql_create (self ) -> str :
202- result = ""
185+ segments = []
203186 if self .mode :
204- result += {"o" : "OUT " , "b" : "INOUT " , "v" : "VARIADIC " , "t" : "TABLE " }.get (
205- self .mode , ""
206- )
187+ modes = {"o" : "OUT " , "b" : "INOUT " , "v" : "VARIADIC " , "t" : "TABLE " }
188+ mode = modes .get (self .mode )
189+ if mode :
190+ segments .append (mode )
191+
192+ if self .name :
193+ segments .append (self .name )
207194
208- result += f" { self . name } { self .type } "
195+ segments . append ( self .type )
209196
210197 if self .default is not None :
211- result += f" DEFAULT { self .default } "
212- return result
198+ segments . append ( f" DEFAULT { self .default } ")
199+ return " " . join ( segments )
213200
214201 def to_sql_drop (self ) -> str :
215202 return self .type
@@ -245,7 +232,7 @@ def from_unknown(
245232 returns_lower = source .lower ().strip ()
246233 if returns_lower .startswith ("table(" ):
247234 table_return_params = [
248- (p .name , p .type ) for p in (parameters or []) if p .mode == "t"
235+ (p .name , p .type ) for p in (parameters or []) if p .name and p . mode == "t"
249236 ]
250237
251238 if not table_return_params :
0 commit comments