Shahzad Bhatti Welcome to my ramblings and rants!

November 4, 2025

Building a Production-Grade Enterprise AI Platform with vLLM: A Complete Guide from the Trenches

Filed under: Agentic AI — admin @ 11:48 am

TL;DR: Tested open-source LLM serving (vLLM) on GCP L4 GPUs. Achieved 93% cost savings vs OpenAI GPT-4, 100% routing accuracy, and 91% cache hit rates. Prototype proves feasibility; production requires 5-7 months additional work (security, HA, ops). All code at github.com/bhatti/vllm-tutorial.

Background

Last year, our CEO mandated “AI adoption” across the organization and everyone had access to LLMs through an internal portal that used Vertex AI. However, there was a little training or best practices. I saw engineers using the most expensive models for simple queries, no cost tracking, zero observability into what was being used, and no policies around data handling. People tried AI, built some demos and got mixed results.

This mirrors what’s happening across the industry. Recent research shows 95% of AI pilots fail at large companies, and McKinsey found 42% of companies abandoned generative AI projects citing “no significant bottom line impact.” The 5% that succeed do something fundamentally different: they treat AI as infrastructure requiring proper tooling, not just API access.

This experience drove me to explore better approaches. I built prototypes using vLLM and open-source tools, tested them on GCP L4 GPUs, and documented what actually works. This blog shares those findings with real code, benchmarks, and lessons from building production-ready AI infrastructure. Every benchmark ran on actual hardware (GCP L4 GPUs), every pattern emerged from solving real problems, and all code is available at github.com/bhatti/vllm-tutorial.


Why Hosted LLM Access Isn’t Enough

Even with managed services like Vertex AI or Bedrock, enterprise AI needs additional layers that most organizations overlook:

Cost Management

  • No intelligent routing between models (GPT-4 for simple definitions that Phi-2 could handle)
  • No per-user, per-team budgets or limits
  • No cost attribution or chargeback
  • Result: Unpredictable expenses, no accountability

Observability

  • Can’t track which prompts users send
  • Can’t identify failing queries or quality degradation
  • Can’t measure actual usage patterns
  • Result: Flying blind when issues occur

Security & Governance

  • Data flows through third-party infrastructure
  • No granular access controls beyond API keys
  • Limited audit trails for compliance
  • Result: Compliance gaps, security risks

Performance Control

  • Can’t deploy custom fine-tuned models
  • No A/B testing between models
  • Limited control over routing logic
  • Result: Vendor lock-in, inflexibility

The Solution: vLLM with Production Patterns

After evaluating options, I built prototypes using vLLM—a high-performance inference engine for running open-source LLMs (Llama, Mistral, Phi) on your infrastructure. Think of vLLM as NGINX for LLMs: battle-tested, optimized runtime that makes production deployments feasible.

Why vLLM specifically?

  • PagedAttention: Revolutionary memory management enabling 22.5x higher throughput
  • Continuous batching: Automatically batches requests for maximum efficiency
  • Production-ready: Used by major companies, not experimental
  • Open source: Full control, no vendor lock-in

What I tested:

  • Intelligent model routing (complexity-based selection)
  • Budget enforcement (hard limits, not just monitoring)
  • Prefix caching (80% cost reduction)
  • Quantization (3.7x memory reduction with AWQ)
  • Complete observability (Prometheus + Grafana + Langfuse)
  • Production error handling (retries, circuit breakers, fallbacks)

System Architecture

Here’s the complete system architecture I’ve built and tested:

Production AI requires three monitoring layers:

Layer 1: Infrastructure (Prometheus + Grafana)

  • GPU utilization, memory usage
  • Request rate, error rate, latency (P50, P95, P99)
  • Integration via /metrics endpoint that vLLM exposes
  • Grafana dashboards visualize trends and trigger alerts

Layer 2: Application Metrics

  • Time to First Token (TTFT), tokens per second
  • Cost per request, model distribution
  • Budget tracking (daily, monthly limits)
  • Custom Prometheus metrics embedded in application code

Layer 3: LLM Observability (Langfuse)

  • Full prompt/response history for debugging
  • Cost attribution per user/team
  • Quality tracking over time
  • Essential for understanding what users actually do

Here’s what I’ve built and tested:


Setting Up Your Environment: GCP L4 GPU Setup

Before we dive into the concepts, let’s get your environment ready. I’m using GCP L4 GPUs because they offer the best price/performance for this workload ($0.45/hour), but the code works on any CUDA-capable GPU.

Minimum Hardware Requirements

  • NVIDIA GPU with 16GB+ VRAM (L4, T4, A10G, A100)
  • 4 CPU cores
  • 16GB RAM
  • 100GB disk space

Step 1: Create GCP L4 Instance

# Create instance with L4 GPU
gcloud compute instances create vllm-test \
  --zone=us-central1-a \
  --machine-type=g2-standard-8 \
  --accelerator=type=nvidia-l4,count=1 \
  --image-family=ubuntu-2004-lts \
  --image-project=ubuntu-os-cloud \
  --boot-disk-size=200GB \
  --boot-disk-type=pd-ssd \
  --maintenance-policy=TERMINATE

# SSH into instance
gcloud compute ssh vllm-test --zone=us-central1-a

Step 2: Install CUDA 11.8

# Update system
sudo apt update && sudo apt upgrade -y

# Install CUDA 11.8
wget https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run
sudo sh cuda_11.8.0_520.61.05_linux.run --silent --toolkit

# Add to PATH
echo 'export PATH=/usr/local/cuda-11.8/bin:$PATH' >> ~/.bashrc
echo 'export LD_LIBRARY_PATH=/usr/local/cuda-11.8/lib64:$LD_LIBRARY_PATH' >> ~/.bashrc
source ~/.bashrc

# Verify
nvidia-smi  # Should show your L4 GPU
nvcc --version  # Should show CUDA 11.8

Troubleshooting: If nvidia-smi doesn’t work, reboot the instance: sudo reboot

Step 3: Install Python Dependencies

# Install Python 3.10
sudo apt install -y python3.10 python3.10-venv python3-pip

# Clone the repository
git clone https://github.com/bhatti/vllm-tutorial.git
cd vllm-tutorial

# Create virtual environment
python3 -m venv venv
source venv/bin/activate

# Install dependencies
pip install --upgrade pip
pip install -r requirements.txt

Step 4: Verify Installation

# Test vLLM installation
python -c "import vllm; print(f'vLLM version: {vllm.__version__}')"

# Quick functionality test
python examples/01_basic_vllm.py

Expected output:

Loading model microsoft/phi-2...
Model loaded in 8.3 seconds

Generating response...
Generated 50 tokens in 987ms
Throughput: 41.5 tokens/sec

? vLLM is working!

Quick Start

Before we dive deep, let’s get something running:

  1. Clone the repo:
   git clone https://github.com/bhatti/vllm-tutorial.git
   cd vllm-tutorial
  1. If you have a GPU available:
   # Follow setup instructions in README
   python examples/01_basic_vllm.py
  1. No GPU? Run the benchmarks locally:
   # See the actual results from GCP L4 testing
   cat benchmarks/results/01_throughput_results.json
  1. Explore the code:

Core Concept 1: Intelligent Model Routing

The problem: Not all queries need your most expensive model.

  • “What is EBITDA?” needs a 30-word definition ? Use Phi-2 ($0.0001)
  • “Analyze Microsoft’s 10-K risk factors…” needs deep reasoning ? Use Llama-3-8B ($0.0003)

Most teams send everything to their best model, which is wasteful.

The solution: Route queries to the right model based on complexity.

The Three-Tier Routing Strategy

TierModelUse CasesCost (per 1K tokens)% of Queries
SimplePhi-2 (2.7B)Definitions, facts$0.0001 / 1K60%
MediumMistral-7BSummaries, comparisons$0.0002 / 1K30%
ComplexLlama-3-8BAnalysis, reasoning$0.0003 / 1K10%

Routing Decision Flow

Implementation: Complexity Classification

Here’s how I classify query complexity:

def classify_complexity(self, prompt: str) -> str:
    """
    Classify prompt complexity to select appropriate model

    Rules:
    - Simple: Definitions, quick facts, <50 words
    - Medium: Summaries, comparisons, 50-150 words
    - Complex: Deep analysis, multi-step reasoning, >150 words
    """
    word_count = len(prompt.split())

    # Keywords indicating complexity
    complex_keywords = [
        "analyze", "compare", "evaluate", "assess risk",
        "recommend", "predict", "forecast", "implications"
    ]

    medium_keywords = [
        "summarize", "explain", "describe", "list",
        "what are", "how does", "differences"
    ]

    has_complex = any(kw in prompt.lower() for kw in complex_keywords)
    has_medium = any(kw in prompt.lower() for kw in medium_keywords)

    # Classification logic
    if word_count > 150 or has_complex:
        return "complex"
    elif word_count > 50 or has_medium:
        return "medium"
    else:
        return "simple"

Why this works:

  • Length is a strong signal (detailed questions need detailed answers)
  • Keywords indicate intent (“analyze” needs more reasoning than “define”)
  • Conservative defaults (when in doubt, route up)

Testing Results

I tested this with 11 queries on GCP L4. Here are the actual results:

Query: "What is EBITDA?"
Classified as: simple ? Routed to: Phi-2
Cost: $0.00002038
Latency: 4,843ms (first request, includes model loading)
Quality: ? Perfect (simple definition)

Query: "Summarize Apple's Q4 2024 earnings highlights"
Classified as: medium ? Routed to: Mistral-7B
Cost: $0.00000865
Latency: 4,827ms
Quality: ? Good summary

Query: "Analyze Microsoft's 10-K risk factors and assess their potential impact on future earnings"
Classified as: complex ? Routed to: Llama-3-8B
Cost: $0.00001382
Latency: 4,836ms
Quality: ? Detailed analysis

Accuracy: 100% (11/11 queries routed correctly)
Cost savings: 30% vs routing everything to the most expensive model

Complete Router

Here’s the full intelligent router (you can find this in src/intelligent_router.py):

from typing import Dict, Optional
from dataclasses import dataclass
from vllm import LLM, SamplingParams

@dataclass
class ModelConfig:
    """Configuration for a model tier"""
    name: str
    complexity: str  # "simple", "medium", "complex"
    cost_per_1k_tokens: float
    max_tokens: int

class IntelligentRouter:
    """
    Production-ready intelligent router with:
    - Complexity-based routing
    - Budget enforcement
    - Cost tracking
    - Fallback handling
    """

    def __init__(self, daily_budget_usd: float = 100.0):
        self.daily_budget_usd = daily_budget_usd
        self.total_cost_today = 0.0

        # Model configurations
        self.models = {
            "phi-2": ModelConfig(
                name="microsoft/phi-2",
                complexity="simple",
                cost_per_1k_tokens=0.0001,
                max_tokens=1024,
            ),
            "mistral-7b": ModelConfig(
                name="mistralai/Mistral-7B-Instruct-v0.2",
                complexity="medium",
                cost_per_1k_tokens=0.0002,
                max_tokens=2048,
            ),
            "llama-3-8b": ModelConfig(
                name="meta-llama/Meta-Llama-3-8B",
                complexity="complex",
                cost_per_1k_tokens=0.0003,
                max_tokens=4096,
            ),
        }

        # Initialize LLM (in production, these would be separate instances)
        self.llm = LLM(
            model=self.models["phi-2"].name,
            trust_remote_code=True,
            gpu_memory_utilization=0.9,
        )

    def route_request(self, prompt: str, max_tokens: int = 200) -> Dict:
        """
        Route request to appropriate model

        Returns:
            Dict with 'response', 'model_used', 'cost', 'latency_ms'
        """
        # Step 1: Classify complexity
        complexity = self.classify_complexity(prompt)

        # Step 2: Select model
        model_id = self._select_model(complexity)
        model_config = self.models[model_id]

        # Step 3: Check budget
        estimated_cost = self._estimate_cost(model_config, prompt, max_tokens)
        if self.total_cost_today + estimated_cost > self.daily_budget_usd:
            # Budget exceeded - fallback to cheapest model
            model_id = "phi-2"
            model_config = self.models[model_id]

        # Step 4: Generate response
        sampling_params = SamplingParams(
            temperature=0.7,
            top_p=0.9,
            max_tokens=max_tokens,
        )

        start_time = time.time()
        outputs = self.llm.generate([prompt], sampling_params)
        latency_ms = (time.time() - start_time) * 1000

        # Step 5: Track cost
        tokens_generated = len(outputs[0].outputs[0].token_ids)
        actual_cost = self._calculate_cost(model_config, prompt, tokens_generated)
        self.total_cost_today += actual_cost

        return {
            "response": outputs[0].outputs[0].text,
            "model_used": model_id,
            "cost_usd": actual_cost,
            "latency_ms": latency_ms,
            "tokens_generated": tokens_generated,
        }

    def _select_model(self, complexity: str) -> str:
        """Select model based on complexity"""
        for model_id, config in self.models.items():
            if config.complexity == complexity:
                return model_id
        return "phi-2"  # Default fallback

    def _estimate_cost(self, config: ModelConfig, prompt: str, max_tokens: int) -> float:
        """Estimate cost before generation"""
        input_tokens = len(prompt) / 4  # Rough estimate
        total_tokens = input_tokens + max_tokens
        return (total_tokens / 1000) * config.cost_per_1k_tokens

    def _calculate_cost(self, config: ModelConfig, prompt: str, tokens_generated: int) -> float:
        """Calculate actual cost after generation"""
        input_tokens = len(prompt) / 4
        total_tokens = input_tokens + tokens_generated
        return (total_tokens / 1000) * config.cost_per_1k_tokens

How to use it:

# Initialize router with daily budget
router = IntelligentRouter(daily_budget_usd=100.0)

# Route a simple query
result = router.route_request("What is gross margin?")
print(f"Model used: {result['model_used']}")  # phi-2
print(f"Cost: ${result['cost_usd']:.6f}")     # $0.000020

# Route a complex query
result = router.route_request(
    "Analyze Tesla's competitive positioning in the EV market "
    "and provide investment recommendations based on recent trends"
)
print(f"Model used: {result['model_used']}")  # llama-3-8b
print(f"Cost: ${result['cost_usd']:.6f}")     # $0.000138

Core Concept 2: Budget Enforcement

The problem: Monitoring costs isn’t the same as preventing them.

I have seen hundreds of thousands spent on a company AI hackathon because developers were using expensive models needlessly.

The solution: Hard limits that reject requests before they burn your budget.

The Three Levels of Budget Control

from dataclasses import dataclass
from datetime import datetime

@dataclass
class BudgetConfig:
    """Budget configuration with multiple enforcement levels"""
    max_cost_per_request: float = 0.50        # Level 1: prevent accidents
    daily_budget_usd: float = 100.0           # Level 2: daily cap
    monthly_budget_usd: float = 3000.0        # Level 3: monthly cap
    warning_threshold_pct: float = 0.80       # Warn at 80%

class BudgetEnforcer:
    """Hard budget enforcement - prevents spending, not just monitors"""
    
    def __init__(self, config: BudgetConfig):
        self.config = config
        self.daily_spend = 0.0
        self.monthly_spend = 0.0
        # ... implementation
    
    def check_budget(self, estimated_cost: float) -> Dict:
        """Check BEFORE generating - this is the key difference"""
        
        # Level 1: Per-request limit
        if estimated_cost > self.config.max_cost_per_request:
            return {"action": "reject", "reason": "Request too expensive"}
        
        # Level 2: Daily budget
        if self.daily_spend + estimated_cost > self.config.daily_budget_usd:
            return {"action": "downgrade", "reason": "Daily limit approaching"}
        
        # Level 3: Monthly budget
        if self.monthly_spend + estimated_cost > self.config.monthly_budget_usd:
            return {"action": "downgrade", "reason": "Monthly limit approaching"}
        
        return {"action": "allow"}

Best Practices

  • Set conservative limits initially
  • Monitor budget utilization trends
  • Implement graceful degradation
  • Track who’s using what

Core Concept 3: Prefix Caching

Problem: You’re paying to process the same content repeatedly.

In enterprise AI, you typically have a structure like this:

[Fixed System Prompt - 500 tokens]
You are a financial analyst AI assistant specializing in:
- Earnings report analysis
- SEC filing interpretation
- Market sentiment analysis
...

[User Query - 50 tokens]
What is EBITDA?

[Response - 100 tokens]
EBITDA stands for...

Total tokens: 650 (500 system + 50 query + 100 response)
What you pay for: All 650 tokens, every single request

The Solution: Prefix Caching

vLLM has a feature called “prefix caching” that solves this elegantly:

How to Enable It

from vllm import LLM

# WITHOUT prefix caching
llm = LLM(
    model="microsoft/phi-2",
    trust_remote_code=True,
)

# WITH prefix caching (80% cost reduction!)
llm = LLM(
    model="microsoft/phi-2",
    trust_remote_code=True,
    enable_prefix_caching=True,  # <-- That's it!
)

Testing Results

I tested this on GCP L4 with our end-to-end integration test. Here are the actual numbers:

Test setup:

  • Fixed system prompt: 500 tokens
  • 11 different user queries: 15-290 tokens each
  • Model: Phi-2 (2.7B)

Results WITHOUT prefix caching:

Request 1: $0.00010188 (full cost)
Request 2: $0.00010188 (full cost)
Request 3: $0.00010188 (full cost)
...
Total: $0.00112068 (11 × $0.00010188)

Results WITH prefix caching:

Request 1: $0.00002038 (full cost - establishes cache)
Request 2: $0.00000414 (80% cheaper - uses cache!)
Request 3: $0.00000409 (80% cheaper)
Request 4: $0.00000865 (80% cheaper)
...
Total: $0.00010031

Savings: $0.00102037 (91% reduction!)
Cache hit rate: 90.9% (10/11 requests)

Here is what just happened:

  • Same 11 queries
  • Same model
  • Same responses
  • One parameter change
  • 91% cost reduction

Best use cases:

  • RAG systems (fixed context, many questions): 80% savings
  • Template generation (fixed format, variable content): 70% savings
  • Conversations (history grows, new turns added): 50% savings

When it doesn’t help:

  • Every request is unique (no repeated prefix)
  • Prefix changes frequently (cache invalidated)
  • Very short queries (overhead dominates)

Rule of thumb: If you have a fixed prefix >200 tokens reused across requests, enable prefix caching.


Core Concept 4: Quantization

The problem: The models you want don’t fit in the GPUs you can afford.

  • Llama-3-70B in full precision: Requires 140GB GPU memory
  • Your budget: Maybe a 24GB L4 GPU
  • The gap: 116GB short

The solution: Use fewer bits per number with minimal quality loss, e.g., converting FP16 into INT8.

Quantization Schemes

MethodMemoryCompressionQuality LossWorks On
FP16 (baseline)19.3 GB0%All GPUs
AWQ5.2 GB3.7×~2%L4, A100
FP8~9.7 GB~1%H100 only

I’ve tested three quantization approaches on GCP L4:

1. FP8 (8-bit floating point)

  • Compression: 2x (FP16 ? FP8)
  • Quality: ~99% of original
  • Speed: Same or faster (better memory bandwidth)
  • Limitation: Requires H100 GPU (NOT supported on L4)

2. AWQ (Activation-aware Weight Quantization)

  • Compression: 3.7x (FP16 ? W4A16)
  • Quality: ~98% of original
  • Speed: Slightly slower than FP16
  • Limitation: Requires pre-quantized model

3. GPTQ (Post-training quantization)

  • Compression: 3.5x (FP16 ? INT4)
  • Quality: ~97% of original
  • Speed: Similar to AWQ
  • Limitation: Longer quantization process

Benchmark Results

I ran quantization benchmarks on GCP L4 with Phi-2. Here’s what I measured (from benchmarks/04_quantization_comparison.py):

# Benchmark code
class QuantizationBenchmark:
    def benchmark_quantization(self, quantization: str):
        """
        Test quantization scheme

        Args:
            quantization: "none" (FP16), "fp8", "awq", or "gptq"
        """
        llm_kwargs = {
            "model": "microsoft/phi-2",
            "trust_remote_code": True,
            "gpu_memory_utilization": 0.9,
            "max_model_len": 1024,
        }

        # Add quantization if specified
        if quantization != "none":
            llm_kwargs["quantization"] = quantization

        # Load model
        start = time.time()
        llm = LLM(**llm_kwargs)
        load_time = time.time() - start

        # Measure memory
        gpu_memory = torch.cuda.memory_allocated() / 1e9  # GB

        # Benchmark generation
        prompt = "Explain quantum computing in simple terms"
        sampling_params = SamplingParams(max_tokens=100)

        start = time.time()
        outputs = llm.generate([prompt], sampling_params)
        latency_ms = (time.time() - start) * 1000

        return {
            "quantization": quantization,
            "memory_gb": gpu_memory,
            "load_time_sec": load_time,
            "latency_ms": latency_ms,
            "tokens_per_sec": 100 / (latency_ms / 1000),
        }

How to Use AWQ Quantization

The easiest approach is using pre-quantized models from HuggingFace:

from vllm import LLM

# Option 1: Use pre-quantized AWQ model
llm = LLM(
    model="TheBloke/Mistral-7B-Instruct-v0.2-AWQ",  # Pre-quantized!
    quantization="awq",
    trust_remote_code=True,
    gpu_memory_utilization=0.9,
)

# That's it! 3.7x smaller, ready to use

Available AWQ models (from TheBloke on HuggingFace):

  • Llama-2-7B-AWQ
  • Llama-2-13B-AWQ
  • Mistral-7B-Instruct-v0.2-AWQ
  • CodeLlama-7B-AWQ
  • Mixtral-8x7B-AWQ

Memory savings example:

# Mistral-7B in FP16
llm_fp16 = LLM(model="mistralai/Mistral-7B-Instruct-v0.2")
# Memory: ~16GB VRAM
# Fits on: A100-40GB, L4 (barely)

# Mistral-7B in AWQ
llm_awq = LLM(
    model="TheBloke/Mistral-7B-Instruct-v0.2-AWQ",
    quantization="awq"
)
# Memory: ~4.3GB VRAM
# Fits on: T4-16GB, L4 (comfortably), even RTX 3090

# Savings: 72% memory reduction!

When to Use Quantization

? Use quantization when:

  • You’re memory-constrained (model doesn’t fit)
  • You want to use cheaper GPUs
  • Quality loss <2% is acceptable
  • You’re deploying at scale (cost matters)

? Skip quantization when:

  • You have unlimited GPU budget (rare!)
  • You need absolute maximum quality
  • Model already fits comfortably
  • You’re still prototyping (optimize later)

My recommendation: Start with AWQ for all production deployments. The cost savings alone justify it, and quality loss is negligible for most tasks.


Core Concept 5: Complete Observability

The problem: When your AI system breaks, you need to know what, when, why, and who.

The solution: Three monitoring layers.

The Three Layers of AI Observability

Layer 1: Infrastructure (What Prometheus tracks)

GPU metrics:

  • Memory usage (prevent out-of-memory)
  • Utilization (optimize capacity)
  • Temperature (hardware health)

Service metrics:

  • Request rate (traffic patterns)
  • Error rate (system health)
  • Latency percentiles (user experience)

Layer 2: Application Metrics (AI-specific)

  • Time to First Token (TTFT)
  • Inter-Token Latency (ITL)
  • Tokens per second
  • Cost per request
  • Model distribution
  • Tool: Custom metrics in Prometheus

Layer 3: LLM Observability (Content-level)

  • What prompts are users sending?
  • What responses are being generated?
  • Cost attribution per user/team
  • Quality trends over time
  • Tool: Langfuse or Arize Phoenix

Custom Application Metrics

Here’s how I export custom metrics from the vLLM application – Layer 2 (from src/observability_monitoring.py):

from prometheus_client import Counter, Histogram, Gauge, generate_latest
from typing import Dict
import time

class VLLMMetrics:
    """
    Production metrics for vLLM serving

    Tracks:
    - Request counts (total, success, failure)
    - Latency distributions (P50, P95, P99)
    - Token throughput
    - Cost tracking
    - Model distribution
    """

    def __init__(self):
        # Request counters
        self.requests_total = Counter(
            'vllm_requests_total',
            'Total number of requests',
            ['model', 'status']
        )

        # Latency histogram
        self.latency = Histogram(
            'vllm_latency_ms',
            'Request latency in milliseconds',
            ['model'],
            buckets=[10, 50, 100, 250, 500, 1000, 2500, 5000, 10000]
        )

        # Token metrics
        self.tokens_generated = Counter(
            'vllm_tokens_generated_total',
            'Total tokens generated',
            ['model']
        )

        self.tokens_per_second = Gauge(
            'vllm_tokens_per_second',
            'Current tokens per second',
            ['model']
        )

        # Cost tracking
        self.cost_usd = Counter(
            'vllm_cost_usd_total',
            'Total cost in USD',
            ['model']
        )

        self.daily_cost = Gauge(
            'vllm_daily_cost_usd',
            'Cost today in USD'
        )

        # GPU memory
        self.gpu_memory_used = Gauge(
            'vllm_gpu_memory_used_gb',
            'GPU memory used in GB'
        )

        self.gpu_memory_total = Gauge(
            'vllm_gpu_memory_total_gb',
            'Total GPU memory in GB'
        )

        # Cache metrics
        self.cache_hit_rate = Gauge(
            'vllm_cache_hit_rate',
            'Prefix cache hit rate'
        )

    def record_request(
        self,
        model: str,
        latency_ms: float,
        tokens: int,
        cost_usd: float,
        success: bool,
        cached: bool = False
    ):
        """Record request metrics"""

        # Update counters
        status = "success" if success else "failure"
        self.requests_total.labels(model=model, status=status).inc()

        if success:
            # Latency
            self.latency.labels(model=model).observe(latency_ms)

            # Tokens
            self.tokens_generated.labels(model=model).inc(tokens)
            tokens_per_sec = tokens / (latency_ms / 1000)
            self.tokens_per_second.labels(model=model).set(tokens_per_sec)

            # Cost
            self.cost_usd.labels(model=model).inc(cost_usd)

    def update_gpu_memory(self):
        """Update GPU memory metrics"""
        if torch.cuda.is_available():
            used_gb = torch.cuda.memory_allocated() / 1e9
            total_gb = torch.cuda.get_device_properties(0).total_memory / 1e9

            self.gpu_memory_used.set(used_gb)
            self.gpu_memory_total.set(total_gb)

    def export_metrics(self) -> str:
        """Export Prometheus metrics"""
        return generate_latest().decode('utf-8')

# Usage in FastAPI
from fastapi import FastAPI

app = FastAPI()
metrics = VLLMMetrics()

@app.post("/generate")
async def generate(request: GenerateRequest):
    start = time.time()

    try:
        # Generate response
        result = llm.generate(request.prompt)

        # Record success metrics
        latency_ms = (time.time() - start) * 1000
        metrics.record_request(
            model="phi-2",
            latency_ms=latency_ms,
            tokens=len(result.tokens),
            cost_usd=calculate_cost(result),
            success=True,
        )

        return result

    except Exception as e:
        # Record failure
        latency_ms = (time.time() - start) * 1000
        metrics.record_request(
            model="phi-2",
            latency_ms=latency_ms,
            tokens=0,
            cost_usd=0,
            success=False,
        )
        raise

@app.get("/metrics")
async def get_metrics():
    """Prometheus scrape endpoint"""
    metrics.update_gpu_memory()
    return Response(
        content=metrics.export_metrics(),
        media_type="text/plain"
    )

What this tracks:

  • ? Request rate (by model, by status)
  • ? Latency distribution (with percentiles)
  • ? Token throughput (tokens/sec)
  • ? Cost tracking (per model, daily total)
  • ? GPU memory usage
  • ? Cache hit rates

Integration code for Langfuse – Layer 3 (from examples/05_llm_observability.py):

from langfuse import Langfuse
import os

# Initialize Langfuse
langfuse = Langfuse(
    public_key=os.getenv("LANGFUSE_PUBLIC_KEY"),
    secret_key=os.getenv("LANGFUSE_SECRET_KEY"),
    host=os.getenv("LANGFUSE_HOST", "http://localhost:3001"),
)

def generate_with_observability(prompt: str, user_id: str, metadata: Dict = None):
    """Generate response with full Langfuse tracing"""

    # Create trace
    trace = langfuse.trace(
        name="financial_analysis",
        user_id=user_id,
        metadata=metadata or {},
    )

    # Start generation span
    generation = trace.generation(
        name="vllm_generate",
        model="microsoft/phi-2",
        input=prompt,
        metadata={
            "quantization": "awq",
            "max_tokens": 200,
        }
    )

    # Generate
    start = time.time()
    result = llm.generate(prompt)
    latency_ms = (time.time() - start) * 1000

    # Calculate cost
    tokens_in = len(prompt) / 4
    tokens_out = len(result.tokens)
    cost_usd = ((tokens_in + tokens_out) / 1000) * 0.0001

    # End span with metrics
    generation.end(
        output=result.text,
        usage={
            "input_tokens": int(tokens_in),
            "output_tokens": tokens_out,
            "total_tokens": int(tokens_in + tokens_out),
        },
        metadata={
            "latency_ms": latency_ms,
            "cost_usd": cost_usd,
            "model": "phi-2",
        }
    )

    return result

# Usage
result = generate_with_observability(
    prompt="Analyze Apple's Q4 earnings",
    user_id="analyst_001",
    metadata={
        "team": "equity_research",
        "department": "finance",
    }
)

You can see following in Langfuse dashboard:

  • Every prompt and response
  • Cost per request, per user, per team
  • Latency trends over time
  • Token usage patterns
  • Quality scores (if you add feedback)
  • Prompt versions (track what works)

Alerting Strategy

You can configure Langfuse with alerting with various severity such as:

Critical (PagerDuty/Phone):

  • Service down
  • Error rate >10%
  • Daily budget exceeded by 50%
  • GPU out of memory

Warning (Slack):

  • Error rate >5%
  • P95 latency >1000ms
  • Daily budget at 80%
  • GPU memory >95%

Info (Email):

  • Daily usage summary
  • Cost reports
  • Quality metrics

Observability isn’t optional for production AI—it’s essential.


Core Concept 6: Production Error Handling

Your AI system will fail. GPUs crash, networks drop, users send garbage, budgets get exceeded.

Error Handling Pattern Flow

Five essential patterns:

Pattern 1: Retry with Exponential Backoff

Here is a retry logic (from examples/07_advanced_error_handling.py):

from typing import Callable
from dataclasses import dataclass
import time

@dataclass
class RetryConfig:
    """Retry configuration"""
    max_retries: int = 3
    initial_delay: float = 1.0
    max_delay: float = 60.0
    exponential_base: float = 2.0

def retry_with_backoff(config: RetryConfig = RetryConfig()):
    """
    Decorator: Retry with exponential backoff

    Example:
        @retry_with_backoff()
        def generate_text(prompt):
            return llm.generate(prompt)
    """
    def decorator(func: Callable) -> Callable:
        def wrapper(*args, **kwargs):
            delay = config.initial_delay

            for attempt in range(config.max_retries):
                try:
                    return func(*args, **kwargs)

                except Exception as e:
                    if attempt == config.max_retries - 1:
                        raise  # Last attempt, re-raise

                    error_type = classify_error(e)

                    # Don't retry on invalid input
                    if error_type == ErrorType.INVALID_INPUT:
                        raise

                    print(f"??  Attempt {attempt + 1} failed: {error_type.value}")
                    print(f"   Retrying in {delay:.1f}s...")
                    time.sleep(delay)

                    # Exponential backoff
                    delay = min(delay * config.exponential_base, config.max_delay)

            raise RuntimeError(f"Failed after {config.max_retries} retries")

        return wrapper
    return decorator

# Usage
@retry_with_backoff(RetryConfig(max_retries=3, initial_delay=1.0))
def generate_with_retry(prompt: str):
    """Generate with automatic retry on failure"""
    return llm.generate(prompt)

# This will retry up to 3 times with exponential backoff
result = generate_with_retry("Analyze earnings report")

Pattern 2: Circuit Breaker

When a service starts failing repeatedly, stop calling it:

from datetime import datetime, timedelta
from enum import Enum

class CircuitState(Enum):
    CLOSED = "closed"      # Normal operation
    OPEN = "open"          # Failing, reject requests
    HALF_OPEN = "half_open"  # Testing recovery

class CircuitBreaker:
    """
    Circuit breaker for fault tolerance

    Prevents cascading failures by stopping calls to
    failing services
    """

    def __init__(
        self,
        failure_threshold: int = 5,
        timeout: int = 60,
        expected_exception: type = Exception
    ):
        self.failure_threshold = failure_threshold
        self.timeout = timeout
        self.expected_exception = expected_exception

        self.failure_count = 0
        self.last_failure_time = None
        self.state = CircuitState.CLOSED

    def call(self, func: Callable, *args, **kwargs):
        """Execute function with circuit breaker protection"""

        if self.state == CircuitState.OPEN:
            # Check if timeout elapsed
            if datetime.now() - self.last_failure_time > timedelta(seconds=self.timeout):
                self.state = CircuitState.HALF_OPEN
                print("? Circuit breaker: HALF_OPEN (testing recovery)")
            else:
                raise RuntimeError("Circuit breaker OPEN - service unavailable")

        try:
            result = func(*args, **kwargs)

            # Success - reset if recovering
            if self.state == CircuitState.HALF_OPEN:
                self.state = CircuitState.CLOSED
                self.failure_count = 0
                print("? Circuit breaker: CLOSED (service recovered)")

            return result

        except self.expected_exception as e:
            self.failure_count += 1
            self.last_failure_time = datetime.now()

            if self.failure_count >= self.failure_threshold:
                self.state = CircuitState.OPEN
                print(f"? Circuit breaker: OPEN (threshold {self.failure_threshold} reached)")

            raise

# Usage
circuit_breaker = CircuitBreaker(failure_threshold=5, timeout=60)

def generate_protected(prompt: str):
    """Generate with circuit breaker protection"""
    return circuit_breaker.call(llm.generate, prompt)

# If llm.generate fails 5 times, circuit breaker opens
# Requests fail fast for 60 seconds
# Then one test request (half-open)
# If successful, normal operation resumes

This prevents:

  • Thundering herd problem
  • Resource exhaustion
  • Long timeouts on every request

Pattern 3: Rate Limiting

Protect your system from overload:

import time

class RateLimiter:
    """
    Token bucket rate limiter

    Limits requests per second to prevent overload
    """

    def __init__(self, max_requests: int, time_window: float = 1.0):
        self.max_requests = max_requests
        self.time_window = time_window
        self.tokens = max_requests
        self.last_update = time.time()

    def acquire(self, tokens: int = 1) -> bool:
        """Try to acquire tokens, return True if allowed"""

        now = time.time()
        elapsed = now - self.last_update

        # Refill tokens based on elapsed time
        self.tokens = min(
            self.max_requests,
            self.tokens + (elapsed / self.time_window) * self.max_requests
        )
        self.last_update = now

        if self.tokens >= tokens:
            self.tokens -= tokens
            return True
        else:
            return False

    def wait_for_token(self, tokens: int = 1):
        """Wait until token is available"""
        while not self.acquire(tokens):
            time.sleep(0.1)

# Usage
rate_limiter = RateLimiter(max_requests=100, time_window=1.0)

@app.post("/generate")
async def generate(request: GenerateRequest):
    # Check rate limit
    if not rate_limiter.acquire():
        raise HTTPException(
            status_code=429,
            detail="Rate limit exceeded (100 req/sec)"
        )

    # Process request
    result = llm.generate(request.prompt)
    return result

Why this matters:

  • Prevents DoS (accidental or malicious)
  • Protects GPU from overload
  • Ensures fair usage

Pattern 4: Fallback Strategies

When primary fails, don’t just error—degrade gracefully:

def generate_with_fallback(prompt: str) -> str:
    """
    Try multiple strategies before failing

    Strategy 1: Primary model (Llama-3-8B)
    Strategy 2: Cached response (if available)
    Strategy 3: Simpler model (Phi-2)
    Strategy 4: Template response
    """

    # Try primary model
    try:
        return llm_primary.generate(prompt)

    except Exception as e:
        print(f"??  Primary model failed: {e}")

        # Fallback 1: Check cache
        cached_response = cache.get(prompt)
        if cached_response:
            print("? Returning cached response")
            return cached_response

        # Fallback 2: Try simpler model
        try:
            print("? Falling back to Phi-2")
            return llm_simple.generate(prompt)

        except Exception as e2:
            print(f"??  Fallback model also failed: {e2}")

            # Fallback 3: Template response
            return (
                "I apologize, but I'm unable to process your request right now. "
                "Please try again in a few minutes, or contact support if the issue persists."
            )

# User never sees "Internal Server Error"
# They always get SOME response

Graceful degradation examples:

  • Can’t generate full analysis? Return summary
  • Can’t use complex model? Use simple model
  • Can’t generate? Return cached response
  • Everything failing? Return polite error message

Pattern 5: Timeout Handling

Don’t let requests hang forever:

import signal

class TimeoutError(Exception):
    pass

def timeout_handler(signum, frame):
    raise TimeoutError("Request timed out")

def generate_with_timeout(prompt: str, timeout_seconds: int = 30):
    """Generate with timeout"""

    # Set timeout
    signal.signal(signal.SIGALRM, timeout_handler)
    signal.alarm(timeout_seconds)

    try:
        result = llm.generate(prompt)

        # Cancel timeout
        signal.alarm(0)
        return result

    except TimeoutError:
        print(f"? Request timed out after {timeout_seconds}s")
        return "Request timed out. Please try a shorter prompt."

# Or using asyncio
import asyncio

async def generate_with_timeout_async(prompt: str, timeout_seconds: int = 30):
    """Generate with async timeout"""

    try:
        result = await asyncio.wait_for(
            llm.generate_async(prompt),
            timeout=timeout_seconds
        )
        return result

    except asyncio.TimeoutError:
        return "Request timed out. Please try a shorter prompt."

Why timeouts matter:

  • Prevent resource leaks
  • Free up GPU for other requests
  • Give users fast feedback

Combined Example

Here’s how I combine all patterns:

from fastapi import FastAPI, HTTPException
from circuitbreaker import CircuitBreaker, CircuitBreakerError

app = FastAPI()

# Initialize components
circuit_breaker = CircuitBreaker(failure_threshold=5, timeout=60)
rate_limiter = RateLimiter(max_requests=100, time_window=1.0)
cache = ResponseCache(ttl=3600)

@app.post("/generate")
@retry_with_backoff(max_retries=3)
async def generate(request: GenerateRequest):
    """
    Generate with full error handling:
    - Rate limiting
    - Circuit breaker
    - Retry with backoff
    - Timeout
    - Fallback strategies
    - Caching
    """

    # Rate limiting
    if not rate_limiter.acquire():
        raise HTTPException(status_code=429, detail="Rate limit exceeded")

    # Check cache first
    cached = cache.get(request.prompt)
    if cached:
        return {"text": cached, "cached": True}

    try:
        # Circuit breaker protection
        result = circuit_breaker.call(
            generate_with_timeout,
            request.prompt,
            timeout_seconds=30
        )

        # Cache successful response
        cache.set(request.prompt, result)

        return {"text": result, "status": "success"}

    except CircuitBreakerError:
        # Circuit breaker open - return fallback
        return {
            "text": "Service temporarily unavailable. Using cached response.",
            "status": "degraded",
            "fallback": True
        }

    except TimeoutError:
        raise HTTPException(status_code=504, detail="Request timed out")

    except Exception as e:
        # Log error
        logger.error(f"Generation failed: {e}")

        # Return graceful error
        return {
            "text": "I apologize, but I'm unable to process your request.",
            "status": "error",
            "fallback": True
        }

What this provides:

  • ? Prevents overload (rate limiting)
  • ? Fast failure (circuit breaker)
  • ? Automatic recovery (retry)
  • ? Resource protection (timeout)
  • ? Graceful degradation (fallback)
  • ? Performance (caching)

Deployment Recommendations

While my testing remained at POC level, these patterns prepare for production deployment:

Before deploying:

Load Testing

  • Test with expected peak load (10-100x normal traffic)
  • Measure P95 latency under load (<500ms target)
  • Verify error rate stays <1%
  • Confirm GPU memory stable (no leaks)

Production Deployment Checklist

Before going live, verify:

Infrastructure:

  • [ ] GPU drivers installed and working (nvidia-smi)
  • [ ] Docker and Docker Compose installed
  • [ ] Sufficient disk space (200GB+ for models)
  • [ ] Network configured (firewall rules, security groups)
  • [ ] SSL/TLS certificates (for HTTPS)

Configuration:

  • [ ] Model name set correctly in .env
  • [ ] Quantization configured (AWQ recommended)
  • [ ] GPU memory utilization set (0.9 typical)
  • [ ] Prefix caching enabled (ENABLE_PREFIX_CACHING=True)
  • [ ] Budget limits configured
  • [ ] Log level appropriate (info for prod)

Monitoring:

  • [ ] Prometheus scraping vLLM metrics
  • [ ] Grafana dashboard imported and working
  • [ ] Alerts configured in alert_rules.yml
  • [ ] Alert destinations set (PagerDuty, Slack, email)
  • [ ] Langfuse set up (if using LLM observability)

Testing:

  • [ ] Health check returns 200 OK
  • [ ] Can generate completions via API
  • [ ] Metrics endpoint returning data
  • [ ] Error handling works (try invalid input)
  • [ ] Budget limits enforced (if configured)
  • [ ] Load test passed (see next section)

Security:

  • [ ] API authentication enabled
  • [ ] Rate limiting configured
  • [ ] HTTPS enforced (no HTTP)
  • [ ] CORS policies set
  • [ ] Input validation in place
  • [ ] Secrets not in git (use env variables)

Operations:

  • [ ] Backup strategy for logs
  • [ ] Model cache backed up
  • [ ] Runbook written (how to handle incidents)
  • [ ] On-call rotation defined
  • [ ] SLAs documented
  • [ ] Disaster recovery plan

Real-World Results

Testing on GCP L4 GPUs with 11 queries produced these validated results:

End-to-End Integration Test Results

Test configuration:

  • Model: Phi-2 (2.7B parameters)
  • Quantization: None (FP16 baseline)
  • Prefix caching: Enabled
  • Budget: $10/day
  • Hardware: GCP L4 GPU

Results:

MetricValue
Total Requests11
Success Rate100% (11/11) ?
Total Tokens Generated2,200
Total Cost$0.000100
Average Latency5,418ms
Cache Hit Rate90.9% ?
Budget Utilization0.001%

Model distribution:

  • Phi-2: 54.5% (6 requests)
  • Llama-3-8B: 27.3% (3 requests)
  • Mistral-7B: 18.2% (2 requests)

What this proves:
? Intelligent routing works (3 models selected correctly)
? Budget enforcement works (under budget, no overruns)
? Prefix caching works (91% hit rate = huge savings)
? Multi-model support works (distributed correctly)
? Observability works (all metrics collected)

Cost Comparison

Let me show you the exact cost calculations:

Per-request costs (from actual test):

Request 1 (uncached): $0.00002038
Requests 2-11 (cached): $0.00000414 average

Total: $0.00010031 for 11 requests
Average: $0.0000091 per request

Extrapolated monthly costs (10,000 requests/day):

ConfigurationDaily CostMonthly CostSavings
Without caching$0.91$27.30Baseline
With caching (91% hit rate)$0.18$5.4680%
With quantization (AWQ)$0.09$2.7390%
All optimizations$0.09$2.7390%

Add in infrastructure costs:

GCP L4 GPU: $0.45/hour = $328/month

Total monthly cost:
- Infrastructure: $328
- API costs: $2.73
- Total: $330.73/month for 10,000 requests/day

Compare to OpenAI:

OpenAI GPT-4:
- Input: $0.03 per 1K tokens
- Output: $0.06 per 1K tokens
- Average request: 100 tokens in + 100 tokens out = $0.009
- 10,000 requests/day = $90/day = $2,700/month

Savings: $2,369/month (88% cheaper!)

Benchmark Results Summary

Here are all the benchmark results from GCP L4:

1. Throughput Benchmark (benchmarks/01_throughput_comparison.py)

Batch SizeTokens/SecLatency (ms)Speedup
141.59871x
4165.82474x
8331.61248x
16663.26216x
32934.49922.5x

Key insight: Batching provides massive throughput improvements (22.5x!)

2. Memory Efficiency (benchmarks/02_memory_efficiency.py)

Batch SizeMemory Used (GB)Overhead
119.30Baseline
419.33+0.16%
819.38+0.41%
1619.45+0.78%
3219.58+1.45%

Key insight: PagedAttention keeps memory growth near zero even with large batches

3. Cost Analysis (benchmarks/03_cost_analysis.py)

ScenarioCost/Monthvs GPT-4
OpenAI GPT-4$666Baseline
OpenAI GPT-3.5$15-98%
vLLM Phi-2 (FP16)$324-51%
vLLM + AWQ$87-87%
vLLM + AWQ + Caching$65-90%
All optimizations$45-93%

Key insight: Self-hosting with vLLM is 93% cheaper than OpenAI GPT-4

4. Quantization (benchmarks/04_quantization_comparison.py)

SchemeMemory (GB)CompressionQuality Loss
FP1619.31x0%
AWQ5.23.7x~2%

Key insight: AWQ provides 3.7x compression with minimal quality loss

What validated:
? Intelligent routing correctly classified 100% of queries
? Budget enforcement prevented overruns
? Prefix caching delivered promised 80% savings
? Multi-model support distributed load appropriately
? Observability captured all metrics accurately

What Surprised Me

Good surprises:

  1. Cache hit rates higher than expected – I expected 70%, got 91%
  2. Quantization quality loss minimal – Barely noticeable in real use
  3. vLLM stability – Zero crashes during testing
  4. Cost savings magnitude – 93% cheaper than GPT-4 is huge

Challenges:

  1. FP8 not supported on L4 – Had to use AWQ instead (still great)
  2. First request slow – Model loading takes 8 seconds (then fast)
  3. Large context memory usage – 2K tokens works, 4K+ needs more GPU

ROI Calculation (50,000 requests/day)

Option A: OpenAI GPT-4

Cost per request: $0.009
Daily: $450
Monthly: $13,500
Annual: $162,000

Option B: vLLM on GCP L4 (our solution)

Infrastructure: $328/month
API costs (with optimizations): $13.65/month
Monthly total: $341.65
Annual: $4,100

Savings: $157,900/year (97%)

Break-even:

Setup time: 2 days engineering ($2,000)
Maintenance: 4 hours/month ($200/month)

Year 1:
  Savings: $157,900
  Costs: $2,000 setup + $2,400 maintenance = $4,400
  Net: $153,500 saved

ROI: 3,500% in year 1

At scale (500,000 requests/day):

OpenAI GPT-4: $1,350,000/year
vLLM solution: $41,000/year

Savings: $1,309,000/year (97%)

Production Readiness Checklist

Based on testing, here’s what enterprise deployment requires:

Security & Compliance:

  • Authentication/authorization at API level
  • Data encryption (rest and transit)
  • PII detection and redaction capabilities
  • Audit logs for compliance (GDPR, HIPAA)
  • Network security (VPC, firewalls, no public exposure)

Operational Excellence:

  • Comprehensive monitoring (3 layers: infra, app, LLM)
  • Alerting strategy (critical/warning/info tiers)
  • Structured logging with aggregation
  • Backup/recovery procedures tested
  • Incident response runbook documented

Performance & Scale:

  • Load testing validates capacity
  • P95 latency meets SLAs (<500ms)
  • Success rate >99.9% under load
  • Auto-scaling strategy defined
  • Capacity planning for 2x, 5x, 10x growth

Cost Governance:

  • Hard budget limits (daily, monthly)
  • Per-user and per-team tracking
  • Cost dashboards for visibility
  • Automated alerts at 80%, 100%
  • Chargeback reports for finance

Quality Assurance:

  • Automated test suite (unit, integration, e2e)
  • Error handling verified (retries, circuit breakers)
  • Fallback strategies tested
  • Chaos engineering (simulate failures)
  • SLA monitoring automated

Final Thoughts

After building and testing this platform, I understand why enterprise AI differs from giving developers ChatGPT access and why 95% of initiatives fail. Here is why these layers matter:

  • Cost tracking isn’t about being cheap—it’s about accountability. Finance won’t approve next year’s AI budget without ROI proof.
  • Intelligent routing prevents the death spiral: early excitement ? everyone uses the expensive model ? costs spiral ? finance pulls the plug ? initiative dies.
  • Observability builds trust. When executives ask “Is AI working?”, you need data: success rates, cost per department, quality trends. Without metrics, you get politics and cancellation.
  • Error handling and budgets are professional table stakes. Enterprises can’t have systems that randomly fail or spend unpredictably.

Here are things missing from the prototype:

  • Security: No SSO, PII detection, audit logs for compliance, encryption at rest, security review
  • High Availability: Single instance, no load balancer, no failover, no disaster recovery
  • Operations: No CI/CD, secrets management, log aggregation, incident playbooks
  • Scale: No auto-scaling, multi-region, or load testing beyond 100 concurrent
  • Governance: No approval workflows, per-user limits, content filtering, A/B testing

I have learned that vLLM works, open models are competitive, the tooling is mature. This POC proves that the patterns work and the savings are real. The 5% that succeed treat AI as infrastructure requiring proper tooling. The 95% that fail treat it as magic requiring only faith.

Try it yourself: All code at github.com/bhatti/vllm-tutorial. Clone it, test it, prove it works in your environment. Then build the business case for production investment.

No Comments

No comments yet.

RSS feed for comments on this post. TrackBack URL

Sorry, the comment form is closed at this time.

Powered by WordPress