diff --git a/app.py b/app.py index b79f378..6f6b248 100644 --- a/app.py +++ b/app.py @@ -1,9 +1,13 @@ from cryptography.hazmat.primitives import serialization from flask import Flask, request, jsonify from models import db, VAPIDKey, Subscription -from py_vapid import Vapid from pywebpush import webpush, WebPushException +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import ec + +import binascii import base64 import json import os @@ -16,16 +20,24 @@ def create_app(config_name): def generate_and_save_vapid_keys(): - vapid = Vapid() try: - vapid.generate_keys() - private_key = vapid.private_pem().decode('utf-8').strip() - public_key = vapid.public_pem().decode('utf-8').strip() + private_key = ec.generate_private_key(ec.SECP256R1(), default_backend()) + public_key_bytes = private_key.public_key().public_bytes( + encoding=serialization.Encoding.X962, + format=serialization.PublicFormat.UncompressedPoint + ) + public_key_base64 = base64.b64encode(public_key_bytes).decode() + + private_key_hex = binascii.hexlify(private_key.private_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption() + )).decode() except Exception as e: print(f"Error generating VAPID keys: {e}") - key = VAPIDKey(public_key=public_key, private_key=private_key) + key = VAPIDKey(public_key=public_key_base64, private_key=private_key_hex) db.session.add(key) db.session.commit() @@ -37,48 +49,58 @@ def create_app(config_name): def send_push_notification(subscription_info, message, vapid_key): try: + private_key_bytes = bytes.fromhex(vapid_key.private_key) + webpush( subscription_info=subscription_info, data=json.dumps(message), - vapid_private_key=vapid_key.private_key, + vapid_private_key=private_key_bytes, vapid_claims={ "sub": "mailto:your-email@example.com" } ) + except WebPushException as e: print(f"Failed to send notification: {e}") + def is_valid_base64_url(s): return re.match(r'^[A-Za-z0-9_-]*$', s) is not None + def is_valid_server_key(server_key): if not is_valid_base64_url(server_key): return False return len(server_key) == 88 - @app.route('/web-push/vapid', methods=['GET']) - def get_vapid(): - key = VAPIDKey.query.first() - pem_bytes = key.private_key.encode("utf-8") - private_key = serialization.load_pem_private_key(pem_bytes, password=None) - public_key = private_key.public_key() - public_key_der = public_key.public_bytes( - encoding=serialization.Encoding.DER, - format=serialization.PublicFormat.SubjectPublicKeyInfo - ) + + @app.route('/web-push/clear_subscriptions', methods=['POST']) + def clear_subscriptions(): + try: + Subscription.query.delete() + return jsonify(message='Subscriptions cleared') - base64_data = base64.urlsafe_b64encode(public_key_der) - base64_str = base64_data.decode('utf-8') - base64_str = base64_str.rstrip('=') + except Exception as e: + return jsonify(error=f'Error clearing subscriptions: {str(e)}'), 500 - if is_valid_server_key(base64_str): - return jsonify(vapidKey=base64_str) - else: - return jsonify(error=f'VAPID server key is not valid. {len(base64_str)} {is_valid_base64_url(base64_str)}'), 422 + @app.route('/web-push/regenerate_vapid_keys', methods=['POST']) + def regenerate_vapid_keys(): + try: + # Generate new VAPID keys + VAPIDKey.query.delete() + generate_and_save_vapid_keys() + return jsonify(message='VAPID keys regenerated successfully') - return jsonify(error='No VAPID keys found'), 404 + except Exception as e: + return jsonify(error=f'Error regenerating VAPID keys: {str(e)}'), 500 + + + @app.route('/web-push/vapid', methods=['GET']) + def get_vapid(): + key = VAPIDKey.query.first() + return jsonify(vapidKey=key.public_key) @app.route('/web-push/subscribe', methods=['POST']) @@ -128,4 +150,3 @@ def create_app(config_name): initialize() return app -