|
1 | 1 | import logging |
2 | | -from typing import Any, Dict, Literal, Optional |
| 2 | +from typing import Any, Dict, Optional |
3 | 3 |
|
4 | 4 | import orjson |
5 | 5 | import pandas as pd |
6 | 6 | from haystack import component |
7 | 7 | from jsonschema.exceptions import ValidationError |
8 | | -from pydantic import BaseModel, Field |
| 8 | +from pydantic import BaseModel |
9 | 9 |
|
10 | 10 | logger = logging.getLogger("wren-ai-service") |
11 | 11 |
|
12 | 12 |
|
13 | | -chart_generation_instructions = """ |
14 | | -1. Please check VEGA-LITE SCHEMA SPECIFICATION to make sure the vega-lite schema is valid. |
15 | | -2. The following are examples of several chart types and their corresponding vega-lite schema: |
16 | | -
|
17 | | -a. Chart Type: Candlestick Chart |
18 | | - When to use: When you want to visualize the price changes of a stock or other financial instrument over time. |
19 | | - Vega-Lite Schema in JSON format(only include the necessary part): |
20 | | - { |
21 | | - "encoding": { |
22 | | - "x": { |
23 | | - "field": "date", |
24 | | - "type": "temporal", |
25 | | - "title": "<title>", |
26 | | - "axis": { |
27 | | - "format": "<format>", |
28 | | - "labelAngle": <labelAngle>, |
29 | | - "title": "<x_axis_title>" |
30 | | - } |
31 | | - }, |
32 | | - "y": { |
33 | | - "type": "quantitative", |
34 | | - "scale": {"zero": false}, |
35 | | - "axis": {"title": "<y_axis_title>"} |
36 | | - }, |
37 | | - "color": { |
38 | | - "condition": { |
39 | | - "test": "datum.open < datum.close", |
40 | | - "value": "#06982d" |
41 | | - }, |
42 | | - "value": "#ae1325" |
43 | | - } |
44 | | - }, |
45 | | - "layer": [ |
46 | | - { |
47 | | - "mark": "rule", |
48 | | - "encoding": { |
49 | | - "y": {"field": "low"}, |
50 | | - "y2": {"field": "high"} |
51 | | - } |
52 | | - }, |
53 | | - { |
54 | | - "mark": "bar", |
55 | | - "encoding": { |
56 | | - "y": {"field": "open"}, |
57 | | - "y2": {"field": "close"} |
58 | | - } |
59 | | - } |
60 | | - ] |
61 | | - } |
62 | | -""" |
63 | | - |
64 | | - |
65 | 13 | def load_custom_theme() -> Dict[str, Any]: |
66 | 14 | with open("src/pipelines/generation/utils/theme_powerbi.json", "r") as f: |
67 | 15 | return orjson.loads(f.read()) |
@@ -175,129 +123,7 @@ def read_vega_lite_schema() -> Dict[str, Any]: |
175 | 123 | return vega_lite_schema |
176 | 124 |
|
177 | 125 |
|
178 | | -class ChartSchema(BaseModel): |
179 | | - class ChartType(BaseModel): |
180 | | - type: Literal["bar", "line", "area", "arc"] |
181 | | - |
182 | | - class ChartEncoding(BaseModel): |
183 | | - field: str |
184 | | - type: Literal["ordinal", "quantitative", "nominal"] |
185 | | - title: str |
186 | | - |
187 | | - title: str |
188 | | - mark: ChartType |
189 | | - encoding: ChartEncoding |
190 | | - |
191 | | - |
192 | | -class TemporalChartEncoding(ChartSchema.ChartEncoding): |
193 | | - type: Literal["temporal"] = Field(default="temporal") |
194 | | - timeUnit: str = Field(default="yearmonth") |
195 | | - |
196 | | - |
197 | | -class LineChartSchema(ChartSchema): |
198 | | - class LineChartMark(BaseModel): |
199 | | - type: Literal["line"] = Field(default="line") |
200 | | - |
201 | | - class LineChartEncoding(BaseModel): |
202 | | - x: TemporalChartEncoding | ChartSchema.ChartEncoding |
203 | | - y: ChartSchema.ChartEncoding |
204 | | - color: ChartSchema.ChartEncoding |
205 | | - |
206 | | - mark: LineChartMark |
207 | | - encoding: LineChartEncoding |
208 | | - |
209 | | - |
210 | | -class MultiLineChartSchema(ChartSchema): |
211 | | - class MultiLineChartMark(BaseModel): |
212 | | - type: Literal["line"] = Field(default="line") |
213 | | - |
214 | | - class MultiLineChartTransform(BaseModel): |
215 | | - fold: list[str] |
216 | | - as_: list[str] = Field(alias="as") |
217 | | - |
218 | | - class MultiLineChartEncoding(BaseModel): |
219 | | - x: TemporalChartEncoding | ChartSchema.ChartEncoding |
220 | | - y: ChartSchema.ChartEncoding |
221 | | - color: ChartSchema.ChartEncoding |
222 | | - |
223 | | - mark: MultiLineChartMark |
224 | | - transform: list[MultiLineChartTransform] |
225 | | - encoding: MultiLineChartEncoding |
226 | | - |
227 | | - |
228 | | -class BarChartSchema(ChartSchema): |
229 | | - class BarChartMark(BaseModel): |
230 | | - type: Literal["bar"] = Field(default="bar") |
231 | | - |
232 | | - class BarChartEncoding(BaseModel): |
233 | | - x: TemporalChartEncoding | ChartSchema.ChartEncoding |
234 | | - y: ChartSchema.ChartEncoding |
235 | | - color: ChartSchema.ChartEncoding |
236 | | - |
237 | | - mark: BarChartMark |
238 | | - encoding: BarChartEncoding |
239 | | - |
240 | | - |
241 | | -class GroupedBarChartSchema(ChartSchema): |
242 | | - class GroupedBarChartMark(BaseModel): |
243 | | - type: Literal["bar"] = Field(default="bar") |
244 | | - |
245 | | - class GroupedBarChartEncoding(BaseModel): |
246 | | - x: TemporalChartEncoding | ChartSchema.ChartEncoding |
247 | | - y: ChartSchema.ChartEncoding |
248 | | - xOffset: ChartSchema.ChartEncoding |
249 | | - color: ChartSchema.ChartEncoding |
250 | | - |
251 | | - mark: GroupedBarChartMark |
252 | | - encoding: GroupedBarChartEncoding |
253 | | - |
254 | | - |
255 | | -class StackedBarChartYEncoding(ChartSchema.ChartEncoding): |
256 | | - stack: Literal["zero"] = Field(default="zero") |
257 | | - |
258 | | - |
259 | | -class StackedBarChartSchema(ChartSchema): |
260 | | - class StackedBarChartMark(BaseModel): |
261 | | - type: Literal["bar"] = Field(default="bar") |
262 | | - |
263 | | - class StackedBarChartEncoding(BaseModel): |
264 | | - x: TemporalChartEncoding | ChartSchema.ChartEncoding |
265 | | - y: StackedBarChartYEncoding |
266 | | - color: ChartSchema.ChartEncoding |
267 | | - |
268 | | - mark: StackedBarChartMark |
269 | | - encoding: StackedBarChartEncoding |
270 | | - |
271 | | - |
272 | | -class PieChartSchema(ChartSchema): |
273 | | - class PieChartMark(BaseModel): |
274 | | - type: Literal["arc"] = Field(default="arc") |
275 | | - |
276 | | - class PieChartEncoding(BaseModel): |
277 | | - theta: ChartSchema.ChartEncoding |
278 | | - color: ChartSchema.ChartEncoding |
279 | | - |
280 | | - mark: PieChartMark |
281 | | - encoding: PieChartEncoding |
282 | | - |
283 | | - |
284 | | -class AreaChartSchema(ChartSchema): |
285 | | - class AreaChartMark(BaseModel): |
286 | | - type: Literal["area"] = Field(default="area") |
287 | | - |
288 | | - class AreaChartEncoding(BaseModel): |
289 | | - x: TemporalChartEncoding | ChartSchema.ChartEncoding |
290 | | - y: ChartSchema.ChartEncoding |
291 | | - |
292 | | - mark: AreaChartMark |
293 | | - encoding: AreaChartEncoding |
294 | | - |
295 | | - |
296 | 126 | class ChartGenerationResults(BaseModel): |
297 | 127 | reasoning: str |
298 | | - chart_type: Optional[ |
299 | | - Literal[ |
300 | | - "line", "multi_line", "bar", "pie", "grouped_bar", "stacked_bar", "area", "" |
301 | | - ] |
302 | | - ] = "" # empty string for no chart |
303 | 128 | chart_schema: dict[str, Any] |
| 129 | + chart_type: Optional[str] = "" # deprecated |
0 commit comments