diff --git a/app/api/serializers.py b/app/api/serializers.py index eddeda1dd..ed61f8205 100644 --- a/app/api/serializers.py +++ b/app/api/serializers.py @@ -728,6 +728,11 @@ class PromptGenerationRequestSerializer(serializers.Serializer): prompt = serializers.CharField(min_length=1, max_length=10000) +class PromptStreamingRequestSerializer(serializers.Serializer): + conversation = serializers.ListField() + system_prompt = serializers.CharField(required=False) + + # Used for incoming requests to copy a widget instance. Does NOT map to a model. class WidgetInstanceCopyRequestSerializer(serializers.Serializer): new_name = serializers.ModelField( diff --git a/app/api/urls/api_urls.py b/app/api/urls/api_urls.py index aadd51bfb..d9bd3fa7d 100644 --- a/app/api/urls/api_urls.py +++ b/app/api/urls/api_urls.py @@ -38,5 +38,6 @@ # AI generation path("generate/qset/", generation.GenerateQsetView.as_view()), path("generate/from_prompt/", generation.GenerateFromPromptView.as_view()), + path("generate/streaming/", generation.GenerateStreamingResponseView.as_view()), path("lti//instances/", LtiWidgetInstancesInCourseView.as_view()), ] diff --git a/app/api/views/generation.py b/app/api/views/generation.py index 8cb99d2ad..d151a8cb9 100644 --- a/app/api/views/generation.py +++ b/app/api/views/generation.py @@ -1,15 +1,20 @@ +import logging import re -from rest_framework.response import Response -from rest_framework.views import APIView - from api.permissions import CanCreateWidgetInstances from api.serializers import ( - QsetGenerationRequestSerializer, PromptGenerationRequestSerializer, + PromptStreamingRequestSerializer, + QsetGenerationRequestSerializer, ) -from core.services.generator_service import GenerationUtil -from core.message_exception import MsgNoLogin, MsgInvalidInput, MsgFailure +from core.message_exception import MsgFailure, MsgInvalidInput, MsgNoLogin +from django.http import StreamingHttpResponse +from generation.core import GenerationCore +from generation.factory import GenerationDriverFactory +from rest_framework.response import Response +from rest_framework.views import APIView + +logger = logging.getLogger(__name__) class GenerateQsetView(APIView): @@ -18,7 +23,7 @@ class GenerateQsetView(APIView): def post(self, request): # Check if generation is available - if not GenerationUtil.is_enabled(): + if not GenerationCore.is_enabled(): raise MsgFailure( msg="AI generation is not enabled on this instance of Materia" ) @@ -52,19 +57,20 @@ def post(self, request): if num_questions > 32: num_questions = 32 - # Generate qset - result = GenerationUtil.generate_qset( + # Get the appropriate driver and generate qset + driver = GenerationDriverFactory.get_driver() + result = driver.generate_qset( widget=widget, - instance=widget_instance, topic=topic, num_questions=num_questions, build_off_existing=build_off_existing, + instance=widget_instance, ) # Return generated qset return Response( { - **result, + "qset": result, "title": topic, } ) @@ -76,7 +82,7 @@ class GenerateFromPromptView(APIView): def post(self, request): # Check if generation is available - if not GenerationUtil.is_enabled(): + if not GenerationCore.is_enabled(): raise MsgFailure( msg="AI generation is not enabled on this instance of Materia" ) @@ -87,11 +93,40 @@ def post(self, request): prompt = request_serializer.validated_data["prompt"] - # Perform generation - result = GenerationUtil.generate_from_prompt(prompt) + # Get the appropriate driver and perform generation + driver = GenerationDriverFactory.get_driver() + result = driver.query_sync(prompt) + return Response( { "success": True, "response": result, } ) + + +class GenerateStreamingResponseView(APIView): + http_method_names = ["post"] + permission_classes = [CanCreateWidgetInstances] + + def post(self, request): + # Check if generation is available + if not GenerationCore.is_enabled(): + raise MsgFailure( + msg="AI generation is not enabled on this instance of Materia" + ) + + request_serializer = PromptStreamingRequestSerializer(data=request.data) + request_serializer.is_valid(raise_exception=True) + + messages = request_serializer.validated_data["conversation"] + system_prompt = request_serializer.validated_data["system_prompt"] + + driver = GenerationDriverFactory.get_driver() + response = StreamingHttpResponse( + driver.generate_prompt_stream(messages, system_prompt), + content_type="text/event-stream", + ) + response["Cache-Control"] = "no-cache" + response["X-Accel-Buffering"] = "no" + return response diff --git a/app/core/services/boto_session_service.py b/app/core/services/boto_session_service.py new file mode 100644 index 000000000..8c006c97e --- /dev/null +++ b/app/core/services/boto_session_service.py @@ -0,0 +1,39 @@ +import logging + +import boto3 +from django.conf import settings + +logger = logging.getLogger(__name__) + + +class BotoSessionService: + + # TODO: move client-related credential configs to another configuration location + # instead of just s3? + + @staticmethod + def get_session(): + # Configure credentials depending on whether we're providing them from env or Amazon's IMDSv2 service + # IMDS is HIGHLY recommended for prod usage on AWS + session = None + s = settings.AWS_SETTINGS + if s["credential_provider"] == "imds": + # Credentials are sourced from the EC2 instance's IAM role + session = boto3.Session() + elif s["credential_provider"] == "env": + session_config = { + "region_name": s["region"], + } + if s["profile"] is not None: + session_config["profile_name"] = s["profile"] + else: + session_config["aws_access_key_id"] = (s["key"],) + session_config["aws_secret_access_key"] = (s["secret_key"],) + session_config["aws_session_token"] = (s["aws_session_token"],) + session = boto3.Session(**session_config) + else: + raise Exception( + "boto3: Failed to determine credential provider. Did you set the appropriate environment variable?" + ) + + return session diff --git a/app/core/services/generator_service.py b/app/core/services/generator_service.py deleted file mode 100644 index 129b6170c..000000000 --- a/app/core/services/generator_service.py +++ /dev/null @@ -1,346 +0,0 @@ -import json -import logging -from datetime import datetime - -from core.models import Widget, WidgetInstance -from django.conf import settings -from openai import NOT_GIVEN, OpenAI, OpenAIError -from openai.types.chat import ChatCompletion -from core.message_exception import ( - MsgInvalidInput, - MsgNotFound, - MsgFailure, - MsgException, -) - -logger = logging.getLogger(__name__) - - -class GenerationUtil: - client = None - - @staticmethod - def generate_qset( - widget: Widget, - topic: str, - num_questions: int, - build_off_existing: bool, - include_images: bool = False, - instance: WidgetInstance = None, - ) -> dict: - # Check if generation is enabled - if not GenerationUtil.is_enabled(): - raise MsgFailure(msg="Generation is not enabled") - - # Check if image generation is allowed. Overrides what parameter says. - if not settings.AI_GENERATION["ALLOW_IMAGES"]: - include_images = False - - # Get demo for widget - widget_demo_id = widget.metadata.get("demo") - widget_demo = WidgetInstance.objects.filter(id=widget_demo_id).first() - if widget_demo is None: - raise MsgNotFound() - - # Prepare a few variables - about = widget.metadata.get("about") - qset_version = 1 - - # Grab custom prompt from the widget engine if it's available - custom_engine_prompt = ( - widget.metadata["custom_engine_prompt"] - if "custom_engine_prompt" in widget.metadata - else None - ) - - # Start logging time - start_time = datetime.now() - time_elapsed_seconds = 0 - - ################### - # Assemble prompt # - ################### - - # build_off_existing is set. Append questions to an existing qset. The instance must have been previously saved. - if build_off_existing: - # Validate instance - qset = instance.get_latest_qset() - if instance is None: - raise MsgInvalidInput( - msg="Requires a previously saved instance to build from" - ) - if not qset.data: - raise MsgFailure(msg="No existing question set found") - if qset.version: - qset_version = qset.version - - qset_encoded = json.dumps(qset.get_data()) - - prompt_text = ( - f"{widget.name} is a 'widget', an interactive piece of educational web content described " - f"as:'{about}'. Using the exact same json format of the following question set, without " - f"changing any field keys or data types and without changing any of the existing questions, " - f"generate {num_questions} more questions and add them to the existing question set. The " - f"name of this particular instance of {widget.name} is {instance.name} and the new " - f"questions must be based on this topic: '{topic}'. Return only the JSON for the resulting " - f"question set." - ) - - # Generate image descriptions, if requested - if include_images: - prompt_text += ( - " In every asset or assets object in each new question, add a field titled " - "'description' that best describes the image within the answer or question's context, " - "unless otherwise specified later on in this prompt. Do not generate descriptions " - "that would violate OpenAI's image generation safety system and do not use real " - "names. IDs must be null." - ) - else: - prompt_text += ( - " Leave the asset field empty or otherwise equivalent to asset fields in questions " - "with no associated asset. IDs must be null." - ) - - # Insert custom engine prompt, if it exists - if custom_engine_prompt: - prompt_text += ( - f"Lastly, the following instructions apply to the {widget.name} widget specifically, " - f"and supersede earlier instructions where applicable: {custom_engine_prompt}" - ) - - # Insert existing qset - prompt_text += f"\n{qset_encoded}" - - # Building a brand new qset - else: - # Validate/process demo - qset = widget_demo.get_latest_qset() - if not qset: - raise MsgNotFound( - msg="Unable to locate demo question set for widget engine" - ) - if qset.version: - qset_version = qset.version - - qset_encoded = json.dumps(qset.get_data()) - - prompt_text = ( - f"{widget.name} is a 'widget', an interactive piece of educational web content described " - f"as: '{about}'. The following is a 'demo' question set for the widget titled " - f"'{widget_demo.name}'. Using the same json format as the demo question set, and without " - f"changing any field keys or data types, return only the JSON for a question set based on " - f"this topic: '{topic}'. Ignore the topic of the demo contents entirely. Replace the " - f"relevant field values with generated values. Generate a total {num_questions} of " - f"questions. IDs must be NULL." - ) - - # Generate image descriptions, if requested - if include_images: - prompt_text += ( - " In every asset or assets object in each new question, add a field titled " - "'description' that best describes the image within the answer or question's context, " - "unless otherwise specified later on in this prompt. Do not generate descriptions " - "that would violate OpenAI's image generation safety system and do not use real " - "names. IDs must be null." - ) - else: - prompt_text += ( - "Asset fields associated with media (image, audio, or video) should be left blank. For " - "text assets, or if the 'materiaType' of an asset is 'text', create a field titled " - "'value' with the text inside the asset object." - ) - - # Insert custom engine prompt, if it exists - if custom_engine_prompt: - prompt_text += ( - f" Lastly, the following instructions apply to the {widget.name} widget specifically, " - f"and supersede earlier instructions where applicable: {custom_engine_prompt}" - ) - - # Insert qset - prompt_text += f"\n{qset_encoded}" - - # Send the prompt to the generative AI provider - try: - result = GenerationUtil._query(prompt_text, "json") - time_elapsed_seconds = datetime.now().timestamp() - start_time.timestamp() - except MsgException as e: - logger.error( - f"Error generating question set:\n" - f"- Widget: {widget.name}\n" - f"- Date: {datetime.now()}\n" - f"- Time to complete (seconds): {time_elapsed_seconds}\n" - f"- Number of questions asked to generate: {num_questions}", - exc_info=True, - ) - raise e - - # A qset was received - decode it - content = result.choices[0].message.content - qset = json.loads(content) - logger.info("Generated question set received: %s", qset) - - # Log - if settings.AI_GENERATION["LOG_STATS"]: - logger.debug( - f"Successfully generated question set:\n" - f"- Widget: {widget.name}\n" - f"- Date: {datetime.now()}\n" - f"- Time to complete (seconds): {time_elapsed_seconds}\n" - f"- Number of questions asked to generate: {num_questions}\n" - f"- Included images: {include_images}\n" - f"- Prompt tokens: {result.usage.prompt_tokens}\n" - f"- Completion tokens: {result.usage.completion_tokens}\n" - f"- Total tokens: {result.usage.total_tokens}\n" - ) - - # Done! - return { - "qset": qset, - "version": qset_version, - } - - @staticmethod - def generate_from_prompt(prompt: str) -> str: - # Check if generation is enabled - if not GenerationUtil.is_enabled(): - raise MsgFailure(msg="Generation is not enabled") - - # Do query - try: - result = GenerationUtil._query(prompt, "text") - except MsgException as e: - logger.error( - "GENERATION UTIL: Error while generation prompt:\n - Prompt: %s", - prompt, - exc_info=True, - ) - raise e - - return result.choices[0].message.content - - @staticmethod - def is_enabled() -> bool: - return bool(settings.AI_GENERATION["ENABLED"]) - - @staticmethod - def _query(prompt: str, response_format: str = "json") -> ChatCompletion: - # Get client - client = GenerationUtil._get_client() - if client is None: - raise MsgFailure(msg="Failed to initialize generation client") - - # Process response format - # TODO in the future, maybe look into supporting json_schema? it looks neat - match response_format: - case "json": - response_format = {"type": "json_object"} - case "text": - response_format = {"type": "text"} - case _: - response_format = NOT_GIVEN - - try: - completion = client.chat.completions.create( - model=settings.AI_GENERATION["MODEL"], - messages=[{"role": "user", "content": prompt}], - max_tokens=16000, - frequency_penalty=0, - presence_penalty=0, - temperature=1, - top_p=1, - response_format=response_format, - ) - except OpenAIError: - logger.error( - "GENERATION ERROR: Client threw an error while attempting completion.", - exc_info=True, - ) - raise MsgFailure(msg="Client threw an error while attempting completion") - except Exception: - logger.error( - "GENERATION ERROR: Unknown error occurred while attempting completion.", - exc_info=True, - ) - raise MsgFailure(msg="Unknown error occurred while attempting completion") - - # Check for refusal - message = completion.choices[0].message - if message.refusal is not None: - logger.error( - "GENERATION ERROR: Provider actively refused to run completion. Reason given: '%s'", - message.refusal, - ) - raise MsgFailure(msg="Provider actively refused to run completion") - - return completion - - @staticmethod - def _get_client(): - # Initialize the client if not loaded - if GenerationUtil.client is None and GenerationUtil.is_enabled(): - GenerationUtil.client = GenerationUtil._initialize_client() - return GenerationUtil.client - - @staticmethod - def _initialize_client(): - client = None - - # Check if provider is provided (lol) - if not settings.AI_GENERATION["PROVIDER"]: - logger.error( - "GENERATION ERROR: Question generation provider config missing." - ) - return None - - # Set up based on type of provider - # AZURE OPENAI - if settings.AI_GENERATION["PROVIDER"] == "azure_openai": - api_key = settings.AI_GENERATION["API_KEY"] - endpoint = settings.AI_GENERATION["ENDPOINT"] - model = settings.AI_GENERATION["MODEL"] - - if not api_key or not endpoint or not model: - logger.error( - "GENERATION ERROR: Azure OpenAI question generation configs missing." - ) - return None - - try: - client = OpenAI( - api_key=api_key, - base_url=endpoint + "openai/v1", - ) - except Exception: - logger.error( - "GENERATION ERROR: Failed to initialize Azure OpenAI client", - exc_info=True, - ) - - # OPENAI - elif settings.AI_GENERATION["PROVIDER"] == "openai": - api_key = settings.AI_GENERATION["API_KEY"] - model = settings.AI_GENERATION["MODEL"] - - if not api_key or not model: - logger.error( - "GENERATION ERROR: OpenAI Platform question generation configs missing." - ) - return None - - try: - client = OpenAI(api_key=api_key) - except Exception: - logger.error( - "GENERATION ERROR: Failed to initialize OpenAI client", - exc_info=True, - ) - - # NOT A SUPPORTED PROVIDER - else: - logger.error( - "GENERATION ERROR: Question generation provider config invalid." - ) - return None - - return client diff --git a/app/generation/__init__.py b/app/generation/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/app/generation/bedrock.py b/app/generation/bedrock.py new file mode 100644 index 000000000..93d56c42c --- /dev/null +++ b/app/generation/bedrock.py @@ -0,0 +1,139 @@ +import json +import logging +import threading +from typing import Generator + +from core.message_exception import MsgFailure +from core.models import Widget, WidgetInstance +from core.services.boto_session_service import BotoSessionService +from django.conf import settings +from generation.core import GenerationCore, GenerationDriver + +logger = logging.getLogger(__name__) + +_bedrock_cache_lock = threading.Lock() +_bedrock_client_cache = None +_bedrock_resource_cache = None +_bedrock_cache_config = None + + +class BedrockGenerationDriver(GenerationDriver): + + @staticmethod + def get_client(): + global _bedrock_client_cache, _bedrock_resource_cache, _bedrock_cache_config, _bedrock_cache_lock + s = settings.AI_GENERATION + + # TODO what settings should be hashed to the config? + current_config = s.get("GENERATION_API_MODEL") + + # Thread-safe cache check and initialization + with _bedrock_cache_lock: + # If config changed, invalidate cache + if _bedrock_cache_config != current_config: + _bedrock_cache_config = current_config + + # Return cached client/resource if available + if _bedrock_client_cache is not None: + return _bedrock_client_cache + + try: + session = BotoSessionService.get_session() + except Exception: + logger.error("Boto3: Failed to create session", exc_info=True) + raise + + _bedrock_client_cache = session.client( + "bedrock-runtime", region_name="us-east-1" + ) + return _bedrock_client_cache + + @staticmethod + def query_sync(prompt: str, response_format: str = "json") -> str: + + # Get Bedrock client + client = BedrockGenerationDriver.get_client() + + request_body = { + "anthropic_version": "bedrock-2023-05-31", + "messages": [{"role": "user", "content": prompt}], + "max_tokens": 4096, + } + + response = client.invoke_model( + modelId=settings.AI_GENERATION["MODEL"], body=json.dumps(request_body) + ) + + response_body = json.loads(response["body"].read()) + generated_text = response_body.get("content", [{}])[0].get("text", "") + + return generated_text + + @staticmethod + def query_streaming( + messages: list, system_prompt=None + ) -> Generator[str, None, None]: + + client = BedrockGenerationDriver.get_client() + + request_body = { + "anthropic_version": "bedrock-2023-05-31", + "messages": messages, + "max_tokens": 4096, + } + + if system_prompt is not None: + request_body["system"] = system_prompt + + response = client.invoke_model_with_response_stream( + modelId=settings.AI_GENERATION["MODEL"], body=json.dumps(request_body) + ) + + for event in response["body"]: + chunk = json.loads(event["chunk"]["bytes"]) + if chunk.get("type") == "content_block_delta": + delta = chunk.get("delta", {}) + if delta.get("type") == "text_delta": + yield f"data: {json.dumps({'text': delta['text']})}\n\n" + + yield "data: [DONE]\n\n" + + @staticmethod + def generate_prompt_stream(messages: list, system_prompt: str | None): + yield from BedrockGenerationDriver.query_streaming(messages, system_prompt) + + @staticmethod + def generate_qset( + widget: Widget, + topic: str, + num_questions: int, + build_off_existing: bool, + instance: WidgetInstance = None, + ) -> dict: + + prompt = GenerationCore.generate_qset_prompt( + widget, topic, num_questions, build_off_existing, instance + ) + + try: + result = BedrockGenerationDriver.query_sync(prompt) + parsed = json.loads(result) + + return parsed + + except Exception: + logger.error( + "Failed to generate qset from prompt with error", exc_info=True + ) + + raise MsgFailure() + + @staticmethod + def generate_from_prompt(prompt: str) -> str: + try: + result = BedrockGenerationDriver.query_sync(prompt) + except Exception as e: + logger.error("Generation failure for prompt %s", prompt, exc_info=True) + raise e + + return result diff --git a/app/generation/core.py b/app/generation/core.py new file mode 100644 index 000000000..1c52a85e8 --- /dev/null +++ b/app/generation/core.py @@ -0,0 +1,157 @@ +import json +from abc import ABC, abstractmethod + +from core.message_exception import MsgFailure, MsgInvalidInput, MsgNotFound +from core.models import Widget, WidgetInstance +from django.conf import settings + + +class GenerationCore: + + @staticmethod + def is_enabled() -> bool: + return bool(settings.AI_GENERATION["ENABLED"]) + + @staticmethod + def generate_qset_prompt( + widget: Widget, + topic: str, + num_questions: int, + build_off_existing: bool, + instance: WidgetInstance = None, + ) -> str: + + if not GenerationCore.is_enabled(): + raise MsgFailure(msg="Generation is not enabled") + + # Get demo for widget + widget_demo_id = widget.metadata.get("demo") + widget_demo = WidgetInstance.objects.filter(id=widget_demo_id).first() + if widget_demo is None: + raise MsgNotFound() + + # Prepare a few variables + about = widget.metadata.get("about") + # qset_version = 1 + + # Grab custom prompt from the widget engine if it's available + custom_engine_prompt = ( + widget.metadata["custom_engine_prompt"] + if "custom_engine_prompt" in widget.metadata + else None + ) + + # # Start logging time + # start_time = datetime.now() + # time_elapsed_seconds = 0 + + if build_off_existing: + if instance is None: + raise MsgInvalidInput( + msg="Requires a previously saved instance to build from" + ) + qset = instance.get_latest_qset() + if not qset.data: + raise MsgFailure(msg="No existing question set found") + # if qset.version: + # qset_version = qset.version + + qset_encoded = json.dumps(qset.get_data()) + + prompt_text = ( + f"{widget.name} is a 'widget', an interactive piece of educational web content described " + f"as:'{about}'. Using the exact same json format of the following question set, without " + f"changing any field keys or data types and without changing any of the existing questions, " + f"generate {num_questions} more questions and add them to the existing question set. The " + f"name of this particular instance of {widget.name} is {instance.name} and the new " + f"questions must be based on this topic: '{topic}'. Return only the JSON for the resulting " + f"question set." + " Leave the asset field empty or otherwise equivalent to asset fields in questions " + "with no associated asset. IDs must be null." + ) + + else: + # Validate/process demo + qset = widget_demo.get_latest_qset() + if not qset: + raise MsgNotFound( + msg="Unable to locate demo question set for widget engine" + ) + # if qset.version: + # qset_version = qset.version + + qset_encoded = json.dumps(qset.get_data()) + + prompt_text = ( + f"{widget.name} is a 'widget', an interactive piece of educational web content described " + f"as: '{about}'. The following is a 'demo' question set for the widget titled " + f"'{widget_demo.name}'. Using the same json format as the demo question set, and without " + f"changing any field keys or data types, return only the JSON for a question set based on " + f"this topic: '{topic}'. Ignore the topic of the demo contents entirely. Replace the " + f"relevant field values with generated values. Generate a total {num_questions} of " + f"questions. IDs must be NULL." + "Asset fields associated with media (image, audio, or video) should be left blank. For " + "text assets, or if the 'materiaType' of an asset is 'text', create a field titled " + "'value' with the text inside the asset object." + "Please return ONLY the json object without any formatting or qualifying characters such as " + "backticks. The json object is ideally interpretable by python's json library without any " + "kind of string sanitization." + ) + + # Insert custom engine prompt, if it exists + if custom_engine_prompt: + prompt_text += ( + f" Lastly, the following instructions apply to the {widget.name} widget specifically, " + f"and supersede earlier instructions where applicable: {custom_engine_prompt}" + ) + + # Insert qset + prompt_text += f"\n{qset_encoded}" + + return prompt_text + + +class GenerationDriver(ABC): + """Abstract base class defining the interface all generation drivers must implement""" + + @staticmethod + @abstractmethod + def get_client(): + """ + Get or create the provider-specific client. + Returns: Provider client instance + """ + pass + + @staticmethod + @abstractmethod + def query(prompt: str, response_format: str = "json") -> str: + """ + Perform the query for a given prompt. + Returns: the result text from the model. + """ + pass + + @staticmethod + @abstractmethod + def generate_qset( + widget: Widget, + topic: str, + num_questions: int, + build_off_existing: bool, + instance: WidgetInstance = None, + ) -> dict: + """ + Generate a question set using the provider's API. + + Args: + widget: The widget to generate questions for + topic: The topic to generate questions about + num_questions: Number of questions to generate + build_off_existing: Whether to add to existing questions + instance: Existing widget instance (required if build_off_existing=True) + + Returns: + dict: Generated question set data + """ + pass diff --git a/app/generation/factory.py b/app/generation/factory.py new file mode 100644 index 000000000..ffc63c9d1 --- /dev/null +++ b/app/generation/factory.py @@ -0,0 +1,61 @@ +import logging +from importlib import import_module + +from core.message_exception import MsgFailure +from django.conf import settings +from generation.core import GenerationDriver + +logger = logging.getLogger(__name__) + + +class GenerationDriverFactory: + + _drivers = { + "bedrock": "generation.bedrock.BedrockGenerationDriver", + } + + _cached_driver = None + _cached_provider = None + + @classmethod + def get_driver(cls) -> GenerationDriver: + + provider = settings.AI_GENERATION.get("PROVIDER") + + if cls._cached_driver is not None and cls._cached_provider == provider: + return cls._cached_driver + + if not provider: + logger.error("AI_GENERATION PROVIDER not configured") + raise MsgFailure(msg="AI generation provider not configured") + + # Normalize provider name (case-insensitive) + provider_key = provider.lower() + + # Validate provider is supported + if provider_key not in cls._drivers: + logger.error(f"Unknown AI generation provider: {provider}") + raise MsgFailure(msg=f"Unknown AI generation provider: {provider}") + + try: + driver_path = cls._drivers[provider_key] + module_path, class_name = driver_path.rsplit(".", 1) + module = import_module(module_path) + driver_class = getattr(module, class_name) + + cls._cached_driver = driver_class + cls._cached_provider = provider + + return driver_class + + except (ImportError, AttributeError) as e: + logger.error( + f"Failed to load generation driver for {provider}: {e}", exc_info=True + ) + raise MsgFailure(msg="Failed to initialize AI generation driver") + + @classmethod + def clear_cache(cls): + """Clear the cached driver instance. Useful for testing or config changes.""" + cls._cached_driver = None + cls._cached_provider = None diff --git a/app/materia/settings/aws.py b/app/materia/settings/aws.py new file mode 100644 index 000000000..2a79afb54 --- /dev/null +++ b/app/materia/settings/aws.py @@ -0,0 +1,20 @@ +import os + +AWS_SETTINGS = { + "key": os.environ.get( + "AWS_ACCESS_KEY_ID", os.environ.get("ASSET_STORAGE_S3_KEY", "KEY") + ), + "secret_key": os.environ.get( + "AWS_SECRET_ACCESS_KEY", + os.environ.get("ASSET_STORAGE_S3_SECRET", "SECRET"), + ), + "session_token": os.environ.get("AWS_SESSION_TOKEN"), + "credential_provider": os.environ.get( + "AWS_CREDENTIAL_PROVIDER", + os.environ.get("ASSET_STORAGE_S3_CREDENTIAL_PROVIDER", "env"), + ), + "region": os.environ.get( + "AWS_REGION", os.environ.get("ASSET_STORAGE_S3_REGION", "us-east-1") + ), + "profile": os.environ.get("AWS_PROFILE", None), +} diff --git a/app/materia/settings/base.py b/app/materia/settings/base.py index 899d9e198..cdcf47893 100644 --- a/app/materia/settings/base.py +++ b/app/materia/settings/base.py @@ -8,6 +8,7 @@ from sentry_sdk.integrations.django import DjangoIntegration from .apps import * # noqa: F401, F403 +from .aws import * # noqa: F401, F403 from .css import * # noqa: F401, F403 from .db import * # noqa: F401, F403 diff --git a/app/materia/settings/generation.py b/app/materia/settings/generation.py index 4477efc97..427c10b4f 100644 --- a/app/materia/settings/generation.py +++ b/app/materia/settings/generation.py @@ -4,9 +4,9 @@ AI_GENERATION = { "ENABLED": ValidatorUtil.validate_bool(os.environ.get("GENERATION_ENABLED"), False), - "ALLOW_IMAGES": ValidatorUtil.validate_bool( - os.environ.get("GENERATION_ALLOW_IMAGES"), False - ), + # "ALLOW_IMAGES": ValidatorUtil.validate_bool( + # os.environ.get("GENERATION_ALLOW_IMAGES"), False + # ), "PROVIDER": os.environ.get("GENERATION_API_PROVIDER"), "ENDPOINT": os.environ.get("GENERATION_API_ENDPOINT"), "API_KEY": os.environ.get("GENERATION_API_KEY"), diff --git a/app/storage/s3.py b/app/storage/s3.py index 3d4d9bd44..f81830b86 100644 --- a/app/storage/s3.py +++ b/app/storage/s3.py @@ -4,8 +4,8 @@ import tempfile import threading -import boto3 import botocore +from core.services.boto_session_service import BotoSessionService from django.conf import settings from django.http import HttpResponseNotFound, HttpResponseRedirect @@ -92,23 +92,7 @@ def get_s3(get_client=False): # Create new session and client/resource # We only cache the client/resource objects created from it try: - # Configure credentials depending on whether we're providing them from env or Amazon's IMDSv2 service - # IMDS is HIGHLY recommended for prod usage on AWS - session = None - if s["credential_provider"] == "imds": - # Credentials are sourced from the EC2 instance's IAM role - session = boto3.Session() - elif s["credential_provider"] == "env": - session_config = { - "region_name": s["region"], - "aws_access_key_id": s["key"], - "aws_secret_access_key": s["secret_key"], - } - session = boto3.Session(**session_config) - else: - raise Exception( - "S3: Failed to determine credential provider. Did you set the appropriate environment variable?" - ) + session = BotoSessionService.get_session() except Exception: logger.error("S3: Failed to create S3 session.", exc_info=True) raise diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index 23bc49288..f7f269dfd 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -47,6 +47,7 @@ services: - ../app:/var/www/html - uploaded_media:/var/www/html/media/ - ../public/widget:/var/www/html/staticfiles/widget/:rw + - ${HOME}/.aws:/home/www-data/.aws:ro command: /wait_for_it.sh mysql:3306 -t 15 -- gunicorn materia.wsgi:application --bind 0.0.0.0:8001 --reload --workers 4 --threads 2 --timeout 150 --log-level info --access-logfile - --error-logfile - mysql: diff --git a/src/components/hooks/usePlayerPromptStream.jsx b/src/components/hooks/usePlayerPromptStream.jsx new file mode 100644 index 000000000..23e2c9a96 --- /dev/null +++ b/src/components/hooks/usePlayerPromptStream.jsx @@ -0,0 +1,16 @@ +import { useMutation } from 'react-query' +import { apiStreamingResponseGenerate } from '../../util/api' + +export default function usePlayerPromptStream() { + return useMutation( + (variables) => apiStreamingResponseGenerate(variables), + { + onSuccess: (data, variables) => { + variables.successFunc(data) + }, + onError: (err, variables) => { + variables.errorFunc(err) + } + } + ) +} \ No newline at end of file diff --git a/src/components/widget-player.jsx b/src/components/widget-player.jsx index c45c611f4..bf0744108 100644 --- a/src/components/widget-player.jsx +++ b/src/components/widget-player.jsx @@ -6,6 +6,7 @@ import { player } from './materia-constants' import Alert from './alert' import usePlayStorageDataSave from './hooks/usePlayStorageDataSave' import usePlayLogSave from './hooks/usePlayLogSave' +import usePlayerPromptStream from './hooks/usePlayerPromptStream' import LoadingIcon from './loading-icon' import './widget-player.scss' @@ -111,6 +112,8 @@ const WidgetPlayer = ({instanceId, playId, minHeight=0, minWidth=0,showFooter=tr const savePlayLog = usePlayLogSave() const saveStorage = usePlayStorageDataSave() + const sendPlayerPromptStream = usePlayerPromptStream() + const previewPlayId = useMemo(() => { if (!isPreview) return null return crypto.randomUUID() // Generate a random preview play ID @@ -360,6 +363,8 @@ const WidgetPlayer = ({instanceId, playId, minHeight=0, minWidth=0,showFooter=tr return _setHeight(msg.data[0]) case 'setVerticalScroll': return _setVerticalScroll(msg.data[0]) + case 'generationStreamingRequest': + return _submitGenerationRequestForPlayer(msg.data[0], msg.data[1]) case 'initialize': break default: @@ -552,6 +557,25 @@ const WidgetPlayer = ({instanceId, playId, minHeight=0, minWidth=0,showFooter=tr window.scrollTo(0, calculatedLocation) } + const _submitGenerationRequestForPlayer = (conversation, systemPrompt = "") => { + sendPlayerPromptStream.mutate({ + request: { + conversation: conversation, + systemPrompt: systemPrompt + }, + onChunk: (chunk, fullText) => { + _sendToWidget('promptResponse', [fullText, false]) + }, + successFunc: (result) => { + _sendToWidget('promptResponse', [result.response, true]) + }, + errorFunc: (err) => { + console.log('submit generation error!', err) + _sendToWidget('promptRejection', []) + } + }) + } + const _onLoadFail = msg => setAlert({ msg: msg, title: 'Failure!', diff --git a/src/materia/materia.enginecore.js b/src/materia/materia.enginecore.js index 10e55c801..1b9ad2cb4 100644 --- a/src/materia/materia.enginecore.js +++ b/src/materia/materia.enginecore.js @@ -15,6 +15,11 @@ Namespace('Materia').Engine = (() => { _mediaUrl = msg.data[3] _initWidget(msg.data[0], msg.data[1]) break + case 'promptResponse': + _promptResponse(msg.data[0], msg.data[1]) + break + case 'promptRejection': + _promptRejection() default: throw new Error(`Error: Engine Core received unknown post message: ${msg.type}`) break @@ -31,6 +36,14 @@ Namespace('Materia').Engine = (() => { _instance = instance } + const _promptResponse = (response, complete=false) => { + _widgetClass.promptStreamingResponse(response, complete) + } + + const _promptRejection = () => { + _widgetClass.promptStreamingRejection() + } + const start = (widgetClass) => { // setup the postmessage listener addEventListener('message', _onPostMessage, false) @@ -100,6 +113,10 @@ Namespace('Materia').Engine = (() => { const escapeScriptTags = (text) => text.replace(//g, '>') + const submitGenerationRequest = (conversation, systemPrompt) => { + _sendPostMessage('generationStreamingRequest', [conversation, systemPrompt]) + } + return { start, addLog, @@ -113,5 +130,6 @@ Namespace('Materia').Engine = (() => { setHeight, // allows the widget to resize its iframe container to fit the height of its contents setVerticalScroll, // allows the widget to scroll the page to a specific location escapeScriptTags, + submitGenerationRequest, } })() diff --git a/src/util/api.js b/src/util/api.js index a0c0b54b2..a52fe2409 100644 --- a/src/util/api.js +++ b/src/util/api.js @@ -724,6 +724,76 @@ export const apiWidgetPromptGenerate = (prompt) => { return handleRequest(methods.POST, `/api/generate/from_prompt/`, {prompt}) } +/** + * Sends a conversation to the backend and streams the response back to the caller, + * invoking a callback with each received chunk of text as it arrives. + * Unlike other API methods in this file, this function does not use the `handleRequest` + * helper and instead directly interfaces with the Fetch API to handle a + * Server-Sent Events (SSE) stream. + * @param {Object} params - The parameters for the streaming request. + * @param {Object[]} params.request.conversation - The conversation history to send to the backend. + * * @param {Object[]} params.request.systemPrompt - The system prompt to send to the backend. + * @param {Function} [params.onChunk] - Optional callback invoked with each received text chunk. + * Receives two arguments: the latest chunk of text, and the full accumulated text so far. + * @returns {Promise<{response: string}>} - Resolves with an object containing the full + * accumulated response text once the stream is complete. + * @throws {Error} - Throws an HTTP error if the response status is not OK. + */ +export const apiStreamingResponseGenerate = async (params) => { + const conversation = params.request.conversation + const system_prompt = params.request.systemPrompt + const onChunk = params.onChunk // Extract the onChunk callback + + const response = await fetch('/api/generate/streaming/', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'X-CSRFToken': getCSRFToken(), + }, + body: JSON.stringify({conversation, system_prompt}), + }) + + if (!response.ok) { + throw new Error(`HTTP error ${response.status}`) + } + + const reader = response.body.getReader() + const decoder = new TextDecoder() + let buffer = '' + let fullText = '' + + while (true) { + const { done, value } = await reader.read() + if (done) break + + buffer += decoder.decode(value, { stream: true }) + const lines = buffer.split('\n') + buffer = lines.pop() // Keep incomplete line in buffer + + for (const line of lines) { + if (line.startsWith('data: ')) { + const eventData = line.slice(6) + if (eventData === '[DONE]') { + return { response: fullText } + } + try { + const parsed = JSON.parse(eventData) + if (parsed.text) { + fullText += parsed.text + if (onChunk) { + onChunk(parsed.text, fullText) + } + } + } catch (e) { + // Ignore parse errors for non-JSON lines + } + } + } + } + + return { response: fullText } +} + /** * Takes a widget ID, returns whether an update is available and what version * is new when applicable