Files
NetBirdMSP-Appliance/app/routers/auth.py
twothatit 7793ca3666 feat: add Windows DNS integration and LDAP/AD authentication
Windows DNS (WinRM):
- New dns_service.py: create/delete A-records via PowerShell over WinRM (NTLM)
- Idempotent create (removes existing record first), graceful delete
- DNS failures are non-fatal — deployment continues, error logged
- test-dns endpoint: GET /api/settings/test-dns
- Integrated into deploy_customer() and undeploy_customer()

LDAP / Active Directory auth:
- New ldap_service.py: service-account bind + user search + user bind (ldap3)
- Optional AD group restriction via ldap_group_dn
- Login flow: LDAP first → local fallback (prevents admin lockout)
- LDAP users auto-created with auth_provider="ldap" and role="viewer"
- test-ldap endpoint: GET /api/settings/test-ldap
- reset-password/reset-mfa guards extended to block LDAP users

All credentials (dns_password, ldap_bind_password) encrypted with Fernet.
New DB columns added via backwards-compatible migrations.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-02-21 21:06:51 +01:00

467 lines
18 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Authentication API endpoints — login, logout, current user, password change, MFA, Azure AD."""
import base64
import io
import logging
import secrets
from datetime import datetime
from fastapi import APIRouter, Depends, HTTPException, Request, status
from pydantic import BaseModel
from sqlalchemy.orm import Session
from app.database import get_db
from app.dependencies import create_access_token, create_mfa_token, get_current_user, verify_mfa_token
from app.models import SystemConfig, User
from app.services import ldap_service
from app.utils.config import get_system_config
from app.utils.security import (
decrypt_value,
encrypt_value,
generate_totp_secret,
generate_totp_uri,
hash_password,
verify_password,
verify_totp,
)
from app.utils.validators import ChangePasswordRequest, LoginRequest, MfaTokenRequest, MfaVerifyRequest
logger = logging.getLogger(__name__)
router = APIRouter()
from app.limiter import limiter
@router.post("/login")
@limiter.limit("10/minute")
async def login(request: Request, payload: LoginRequest, db: Session = Depends(get_db)):
"""Authenticate with username/password. May require MFA as a second step.
Auth flow:
1. If LDAP is enabled: try LDAP authentication first.
- Success → find or auto-create local User with auth_provider="ldap"
- Wrong password (user found in LDAP) → HTTP 401
- User not found in LDAP → fall through to local auth
2. Local auth: verify bcrypt hash for users with auth_provider="local"
3. On success: check MFA requirement (local users only) then issue JWT
Rate-limited to 10 attempts per minute per IP address.
"""
config = get_system_config(db)
user: User | None = None
# ------------------------------------------------------------------
# Step 1: LDAP authentication (if enabled)
# ------------------------------------------------------------------
if config and config.ldap_enabled and config.ldap_server:
try:
ldap_info = await ldap_service.authenticate_ldap(
payload.username, payload.password, config
)
if ldap_info is not None:
# User authenticated via LDAP — find or create local record
user = db.query(User).filter(User.username == ldap_info["username"]).first()
if not user:
user = User(
username=ldap_info["username"],
password_hash=hash_password(secrets.token_urlsafe(32)),
email=ldap_info.get("email", ""),
is_active=True,
role="viewer",
auth_provider="ldap",
)
db.add(user)
db.commit()
db.refresh(user)
logger.info("LDAP user '%s' auto-created with role 'viewer'.", ldap_info["username"])
elif not user.is_active:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Account is disabled.",
)
else:
# Keep auth_provider in sync in case it was changed
if user.auth_provider != "ldap":
user.auth_provider = "ldap"
db.commit()
except ValueError as exc:
# User found in LDAP but wrong password or group denied
logger.warning("LDAP login failed for '%s': %s", payload.username, exc)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid username or password.",
)
except RuntimeError as exc:
# LDAP server unreachable — log and fall through to local auth
logger.error("LDAP server error, falling back to local auth: %s", exc)
# ------------------------------------------------------------------
# Step 2: Local authentication (if LDAP didn't produce a user)
# ------------------------------------------------------------------
if user is None:
local_user = db.query(User).filter(User.username == payload.username).first()
if local_user and local_user.auth_provider == "local":
if not verify_password(payload.password, local_user.password_hash):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid username or password.",
)
user = local_user
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid username or password.",
)
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Account is disabled.",
)
# ------------------------------------------------------------------
# Step 3: MFA check (local users only)
# ------------------------------------------------------------------
if user.auth_provider == "local":
sys_config = db.query(SystemConfig).filter(SystemConfig.id == 1).first()
if sys_config and getattr(sys_config, "mfa_enabled", False):
mfa_token = create_mfa_token(user.username)
return {
"mfa_required": True,
"mfa_token": mfa_token,
"totp_setup_needed": not bool(user.totp_enabled),
}
token = create_access_token(user.username)
logger.info("User %s logged in (provider: %s).", user.username, user.auth_provider)
return {
"access_token": token,
"token_type": "bearer",
"user": user.to_dict(),
}
# ---------------------------------------------------------------------------
# MFA endpoints
# ---------------------------------------------------------------------------
@router.post("/mfa/setup")
async def mfa_setup(payload: MfaTokenRequest, db: Session = Depends(get_db)):
"""Generate a new TOTP secret and QR code for first-time MFA setup."""
username = verify_mfa_token(payload.mfa_token)
user = db.query(User).filter(User.username == username).first()
if not user:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found.")
# Generate new secret and store encrypted (not yet enabled)
secret = generate_totp_secret()
user.totp_secret_encrypted = encrypt_value(secret)
db.commit()
# Generate QR code as base64 data URI
uri = generate_totp_uri(secret, username)
import qrcode
img = qrcode.make(uri)
buf = io.BytesIO()
img.save(buf, format="PNG")
qr_b64 = base64.b64encode(buf.getvalue()).decode()
return {
"secret": secret,
"qr_code": f"data:image/png;base64,{qr_b64}",
"otpauth_uri": uri,
}
@router.post("/mfa/setup/complete")
async def mfa_setup_complete(payload: MfaVerifyRequest, db: Session = Depends(get_db)):
"""Verify the first TOTP code to complete MFA setup, then issue access token."""
username = verify_mfa_token(payload.mfa_token)
user = db.query(User).filter(User.username == username).first()
if not user:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found.")
if not user.totp_secret_encrypted:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="TOTP setup not initiated. Call /auth/mfa/setup first.",
)
secret = decrypt_value(user.totp_secret_encrypted)
if not verify_totp(secret, payload.totp_code):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid TOTP code.",
)
user.totp_enabled = True
db.commit()
token = create_access_token(user.username)
logger.info("User %s completed MFA setup and logged in.", user.username)
return {
"access_token": token,
"token_type": "bearer",
"user": user.to_dict(),
}
@router.post("/mfa/verify")
@limiter.limit("10/minute")
async def mfa_verify(request: Request, payload: MfaVerifyRequest, db: Session = Depends(get_db)):
"""Verify a TOTP code for users who already have MFA set up.
Rate-limited to 10 attempts per minute per IP address.
"""
username = verify_mfa_token(payload.mfa_token)
user = db.query(User).filter(User.username == username).first()
if not user:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found.")
if not user.totp_secret_encrypted or not user.totp_enabled:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="TOTP is not set up for this user.",
)
secret = decrypt_value(user.totp_secret_encrypted)
if not verify_totp(secret, payload.totp_code):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid TOTP code.",
)
token = create_access_token(user.username)
logger.info("User %s passed MFA verification.", user.username)
return {
"access_token": token,
"token_type": "bearer",
"user": user.to_dict(),
}
@router.get("/mfa/status")
async def mfa_status(
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""Return MFA status for the current user and global setting."""
config = db.query(SystemConfig).filter(SystemConfig.id == 1).first()
return {
"mfa_enabled_global": bool(config and getattr(config, "mfa_enabled", False)),
"totp_enabled_user": bool(current_user.totp_enabled),
}
@router.post("/mfa/disable")
async def mfa_disable(
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""Disable TOTP for the current user."""
current_user.totp_enabled = False
current_user.totp_secret_encrypted = None
db.commit()
logger.info("User %s disabled their TOTP.", current_user.username)
return {"message": "TOTP disabled successfully."}
# ---------------------------------------------------------------------------
# Password change
# ---------------------------------------------------------------------------
@router.post("/logout")
async def logout(current_user: User = Depends(get_current_user)):
"""Logout (client-side token discard)."""
logger.info("User %s logged out.", current_user.username)
return {"message": "Logged out successfully."}
@router.get("/me")
async def get_me(current_user: User = Depends(get_current_user)):
"""Return the current authenticated user's profile."""
return current_user.to_dict()
@router.post("/change-password")
async def change_password(
payload: ChangePasswordRequest,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""Change the current user's password."""
if not verify_password(payload.current_password, current_user.password_hash):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Current password is incorrect.",
)
current_user.password_hash = hash_password(payload.new_password)
db.commit()
logger.info("Password changed for user %s.", current_user.username)
return {"message": "Password changed successfully."}
# ---------------------------------------------------------------------------
# Azure AD
# ---------------------------------------------------------------------------
class AzureCallbackRequest(BaseModel):
"""Azure AD auth code callback payload."""
code: str
redirect_uri: str
@router.get("/azure/config")
async def get_azure_config(db: Session = Depends(get_db)):
"""Public endpoint — returns Azure AD config for the login page."""
config = db.query(SystemConfig).filter(SystemConfig.id == 1).first()
if not config or not config.azure_enabled:
return {"azure_enabled": False}
return {
"azure_enabled": True,
"azure_tenant_id": config.azure_tenant_id,
"azure_client_id": config.azure_client_id,
}
@router.post("/azure/callback")
async def azure_callback(
payload: AzureCallbackRequest,
db: Session = Depends(get_db),
):
"""Exchange Azure AD authorization code for tokens and authenticate."""
config = db.query(SystemConfig).filter(SystemConfig.id == 1).first()
if not config or not config.azure_enabled:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Azure AD authentication is not enabled.",
)
if not config.azure_tenant_id or not config.azure_client_id or not config.azure_client_secret_encrypted:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Azure AD is not fully configured.",
)
try:
import msal
import httpx as _httpx
client_secret = decrypt_value(config.azure_client_secret_encrypted)
authority = f"https://login.microsoftonline.com/{config.azure_tenant_id}"
msal_app = msal.ConfidentialClientApplication(
config.azure_client_id,
authority=authority,
client_credential=client_secret,
)
result = msal_app.acquire_token_by_authorization_code(
payload.code,
scopes=["User.Read"],
redirect_uri=payload.redirect_uri,
)
if "error" in result:
logger.warning("Azure AD token exchange failed: %s", result.get("error_description", result["error"]))
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=result.get("error_description", "Azure AD authentication failed."),
)
id_token_claims = result.get("id_token_claims", {})
email = id_token_claims.get("preferred_username") or id_token_claims.get("email", "")
display_name = id_token_claims.get("name", email) # noqa: F841
user_access_token = result.get("access_token", "")
if not email:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Could not determine email from Azure AD token.",
)
# -----------------------------------------------------------------
# Group membership check (Fix #3 Azure AD group whitelist)
# -----------------------------------------------------------------
allowed_group_id = getattr(config, "azure_allowed_group_id", None)
if allowed_group_id:
# Use the user's own access token to check their group membership
# via the Microsoft Graph API (requires GroupMember.Read.All or
# the user's own memberOf delegated permission).
graph_url = "https://graph.microsoft.com/v1.0/me/memberOf"
is_member = False
try:
async with _httpx.AsyncClient(timeout=10) as http:
resp = await http.get(
graph_url,
headers={"Authorization": f"Bearer {user_access_token}"},
)
if resp.status_code == 200:
groups = resp.json().get("value", [])
is_member = any(
g.get("id") == allowed_group_id for g in groups
)
else:
logger.warning(
"Graph API group check returned %s for user '%s'.",
resp.status_code, email,
)
except Exception as graph_exc:
logger.error("Graph API group check failed: %s", graph_exc)
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Could not verify Azure AD group membership. Please try again.",
)
if not is_member:
logger.warning(
"Azure AD login denied for '%s': not a member of required group '%s'.",
email, allowed_group_id,
)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied: you are not a member of the required Azure AD group.",
)
else:
logger.warning(
"azure_allowed_group_id is not configured. All Azure AD tenant users can log in. "
"Set azure_allowed_group_id in Settings to restrict access."
)
# Find or create user
user = db.query(User).filter(User.username == email).first()
if not user:
user = User(
username=email,
password_hash=hash_password(secrets.token_urlsafe(32)),
email=email,
is_active=True,
role="viewer", # New Azure users start as viewer; promote manually
auth_provider="azure",
)
db.add(user)
db.commit()
db.refresh(user)
logger.info("Azure AD user '%s' auto-created with role 'viewer'.", email)
elif not user.is_active:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Account is disabled.",
)
token = create_access_token(user.username)
logger.info("Azure AD user '%s' logged in.", user.username)
return {
"access_token": token,
"token_type": "bearer",
"user": user.to_dict(),
}
except HTTPException:
raise
except Exception:
logger.exception("Azure AD authentication error")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Azure AD authentication failed. Please try again or contact support.",
)