Skip to content

Commit 4cb5bd1

Browse files
committed
update
1 parent 154e791 commit 4cb5bd1

File tree

2 files changed

+6
-185
lines changed

2 files changed

+6
-185
lines changed

wren-ai-service/src/pipelines/generation/chart_generation.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,31 +13,26 @@
1313
ChartDataPreprocessor,
1414
ChartGenerationPostProcessor,
1515
ChartGenerationResults,
16-
chart_generation_instructions,
1716
)
1817

1918
logger = logging.getLogger("wren-ai-service")
2019

2120

2221
def gen_chart_gen_system_prompt() -> str:
23-
return f"""
22+
return """
2423
### TASK ###
2524
2625
You are a data analyst great at generating data visualization using vega-lite! Given the user's question, SQL, sample data and sample column values, you need to think about the best chart type and generate correspondingvega-lite schema in JSON format.
2726
Besides, you need to give a concise and easy-to-understand reasoning to describe why you provide such vega-lite schema based on the question, SQL, sample data and sample column values.
2827
29-
### INSTRUCTIONS ###
30-
31-
{chart_generation_instructions}
32-
3328
### OUTPUT FORMAT ###
3429
3530
Please provide your chain of thought reasoning, and the vega-lite schema in JSON format.
3631
37-
{{
32+
{
3833
"reasoning": <REASON_TO_CHOOSE_THE_SCHEMA_IN_STRING_FORMATTED_IN_LANGUAGE_PROVIDED_BY_USER>,
3934
"chart_schema": <VEGA_LITE_JSON_SCHEMA>
40-
}}
35+
}
4136
"""
4237

4338

Lines changed: 3 additions & 177 deletions
Original file line numberDiff line numberDiff line change
@@ -1,67 +1,15 @@
11
import logging
2-
from typing import Any, Dict, Literal, Optional
2+
from typing import Any, Dict, Optional
33

44
import orjson
55
import pandas as pd
66
from haystack import component
77
from jsonschema.exceptions import ValidationError
8-
from pydantic import BaseModel, Field
8+
from pydantic import BaseModel
99

1010
logger = logging.getLogger("wren-ai-service")
1111

1212

13-
chart_generation_instructions = """
14-
1. Please check VEGA-LITE SCHEMA SPECIFICATION to make sure the vega-lite schema is valid.
15-
2. The following are examples of several chart types and their corresponding vega-lite schema:
16-
17-
a. Chart Type: Candlestick Chart
18-
When to use: When you want to visualize the price changes of a stock or other financial instrument over time.
19-
Vega-Lite Schema in JSON format(only include the necessary part):
20-
{
21-
"encoding": {
22-
"x": {
23-
"field": "date",
24-
"type": "temporal",
25-
"title": "<title>",
26-
"axis": {
27-
"format": "<format>",
28-
"labelAngle": <labelAngle>,
29-
"title": "<x_axis_title>"
30-
}
31-
},
32-
"y": {
33-
"type": "quantitative",
34-
"scale": {"zero": false},
35-
"axis": {"title": "<y_axis_title>"}
36-
},
37-
"color": {
38-
"condition": {
39-
"test": "datum.open < datum.close",
40-
"value": "#06982d"
41-
},
42-
"value": "#ae1325"
43-
}
44-
},
45-
"layer": [
46-
{
47-
"mark": "rule",
48-
"encoding": {
49-
"y": {"field": "low"},
50-
"y2": {"field": "high"}
51-
}
52-
},
53-
{
54-
"mark": "bar",
55-
"encoding": {
56-
"y": {"field": "open"},
57-
"y2": {"field": "close"}
58-
}
59-
}
60-
]
61-
}
62-
"""
63-
64-
6513
def load_custom_theme() -> Dict[str, Any]:
6614
with open("src/pipelines/generation/utils/theme_powerbi.json", "r") as f:
6715
return orjson.loads(f.read())
@@ -175,129 +123,7 @@ def read_vega_lite_schema() -> Dict[str, Any]:
175123
return vega_lite_schema
176124

177125

178-
class ChartSchema(BaseModel):
179-
class ChartType(BaseModel):
180-
type: Literal["bar", "line", "area", "arc"]
181-
182-
class ChartEncoding(BaseModel):
183-
field: str
184-
type: Literal["ordinal", "quantitative", "nominal"]
185-
title: str
186-
187-
title: str
188-
mark: ChartType
189-
encoding: ChartEncoding
190-
191-
192-
class TemporalChartEncoding(ChartSchema.ChartEncoding):
193-
type: Literal["temporal"] = Field(default="temporal")
194-
timeUnit: str = Field(default="yearmonth")
195-
196-
197-
class LineChartSchema(ChartSchema):
198-
class LineChartMark(BaseModel):
199-
type: Literal["line"] = Field(default="line")
200-
201-
class LineChartEncoding(BaseModel):
202-
x: TemporalChartEncoding | ChartSchema.ChartEncoding
203-
y: ChartSchema.ChartEncoding
204-
color: ChartSchema.ChartEncoding
205-
206-
mark: LineChartMark
207-
encoding: LineChartEncoding
208-
209-
210-
class MultiLineChartSchema(ChartSchema):
211-
class MultiLineChartMark(BaseModel):
212-
type: Literal["line"] = Field(default="line")
213-
214-
class MultiLineChartTransform(BaseModel):
215-
fold: list[str]
216-
as_: list[str] = Field(alias="as")
217-
218-
class MultiLineChartEncoding(BaseModel):
219-
x: TemporalChartEncoding | ChartSchema.ChartEncoding
220-
y: ChartSchema.ChartEncoding
221-
color: ChartSchema.ChartEncoding
222-
223-
mark: MultiLineChartMark
224-
transform: list[MultiLineChartTransform]
225-
encoding: MultiLineChartEncoding
226-
227-
228-
class BarChartSchema(ChartSchema):
229-
class BarChartMark(BaseModel):
230-
type: Literal["bar"] = Field(default="bar")
231-
232-
class BarChartEncoding(BaseModel):
233-
x: TemporalChartEncoding | ChartSchema.ChartEncoding
234-
y: ChartSchema.ChartEncoding
235-
color: ChartSchema.ChartEncoding
236-
237-
mark: BarChartMark
238-
encoding: BarChartEncoding
239-
240-
241-
class GroupedBarChartSchema(ChartSchema):
242-
class GroupedBarChartMark(BaseModel):
243-
type: Literal["bar"] = Field(default="bar")
244-
245-
class GroupedBarChartEncoding(BaseModel):
246-
x: TemporalChartEncoding | ChartSchema.ChartEncoding
247-
y: ChartSchema.ChartEncoding
248-
xOffset: ChartSchema.ChartEncoding
249-
color: ChartSchema.ChartEncoding
250-
251-
mark: GroupedBarChartMark
252-
encoding: GroupedBarChartEncoding
253-
254-
255-
class StackedBarChartYEncoding(ChartSchema.ChartEncoding):
256-
stack: Literal["zero"] = Field(default="zero")
257-
258-
259-
class StackedBarChartSchema(ChartSchema):
260-
class StackedBarChartMark(BaseModel):
261-
type: Literal["bar"] = Field(default="bar")
262-
263-
class StackedBarChartEncoding(BaseModel):
264-
x: TemporalChartEncoding | ChartSchema.ChartEncoding
265-
y: StackedBarChartYEncoding
266-
color: ChartSchema.ChartEncoding
267-
268-
mark: StackedBarChartMark
269-
encoding: StackedBarChartEncoding
270-
271-
272-
class PieChartSchema(ChartSchema):
273-
class PieChartMark(BaseModel):
274-
type: Literal["arc"] = Field(default="arc")
275-
276-
class PieChartEncoding(BaseModel):
277-
theta: ChartSchema.ChartEncoding
278-
color: ChartSchema.ChartEncoding
279-
280-
mark: PieChartMark
281-
encoding: PieChartEncoding
282-
283-
284-
class AreaChartSchema(ChartSchema):
285-
class AreaChartMark(BaseModel):
286-
type: Literal["area"] = Field(default="area")
287-
288-
class AreaChartEncoding(BaseModel):
289-
x: TemporalChartEncoding | ChartSchema.ChartEncoding
290-
y: ChartSchema.ChartEncoding
291-
292-
mark: AreaChartMark
293-
encoding: AreaChartEncoding
294-
295-
296126
class ChartGenerationResults(BaseModel):
297127
reasoning: str
298-
chart_type: Optional[
299-
Literal[
300-
"line", "multi_line", "bar", "pie", "grouped_bar", "stacked_bar", "area", ""
301-
]
302-
] = "" # empty string for no chart
303128
chart_schema: dict[str, Any]
129+
chart_type: Optional[str] = "" # deprecated

0 commit comments

Comments
 (0)