import logging as log from typing import Type import requests from sqlalchemy.orm import Session from sqlalchemy import select, desc, delete import constants from db.helpers import get_engine from db.models import ApiToken from constants import * import logging import datetime logger = logging.getLogger(__name__) def get_auth_token() -> str: delete_old_tokens() all_tokens: list[ApiToken] = get_all_tokens() if len(all_tokens) == 0: return generate_token().access_token # check if there is at least one active active_tokens: list[ApiToken] = [] tokens_to_refresh: list[ApiToken] = [] for token in all_tokens: if token.date_expiration_access_token > datetime.datetime.now(): active_tokens.append(token) else: tokens_to_refresh.append(token) with Session(get_engine()) as session: if len(active_tokens) > 0: logger.info("There are active tokens. Returning first and deleting the rest") for idx, active_token in enumerate(active_tokens): if idx > 0: logger.info(f"More than 1 active token. Deleting auth token.") session.delete(active_token) return active_tokens[0].access_token for idx, ref_token in enumerate(tokens_to_refresh): if idx > 0: logger.info(f"More than 1 token to refresh. Deleting token expired in {ref_token.date_expiration_refresh_token}") session.delete(ref_token) session.commit() return refresh_token(tokens_to_refresh[0]).access_token def delete_old_tokens(): with Session(get_engine()) as session: logger.info("Deleting old tokens.") delete(ApiToken).where(ApiToken.date_expiration_refresh_token > (datetime.datetime.now() - datetime.timedelta(1))) session.commit() def refresh_token(token: ApiToken) -> ApiToken: response = requests.get(f"{constants.api_url}/auth?refresh_token={token.refresh_token}") refreshed_token = persist_auth_token(response.json()) with Session(get_engine()) as session: session.delete(token) session.commit() return refresh_token def get_all_tokens(): with Session(get_engine()) as session: return session.query(ApiToken).all() def generate_token() -> ApiToken: body = { "consumer_key": api_consumer_key, "consumer_secret": api_consumer_secret, "code": api_code } response = requests.post(f"{constants.api_url}/auth", json=body) if response.status_code in [201, 200]: log.info('Generated tray access token: %s', response.json()) return persist_auth_token(response.json()) def persist_auth_token(api_token) -> ApiToken: with Session(get_engine()) as session: token = ApiToken(message=api_token["message"], code=api_token["code"], access_token=api_token["access_token"], refresh_token=api_token["refresh_token"], date_expiration_access_token=api_token["date_expiration_access_token"], date_expiration_refresh_token=api_token["date_expiration_refresh_token"], date_activated=api_token["date_activated"], api_host=api_token["api_host"], store_id=api_token["store_id"]) session.add(token) session.commit() return token def delete_token(token): with Session(get_engine()) as session: log.info('Deleting token %s', token.id) session.delete(token) session.commit() def get_last_api_token() -> Type[ApiToken] | None: with Session(get_engine()) as session: return session.query(ApiToken).order_by(desc(ApiToken.id)).first()