On-Device ML Model Testing
Why On-Device ML Needs Dedicated Testing
Applications increasingly run ML models directly on device (Core ML on iOS, TensorFlow Lite on Android) for features like image classification, text prediction, object detection, and face recognition. These models execute locally without network connectivity, providing faster inference and better privacy.
But on-device ML introduces testing challenges that server-side ML does not have. The model must load within memory constraints, inference must complete within latency budgets, and the model must produce acceptable accuracy on hardware with limited compute power. A model that works perfectly on a server with a GPU may perform poorly on a mid-range phone.
What to Test for On-Device ML
| Test Area | What to Verify | Example Assertion |
|---|---|---|
| Model loading | Model initializes without crash on target devices | Load time < 2s on Tier 1 devices |
| Inference latency | Prediction completes within acceptable time | p95 < 100ms on mid-range devices |
| Accuracy on device | Model produces same results as server-side | Agreement > 99% on validation set |
| Memory footprint | Model does not cause OOM on constrained devices | Peak memory < 100MB |
| Battery impact | Continuous inference does not drain battery | < 5% per hour of active use |
| Fallback behavior | App works when model fails to load | Graceful degradation to server-side API |
| Model update | OTA model updates apply correctly | New model version loads after app restart |
| Concurrent inference | Multiple inference requests do not crash | 10 concurrent requests complete |
Testing Inference Latency
# tests/ml/test_ondevice_model.py
import time
from appium.webdriver.common.appiumby import AppiumBy
def test_image_classification_latency(driver):
"""On-device ML inference must complete within 200ms."""
driver.find_element(AppiumBy.ACCESSIBILITY_ID, "ml-classify-btn").click()
# Inject test image
driver.push_file("/sdcard/test_images/cat.jpg", source_path="test_data/cat.jpg")
start = time.time()
driver.find_element(AppiumBy.ACCESSIBILITY_ID, "classify-image").click()
# Wait for result
result = driver.find_element(AppiumBy.ACCESSIBILITY_ID, "classification-result")
elapsed = time.time() - start
assert elapsed < 0.2, f"Inference took {elapsed:.3f}s, exceeding 200ms threshold"
assert "cat" in result.text.lower()
def test_inference_latency_p95(driver):
"""Measure p95 inference latency across multiple runs."""
latencies = []
for i in range(20):
driver.push_file(
f"/sdcard/test_images/test_{i}.jpg",
source_path=f"test_data/validation/image_{i}.jpg"
)
start = time.time()
driver.find_element(AppiumBy.ACCESSIBILITY_ID, "classify-image").click()
driver.find_element(AppiumBy.ACCESSIBILITY_ID, "classification-result")
elapsed = (time.time() - start) * 1000 # Convert to ms
latencies.append(elapsed)
# Reset for next image
driver.find_element(AppiumBy.ACCESSIBILITY_ID, "clear-result").click()
latencies.sort()
p95 = latencies[int(len(latencies) * 0.95)]
assert p95 < 200, f"p95 inference latency is {p95:.0f}ms, exceeding 200ms threshold"
Testing Model Accuracy on Device
The same model may produce slightly different results on different hardware due to floating-point precision differences between GPU, CPU, and NPU inference.
def test_model_accuracy_matches_server(driver):
"""On-device predictions must agree with server-side predictions."""
validation_set = load_validation_set("test_data/validation/")
agreements = 0
total = len(validation_set)
for image_path, expected_label in validation_set:
# Get on-device prediction
driver.push_file("/sdcard/test_images/current.jpg", source_path=image_path)
driver.find_element(AppiumBy.ACCESSIBILITY_ID, "classify-image").click()
result = driver.find_element(AppiumBy.ACCESSIBILITY_ID, "classification-result")
device_prediction = result.text
# Compare with server-side prediction
if device_prediction.lower() == expected_label.lower():
agreements += 1
driver.find_element(AppiumBy.ACCESSIBILITY_ID, "clear-result").click()
accuracy = agreements / total
assert accuracy >= 0.99, \
f"On-device accuracy {accuracy:.2%} is below 99% threshold"
def test_model_handles_edge_case_inputs(driver):
"""Model should handle adversarial and edge case inputs gracefully."""
edge_cases = [
("test_data/blank_white.jpg", "no_object_detected"),
("test_data/blank_black.jpg", "no_object_detected"),
("test_data/tiny_1x1.jpg", "invalid_input"),
("test_data/very_large.jpg", "cat"), # Should still classify correctly
("test_data/rotated_90.jpg", "cat"), # Rotation invariance
("test_data/low_quality.jpg", "cat"), # JPEG quality = 10
]
for image_path, expected_category in edge_cases:
driver.push_file("/sdcard/test_images/edge.jpg", source_path=image_path)
driver.find_element(AppiumBy.ACCESSIBILITY_ID, "classify-image").click()
# Should not crash regardless of input
result = driver.find_element(AppiumBy.ACCESSIBILITY_ID, "classification-result")
assert result.is_displayed(), f"Model crashed on {image_path}"
driver.find_element(AppiumBy.ACCESSIBILITY_ID, "clear-result").click()
Testing Fallback Behavior
When the on-device model fails to load (corrupted file, unsupported device, insufficient memory), the app must degrade gracefully to a server-side API.
def test_model_fallback_when_unavailable(driver):
"""App should fall back to server-side API when model fails to load."""
# Simulate model file corruption
driver.execute_script("mobile: shell", {
"command": "rm",
"args": ["/data/data/com.app/files/ml_model.tflite"]
})
# Restart app
driver.terminate_app("com.app")
driver.activate_app("com.app")
# Feature should still work (via server fallback)
driver.find_element(AppiumBy.ACCESSIBILITY_ID, "classify-image").click()
result = driver.find_element(AppiumBy.ACCESSIBILITY_ID, "classification-result")
assert result.is_displayed()
# Verify fallback indicator is shown
fallback_badge = driver.find_element(AppiumBy.ACCESSIBILITY_ID, "server-mode-indicator")
assert fallback_badge.is_displayed()
def test_fallback_to_server_on_low_memory(driver):
"""On low-memory devices, the app should use server-side inference."""
# Get available memory
memory_info = driver.execute_script("mobile: shell", {
"command": "cat",
"args": ["/proc/meminfo"]
})
# If available memory is under threshold, model may not load
# The app should detect this and use server-side inference
driver.find_element(AppiumBy.ACCESSIBILITY_ID, "classify-image").click()
result = driver.find_element(AppiumBy.ACCESSIBILITY_ID, "classification-result")
assert result.is_displayed() # Feature works regardless of inference path
Testing OTA Model Updates
def test_model_update_applies_correctly(driver):
"""Over-the-air model updates should apply on next app launch."""
# Get current model version
model_info = driver.find_element(AppiumBy.ACCESSIBILITY_ID, "model-version")
original_version = model_info.text
# Trigger model update check
driver.find_element(AppiumBy.ACCESSIBILITY_ID, "check-model-update").click()
# Wait for download
import time
time.sleep(10)
# Restart app to load new model
driver.terminate_app("com.app")
driver.activate_app("com.app")
# Verify new model version is loaded
model_info = driver.find_element(AppiumBy.ACCESSIBILITY_ID, "model-version")
new_version = model_info.text
# Version should be different (assuming an update was available)
# In CI, you can stage a test model update
assert new_version != original_version or new_version == "latest"
Memory and Battery Impact Testing
def test_memory_usage_during_inference(driver):
"""Continuous inference must not cause memory growth."""
# Get initial memory usage
initial_memory = get_app_memory(driver)
# Run 50 consecutive inferences
for i in range(50):
driver.find_element(AppiumBy.ACCESSIBILITY_ID, "classify-image").click()
driver.find_element(AppiumBy.ACCESSIBILITY_ID, "classification-result")
driver.find_element(AppiumBy.ACCESSIBILITY_ID, "clear-result").click()
# Get final memory usage
final_memory = get_app_memory(driver)
# Memory should not grow significantly (allow 20% for caching)
memory_growth = (final_memory - initial_memory) / initial_memory
assert memory_growth < 0.20, \
f"Memory grew by {memory_growth:.1%} during sustained inference"
def get_app_memory(driver):
"""Get the current memory usage of the app in MB."""
result = driver.execute_script("mobile: shell", {
"command": "dumpsys",
"args": ["meminfo", "com.app", "--short"]
})
# Parse total PSS from output
import re
match = re.search(r'TOTAL\s+(\d+)', result)
return int(match.group(1)) / 1024 if match else 0
On-device ML testing requires a blend of traditional functional testing (does it produce the right output?) and performance testing (does it do it fast enough, within memory constraints, without draining the battery?). The fallback behavior is the most critical test -- if the model fails, the user must still be able to use the feature.