31 changed files with 822 additions and 260 deletions
@ -0,0 +1,38 @@ |
|||||
|
import time |
||||
|
from typing import Annotated |
||||
|
|
||||
|
from fastapi import Depends, FastAPI, HTTPException |
||||
|
from fastapi.responses import StreamingResponse |
||||
|
from sqlmodel import Field, Session, SQLModel, create_engine |
||||
|
|
||||
|
engine = create_engine("postgresql+psycopg://postgres:postgres@localhost/db") |
||||
|
|
||||
|
|
||||
|
class User(SQLModel, table=True): |
||||
|
id: int | None = Field(default=None, primary_key=True) |
||||
|
name: str |
||||
|
|
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
|
||||
|
def get_session(): |
||||
|
with Session(engine) as session: |
||||
|
yield session |
||||
|
|
||||
|
|
||||
|
def get_user(user_id: int, session: Annotated[Session, Depends(get_session)]): |
||||
|
user = session.get(User, user_id) |
||||
|
if not user: |
||||
|
raise HTTPException(status_code=403, detail="Not authorized") |
||||
|
|
||||
|
|
||||
|
def generate_stream(query: str): |
||||
|
for ch in query: |
||||
|
yield ch |
||||
|
time.sleep(0.1) |
||||
|
|
||||
|
|
||||
|
@app.get("/generate", dependencies=[Depends(get_user)]) |
||||
|
def generate(query: str): |
||||
|
return StreamingResponse(content=generate_stream(query)) |
||||
@ -0,0 +1,39 @@ |
|||||
|
import time |
||||
|
from typing import Annotated |
||||
|
|
||||
|
from fastapi import Depends, FastAPI, HTTPException |
||||
|
from fastapi.responses import StreamingResponse |
||||
|
from sqlmodel import Field, Session, SQLModel, create_engine |
||||
|
|
||||
|
engine = create_engine("postgresql+psycopg://postgres:postgres@localhost/db") |
||||
|
|
||||
|
|
||||
|
class User(SQLModel, table=True): |
||||
|
id: int | None = Field(default=None, primary_key=True) |
||||
|
name: str |
||||
|
|
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
|
||||
|
def get_session(): |
||||
|
with Session(engine) as session: |
||||
|
yield session |
||||
|
|
||||
|
|
||||
|
def get_user(user_id: int, session: Annotated[Session, Depends(get_session)]): |
||||
|
user = session.get(User, user_id) |
||||
|
if not user: |
||||
|
raise HTTPException(status_code=403, detail="Not authorized") |
||||
|
session.close() |
||||
|
|
||||
|
|
||||
|
def generate_stream(query: str): |
||||
|
for ch in query: |
||||
|
yield ch |
||||
|
time.sleep(0.1) |
||||
|
|
||||
|
|
||||
|
@app.get("/generate", dependencies=[Depends(get_user)]) |
||||
|
def generate(query: str): |
||||
|
return StreamingResponse(content=generate_stream(query)) |
||||
@ -5,7 +5,7 @@ import jwt |
|||||
from fastapi import Depends, FastAPI, HTTPException, status |
from fastapi import Depends, FastAPI, HTTPException, status |
||||
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm |
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm |
||||
from jwt.exceptions import InvalidTokenError |
from jwt.exceptions import InvalidTokenError |
||||
from passlib.context import CryptContext |
from pwdlib import PasswordHash |
||||
from pydantic import BaseModel |
from pydantic import BaseModel |
||||
|
|
||||
# to get a string like this run: |
# to get a string like this run: |
||||
@ -20,7 +20,7 @@ fake_users_db = { |
|||||
"username": "johndoe", |
"username": "johndoe", |
||||
"full_name": "John Doe", |
"full_name": "John Doe", |
||||
"email": "[email protected]", |
"email": "[email protected]", |
||||
"hashed_password": "$2b$12$EixZaYVK1fsbw1ZfbX3OXePaWxn96p36WQoeG6Lruj3vjPGga31lW", |
"hashed_password": "$argon2id$v=19$m=65536,t=3,p=4$wagCPXjifgvUFBzq4hqe3w$CYaIb8sB+wtD+Vu/P4uod1+Qof8h+1g7bbDlBID48Rc", |
||||
"disabled": False, |
"disabled": False, |
||||
} |
} |
||||
} |
} |
||||
@ -46,7 +46,7 @@ class UserInDB(User): |
|||||
hashed_password: str |
hashed_password: str |
||||
|
|
||||
|
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") |
password_hash = PasswordHash.recommended() |
||||
|
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") |
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") |
||||
|
|
||||
@ -54,11 +54,11 @@ app = FastAPI() |
|||||
|
|
||||
|
|
||||
def verify_password(plain_password, hashed_password): |
def verify_password(plain_password, hashed_password): |
||||
return pwd_context.verify(plain_password, hashed_password) |
return password_hash.verify(plain_password, hashed_password) |
||||
|
|
||||
|
|
||||
def get_password_hash(password): |
def get_password_hash(password): |
||||
return pwd_context.hash(password) |
return password_hash.hash(password) |
||||
|
|
||||
|
|
||||
def get_user(db, username: str): |
def get_user(db, username: str): |
||||
|
|||||
@ -5,7 +5,7 @@ import jwt |
|||||
from fastapi import Depends, FastAPI, HTTPException, status |
from fastapi import Depends, FastAPI, HTTPException, status |
||||
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm |
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm |
||||
from jwt.exceptions import InvalidTokenError |
from jwt.exceptions import InvalidTokenError |
||||
from passlib.context import CryptContext |
from pwdlib import PasswordHash |
||||
from pydantic import BaseModel |
from pydantic import BaseModel |
||||
from typing_extensions import Annotated |
from typing_extensions import Annotated |
||||
|
|
||||
@ -21,7 +21,7 @@ fake_users_db = { |
|||||
"username": "johndoe", |
"username": "johndoe", |
||||
"full_name": "John Doe", |
"full_name": "John Doe", |
||||
"email": "[email protected]", |
"email": "[email protected]", |
||||
"hashed_password": "$2b$12$EixZaYVK1fsbw1ZfbX3OXePaWxn96p36WQoeG6Lruj3vjPGga31lW", |
"hashed_password": "$argon2id$v=19$m=65536,t=3,p=4$wagCPXjifgvUFBzq4hqe3w$CYaIb8sB+wtD+Vu/P4uod1+Qof8h+1g7bbDlBID48Rc", |
||||
"disabled": False, |
"disabled": False, |
||||
} |
} |
||||
} |
} |
||||
@ -47,7 +47,7 @@ class UserInDB(User): |
|||||
hashed_password: str |
hashed_password: str |
||||
|
|
||||
|
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") |
password_hash = PasswordHash.recommended() |
||||
|
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") |
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") |
||||
|
|
||||
@ -55,11 +55,11 @@ app = FastAPI() |
|||||
|
|
||||
|
|
||||
def verify_password(plain_password, hashed_password): |
def verify_password(plain_password, hashed_password): |
||||
return pwd_context.verify(plain_password, hashed_password) |
return password_hash.verify(plain_password, hashed_password) |
||||
|
|
||||
|
|
||||
def get_password_hash(password): |
def get_password_hash(password): |
||||
return pwd_context.hash(password) |
return password_hash.hash(password) |
||||
|
|
||||
|
|
||||
def get_user(db, username: str): |
def get_user(db, username: str): |
||||
|
|||||
@ -5,7 +5,7 @@ import jwt |
|||||
from fastapi import Depends, FastAPI, HTTPException, status |
from fastapi import Depends, FastAPI, HTTPException, status |
||||
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm |
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm |
||||
from jwt.exceptions import InvalidTokenError |
from jwt.exceptions import InvalidTokenError |
||||
from passlib.context import CryptContext |
from pwdlib import PasswordHash |
||||
from pydantic import BaseModel |
from pydantic import BaseModel |
||||
|
|
||||
# to get a string like this run: |
# to get a string like this run: |
||||
@ -20,7 +20,7 @@ fake_users_db = { |
|||||
"username": "johndoe", |
"username": "johndoe", |
||||
"full_name": "John Doe", |
"full_name": "John Doe", |
||||
"email": "[email protected]", |
"email": "[email protected]", |
||||
"hashed_password": "$2b$12$EixZaYVK1fsbw1ZfbX3OXePaWxn96p36WQoeG6Lruj3vjPGga31lW", |
"hashed_password": "$argon2id$v=19$m=65536,t=3,p=4$wagCPXjifgvUFBzq4hqe3w$CYaIb8sB+wtD+Vu/P4uod1+Qof8h+1g7bbDlBID48Rc", |
||||
"disabled": False, |
"disabled": False, |
||||
} |
} |
||||
} |
} |
||||
@ -46,7 +46,7 @@ class UserInDB(User): |
|||||
hashed_password: str |
hashed_password: str |
||||
|
|
||||
|
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") |
password_hash = PasswordHash.recommended() |
||||
|
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") |
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") |
||||
|
|
||||
@ -54,11 +54,11 @@ app = FastAPI() |
|||||
|
|
||||
|
|
||||
def verify_password(plain_password, hashed_password): |
def verify_password(plain_password, hashed_password): |
||||
return pwd_context.verify(plain_password, hashed_password) |
return password_hash.verify(plain_password, hashed_password) |
||||
|
|
||||
|
|
||||
def get_password_hash(password): |
def get_password_hash(password): |
||||
return pwd_context.hash(password) |
return password_hash.hash(password) |
||||
|
|
||||
|
|
||||
def get_user(db, username: str): |
def get_user(db, username: str): |
||||
|
|||||
@ -5,7 +5,7 @@ import jwt |
|||||
from fastapi import Depends, FastAPI, HTTPException, status |
from fastapi import Depends, FastAPI, HTTPException, status |
||||
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm |
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm |
||||
from jwt.exceptions import InvalidTokenError |
from jwt.exceptions import InvalidTokenError |
||||
from passlib.context import CryptContext |
from pwdlib import PasswordHash |
||||
from pydantic import BaseModel |
from pydantic import BaseModel |
||||
|
|
||||
# to get a string like this run: |
# to get a string like this run: |
||||
@ -20,7 +20,7 @@ fake_users_db = { |
|||||
"username": "johndoe", |
"username": "johndoe", |
||||
"full_name": "John Doe", |
"full_name": "John Doe", |
||||
"email": "[email protected]", |
"email": "[email protected]", |
||||
"hashed_password": "$2b$12$EixZaYVK1fsbw1ZfbX3OXePaWxn96p36WQoeG6Lruj3vjPGga31lW", |
"hashed_password": "$argon2id$v=19$m=65536,t=3,p=4$wagCPXjifgvUFBzq4hqe3w$CYaIb8sB+wtD+Vu/P4uod1+Qof8h+1g7bbDlBID48Rc", |
||||
"disabled": False, |
"disabled": False, |
||||
} |
} |
||||
} |
} |
||||
@ -46,7 +46,7 @@ class UserInDB(User): |
|||||
hashed_password: str |
hashed_password: str |
||||
|
|
||||
|
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") |
password_hash = PasswordHash.recommended() |
||||
|
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") |
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") |
||||
|
|
||||
@ -54,11 +54,11 @@ app = FastAPI() |
|||||
|
|
||||
|
|
||||
def verify_password(plain_password, hashed_password): |
def verify_password(plain_password, hashed_password): |
||||
return pwd_context.verify(plain_password, hashed_password) |
return password_hash.verify(plain_password, hashed_password) |
||||
|
|
||||
|
|
||||
def get_password_hash(password): |
def get_password_hash(password): |
||||
return pwd_context.hash(password) |
return password_hash.hash(password) |
||||
|
|
||||
|
|
||||
def get_user(db, username: str): |
def get_user(db, username: str): |
||||
|
|||||
@ -4,7 +4,7 @@ import jwt |
|||||
from fastapi import Depends, FastAPI, HTTPException, status |
from fastapi import Depends, FastAPI, HTTPException, status |
||||
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm |
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm |
||||
from jwt.exceptions import InvalidTokenError |
from jwt.exceptions import InvalidTokenError |
||||
from passlib.context import CryptContext |
from pwdlib import PasswordHash |
||||
from pydantic import BaseModel |
from pydantic import BaseModel |
||||
|
|
||||
# to get a string like this run: |
# to get a string like this run: |
||||
@ -19,7 +19,7 @@ fake_users_db = { |
|||||
"username": "johndoe", |
"username": "johndoe", |
||||
"full_name": "John Doe", |
"full_name": "John Doe", |
||||
"email": "[email protected]", |
"email": "[email protected]", |
||||
"hashed_password": "$2b$12$EixZaYVK1fsbw1ZfbX3OXePaWxn96p36WQoeG6Lruj3vjPGga31lW", |
"hashed_password": "$argon2id$v=19$m=65536,t=3,p=4$wagCPXjifgvUFBzq4hqe3w$CYaIb8sB+wtD+Vu/P4uod1+Qof8h+1g7bbDlBID48Rc", |
||||
"disabled": False, |
"disabled": False, |
||||
} |
} |
||||
} |
} |
||||
@ -45,7 +45,7 @@ class UserInDB(User): |
|||||
hashed_password: str |
hashed_password: str |
||||
|
|
||||
|
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") |
password_hash = PasswordHash.recommended() |
||||
|
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") |
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") |
||||
|
|
||||
@ -53,11 +53,11 @@ app = FastAPI() |
|||||
|
|
||||
|
|
||||
def verify_password(plain_password, hashed_password): |
def verify_password(plain_password, hashed_password): |
||||
return pwd_context.verify(plain_password, hashed_password) |
return password_hash.verify(plain_password, hashed_password) |
||||
|
|
||||
|
|
||||
def get_password_hash(password): |
def get_password_hash(password): |
||||
return pwd_context.hash(password) |
return password_hash.hash(password) |
||||
|
|
||||
|
|
||||
def get_user(db, username: str): |
def get_user(db, username: str): |
||||
|
|||||
@ -9,7 +9,7 @@ from fastapi.security import ( |
|||||
SecurityScopes, |
SecurityScopes, |
||||
) |
) |
||||
from jwt.exceptions import InvalidTokenError |
from jwt.exceptions import InvalidTokenError |
||||
from passlib.context import CryptContext |
from pwdlib import PasswordHash |
||||
from pydantic import BaseModel, ValidationError |
from pydantic import BaseModel, ValidationError |
||||
|
|
||||
# to get a string like this run: |
# to get a string like this run: |
||||
@ -24,14 +24,14 @@ fake_users_db = { |
|||||
"username": "johndoe", |
"username": "johndoe", |
||||
"full_name": "John Doe", |
"full_name": "John Doe", |
||||
"email": "[email protected]", |
"email": "[email protected]", |
||||
"hashed_password": "$2b$12$EixZaYVK1fsbw1ZfbX3OXePaWxn96p36WQoeG6Lruj3vjPGga31lW", |
"hashed_password": "$argon2id$v=19$m=65536,t=3,p=4$wagCPXjifgvUFBzq4hqe3w$CYaIb8sB+wtD+Vu/P4uod1+Qof8h+1g7bbDlBID48Rc", |
||||
"disabled": False, |
"disabled": False, |
||||
}, |
}, |
||||
"alice": { |
"alice": { |
||||
"username": "alice", |
"username": "alice", |
||||
"full_name": "Alice Chains", |
"full_name": "Alice Chains", |
||||
"email": "[email protected]", |
"email": "[email protected]", |
||||
"hashed_password": "$2b$12$gSvqqUPvlXP2tfVFaWK1Be7DlH.PKZbv5H8KnzzVgXXbVxpva.pFm", |
"hashed_password": "$argon2id$v=19$m=65536,t=3,p=4$g2/AV1zwopqUntPKJavBFw$BwpRGDCyUHLvHICnwijyX8ROGoiUPwNKZ7915MeYfCE", |
||||
"disabled": True, |
"disabled": True, |
||||
}, |
}, |
||||
} |
} |
||||
@ -58,7 +58,7 @@ class UserInDB(User): |
|||||
hashed_password: str |
hashed_password: str |
||||
|
|
||||
|
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") |
password_hash = PasswordHash.recommended() |
||||
|
|
||||
oauth2_scheme = OAuth2PasswordBearer( |
oauth2_scheme = OAuth2PasswordBearer( |
||||
tokenUrl="token", |
tokenUrl="token", |
||||
@ -69,11 +69,11 @@ app = FastAPI() |
|||||
|
|
||||
|
|
||||
def verify_password(plain_password, hashed_password): |
def verify_password(plain_password, hashed_password): |
||||
return pwd_context.verify(plain_password, hashed_password) |
return password_hash.verify(plain_password, hashed_password) |
||||
|
|
||||
|
|
||||
def get_password_hash(password): |
def get_password_hash(password): |
||||
return pwd_context.hash(password) |
return password_hash.hash(password) |
||||
|
|
||||
|
|
||||
def get_user(db, username: str): |
def get_user(db, username: str): |
||||
|
|||||
@ -9,7 +9,7 @@ from fastapi.security import ( |
|||||
SecurityScopes, |
SecurityScopes, |
||||
) |
) |
||||
from jwt.exceptions import InvalidTokenError |
from jwt.exceptions import InvalidTokenError |
||||
from passlib.context import CryptContext |
from pwdlib import PasswordHash |
||||
from pydantic import BaseModel, ValidationError |
from pydantic import BaseModel, ValidationError |
||||
from typing_extensions import Annotated |
from typing_extensions import Annotated |
||||
|
|
||||
@ -25,14 +25,14 @@ fake_users_db = { |
|||||
"username": "johndoe", |
"username": "johndoe", |
||||
"full_name": "John Doe", |
"full_name": "John Doe", |
||||
"email": "[email protected]", |
"email": "[email protected]", |
||||
"hashed_password": "$2b$12$EixZaYVK1fsbw1ZfbX3OXePaWxn96p36WQoeG6Lruj3vjPGga31lW", |
"hashed_password": "$argon2id$v=19$m=65536,t=3,p=4$wagCPXjifgvUFBzq4hqe3w$CYaIb8sB+wtD+Vu/P4uod1+Qof8h+1g7bbDlBID48Rc", |
||||
"disabled": False, |
"disabled": False, |
||||
}, |
}, |
||||
"alice": { |
"alice": { |
||||
"username": "alice", |
"username": "alice", |
||||
"full_name": "Alice Chains", |
"full_name": "Alice Chains", |
||||
"email": "[email protected]", |
"email": "[email protected]", |
||||
"hashed_password": "$2b$12$gSvqqUPvlXP2tfVFaWK1Be7DlH.PKZbv5H8KnzzVgXXbVxpva.pFm", |
"hashed_password": "$argon2id$v=19$m=65536,t=3,p=4$g2/AV1zwopqUntPKJavBFw$BwpRGDCyUHLvHICnwijyX8ROGoiUPwNKZ7915MeYfCE", |
||||
"disabled": True, |
"disabled": True, |
||||
}, |
}, |
||||
} |
} |
||||
@ -59,7 +59,7 @@ class UserInDB(User): |
|||||
hashed_password: str |
hashed_password: str |
||||
|
|
||||
|
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") |
password_hash = PasswordHash.recommended() |
||||
|
|
||||
oauth2_scheme = OAuth2PasswordBearer( |
oauth2_scheme = OAuth2PasswordBearer( |
||||
tokenUrl="token", |
tokenUrl="token", |
||||
@ -70,11 +70,11 @@ app = FastAPI() |
|||||
|
|
||||
|
|
||||
def verify_password(plain_password, hashed_password): |
def verify_password(plain_password, hashed_password): |
||||
return pwd_context.verify(plain_password, hashed_password) |
return password_hash.verify(plain_password, hashed_password) |
||||
|
|
||||
|
|
||||
def get_password_hash(password): |
def get_password_hash(password): |
||||
return pwd_context.hash(password) |
return password_hash.hash(password) |
||||
|
|
||||
|
|
||||
def get_user(db, username: str): |
def get_user(db, username: str): |
||||
|
|||||
@ -9,7 +9,7 @@ from fastapi.security import ( |
|||||
SecurityScopes, |
SecurityScopes, |
||||
) |
) |
||||
from jwt.exceptions import InvalidTokenError |
from jwt.exceptions import InvalidTokenError |
||||
from passlib.context import CryptContext |
from pwdlib import PasswordHash |
||||
from pydantic import BaseModel, ValidationError |
from pydantic import BaseModel, ValidationError |
||||
|
|
||||
# to get a string like this run: |
# to get a string like this run: |
||||
@ -24,14 +24,14 @@ fake_users_db = { |
|||||
"username": "johndoe", |
"username": "johndoe", |
||||
"full_name": "John Doe", |
"full_name": "John Doe", |
||||
"email": "[email protected]", |
"email": "[email protected]", |
||||
"hashed_password": "$2b$12$EixZaYVK1fsbw1ZfbX3OXePaWxn96p36WQoeG6Lruj3vjPGga31lW", |
"hashed_password": "$argon2id$v=19$m=65536,t=3,p=4$wagCPXjifgvUFBzq4hqe3w$CYaIb8sB+wtD+Vu/P4uod1+Qof8h+1g7bbDlBID48Rc", |
||||
"disabled": False, |
"disabled": False, |
||||
}, |
}, |
||||
"alice": { |
"alice": { |
||||
"username": "alice", |
"username": "alice", |
||||
"full_name": "Alice Chains", |
"full_name": "Alice Chains", |
||||
"email": "[email protected]", |
"email": "[email protected]", |
||||
"hashed_password": "$2b$12$gSvqqUPvlXP2tfVFaWK1Be7DlH.PKZbv5H8KnzzVgXXbVxpva.pFm", |
"hashed_password": "$argon2id$v=19$m=65536,t=3,p=4$g2/AV1zwopqUntPKJavBFw$BwpRGDCyUHLvHICnwijyX8ROGoiUPwNKZ7915MeYfCE", |
||||
"disabled": True, |
"disabled": True, |
||||
}, |
}, |
||||
} |
} |
||||
@ -58,7 +58,7 @@ class UserInDB(User): |
|||||
hashed_password: str |
hashed_password: str |
||||
|
|
||||
|
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") |
password_hash = PasswordHash.recommended() |
||||
|
|
||||
oauth2_scheme = OAuth2PasswordBearer( |
oauth2_scheme = OAuth2PasswordBearer( |
||||
tokenUrl="token", |
tokenUrl="token", |
||||
@ -69,11 +69,11 @@ app = FastAPI() |
|||||
|
|
||||
|
|
||||
def verify_password(plain_password, hashed_password): |
def verify_password(plain_password, hashed_password): |
||||
return pwd_context.verify(plain_password, hashed_password) |
return password_hash.verify(plain_password, hashed_password) |
||||
|
|
||||
|
|
||||
def get_password_hash(password): |
def get_password_hash(password): |
||||
return pwd_context.hash(password) |
return password_hash.hash(password) |
||||
|
|
||||
|
|
||||
def get_user(db, username: str): |
def get_user(db, username: str): |
||||
|
|||||
@ -9,7 +9,7 @@ from fastapi.security import ( |
|||||
SecurityScopes, |
SecurityScopes, |
||||
) |
) |
||||
from jwt.exceptions import InvalidTokenError |
from jwt.exceptions import InvalidTokenError |
||||
from passlib.context import CryptContext |
from pwdlib import PasswordHash |
||||
from pydantic import BaseModel, ValidationError |
from pydantic import BaseModel, ValidationError |
||||
|
|
||||
# to get a string like this run: |
# to get a string like this run: |
||||
@ -24,14 +24,14 @@ fake_users_db = { |
|||||
"username": "johndoe", |
"username": "johndoe", |
||||
"full_name": "John Doe", |
"full_name": "John Doe", |
||||
"email": "[email protected]", |
"email": "[email protected]", |
||||
"hashed_password": "$2b$12$EixZaYVK1fsbw1ZfbX3OXePaWxn96p36WQoeG6Lruj3vjPGga31lW", |
"hashed_password": "$argon2id$v=19$m=65536,t=3,p=4$wagCPXjifgvUFBzq4hqe3w$CYaIb8sB+wtD+Vu/P4uod1+Qof8h+1g7bbDlBID48Rc", |
||||
"disabled": False, |
"disabled": False, |
||||
}, |
}, |
||||
"alice": { |
"alice": { |
||||
"username": "alice", |
"username": "alice", |
||||
"full_name": "Alice Chains", |
"full_name": "Alice Chains", |
||||
"email": "[email protected]", |
"email": "[email protected]", |
||||
"hashed_password": "$2b$12$gSvqqUPvlXP2tfVFaWK1Be7DlH.PKZbv5H8KnzzVgXXbVxpva.pFm", |
"hashed_password": "$argon2id$v=19$m=65536,t=3,p=4$g2/AV1zwopqUntPKJavBFw$BwpRGDCyUHLvHICnwijyX8ROGoiUPwNKZ7915MeYfCE", |
||||
"disabled": True, |
"disabled": True, |
||||
}, |
}, |
||||
} |
} |
||||
@ -58,7 +58,7 @@ class UserInDB(User): |
|||||
hashed_password: str |
hashed_password: str |
||||
|
|
||||
|
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") |
password_hash = PasswordHash.recommended() |
||||
|
|
||||
oauth2_scheme = OAuth2PasswordBearer( |
oauth2_scheme = OAuth2PasswordBearer( |
||||
tokenUrl="token", |
tokenUrl="token", |
||||
@ -69,11 +69,11 @@ app = FastAPI() |
|||||
|
|
||||
|
|
||||
def verify_password(plain_password, hashed_password): |
def verify_password(plain_password, hashed_password): |
||||
return pwd_context.verify(plain_password, hashed_password) |
return password_hash.verify(plain_password, hashed_password) |
||||
|
|
||||
|
|
||||
def get_password_hash(password): |
def get_password_hash(password): |
||||
return pwd_context.hash(password) |
return password_hash.hash(password) |
||||
|
|
||||
|
|
||||
def get_user(db, username: str): |
def get_user(db, username: str): |
||||
|
|||||
@ -8,7 +8,7 @@ from fastapi.security import ( |
|||||
SecurityScopes, |
SecurityScopes, |
||||
) |
) |
||||
from jwt.exceptions import InvalidTokenError |
from jwt.exceptions import InvalidTokenError |
||||
from passlib.context import CryptContext |
from pwdlib import PasswordHash |
||||
from pydantic import BaseModel, ValidationError |
from pydantic import BaseModel, ValidationError |
||||
|
|
||||
# to get a string like this run: |
# to get a string like this run: |
||||
@ -23,14 +23,14 @@ fake_users_db = { |
|||||
"username": "johndoe", |
"username": "johndoe", |
||||
"full_name": "John Doe", |
"full_name": "John Doe", |
||||
"email": "[email protected]", |
"email": "[email protected]", |
||||
"hashed_password": "$2b$12$EixZaYVK1fsbw1ZfbX3OXePaWxn96p36WQoeG6Lruj3vjPGga31lW", |
"hashed_password": "$argon2id$v=19$m=65536,t=3,p=4$wagCPXjifgvUFBzq4hqe3w$CYaIb8sB+wtD+Vu/P4uod1+Qof8h+1g7bbDlBID48Rc", |
||||
"disabled": False, |
"disabled": False, |
||||
}, |
}, |
||||
"alice": { |
"alice": { |
||||
"username": "alice", |
"username": "alice", |
||||
"full_name": "Alice Chains", |
"full_name": "Alice Chains", |
||||
"email": "[email protected]", |
"email": "[email protected]", |
||||
"hashed_password": "$2b$12$gSvqqUPvlXP2tfVFaWK1Be7DlH.PKZbv5H8KnzzVgXXbVxpva.pFm", |
"hashed_password": "$argon2id$v=19$m=65536,t=3,p=4$g2/AV1zwopqUntPKJavBFw$BwpRGDCyUHLvHICnwijyX8ROGoiUPwNKZ7915MeYfCE", |
||||
"disabled": True, |
"disabled": True, |
||||
}, |
}, |
||||
} |
} |
||||
@ -57,7 +57,7 @@ class UserInDB(User): |
|||||
hashed_password: str |
hashed_password: str |
||||
|
|
||||
|
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") |
password_hash = PasswordHash.recommended() |
||||
|
|
||||
oauth2_scheme = OAuth2PasswordBearer( |
oauth2_scheme = OAuth2PasswordBearer( |
||||
tokenUrl="token", |
tokenUrl="token", |
||||
@ -68,11 +68,11 @@ app = FastAPI() |
|||||
|
|
||||
|
|
||||
def verify_password(plain_password, hashed_password): |
def verify_password(plain_password, hashed_password): |
||||
return pwd_context.verify(plain_password, hashed_password) |
return password_hash.verify(plain_password, hashed_password) |
||||
|
|
||||
|
|
||||
def get_password_hash(password): |
def get_password_hash(password): |
||||
return pwd_context.hash(password) |
return password_hash.hash(password) |
||||
|
|
||||
|
|
||||
def get_user(db, username: str): |
def get_user(db, username: str): |
||||
|
|||||
@ -9,7 +9,7 @@ from fastapi.security import ( |
|||||
SecurityScopes, |
SecurityScopes, |
||||
) |
) |
||||
from jwt.exceptions import InvalidTokenError |
from jwt.exceptions import InvalidTokenError |
||||
from passlib.context import CryptContext |
from pwdlib import PasswordHash |
||||
from pydantic import BaseModel, ValidationError |
from pydantic import BaseModel, ValidationError |
||||
|
|
||||
# to get a string like this run: |
# to get a string like this run: |
||||
@ -24,14 +24,14 @@ fake_users_db = { |
|||||
"username": "johndoe", |
"username": "johndoe", |
||||
"full_name": "John Doe", |
"full_name": "John Doe", |
||||
"email": "[email protected]", |
"email": "[email protected]", |
||||
"hashed_password": "$2b$12$EixZaYVK1fsbw1ZfbX3OXePaWxn96p36WQoeG6Lruj3vjPGga31lW", |
"hashed_password": "$argon2id$v=19$m=65536,t=3,p=4$wagCPXjifgvUFBzq4hqe3w$CYaIb8sB+wtD+Vu/P4uod1+Qof8h+1g7bbDlBID48Rc", |
||||
"disabled": False, |
"disabled": False, |
||||
}, |
}, |
||||
"alice": { |
"alice": { |
||||
"username": "alice", |
"username": "alice", |
||||
"full_name": "Alice Chains", |
"full_name": "Alice Chains", |
||||
"email": "[email protected]", |
"email": "[email protected]", |
||||
"hashed_password": "$2b$12$gSvqqUPvlXP2tfVFaWK1Be7DlH.PKZbv5H8KnzzVgXXbVxpva.pFm", |
"hashed_password": "$argon2id$v=19$m=65536,t=3,p=4$g2/AV1zwopqUntPKJavBFw$BwpRGDCyUHLvHICnwijyX8ROGoiUPwNKZ7915MeYfCE", |
||||
"disabled": True, |
"disabled": True, |
||||
}, |
}, |
||||
} |
} |
||||
@ -58,7 +58,7 @@ class UserInDB(User): |
|||||
hashed_password: str |
hashed_password: str |
||||
|
|
||||
|
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") |
password_hash = PasswordHash.recommended() |
||||
|
|
||||
oauth2_scheme = OAuth2PasswordBearer( |
oauth2_scheme = OAuth2PasswordBearer( |
||||
tokenUrl="token", |
tokenUrl="token", |
||||
@ -69,11 +69,11 @@ app = FastAPI() |
|||||
|
|
||||
|
|
||||
def verify_password(plain_password, hashed_password): |
def verify_password(plain_password, hashed_password): |
||||
return pwd_context.verify(plain_password, hashed_password) |
return password_hash.verify(plain_password, hashed_password) |
||||
|
|
||||
|
|
||||
def get_password_hash(password): |
def get_password_hash(password): |
||||
return pwd_context.hash(password) |
return password_hash.hash(password) |
||||
|
|
||||
|
|
||||
def get_user(db, username: str): |
def get_user(db, username: str): |
||||
|
|||||
@ -0,0 +1,18 @@ |
|||||
|
from contextlib import AsyncExitStack |
||||
|
|
||||
|
from starlette.types import ASGIApp, Receive, Scope, Send |
||||
|
|
||||
|
|
||||
|
# Used mainly to close files after the request is done, dependencies are closed |
||||
|
# in their own AsyncExitStack |
||||
|
class AsyncExitStackMiddleware: |
||||
|
def __init__( |
||||
|
self, app: ASGIApp, context_name: str = "fastapi_middleware_astack" |
||||
|
) -> None: |
||||
|
self.app = app |
||||
|
self.context_name = context_name |
||||
|
|
||||
|
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: |
||||
|
async with AsyncExitStack() as stack: |
||||
|
scope[self.context_name] = stack |
||||
|
await self.app(scope, receive, send) |
||||
@ -0,0 +1,69 @@ |
|||||
|
from typing import Any |
||||
|
|
||||
|
import pytest |
||||
|
from fastapi import Depends, FastAPI, HTTPException |
||||
|
from fastapi.testclient import TestClient |
||||
|
from typing_extensions import Annotated |
||||
|
|
||||
|
|
||||
|
class CustomError(Exception): |
||||
|
pass |
||||
|
|
||||
|
|
||||
|
def catching_dep() -> Any: |
||||
|
try: |
||||
|
yield "s" |
||||
|
except CustomError as err: |
||||
|
raise HTTPException(status_code=418, detail="Session error") from err |
||||
|
|
||||
|
|
||||
|
def broken_dep() -> Any: |
||||
|
yield "s" |
||||
|
raise ValueError("Broken after yield") |
||||
|
|
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
|
||||
|
@app.get("/catching") |
||||
|
def catching(d: Annotated[str, Depends(catching_dep)]) -> Any: |
||||
|
raise CustomError("Simulated error during streaming") |
||||
|
|
||||
|
|
||||
|
@app.get("/broken") |
||||
|
def broken(d: Annotated[str, Depends(broken_dep)]) -> Any: |
||||
|
return {"message": "all good?"} |
||||
|
|
||||
|
|
||||
|
client = TestClient(app) |
||||
|
|
||||
|
|
||||
|
def test_catching(): |
||||
|
response = client.get("/catching") |
||||
|
assert response.status_code == 418 |
||||
|
assert response.json() == {"detail": "Session error"} |
||||
|
|
||||
|
|
||||
|
def test_broken_raise(): |
||||
|
with pytest.raises(ValueError, match="Broken after yield"): |
||||
|
client.get("/broken") |
||||
|
|
||||
|
|
||||
|
def test_broken_no_raise(): |
||||
|
""" |
||||
|
When a dependency with yield raises after the yield (not in an except), the |
||||
|
response is already "successfully" sent back to the client, but there's still |
||||
|
an error in the server afterwards, an exception is raised and captured or shown |
||||
|
in the server logs. |
||||
|
""" |
||||
|
with TestClient(app, raise_server_exceptions=False) as client: |
||||
|
response = client.get("/broken") |
||||
|
assert response.status_code == 200 |
||||
|
assert response.json() == {"message": "all good?"} |
||||
|
|
||||
|
|
||||
|
def test_broken_return_finishes(): |
||||
|
client = TestClient(app, raise_server_exceptions=False) |
||||
|
response = client.get("/broken") |
||||
|
assert response.status_code == 200 |
||||
|
assert response.json() == {"message": "all good?"} |
||||
@ -0,0 +1,130 @@ |
|||||
|
from contextlib import contextmanager |
||||
|
from typing import Any, Generator |
||||
|
|
||||
|
import pytest |
||||
|
from fastapi import Depends, FastAPI |
||||
|
from fastapi.responses import StreamingResponse |
||||
|
from fastapi.testclient import TestClient |
||||
|
from typing_extensions import Annotated |
||||
|
|
||||
|
|
||||
|
class Session: |
||||
|
def __init__(self) -> None: |
||||
|
self.data = ["foo", "bar", "baz"] |
||||
|
self.open = True |
||||
|
|
||||
|
def __iter__(self) -> Generator[str, None, None]: |
||||
|
for item in self.data: |
||||
|
if self.open: |
||||
|
yield item |
||||
|
else: |
||||
|
raise ValueError("Session closed") |
||||
|
|
||||
|
|
||||
|
@contextmanager |
||||
|
def acquire_session() -> Generator[Session, None, None]: |
||||
|
session = Session() |
||||
|
try: |
||||
|
yield session |
||||
|
finally: |
||||
|
session.open = False |
||||
|
|
||||
|
|
||||
|
def dep_session() -> Any: |
||||
|
with acquire_session() as s: |
||||
|
yield s |
||||
|
|
||||
|
|
||||
|
def broken_dep_session() -> Any: |
||||
|
with acquire_session() as s: |
||||
|
s.open = False |
||||
|
yield s |
||||
|
|
||||
|
|
||||
|
SessionDep = Annotated[Session, Depends(dep_session)] |
||||
|
BrokenSessionDep = Annotated[Session, Depends(broken_dep_session)] |
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
|
||||
|
@app.get("/data") |
||||
|
def get_data(session: SessionDep) -> Any: |
||||
|
data = list(session) |
||||
|
return data |
||||
|
|
||||
|
|
||||
|
@app.get("/stream-simple") |
||||
|
def get_stream_simple(session: SessionDep) -> Any: |
||||
|
def iter_data(): |
||||
|
yield from ["x", "y", "z"] |
||||
|
|
||||
|
return StreamingResponse(iter_data()) |
||||
|
|
||||
|
|
||||
|
@app.get("/stream-session") |
||||
|
def get_stream_session(session: SessionDep) -> Any: |
||||
|
def iter_data(): |
||||
|
yield from session |
||||
|
|
||||
|
return StreamingResponse(iter_data()) |
||||
|
|
||||
|
|
||||
|
@app.get("/broken-session-data") |
||||
|
def get_broken_session_data(session: BrokenSessionDep) -> Any: |
||||
|
return list(session) |
||||
|
|
||||
|
|
||||
|
@app.get("/broken-session-stream") |
||||
|
def get_broken_session_stream(session: BrokenSessionDep) -> Any: |
||||
|
def iter_data(): |
||||
|
yield from session |
||||
|
|
||||
|
return StreamingResponse(iter_data()) |
||||
|
|
||||
|
|
||||
|
client = TestClient(app) |
||||
|
|
||||
|
|
||||
|
def test_regular_no_stream(): |
||||
|
response = client.get("/data") |
||||
|
assert response.json() == ["foo", "bar", "baz"] |
||||
|
|
||||
|
|
||||
|
def test_stream_simple(): |
||||
|
response = client.get("/stream-simple") |
||||
|
assert response.text == "xyz" |
||||
|
|
||||
|
|
||||
|
def test_stream_session(): |
||||
|
response = client.get("/stream-session") |
||||
|
assert response.text == "foobarbaz" |
||||
|
|
||||
|
|
||||
|
def test_broken_session_data(): |
||||
|
with pytest.raises(ValueError, match="Session closed"): |
||||
|
client.get("/broken-session-data") |
||||
|
|
||||
|
|
||||
|
def test_broken_session_data_no_raise(): |
||||
|
client = TestClient(app, raise_server_exceptions=False) |
||||
|
response = client.get("/broken-session-data") |
||||
|
assert response.status_code == 500 |
||||
|
assert response.text == "Internal Server Error" |
||||
|
|
||||
|
|
||||
|
def test_broken_session_stream_raise(): |
||||
|
# Can raise ValueError on Pydantic v2 and ExceptionGroup on Pydantic v1 |
||||
|
with pytest.raises((ValueError, Exception)): |
||||
|
client.get("/broken-session-stream") |
||||
|
|
||||
|
|
||||
|
def test_broken_session_stream_no_raise(): |
||||
|
""" |
||||
|
When a dependency with yield raises after the streaming response already started |
||||
|
the 200 status code is already sent, but there's still an error in the server |
||||
|
afterwards, an exception is raised and captured or shown in the server logs. |
||||
|
""" |
||||
|
with TestClient(app, raise_server_exceptions=False) as client: |
||||
|
response = client.get("/broken-session-stream") |
||||
|
assert response.status_code == 200 |
||||
|
assert response.text == "" |
||||
@ -0,0 +1,79 @@ |
|||||
|
from contextlib import contextmanager |
||||
|
from typing import Any, Generator |
||||
|
|
||||
|
import pytest |
||||
|
from fastapi import Depends, FastAPI, WebSocket |
||||
|
from fastapi.testclient import TestClient |
||||
|
from typing_extensions import Annotated |
||||
|
|
||||
|
|
||||
|
class Session: |
||||
|
def __init__(self) -> None: |
||||
|
self.data = ["foo", "bar", "baz"] |
||||
|
self.open = True |
||||
|
|
||||
|
def __iter__(self) -> Generator[str, None, None]: |
||||
|
for item in self.data: |
||||
|
if self.open: |
||||
|
yield item |
||||
|
else: |
||||
|
raise ValueError("Session closed") |
||||
|
|
||||
|
|
||||
|
@contextmanager |
||||
|
def acquire_session() -> Generator[Session, None, None]: |
||||
|
session = Session() |
||||
|
try: |
||||
|
yield session |
||||
|
finally: |
||||
|
session.open = False |
||||
|
|
||||
|
|
||||
|
def dep_session() -> Any: |
||||
|
with acquire_session() as s: |
||||
|
yield s |
||||
|
|
||||
|
|
||||
|
def broken_dep_session() -> Any: |
||||
|
with acquire_session() as s: |
||||
|
s.open = False |
||||
|
yield s |
||||
|
|
||||
|
|
||||
|
SessionDep = Annotated[Session, Depends(dep_session)] |
||||
|
BrokenSessionDep = Annotated[Session, Depends(broken_dep_session)] |
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
|
||||
|
@app.websocket("/ws") |
||||
|
async def websocket_endpoint(websocket: WebSocket, session: SessionDep): |
||||
|
await websocket.accept() |
||||
|
for item in session: |
||||
|
await websocket.send_text(f"{item}") |
||||
|
|
||||
|
|
||||
|
@app.websocket("/ws-broken") |
||||
|
async def websocket_endpoint_broken(websocket: WebSocket, session: BrokenSessionDep): |
||||
|
await websocket.accept() |
||||
|
for item in session: |
||||
|
await websocket.send_text(f"{item}") # pragma no cover |
||||
|
|
||||
|
|
||||
|
client = TestClient(app) |
||||
|
|
||||
|
|
||||
|
def test_websocket_dependency_after_yield(): |
||||
|
with client.websocket_connect("/ws") as websocket: |
||||
|
data = websocket.receive_text() |
||||
|
assert data == "foo" |
||||
|
data = websocket.receive_text() |
||||
|
assert data == "bar" |
||||
|
data = websocket.receive_text() |
||||
|
assert data == "baz" |
||||
|
|
||||
|
|
||||
|
def test_websocket_dependency_after_yield_broken(): |
||||
|
with pytest.raises(ValueError, match="Session closed"): |
||||
|
with client.websocket_connect("/ws-broken"): |
||||
|
pass # pragma no cover |
||||
Loading…
Reference in new issue